Pull request: AGDNS-3523-use-dnsproxy-context

Squashed commit of the following:

commit 685390143a84710b42f9948cf46080f7d86c404e
Merge: eb023e9dd 8824fc791
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Fri Mar 27 20:47:11 2026 +0700

    Merge remote-tracking branch 'origin/master' into AGDNS-3523-use-dnsproxy-context

commit eb023e9ddffee03c859b112cfefad79708af4cf6
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Fri Mar 27 15:24:01 2026 +0700

    bamboo: add arm64 test

commit 93acedaf81ba7ffde7020f59978c54365b7db766
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Fri Mar 27 15:22:20 2026 +0700

    bamboo: add arm64 test

commit 99a00e9e2199c522adfbc1e22c05a009f86170ee
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed Mar 25 08:17:46 2026 +0700

    dnsforward: imp code

commit e5863521b2c34466e610c0a2de94a749cbefd9dd
Merge: 9783620a7 141501408
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Mon Mar 23 12:07:01 2026 +0700

    Merge remote-tracking branch 'origin/master' into AGDNS-3523-use-dnsproxy-context

commit 9783620a7ed4ec30860da1afd40cedd5e8f78472
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Mon Mar 23 11:55:33 2026 +0700

    dnsforward: imp code

commit 13cb50319d2bfcad5cf9215d610c90314d633bc6
Author: Dimitry Kolyshev <dkolyshev@adguard.com>
Date:   Wed Mar 11 13:20:12 2026 +0700

    dnsforward: use log middleware
This commit is contained in:
Dimitry Kolyshev 2026-03-30 03:48:58 +00:00
parent 8824fc7911
commit 55d0d6ae59
14 changed files with 188 additions and 121 deletions

View file

@ -50,7 +50,7 @@ func clientIDFromClientServerName(
}
// clientIDFromDNSContextHTTPS extracts the ClientID from the path of the
// client's DNS-over-HTTPS request.
// client's DNS-over-HTTPS request. pctx must not be nil.
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
r := pctx.HTTPRequest
if r == nil {

View file

@ -229,7 +229,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
clientID, err := srv.clientIDFromDNSContext(ctx, pctx)
clientID, err := srv.clientIDFromDNSContext(ctx, testLogger, pctx)
assert.Equal(t, tc.wantClientID, clientID)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)

View file

@ -3,6 +3,7 @@ package dnsforward
import (
"context"
"fmt"
"log/slog"
"net"
"slices"
"strings"
@ -23,9 +24,10 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
}
// filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the
// request was filtered.
// request was filtered. l and dctx must not be nil.
func (s *Server) filterDNSRequest(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (res *filtering.Result, err error) {
pctx := dctx.proxyCtx
@ -47,8 +49,8 @@ func (s *Server) filterDNSRequest(
dctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
case res.IsFiltered:
s.logger.DebugContext(ctx, "host is filtered", "host", host, "reason", res.Reason)
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
l.DebugContext(ctx, "host is filtered", "reason", res.Reason)
pctx.Res = s.genDNSFilterMessage(ctx, l, pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
pctx.Res = s.getCNAMEWithIPs(ctx, req, res.IPList, res.CanonName)
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
@ -92,8 +94,13 @@ func (s *Server) checkHostRules(
// filterDNSResponse checks each resource record of answer section of
// dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
func (s *Server) filterDNSResponse(ctx context.Context, dctx *dnsContext) (err error) {
// rules, as well as sets dctx.proxyCtx.Res to the filtered response. l and
// dctx must not be nil.
func (s *Server) filterDNSResponse(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (err error) {
setts := dctx.setts
if !setts.FilteringEnabled {
return nil
@ -126,27 +133,16 @@ func (s *Server) filterDNSResponse(ctx context.Context, dctx *dnsContext) (err e
continue
}
s.logger.DebugContext(
ctx,
"checked",
"dns_type", dns.Type(rrtype),
"host", host,
"name", a.Header().Name,
)
l.DebugContext(ctx, "checked", "name", a.Header().Name)
if err != nil {
return fmt.Errorf("filtering answer at index %d: %w", i, err)
} else if res != nil && res.IsFiltered {
dctx.result = res
dctx.origResp = pctx.Res
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
pctx.Res = s.genDNSFilterMessage(ctx, l, pctx, res)
s.logger.DebugContext(
ctx,
"matched by response",
"name", pctx.Req.Question[0].Name,
"host", host,
)
l.DebugContext(ctx, "matched by response", "name", pctx.Req.Question[0].Name)
break
}

View file

@ -154,7 +154,8 @@ func TestServer_filterDNSResponse(t *testing.T) {
},
}
fltErr := s.filterDNSResponse(testutil.ContextWithTimeout(t, testTimeout), dctx)
ctx := testutil.ContextWithTimeout(t, testTimeout)
fltErr := s.filterDNSResponse(ctx, testLogger, dctx)
require.NoError(t, fltErr)
res := dctx.result

View file

@ -120,10 +120,15 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
return ip4s, ip6s
}
// process adds the resolved IP addresses to the domain's ipsets, if any.
func (h *ipsetHandler) process(ctx context.Context, dctx *dnsContext) (rc resultCode) {
h.logger.DebugContext(ctx, "started processing")
defer h.logger.DebugContext(ctx, "finished processing")
// process adds the resolved IP addresses to the domain's ipsets, if any. l and
// dctx must not be nil.
func (h *ipsetHandler) process(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing")
defer l.DebugContext(ctx, "finished processing")
if h.skipIpsetProcessing(dctx) {
return resultCodeSuccess
@ -138,12 +143,12 @@ func (h *ipsetHandler) process(ctx context.Context, dctx *dnsContext) (rc result
n, err := h.ipsetMgr.Add(ctx, host, ip4s, ip6s)
if err != nil {
// Consider ipset errors non-critical to the request.
h.logger.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)
l.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)
return resultCodeSuccess
}
h.logger.DebugContext(ctx, "added new ipset entries", "num", n)
l.DebugContext(ctx, "added new ipset entries", "num", n)
return resultCodeSuccess
}

View file

@ -63,7 +63,7 @@ func TestIpsetCtx_process(t *testing.T) {
ictx := &ipsetHandler{
logger: testLogger,
}
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
assert.Equal(t, resultCodeSuccess, rc)
err := ictx.close()
@ -86,7 +86,7 @@ func TestIpsetCtx_process(t *testing.T) {
logger: testLogger,
}
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
assert.Equal(t, resultCodeSuccess, rc)
assert.Equal(t, []net.IP{ip4}, m.ip4s)
assert.Empty(t, m.ip6s)
@ -111,7 +111,7 @@ func TestIpsetCtx_process(t *testing.T) {
logger: testLogger,
}
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
assert.Equal(t, resultCodeSuccess, rc)
assert.Empty(t, m.ip4s)
assert.Equal(t, []net.IP{ip6}, m.ip6s)

View file

@ -17,13 +17,15 @@ import (
// type check
var _ proxy.Middleware = (*Server)(nil)
// Wrap implements the [proxy.Middleware] interface for *Server.
// Wrap implements the [proxy.Middleware] interface for *Server. ctx must
// contain a logger accessible with [slogutil.LoggerFromContext].
//
// TODO(d.kolyshev): Move to a dedicated package.
// TODO(d.kolyshev): Use logger from context.
func (s *Server) Wrap(h proxy.Handler) (wrapped proxy.Handler) {
f := func(ctx context.Context, p *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
clientID, err := s.clientIDFromDNSContext(ctx, pctx)
l := slogutil.MustLoggerFromContext(ctx)
clientID, err := s.clientIDFromDNSContext(ctx, l, pctx)
if err != nil {
s.logger.WarnContext(ctx, "resolving client id", slogutil.KeyError, err)
@ -37,7 +39,7 @@ func (s *Server) Wrap(h proxy.Handler) (wrapped proxy.Handler) {
return s.serveBlockedResponse(pctx)
}
blocked = s.isBlockedHost(ctx, pctx.Req.Question)
blocked = s.isBlockedHost(ctx, l, pctx.Req.Question)
if blocked {
return s.serveBlockedResponse(pctx)
}
@ -66,8 +68,13 @@ func (s *Server) serveBlockedResponse(pctx *proxy.DNSContext) (err error) {
return nil
}
// isBlockedHost checks if the request is in the access blocklist.
func (s *Server) isBlockedHost(ctx context.Context, question []dns.Question) (blocked bool) {
// isBlockedHost checks if the request is in the access blocklist. l must not
// be nil.
func (s *Server) isBlockedHost(
ctx context.Context,
l *slog.Logger,
question []dns.Question,
) (blocked bool) {
if len(question) != 1 {
return false
}
@ -77,12 +84,7 @@ func (s *Server) isBlockedHost(ctx context.Context, question []dns.Question) (bl
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
s.logger.DebugContext(
ctx,
"request is in access blocklist",
"dns_type", dns.Type(qt),
"host", host,
)
l.DebugContext(ctx, "request is in access blocklist")
return true
}
@ -92,9 +94,11 @@ func (s *Server) isBlockedHost(ctx context.Context, question []dns.Question) (bl
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
// is not one of these, clientID is an empty string and err is nil. l and pctx
// must not be nil.
func (s *Server) clientIDFromDNSContext(
ctx context.Context,
l *slog.Logger,
pctx *proxy.DNSContext,
) (clientID string, err error) {
proto := pctx.Proto
@ -116,7 +120,7 @@ func (s *Server) clientIDFromDNSContext(
return "", nil
}
cliSrvName, err := clientServerName(ctx, s.logger, pctx, proto)
cliSrvName, err := clientServerName(ctx, l, pctx, proto)
if err != nil {
return "", fmt.Errorf("getting client server-name: %w", err)
}
@ -135,6 +139,8 @@ func (s *Server) clientIDFromDNSContext(
// logMiddleware adds a logger using [slogutil.ContextWithLogger] and logs the
// starts and ends of queries at a given level.
//
// TODO(d.kolyshev): Consider moving to dnsproxy.
type logMiddleware struct {
attrPool *syncutil.Pool[[]slog.Attr]
logger *slog.Logger

View file

@ -2,6 +2,7 @@ package dnsforward
import (
"context"
"log/slog"
"net/netip"
"slices"
@ -46,9 +47,10 @@ func ipsFromRules(resRules []*filtering.ResultRule) (ips []netip.Addr) {
}
// genDNSFilterMessage generates a filtered response to req for the filtering
// result res.
// result res. l, dctx, and res must not be nil.
func (s *Server) genDNSFilterMessage(
ctx context.Context,
l *slog.Logger,
dctx *proxy.DNSContext,
res *filtering.Result,
) (resp *dns.Msg) {
@ -65,9 +67,9 @@ func (s *Server) genDNSFilterMessage(
switch res.Reason {
case filtering.FilteredSafeBrowsing:
return s.genBlockedHost(ctx, req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
return s.genBlockedHost(ctx, l, req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
case filtering.FilteredParental:
return s.genBlockedHost(ctx, req, s.dnsFilter.ParentalBlockHost(), dctx)
return s.genBlockedHost(ctx, l, req, s.dnsFilter.ParentalBlockHost(), dctx)
case filtering.FilteredSafeSearch:
// If Safe Search generated the necessary IP addresses, use them.
// Otherwise, if there were no errors, there are no addresses for the
@ -315,14 +317,17 @@ func (s *Server) makeResponseNullIP(ctx context.Context, req *dns.Msg) (resp *dn
return resp
}
// genBlockedHost generates a blocked host response. l, request, and d must not
// be nil.
func (s *Server) genBlockedHost(
ctx context.Context,
l *slog.Logger,
request *dns.Msg,
newAddr string,
d *proxy.DNSContext,
) (msg *dns.Msg) {
if newAddr == "" {
s.logger.InfoContext(ctx, "block host not specified")
l.InfoContext(ctx, "block host not specified")
return s.NewMsgSERVFAIL(request)
}
@ -345,14 +350,14 @@ func (s *Server) genBlockedHost(
prx := s.proxy()
if prx == nil {
s.logger.DebugContext(ctx, "getting current proxy", slogutil.KeyError, srvClosedErr)
l.DebugContext(ctx, "getting current proxy", slogutil.KeyError, srvClosedErr)
return s.NewMsgSERVFAIL(request)
}
err = prx.Resolve(ctx, newContext)
if err != nil {
s.logger.ErrorContext(
l.ErrorContext(
ctx,
"looking up replacement host",
"host", newAddr,

View file

@ -2,6 +2,7 @@ package dnsforward
import (
"context"
"log/slog"
"net"
"net/netip"
"strings"
@ -96,15 +97,20 @@ const mozillaFQDN = "use-application-dns.net."
const healthcheckFQDN = "healthcheck.adguardhome.test."
// processInitial terminates the following processing for some requests if
// needed and enriches dctx with some client-specific information.
// needed and enriches dctx with some client-specific information. l and dctx
// must not be nil.
//
// TODO(e.burkov): Decompose into less general processors.
func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing initial")
defer s.logger.DebugContext(ctx, "finished processing initial")
func (s *Server) processInitial(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing initial")
defer l.DebugContext(ctx, "finished processing initial")
pctx := dctx.proxyCtx
s.processClientIP(ctx, pctx.Addr.Addr())
s.processClientIP(ctx, l, pctx.Addr.Addr())
q := pctx.Req.Question[0]
qt := q.Qtype
@ -141,10 +147,11 @@ func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resul
return resultCodeSuccess
}
// processClientIP sends the client IP address to s.addrProc, if needed.
func (s *Server) processClientIP(ctx context.Context, addr netip.Addr) {
// processClientIP sends the client IP address to s.addrProc, if needed. l must
// not be nil.
func (s *Server) processClientIP(ctx context.Context, l *slog.Logger, addr netip.Addr) {
if !addr.IsValid() {
s.logger.WarnContext(ctx, "bad client address", "addr", addr)
l.WarnContext(ctx, "bad client address", "addr", addr)
return
}
@ -159,12 +166,16 @@ func (s *Server) processClientIP(ctx context.Context, addr netip.Addr) {
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
// queries. The response contains different types of encryption supported by
// current user configuration.
// current user configuration. l and dctx must not be nil.
//
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) processDDRQuery(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing ddr")
defer s.logger.DebugContext(ctx, "finished processing ddr")
func (s *Server) processDDRQuery(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing ddr")
defer l.DebugContext(ctx, "finished processing ddr")
if !s.conf.HandleDDR {
return resultCodeSuccess
@ -258,12 +269,16 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
// processDHCPHosts respond to A requests if the target hostname is known to
// the server. It responds with a mapped IP address if the DNS64 is enabled and
// the request is for AAAA.
// the request is for AAAA. l and dctx must not be nil.
//
// TODO(a.garipov): Adapt to AAAA as well.
func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing dhcp hosts")
defer s.logger.DebugContext(ctx, "finished processing dhcp hosts")
func (s *Server) processDHCPHosts(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing dhcp hosts")
defer l.DebugContext(ctx, "finished processing dhcp hosts")
pctx := dctx.proxyCtx
req := pctx.Req
@ -275,7 +290,7 @@ func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc res
}
if !pctx.IsPrivateClient {
s.logger.DebugContext(
l.DebugContext(
ctx,
"requests for dhcp host",
"addr", pctx.Addr,
@ -291,12 +306,12 @@ func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc res
if ip == (netip.Addr{}) {
// Go on and process them with filters, including dnsrewrite ones, and
// possibly route them to a domain-specific upstream.
s.logger.DebugContext(ctx, "no dhcp record", "dhcp_host", dhcpHost)
l.DebugContext(ctx, "no dhcp record", "dhcp_host", dhcpHost)
return resultCodeSuccess
}
s.logger.DebugContext(ctx, "dhcp record for", "dhcp_host", dhcpHost, "ip", ip)
l.DebugContext(ctx, "dhcp record for", "dhcp_host", dhcpHost, "ip", ip)
resp := s.replyCompressed(req)
switch q.Qtype {
@ -326,10 +341,14 @@ func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc res
}
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server.
func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing dhcp addrs")
defer s.logger.DebugContext(ctx, "finished processing dhcp addrs")
// DHCP server. l and dctx must not be nil.
func (s *Server) processDHCPAddrs(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing dhcp addrs")
defer l.DebugContext(ctx, "finished processing dhcp addrs")
pctx := dctx.proxyCtx
if pctx.Res != nil {
@ -351,7 +370,7 @@ func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc res
return resultCodeSuccess
}
s.logger.DebugContext(ctx, "dhcp client", "addr", addr, "host", host)
l.DebugContext(ctx, "dhcp client", "addr", addr, "host", host)
resp := s.replyCompressed(req)
ptr := &dns.PTR{
@ -371,13 +390,15 @@ func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc res
return resultCodeSuccess
}
// Apply filtering logic
// processFilteringBeforeRequest applies filtering logic. l and dctx must not
// be nil.
func (s *Server) processFilteringBeforeRequest(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing filtering before request")
defer s.logger.DebugContext(ctx, "finished processing filtering before request")
l.DebugContext(ctx, "started processing filtering before request")
defer l.DebugContext(ctx, "finished processing filtering before request")
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
// There is no need to filter request for locally served ARPA hostname
@ -397,7 +418,7 @@ func (s *Server) processFilteringBeforeRequest(
defer s.serverLock.RUnlock()
var err error
if dctx.result, err = s.filterDNSRequest(ctx, dctx); err != nil {
if dctx.result, err = s.filterDNSRequest(ctx, l, dctx); err != nil {
dctx.err = err
return resultCodeError
@ -416,9 +437,14 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
}
// processUpstream passes request to upstream servers and handles the response.
func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing upstream")
defer s.logger.DebugContext(ctx, "finished processing upstream")
// l and dctx must not be nil.
func (s *Server) processUpstream(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing upstream")
defer l.DebugContext(ctx, "finished processing upstream")
pctx := dctx.proxyCtx
req := pctx.Req
@ -433,7 +459,7 @@ func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resu
// TODO(a.garipov): Route such queries to a custom upstream for the
// local domain name if there is one.
name := req.Question[0].Name
s.logger.DebugContext(
l.DebugContext(
ctx,
"dhcp client hostname was not filtered",
"hostname", name[:len(name)-1],
@ -443,7 +469,7 @@ func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resu
return resultCodeFinish
}
s.setCustomUpstream(ctx, pctx, dctx.clientID)
s.setCustomUpstream(ctx, l, pctx, dctx.clientID)
reqWantsDNSSEC := s.setReqAD(req)
@ -532,8 +558,14 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
return reqHost[:len(reqHost)-len(s.localDomainSuffix)-1]
}
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext, clientID string) {
// setCustomUpstream sets custom upstream settings in pctx, if necessary. l and
// pctx must not be nil.
func (s *Server) setCustomUpstream(
ctx context.Context,
l *slog.Logger,
pctx *proxy.DNSContext,
clientID string,
) {
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
return
}
@ -541,7 +573,7 @@ func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext,
cliAddr := pctx.Addr.Addr()
upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr)
if upsConf != nil {
s.logger.DebugContext(
l.DebugContext(
ctx,
"using custom upstreams for client with",
"ip", cliAddr,
@ -552,10 +584,15 @@ func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext,
}
}
// Apply filtering logic after we have received response from upstream servers
func (s *Server) processFilteringAfterResponse(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing filtering after response")
defer s.logger.DebugContext(ctx, "finished processing filtering after response")
// Apply filtering logic after we have received response from upstream servers.
// l and dctx must not be nil.
func (s *Server) processFilteringAfterResponse(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing filtering after response")
defer l.DebugContext(ctx, "finished processing filtering after response")
switch res := dctx.result; res.Reason {
case filtering.NotFilteredAllowList:
@ -580,13 +617,17 @@ func (s *Server) processFilteringAfterResponse(ctx context.Context, dctx *dnsCon
return resultCodeSuccess
default:
return s.filterAfterResponse(ctx, dctx)
return s.filterAfterResponse(ctx, l, dctx)
}
}
// filterAfterResponse returns the result of filtering the response that wasn't
// explicitly allowed or rewritten.
func (s *Server) filterAfterResponse(ctx context.Context, dctx *dnsContext) (res resultCode) {
// explicitly allowed or rewritten. l and dctx must not be nil.
func (s *Server) filterAfterResponse(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (res resultCode) {
// Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway.
@ -594,7 +635,7 @@ func (s *Server) filterAfterResponse(ctx context.Context, dctx *dnsContext) (res
return resultCodeSuccess
}
err := s.filterDNSResponse(ctx, dctx)
err := s.filterDNSResponse(ctx, l, dctx)
if err != nil {
dctx.err = err

View file

@ -104,7 +104,7 @@ func TestServer_ProcessInitial(t *testing.T) {
},
}
gotRC := s.processInitial(testutil.ContextWithTimeout(t, testTimeout), dctx)
gotRC := s.processInitial(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, testClientAddrPort.Addr(), gotAddr)
@ -208,7 +208,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
},
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
gotRC := s.processFilteringAfterResponse(ctx, dctx)
gotRC := s.processFilteringAfterResponse(ctx, testLogger, dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, newResp(dns.RcodeSuccess, tc.req, tc.wantRespAns), dctx.proxyCtx.Res)
})
@ -356,7 +356,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
},
}
res := s.processDDRQuery(testutil.ContextWithTimeout(t, testTimeout), dctx)
res := s.processDDRQuery(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
require.Equal(t, tc.wantRes, res)
if tc.wantRes != resultCodeFinish {
@ -464,7 +464,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
},
}
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
pctx := dctx.proxyCtx
if !tc.isLocalCli {
@ -609,7 +609,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
pctx := dctx.proxyCtx
assert.Equal(t, tc.wantRes, res)
require.NoError(t, dctx.err)
@ -685,7 +685,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
)
ctx := testutil.ContextWithTimeout(t, testTimeout)
pctx := newPrxCtx()
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
rc := s.processUpstream(ctx, testLogger, &dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, pctx.Res.Answer)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
@ -716,7 +716,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
pctx := newPrxCtx()
ctx := testutil.ContextWithTimeout(t, testTimeout)
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
rc := s.processUpstream(ctx, testLogger, &dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeError, rc)
require.Empty(t, pctx.Res.Answer)
})

View file

@ -2,18 +2,19 @@ package dnsforward
import (
"context"
"log/slog"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// type check
var _ proxy.Handler = (*Server)(nil)
// ServeDNS implements the [proxy.Handler] interface for [*Server].
//
// TODO(d.kolyshev): Use logger from context.
// ServeDNS implements the [proxy.Handler] interface for [*Server]. ctx must
// contain a logger accessible with [slogutil.LoggerFromContext].
func (s *Server) ServeDNS(ctx context.Context, _ *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
dctx := &dnsContext{
proxyCtx: pctx,
@ -21,7 +22,9 @@ func (s *Server) ServeDNS(ctx context.Context, _ *proxy.Proxy, pctx *proxy.DNSCo
startTime: time.Now(),
}
type modProcessFunc func(ctx context.Context, dctx *dnsContext) (rc resultCode)
l := slogutil.MustLoggerFromContext(ctx)
type modProcessFunc func(ctx context.Context, l *slog.Logger, dctx *dnsContext) (rc resultCode)
// Since [*dnsforward.Server] is used as [proxy.Handler], there is no need
// for additional index out of range checking in any of the following
@ -39,7 +42,7 @@ func (s *Server) ServeDNS(ctx context.Context, _ *proxy.Proxy, pctx *proxy.DNSCo
s.processQueryLogsAndStats,
}
for _, process := range mods {
r := process(ctx, dctx)
r := process(ctx, l, dctx)
switch r {
case resultCodeSuccess:
// continue: call the next filter

View file

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
@ -200,7 +201,10 @@ func TestServer_ServeDNS(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
err = s.ServeDNS(testutil.ContextWithTimeout(t, testTimeout), nil, dctx)
ctx := testutil.ContextWithTimeout(t, testTimeout)
ctx = slogutil.ContextWithLogger(ctx, testLogger)
err = s.ServeDNS(ctx, nil, dctx)
require.NoError(t, err)
require.NotNil(t, dctx.Res)
@ -335,7 +339,10 @@ func TestServer_ServeDNS_restrictLocal(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
err = s.ServeDNS(testutil.ContextWithTimeout(t, testTimeout), s.dnsProxy, pctx)
ctx := testutil.ContextWithTimeout(t, testTimeout)
ctx = slogutil.ContextWithLogger(ctx, testLogger)
err = s.ServeDNS(ctx, s.dnsProxy, pctx)
require.ErrorIs(t, err, tc.wantErr)
require.NotNil(t, pctx.Res)

View file

@ -2,6 +2,7 @@ package dnsforward
import (
"context"
"log/slog"
"net"
"time"
@ -13,10 +14,15 @@ import (
"github.com/miekg/dns"
)
// Write Stats data and logs
func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing querylog and stats")
defer s.logger.DebugContext(ctx, "finished processing querylog and stats")
// processQueryLogsAndStats writes stats data and logs. l and dctx must not be
// nil.
func (s *Server) processQueryLogsAndStats(
ctx context.Context,
l *slog.Logger,
dctx *dnsContext,
) (rc resultCode) {
l.DebugContext(ctx, "started processing querylog and stats")
defer l.DebugContext(ctx, "finished processing querylog and stats")
pctx := dctx.proxyCtx
q := pctx.Req.Question[0]
@ -27,7 +33,7 @@ func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext)
s.anonymizer.Load()(ip)
ipStr := net.IP(ip).String()
s.logger.DebugContext(ctx, "client ip for stats and querylog", "ip", ipStr)
l.DebugContext(ctx, "client ip for stats and querylog", "ip", ipStr)
ids := []string{ipStr}
if dctx.clientID != "" {
@ -47,12 +53,10 @@ func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext)
if s.shouldLog(host, qt, cl, ids) {
s.logQuery(dctx, ip, processingTime)
} else {
s.logger.DebugContext(
l.DebugContext(
ctx,
"not adding to querylog",
"dns_class", dns.Class(cl),
"dns_type", dns.Type(qt),
"host", host,
"ip", ipStr,
)
}
@ -60,12 +64,10 @@ func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext)
if s.shouldCountStat(host, qt, cl, ids) {
s.updateStats(dctx, ipStr, processingTime)
} else {
s.logger.DebugContext(
l.DebugContext(
ctx,
"not counting in stats",
"dns_class", dns.Class(cl),
"dns_type", dns.Type(qt),
"host", host,
"ip", ipStr,
)
}

View file

@ -231,7 +231,8 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
clientID: tc.clientID,
}
code := srv.processQueryLogsAndStats(testutil.ContextWithTimeout(t, testTimeout), dctx)
ctx := testutil.ContextWithTimeout(t, testTimeout)
code := srv.processQueryLogsAndStats(ctx, testLogger, dctx)
assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)