mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2026-06-27 19:31:26 +00:00
Pull request: AGDNS-3523-use-dnsproxy-context
Squashed commit of the following: commit 685390143a84710b42f9948cf46080f7d86c404e Merge: eb023e9dd8824fc791Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Mar 27 20:47:11 2026 +0700 Merge remote-tracking branch 'origin/master' into AGDNS-3523-use-dnsproxy-context commit eb023e9ddffee03c859b112cfefad79708af4cf6 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Mar 27 15:24:01 2026 +0700 bamboo: add arm64 test commit 93acedaf81ba7ffde7020f59978c54365b7db766 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Mar 27 15:22:20 2026 +0700 bamboo: add arm64 test commit 99a00e9e2199c522adfbc1e22c05a009f86170ee Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Wed Mar 25 08:17:46 2026 +0700 dnsforward: imp code commit e5863521b2c34466e610c0a2de94a749cbefd9dd Merge: 9783620a7141501408Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Mon Mar 23 12:07:01 2026 +0700 Merge remote-tracking branch 'origin/master' into AGDNS-3523-use-dnsproxy-context commit 9783620a7ed4ec30860da1afd40cedd5e8f78472 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Mon Mar 23 11:55:33 2026 +0700 dnsforward: imp code commit 13cb50319d2bfcad5cf9215d610c90314d633bc6 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Wed Mar 11 13:20:12 2026 +0700 dnsforward: use log middleware
This commit is contained in:
parent
8824fc7911
commit
55d0d6ae59
14 changed files with 188 additions and 121 deletions
|
|
@ -50,7 +50,7 @@ func clientIDFromClientServerName(
|
|||
}
|
||||
|
||||
// clientIDFromDNSContextHTTPS extracts the ClientID from the path of the
|
||||
// client's DNS-over-HTTPS request.
|
||||
// client's DNS-over-HTTPS request. pctx must not be nil.
|
||||
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
r := pctx.HTTPRequest
|
||||
if r == nil {
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
|||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
clientID, err := srv.clientIDFromDNSContext(ctx, pctx)
|
||||
clientID, err := srv.clientIDFromDNSContext(ctx, testLogger, pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package dnsforward
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
|
@ -23,9 +24,10 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
|
|||
}
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the
|
||||
// request was filtered.
|
||||
// request was filtered. l and dctx must not be nil.
|
||||
func (s *Server) filterDNSRequest(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (res *filtering.Result, err error) {
|
||||
pctx := dctx.proxyCtx
|
||||
|
|
@ -47,8 +49,8 @@ func (s *Server) filterDNSRequest(
|
|||
dctx.origQuestion = q
|
||||
req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
case res.IsFiltered:
|
||||
s.logger.DebugContext(ctx, "host is filtered", "host", host, "reason", res.Reason)
|
||||
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
|
||||
l.DebugContext(ctx, "host is filtered", "reason", res.Reason)
|
||||
pctx.Res = s.genDNSFilterMessage(ctx, l, pctx, res)
|
||||
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
|
||||
pctx.Res = s.getCNAMEWithIPs(ctx, req, res.IPList, res.CanonName)
|
||||
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
|
||||
|
|
@ -92,8 +94,13 @@ func (s *Server) checkHostRules(
|
|||
// filterDNSResponse checks each resource record of answer section of
|
||||
// dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
|
||||
// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
|
||||
// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
|
||||
func (s *Server) filterDNSResponse(ctx context.Context, dctx *dnsContext) (err error) {
|
||||
// rules, as well as sets dctx.proxyCtx.Res to the filtered response. l and
|
||||
// dctx must not be nil.
|
||||
func (s *Server) filterDNSResponse(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (err error) {
|
||||
setts := dctx.setts
|
||||
if !setts.FilteringEnabled {
|
||||
return nil
|
||||
|
|
@ -126,27 +133,16 @@ func (s *Server) filterDNSResponse(ctx context.Context, dctx *dnsContext) (err e
|
|||
continue
|
||||
}
|
||||
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"checked",
|
||||
"dns_type", dns.Type(rrtype),
|
||||
"host", host,
|
||||
"name", a.Header().Name,
|
||||
)
|
||||
l.DebugContext(ctx, "checked", "name", a.Header().Name)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("filtering answer at index %d: %w", i, err)
|
||||
} else if res != nil && res.IsFiltered {
|
||||
dctx.result = res
|
||||
dctx.origResp = pctx.Res
|
||||
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
|
||||
pctx.Res = s.genDNSFilterMessage(ctx, l, pctx, res)
|
||||
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"matched by response",
|
||||
"name", pctx.Req.Question[0].Name,
|
||||
"host", host,
|
||||
)
|
||||
l.DebugContext(ctx, "matched by response", "name", pctx.Req.Question[0].Name)
|
||||
|
||||
break
|
||||
}
|
||||
|
|
|
|||
|
|
@ -154,7 +154,8 @@ func TestServer_filterDNSResponse(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
fltErr := s.filterDNSResponse(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
fltErr := s.filterDNSResponse(ctx, testLogger, dctx)
|
||||
require.NoError(t, fltErr)
|
||||
|
||||
res := dctx.result
|
||||
|
|
|
|||
|
|
@ -120,10 +120,15 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
|
|||
return ip4s, ip6s
|
||||
}
|
||||
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any.
|
||||
func (h *ipsetHandler) process(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
h.logger.DebugContext(ctx, "started processing")
|
||||
defer h.logger.DebugContext(ctx, "finished processing")
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any. l and
|
||||
// dctx must not be nil.
|
||||
func (h *ipsetHandler) process(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing")
|
||||
defer l.DebugContext(ctx, "finished processing")
|
||||
|
||||
if h.skipIpsetProcessing(dctx) {
|
||||
return resultCodeSuccess
|
||||
|
|
@ -138,12 +143,12 @@ func (h *ipsetHandler) process(ctx context.Context, dctx *dnsContext) (rc result
|
|||
n, err := h.ipsetMgr.Add(ctx, host, ip4s, ip6s)
|
||||
if err != nil {
|
||||
// Consider ipset errors non-critical to the request.
|
||||
h.logger.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)
|
||||
l.ErrorContext(ctx, "adding host ips", slogutil.KeyError, err)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
h.logger.DebugContext(ctx, "added new ipset entries", "num", n)
|
||||
l.DebugContext(ctx, "added new ipset entries", "num", n)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func TestIpsetCtx_process(t *testing.T) {
|
|||
ictx := &ipsetHandler{
|
||||
logger: testLogger,
|
||||
}
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
|
||||
assert.Equal(t, resultCodeSuccess, rc)
|
||||
|
||||
err := ictx.close()
|
||||
|
|
@ -86,7 +86,7 @@ func TestIpsetCtx_process(t *testing.T) {
|
|||
logger: testLogger,
|
||||
}
|
||||
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
|
||||
assert.Equal(t, resultCodeSuccess, rc)
|
||||
assert.Equal(t, []net.IP{ip4}, m.ip4s)
|
||||
assert.Empty(t, m.ip6s)
|
||||
|
|
@ -111,7 +111,7 @@ func TestIpsetCtx_process(t *testing.T) {
|
|||
logger: testLogger,
|
||||
}
|
||||
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
|
||||
assert.Equal(t, resultCodeSuccess, rc)
|
||||
assert.Empty(t, m.ip4s)
|
||||
assert.Equal(t, []net.IP{ip6}, m.ip6s)
|
||||
|
|
|
|||
|
|
@ -17,13 +17,15 @@ import (
|
|||
// type check
|
||||
var _ proxy.Middleware = (*Server)(nil)
|
||||
|
||||
// Wrap implements the [proxy.Middleware] interface for *Server.
|
||||
// Wrap implements the [proxy.Middleware] interface for *Server. ctx must
|
||||
// contain a logger accessible with [slogutil.LoggerFromContext].
|
||||
//
|
||||
// TODO(d.kolyshev): Move to a dedicated package.
|
||||
// TODO(d.kolyshev): Use logger from context.
|
||||
func (s *Server) Wrap(h proxy.Handler) (wrapped proxy.Handler) {
|
||||
f := func(ctx context.Context, p *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
|
||||
clientID, err := s.clientIDFromDNSContext(ctx, pctx)
|
||||
l := slogutil.MustLoggerFromContext(ctx)
|
||||
|
||||
clientID, err := s.clientIDFromDNSContext(ctx, l, pctx)
|
||||
if err != nil {
|
||||
s.logger.WarnContext(ctx, "resolving client id", slogutil.KeyError, err)
|
||||
|
||||
|
|
@ -37,7 +39,7 @@ func (s *Server) Wrap(h proxy.Handler) (wrapped proxy.Handler) {
|
|||
return s.serveBlockedResponse(pctx)
|
||||
}
|
||||
|
||||
blocked = s.isBlockedHost(ctx, pctx.Req.Question)
|
||||
blocked = s.isBlockedHost(ctx, l, pctx.Req.Question)
|
||||
if blocked {
|
||||
return s.serveBlockedResponse(pctx)
|
||||
}
|
||||
|
|
@ -66,8 +68,13 @@ func (s *Server) serveBlockedResponse(pctx *proxy.DNSContext) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// isBlockedHost checks if the request is in the access blocklist.
|
||||
func (s *Server) isBlockedHost(ctx context.Context, question []dns.Question) (blocked bool) {
|
||||
// isBlockedHost checks if the request is in the access blocklist. l must not
|
||||
// be nil.
|
||||
func (s *Server) isBlockedHost(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
question []dns.Question,
|
||||
) (blocked bool) {
|
||||
if len(question) != 1 {
|
||||
return false
|
||||
}
|
||||
|
|
@ -77,12 +84,7 @@ func (s *Server) isBlockedHost(ctx context.Context, question []dns.Question) (bl
|
|||
host := aghnet.NormalizeDomain(q.Name)
|
||||
|
||||
if s.access.isBlockedHost(host, qt) {
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"request is in access blocklist",
|
||||
"dns_type", dns.Type(qt),
|
||||
"host", host,
|
||||
)
|
||||
l.DebugContext(ctx, "request is in access blocklist")
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
@ -92,9 +94,11 @@ func (s *Server) isBlockedHost(ctx context.Context, question []dns.Question) (bl
|
|||
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
// is not one of these, clientID is an empty string and err is nil. l and pctx
|
||||
// must not be nil.
|
||||
func (s *Server) clientIDFromDNSContext(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
pctx *proxy.DNSContext,
|
||||
) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
|
|
@ -116,7 +120,7 @@ func (s *Server) clientIDFromDNSContext(
|
|||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName, err := clientServerName(ctx, s.logger, pctx, proto)
|
||||
cliSrvName, err := clientServerName(ctx, l, pctx, proto)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting client server-name: %w", err)
|
||||
}
|
||||
|
|
@ -135,6 +139,8 @@ func (s *Server) clientIDFromDNSContext(
|
|||
|
||||
// logMiddleware adds a logger using [slogutil.ContextWithLogger] and logs the
|
||||
// starts and ends of queries at a given level.
|
||||
//
|
||||
// TODO(d.kolyshev): Consider moving to dnsproxy.
|
||||
type logMiddleware struct {
|
||||
attrPool *syncutil.Pool[[]slog.Attr]
|
||||
logger *slog.Logger
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
|
|
@ -46,9 +47,10 @@ func ipsFromRules(resRules []*filtering.ResultRule) (ips []netip.Addr) {
|
|||
}
|
||||
|
||||
// genDNSFilterMessage generates a filtered response to req for the filtering
|
||||
// result res.
|
||||
// result res. l, dctx, and res must not be nil.
|
||||
func (s *Server) genDNSFilterMessage(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *proxy.DNSContext,
|
||||
res *filtering.Result,
|
||||
) (resp *dns.Msg) {
|
||||
|
|
@ -65,9 +67,9 @@ func (s *Server) genDNSFilterMessage(
|
|||
|
||||
switch res.Reason {
|
||||
case filtering.FilteredSafeBrowsing:
|
||||
return s.genBlockedHost(ctx, req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
|
||||
return s.genBlockedHost(ctx, l, req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
|
||||
case filtering.FilteredParental:
|
||||
return s.genBlockedHost(ctx, req, s.dnsFilter.ParentalBlockHost(), dctx)
|
||||
return s.genBlockedHost(ctx, l, req, s.dnsFilter.ParentalBlockHost(), dctx)
|
||||
case filtering.FilteredSafeSearch:
|
||||
// If Safe Search generated the necessary IP addresses, use them.
|
||||
// Otherwise, if there were no errors, there are no addresses for the
|
||||
|
|
@ -315,14 +317,17 @@ func (s *Server) makeResponseNullIP(ctx context.Context, req *dns.Msg) (resp *dn
|
|||
return resp
|
||||
}
|
||||
|
||||
// genBlockedHost generates a blocked host response. l, request, and d must not
|
||||
// be nil.
|
||||
func (s *Server) genBlockedHost(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
request *dns.Msg,
|
||||
newAddr string,
|
||||
d *proxy.DNSContext,
|
||||
) (msg *dns.Msg) {
|
||||
if newAddr == "" {
|
||||
s.logger.InfoContext(ctx, "block host not specified")
|
||||
l.InfoContext(ctx, "block host not specified")
|
||||
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
|
@ -345,14 +350,14 @@ func (s *Server) genBlockedHost(
|
|||
|
||||
prx := s.proxy()
|
||||
if prx == nil {
|
||||
s.logger.DebugContext(ctx, "getting current proxy", slogutil.KeyError, srvClosedErr)
|
||||
l.DebugContext(ctx, "getting current proxy", slogutil.KeyError, srvClosedErr)
|
||||
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
err = prx.Resolve(ctx, newContext)
|
||||
if err != nil {
|
||||
s.logger.ErrorContext(
|
||||
l.ErrorContext(
|
||||
ctx,
|
||||
"looking up replacement host",
|
||||
"host", newAddr,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
|
@ -96,15 +97,20 @@ const mozillaFQDN = "use-application-dns.net."
|
|||
const healthcheckFQDN = "healthcheck.adguardhome.test."
|
||||
|
||||
// processInitial terminates the following processing for some requests if
|
||||
// needed and enriches dctx with some client-specific information.
|
||||
// needed and enriches dctx with some client-specific information. l and dctx
|
||||
// must not be nil.
|
||||
//
|
||||
// TODO(e.burkov): Decompose into less general processors.
|
||||
func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing initial")
|
||||
defer s.logger.DebugContext(ctx, "finished processing initial")
|
||||
func (s *Server) processInitial(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing initial")
|
||||
defer l.DebugContext(ctx, "finished processing initial")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
s.processClientIP(ctx, pctx.Addr.Addr())
|
||||
s.processClientIP(ctx, l, pctx.Addr.Addr())
|
||||
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
|
|
@ -141,10 +147,11 @@ func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resul
|
|||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processClientIP sends the client IP address to s.addrProc, if needed.
|
||||
func (s *Server) processClientIP(ctx context.Context, addr netip.Addr) {
|
||||
// processClientIP sends the client IP address to s.addrProc, if needed. l must
|
||||
// not be nil.
|
||||
func (s *Server) processClientIP(ctx context.Context, l *slog.Logger, addr netip.Addr) {
|
||||
if !addr.IsValid() {
|
||||
s.logger.WarnContext(ctx, "bad client address", "addr", addr)
|
||||
l.WarnContext(ctx, "bad client address", "addr", addr)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -159,12 +166,16 @@ func (s *Server) processClientIP(ctx context.Context, addr netip.Addr) {
|
|||
|
||||
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
|
||||
// queries. The response contains different types of encryption supported by
|
||||
// current user configuration.
|
||||
// current user configuration. l and dctx must not be nil.
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) processDDRQuery(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing ddr")
|
||||
defer s.logger.DebugContext(ctx, "finished processing ddr")
|
||||
func (s *Server) processDDRQuery(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing ddr")
|
||||
defer l.DebugContext(ctx, "finished processing ddr")
|
||||
|
||||
if !s.conf.HandleDDR {
|
||||
return resultCodeSuccess
|
||||
|
|
@ -258,12 +269,16 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
|||
|
||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||
// the server. It responds with a mapped IP address if the DNS64 is enabled and
|
||||
// the request is for AAAA.
|
||||
// the request is for AAAA. l and dctx must not be nil.
|
||||
//
|
||||
// TODO(a.garipov): Adapt to AAAA as well.
|
||||
func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing dhcp hosts")
|
||||
defer s.logger.DebugContext(ctx, "finished processing dhcp hosts")
|
||||
func (s *Server) processDHCPHosts(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing dhcp hosts")
|
||||
defer l.DebugContext(ctx, "finished processing dhcp hosts")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
|
|
@ -275,7 +290,7 @@ func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc res
|
|||
}
|
||||
|
||||
if !pctx.IsPrivateClient {
|
||||
s.logger.DebugContext(
|
||||
l.DebugContext(
|
||||
ctx,
|
||||
"requests for dhcp host",
|
||||
"addr", pctx.Addr,
|
||||
|
|
@ -291,12 +306,12 @@ func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc res
|
|||
if ip == (netip.Addr{}) {
|
||||
// Go on and process them with filters, including dnsrewrite ones, and
|
||||
// possibly route them to a domain-specific upstream.
|
||||
s.logger.DebugContext(ctx, "no dhcp record", "dhcp_host", dhcpHost)
|
||||
l.DebugContext(ctx, "no dhcp record", "dhcp_host", dhcpHost)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
s.logger.DebugContext(ctx, "dhcp record for", "dhcp_host", dhcpHost, "ip", ip)
|
||||
l.DebugContext(ctx, "dhcp record for", "dhcp_host", dhcpHost, "ip", ip)
|
||||
|
||||
resp := s.replyCompressed(req)
|
||||
switch q.Qtype {
|
||||
|
|
@ -326,10 +341,14 @@ func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc res
|
|||
}
|
||||
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
// DHCP server.
|
||||
func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing dhcp addrs")
|
||||
defer s.logger.DebugContext(ctx, "finished processing dhcp addrs")
|
||||
// DHCP server. l and dctx must not be nil.
|
||||
func (s *Server) processDHCPAddrs(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing dhcp addrs")
|
||||
defer l.DebugContext(ctx, "finished processing dhcp addrs")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
|
|
@ -351,7 +370,7 @@ func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc res
|
|||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
s.logger.DebugContext(ctx, "dhcp client", "addr", addr, "host", host)
|
||||
l.DebugContext(ctx, "dhcp client", "addr", addr, "host", host)
|
||||
|
||||
resp := s.replyCompressed(req)
|
||||
ptr := &dns.PTR{
|
||||
|
|
@ -371,13 +390,15 @@ func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc res
|
|||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
// processFilteringBeforeRequest applies filtering logic. l and dctx must not
|
||||
// be nil.
|
||||
func (s *Server) processFilteringBeforeRequest(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing filtering before request")
|
||||
defer s.logger.DebugContext(ctx, "finished processing filtering before request")
|
||||
l.DebugContext(ctx, "started processing filtering before request")
|
||||
defer l.DebugContext(ctx, "finished processing filtering before request")
|
||||
|
||||
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
|
||||
// There is no need to filter request for locally served ARPA hostname
|
||||
|
|
@ -397,7 +418,7 @@ func (s *Server) processFilteringBeforeRequest(
|
|||
defer s.serverLock.RUnlock()
|
||||
|
||||
var err error
|
||||
if dctx.result, err = s.filterDNSRequest(ctx, dctx); err != nil {
|
||||
if dctx.result, err = s.filterDNSRequest(ctx, l, dctx); err != nil {
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
|
|
@ -416,9 +437,14 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
|||
}
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing upstream")
|
||||
defer s.logger.DebugContext(ctx, "finished processing upstream")
|
||||
// l and dctx must not be nil.
|
||||
func (s *Server) processUpstream(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing upstream")
|
||||
defer l.DebugContext(ctx, "finished processing upstream")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
|
|
@ -433,7 +459,7 @@ func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resu
|
|||
// TODO(a.garipov): Route such queries to a custom upstream for the
|
||||
// local domain name if there is one.
|
||||
name := req.Question[0].Name
|
||||
s.logger.DebugContext(
|
||||
l.DebugContext(
|
||||
ctx,
|
||||
"dhcp client hostname was not filtered",
|
||||
"hostname", name[:len(name)-1],
|
||||
|
|
@ -443,7 +469,7 @@ func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resu
|
|||
return resultCodeFinish
|
||||
}
|
||||
|
||||
s.setCustomUpstream(ctx, pctx, dctx.clientID)
|
||||
s.setCustomUpstream(ctx, l, pctx, dctx.clientID)
|
||||
|
||||
reqWantsDNSSEC := s.setReqAD(req)
|
||||
|
||||
|
|
@ -532,8 +558,14 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
|
|||
return reqHost[:len(reqHost)-len(s.localDomainSuffix)-1]
|
||||
}
|
||||
|
||||
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
|
||||
func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext, clientID string) {
|
||||
// setCustomUpstream sets custom upstream settings in pctx, if necessary. l and
|
||||
// pctx must not be nil.
|
||||
func (s *Server) setCustomUpstream(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
pctx *proxy.DNSContext,
|
||||
clientID string,
|
||||
) {
|
||||
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -541,7 +573,7 @@ func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext,
|
|||
cliAddr := pctx.Addr.Addr()
|
||||
upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr)
|
||||
if upsConf != nil {
|
||||
s.logger.DebugContext(
|
||||
l.DebugContext(
|
||||
ctx,
|
||||
"using custom upstreams for client with",
|
||||
"ip", cliAddr,
|
||||
|
|
@ -552,10 +584,15 @@ func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext,
|
|||
}
|
||||
}
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
func (s *Server) processFilteringAfterResponse(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing filtering after response")
|
||||
defer s.logger.DebugContext(ctx, "finished processing filtering after response")
|
||||
// Apply filtering logic after we have received response from upstream servers.
|
||||
// l and dctx must not be nil.
|
||||
func (s *Server) processFilteringAfterResponse(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing filtering after response")
|
||||
defer l.DebugContext(ctx, "finished processing filtering after response")
|
||||
|
||||
switch res := dctx.result; res.Reason {
|
||||
case filtering.NotFilteredAllowList:
|
||||
|
|
@ -580,13 +617,17 @@ func (s *Server) processFilteringAfterResponse(ctx context.Context, dctx *dnsCon
|
|||
|
||||
return resultCodeSuccess
|
||||
default:
|
||||
return s.filterAfterResponse(ctx, dctx)
|
||||
return s.filterAfterResponse(ctx, l, dctx)
|
||||
}
|
||||
}
|
||||
|
||||
// filterAfterResponse returns the result of filtering the response that wasn't
|
||||
// explicitly allowed or rewritten.
|
||||
func (s *Server) filterAfterResponse(ctx context.Context, dctx *dnsContext) (res resultCode) {
|
||||
// explicitly allowed or rewritten. l and dctx must not be nil.
|
||||
func (s *Server) filterAfterResponse(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (res resultCode) {
|
||||
// Check the response only if it's from an upstream. Don't check the
|
||||
// response if the protection is disabled since dnsrewrite rules aren't
|
||||
// applied to it anyway.
|
||||
|
|
@ -594,7 +635,7 @@ func (s *Server) filterAfterResponse(ctx context.Context, dctx *dnsContext) (res
|
|||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
err := s.filterDNSResponse(ctx, dctx)
|
||||
err := s.filterDNSResponse(ctx, l, dctx)
|
||||
if err != nil {
|
||||
dctx.err = err
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
gotRC := s.processInitial(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
gotRC := s.processInitial(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, testClientAddrPort.Addr(), gotAddr)
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
|||
},
|
||||
}
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
gotRC := s.processFilteringAfterResponse(ctx, dctx)
|
||||
gotRC := s.processFilteringAfterResponse(ctx, testLogger, dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, newResp(dns.RcodeSuccess, tc.req, tc.wantRespAns), dctx.proxyCtx.Res)
|
||||
})
|
||||
|
|
@ -356,7 +356,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
res := s.processDDRQuery(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
res := s.processDDRQuery(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
|
||||
if tc.wantRes != resultCodeFinish {
|
||||
|
|
@ -464,7 +464,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if !tc.isLocalCli {
|
||||
|
|
@ -609,7 +609,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), testLogger, dctx)
|
||||
pctx := dctx.proxyCtx
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
require.NoError(t, dctx.err)
|
||||
|
|
@ -685,7 +685,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
|||
)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
pctx := newPrxCtx()
|
||||
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
|
||||
rc := s.processUpstream(ctx, testLogger, &dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeSuccess, rc)
|
||||
require.NotEmpty(t, pctx.Res.Answer)
|
||||
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
|
||||
|
|
@ -716,7 +716,7 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
|||
pctx := newPrxCtx()
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
|
||||
rc := s.processUpstream(ctx, testLogger, &dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeError, rc)
|
||||
require.Empty(t, pctx.Res.Answer)
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,18 +2,19 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ proxy.Handler = (*Server)(nil)
|
||||
|
||||
// ServeDNS implements the [proxy.Handler] interface for [*Server].
|
||||
//
|
||||
// TODO(d.kolyshev): Use logger from context.
|
||||
// ServeDNS implements the [proxy.Handler] interface for [*Server]. ctx must
|
||||
// contain a logger accessible with [slogutil.LoggerFromContext].
|
||||
func (s *Server) ServeDNS(ctx context.Context, _ *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: pctx,
|
||||
|
|
@ -21,7 +22,9 @@ func (s *Server) ServeDNS(ctx context.Context, _ *proxy.Proxy, pctx *proxy.DNSCo
|
|||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
type modProcessFunc func(ctx context.Context, dctx *dnsContext) (rc resultCode)
|
||||
l := slogutil.MustLoggerFromContext(ctx)
|
||||
|
||||
type modProcessFunc func(ctx context.Context, l *slog.Logger, dctx *dnsContext) (rc resultCode)
|
||||
|
||||
// Since [*dnsforward.Server] is used as [proxy.Handler], there is no need
|
||||
// for additional index out of range checking in any of the following
|
||||
|
|
@ -39,7 +42,7 @@ func (s *Server) ServeDNS(ctx context.Context, _ *proxy.Proxy, pctx *proxy.DNSCo
|
|||
s.processQueryLogsAndStats,
|
||||
}
|
||||
for _, process := range mods {
|
||||
r := process(ctx, dctx)
|
||||
r := process(ctx, l, dctx)
|
||||
switch r {
|
||||
case resultCodeSuccess:
|
||||
// continue: call the next filter
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
|
|
@ -200,7 +201,10 @@ func TestServer_ServeDNS(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.ServeDNS(testutil.ContextWithTimeout(t, testTimeout), nil, dctx)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
ctx = slogutil.ContextWithLogger(ctx, testLogger)
|
||||
|
||||
err = s.ServeDNS(ctx, nil, dctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dctx.Res)
|
||||
|
||||
|
|
@ -335,7 +339,10 @@ func TestServer_ServeDNS_restrictLocal(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.ServeDNS(testutil.ContextWithTimeout(t, testTimeout), s.dnsProxy, pctx)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
ctx = slogutil.ContextWithLogger(ctx, testLogger)
|
||||
|
||||
err = s.ServeDNS(ctx, s.dnsProxy, pctx)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
require.NotNil(t, pctx.Res)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
|
|
@ -13,10 +14,15 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Write Stats data and logs
|
||||
func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing querylog and stats")
|
||||
defer s.logger.DebugContext(ctx, "finished processing querylog and stats")
|
||||
// processQueryLogsAndStats writes stats data and logs. l and dctx must not be
|
||||
// nil.
|
||||
func (s *Server) processQueryLogsAndStats(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
l.DebugContext(ctx, "started processing querylog and stats")
|
||||
defer l.DebugContext(ctx, "finished processing querylog and stats")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
q := pctx.Req.Question[0]
|
||||
|
|
@ -27,7 +33,7 @@ func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext)
|
|||
s.anonymizer.Load()(ip)
|
||||
ipStr := net.IP(ip).String()
|
||||
|
||||
s.logger.DebugContext(ctx, "client ip for stats and querylog", "ip", ipStr)
|
||||
l.DebugContext(ctx, "client ip for stats and querylog", "ip", ipStr)
|
||||
|
||||
ids := []string{ipStr}
|
||||
if dctx.clientID != "" {
|
||||
|
|
@ -47,12 +53,10 @@ func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext)
|
|||
if s.shouldLog(host, qt, cl, ids) {
|
||||
s.logQuery(dctx, ip, processingTime)
|
||||
} else {
|
||||
s.logger.DebugContext(
|
||||
l.DebugContext(
|
||||
ctx,
|
||||
"not adding to querylog",
|
||||
"dns_class", dns.Class(cl),
|
||||
"dns_type", dns.Type(qt),
|
||||
"host", host,
|
||||
"ip", ipStr,
|
||||
)
|
||||
}
|
||||
|
|
@ -60,12 +64,10 @@ func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext)
|
|||
if s.shouldCountStat(host, qt, cl, ids) {
|
||||
s.updateStats(dctx, ipStr, processingTime)
|
||||
} else {
|
||||
s.logger.DebugContext(
|
||||
l.DebugContext(
|
||||
ctx,
|
||||
"not counting in stats",
|
||||
"dns_class", dns.Class(cl),
|
||||
"dns_type", dns.Type(qt),
|
||||
"host", host,
|
||||
"ip", ipStr,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -231,7 +231,8 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
|||
clientID: tc.clientID,
|
||||
}
|
||||
|
||||
code := srv.processQueryLogsAndStats(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
code := srv.processQueryLogsAndStats(ctx, testLogger, dctx)
|
||||
assert.Equal(t, tc.wantCode, code)
|
||||
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
|
||||
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue