mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2026-06-27 19:31:26 +00:00
Pull request: AGDNS-3523-upd-dnsproxy
Squashed commit of the following: commit 26fda761b768c8fa97c018276813aaf67182be63 Merge: 3a03dda2d31dc811ffAuthor: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Mar 13 14:48:12 2026 +0700 Merge remote-tracking branch 'origin/master' into AGDNS-3523-upd-dnsproxy-middlewares commit 3a03dda2df5620d2c27b041fd214e7a80046bdad Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Mar 13 09:07:03 2026 +0700 all: upd dnsproxy commit 95d669788e87aa63eb79dbc6dba7259196164dde Merge: 3f1ab07f35558089e7Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Thu Mar 12 10:19:45 2026 +0700 Merge remote-tracking branch 'origin/master' into AGDNS-3523-upd-dnsproxy-middlewares # Conflicts: # go.mod # go.sum # internal/dnsforward/config.go # internal/dnsforward/middleware.go commit 3f1ab07f36082d27417462e403428db7b6e222b3 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Wed Mar 11 13:38:52 2026 +0700 dnsforward: client id in ctx commit c9be3e2fc37170d12cddcfc3157200d1f7ad955b Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Wed Mar 11 13:24:02 2026 +0700 dnsforward: todo commit 99206ebe3549dfe01edb1ceef4af6a270d2a7572 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Wed Mar 11 12:32:06 2026 +0700 dnsforward: add log middleware commit 4a2222b6f2d17bc47c385b86ec36f357c4513d63 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Tue Mar 10 12:30:55 2026 +0700 all: upd dnsproxy fix commit b6435001c41b50c072f52b1e24e01b5ecdd4e1fb Merge: 963832d1cad9cb3e8dAuthor: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Tue Mar 10 08:51:43 2026 +0700 Merge remote-tracking branch 'origin/master' into AGDNS-3523-upd-dnsproxy-middlewares commit 963832d1c812182606ae7916da491b9d3a9ecc5e Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Mon Mar 9 15:39:39 2026 +0700 all: upd dnsproxy commit 961b69368db927f0f72c225c71e554117d606e9d Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Tue Mar 3 09:00:50 2026 +0700 all: upd dnsproxy commit 38c7733939eeed4f1d3e571ecb01f8fc82cae0de Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Mon Mar 2 10:41:40 2026 +0700 dnsforward: imp tests commit 5a0f466bb87bf79cefb254e2df8fb9d78e7584c4 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Feb 27 11:49:58 2026 +0700 dnsforward: imp tests commit 6de11697d27edad2c821ea8f88054d31cb912fdb Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Feb 27 11:46:39 2026 +0700 dnsforward: imp tests commit 52fb326145a5c7be868cedb96a8ffe9573c6b41f Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Feb 27 11:42:16 2026 +0700 dnsforward: imp tests commit bc3f42165de26fa378fe53bd98f2f98903924a98 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Feb 27 11:25:46 2026 +0700 dnsforward: imp code commit 8d1c43192c9672d66350f206248716ee8faa0f57 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Wed Feb 25 11:31:12 2026 +0700 all: upd dnsproxy, rm beforerequest
This commit is contained in:
parent
31dc811ffc
commit
2d49f9dc96
12 changed files with 228 additions and 102 deletions
2
go.mod
2
go.mod
|
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
|
|||
go 1.25.7
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.80.0
|
||||
github.com/AdguardTeam/dnsproxy v0.81.0
|
||||
github.com/AdguardTeam/golibs v0.35.8
|
||||
github.com/AdguardTeam/urlfilter v0.23.2
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -4,8 +4,8 @@ cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
|
|||
cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
github.com/AdguardTeam/dnsproxy v0.80.0 h1:PttgkZfnAe9itH8vGVhpOTS9FLGZq8A49Qa7l4+/11Q=
|
||||
github.com/AdguardTeam/dnsproxy v0.80.0/go.mod h1:gwr+7Dc0e7QddQLC9JLGjL5NSKcqw0ESsNMRI5Q67Ps=
|
||||
github.com/AdguardTeam/dnsproxy v0.81.0 h1:derWNPHd25PbQ2eSpEpg/dw7d7VRA4dNaP5GdG+qEJ8=
|
||||
github.com/AdguardTeam/dnsproxy v0.81.0/go.mod h1:gwr+7Dc0e7QddQLC9JLGjL5NSKcqw0ESsNMRI5Q67Ps=
|
||||
github.com/AdguardTeam/golibs v0.35.8 h1:KsyF3SWwj05Ey4GiAWU6FGD9oJTDNMp1ixVdS+Nw50M=
|
||||
github.com/AdguardTeam/golibs v0.35.8/go.mod h1:kuLQ0yNRTl0Em2FmmXtSri7ZdVT7p62oojyc51RvP38=
|
||||
github.com/AdguardTeam/urlfilter v0.23.2 h1:EiS/PQZO/X2S6cduFW6BBoRLyjd6SqZj1ZiFbU1KaFE=
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ const testTimeout = 1 * time.Second
|
|||
|
||||
// StartLocalhostUpstream is a test helper that starts a DNS server on
|
||||
// localhost.
|
||||
func StartLocalhostUpstream(tb *testing.T, h dns.Handler) (addr *url.URL) {
|
||||
func StartLocalhostUpstream(tb testing.TB, h dns.Handler) (addr *url.URL) {
|
||||
tb.Helper()
|
||||
|
||||
startCh := make(chan netip.AddrPort)
|
||||
|
|
|
|||
|
|
@ -333,6 +333,8 @@ func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err er
|
|||
return nil, fmt.Errorf("ratelimit middleware: %w", err)
|
||||
}
|
||||
|
||||
logMw := newLogMiddleware(s.baseLogger, slogutil.LevelTrace)
|
||||
|
||||
httpConf := &proxy.HTTPConfig{
|
||||
ServerHeader: aghhttp.UserAgent(),
|
||||
InsecureEnabled: s.conf.TLSAllowUnencryptedDoH,
|
||||
|
|
@ -349,7 +351,7 @@ func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err er
|
|||
CacheOptimisticMaxAge: time.Duration(srvConf.CacheOptimisticMaxAge),
|
||||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||
RequestHandler: ratelimitMw.Wrap(s.Wrap(s)),
|
||||
RequestHandler: ratelimitMw.Wrap(logMw.Wrap(s.Wrap(s))),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
UseDNS64: srvConf.UseDNS64,
|
||||
|
|
|
|||
34
internal/dnsforward/context.go
Normal file
34
internal/dnsforward/context.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ctxKey is the type for context keys.
|
||||
type ctxKey int
|
||||
|
||||
// Context key values.
|
||||
const (
|
||||
ctxKeyClientID ctxKey = iota
|
||||
)
|
||||
|
||||
// contextWithClientID returns a new context with the given ID.
|
||||
func contextWithClientID(parent context.Context, id string) (ctx context.Context) {
|
||||
return context.WithValue(parent, ctxKeyClientID, id)
|
||||
}
|
||||
|
||||
// clientIDFromContext returns ID for this request, if any.
|
||||
func clientIDFromContext(ctx context.Context) (id string, ok bool) {
|
||||
v := ctx.Value(ctxKeyClientID)
|
||||
if v == nil {
|
||||
return id, false
|
||||
}
|
||||
|
||||
id, ok = v.(string)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("bad type for ctxKeyClientID: %T(%[1]v)", v))
|
||||
}
|
||||
|
||||
return id, true
|
||||
}
|
||||
|
|
@ -26,7 +26,6 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
|
|
@ -43,11 +42,6 @@ const DefaultTimeout = 10 * time.Second
|
|||
// faster than ordinary upstreams.
|
||||
const defaultLocalTimeout = 1 * time.Second
|
||||
|
||||
// defaultClientIDCacheCount is the default count of items in the LRU ClientID
|
||||
// cache. The assumption here is that there won't be more than this many
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
|
||||
var defaultDNS = []string{
|
||||
"https://dns10.quad9.net/dns-query",
|
||||
}
|
||||
|
|
@ -110,10 +104,6 @@ type Server struct {
|
|||
// bootstrap is the resolver for upstreams' hostnames.
|
||||
bootstrap upstream.Resolver
|
||||
|
||||
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
||||
// during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// dhcpServer is the DHCP server for accessing lease data.
|
||||
dhcpServer DHCP
|
||||
|
||||
|
|
@ -261,11 +251,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
|||
// TODO(e.burkov): Use some case-insensitive string comparison.
|
||||
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
||||
etcHosts: etcHosts,
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
}),
|
||||
anonymizer: p.Anonymizer,
|
||||
anonymizer: p.Anonymizer,
|
||||
conf: ServerConfig{
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
|
|
@ -418,7 +404,7 @@ func (s *Server) Exchange(
|
|||
} else {
|
||||
errMsg = "resolving an address: %w"
|
||||
}
|
||||
if err = s.internalProxy.Resolve(dctx); err != nil {
|
||||
if err = s.internalProxy.Resolve(ctx, dctx); err != nil {
|
||||
return "", 0, fmt.Errorf(errMsg, err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,15 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/syncutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
|
@ -16,18 +19,15 @@ var _ proxy.Middleware = (*Server)(nil)
|
|||
|
||||
// Wrap implements the [proxy.Middleware] interface for *Server.
|
||||
//
|
||||
// TODO(d.kolyshev): Move to a dedicated package.
|
||||
// TODO(d.kolyshev): Move to a dedicated package.
|
||||
// TODO(d.kolyshev): Use logger from context.
|
||||
func (s *Server) Wrap(h proxy.Handler) (wrapped proxy.Handler) {
|
||||
f := func(p *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
|
||||
// TODO(f.setrakov): Obtain context from arguments.
|
||||
ctx := context.TODO()
|
||||
|
||||
f := func(ctx context.Context, p *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
|
||||
clientID, err := s.clientIDFromDNSContext(ctx, pctx)
|
||||
if err != nil {
|
||||
s.logger.WarnContext(ctx, "resolving client id", slogutil.KeyError, err)
|
||||
|
||||
pctx.Res = s.NewMsgSERVFAIL(pctx.Req)
|
||||
pctx.Res.Compress = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -43,19 +43,17 @@ func (s *Server) Wrap(h proxy.Handler) (wrapped proxy.Handler) {
|
|||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
ctx = contextWithClientID(ctx, clientID)
|
||||
}
|
||||
|
||||
return h.ServeDNS(p, pctx)
|
||||
return h.ServeDNS(ctx, p, pctx)
|
||||
}
|
||||
|
||||
return proxy.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// serveBlockedResponse sets a protocol-appropriate response for a request that
|
||||
// was blocked by access settings.
|
||||
// was blocked by access settings. pctx must be filled with the request.
|
||||
func (s *Server) serveBlockedResponse(pctx *proxy.DNSContext) (err error) {
|
||||
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
|
||||
// Return nil so that dnsproxy drops the connection and thus prevent DNS
|
||||
|
|
@ -64,7 +62,6 @@ func (s *Server) serveBlockedResponse(pctx *proxy.DNSContext) (err error) {
|
|||
}
|
||||
|
||||
pctx.Res = s.makeResponseREFUSED(pctx.Req)
|
||||
pctx.Res.Compress = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -135,3 +132,83 @@ func (s *Server) clientIDFromDNSContext(
|
|||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// logMiddleware adds a logger using [slogutil.ContextWithLogger] and logs the
|
||||
// starts and ends of queries at a given level.
|
||||
type logMiddleware struct {
|
||||
attrPool *syncutil.Pool[[]slog.Attr]
|
||||
logger *slog.Logger
|
||||
lvl slog.Level
|
||||
}
|
||||
|
||||
// logMwAttrNum is the number of attributes used by the logger set by
|
||||
// [logMiddleware].
|
||||
const logMwAttrNum = 3
|
||||
|
||||
// newLogMiddleware returns a new *logMiddleware with l as the base logger.
|
||||
func newLogMiddleware(l *slog.Logger, lvl slog.Level) (mw *logMiddleware) {
|
||||
return &logMiddleware{
|
||||
attrPool: syncutil.NewSlicePool[slog.Attr](logMwAttrNum),
|
||||
logger: l,
|
||||
lvl: lvl,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ proxy.Middleware = (*logMiddleware)(nil)
|
||||
|
||||
// Wrap implements the [proxy.Middleware] interface for *logMiddleware. It adds
|
||||
// a logger to the context and logs the starts and ends of queries at a given
|
||||
// level.
|
||||
func (m *logMiddleware) Wrap(h proxy.Handler) (wrapped proxy.Handler) {
|
||||
f := func(ctx context.Context, p *proxy.Proxy, dctx *proxy.DNSContext) (err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
attrsPtr := m.attrsSlicePtr(dctx.Req)
|
||||
defer m.attrPool.Put(attrsPtr)
|
||||
|
||||
logHdlr := m.logger.Handler().WithAttrs(*attrsPtr)
|
||||
l := slog.New(logHdlr)
|
||||
ctx = slogutil.ContextWithLogger(ctx, l)
|
||||
|
||||
l.Log(ctx, m.lvl, "started")
|
||||
defer m.logFinished(ctx, l, startTime)
|
||||
|
||||
return h.ServeDNS(ctx, p, dctx)
|
||||
}
|
||||
|
||||
return proxy.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// attrsSlicePtr returns a pointer to a slice with the attributes from the
|
||||
// request set. Callers should defer returning attrsPtr back to the pool.
|
||||
func (m *logMiddleware) attrsSlicePtr(r *dns.Msg) (attrsPtr *[]slog.Attr) {
|
||||
attrsPtr = m.attrPool.Get()
|
||||
|
||||
attrs := *attrsPtr
|
||||
|
||||
// Optimize bounds checking.
|
||||
_ = attrs[logMwAttrNum-1]
|
||||
|
||||
attrs[0] = slog.Uint64("id", uint64(r.Id))
|
||||
|
||||
if len(r.Question) > 0 {
|
||||
q := r.Question[0]
|
||||
attrs[1] = slog.String("qtype", dns.Type(q.Qtype).String())
|
||||
attrs[2] = slog.String("target", q.Name)
|
||||
} else {
|
||||
attrs[1] = slog.Attr{}
|
||||
attrs[2] = slog.Attr{}
|
||||
}
|
||||
|
||||
return attrsPtr
|
||||
}
|
||||
|
||||
// logFinished is called at the end of handling of a query.
|
||||
func (m *logMiddleware) logFinished(ctx context.Context, l *slog.Logger, startTime time.Time) {
|
||||
if !l.Enabled(ctx, m.lvl) {
|
||||
return
|
||||
}
|
||||
|
||||
l.Log(ctx, m.lvl, "finished", "elapsed", timeutil.Duration(time.Since(startTime)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,23 +102,8 @@ func TestServer_middlewareTLS(t *testing.T) {
|
|||
wantRCode: dns.RcodeRefused,
|
||||
}}
|
||||
|
||||
localAns := []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: testFQDN,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
Rdlength: 4,
|
||||
},
|
||||
A: net.IP{1, 2, 3, 4},
|
||||
}}
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = localAns
|
||||
|
||||
require.NoError(t, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
localAns := newTestDNSAnswer(testFQDN, net.IP{1, 2, 3, 4})
|
||||
localUpsAddr := newTestUpstream(t, localAns)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
@ -130,7 +115,6 @@ func TestServer_middlewareTLS(t *testing.T) {
|
|||
})
|
||||
|
||||
s.conf.UpstreamDNS = []string{localUpsAddr}
|
||||
|
||||
s.conf.AllowedClients = tc.allowedClients
|
||||
s.conf.DisallowedClients = tc.disallowedClients
|
||||
s.conf.BlockedHosts = tc.blockedHosts
|
||||
|
|
@ -140,16 +124,7 @@ func TestServer_middlewareTLS(t *testing.T) {
|
|||
|
||||
startDeferStop(t, s)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: tc.clientSrvName,
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
TLSConfig: tlsConfig,
|
||||
Timeout: dnsClientTimeout,
|
||||
}
|
||||
client := newTestTCPClient(tc.clientSrvName)
|
||||
|
||||
req := createTestMessage(tc.host)
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoTLS).String()
|
||||
|
|
@ -157,11 +132,10 @@ func TestServer_middlewareTLS(t *testing.T) {
|
|||
reply, _, err := client.Exchange(req, addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, reply.Rcode)
|
||||
if tc.wantRCode == dns.RcodeSuccess {
|
||||
assert.Equal(t, localAns, reply.Answer)
|
||||
assertSuccessResponse(t, reply, localAns)
|
||||
} else {
|
||||
assert.Empty(t, reply.Answer)
|
||||
assertRejectedResponse(t, reply, tc.wantRCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -235,23 +209,8 @@ func TestServer_middlewareUDP(t *testing.T) {
|
|||
wantTimeout: true,
|
||||
}}
|
||||
|
||||
localAns := []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: testFQDN,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
Rdlength: 4,
|
||||
},
|
||||
A: net.IP{1, 2, 3, 4},
|
||||
}}
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = localAns
|
||||
|
||||
require.NoError(t, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
localAns := newTestDNSAnswer(testFQDN, net.IP{1, 2, 3, 4})
|
||||
localUpsAddr := newTestUpstream(t, localAns)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
|
@ -287,18 +246,87 @@ func TestServer_middlewareUDP(t *testing.T) {
|
|||
|
||||
reply, _, err := client.Exchange(req, addr)
|
||||
if tc.wantTimeout {
|
||||
wantErr := &net.OpError{}
|
||||
require.ErrorAs(t, err, &wantErr)
|
||||
assert.True(t, wantErr.Timeout())
|
||||
|
||||
assert.Nil(t, reply)
|
||||
assertTimeoutError(t, err, reply)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, reply)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
|
||||
assert.Equal(t, localAns, reply.Answer)
|
||||
assertSuccessResponse(t, reply, localAns)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newTestDNSAnswer creates a standard A record answer for testing.
|
||||
func newTestDNSAnswer(fqdn string, ip net.IP) (ans []dns.RR) {
|
||||
return []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: fqdn,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 3600,
|
||||
Rdlength: 4,
|
||||
},
|
||||
A: ip,
|
||||
}}
|
||||
}
|
||||
|
||||
// newTestUpstream creates a test upstream handler and returns its address.
|
||||
func newTestUpstream(tb testing.TB, answer []dns.RR) (addr string) {
|
||||
tb.Helper()
|
||||
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := (&dns.Msg{}).SetReply(req)
|
||||
resp.Answer = answer
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
})
|
||||
|
||||
return aghtest.StartLocalhostUpstream(tb, handler).String()
|
||||
}
|
||||
|
||||
// newTestTCPClient creates a new TCP client for testing.
|
||||
func newTestTCPClient(clientSrvName string) (c *dns.Client) {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: clientSrvName,
|
||||
}
|
||||
|
||||
return &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
TLSConfig: tlsConfig,
|
||||
Timeout: dnsClientTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// assertSuccessResponse checks that the response is successful with expected
|
||||
// answer.
|
||||
func assertSuccessResponse(tb testing.TB, reply *dns.Msg, expectedAns []dns.RR) {
|
||||
tb.Helper()
|
||||
|
||||
require.NotNil(tb, reply)
|
||||
|
||||
assert.Equal(tb, dns.RcodeSuccess, reply.Rcode)
|
||||
assert.Equal(tb, expectedAns, reply.Answer)
|
||||
}
|
||||
|
||||
// assertRejectedResponse checks that the response has the expected error code
|
||||
// and no answer.
|
||||
func assertRejectedResponse(tb testing.TB, reply *dns.Msg, wantRCode int) {
|
||||
tb.Helper()
|
||||
|
||||
require.NotNil(tb, reply)
|
||||
|
||||
assert.Equal(tb, wantRCode, reply.Rcode)
|
||||
assert.Empty(tb, reply.Answer)
|
||||
}
|
||||
|
||||
// assertTimeoutError checks that the error is a timeout error and reply is nil.
|
||||
func assertTimeoutError(tb testing.TB, err error, reply *dns.Msg) {
|
||||
tb.Helper()
|
||||
|
||||
wantErr := &net.OpError{}
|
||||
require.ErrorAs(tb, err, &wantErr)
|
||||
|
||||
assert.True(tb, wantErr.Timeout())
|
||||
assert.Nil(tb, reply)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -350,7 +350,7 @@ func (s *Server) genBlockedHost(
|
|||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
err = prx.Resolve(newContext)
|
||||
err = prx.Resolve(ctx, newContext)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(
|
||||
ctx,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
|
@ -130,9 +129,10 @@ func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resul
|
|||
|
||||
// Get the ClientID, if any, before getting client-specific filtering
|
||||
// settings.
|
||||
var key [8]byte
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
dctx.clientID = string(s.clientIDCache.Get(key[:]))
|
||||
clientID, ok := clientIDFromContext(ctx)
|
||||
if ok {
|
||||
dctx.clientID = clientID
|
||||
}
|
||||
|
||||
// Get the client-specific filtering settings.
|
||||
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus(ctx)
|
||||
|
|
@ -455,7 +455,7 @@ func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resu
|
|||
return resultCodeError
|
||||
}
|
||||
|
||||
if dctx.err = prx.Resolve(pctx); dctx.err != nil {
|
||||
if dctx.err = prx.Resolve(ctx, pctx); dctx.err != nil {
|
||||
return resultCodeError
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,10 +12,9 @@ import (
|
|||
var _ proxy.Handler = (*Server)(nil)
|
||||
|
||||
// ServeDNS implements the [proxy.Handler] interface for [*Server].
|
||||
func (s *Server) ServeDNS(_ *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
//
|
||||
// TODO(d.kolyshev): Use logger from context.
|
||||
func (s *Server) ServeDNS(ctx context.Context, _ *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: pctx,
|
||||
result: &filtering.Result{},
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ func TestServer_ServeDNS(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.ServeDNS(nil, dctx)
|
||||
err = s.ServeDNS(testutil.ContextWithTimeout(t, testTimeout), nil, dctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dctx.Res)
|
||||
|
||||
|
|
@ -335,7 +335,7 @@ func TestServer_ServeDNS_restrictLocal(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.ServeDNS(s.dnsProxy, pctx)
|
||||
err = s.ServeDNS(testutil.ContextWithTimeout(t, testTimeout), s.dnsProxy, pctx)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
require.NotNil(t, pctx.Res)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue