Pull request 2657: AGDNS-3720-add-tls-config-provider

Squashed commit of the following:

commit 1748706d70718ae68d64cb0b26d30be5c3635a8d
Author: Maksim Kazantsev <m.kazantsev@adguard.com>
Date:   Wed May 20 12:21:56 2026 +0300

    all: imp docs;

commit 90f314adeadd167765a0a86493877f042f4b9805
Author: Maksim Kazantsev <m.kazantsev@adguard.com>
Date:   Tue May 19 20:02:09 2026 +0300

    home: imp code;

commit 76265a91fd138ee344acc644bc3a8cfbb0c458f9
Author: Maksim Kazantsev <m.kazantsev@adguard.com>
Date:   Tue May 19 19:39:35 2026 +0300

    all: add tls config provider; imp tests;
This commit is contained in:
Maksim Kazantsev 2026-05-20 11:34:50 +00:00
parent 6c5e88e29e
commit 3144e1856a
12 changed files with 272 additions and 148 deletions

View file

@ -2,6 +2,8 @@ package aghtest
import (
"context"
"crypto/tls"
"crypto/x509"
"net/http"
"net/netip"
"time"
@ -9,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
@ -198,3 +201,26 @@ var _ aghhttp.Registrar = (*Registrar)(nil)
func (m *Registrar) Register(method, path string, h http.HandlerFunc) {
m.OnRegister(method, path, h)
}
// TLSConfigProvider is a fake [aghtls.TLSConfigProvider] implementation for
// tests.
// TODO(m.kazantsev): Use in tests.
type TLSConfigProvider struct {
OnTLSConfig func() (conf *tls.Config)
OnRootCAs func() (cert *x509.CertPool)
}
// type check
var _ aghtls.TLSConfigProvider = (*TLSConfigProvider)(nil)
// TLSConfig implements the [aghtls.TLSConfigProvider] interface for
// *TLSConfigProvider.
func (t *TLSConfigProvider) TLSConfig() (conf *tls.Config) {
return t.OnTLSConfig()
}
// RootCAs implements the [aghtls.TLSConfigProvider] interface for
// *TLSConfigProvider.
func (t *TLSConfigProvider) RootCAs() (pool *x509.CertPool) {
return t.OnRootCAs()
}

View file

@ -0,0 +1,39 @@
package aghtls
import (
"crypto/tls"
"crypto/x509"
)
// TLSConfigProvider provides TLS configuration to consumers. Implementations
// must be safe for concurrent use.
//
// TODO(m.kazantsev): Merge with the Manager interface.
// TODO(m.kazantsev): Add at least one real implementation.
type TLSConfigProvider interface {
// TLSConfig returns a clone of the current TLS configuration. conf
// provides its certificates via GetConfigForClient method.
TLSConfig() (conf *tls.Config)
// RootCAs returns the current root CA pool.
RootCAs() (root *x509.CertPool)
}
// type check
var _ TLSConfigProvider = EmptyTLSConfigProvider{}
// EmptyTLSConfigProvider is the implementation of the [TLSConfigProvider]
// interface that does nothing.
type EmptyTLSConfigProvider struct{}
// TLSConfig implements the [TLSConfigProvider] interface for
// EmptyTLSConfigProvider. It always returns nil.
func (EmptyTLSConfigProvider) TLSConfig() (conf *tls.Config) {
return nil
}
// RootCAs implements the [TLSConfigProvider] interface for
// EmptyTLSConfigProvider. It always returns nil.
func (EmptyTLSConfigProvider) RootCAs() (root *x509.CertPool) {
return nil
}

View file

@ -544,7 +544,12 @@ func (conf *ServerConfig) loadUpstreams(
upstreams = stringutil.SplitTrimmed(string(data), "\n")
l.DebugContext(ctx, "got upstreams", "number", len(upstreams), "filename", conf.UpstreamDNSFileName)
l.DebugContext(
ctx,
"got upstreams",
"number", len(upstreams),
"filename", conf.UpstreamDNSFileName,
)
return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil
}
@ -652,7 +657,10 @@ func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error)
// ourAddrsSet returns an addrPortSet that contains all the configured listening
// addresses. l must not be nil.
func (conf *ServerConfig) ourAddrsSet(ctx context.Context, l *slog.Logger) (m addrPortSet, err error) {
func (conf *ServerConfig) ourAddrsSet(
ctx context.Context,
l *slog.Logger,
) (m addrPortSet, err error) {
addrs, unspecPorts := conf.collectDNSAddrs()
switch {
case addrs.Len() == 0:
@ -781,8 +789,9 @@ func anyNameMatches(dnsNames []string, sni string) (ok bool) {
return false
}
// Called by 'tls' package when Client Hello is received
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
// onGetCertificate is called by [tls] package when Client Hello is received. If
// the server name (from SNI) supplied by client is incorrect - we terminate the
// ongoing TLS handshake.
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.conf.TLSConf.StrictSNICheck && !anyNameMatches(s.dnsNames, ch.ServerName) {
// TODO(s.chzhen): Pass context.
@ -798,8 +807,8 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
return s.conf.TLSConf.Cert, nil
}
// preparePlain prepares the plain-DNS configuration for the DNS proxy.
// preparePlain assumes that prepareTLS has already been called.
// preparePlain prepares the plain-DNS configuration for the DNS proxy. The
// method assumes that prepareTLS has already been called.
func (s *Server) preparePlain(ctx context.Context, proxyConf *proxy.Config) (err error) {
if s.conf.ServePlainDNS {
proxyConf.UDPListenAddr = s.conf.UDPListenAddrs

View file

@ -90,7 +90,12 @@ func newUpstreamConfigValidator(
// collectErrResults parses err and returns parsing results containing the
// original upstream configuration line and the corresponding error. err can be
// nil. l must not be nil.
func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err error) (results []*parseResult) {
func collectErrResults(
ctx context.Context,
l *slog.Logger,
lines []string,
err error,
) (results []*parseResult) {
if err == nil {
return nil
}
@ -132,7 +137,7 @@ func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err
}
// insertConfResults parses conf and inserts the upstream result into results.
// It can insert multiple results as well as none.
// It can insert multiple results as well as none. conf must not be nil.
func insertConfResults(conf *proxy.UpstreamConfig, results map[string]*upstreamResult) {
insertListResults(conf.Upstreams, results, false)

View file

@ -884,8 +884,8 @@ func (c *configuration) write(
}
if tlsMgr != nil {
tlsConf := tlsMgr.config()
config.TLS = *tlsConf
extTLSConf := tlsMgr.extendedTLSConfig()
config.TLS = *extTLSConf
}
if globalContext.stats != nil {

View file

@ -178,6 +178,10 @@ type versionResponse struct {
Disabled bool `json:"disabled"`
}
// maxPrivilegedPort is the maximum port number. This only applies to Unix, as
// on Windows, [aghnet.CanBindPrivilegedPorts] always returns `true`, `nil`.
const maxPrivilegedPort = 1024
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
// allowed to perform an automatic update by the OS. l and tlsMgr must not be
// nil.
@ -191,9 +195,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate(
}
canUpdate := true
if tlsConfUsesPrivilegedPorts(tlsMgr.config()) ||
config.HTTPConfig.Address.Port() < 1024 ||
config.DNS.Port < 1024 {
if tlsConfUsesPrivilegedPorts(tlsMgr.extendedTLSConfig()) ||
config.HTTPConfig.Address.Port() < maxPrivilegedPort ||
config.DNS.Port < maxPrivilegedPort {
canUpdate, err = aghnet.CanBindPrivilegedPorts(ctx, l)
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
@ -206,9 +210,11 @@ func (vr *versionResponse) setAllowedToAutoUpdate(
}
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
// indicates that privileged ports are used.
// indicates that privileged ports are used. c must be valid
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
return c.Enabled && (c.PortHTTPS < maxPrivilegedPort ||
c.PortDNSOverTLS < maxPrivilegedPort ||
c.PortDNSOverQUIC < maxPrivilegedPort)
}
// finishUpdate completes an update procedure. It is intended to be used as a

View file

@ -170,7 +170,7 @@ func initDNSServer(
dnsConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
tlsMgr.config(),
tlsMgr.extendedTLSConfig(),
config.HTTPConfig.DoH,
tlsMgr,
httpReg,
@ -212,7 +212,8 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
}
}
func isRunning() bool {
// isRunning checks whether the DNS server is running.
func isRunning() (ok bool) {
return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
}
@ -262,7 +263,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
func newServerConfig(
dnsConf *dnsConfig,
clientSrcConf *clientSourcesConfig,
tlsConf *tlsConfigSettings,
extTLSConf *tlsConfigSettings,
dohConf *doHConfig,
tlsMgr *tlsManager,
httpReg aghhttp.Registrar,
@ -274,7 +275,7 @@ func newServerConfig(
fwdConf := dnsConf.Config
fwdConf.ClientsContainer = clientsContainer
intTLSConf, err := newDNSTLSConfig(tlsConf, hosts, dohConf.InsecureEnabled)
intTLSConf, err := newDNSTLSConfig(extTLSConf, hosts, dohConf.InsecureEnabled)
if err != nil {
return nil, fmt.Errorf("constructing tls config: %w", err)
}
@ -322,19 +323,19 @@ func newServerConfig(
}
// newDNSTLSConfig converts values from the configuration file into the internal
// TLS settings for the DNS server. conf must not be nil.
// TLS settings for the DNS server. extTLSConf must not be nil.
func newDNSTLSConfig(
conf *tlsConfigSettings,
extTLSConf *tlsConfigSettings,
addrs []netip.Addr,
allowUnencryptedDoH bool,
) (dnsConf *dnsforward.TLSConfig, err error) {
if !conf.Enabled {
if !extTLSConf.Enabled {
return &dnsforward.TLSConfig{}, nil
}
// TODO(e.burkov): Add tracking for DNSCrypt configuration file changes to
// the [aghtls.Manager].
dnsCryptConf, err := newDNSCryptConfig(conf, addrs)
dnsCryptConf, err := newDNSCryptConfig(extTLSConf, addrs)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
@ -342,23 +343,23 @@ func newDNSTLSConfig(
dnsConf = &dnsforward.TLSConfig{
DNSCryptConf: dnsCryptConf,
ServerName: conf.ServerName,
StrictSNICheck: conf.StrictSNICheck,
ServerName: extTLSConf.ServerName,
StrictSNICheck: extTLSConf.StrictSNICheck,
}
if conf.PortHTTPS != 0 {
dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, conf.PortHTTPS)
if extTLSConf.PortHTTPS != 0 {
dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, extTLSConf.PortHTTPS)
}
if conf.PortDNSOverTLS != 0 {
dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, conf.PortDNSOverTLS)
if extTLSConf.PortDNSOverTLS != 0 {
dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, extTLSConf.PortDNSOverTLS)
}
if conf.PortDNSOverQUIC != 0 {
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, conf.PortDNSOverQUIC)
if extTLSConf.PortDNSOverQUIC != 0 {
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, extTLSConf.PortDNSOverQUIC)
}
cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData)
cert, err := tls.X509KeyPair(extTLSConf.CertificateChainData, extTLSConf.PrivateKeyData)
if err != nil {
err = fmt.Errorf("parsing tls key pair: %w", err)
if allowUnencryptedDoH || dnsCryptConf != nil {
@ -378,20 +379,20 @@ func newDNSTLSConfig(
}
// newDNSCryptConfig converts values from the configuration file into the
// internal DNSCrypt settings for the DNS server. conf must not be nil.
// internal DNSCrypt settings for the DNS server. extTLSConf must not be nil.
func newDNSCryptConfig(
conf *tlsConfigSettings,
extTLSConf *tlsConfigSettings,
addrs []netip.Addr,
) (dnsCryptConf *dnsforward.DNSCryptConfig, err error) {
if conf.PortDNSCrypt == 0 {
if extTLSConf.PortDNSCrypt == 0 {
return nil, nil
}
if conf.DNSCryptConfigFile == "" {
if extTLSConf.DNSCryptConfigFile == "" {
return nil, fmt.Errorf("dnscrypt_config_file: %w", errors.ErrEmptyValue)
}
f, err := os.Open(conf.DNSCryptConfigFile)
f, err := os.Open(extTLSConf.DNSCryptConfigFile)
if err != nil {
return nil, fmt.Errorf("opening dnscrypt config: %w", err)
}
@ -410,8 +411,8 @@ func newDNSCryptConfig(
return &dnsforward.DNSCryptConfig{
ResolverCert: cert,
UDPListenAddrs: ipsToUDPAddrs(addrs, conf.PortDNSCrypt),
TCPListenAddrs: ipsToTCPAddrs(addrs, conf.PortDNSCrypt),
UDPListenAddrs: ipsToUDPAddrs(addrs, extTLSConf.PortDNSCrypt),
TCPListenAddrs: ipsToTCPAddrs(addrs, extTLSConf.PortDNSCrypt),
ProviderName: rc.ProviderName,
}, nil
}
@ -426,16 +427,16 @@ type dnsEncryption struct {
// getDNSEncryption returns the TLS encryption addresses that AdGuard Home
// listens on. tlsMgr must not be nil.
func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
tlsConf := tlsMgr.config()
extTLSConf := tlsMgr.extendedTLSConfig()
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
if !extTLSConf.Enabled || extTLSConf.ServerName == "" {
return dnsEncryption{}
}
hostname := tlsConf.ServerName
if tlsConf.PortHTTPS != 0 {
hostname := extTLSConf.ServerName
if extTLSConf.PortHTTPS != 0 {
addr := hostname
if p := tlsConf.PortHTTPS; p != defaultPortHTTPS {
if p := extTLSConf.PortHTTPS; p != defaultPortHTTPS {
addr = netutil.JoinHostPort(addr, p)
}
@ -446,14 +447,14 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
}).String()
}
if p := tlsConf.PortDNSOverTLS; p != 0 {
if p := extTLSConf.PortDNSOverTLS; p != 0 {
de.tls = (&url.URL{
Scheme: "tls",
Host: netutil.JoinHostPort(hostname, p),
}).String()
}
if p := tlsConf.PortDNSOverQUIC; p != 0 {
if p := extTLSConf.PortDNSOverQUIC; p != 0 {
de.quic = (&url.URL{
Scheme: "quic",
Host: netutil.JoinHostPort(hostname, p),
@ -463,7 +464,9 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
return de
}
func startDNSServer() error {
// startDNSServer starts the DNS server, clients container, filters, stats and
// the query log.
func startDNSServer() (err error) {
config.RLock()
defer config.RUnlock()
@ -475,7 +478,7 @@ func startDNSServer() error {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err := globalContext.clients.Start(ctx)
err = globalContext.clients.Start(ctx)
if err != nil {
return fmt.Errorf("starting clients container: %w", err)
}

View file

@ -1269,22 +1269,24 @@ func printWebAddrs(ctx context.Context, l *slog.Logger, proto, addr string, port
}
// printHTTPAddresses prints the IP addresses which user can use to access the
// admin interface. proto is either schemeHTTP or schemeHTTPS.
// admin interface. proto is either [urlutil.SchemeHTTPS] or
// [urlutil.SchemeHTTP]. l must not be nil. If proto is [urlutil.SchemeHTTPS],
// then tlsMgr must not be nil.
//
// TODO(s.chzhen): Implement separate functions for HTTP and HTTPS.
func printHTTPAddresses(ctx context.Context, l *slog.Logger, proto string, tlsMgr *tlsManager) {
var tlsConf *tlsConfigSettings
var extTLSConf *tlsConfigSettings
if tlsMgr != nil {
tlsConf = tlsMgr.config()
extTLSConf = tlsMgr.extendedTLSConfig()
}
port := config.HTTPConfig.Address.Port()
if proto == urlutil.SchemeHTTPS {
port = tlsConf.PortHTTPS
port = extTLSConf.PortHTTPS
}
if proto == urlutil.SchemeHTTPS && tlsConf.ServerName != "" {
printWebAddrs(ctx, l, proto, tlsConf.ServerName, tlsConf.PortHTTPS)
if proto == urlutil.SchemeHTTPS && extTLSConf.ServerName != "" {
printWebAddrs(ctx, l, proto, extTLSConf.ServerName, extTLSConf.PortHTTPS)
return
}

14
internal/home/testdata/cert.pem vendored Normal file
View file

@ -0,0 +1,14 @@
-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
-----END CERTIFICATE-----

16
internal/home/testdata/key.pem vendored Normal file
View file

@ -0,0 +1,16 @@
-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
kXS9jgARhhiWXJrk
-----END PRIVATE KEY-----

View file

@ -36,7 +36,7 @@ type tlsManager struct {
// logger is used for logging the operation of the TLS Manager.
logger *slog.Logger
// mu protects status, certLastMod, conf, and servePlainDNS.
// mu protects status, certLastMod, extTLSConf, and servePlainDNS.
mu *sync.Mutex
// status is the current status of the configuration. It is never nil.
@ -54,8 +54,12 @@ type tlsManager struct {
// Resolve it.
web *webAPI
// conf contains the TLS configuration settings. It must not be nil.
conf *tlsConfigSettings
// extTLSConf contains extended TLS configuration settings. It must not be
// nil.
// TODO(m.kazantsev): Add a field of a type of [*tls.Config] which will
// represent the TLS settings. This is why these settings are called
// 'extended'.
extTLSConf *tlsConfigSettings
// confModifier is used to update the global configuration.
confModifier agh.ConfigModifier
@ -111,7 +115,7 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager,
httpReg: conf.httpReg,
manager: conf.manager,
status: &tlsConfigStatus{},
conf: &conf.tlsSettings,
extTLSConf: &conf.tlsSettings,
servePlainDNS: conf.servePlainDNS,
}
@ -133,21 +137,21 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager,
m.mu.Lock()
defer m.mu.Unlock()
if !m.conf.Enabled {
if !m.extTLSConf.Enabled {
return m, nil
}
err = m.manager.Set(ctx, aghtls.TLSPair{
CertPath: m.conf.CertificatePath,
KeyPath: m.conf.PrivateKeyPath,
CertPath: m.extTLSConf.CertificatePath,
KeyPath: m.extTLSConf.PrivateKeyPath,
})
if err != nil {
m.logger.ErrorContext(ctx, "setting tls files", slogutil.KeyError, err)
}
err = m.loadTLSConfig(ctx, m.conf, m.status)
err = m.loadTLSConfig(ctx, m.extTLSConf, m.status)
if err != nil {
m.conf.Enabled = false
m.extTLSConf.Enabled = false
return m, err
}
@ -166,22 +170,22 @@ func (m *tlsManager) setWebAPI(webAPI *webAPI) {
m.web = webAPI
}
// config returns a deep copy of the stored TLS configuration.
func (m *tlsManager) config() (conf *tlsConfigSettings) {
// extendedTLSConfig returns a deep copy of the stored TLS configuration.
func (m *tlsManager) extendedTLSConfig() (extTLSConf *tlsConfigSettings) {
m.mu.Lock()
defer m.mu.Unlock()
return m.conf.clone()
return m.extTLSConf.clone()
}
// setCertFileTime sets [tlsManager.certLastMod] from the certificate. If there
// are errors, setCertFileTime logs them. m.mu is expected to be locked.
func (m *tlsManager) setCertFileTime(ctx context.Context) {
if len(m.conf.CertificatePath) == 0 {
if len(m.extTLSConf.CertificatePath) == 0 {
return
}
fi, err := os.Stat(m.conf.CertificatePath)
fi, err := os.Stat(m.extTLSConf.CertificatePath)
if err != nil {
m.logger.ErrorContext(ctx, "looking up certificate path", slogutil.KeyError, err)
@ -203,7 +207,7 @@ func (m *tlsManager) start(ctx context.Context) {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
m.web.tlsConfigChanged(context.Background(), m.conf)
m.web.tlsConfigChanged(context.Background(), m.extTLSConf)
go m.handleCertFileChange(ctx)
}
@ -235,7 +239,7 @@ func (m *tlsManager) reload(ctx context.Context) {
m.mu.Lock()
defer m.mu.Unlock()
tlsConfPtr := m.conf
tlsConfPtr := m.extTLSConf
if !tlsConfPtr.Enabled || len(tlsConfPtr.CertificatePath) == 0 {
return
@ -267,7 +271,7 @@ func (m *tlsManager) reload(ctx context.Context) {
return
}
m.conf = &tlsConf
m.extTLSConf = &tlsConf
m.status = status
m.certLastMod = fi.ModTime().UTC()
@ -280,7 +284,7 @@ func (m *tlsManager) reload(ctx context.Context) {
// The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current
// request.
m.web.tlsConfigChanged(context.Background(), m.conf)
m.web.tlsConfigChanged(context.Background(), m.extTLSConf)
}
// reconfigureDNSServer updates the DNS server configuration using the stored
@ -289,7 +293,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
newConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
m.conf,
m.extTLSConf,
config.HTTPConfig.DoH,
m,
m.httpReg,
@ -314,7 +318,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
// set in status.WarningValidation.
func (m *tlsManager) loadTLSConfig(
ctx context.Context,
tlsConf *tlsConfigSettings,
extTLSConf *tlsConfigSettings,
status *tlsConfigStatus,
) (err error) {
defer func() {
@ -327,13 +331,13 @@ func (m *tlsManager) loadTLSConfig(
}
}()
err = loadCertificateChainData(tlsConf)
err = loadCertificateChainData(extTLSConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = loadPrivateKeyData(tlsConf)
err = loadPrivateKeyData(extTLSConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
@ -342,9 +346,9 @@ func (m *tlsManager) loadTLSConfig(
err = m.validateCertificates(
ctx,
status,
tlsConf.CertificateChainData,
tlsConf.PrivateKeyData,
tlsConf.ServerName,
extTLSConf.CertificateChainData,
extTLSConf.PrivateKeyData,
extTLSConf.ServerName,
)
return errors.Annotate(err, "validating certificate pair: %w")
@ -353,15 +357,15 @@ func (m *tlsManager) loadTLSConfig(
// loadCertificateChainData loads PEM-encoded certificates chain data to the
// TLS configuration. tlsConf must be not nil. tlsConf.CertificateChainData
// struct field will be modified in case tlsConfig.CertificatePath is not an
// empty string.
func loadCertificateChainData(tlsConf *tlsConfigSettings) (err error) {
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain)
if tlsConf.CertificatePath != "" {
if tlsConf.CertificateChain != "" {
// empty string. extTLSConf must not be nil.
func loadCertificateChainData(extTLSConf *tlsConfigSettings) (err error) {
extTLSConf.CertificateChainData = []byte(extTLSConf.CertificateChain)
if extTLSConf.CertificatePath != "" {
if extTLSConf.CertificateChain != "" {
return errors.Error("certificate data and file can't be set together")
}
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
extTLSConf.CertificateChainData, err = os.ReadFile(extTLSConf.CertificatePath)
if err != nil {
return fmt.Errorf("reading cert file: %w", err)
}
@ -373,14 +377,15 @@ func loadCertificateChainData(tlsConf *tlsConfigSettings) (err error) {
// loadPrivateKeyData loads PEM-encoded private key data to the TLS
// configuration. tlsConf must be not nil. tlsConf.PrivateKeyData struct field
// will be modified in case tlsConfig.PrivateKeyPath is not an empty string.
func loadPrivateKeyData(tlsConf *tlsConfigSettings) (err error) {
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
if tlsConf.PrivateKeyPath != "" {
if tlsConf.PrivateKey != "" {
// extTLSConf must not be nil.
func loadPrivateKeyData(extTLSConf *tlsConfigSettings) (err error) {
extTLSConf.PrivateKeyData = []byte(extTLSConf.PrivateKey)
if extTLSConf.PrivateKeyPath != "" {
if extTLSConf.PrivateKey != "" {
return errors.Error("private key data and file can't be set together")
}
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
extTLSConf.PrivateKeyData, err = os.ReadFile(extTLSConf.PrivateKeyPath)
if err != nil {
return fmt.Errorf("reading key file: %w", err)
}
@ -460,7 +465,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
m.mu.Lock()
defer m.mu.Unlock()
tlsConf = m.conf.clone()
tlsConf = m.extTLSConf.clone()
servePlainDNS = m.servePlainDNS
}()
@ -494,7 +499,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
defer m.mu.Unlock()
if setts.PrivateKeySaved {
setts.PrivateKey = m.conf.PrivateKey
setts.PrivateKey = m.extTLSConf.PrivateKey
}
if err = m.validateTLSSettings(setts); err != nil {
@ -525,14 +530,14 @@ func (m *tlsManager) setConfig(
status *tlsConfigStatus,
servePlain aghalg.NullBool,
) (restartHTTPS bool) {
if !m.conf.setPrivateFieldsAndCompare(&newConf) {
if !m.extTLSConf.setPrivateFieldsAndCompare(&newConf) {
m.logger.InfoContext(ctx, "config has changed, restarting https server")
restartHTTPS = true
} else {
m.logger.InfoContext(ctx, "config has not changed")
}
m.conf = &newConf
m.extTLSConf = &newConf
m.status = status
@ -587,10 +592,10 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
defer m.mu.Unlock()
if req.PrivateKeySaved {
req.PrivateKey = m.conf.PrivateKey
req.PrivateKey = m.extTLSConf.PrivateKey
}
req.StrictSNICheck = m.conf.StrictSNICheck
req.StrictSNICheck = m.extTLSConf.StrictSNICheck
if err = m.validateTLSSettings(req); err != nil {
aghhttp.ErrorAndLog(ctx, m.logger, r, w, http.StatusBadRequest, "%s", err)

View file

@ -25,44 +25,18 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(s.chzhen): Consider moving to testdata.
var testCertChainData = []byte(`-----BEGIN CERTIFICATE-----
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
-----END CERTIFICATE-----`)
var testPrivateKeyData = []byte(`-----BEGIN PRIVATE KEY-----
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
kXS9jgARhhiWXJrk
-----END PRIVATE KEY-----`)
// Paths to the test TLS-related data.
const (
testCertificatePath = "./testdata/cert.pem"
testPrivateKeyPath = "./testdata/key.pem"
)
func TestValidateCertificates(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
@ -92,6 +66,10 @@ func TestValidateCertificates(t *testing.T) {
t.Run("valid", func(t *testing.T) {
status := &tlsConfigStatus{}
testCertChainData := requireReadFile(t, testCertificatePath)
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
err = m.validateCertificates(ctx, status, testCertChainData, testPrivateKeyData, "")
assert.Error(t, err)
@ -213,7 +191,7 @@ func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey)
}
// writeCertAndKey is a helper function that writes certificate and key to
// specified paths.
// specified paths. key must not be nil.
func writeCertAndKey(
tb testing.TB,
certDER []byte,
@ -310,8 +288,8 @@ func TestTLSManager_Reload(t *testing.T) {
web := newTestWeb(t, &webConfig{})
m.setWebAPI(web)
conf := m.config()
assertCertSerialNumber(t, conf, snBefore)
extTLSConf := m.extendedTLSConfig()
assertCertSerialNumber(t, extTLSConf, snBefore)
certDER, key = newCertAndKey(t, snAfter)
writeCertAndKey(t, certDER, certPath, key, keyPath)
@ -324,8 +302,8 @@ func TestTLSManager_Reload(t *testing.T) {
return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout))
})
conf = m.config()
assertCertSerialNumber(t, conf, snAfter)
extTLSConf = m.extendedTLSConfig()
assertCertSerialNumber(t, extTLSConf, snAfter)
}
func TestTLSManager_HandleTLSStatus(t *testing.T) {
@ -334,13 +312,16 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
err error
)
testCertChain := requireReadFile(t, testCertificatePath)
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
m, err := newTLSManager(ctx, &tlsManagerConfig{
logger: testLogger,
confModifier: agh.EmptyConfigModifier{},
manager: aghtls.EmptyManager{},
tlsSettings: tlsConfigSettings{
Enabled: true,
CertificateChain: string(testCertChainData),
CertificateChain: string(testCertChain),
PrivateKey: string(testPrivateKeyData),
},
servePlainDNS: false,
@ -355,7 +336,7 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
err = json.NewDecoder(w.Body).Decode(res)
require.NoError(t, err)
wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChainData)
wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChain)
assert.True(t, res.Enabled)
assert.Equal(t, wantCertificateChain, res.CertificateChain)
assert.True(t, res.PrivateKeySaved)
@ -470,9 +451,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
confModifier: agh.EmptyConfigModifier{},
manager: aghtls.EmptyManager{},
tlsSettings: tlsConfigSettings{
Enabled: true,
CertificateChain: string(testCertChainData),
PrivateKey: string(testPrivateKeyData),
Enabled: true,
CertificatePath: testCertificatePath,
PrivateKeyPath: testPrivateKeyPath,
},
servePlainDNS: false,
})
@ -483,9 +464,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
setts := &tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
Enabled: true,
CertificatePath: testCertificatePath,
PrivateKeyPath: testPrivateKeyPath,
},
}
@ -500,6 +481,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
err = json.NewDecoder(w.Body).Decode(res)
require.NoError(t, err)
testCertChainData := requireReadFile(t, testCertificatePath)
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData)
require.NoError(t, err)
@ -541,7 +525,7 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
})
require.NoError(t, err)
config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")}
config.DNS.BindHosts = []netip.Addr{netutil.IPv4Localhost()}
config.DNS.Port = 0
const wantSerialNumber int64 = 1
@ -571,16 +555,16 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
web := newTestWeb(t, &webConfig{})
m.setWebAPI(web)
conf := m.config()
assertCertSerialNumber(t, conf, wantSerialNumber)
extTLSConf := m.extendedTLSConfig()
assertCertSerialNumber(t, extTLSConf, wantSerialNumber)
// Prepare a request with the new TLS configuration.
setts := &tlsConfigSettingsExt{
tlsConfigSettings: tlsConfigSettings{
Enabled: true,
PortHTTPS: 4433,
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
Enabled: true,
PortHTTPS: 4433,
CertificatePath: testCertificatePath,
PrivateKeyPath: testPrivateKeyPath,
},
}
@ -606,6 +590,9 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
err = json.NewDecoder(w.Body).Decode(res)
require.NoError(t, err)
testCertChainData := requireReadFile(t, testCertificatePath)
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData)
require.NoError(t, err)
@ -629,3 +616,15 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
return true
}, testTimeout, testTimeout/10)
}
// requireReadFile reads the file at the specified path and returns its content.
//
// TODO(m.kazantsev): Move to golibs/testutil.
func requireReadFile(tb testing.TB, path string) (data []byte) {
tb.Helper()
data, err := os.ReadFile(path)
require.NoError(tb, err)
return data
}