mirror of
https://github.com/caddyserver/caddy.git
synced 2026-06-28 04:41:41 +00:00
Merge bd03a35ca2 into 13a4c3f43c
This commit is contained in:
commit
7ced7edabe
3 changed files with 559 additions and 177 deletions
202
dynamicupstreams/dynamicupstreams.go
Normal file
202
dynamicupstreams/dynamicupstreams.go
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package dynamicupstreams provides transport-neutral DNS-based discovery of
|
||||
// upstream targets, with result caching. It is shared so that different proxies
|
||||
// (e.g. the HTTP reverse_proxy and third-party layer4 proxies) can discover
|
||||
// backends from DNS without each copying the resolution and caching logic.
|
||||
//
|
||||
// The package intentionally returns neutral [Target] values rather than any
|
||||
// proxy-specific upstream type; each caller builds its own upstream
|
||||
// representation from the targets.
|
||||
package dynamicupstreams
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Target is a single discovered upstream endpoint. It carries the SRV priority
|
||||
// and weight when available (both zero for non-SRV lookups).
|
||||
type Target struct {
|
||||
Host string
|
||||
Port string
|
||||
Priority uint16
|
||||
Weight uint16
|
||||
}
|
||||
|
||||
// SRVLookupFunc resolves SRV records. It matches the signature of
|
||||
// (*net.Resolver).LookupSRV, so callers can pass that directly (and tests can
|
||||
// inject a stub).
|
||||
type SRVLookupFunc func(ctx context.Context, service, proto, name string) (cname string, addrs []*net.SRV, err error)
|
||||
|
||||
// SRV resolves the given SRV record into targets, caching the result for
|
||||
// refresh. lookup is the resolver to use (typically (*net.Resolver).LookupSRV).
|
||||
//
|
||||
// If the lookup fails but returns some records (e.g. a few invalid names were
|
||||
// filtered out), those records are still used. If it fails with no records and
|
||||
// grace > 0, the previously cached result keeps being served for up to grace
|
||||
// past its refresh instead of returning an error.
|
||||
func SRV(ctx context.Context, lookup SRVLookupFunc, service, proto, name string, refresh, grace time.Duration, logger *zap.Logger) ([]Target, error) {
|
||||
key := srvKey(service, proto, name)
|
||||
|
||||
// fast path: a fresh cached result under a read lock
|
||||
srvMu.RLock()
|
||||
cached := srvCache[key]
|
||||
srvMu.RUnlock()
|
||||
if cached.isFresh() {
|
||||
return cached.targets, nil
|
||||
}
|
||||
|
||||
srvMu.Lock()
|
||||
defer srvMu.Unlock()
|
||||
|
||||
// re-check under the write lock in case another goroutine refreshed it
|
||||
cached = srvCache[key]
|
||||
if cached.isFresh() {
|
||||
return cached.targets, nil
|
||||
}
|
||||
|
||||
_, records, err := lookup(ctx, service, proto, name)
|
||||
if err != nil {
|
||||
// From LookupSRV docs: invalid names are filtered out and an error is
|
||||
// returned alongside any remaining results; only treat it as fatal when
|
||||
// nothing usable came back.
|
||||
if len(records) == 0 {
|
||||
if grace > 0 && cached.targets != nil {
|
||||
if c := logger.Check(zapcore.ErrorLevel, "SRV lookup failed; using previously cached"); c != nil {
|
||||
c.Write(zap.String("service", service), zap.String("proto", proto), zap.String("name", name), zap.Error(err))
|
||||
}
|
||||
cached.freshness = time.Now().Add(grace - refresh)
|
||||
srvCache[key] = cached
|
||||
return cached.targets, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if c := logger.Check(zapcore.WarnLevel, "SRV records filtered"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
targets := make([]Target, len(records))
|
||||
for i, rec := range records {
|
||||
targets[i] = Target{
|
||||
Host: rec.Target,
|
||||
Port: strconv.Itoa(int(rec.Port)),
|
||||
Priority: rec.Priority,
|
||||
Weight: rec.Weight,
|
||||
}
|
||||
}
|
||||
|
||||
// when inserting a brand-new entry (not replacing a stale one), bound the cache
|
||||
if cached.freshness.IsZero() && len(srvCache) >= maxCacheEntries {
|
||||
for k := range srvCache {
|
||||
delete(srvCache, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
srvCache[key] = cacheEntry{refresh: refresh, freshness: time.Now(), targets: targets}
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
// ResetSRV removes the cached result for a single SRV record.
|
||||
func ResetSRV(service, proto, name string) {
|
||||
srvMu.Lock()
|
||||
delete(srvCache, srvKey(service, proto, name))
|
||||
srvMu.Unlock()
|
||||
}
|
||||
|
||||
// ResetAllSRV clears the entire SRV cache.
|
||||
func ResetAllSRV() {
|
||||
srvMu.Lock()
|
||||
srvCache = make(map[string]cacheEntry)
|
||||
srvMu.Unlock()
|
||||
}
|
||||
|
||||
func srvKey(service, proto, name string) string {
|
||||
return service + "\x00" + proto + "\x00" + name
|
||||
}
|
||||
|
||||
const maxCacheEntries = 100
|
||||
|
||||
type cacheEntry struct {
|
||||
refresh time.Duration
|
||||
freshness time.Time
|
||||
targets []Target
|
||||
}
|
||||
|
||||
func (e cacheEntry) isFresh() bool {
|
||||
return !e.freshness.IsZero() && time.Since(e.freshness) < e.refresh
|
||||
}
|
||||
|
||||
// IPLookupFunc resolves a host's IP addresses. It matches the signature of
|
||||
// (*net.Resolver).LookupIP, so callers can pass that directly (and tests can
|
||||
// inject a stub). network is one of "ip", "ip4", "ip6".
|
||||
type IPLookupFunc func(ctx context.Context, network, host string) ([]net.IP, error)
|
||||
|
||||
// A resolves name's A/AAAA records into targets (one per address, all using the
|
||||
// given port), caching the result for refresh. network selects the IP versions
|
||||
// ("ip", "ip4" or "ip6"). lookup is the resolver to use (typically
|
||||
// (*net.Resolver).LookupIP).
|
||||
func A(ctx context.Context, lookup IPLookupFunc, network, name, port string, refresh time.Duration, logger *zap.Logger) ([]Target, error) {
|
||||
key := name + "\x00" + port + "\x00" + network
|
||||
|
||||
aMu.RLock()
|
||||
cached := aCache[key]
|
||||
aMu.RUnlock()
|
||||
if cached.isFresh() {
|
||||
return cached.targets, nil
|
||||
}
|
||||
|
||||
aMu.Lock()
|
||||
defer aMu.Unlock()
|
||||
|
||||
cached = aCache[key]
|
||||
if cached.isFresh() {
|
||||
return cached.targets, nil
|
||||
}
|
||||
|
||||
ips, err := lookup(ctx, network, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targets := make([]Target, len(ips))
|
||||
for i, ip := range ips {
|
||||
targets[i] = Target{Host: ip.String(), Port: port}
|
||||
}
|
||||
|
||||
if cached.freshness.IsZero() && len(aCache) >= maxCacheEntries {
|
||||
for k := range aCache {
|
||||
delete(aCache, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
aCache[key] = cacheEntry{refresh: refresh, freshness: time.Now(), targets: targets}
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
var (
|
||||
srvMu sync.RWMutex
|
||||
srvCache = make(map[string]cacheEntry)
|
||||
|
||||
aMu sync.RWMutex
|
||||
aCache = make(map[string]cacheEntry)
|
||||
)
|
||||
331
dynamicupstreams/dynamicupstreams_test.go
Normal file
331
dynamicupstreams/dynamicupstreams_test.go
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dynamicupstreams
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func TestSRVResolvesAndCaches(t *testing.T) {
|
||||
calls := 0
|
||||
lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
calls++
|
||||
return "", []*net.SRV{
|
||||
{Target: "a.example.", Port: 5432, Priority: 1, Weight: 10},
|
||||
{Target: "b.example.", Port: 5433, Priority: 1, Weight: 20},
|
||||
}, nil
|
||||
}
|
||||
|
||||
targets, err := SRV(context.Background(), lookup, "svc-cache", "tcp", "x", time.Minute, 0, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("SRV: %v", err)
|
||||
}
|
||||
if len(targets) != 2 {
|
||||
t.Fatalf("targets = %d, want 2", len(targets))
|
||||
}
|
||||
if targets[0].Host != "a.example." || targets[0].Port != "5432" || targets[0].Weight != 10 {
|
||||
t.Errorf("unexpected first target: %+v", targets[0])
|
||||
}
|
||||
|
||||
// second call within refresh must be served from cache (no extra lookup)
|
||||
if _, err := SRV(context.Background(), lookup, "svc-cache", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil {
|
||||
t.Fatalf("SRV (cached): %v", err)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("lookup calls = %d, want 1 (cached)", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSRVErrorWithoutGrace(t *testing.T) {
|
||||
lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
return "", nil, errors.New("dns boom")
|
||||
}
|
||||
if _, err := SRV(context.Background(), lookup, "svc-err", "tcp", "x", time.Minute, 0, zap.NewNop()); err == nil {
|
||||
t.Fatal("expected an error when lookup fails and nothing is cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSRVGracePeriodServesStale(t *testing.T) {
|
||||
ok := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil
|
||||
}
|
||||
// tiny refresh so the entry is immediately stale on the next call
|
||||
if _, err := SRV(context.Background(), ok, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zap.NewNop()); err != nil {
|
||||
t.Fatalf("seeding cache: %v", err)
|
||||
}
|
||||
|
||||
fail := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
return "", nil, errors.New("dns boom")
|
||||
}
|
||||
// a level-enabled logger so the "using previously cached" error log fires
|
||||
targets, err := SRV(context.Background(), fail, "svc-grace", "tcp", "x", time.Nanosecond, time.Hour, zaptest.NewLogger(t))
|
||||
if err != nil {
|
||||
t.Fatalf("grace period should suppress the error: %v", err)
|
||||
}
|
||||
if len(targets) != 1 {
|
||||
t.Errorf("expected the stale cached target to be served, got %d", len(targets))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSRV(t *testing.T) {
|
||||
calls := 0
|
||||
lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
calls++
|
||||
return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil
|
||||
}
|
||||
if _, err := SRV(context.Background(), lookup, "svc-reset", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ResetSRV("svc-reset", "tcp", "x")
|
||||
if _, err := SRV(context.Background(), lookup, "svc-reset", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Errorf("lookup calls = %d, want 2 (cache was reset between calls)", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAResolvesAndCaches(t *testing.T) {
|
||||
calls := 0
|
||||
lookup := func(context.Context, string, string) ([]net.IP, error) {
|
||||
calls++
|
||||
return []net.IP{net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2")}, nil
|
||||
}
|
||||
|
||||
targets, err := A(context.Background(), lookup, "ip", "db.a-test", "5432", time.Minute, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("A: %v", err)
|
||||
}
|
||||
if len(targets) != 2 {
|
||||
t.Fatalf("targets = %d, want 2", len(targets))
|
||||
}
|
||||
if targets[0].Host != "10.0.0.1" || targets[0].Port != "5432" {
|
||||
t.Errorf("unexpected first target: %+v", targets[0])
|
||||
}
|
||||
|
||||
if _, err := A(context.Background(), lookup, "ip", "db.a-test", "5432", time.Minute, zap.NewNop()); err != nil {
|
||||
t.Fatalf("A (cached): %v", err)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("lookup calls = %d, want 1 (cached)", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAError(t *testing.T) {
|
||||
lookup := func(context.Context, string, string) ([]net.IP, error) {
|
||||
return nil, errors.New("dns boom")
|
||||
}
|
||||
if _, err := A(context.Background(), lookup, "ip", "db.a-err", "5432", time.Minute, zap.NewNop()); err == nil {
|
||||
t.Fatal("expected an error when the A lookup fails")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSRVFilteredRecords covers the LookupSRV semantics where invalid names are
|
||||
// filtered out and an error is returned alongside the usable remainder: the
|
||||
// usable records must still be returned (the error is downgraded to a warning).
|
||||
func TestSRVFilteredRecords(t *testing.T) {
|
||||
lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
return "", []*net.SRV{{Target: "ok.example.", Port: 5432, Priority: 1, Weight: 10}},
|
||||
errors.New("some SRV names were filtered out")
|
||||
}
|
||||
// a level-enabled logger so the "SRV records filtered" warning fires
|
||||
targets, err := SRV(context.Background(), lookup, "svc-filtered", "tcp", "x", time.Minute, 0, zaptest.NewLogger(t))
|
||||
if err != nil {
|
||||
t.Fatalf("usable records must be returned despite the partial error: %v", err)
|
||||
}
|
||||
if len(targets) != 1 || targets[0].Host != "ok.example." {
|
||||
t.Fatalf("unexpected targets: %+v", targets)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResetAllSRV verifies that the whole SRV cache is dropped.
|
||||
func TestResetAllSRV(t *testing.T) {
|
||||
lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil
|
||||
}
|
||||
if _, err := SRV(context.Background(), lookup, "svc-reset-all", "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil {
|
||||
t.Fatalf("SRV: %v", err)
|
||||
}
|
||||
srvMu.RLock()
|
||||
populated := len(srvCache) > 0
|
||||
srvMu.RUnlock()
|
||||
if !populated {
|
||||
t.Fatal("expected the SRV cache to be populated before reset")
|
||||
}
|
||||
|
||||
ResetAllSRV()
|
||||
|
||||
srvMu.RLock()
|
||||
n := len(srvCache)
|
||||
srvMu.RUnlock()
|
||||
if n != 0 {
|
||||
t.Fatalf("srvCache len = %d after ResetAllSRV, want 0", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSRVCacheBound verifies that inserting a brand-new entry once the cache is
|
||||
// full evicts an existing one so the cache stays bounded.
|
||||
func TestSRVCacheBound(t *testing.T) {
|
||||
srvMu.Lock()
|
||||
srvCache = make(map[string]cacheEntry)
|
||||
srvMu.Unlock()
|
||||
|
||||
lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil
|
||||
}
|
||||
for i := 0; i < maxCacheEntries+5; i++ {
|
||||
name := fmt.Sprintf("svc-bound-%d", i)
|
||||
if _, err := SRV(context.Background(), lookup, name, "tcp", "x", time.Minute, 0, zap.NewNop()); err != nil {
|
||||
t.Fatalf("SRV[%d]: %v", i, err)
|
||||
}
|
||||
}
|
||||
srvMu.RLock()
|
||||
n := len(srvCache)
|
||||
srvMu.RUnlock()
|
||||
if n > maxCacheEntries {
|
||||
t.Fatalf("srvCache len = %d, want <= %d (cache must stay bounded)", n, maxCacheEntries)
|
||||
}
|
||||
}
|
||||
|
||||
// TestACacheBound verifies the same bounding behavior for the A cache.
|
||||
func TestACacheBound(t *testing.T) {
|
||||
aMu.Lock()
|
||||
aCache = make(map[string]cacheEntry)
|
||||
aMu.Unlock()
|
||||
|
||||
lookup := func(context.Context, string, string) ([]net.IP, error) {
|
||||
return []net.IP{net.ParseIP("10.0.0.1")}, nil
|
||||
}
|
||||
for i := 0; i < maxCacheEntries+5; i++ {
|
||||
name := fmt.Sprintf("db.bound-%d", i)
|
||||
if _, err := A(context.Background(), lookup, "ip", name, "5432", time.Minute, zap.NewNop()); err != nil {
|
||||
t.Fatalf("A[%d]: %v", i, err)
|
||||
}
|
||||
}
|
||||
aMu.RLock()
|
||||
n := len(aCache)
|
||||
aMu.RUnlock()
|
||||
if n > maxCacheEntries {
|
||||
t.Fatalf("aCache len = %d, want <= %d (cache must stay bounded)", n, maxCacheEntries)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSRVConcurrentRefreshDeduplicates covers the double-checked locking: when
|
||||
// two goroutines miss the read-lock fast path for the same key, only the first
|
||||
// performs the lookup; the second re-checks under the write lock and is served
|
||||
// from the freshly populated cache (no second lookup).
|
||||
func TestSRVConcurrentRefreshDeduplicates(t *testing.T) {
|
||||
srvMu.Lock()
|
||||
srvCache = make(map[string]cacheEntry)
|
||||
srvMu.Unlock()
|
||||
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
inLookup := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
lookup := func(context.Context, string, string, string) (string, []*net.SRV, error) {
|
||||
mu.Lock()
|
||||
calls++
|
||||
first := calls == 1
|
||||
mu.Unlock()
|
||||
if first {
|
||||
close(inLookup)
|
||||
<-release // hold the write lock until the second goroutine is queued
|
||||
}
|
||||
return "", []*net.SRV{{Target: "a.example.", Port: 1}}, nil
|
||||
}
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
// G1: takes the write lock and blocks inside lookup.
|
||||
go func() {
|
||||
_, _ = SRV(context.Background(), lookup, "svc-conc", "tcp", "x", time.Minute, 0, zap.NewNop())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
<-inLookup // G1 now holds the write lock; cache is still empty
|
||||
|
||||
// G2: passes the empty read-lock check, then blocks on the write lock.
|
||||
go func() {
|
||||
_, _ = SRV(context.Background(), lookup, "svc-conc", "tcp", "x", time.Minute, 0, zap.NewNop())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
time.Sleep(50 * time.Millisecond) // let G2 queue on srvMu.Lock()
|
||||
|
||||
close(release) // G1 populates the cache and releases the lock
|
||||
<-done
|
||||
<-done
|
||||
|
||||
mu.Lock()
|
||||
got := calls
|
||||
mu.Unlock()
|
||||
if got != 1 {
|
||||
t.Fatalf("lookup calls = %d, want 1 (second goroutine must hit the cache re-check)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAConcurrentRefreshDeduplicates is the A-cache equivalent of the above.
|
||||
func TestAConcurrentRefreshDeduplicates(t *testing.T) {
|
||||
aMu.Lock()
|
||||
aCache = make(map[string]cacheEntry)
|
||||
aMu.Unlock()
|
||||
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
inLookup := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
lookup := func(context.Context, string, string) ([]net.IP, error) {
|
||||
mu.Lock()
|
||||
calls++
|
||||
first := calls == 1
|
||||
mu.Unlock()
|
||||
if first {
|
||||
close(inLookup)
|
||||
<-release
|
||||
}
|
||||
return []net.IP{net.ParseIP("10.0.0.1")}, nil
|
||||
}
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
go func() {
|
||||
_, _ = A(context.Background(), lookup, "ip", "db.conc", "5432", time.Minute, zap.NewNop())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
<-inLookup
|
||||
|
||||
go func() {
|
||||
_, _ = A(context.Background(), lookup, "ip", "db.conc", "5432", time.Minute, zap.NewNop())
|
||||
done <- struct{}{}
|
||||
}()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
close(release)
|
||||
<-done
|
||||
<-done
|
||||
|
||||
mu.Lock()
|
||||
got := calls
|
||||
mu.Unlock()
|
||||
if got != 1 {
|
||||
t.Fatalf("lookup calls = %d, want 1 (second goroutine must hit the cache re-check)", got)
|
||||
}
|
||||
}
|
||||
|
|
@ -7,14 +7,13 @@ import (
|
|||
weakrand "math/rand/v2"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/dynamicupstreams"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -120,101 +119,42 @@ func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
|
|||
}
|
||||
|
||||
func (su *SRVUpstreams) ResetCache(r *http.Request) error {
|
||||
srvsMu.Lock()
|
||||
if r == nil {
|
||||
srvs = make(map[string]srvLookup)
|
||||
} else {
|
||||
suAddr, _, _, _ := su.expandedAddr(r)
|
||||
delete(srvs, suAddr)
|
||||
dynamicupstreams.ResetAllSRV()
|
||||
return nil
|
||||
}
|
||||
srvsMu.Unlock()
|
||||
_, service, proto, name := su.expandedAddr(r)
|
||||
dynamicupstreams.ResetSRV(service, proto, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
|
||||
suAddr, service, proto, name := su.expandedAddr(r)
|
||||
_, service, proto, name := su.expandedAddr(r)
|
||||
|
||||
// first, use a cheap read-lock to return a cached result quickly
|
||||
srvsMu.RLock()
|
||||
cached := srvs[suAddr]
|
||||
srvsMu.RUnlock()
|
||||
if cached.isFresh() {
|
||||
return allNew(cached.upstreams), nil
|
||||
}
|
||||
|
||||
// otherwise, obtain a write-lock to update the cached value
|
||||
srvsMu.Lock()
|
||||
defer srvsMu.Unlock()
|
||||
|
||||
// check to see if it's still stale, since we're now in a different
|
||||
// lock from when we first checked freshness; another goroutine might
|
||||
// have refreshed it in the meantime before we re-obtained our lock
|
||||
cached = srvs[suAddr]
|
||||
if cached.isFresh() {
|
||||
return allNew(cached.upstreams), nil
|
||||
}
|
||||
|
||||
if c := su.logger.Check(zapcore.DebugLevel, "refreshing SRV upstreams"); c != nil {
|
||||
c.Write(
|
||||
zap.String("service", service),
|
||||
zap.String("proto", proto),
|
||||
zap.String("name", name),
|
||||
)
|
||||
}
|
||||
|
||||
_, records, err := su.resolver.LookupSRV(r.Context(), service, proto, name)
|
||||
targets, err := dynamicupstreams.SRV(r.Context(), su.resolver.LookupSRV,
|
||||
service, proto, name,
|
||||
time.Duration(su.Refresh), time.Duration(su.GracePeriod), su.logger)
|
||||
if err != nil {
|
||||
// From LookupSRV docs: "If the response contains invalid names, those records are filtered
|
||||
// out and an error will be returned alongside the remaining results, if any." Thus, we
|
||||
// only return an error if no records were also returned.
|
||||
if len(records) == 0 {
|
||||
if su.GracePeriod > 0 {
|
||||
if c := su.logger.Check(zapcore.ErrorLevel, "SRV lookup failed; using previously cached"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
cached.freshness = time.Now().Add(time.Duration(su.GracePeriod) - time.Duration(su.Refresh))
|
||||
srvs[suAddr] = cached
|
||||
return allNew(cached.upstreams), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if c := su.logger.Check(zapcore.WarnLevel, "SRV records filtered"); c != nil {
|
||||
c.Write(zap.Error(err))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreams := make([]Upstream, len(records))
|
||||
for i, rec := range records {
|
||||
upstreams := make([]*Upstream, len(targets))
|
||||
for i, t := range targets {
|
||||
if c := su.logger.Check(zapcore.DebugLevel, "discovered SRV record"); c != nil {
|
||||
c.Write(
|
||||
zap.String("target", rec.Target),
|
||||
zap.Uint16("port", rec.Port),
|
||||
zap.Uint16("priority", rec.Priority),
|
||||
zap.Uint16("weight", rec.Weight),
|
||||
zap.String("target", t.Host),
|
||||
zap.String("port", t.Port),
|
||||
zap.Uint16("priority", t.Priority),
|
||||
zap.Uint16("weight", t.Weight),
|
||||
)
|
||||
}
|
||||
addr := net.JoinHostPort(rec.Target, strconv.Itoa(int(rec.Port)))
|
||||
addr := net.JoinHostPort(t.Host, t.Port)
|
||||
if su.DialNetwork != "" {
|
||||
addr = su.DialNetwork + "/" + addr
|
||||
}
|
||||
upstreams[i] = Upstream{Dial: addr}
|
||||
upstreams[i] = &Upstream{Dial: addr}
|
||||
}
|
||||
|
||||
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
|
||||
if cached.freshness.IsZero() && len(srvs) >= 100 {
|
||||
for randomKey := range srvs {
|
||||
delete(srvs, randomKey)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
srvs[suAddr] = srvLookup{
|
||||
srvUpstreams: su,
|
||||
freshness: time.Now(),
|
||||
upstreams: upstreams,
|
||||
}
|
||||
|
||||
return allNew(upstreams), nil
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
func (su SRVUpstreams) String() string {
|
||||
|
|
@ -247,16 +187,6 @@ func (SRVUpstreams) formattedAddr(service, proto, name string) string {
|
|||
return fmt.Sprintf("_%s._%s.%s", service, proto, name)
|
||||
}
|
||||
|
||||
type srvLookup struct {
|
||||
srvUpstreams SRVUpstreams
|
||||
freshness time.Time
|
||||
upstreams []Upstream
|
||||
}
|
||||
|
||||
func (sl srvLookup) isFresh() bool {
|
||||
return time.Since(sl.freshness) < time.Duration(sl.srvUpstreams.Refresh)
|
||||
}
|
||||
|
||||
type IPVersions struct {
|
||||
IPv4 *bool `json:"ipv4,omitempty"`
|
||||
IPv6 *bool `json:"ipv6,omitempty"`
|
||||
|
|
@ -357,93 +287,28 @@ func (au *AUpstreams) Provision(ctx caddy.Context) error {
|
|||
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
|
||||
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
|
||||
|
||||
// Map ipVersion early, so we can use it as part of the cache-key.
|
||||
// This should be fairly inexpensive and comes and the upside of
|
||||
// allowing the same dynamic upstream (name + port combination)
|
||||
// to be used multiple times with different ip versions.
|
||||
//
|
||||
// It also forced a cache-miss if a previously cached dynamic
|
||||
// upstream changes its ip version, e.g. after a config reload,
|
||||
// while keeping the cache-invalidation as simple as it currently is.
|
||||
ipVersion := resolveIpVersion(au.Versions)
|
||||
|
||||
auStr := repl.ReplaceAll(au.String()+ipVersion, "")
|
||||
|
||||
// first, use a cheap read-lock to return a cached result quickly
|
||||
aAaaaMu.RLock()
|
||||
cached := aAaaa[auStr]
|
||||
aAaaaMu.RUnlock()
|
||||
if cached.isFresh() {
|
||||
return allNew(cached.upstreams), nil
|
||||
}
|
||||
|
||||
// otherwise, obtain a write-lock to update the cached value
|
||||
aAaaaMu.Lock()
|
||||
defer aAaaaMu.Unlock()
|
||||
|
||||
// check to see if it's still stale, since we're now in a different
|
||||
// lock from when we first checked freshness; another goroutine might
|
||||
// have refreshed it in the meantime before we re-obtained our lock
|
||||
cached = aAaaa[auStr]
|
||||
if cached.isFresh() {
|
||||
return allNew(cached.upstreams), nil
|
||||
}
|
||||
|
||||
name := repl.ReplaceAll(au.Name, "")
|
||||
port := repl.ReplaceAll(au.Port, "")
|
||||
|
||||
if c := au.logger.Check(zapcore.DebugLevel, "refreshing A upstreams"); c != nil {
|
||||
c.Write(
|
||||
zap.String("version", ipVersion),
|
||||
zap.String("name", name),
|
||||
zap.String("port", port),
|
||||
)
|
||||
}
|
||||
|
||||
ips, err := au.resolver.LookupIP(r.Context(), ipVersion, name)
|
||||
targets, err := dynamicupstreams.A(r.Context(), au.resolver.LookupIP,
|
||||
ipVersion, name, port, time.Duration(au.Refresh), au.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upstreams := make([]Upstream, len(ips))
|
||||
for i, ip := range ips {
|
||||
upstreams := make([]*Upstream, len(targets))
|
||||
for i, t := range targets {
|
||||
if c := au.logger.Check(zapcore.DebugLevel, "discovered A record"); c != nil {
|
||||
c.Write(zap.String("ip", ip.String()))
|
||||
}
|
||||
upstreams[i] = Upstream{
|
||||
Dial: net.JoinHostPort(ip.String(), port),
|
||||
c.Write(zap.String("ip", t.Host))
|
||||
}
|
||||
upstreams[i] = &Upstream{Dial: net.JoinHostPort(t.Host, t.Port)}
|
||||
}
|
||||
|
||||
// before adding a new one to the cache (as opposed to replacing stale one), make room if cache is full
|
||||
if cached.freshness.IsZero() && len(aAaaa) >= 100 {
|
||||
for randomKey := range aAaaa {
|
||||
delete(aAaaa, randomKey)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
aAaaa[auStr] = aLookup{
|
||||
aUpstreams: au,
|
||||
freshness: time.Now(),
|
||||
upstreams: upstreams,
|
||||
}
|
||||
|
||||
return allNew(upstreams), nil
|
||||
return upstreams, nil
|
||||
}
|
||||
|
||||
func (au AUpstreams) String() string { return net.JoinHostPort(au.Name, au.Port) }
|
||||
|
||||
type aLookup struct {
|
||||
aUpstreams AUpstreams
|
||||
freshness time.Time
|
||||
upstreams []Upstream
|
||||
}
|
||||
|
||||
func (al aLookup) isFresh() bool {
|
||||
return time.Since(al.freshness) < time.Duration(al.aUpstreams.Refresh)
|
||||
}
|
||||
|
||||
// MultiUpstreams is a single dynamic upstream source that
|
||||
// aggregates the results of multiple dynamic upstream sources.
|
||||
// All configured sources will be queried in order, with their
|
||||
|
|
@ -548,22 +413,6 @@ func (u *UpstreamResolver) ParseAddresses() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func allNew(upstreams []Upstream) []*Upstream {
|
||||
results := make([]*Upstream, len(upstreams))
|
||||
for i := range upstreams {
|
||||
results[i] = &Upstream{Dial: upstreams[i].Dial}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
var (
|
||||
srvs = make(map[string]srvLookup)
|
||||
srvsMu sync.RWMutex
|
||||
|
||||
aAaaa = make(map[string]aLookup)
|
||||
aAaaaMu sync.RWMutex
|
||||
)
|
||||
|
||||
// Interface guards
|
||||
var (
|
||||
_ caddy.Provisioner = (*SRVUpstreams)(nil)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue