mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2026-06-29 04:11:17 +00:00
Pull request 2579: AGDNS-3523-upd-dnsproxy
Squashed commit of the following: commit 96f8d8057e9d69e5fcf649a5d63f3a935dca3e67 Merge: 44e5f12074c1dcfee4Author: Eugene Burkov <E.Burkov@AdGuard.COM> Date: Wed Feb 25 17:49:37 2026 +0300 Merge branch 'master' into AGDNS-3523-upd-dnsproxy commit 44e5f1207e5201c6527f2f316ceafac2b3593868 Merge: d474280e982f2ac68aAuthor: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Feb 20 10:38:45 2026 +0700 Merge remote-tracking branch 'origin/master' into AGDNS-3523-upd-dnsproxy commit d474280e944ecfab6bfd877089fc7ad79f3ecfed Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Thu Feb 19 20:07:13 2026 +0700 all: upd dnsproxy commit b19a7cd851c62fafafed340825e7ad269f5a578a Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Thu Feb 19 13:50:05 2026 +0700 all: upd dnsproxy commit 56b67426e0ff804880ac74536f01cdbc4db3605b Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Thu Feb 19 10:06:42 2026 +0700 all: imp code commit 716f8186bd5e74152b0677d57a26ca4fc2f991a9 Merge: 96c178bb394a3a4fa6Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Thu Feb 19 09:21:49 2026 +0700 Merge remote-tracking branch 'origin/master' into AGDNS-3523-upd-dnsproxy # Conflicts: # go.mod # go.sum commit 96c178bb3d2b6cfcd1cb9377b71c3c23987eab6e Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Tue Feb 17 10:27:27 2026 +0700 next: dnssvc todos commit f84834f68f399e6ce841c7c5765f3180acb3c417 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Tue Feb 17 09:11:54 2026 +0700 all: upd dnsproxy commit 6ad1a30c087256374882b6e24f5c879121c34834 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Mon Feb 16 13:17:09 2026 +0700 scripts: fix commit cc7c6d3181a427f4e79f4af686d3363f5c341001 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Mon Feb 16 13:11:24 2026 +0700 all: upd golibs commit cd2662cf6bc52216c9b2ec55ffd34a8ade434170 Merge: d3d2bd6f2a165cdb68Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Mon Feb 16 11:49:19 2026 +0700 Merge remote-tracking branch 'refs/remotes/origin/master' into AGDNS-3523-upd-dnsproxy commit d3d2bd6f27de6b721f73a48e1c4c8e1187d51046 Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Feb 6 12:23:43 2026 +0700 all: todo commit e16df07e1062e67cc3d085048c4399cf375997bf Author: Dimitry Kolyshev <dkolyshev@adguard.com> Date: Fri Feb 6 12:06:37 2026 +0700 dnsforward: upd dnsproxy
This commit is contained in:
parent
4c1dcfee40
commit
8c9756f32f
12 changed files with 482 additions and 408 deletions
2
go.mod
2
go.mod
|
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
|
|||
go 1.25.7
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.78.2
|
||||
github.com/AdguardTeam/dnsproxy v0.79.0
|
||||
github.com/AdguardTeam/golibs v0.35.8
|
||||
github.com/AdguardTeam/urlfilter v0.23.1
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -4,8 +4,8 @@ cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs=
|
|||
cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
github.com/AdguardTeam/dnsproxy v0.78.2 h1:g+ba4vh72hAv9zIE+OPSEnu77utSKxIF6u2jNhYAR7g=
|
||||
github.com/AdguardTeam/dnsproxy v0.78.2/go.mod h1:gwr+7Dc0e7QddQLC9JLGjL5NSKcqw0ESsNMRI5Q67Ps=
|
||||
github.com/AdguardTeam/dnsproxy v0.79.0 h1:wvNTny4u6x95bWGRyyqr1PVkHbYyAhPsv4EvnqVlmf4=
|
||||
github.com/AdguardTeam/dnsproxy v0.79.0/go.mod h1:gwr+7Dc0e7QddQLC9JLGjL5NSKcqw0ESsNMRI5Q67Ps=
|
||||
github.com/AdguardTeam/golibs v0.35.8 h1:KsyF3SWwj05Ey4GiAWU6FGD9oJTDNMp1ixVdS+Nw50M=
|
||||
github.com/AdguardTeam/golibs v0.35.8/go.mod h1:kuLQ0yNRTl0Em2FmmXtSri7ZdVT7p62oojyc51RvP38=
|
||||
github.com/AdguardTeam/urlfilter v0.23.1 h1:ifoms1xhof83+IPz96NsZt0h8knXOlL/lNP1cHjndfE=
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/ratelimit"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
|
|
@ -48,11 +49,11 @@ type Config struct {
|
|||
|
||||
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv4 int `yaml:"ratelimit_subnet_len_ipv4"`
|
||||
RatelimitSubnetLenIPv4 uint `yaml:"ratelimit_subnet_len_ipv4"`
|
||||
|
||||
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv6 int `yaml:"ratelimit_subnet_len_ipv6"`
|
||||
RatelimitSubnetLenIPv6 uint `yaml:"ratelimit_subnet_len_ipv6"`
|
||||
|
||||
// RatelimitWhitelist is the list of whitelisted client IP addresses.
|
||||
RatelimitWhitelist []netip.Addr `yaml:"ratelimit_whitelist"`
|
||||
|
|
@ -325,13 +326,14 @@ func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err er
|
|||
srvConf := s.conf
|
||||
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
|
||||
|
||||
ratelimitMw, err := newRatelimitMw(s.baseLogger, srvConf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ratelimit middleware: %w", err)
|
||||
}
|
||||
|
||||
conf = &proxy.Config{
|
||||
Logger: s.baseLogger.With(slogutil.KeyPrefix, aghslog.PrefixDNSProxy),
|
||||
HTTP3: srvConf.ServeHTTP3,
|
||||
Ratelimit: int(srvConf.Ratelimit),
|
||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
|
||||
RatelimitWhitelist: srvConf.RatelimitWhitelist,
|
||||
RefuseAny: srvConf.RefuseAny,
|
||||
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
|
||||
CacheMinTTL: srvConf.CacheMinTTL,
|
||||
|
|
@ -342,7 +344,7 @@ func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err er
|
|||
UpstreamConfig: srvConf.UpstreamConfig,
|
||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||
BeforeRequestHandler: s,
|
||||
RequestHandler: s.handleDNSRequest,
|
||||
RequestHandler: ratelimitMw.Wrap(s),
|
||||
HTTPSServerName: aghhttp.UserAgent(),
|
||||
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
|
||||
MaxGoroutines: srvConf.MaxGoroutines,
|
||||
|
|
@ -395,6 +397,29 @@ func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err er
|
|||
return conf, nil
|
||||
}
|
||||
|
||||
// newRatelimitMw returns the ratelimit middleware. In case of invalid
|
||||
// ratelimit configuration returns an error. l must not be nil.
|
||||
func newRatelimitMw(
|
||||
l *slog.Logger,
|
||||
conf ServerConfig,
|
||||
) (mw proxy.Middleware, err error) {
|
||||
if conf.Ratelimit == 0 {
|
||||
return proxy.MiddlewareFunc(proxy.PassThrough), nil
|
||||
}
|
||||
|
||||
rlConf := &ratelimit.Config{
|
||||
Logger: l.With(slogutil.KeyPrefix, "ratelimit"),
|
||||
Ratelimit: uint(conf.Ratelimit),
|
||||
SubnetLenIPv4: conf.RatelimitSubnetLenIPv4,
|
||||
SubnetLenIPv6: conf.RatelimitSubnetLenIPv6,
|
||||
}
|
||||
if err = rlConf.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
return ratelimit.NewMiddleware(rlConf), nil
|
||||
}
|
||||
|
||||
// prepareCacheConfig prepares the cache configuration and returns an error if
|
||||
// there is one.
|
||||
func prepareCacheConfig(
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ func newRR(tb testing.TB, name string, qtype uint16, ttl uint32, val any) (rr dn
|
|||
return rr
|
||||
}
|
||||
|
||||
func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||
func TestServer_ServeDNS_dns64(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
|
|
@ -16,200 +15,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
rules := `
|
||||
||blocked.domain^
|
||||
@@||allowed.domain^
|
||||
||cname.specific^$dnstype=~CNAME
|
||||
||0.0.0.1^$dnstype=~A
|
||||
||::1^$dnstype=~AAAA
|
||||
0.0.0.0 duplicate.domain
|
||||
0.0.0.0 duplicate.domain
|
||||
0.0.0.0 blocked.by.hostrule
|
||||
`
|
||||
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
filters := []filtering.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{
|
||||
Logger: testLogger,
|
||||
ProtectionEnabled: true,
|
||||
ApplyClientFiltering: applyEmptyClientFiltering,
|
||||
BlockedServices: emptyFilteringBlockedServices(),
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, filters)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return false },
|
||||
OnHostByIP: func(ip netip.Addr) (_ string) { panic(testutil.UnexpectedCall(ip)) },
|
||||
OnIPByHost: func(host string) (_ netip.Addr) { panic(testutil.UnexpectedCall(host)) },
|
||||
},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: testLogger,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: map[string][]string{
|
||||
"cname.exception.": {"cname.specific."},
|
||||
"should.block.": {"blocked.domain."},
|
||||
"allowed.first.": {"allowed.domain.", "blocked.domain."},
|
||||
"blocked.first.": {"blocked.domain.", "allowed.domain."},
|
||||
},
|
||||
IPv4: map[string][]net.IP{
|
||||
"a.exception.": {{0, 0, 0, 1}},
|
||||
},
|
||||
IPv6: map[string][]net.IP{
|
||||
"aaaa.exception.": {net.ParseIP("::1")},
|
||||
},
|
||||
},
|
||||
}
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
req *dns.Msg
|
||||
name string
|
||||
wantRCode int
|
||||
wantAns []dns.RR
|
||||
}{{
|
||||
req: createTestMessage(aghtest.ReqFQDN),
|
||||
name: "pass",
|
||||
wantRCode: dns.RcodeNameError,
|
||||
wantAns: nil,
|
||||
}, {
|
||||
req: createTestMessage("cname.exception."),
|
||||
name: "cname_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "cname.exception.",
|
||||
Rrtype: dns.TypeCNAME,
|
||||
},
|
||||
Target: "cname.specific.",
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("should.block."),
|
||||
name: "blocked_by_cname",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "should.block.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("a.exception."),
|
||||
name: "a_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "a.exception.",
|
||||
Rrtype: dns.TypeA,
|
||||
},
|
||||
A: net.IP{0, 0, 0, 1},
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
|
||||
name: "aaaa_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "aaaa.exception.",
|
||||
Rrtype: dns.TypeAAAA,
|
||||
},
|
||||
AAAA: net.ParseIP("::1"),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("allowed.first."),
|
||||
name: "allowed_first",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "allowed.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("blocked.first."),
|
||||
name: "blocked_first",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "blocked.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("duplicate.domain."),
|
||||
name: "duplicate_domain",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "duplicate.domain.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("blocked.domain.", dns.TypeHTTPS),
|
||||
name: "blocked_https_req",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: nil,
|
||||
}, {
|
||||
req: createTestMessageWithType("blocked.by.hostrule.", dns.TypeHTTPS),
|
||||
name: "blocked_host_rule_https_req",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Addr: testClientAddrPort,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.handleDNSRequest(nil, dctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dctx.Res)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, dctx.Res.Rcode)
|
||||
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
func TestServer_filterDNSResponse(t *testing.T) {
|
||||
const (
|
||||
passedIPv4Str = "1.1.1.1"
|
||||
blockedIPv4Str = "1.2.3.4"
|
||||
|
|
|
|||
|
|
@ -52,11 +52,11 @@ type jsonDNSConfig struct {
|
|||
|
||||
// RatelimitSubnetLenIPv4 is a subnet length for IPv4 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv4 *int `json:"ratelimit_subnet_len_ipv4"`
|
||||
RatelimitSubnetLenIPv4 *uint `json:"ratelimit_subnet_len_ipv4"`
|
||||
|
||||
// RatelimitSubnetLenIPv6 is a subnet length for IPv6 addresses used for
|
||||
// rate limiting requests.
|
||||
RatelimitSubnetLenIPv6 *int `json:"ratelimit_subnet_len_ipv6"`
|
||||
RatelimitSubnetLenIPv6 *uint `json:"ratelimit_subnet_len_ipv6"`
|
||||
|
||||
// UpstreamTimeout is an upstream timeout in seconds.
|
||||
UpstreamTimeout *int `json:"upstream_timeout"`
|
||||
|
|
@ -519,7 +519,7 @@ func (req *jsonDNSConfig) checkUpstreamTimeout() (err error) {
|
|||
|
||||
// checkInclusion returns an error if a ptr is not nil and points to value,
|
||||
// that not in the inclusive range between minN and maxN.
|
||||
func checkInclusion(ptr *int, minN, maxN int) (err error) {
|
||||
func checkInclusion(ptr *uint, minN, maxN uint) (err error) {
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -828,7 +828,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
|
|||
// -> dnsforward.ServeHTTP
|
||||
// -> proxy.ServeHTTP
|
||||
// -> proxy.handleDNSRequest
|
||||
// -> dnsforward.handleDNSRequest
|
||||
// -> dnsforward.ServeDNS
|
||||
func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := s.logger
|
||||
|
|
|
|||
|
|
@ -80,57 +80,6 @@ const (
|
|||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
|
||||
const ddrHostFQDN = "_dns.resolver.arpa."
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: pctx,
|
||||
result: &filtering.Result{},
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
type modProcessFunc func(ctx context.Context, dctx *dnsContext) (rc resultCode)
|
||||
|
||||
// Since (*dnsforward.Server).handleDNSRequest(...) is used as
|
||||
// proxy.(Config).RequestHandler, there is no need for additional index
|
||||
// out of range checking in any of the following functions, because the
|
||||
// (*proxy.Proxy).handleDNSRequest method performs it before calling the
|
||||
// appropriate handler.
|
||||
mods := []modProcessFunc{
|
||||
s.processInitial,
|
||||
s.processDDRQuery,
|
||||
s.processDHCPHosts,
|
||||
s.processDHCPAddrs,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processUpstream,
|
||||
s.processFilteringAfterResponse,
|
||||
s.ipset.process,
|
||||
s.processQueryLogsAndStats,
|
||||
}
|
||||
for _, process := range mods {
|
||||
r := process(ctx, dctx)
|
||||
switch r {
|
||||
case resultCodeSuccess:
|
||||
// continue: call the next filter
|
||||
|
||||
case resultCodeFinish:
|
||||
return nil
|
||||
|
||||
case resultCodeError:
|
||||
return dctx.err
|
||||
}
|
||||
}
|
||||
|
||||
if pctx.Res != nil {
|
||||
// Some devices require DNS message compression.
|
||||
pctx.Res.Compress = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// mozillaFQDN is the domain used to signal the Firefox browser to not use its
|
||||
// own DoH server.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
|
|
@ -637,140 +636,6 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Rewrite this test to use the whole server instead of just
|
||||
// testing the [handleDNSRequest] method. See comment on
|
||||
// "from_external_for_local" test case.
|
||||
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
|
||||
intAddr := netip.MustParseAddr("192.168.1.1")
|
||||
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extAddr := netip.MustParseAddr("254.253.252.1")
|
||||
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
extPTRAnswer = "host1.example.net."
|
||||
intPTRAnswer = "some.local-client."
|
||||
)
|
||||
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
||||
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
|
||||
// Improve Config declaration for tests.
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{localUpsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
question string
|
||||
wantErr error
|
||||
wantAns []dns.RR
|
||||
isPrivate bool
|
||||
}{{
|
||||
name: "from_local_for_external",
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
// In theory this case is not reproducible because [proxy.Proxy] should
|
||||
// respond to such queries with NXDOMAIN before they reach
|
||||
// [Server.handleDNSRequest].
|
||||
name: "from_external_for_local",
|
||||
question: intPTRQuestion,
|
||||
wantErr: upstream.ErrNoUpstreams,
|
||||
wantAns: nil,
|
||||
isPrivate: false,
|
||||
}, {
|
||||
name: "from_local_for_local",
|
||||
question: intPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(intPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(intPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(intPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
name: "from_external_for_external",
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
pref, extErr := netutil.ExtractReversedAddr(tc.question)
|
||||
require.NoError(t, extErr)
|
||||
|
||||
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
|
||||
pctx := &proxy.DNSContext{
|
||||
Req: req,
|
||||
IsPrivateClient: tc.isPrivate,
|
||||
}
|
||||
// TODO(e.burkov): Configure the subnet set properly.
|
||||
if netutil.IsLocallyServed(pref.Addr()) {
|
||||
pctx.RequestedPrivateRDNS = pref
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.handleDNSRequest(s.dnsProxy, pctx)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
require.NotNil(t, pctx.Res)
|
||||
assert.Equal(t, tc.wantAns, pctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
const locDomain = "some.local."
|
||||
const reqAddr = "1.1.168.192.in-addr.arpa."
|
||||
|
|
|
|||
60
internal/dnsforward/requesthandler.go
Normal file
60
internal/dnsforward/requesthandler.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
)
|
||||
|
||||
// type check
|
||||
var _ proxy.Handler = (*Server)(nil)
|
||||
|
||||
// ServeDNS implements the [proxy.Handler] interface for [*Server].
|
||||
func (s *Server) ServeDNS(_ *proxy.Proxy, pctx *proxy.DNSContext) (err error) {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: pctx,
|
||||
result: &filtering.Result{},
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
type modProcessFunc func(ctx context.Context, dctx *dnsContext) (rc resultCode)
|
||||
|
||||
// Since [*dnsforward.Server] is used as [proxy.Handler], there is no need
|
||||
// for additional index out of range checking in any of the following
|
||||
// functions, because the (*proxy.Proxy).handleDNSRequest method performs it
|
||||
// before calling the appropriate handler.
|
||||
mods := []modProcessFunc{
|
||||
s.processInitial,
|
||||
s.processDDRQuery,
|
||||
s.processDHCPHosts,
|
||||
s.processDHCPAddrs,
|
||||
s.processFilteringBeforeRequest,
|
||||
s.processUpstream,
|
||||
s.processFilteringAfterResponse,
|
||||
s.ipset.process,
|
||||
s.processQueryLogsAndStats,
|
||||
}
|
||||
for _, process := range mods {
|
||||
r := process(ctx, dctx)
|
||||
switch r {
|
||||
case resultCodeSuccess:
|
||||
// continue: call the next filter
|
||||
case resultCodeFinish:
|
||||
return nil
|
||||
case resultCodeError:
|
||||
return dctx.err
|
||||
}
|
||||
}
|
||||
|
||||
if pctx.Res != nil {
|
||||
// Some devices require DNS message compression.
|
||||
pctx.Res.Compress = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
345
internal/dnsforward/requesthandler_internal_test.go
Normal file
345
internal/dnsforward/requesthandler_internal_test.go
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServer_ServeDNS(t *testing.T) {
|
||||
rules := `
|
||||
||blocked.domain^
|
||||
@@||allowed.domain^
|
||||
||cname.specific^$dnstype=~CNAME
|
||||
||0.0.0.1^$dnstype=~A
|
||||
||::1^$dnstype=~AAAA
|
||||
0.0.0.0 duplicate.domain
|
||||
0.0.0.0 duplicate.domain
|
||||
0.0.0.0 blocked.by.hostrule
|
||||
`
|
||||
|
||||
forwardConf := ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
Config: Config{
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{
|
||||
Enabled: false,
|
||||
},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
filters := []filtering.Filter{{
|
||||
ID: 0, Data: []byte(rules),
|
||||
}}
|
||||
|
||||
f, err := filtering.New(&filtering.Config{
|
||||
Logger: testLogger,
|
||||
ProtectionEnabled: true,
|
||||
ApplyClientFiltering: applyEmptyClientFiltering,
|
||||
BlockedServices: emptyFilteringBlockedServices(),
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, filters)
|
||||
require.NoError(t, err)
|
||||
f.SetEnabled(true)
|
||||
|
||||
s, err := NewServer(DNSCreateParams{
|
||||
DHCPServer: &testDHCP{
|
||||
OnEnabled: func() (ok bool) { return false },
|
||||
OnHostByIP: func(ip netip.Addr) (_ string) { panic(testutil.UnexpectedCall(ip)) },
|
||||
OnIPByHost: func(host string) (_ netip.Addr) { panic(testutil.UnexpectedCall(host)) },
|
||||
},
|
||||
DNSFilter: f,
|
||||
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||
Logger: testLogger,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
&aghtest.Upstream{
|
||||
CName: map[string][]string{
|
||||
"cname.exception.": {"cname.specific."},
|
||||
"should.block.": {"blocked.domain."},
|
||||
"allowed.first.": {"allowed.domain.", "blocked.domain."},
|
||||
"blocked.first.": {"blocked.domain.", "allowed.domain."},
|
||||
},
|
||||
IPv4: map[string][]net.IP{
|
||||
"a.exception.": {{0, 0, 0, 1}},
|
||||
},
|
||||
IPv6: map[string][]net.IP{
|
||||
"aaaa.exception.": {net.ParseIP("::1")},
|
||||
},
|
||||
},
|
||||
}
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
req *dns.Msg
|
||||
name string
|
||||
wantRCode int
|
||||
wantAns []dns.RR
|
||||
}{{
|
||||
req: createTestMessage(aghtest.ReqFQDN),
|
||||
name: "pass",
|
||||
wantRCode: dns.RcodeNameError,
|
||||
wantAns: nil,
|
||||
}, {
|
||||
req: createTestMessage("cname.exception."),
|
||||
name: "cname_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "cname.exception.",
|
||||
Rrtype: dns.TypeCNAME,
|
||||
},
|
||||
Target: "cname.specific.",
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("should.block."),
|
||||
name: "blocked_by_cname",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "should.block.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("a.exception."),
|
||||
name: "a_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "a.exception.",
|
||||
Rrtype: dns.TypeA,
|
||||
},
|
||||
A: net.IP{0, 0, 0, 1},
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA),
|
||||
name: "aaaa_exception",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "aaaa.exception.",
|
||||
Rrtype: dns.TypeAAAA,
|
||||
},
|
||||
AAAA: net.ParseIP("::1"),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("allowed.first."),
|
||||
name: "allowed_first",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "allowed.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("blocked.first."),
|
||||
name: "blocked_first",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "blocked.first.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessage("duplicate.domain."),
|
||||
name: "duplicate_domain",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: "duplicate.domain.",
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: netutil.IPv4Zero(),
|
||||
}},
|
||||
}, {
|
||||
req: createTestMessageWithType("blocked.domain.", dns.TypeHTTPS),
|
||||
name: "blocked_https_req",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: nil,
|
||||
}, {
|
||||
req: createTestMessageWithType("blocked.by.hostrule.", dns.TypeHTTPS),
|
||||
name: "blocked_host_rule_https_req",
|
||||
wantRCode: dns.RcodeSuccess,
|
||||
wantAns: nil,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
dctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoUDP,
|
||||
Req: tc.req,
|
||||
Addr: testClientAddrPort,
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.ServeDNS(nil, dctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dctx.Res)
|
||||
|
||||
assert.Equal(t, tc.wantRCode, dctx.Res.Rcode)
|
||||
assert.Equal(t, tc.wantAns, dctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Rewrite this test to use the whole server instead of just
|
||||
// testing the [Handle] method. See comment on "from_external_for_local" test
|
||||
// case.
|
||||
func TestServer_ServeDNS_restrictLocal(t *testing.T) {
|
||||
intAddr := netip.MustParseAddr("192.168.1.1")
|
||||
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
extAddr := netip.MustParseAddr("254.253.252.1")
|
||||
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
extPTRAnswer = "host1.example.net."
|
||||
intPTRAnswer = "some.local-client."
|
||||
)
|
||||
|
||||
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
resp := cmp.Or(
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
||||
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
||||
(&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
|
||||
)
|
||||
|
||||
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||
})
|
||||
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||
|
||||
s := createTestServer(t, &filtering.Config{
|
||||
BlockingMode: filtering.BlockingModeDefault,
|
||||
}, ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
|
||||
// Improve Config declaration for tests.
|
||||
Config: Config{
|
||||
UpstreamDNS: []string{localUpsAddr},
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
UsePrivateRDNS: true,
|
||||
LocalPTRResolvers: []string{localUpsAddr},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
startDeferStop(t, s)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
question string
|
||||
wantErr error
|
||||
wantAns []dns.RR
|
||||
isPrivate bool
|
||||
}{{
|
||||
name: "from_local_for_external",
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
// In theory this case is not reproducible because [proxy.Proxy] should
|
||||
// respond to such queries with NXDOMAIN before they reach
|
||||
// [Server.Handle].
|
||||
name: "from_external_for_local",
|
||||
question: intPTRQuestion,
|
||||
wantErr: upstream.ErrNoUpstreams,
|
||||
wantAns: nil,
|
||||
isPrivate: false,
|
||||
}, {
|
||||
name: "from_local_for_local",
|
||||
question: intPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(intPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(intPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(intPTRAnswer),
|
||||
}},
|
||||
isPrivate: true,
|
||||
}, {
|
||||
name: "from_external_for_external",
|
||||
question: extPTRQuestion,
|
||||
wantErr: nil,
|
||||
wantAns: []dns.RR{&dns.PTR{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dns.Fqdn(extPTRQuestion),
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 60,
|
||||
Rdlength: uint16(len(extPTRAnswer) + 1),
|
||||
},
|
||||
Ptr: dns.Fqdn(extPTRAnswer),
|
||||
}},
|
||||
isPrivate: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
pref, extErr := netutil.ExtractReversedAddr(tc.question)
|
||||
require.NoError(t, extErr)
|
||||
|
||||
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
|
||||
pctx := &proxy.DNSContext{
|
||||
Req: req,
|
||||
IsPrivateClient: tc.isPrivate,
|
||||
}
|
||||
// TODO(e.burkov): Configure the subnet set properly.
|
||||
if netutil.IsLocallyServed(pref.Addr()) {
|
||||
pctx.RequestedPrivateRDNS = pref
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = s.ServeDNS(s.dnsProxy, pctx)
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
require.NotNil(t, pctx.Res)
|
||||
assert.Equal(t, tc.wantAns, pctx.Res.Answer)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -17,9 +17,11 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/aghslog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/ratelimit"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
// Service is the AdGuard Home DNS service. A nil *Service is a valid
|
||||
|
|
@ -62,11 +64,15 @@ func New(c *Config) (svc *Service, err error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
rlMw, err := newRatelimitMw(c.Logger, c.Ratelimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ratelimit middleware: %w", err)
|
||||
}
|
||||
|
||||
svc = &Service{
|
||||
logger: c.Logger,
|
||||
proxyConf: &proxy.Config{
|
||||
UpstreamMode: c.UpstreamMode,
|
||||
Ratelimit: c.Ratelimit,
|
||||
DNS64Prefs: c.DNS64Prefixes,
|
||||
CacheSizeBytes: c.CacheSize,
|
||||
CacheEnabled: c.CacheEnabled,
|
||||
|
|
@ -97,14 +103,14 @@ func New(c *Config) (svc *Service, err error) {
|
|||
UpstreamConfig: &proxy.UpstreamConfig{
|
||||
Upstreams: upstreams,
|
||||
},
|
||||
UDPListenAddr: udpAddrs(c.Addresses),
|
||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||
UpstreamMode: svc.proxyConf.UpstreamMode,
|
||||
Ratelimit: svc.proxyConf.Ratelimit,
|
||||
DNS64Prefs: svc.proxyConf.DNS64Prefs,
|
||||
CacheEnabled: svc.proxyConf.CacheEnabled,
|
||||
RefuseAny: svc.proxyConf.RefuseAny,
|
||||
UseDNS64: svc.proxyConf.UseDNS64,
|
||||
UDPListenAddr: udpAddrs(c.Addresses),
|
||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||
UpstreamMode: svc.proxyConf.UpstreamMode,
|
||||
RequestHandler: rlMw.Wrap(proxy.DefaultHandler{}),
|
||||
DNS64Prefs: svc.proxyConf.DNS64Prefs,
|
||||
CacheEnabled: svc.proxyConf.CacheEnabled,
|
||||
RefuseAny: svc.proxyConf.RefuseAny,
|
||||
UseDNS64: svc.proxyConf.UseDNS64,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: %w", err)
|
||||
|
|
@ -113,6 +119,26 @@ func New(c *Config) (svc *Service, err error) {
|
|||
return svc, nil
|
||||
}
|
||||
|
||||
// newRatelimitMw returns the ratelimit middleware. In case of invalid
|
||||
// ratelimit configuration returns an error. l must not be nil.
|
||||
func newRatelimitMw(l *slog.Logger, limit int) (mw proxy.Middleware, err error) {
|
||||
if limit == 0 {
|
||||
return proxy.MiddlewareFunc(proxy.PassThrough), nil
|
||||
}
|
||||
|
||||
rlConf := &ratelimit.Config{
|
||||
Logger: l.With(slogutil.KeyPrefix, "ratelimit"),
|
||||
Ratelimit: uint(limit),
|
||||
SubnetLenIPv4: netutil.IPv4BitLen,
|
||||
SubnetLenIPv6: netutil.IPv6BitLen,
|
||||
}
|
||||
if err = rlConf.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid configuration: %w", err)
|
||||
}
|
||||
|
||||
return ratelimit.NewMiddleware(rlConf), nil
|
||||
}
|
||||
|
||||
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
||||
// accepts a slice of addresses and other upstream parameters, and returns a
|
||||
// slice of upstreams. logger must not be nil.
|
||||
|
|
@ -245,6 +271,7 @@ func (svc *Service) Config() (c *Config) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO(d.kolyshev): Fill ratelimit.
|
||||
c = &Config{
|
||||
Logger: svc.logger,
|
||||
UpstreamMode: svc.proxyConf.UpstreamMode,
|
||||
|
|
@ -254,7 +281,6 @@ func (svc *Service) Config() (c *Config) {
|
|||
DNS64Prefixes: svc.proxyConf.DNS64Prefs,
|
||||
UpstreamTimeout: svc.upstreamTimeout,
|
||||
CacheSize: svc.proxyConf.CacheSizeBytes,
|
||||
Ratelimit: svc.proxyConf.Ratelimit,
|
||||
BootstrapPreferIPv6: svc.bootstrapPreferIPv6,
|
||||
CacheEnabled: svc.proxyConf.CacheEnabled,
|
||||
RefuseAny: svc.proxyConf.RefuseAny,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
|
|||
BootstrapServers: []string{"94.140.14.140", "94.140.14.141"},
|
||||
UpstreamServers: []string{"94.140.14.14", "1.1.1.1"},
|
||||
UpstreamTimeout: aghhttp.JSONDuration(1 * time.Second),
|
||||
Ratelimit: 100,
|
||||
CacheSize: 1048576,
|
||||
BootstrapPreferIPv6: true,
|
||||
RefuseAny: true,
|
||||
|
|
@ -45,7 +44,6 @@ func TestService_HandleGetSettingsAll(t *testing.T) {
|
|||
BootstrapServers: wantDNS.BootstrapServers,
|
||||
UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout),
|
||||
CacheSize: 1048576,
|
||||
Ratelimit: 100,
|
||||
BootstrapPreferIPv6: true,
|
||||
RefuseAny: true,
|
||||
UseDNS64: true,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue