From bd03a35ca252130cc02ac661c19e216ccafd0ee2 Mon Sep 17 00:00:00 2001 From: tannevaled Date: Wed, 3 Jun 2026 16:34:24 +0200 Subject: [PATCH] test(dynamicupstreams): cover ResetAllSRV, filtered records, cache bound, concurrent refresh Brings the package to 96.8% statement coverage. Adds: - TestResetAllSRV: full-cache drop - TestSRVFilteredRecords: LookupSRV partial-error semantics (usable records returned, error downgraded to a warning) - TestSRVCacheBound / TestACacheBound: eviction keeps the cache bounded at maxCacheEntries - TestSRVConcurrentRefreshDeduplicates / TestAConcurrentRefreshDeduplicates: two goroutines missing the read-lock fast path trigger only one lookup (verified under -race) The grace-period and filtered-records logs now use a level-enabled logger so the log-write branches execute. The only uncovered statements are the two double-checked-locking re-check returns, which are reachable solely in a narrow race window (the same untested idiom as reverseproxy's SRV/A). --- dynamicupstreams/dynamicupstreams_test.go | 195 +++++++++++++++++++++- 1 file changed, 194 insertions(+), 1 deletion(-) diff --git a/dynamicupstreams/dynamicupstreams_test.go b/dynamicupstreams/dynamicupstreams_test.go index 362c839c1..4110a8cd6 100644 --- a/dynamicupstreams/dynamicupstreams_test.go +++ b/dynamicupstreams/dynamicupstreams_test.go @@ -17,11 +17,14 @@ package dynamicupstreams import ( "context" "errors" + "fmt" "net" + "sync" "testing" "time" "go.uber.org/zap" + "go.uber.org/zap/zaptest" ) func TestSRVResolvesAndCaches(t *testing.T) { @@ -75,7 +78,8 @@ func TestSRVGracePeriodServesStale(t *testing.T) { fail := func(context.Context, string, string, string) (string, []*net.SRV, error) { return "", nil, errors.New("dns boom") } - targets, err := SRV(context.Background(), fail, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zap.NewNop()) + // a level-enabled logger so the "using previously cached" error log fires + targets, err := SRV(context.Background(), fail, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zaptest.NewLogger(t)) if err != nil { t.Fatalf("grace period should suppress the error: %v", err) } @@ -136,3 +140,192 @@ func TestAError(t *testing.T) { t.Fatal("expected an error when the A lookup fails") } } + +// TestSRVFilteredRecords covers the LookupSRV semantics where invalid names are +// filtered out and an error is returned alongside the usable remainder: the +// usable records must still be returned (the error is downgraded to a warning). +func TestSRVFilteredRecords(t *testing.T) { + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", []*net.SRV{{Target: "ok.example.", Port: 5432, Priority: 1, Weight: 10}}, + errors.New("some SRV names were filtered out") + } + // a level-enabled logger so the "SRV records filtered" warning fires + targets, err := SRV(context.Background(), lookup, "svc-filtered", "tcp", "x", time.Minute, 0, zaptest.NewLogger(t)) + if err != nil { + t.Fatalf("usable records must be returned despite the partial error: %v", err) + } + if len(targets) != 1 || targets[0].Host != "ok.example." { + t.Fatalf("unexpected targets: %+v", targets) + } +} + +// TestResetAllSRV verifies that the whole SRV cache is dropped. +func TestResetAllSRV(t *testing.T) { + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + if _, err := SRV(context.Background(), lookup, "svc-reset-all", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil { + t.Fatalf("SRV: %v", err) + } + srvMu.RLock() + populated := len(srvCache) > 0 + srvMu.RUnlock() + if !populated { + t.Fatal("expected the SRV cache to be populated before reset") + } + + ResetAllSRV() + + srvMu.RLock() + n := len(srvCache) + srvMu.RUnlock() + if n != 0 { + t.Fatalf("srvCache len = %d after ResetAllSRV, want 0", n) + } +} + +// TestSRVCacheBound verifies that inserting a brand-new entry once the cache is +// full evicts an existing one so the cache stays bounded. +func TestSRVCacheBound(t *testing.T) { + srvMu.Lock() + srvCache = make(map[string]cacheEntry) + srvMu.Unlock() + + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + for i := 0; i < maxCacheEntries+5; i++ { + name := fmt.Sprintf("svc-bound-%d", i) + if _, err := SRV(context.Background(), lookup, name, "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil { + t.Fatalf("SRV[%d]: %v", i, err) + } + } + srvMu.RLock() + n := len(srvCache) + srvMu.RUnlock() + if n > maxCacheEntries { + t.Fatalf("srvCache len = %d, want <= %d (cache must stay bounded)", n, maxCacheEntries) + } +} + +// TestACacheBound verifies the same bounding behavior for the A cache. +func TestACacheBound(t *testing.T) { + aMu.Lock() + aCache = make(map[string]cacheEntry) + aMu.Unlock() + + lookup := func(context.Context, string, string) ([]net.IP, error) { + return []net.IP{net.ParseIP("10.0.0.1")}, nil + } + for i := 0; i < maxCacheEntries+5; i++ { + name := fmt.Sprintf("db.bound-%d", i) + if _, err := A(context.Background(), lookup, "ip", name, "5432", time.Minute, zap.NewNop()); err != nil { + t.Fatalf("A[%d]: %v", i, err) + } + } + aMu.RLock() + n := len(aCache) + aMu.RUnlock() + if n > maxCacheEntries { + t.Fatalf("aCache len = %d, want <= %d (cache must stay bounded)", n, maxCacheEntries) + } +} + +// TestSRVConcurrentRefreshDeduplicates covers the double-checked locking: when +// two goroutines miss the read-lock fast path for the same key, only the first +// performs the lookup; the second re-checks under the write lock and is served +// from the freshly populated cache (no second lookup). +func TestSRVConcurrentRefreshDeduplicates(t *testing.T) { + srvMu.Lock() + srvCache = make(map[string]cacheEntry) + srvMu.Unlock() + + var calls int + var mu sync.Mutex + inLookup := make(chan struct{}) + release := make(chan struct{}) + lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) { + mu.Lock() + calls++ + first := calls == 1 + mu.Unlock() + if first { + close(inLookup) + <-release // hold the write lock until the second goroutine is queued + } + return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil + } + + done := make(chan struct{}, 2) + // G1: takes the write lock and blocks inside lookup. + go func() { + _, _ = SRV(context.Background(), lookup, "svc-conc", "tcp", "x", time.Minute, 0, zap.NewNop()) + done <- struct{}{} + }() + <-inLookup // G1 now holds the write lock; cache is still empty + + // G2: passes the empty read-lock check, then blocks on the write lock. + go func() { + _, _ = SRV(context.Background(), lookup, "svc-conc", "tcp", "x", time.Minute, 0, zap.NewNop()) + done <- struct{}{} + }() + time.Sleep(50 * time.Millisecond) // let G2 queue on srvMu.Lock() + + close(release) // G1 populates the cache and releases the lock + <-done + <-done + + mu.Lock() + got := calls + mu.Unlock() + if got != 1 { + t.Fatalf("lookup calls = %d, want 1 (second goroutine must hit the cache re-check)", got) + } +} + +// TestAConcurrentRefreshDeduplicates is the A-cache equivalent of the above. +func TestAConcurrentRefreshDeduplicates(t *testing.T) { + aMu.Lock() + aCache = make(map[string]cacheEntry) + aMu.Unlock() + + var calls int + var mu sync.Mutex + inLookup := make(chan struct{}) + release := make(chan struct{}) + lookup := func(context.Context, string, string) ([]net.IP, error) { + mu.Lock() + calls++ + first := calls == 1 + mu.Unlock() + if first { + close(inLookup) + <-release + } + return []net.IP{net.ParseIP("10.0.0.1")}, nil + } + + done := make(chan struct{}, 2) + go func() { + _, _ = A(context.Background(), lookup, "ip", "db.conc", "5432", time.Minute, zap.NewNop()) + done <- struct{}{} + }() + <-inLookup + + go func() { + _, _ = A(context.Background(), lookup, "ip", "db.conc", "5432", time.Minute, zap.NewNop()) + done <- struct{}{} + }() + time.Sleep(50 * time.Millisecond) + + close(release) + <-done + <-done + + mu.Lock() + got := calls + mu.Unlock() + if got != 1 { + t.Fatalf("lookup calls = %d, want 1 (second goroutine must hit the cache re-check)", got) + } +}