From 5d83b60727efc7c38d084968c2012966a811fc75 Mon Sep 17 00:00:00 2001 From: tannevaled Date: Tue, 2 Jun 2026 15:10:30 +0200 Subject: [PATCH 1/2] reverseproxy: extract DNS SRV+A discovery into a reusable package Move the SRV and A/AAAA resolution + caching logic out of reverseproxy into a new, transport-neutral package (dynamicupstreams) that returns neutral targets instead of *reverseproxy.Upstream. reverseproxy's SRVUpstreams and AUpstreams now delegate to it and build their own upstreams from the targets; behavior is unchanged. This lets other proxies (e.g. third-party layer4 proxies) reuse the same DNS discovery + caching instead of copying it, reducing duplication and maintenance burden. reverseproxy/upstreams.go shrinks substantially. RFC for the de-duplication discussed in caddyserver/caddy-l4#429. --- dynamicupstreams/dynamicupstreams.go | 202 +++++++++++++++++++ dynamicupstreams/dynamicupstreams_test.go | 138 +++++++++++++ modules/caddyhttp/reverseproxy/upstreams.go | 203 +++----------------- 3 files changed, 366 insertions(+), 177 deletions(-) create mode 100644 dynamicupstreams/dynamicupstreams.go create mode 100644 dynamicupstreams/dynamicupstreams_test.go diff --git a/dynamicupstreams/dynamicupstreams.go b/dynamicupstreams/dynamicupstreams.go new file mode 100644 index 000000000..1f0c59072 --- /dev/null +++ b/dynamicupstreams/dynamicupstreams.go @@ -0,0 +1,202 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package dynamicupstreams provides transport-neutral DNS-based discovery of +// upstream targets, with result caching. It is shared so that different proxies +// (e.g. the HTTP reverse_proxy and third-party layer4 proxies) can discover +// backends from DNS without each copying the resolution and caching logic. +// +// The package intentionally returns neutral [Target] values rather than any +// proxy-specific upstream type; each caller builds its own upstream +// representation from the targets. +package dynamicupstreams + +import ( + "context" + "net" + "strconv" + "sync" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Target is a single discovered upstream endpoint. It carries the SRV priority +// and weight when available (both zero for non-SRV lookups). +type Target struct { + Host string + Port string + Priority uint16 + Weight uint16 +} + +// SRVLookupFunc resolves SRV records. It matches the signature of +// (*net.Resolver).LookupSRV, so callers can pass that directly (and tests can +// inject a stub). +type SRVLookupFunc func(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error) + +// SRV resolves the given SRV record into targets, caching the result for +// refresh. lookup is the resolver to use (typically (*net.Resolver).LookupSRV). +// +// If the lookup fails but returns some records (e.g. a few invalid names were +// filtered out), those records are still used. If it fails with no records and +// grace > 0, the previously cached result keeps being served for up to grace +// past its refresh instead of returning an error. +func SRV(ctx context.Context, lookup SRVLookupFunc, service, proto, name string, refresh, grace time.Duration, logger *zap.Logger) ([]Target, error) { + key := srvKey(service, proto, name) + + // fast path: a fresh cached result under a read lock + srvMu.RLock() + cached := srvCache[key] + srvMu.RUnlock() + if cached.isFresh() { + return cached.targets, nil + } + + srvMu.Lock() + defer srvMu.Unlock() + + // re-check under the write lock in case another goroutine refreshed it + cached = srvCache[key] + if cached.isFresh() { + return cached.targets, nil + } + + _, records, err := lookup(ctx, service, proto, name) + if err != nil { + // From LookupSRV docs: invalid names are filtered out and an error is + // returned alongside any remaining results; only treat it as fatal when + // nothing usable came back. + if len(records) == 0 { + if grace > 0 && cached.targets != nil { + if c := logger.Check(zapcore.ErrorLevel, "SRV lookup failed; using previously cached"); c != nil { + c.Write(zap.String("service", service), zap.String("proto", proto), zap.String("name", name), zap.Error(err)) + } + cached.freshness = time.Now().Add(grace - refresh) + srvCache[key] = cached + return cached.targets, nil + } + return nil, err + } + if c := logger.Check(zapcore.WarnLevel, "SRV records filtered"); c != nil { + c.Write(zap.Error(err)) + } + } + + targets := make([]Target, len(records)) + for i, rec := range records { + targets[i] = Target{ + Host: rec.Target, + Port: strconv.Itoa(int(rec.Port)), + Priority: rec.Priority, + Weight: rec.Weight, + } + } + + // when inserting a brand-new entry (not replacing a stale one), bound the cache + if cached.freshness.IsZero() && len(srvCache) >= maxCacheEntries { + for k := range srvCache { + delete(srvCache, k) + break + } + } + srvCache[key] = cacheEntry{refresh: refresh, freshness: time.Now(), targets: targets} + return targets, nil +} + +// ResetSRV removes the cached result for a single SRV record. +func ResetSRV(service, proto, name string) { + srvMu.Lock() + delete(srvCache, srvKey(service, proto, name)) + srvMu.Unlock() +} + +// ResetAllSRV clears the entire SRV cache. +func ResetAllSRV() { + srvMu.Lock() + srvCache = make(map[string]cacheEntry) + srvMu.Unlock() +} + +func srvKey(service, proto, name string) string { + return service + "\x00" + proto + "\x00" + name +} + +const maxCacheEntries = 100 + +type cacheEntry struct { + refresh time.Duration + freshness time.Time + targets []Target +} + +func (e cacheEntry) isFresh() bool { + return !e.freshness.IsZero() && time.Since(e.freshness) < e.refresh +} + +// IPLookupFunc resolves a host's IP addresses. It matches the signature of +// (*net.Resolver).LookupIP, so callers can pass that directly (and tests can +// inject a stub). network is one of "ip", "ip4", "ip6". +type IPLookupFunc func(ctx context.Context, network, host string) ([]net.IP, error) + +// A resolves name's A/AAAA records into targets (one per address, all using the +// given port), caching the result for refresh. network selects the IP versions +// ("ip", "ip4" or "ip6"). lookup is the resolver to use (typically +// (*net.Resolver).LookupIP). +func A(ctx context.Context, lookup IPLookupFunc, network, name, port string, refresh time.Duration, logger *zap.Logger) ([]Target, error) { + key := name + "\x00" + port + "\x00" + network + + aMu.RLock() + cached := aCache[key] + aMu.RUnlock() + if cached.isFresh() { + return cached.targets, nil + } + + aMu.Lock() + defer aMu.Unlock() + + cached = aCache[key] + if cached.isFresh() { + return cached.targets, nil + } + + ips, err := lookup(ctx, network, name) + if err != nil { + return nil, err + } + + targets := make([]Target, len(ips)) + for i, ip := range ips { + targets[i] = Target{Host: ip.String(), Port: port} + } + + if cached.freshness.IsZero() && len(aCache) >= maxCacheEntries { + for k := range aCache { + delete(aCache, k) + break + } + } + aCache[key] = cacheEntry{refresh: refresh, freshness: time.Now(), targets: targets} + return targets, nil +} + +var ( + srvMu sync.RWMutex + srvCache = make(map[string]cacheEntry) + + aMu sync.RWMutex + aCache = make(map[string]cacheEntry) +) diff --git a/dynamicupstreams/dynamicupstreams_test.go b/dynamicupstreams/dynamicupstreams_test.go new file mode 100644 index 000000000..362c839c1 --- /dev/null +++ b/dynamicupstreams/dynamicupstreams_test.go @@ -0,0 +1,138 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dynamicupstreams + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "go.uber.org/zap" +) + +func TestSRVResolvesAndCaches(t *testing.T) { + calls := 0 + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + calls++ + return "", []*net.SRV{ + {Target: "a.example.", Port: 5432, Priority: 1, Weight: 10}, + {Target: "b.example.", Port: 5433, Priority: 1, Weight: 20}, + }, nil + } + + targets, err := SRV(context.Background(), lookup, "svc-cache", "tcp", "x", time.Minute, 0, zap.NewNop()) + if err != nil { + t.Fatalf("SRV: %v", err) + } + if len(targets) != 2 { + t.Fatalf("targets = %d, want 2", len(targets)) + } + if targets[0].Host != "a.example." || targets[0].Port != "5432" || targets[0].Weight != 10 { + t.Errorf("unexpected first target: %+v", targets[0]) + } + + // second call within refresh must be served from cache (no extra lookup) + if _, err := SRV(context.Background(), lookup, "svc-cache", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil { + t.Fatalf("SRV (cached): %v", err) + } + if calls != 1 { + t.Errorf("lookup calls = %d, want 1 (cached)", calls) + } +} + +func TestSRVErrorWithoutGrace(t *testing.T) { + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", nil, errors.New("dns boom") + } + if _, err := SRV(context.Background(), lookup, "svc-err", "tcp", "x", time.Minute, 0, zap.NewNop()); err == nil { + t.Fatal("expected an error when lookup fails and nothing is cached") + } +} + +func TestSRVGracePeriodServesStale(t *testing.T) { + ok := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + // tiny refresh so the entry is immediately stale on the next call + if _, err := SRV(context.Background(), ok, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zap.NewNop()); err != nil { + t.Fatalf("seeding cache: %v", err) + } + + fail := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", nil, errors.New("dns boom") + } + targets, err := SRV(context.Background(), fail, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zap.NewNop()) + if err != nil { + t.Fatalf("grace period should suppress the error: %v", err) + } + if len(targets) != 1 { + t.Errorf("expected the stale cached target to be served, got %d", len(targets)) + } +} + +func TestResetSRV(t *testing.T) { + calls := 0 + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + calls++ + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + if _, err := SRV(context.Background(), lookup, "svc-reset", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil { + t.Fatal(err) + } + ResetSRV("svc-reset", "tcp", "x") + if _, err := SRV(context.Background(), lookup, "svc-reset", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil { + t.Fatal(err) + } + if calls != 2 { + t.Errorf("lookup calls = %d, want 2 (cache was reset between calls)", calls) + } +} + +func TestAResolvesAndCaches(t *testing.T) { + calls := 0 + lookup := func(context.Context, string, string) ([]net.IP, error) { + calls++ + return []net.IP{net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2")}, nil + } + + targets, err := A(context.Background(), lookup, "ip", "db.a-test", "5432", time.Minute, zap.NewNop()) + if err != nil { + t.Fatalf("A: %v", err) + } + if len(targets) != 2 { + t.Fatalf("targets = %d, want 2", len(targets)) + } + if targets[0].Host != "10.0.0.1" || targets[0].Port != "5432" { + t.Errorf("unexpected first target: %+v", targets[0]) + } + + if _, err := A(context.Background(), lookup, "ip", "db.a-test", "5432", time.Minute, zap.NewNop()); err != nil { + t.Fatalf("A (cached): %v", err) + } + if calls != 1 { + t.Errorf("lookup calls = %d, want 1 (cached)", calls) + } +} + +func TestAError(t *testing.T) { + lookup := func(context.Context, string, string) ([]net.IP, error) { + return nil, errors.New("dns boom") + } + if _, err := A(context.Background(), lookup, "ip", "db.a-err", "5432", time.Minute, zap.NewNop()); err == nil { + t.Fatal("expected an error when the A lookup fails") + } +} diff --git a/modules/caddyhttp/reverseproxy/upstreams.go b/modules/caddyhttp/reverseproxy/upstreams.go index f7077ce78..93a23578a 100644 --- a/modules/caddyhttp/reverseproxy/upstreams.go +++ b/modules/caddyhttp/reverseproxy/upstreams.go @@ -7,14 +7,13 @@ import ( weakrand "math/rand/v2" "net" "net/http" - "strconv" - "sync" "time" "go.uber.org/zap" "go.uber.org/zap/zapcore" "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/dynamicupstreams" ) func init() { @@ -120,101 +119,42 @@ func (su *SRVUpstreams) Provision(ctx caddy.Context) error { } func (su *SRVUpstreams) ResetCache(r *http.Request) error { - srvsMu.Lock() if r == nil { - srvs = make(map[string]srvLookup) - } else { - suAddr, _, _, _ := su.expandedAddr(r) - delete(srvs, suAddr) + dynamicupstreams.ResetAllSRV() + return nil } - srvsMu.Unlock() + _, service, proto, name := su.expandedAddr(r) + dynamicupstreams.ResetSRV(service, proto, name) return nil } func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { - suAddr, service, proto, name := su.expandedAddr(r) + _, service, proto, name := su.expandedAddr(r) - // first, use a cheap read-lock to return a cached result quickly - srvsMu.RLock() - cached := srvs[suAddr] - srvsMu.RUnlock() - if cached.isFresh() { - return allNew(cached.upstreams), nil - } - - // otherwise, obtain a write-lock to update the cached value - srvsMu.Lock() - defer srvsMu.Unlock() - - // check to see if it's still stale, since we're now in a different - // lock from when we first checked freshness; another goroutine might - // have refreshed it in the meantime before we re-obtained our lock - cached = srvs[suAddr] - if cached.isFresh() { - return allNew(cached.upstreams), nil - } - - if c := su.logger.Check(zapcore.DebugLevel, "refreshing SRV upstreams"); c != nil { - c.Write( - zap.String("service", service), - zap.String("proto", proto), - zap.String("name", name), - ) - } - - _, records, err := su.resolver.LookupSRV(r.Context(), service, proto, name) + targets, err := dynamicupstreams.SRV(r.Context(), su.resolver.LookupSRV, + service, proto, name, + time.Duration(su.Refresh), time.Duration(su.GracePeriod), su.logger) if err != nil { - // From LookupSRV docs: "If the response contains invalid names, those records are filtered - // out and an error will be returned alongside the remaining results, if any." Thus, we - // only return an error if no records were also returned. - if len(records) == 0 { - if su.GracePeriod > 0 { - if c := su.logger.Check(zapcore.ErrorLevel, "SRV lookup failed; using previously cached"); c != nil { - c.Write(zap.Error(err)) - } - cached.freshness = time.Now().Add(time.Duration(su.GracePeriod) - time.Duration(su.Refresh)) - srvs[suAddr] = cached - return allNew(cached.upstreams), nil - } - return nil, err - } - if c := su.logger.Check(zapcore.WarnLevel, "SRV records filtered"); c != nil { - c.Write(zap.Error(err)) - } + return nil, err } - upstreams := make([]Upstream, len(records)) - for i, rec := range records { + upstreams := make([]*Upstream, len(targets)) + for i, t := range targets { if c := su.logger.Check(zapcore.DebugLevel, "discovered SRV record"); c != nil { c.Write( - zap.String("target", rec.Target), - zap.Uint16("port", rec.Port), - zap.Uint16("priority", rec.Priority), - zap.Uint16("weight", rec.Weight), + zap.String("target", t.Host), + zap.String("port", t.Port), + zap.Uint16("priority", t.Priority), + zap.Uint16("weight", t.Weight), ) } - addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port))) + addr := net.JoinHostPort(t.Host, t.Port) if su.DialNetwork != "" { addr = su.DialNetwork + "/" + addr } - upstreams[i] = Upstream{Dial: addr} + upstreams[i] = &Upstream{Dial: addr} } - - // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full - if cached.freshness.IsZero() && len(srvs) >= 100 { - for randomKey := range srvs { - delete(srvs, randomKey) - break - } - } - - srvs[suAddr] = srvLookup{ - srvUpstreams: su, - freshness: time.Now(), - upstreams: upstreams, - } - - return allNew(upstreams), nil + return upstreams, nil } func (su SRVUpstreams) String() string { @@ -247,16 +187,6 @@ func (SRVUpstreams) formattedAddr(service, proto, name string) string { return fmt.Sprintf("_%s._%s.%s", service, proto, name) } -type srvLookup struct { - srvUpstreams SRVUpstreams - freshness time.Time - upstreams []Upstream -} - -func (sl srvLookup) isFresh() bool { - return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh) -} - type IPVersions struct { IPv4 *bool `json:"ipv4,omitempty"` IPv6 *bool `json:"ipv6,omitempty"` @@ -357,93 +287,28 @@ func (au *AUpstreams) Provision(ctx caddy.Context) error { func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) { repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer) - // Map ipVersion early, so we can use it as part of the cache-key. - // This should be fairly inexpensive and comes and the upside of - // allowing the same dynamic upstream (name + port combination) - // to be used multiple times with different ip versions. - // - // It also forced a cache-miss if a previously cached dynamic - // upstream changes its ip version, e.g. after a config reload, - // while keeping the cache-invalidation as simple as it currently is. ipVersion := resolveIpVersion(au.Versions) - - auStr := repl.ReplaceAll(au.String()+ipVersion, "") - - // first, use a cheap read-lock to return a cached result quickly - aAaaaMu.RLock() - cached := aAaaa[auStr] - aAaaaMu.RUnlock() - if cached.isFresh() { - return allNew(cached.upstreams), nil - } - - // otherwise, obtain a write-lock to update the cached value - aAaaaMu.Lock() - defer aAaaaMu.Unlock() - - // check to see if it's still stale, since we're now in a different - // lock from when we first checked freshness; another goroutine might - // have refreshed it in the meantime before we re-obtained our lock - cached = aAaaa[auStr] - if cached.isFresh() { - return allNew(cached.upstreams), nil - } - name := repl.ReplaceAll(au.Name, "") port := repl.ReplaceAll(au.Port, "") - if c := au.logger.Check(zapcore.DebugLevel, "refreshing A upstreams"); c != nil { - c.Write( - zap.String("version", ipVersion), - zap.String("name", name), - zap.String("port", port), - ) - } - - ips, err := au.resolver.LookupIP(r.Context(), ipVersion, name) + targets, err := dynamicupstreams.A(r.Context(), au.resolver.LookupIP, + ipVersion, name, port, time.Duration(au.Refresh), au.logger) if err != nil { return nil, err } - upstreams := make([]Upstream, len(ips)) - for i, ip := range ips { + upstreams := make([]*Upstream, len(targets)) + for i, t := range targets { if c := au.logger.Check(zapcore.DebugLevel, "discovered A record"); c != nil { - c.Write(zap.String("ip", ip.String())) - } - upstreams[i] = Upstream{ - Dial: net.JoinHostPort(ip.String(), port), + c.Write(zap.String("ip", t.Host)) } + upstreams[i] = &Upstream{Dial: net.JoinHostPort(t.Host, t.Port)} } - - // before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full - if cached.freshness.IsZero() && len(aAaaa) >= 100 { - for randomKey := range aAaaa { - delete(aAaaa, randomKey) - break - } - } - - aAaaa[auStr] = aLookup{ - aUpstreams: au, - freshness: time.Now(), - upstreams: upstreams, - } - - return allNew(upstreams), nil + return upstreams, nil } func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) } -type aLookup struct { - aUpstreams AUpstreams - freshness time.Time - upstreams []Upstream -} - -func (al aLookup) isFresh() bool { - return time.Since(al.freshness) < time.Duration(al.aUpstreams.Refresh) -} - // MultiUpstreams is a single dynamic upstream source that // aggregates the results of multiple dynamic upstream sources. // All configured sources will be queried in order, with their @@ -548,22 +413,6 @@ func (u *UpstreamResolver) ParseAddresses() error { return nil } -func allNew(upstreams []Upstream) []*Upstream { - results := make([]*Upstream, len(upstreams)) - for i := range upstreams { - results[i] = &Upstream{Dial: upstreams[i].Dial} - } - return results -} - -var ( - srvs = make(map[string]srvLookup) - srvsMu sync.RWMutex - - aAaaa = make(map[string]aLookup) - aAaaaMu sync.RWMutex -) - // Interface guards var ( _ caddy.Provisioner = (*SRVUpstreams)(nil) From bd03a35ca252130cc02ac661c19e216ccafd0ee2 Mon Sep 17 00:00:00 2001 From: tannevaled Date: Wed, 3 Jun 2026 16:34:24 +0200 Subject: [PATCH 2/2] test(dynamicupstreams): cover ResetAllSRV, filtered records, cache bound, concurrent refresh Brings the package to 96.8% statement coverage. Adds: - TestResetAllSRV: full-cache drop - TestSRVFilteredRecords: LookupSRV partial-error semantics (usable records returned, error downgraded to a warning) - TestSRVCacheBound / TestACacheBound: eviction keeps the cache bounded at maxCacheEntries - TestSRVConcurrentRefreshDeduplicates / TestAConcurrentRefreshDeduplicates: two goroutines missing the read-lock fast path trigger only one lookup (verified under -race) The grace-period and filtered-records logs now use a level-enabled logger so the log-write branches execute. The only uncovered statements are the two double-checked-locking re-check returns, which are reachable solely in a narrow race window (the same untested idiom as reverseproxy's SRV/A). --- dynamicupstreams/dynamicupstreams_test.go | 195 +++++++++++++++++++++- 1 file changed, 194 insertions(+), 1 deletion(-) diff --git a/dynamicupstreams/dynamicupstreams_test.go b/dynamicupstreams/dynamicupstreams_test.go index 362c839c1..4110a8cd6 100644 --- a/dynamicupstreams/dynamicupstreams_test.go +++ b/dynamicupstreams/dynamicupstreams_test.go @@ -17,11 +17,14 @@ package dynamicupstreams import ( "context" "errors" + "fmt" "net" + "sync" "testing" "time" "go.uber.org/zap" + "go.uber.org/zap/zaptest" ) func TestSRVResolvesAndCaches(t *testing.T) { @@ -75,7 +78,8 @@ func TestSRVGracePeriodServesStale(t *testing.T) { fail := func(context.Context, string, string, string) (string, []*net.SRV, error) { return "", nil, errors.New("dns boom") } - targets, err := SRV(context.Background(), fail, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zap.NewNop()) + // a level-enabled logger so the "using previously cached" error log fires + targets, err := SRV(context.Background(), fail, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zaptest.NewLogger(t)) if err != nil { t.Fatalf("grace period should suppress the error: %v", err) } @@ -136,3 +140,192 @@ func TestAError(t *testing.T) { t.Fatal("expected an error when the A lookup fails") } } + +// TestSRVFilteredRecords covers the LookupSRV semantics where invalid names are +// filtered out and an error is returned alongside the usable remainder: the +// usable records must still be returned (the error is downgraded to a warning). +func TestSRVFilteredRecords(t *testing.T) { + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", []*net.SRV{{Target: "ok.example.", Port: 5432, Priority: 1, Weight: 10}}, + errors.New("some SRV names were filtered out") + } + // a level-enabled logger so the "SRV records filtered" warning fires + targets, err := SRV(context.Background(), lookup, "svc-filtered", "tcp", "x", time.Minute, 0, zaptest.NewLogger(t)) + if err != nil { + t.Fatalf("usable records must be returned despite the partial error: %v", err) + } + if len(targets) != 1 || targets[0].Host != "ok.example." { + t.Fatalf("unexpected targets: %+v", targets) + } +} + +// TestResetAllSRV verifies that the whole SRV cache is dropped. +func TestResetAllSRV(t *testing.T) { + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + if _, err := SRV(context.Background(), lookup, "svc-reset-all", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil { + t.Fatalf("SRV: %v", err) + } + srvMu.RLock() + populated := len(srvCache) > 0 + srvMu.RUnlock() + if !populated { + t.Fatal("expected the SRV cache to be populated before reset") + } + + ResetAllSRV() + + srvMu.RLock() + n := len(srvCache) + srvMu.RUnlock() + if n != 0 { + t.Fatalf("srvCache len = %d after ResetAllSRV, want 0", n) + } +} + +// TestSRVCacheBound verifies that inserting a brand-new entry once the cache is +// full evicts an existing one so the cache stays bounded. +func TestSRVCacheBound(t *testing.T) { + srvMu.Lock() + srvCache = make(map[string]cacheEntry) + srvMu.Unlock() + + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + for i := 0; i < maxCacheEntries+5; i++ { + name := fmt.Sprintf("svc-bound-%d", i) + if _, err := SRV(context.Background(), lookup, name, "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil { + t.Fatalf("SRV[%d]: %v", i, err) + } + } + srvMu.RLock() + n := len(srvCache) + srvMu.RUnlock() + if n > maxCacheEntries { + t.Fatalf("srvCache len = %d, want <= %d (cache must stay bounded)", n, maxCacheEntries) + } +} + +// TestACacheBound verifies the same bounding behavior for the A cache. +func TestACacheBound(t *testing.T) { + aMu.Lock() + aCache = make(map[string]cacheEntry) + aMu.Unlock() + + lookup := func(context.Context, string, string) ([]net.IP, error) { + return []net.IP{net.ParseIP("10.0.0.1")}, nil + } + for i := 0; i < maxCacheEntries+5; i++ { + name := fmt.Sprintf("db.bound-%d", i) + if _, err := A(context.Background(), lookup, "ip", name, "5432", time.Minute, zap.NewNop()); err != nil { + t.Fatalf("A[%d]: %v", i, err) + } + } + aMu.RLock() + n := len(aCache) + aMu.RUnlock() + if n > maxCacheEntries { + t.Fatalf("aCache len = %d, want <= %d (cache must stay bounded)", n, maxCacheEntries) + } +} + +// TestSRVConcurrentRefreshDeduplicates covers the double-checked locking: when +// two goroutines miss the read-lock fast path for the same key, only the first +// performs the lookup; the second re-checks under the write lock and is served +// from the freshly populated cache (no second lookup). +func TestSRVConcurrentRefreshDeduplicates(t *testing.T) { + srvMu.Lock() + srvCache = make(map[string]cacheEntry) + srvMu.Unlock() + + var calls int + var mu sync.Mutex + inLookup := make(chan struct{}) + release := make(chan struct{}) + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + mu.Lock() + calls++ + first := calls == 1 + mu.Unlock() + if first { + close(inLookup) + <-release // hold the write lock until the second goroutine is queued + } + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + + done := make(chan struct{}, 2) + // G1: takes the write lock and blocks inside lookup. + go func() { + _, _ = SRV(context.Background(), lookup, "svc-conc", "tcp", "x", time.Minute, 0, zap.NewNop()) + done <- struct{}{} + }() + <-inLookup // G1 now holds the write lock; cache is still empty + + // G2: passes the empty read-lock check, then blocks on the write lock. + go func() { + _, _ = SRV(context.Background(), lookup, "svc-conc", "tcp", "x", time.Minute, 0, zap.NewNop()) + done <- struct{}{} + }() + time.Sleep(50 * time.Millisecond) // let G2 queue on srvMu.Lock() + + close(release) // G1 populates the cache and releases the lock + <-done + <-done + + mu.Lock() + got := calls + mu.Unlock() + if got != 1 { + t.Fatalf("lookup calls = %d, want 1 (second goroutine must hit the cache re-check)", got) + } +} + +// TestAConcurrentRefreshDeduplicates is the A-cache equivalent of the above. +func TestAConcurrentRefreshDeduplicates(t *testing.T) { + aMu.Lock() + aCache = make(map[string]cacheEntry) + aMu.Unlock() + + var calls int + var mu sync.Mutex + inLookup := make(chan struct{}) + release := make(chan struct{}) + lookup := func(context.Context, string, string) ([]net.IP, error) { + mu.Lock() + calls++ + first := calls == 1 + mu.Unlock() + if first { + close(inLookup) + <-release + } + return []net.IP{net.ParseIP("10.0.0.1")}, nil + } + + done := make(chan struct{}, 2) + go func() { + _, _ = A(context.Background(), lookup, "ip", "db.conc", "5432", time.Minute, zap.NewNop()) + done <- struct{}{} + }() + <-inLookup + + go func() { + _, _ = A(context.Background(), lookup, "ip", "db.conc", "5432", time.Minute, zap.NewNop()) + done <- struct{}{} + }() + time.Sleep(50 * time.Millisecond) + + close(release) + <-done + <-done + + mu.Lock() + got := calls + mu.Unlock() + if got != 1 { + t.Fatalf("lookup calls = %d, want 1 (second goroutine must hit the cache re-check)", got) + } +}