mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2026-06-28 11:51:22 +00:00
Pull request 2657: AGDNS-3720-add-tls-config-provider
Squashed commit of the following:
commit 1748706d70718ae68d64cb0b26d30be5c3635a8d
Author: Maksim Kazantsev <m.kazantsev@adguard.com>
Date: Wed May 20 12:21:56 2026 +0300
all: imp docs;
commit 90f314adeadd167765a0a86493877f042f4b9805
Author: Maksim Kazantsev <m.kazantsev@adguard.com>
Date: Tue May 19 20:02:09 2026 +0300
home: imp code;
commit 76265a91fd138ee344acc644bc3a8cfbb0c458f9
Author: Maksim Kazantsev <m.kazantsev@adguard.com>
Date: Tue May 19 19:39:35 2026 +0300
all: add tls config provider; imp tests;
This commit is contained in:
parent
6c5e88e29e
commit
3144e1856a
12 changed files with 272 additions and 148 deletions
|
|
@ -2,6 +2,8 @@ package aghtest
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
|
@ -9,6 +11,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
|
|
@ -198,3 +201,26 @@ var _ aghhttp.Registrar = (*Registrar)(nil)
|
|||
func (m *Registrar) Register(method, path string, h http.HandlerFunc) {
|
||||
m.OnRegister(method, path, h)
|
||||
}
|
||||
|
||||
// TLSConfigProvider is a fake [aghtls.TLSConfigProvider] implementation for
|
||||
// tests.
|
||||
// TODO(m.kazantsev): Use in tests.
|
||||
type TLSConfigProvider struct {
|
||||
OnTLSConfig func() (conf *tls.Config)
|
||||
OnRootCAs func() (cert *x509.CertPool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ aghtls.TLSConfigProvider = (*TLSConfigProvider)(nil)
|
||||
|
||||
// TLSConfig implements the [aghtls.TLSConfigProvider] interface for
|
||||
// *TLSConfigProvider.
|
||||
func (t *TLSConfigProvider) TLSConfig() (conf *tls.Config) {
|
||||
return t.OnTLSConfig()
|
||||
}
|
||||
|
||||
// RootCAs implements the [aghtls.TLSConfigProvider] interface for
|
||||
// *TLSConfigProvider.
|
||||
func (t *TLSConfigProvider) RootCAs() (pool *x509.CertPool) {
|
||||
return t.OnRootCAs()
|
||||
}
|
||||
|
|
|
|||
39
internal/aghtls/configprovider.go
Normal file
39
internal/aghtls/configprovider.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package aghtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// TLSConfigProvider provides TLS configuration to consumers. Implementations
|
||||
// must be safe for concurrent use.
|
||||
//
|
||||
// TODO(m.kazantsev): Merge with the Manager interface.
|
||||
// TODO(m.kazantsev): Add at least one real implementation.
|
||||
type TLSConfigProvider interface {
|
||||
// TLSConfig returns a clone of the current TLS configuration. conf
|
||||
// provides its certificates via GetConfigForClient method.
|
||||
TLSConfig() (conf *tls.Config)
|
||||
|
||||
// RootCAs returns the current root CA pool.
|
||||
RootCAs() (root *x509.CertPool)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ TLSConfigProvider = EmptyTLSConfigProvider{}
|
||||
|
||||
// EmptyTLSConfigProvider is the implementation of the [TLSConfigProvider]
|
||||
// interface that does nothing.
|
||||
type EmptyTLSConfigProvider struct{}
|
||||
|
||||
// TLSConfig implements the [TLSConfigProvider] interface for
|
||||
// EmptyTLSConfigProvider. It always returns nil.
|
||||
func (EmptyTLSConfigProvider) TLSConfig() (conf *tls.Config) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RootCAs implements the [TLSConfigProvider] interface for
|
||||
// EmptyTLSConfigProvider. It always returns nil.
|
||||
func (EmptyTLSConfigProvider) RootCAs() (root *x509.CertPool) {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -544,7 +544,12 @@ func (conf *ServerConfig) loadUpstreams(
|
|||
|
||||
upstreams = stringutil.SplitTrimmed(string(data), "\n")
|
||||
|
||||
l.DebugContext(ctx, "got upstreams", "number", len(upstreams), "filename", conf.UpstreamDNSFileName)
|
||||
l.DebugContext(
|
||||
ctx,
|
||||
"got upstreams",
|
||||
"number", len(upstreams),
|
||||
"filename", conf.UpstreamDNSFileName,
|
||||
)
|
||||
|
||||
return stringutil.FilterOut(upstreams, aghnet.IsCommentOrEmpty), nil
|
||||
}
|
||||
|
|
@ -652,7 +657,10 @@ func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error)
|
|||
|
||||
// ourAddrsSet returns an addrPortSet that contains all the configured listening
|
||||
// addresses. l must not be nil.
|
||||
func (conf *ServerConfig) ourAddrsSet(ctx context.Context, l *slog.Logger) (m addrPortSet, err error) {
|
||||
func (conf *ServerConfig) ourAddrsSet(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
) (m addrPortSet, err error) {
|
||||
addrs, unspecPorts := conf.collectDNSAddrs()
|
||||
switch {
|
||||
case addrs.Len() == 0:
|
||||
|
|
@ -781,8 +789,9 @@ func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
|||
return false
|
||||
}
|
||||
|
||||
// Called by 'tls' package when Client Hello is received
|
||||
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||
// onGetCertificate is called by [tls] package when Client Hello is received. If
|
||||
// the server name (from SNI) supplied by client is incorrect - we terminate the
|
||||
// ongoing TLS handshake.
|
||||
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.conf.TLSConf.StrictSNICheck && !anyNameMatches(s.dnsNames, ch.ServerName) {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
|
|
@ -798,8 +807,8 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
|
|||
return s.conf.TLSConf.Cert, nil
|
||||
}
|
||||
|
||||
// preparePlain prepares the plain-DNS configuration for the DNS proxy.
|
||||
// preparePlain assumes that prepareTLS has already been called.
|
||||
// preparePlain prepares the plain-DNS configuration for the DNS proxy. The
|
||||
// method assumes that prepareTLS has already been called.
|
||||
func (s *Server) preparePlain(ctx context.Context, proxyConf *proxy.Config) (err error) {
|
||||
if s.conf.ServePlainDNS {
|
||||
proxyConf.UDPListenAddr = s.conf.UDPListenAddrs
|
||||
|
|
|
|||
|
|
@ -90,7 +90,12 @@ func newUpstreamConfigValidator(
|
|||
// collectErrResults parses err and returns parsing results containing the
|
||||
// original upstream configuration line and the corresponding error. err can be
|
||||
// nil. l must not be nil.
|
||||
func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err error) (results []*parseResult) {
|
||||
func collectErrResults(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
lines []string,
|
||||
err error,
|
||||
) (results []*parseResult) {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -132,7 +137,7 @@ func collectErrResults(ctx context.Context, l *slog.Logger, lines []string, err
|
|||
}
|
||||
|
||||
// insertConfResults parses conf and inserts the upstream result into results.
|
||||
// It can insert multiple results as well as none.
|
||||
// It can insert multiple results as well as none. conf must not be nil.
|
||||
func insertConfResults(conf *proxy.UpstreamConfig, results map[string]*upstreamResult) {
|
||||
insertListResults(conf.Upstreams, results, false)
|
||||
|
||||
|
|
|
|||
|
|
@ -884,8 +884,8 @@ func (c *configuration) write(
|
|||
}
|
||||
|
||||
if tlsMgr != nil {
|
||||
tlsConf := tlsMgr.config()
|
||||
config.TLS = *tlsConf
|
||||
extTLSConf := tlsMgr.extendedTLSConfig()
|
||||
config.TLS = *extTLSConf
|
||||
}
|
||||
|
||||
if globalContext.stats != nil {
|
||||
|
|
|
|||
|
|
@ -178,6 +178,10 @@ type versionResponse struct {
|
|||
Disabled bool `json:"disabled"`
|
||||
}
|
||||
|
||||
// maxPrivilegedPort is the maximum port number. This only applies to Unix, as
|
||||
// on Windows, [aghnet.CanBindPrivilegedPorts] always returns `true`, `nil`.
|
||||
const maxPrivilegedPort = 1024
|
||||
|
||||
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
|
||||
// allowed to perform an automatic update by the OS. l and tlsMgr must not be
|
||||
// nil.
|
||||
|
|
@ -191,9 +195,9 @@ func (vr *versionResponse) setAllowedToAutoUpdate(
|
|||
}
|
||||
|
||||
canUpdate := true
|
||||
if tlsConfUsesPrivilegedPorts(tlsMgr.config()) ||
|
||||
config.HTTPConfig.Address.Port() < 1024 ||
|
||||
config.DNS.Port < 1024 {
|
||||
if tlsConfUsesPrivilegedPorts(tlsMgr.extendedTLSConfig()) ||
|
||||
config.HTTPConfig.Address.Port() < maxPrivilegedPort ||
|
||||
config.DNS.Port < maxPrivilegedPort {
|
||||
canUpdate, err = aghnet.CanBindPrivilegedPorts(ctx, l)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
|
||||
|
|
@ -206,9 +210,11 @@ func (vr *versionResponse) setAllowedToAutoUpdate(
|
|||
}
|
||||
|
||||
// tlsConfUsesPrivilegedPorts returns true if the provided TLS configuration
|
||||
// indicates that privileged ports are used.
|
||||
// indicates that privileged ports are used. c must be valid
|
||||
func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
return c.Enabled && (c.PortHTTPS < 1024 || c.PortDNSOverTLS < 1024 || c.PortDNSOverQUIC < 1024)
|
||||
return c.Enabled && (c.PortHTTPS < maxPrivilegedPort ||
|
||||
c.PortDNSOverTLS < maxPrivilegedPort ||
|
||||
c.PortDNSOverQUIC < maxPrivilegedPort)
|
||||
}
|
||||
|
||||
// finishUpdate completes an update procedure. It is intended to be used as a
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ func initDNSServer(
|
|||
dnsConf, err := newServerConfig(
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
tlsMgr.config(),
|
||||
tlsMgr.extendedTLSConfig(),
|
||||
config.HTTPConfig.DoH,
|
||||
tlsMgr,
|
||||
httpReg,
|
||||
|
|
@ -212,7 +212,8 @@ func parseSubnetSet(nets []netutil.Prefix) (s netutil.SubnetSet) {
|
|||
}
|
||||
}
|
||||
|
||||
func isRunning() bool {
|
||||
// isRunning checks whether the DNS server is running.
|
||||
func isRunning() (ok bool) {
|
||||
return globalContext.dnsServer != nil && globalContext.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
|
|
@ -262,7 +263,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
|
|||
func newServerConfig(
|
||||
dnsConf *dnsConfig,
|
||||
clientSrcConf *clientSourcesConfig,
|
||||
tlsConf *tlsConfigSettings,
|
||||
extTLSConf *tlsConfigSettings,
|
||||
dohConf *doHConfig,
|
||||
tlsMgr *tlsManager,
|
||||
httpReg aghhttp.Registrar,
|
||||
|
|
@ -274,7 +275,7 @@ func newServerConfig(
|
|||
fwdConf := dnsConf.Config
|
||||
fwdConf.ClientsContainer = clientsContainer
|
||||
|
||||
intTLSConf, err := newDNSTLSConfig(tlsConf, hosts, dohConf.InsecureEnabled)
|
||||
intTLSConf, err := newDNSTLSConfig(extTLSConf, hosts, dohConf.InsecureEnabled)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("constructing tls config: %w", err)
|
||||
}
|
||||
|
|
@ -322,19 +323,19 @@ func newServerConfig(
|
|||
}
|
||||
|
||||
// newDNSTLSConfig converts values from the configuration file into the internal
|
||||
// TLS settings for the DNS server. conf must not be nil.
|
||||
// TLS settings for the DNS server. extTLSConf must not be nil.
|
||||
func newDNSTLSConfig(
|
||||
conf *tlsConfigSettings,
|
||||
extTLSConf *tlsConfigSettings,
|
||||
addrs []netip.Addr,
|
||||
allowUnencryptedDoH bool,
|
||||
) (dnsConf *dnsforward.TLSConfig, err error) {
|
||||
if !conf.Enabled {
|
||||
if !extTLSConf.Enabled {
|
||||
return &dnsforward.TLSConfig{}, nil
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Add tracking for DNSCrypt configuration file changes to
|
||||
// the [aghtls.Manager].
|
||||
dnsCryptConf, err := newDNSCryptConfig(conf, addrs)
|
||||
dnsCryptConf, err := newDNSCryptConfig(extTLSConf, addrs)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return nil, err
|
||||
|
|
@ -342,23 +343,23 @@ func newDNSTLSConfig(
|
|||
|
||||
dnsConf = &dnsforward.TLSConfig{
|
||||
DNSCryptConf: dnsCryptConf,
|
||||
ServerName: conf.ServerName,
|
||||
StrictSNICheck: conf.StrictSNICheck,
|
||||
ServerName: extTLSConf.ServerName,
|
||||
StrictSNICheck: extTLSConf.StrictSNICheck,
|
||||
}
|
||||
|
||||
if conf.PortHTTPS != 0 {
|
||||
dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, conf.PortHTTPS)
|
||||
if extTLSConf.PortHTTPS != 0 {
|
||||
dnsConf.HTTPSListenAddrs = ipsToAddrPorts(addrs, extTLSConf.PortHTTPS)
|
||||
}
|
||||
|
||||
if conf.PortDNSOverTLS != 0 {
|
||||
dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, conf.PortDNSOverTLS)
|
||||
if extTLSConf.PortDNSOverTLS != 0 {
|
||||
dnsConf.TLSListenAddrs = ipsToTCPAddrs(addrs, extTLSConf.PortDNSOverTLS)
|
||||
}
|
||||
|
||||
if conf.PortDNSOverQUIC != 0 {
|
||||
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, conf.PortDNSOverQUIC)
|
||||
if extTLSConf.PortDNSOverQUIC != 0 {
|
||||
dnsConf.QUICListenAddrs = ipsToUDPAddrs(addrs, extTLSConf.PortDNSOverQUIC)
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(conf.CertificateChainData, conf.PrivateKeyData)
|
||||
cert, err := tls.X509KeyPair(extTLSConf.CertificateChainData, extTLSConf.PrivateKeyData)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("parsing tls key pair: %w", err)
|
||||
if allowUnencryptedDoH || dnsCryptConf != nil {
|
||||
|
|
@ -378,20 +379,20 @@ func newDNSTLSConfig(
|
|||
}
|
||||
|
||||
// newDNSCryptConfig converts values from the configuration file into the
|
||||
// internal DNSCrypt settings for the DNS server. conf must not be nil.
|
||||
// internal DNSCrypt settings for the DNS server. extTLSConf must not be nil.
|
||||
func newDNSCryptConfig(
|
||||
conf *tlsConfigSettings,
|
||||
extTLSConf *tlsConfigSettings,
|
||||
addrs []netip.Addr,
|
||||
) (dnsCryptConf *dnsforward.DNSCryptConfig, err error) {
|
||||
if conf.PortDNSCrypt == 0 {
|
||||
if extTLSConf.PortDNSCrypt == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if conf.DNSCryptConfigFile == "" {
|
||||
if extTLSConf.DNSCryptConfigFile == "" {
|
||||
return nil, fmt.Errorf("dnscrypt_config_file: %w", errors.ErrEmptyValue)
|
||||
}
|
||||
|
||||
f, err := os.Open(conf.DNSCryptConfigFile)
|
||||
f, err := os.Open(extTLSConf.DNSCryptConfigFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening dnscrypt config: %w", err)
|
||||
}
|
||||
|
|
@ -410,8 +411,8 @@ func newDNSCryptConfig(
|
|||
|
||||
return &dnsforward.DNSCryptConfig{
|
||||
ResolverCert: cert,
|
||||
UDPListenAddrs: ipsToUDPAddrs(addrs, conf.PortDNSCrypt),
|
||||
TCPListenAddrs: ipsToTCPAddrs(addrs, conf.PortDNSCrypt),
|
||||
UDPListenAddrs: ipsToUDPAddrs(addrs, extTLSConf.PortDNSCrypt),
|
||||
TCPListenAddrs: ipsToTCPAddrs(addrs, extTLSConf.PortDNSCrypt),
|
||||
ProviderName: rc.ProviderName,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -426,16 +427,16 @@ type dnsEncryption struct {
|
|||
// getDNSEncryption returns the TLS encryption addresses that AdGuard Home
|
||||
// listens on. tlsMgr must not be nil.
|
||||
func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
|
||||
tlsConf := tlsMgr.config()
|
||||
extTLSConf := tlsMgr.extendedTLSConfig()
|
||||
|
||||
if !tlsConf.Enabled || len(tlsConf.ServerName) == 0 {
|
||||
if !extTLSConf.Enabled || extTLSConf.ServerName == "" {
|
||||
return dnsEncryption{}
|
||||
}
|
||||
|
||||
hostname := tlsConf.ServerName
|
||||
if tlsConf.PortHTTPS != 0 {
|
||||
hostname := extTLSConf.ServerName
|
||||
if extTLSConf.PortHTTPS != 0 {
|
||||
addr := hostname
|
||||
if p := tlsConf.PortHTTPS; p != defaultPortHTTPS {
|
||||
if p := extTLSConf.PortHTTPS; p != defaultPortHTTPS {
|
||||
addr = netutil.JoinHostPort(addr, p)
|
||||
}
|
||||
|
||||
|
|
@ -446,14 +447,14 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
|
|||
}).String()
|
||||
}
|
||||
|
||||
if p := tlsConf.PortDNSOverTLS; p != 0 {
|
||||
if p := extTLSConf.PortDNSOverTLS; p != 0 {
|
||||
de.tls = (&url.URL{
|
||||
Scheme: "tls",
|
||||
Host: netutil.JoinHostPort(hostname, p),
|
||||
}).String()
|
||||
}
|
||||
|
||||
if p := tlsConf.PortDNSOverQUIC; p != 0 {
|
||||
if p := extTLSConf.PortDNSOverQUIC; p != 0 {
|
||||
de.quic = (&url.URL{
|
||||
Scheme: "quic",
|
||||
Host: netutil.JoinHostPort(hostname, p),
|
||||
|
|
@ -463,7 +464,9 @@ func getDNSEncryption(tlsMgr *tlsManager) (de dnsEncryption) {
|
|||
return de
|
||||
}
|
||||
|
||||
func startDNSServer() error {
|
||||
// startDNSServer starts the DNS server, clients container, filters, stats and
|
||||
// the query log.
|
||||
func startDNSServer() (err error) {
|
||||
config.RLock()
|
||||
defer config.RUnlock()
|
||||
|
||||
|
|
@ -475,7 +478,7 @@ func startDNSServer() error {
|
|||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
err := globalContext.clients.Start(ctx)
|
||||
err = globalContext.clients.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting clients container: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1269,22 +1269,24 @@ func printWebAddrs(ctx context.Context, l *slog.Logger, proto, addr string, port
|
|||
}
|
||||
|
||||
// printHTTPAddresses prints the IP addresses which user can use to access the
|
||||
// admin interface. proto is either schemeHTTP or schemeHTTPS.
|
||||
// admin interface. proto is either [urlutil.SchemeHTTPS] or
|
||||
// [urlutil.SchemeHTTP]. l must not be nil. If proto is [urlutil.SchemeHTTPS],
|
||||
// then tlsMgr must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Implement separate functions for HTTP and HTTPS.
|
||||
func printHTTPAddresses(ctx context.Context, l *slog.Logger, proto string, tlsMgr *tlsManager) {
|
||||
var tlsConf *tlsConfigSettings
|
||||
var extTLSConf *tlsConfigSettings
|
||||
if tlsMgr != nil {
|
||||
tlsConf = tlsMgr.config()
|
||||
extTLSConf = tlsMgr.extendedTLSConfig()
|
||||
}
|
||||
|
||||
port := config.HTTPConfig.Address.Port()
|
||||
if proto == urlutil.SchemeHTTPS {
|
||||
port = tlsConf.PortHTTPS
|
||||
port = extTLSConf.PortHTTPS
|
||||
}
|
||||
|
||||
if proto == urlutil.SchemeHTTPS && tlsConf.ServerName != "" {
|
||||
printWebAddrs(ctx, l, proto, tlsConf.ServerName, tlsConf.PortHTTPS)
|
||||
if proto == urlutil.SchemeHTTPS && extTLSConf.ServerName != "" {
|
||||
printWebAddrs(ctx, l, proto, extTLSConf.ServerName, extTLSConf.PortHTTPS)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
|||
14
internal/home/testdata/cert.pem
vendored
Normal file
14
internal/home/testdata/cert.pem
vendored
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
||||
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
||||
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
|
||||
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
|
||||
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
|
||||
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
|
||||
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
|
||||
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
|
||||
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
|
||||
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
|
||||
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
|
||||
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
|
||||
-----END CERTIFICATE-----
|
||||
16
internal/home/testdata/key.pem
vendored
Normal file
16
internal/home/testdata/key.pem
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
-----BEGIN PRIVATE KEY-----
|
||||
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
|
||||
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
|
||||
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
|
||||
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
|
||||
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
|
||||
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
|
||||
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
|
||||
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
|
||||
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
|
||||
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
|
||||
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
|
||||
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
|
||||
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
|
||||
kXS9jgARhhiWXJrk
|
||||
-----END PRIVATE KEY-----
|
||||
|
|
@ -36,7 +36,7 @@ type tlsManager struct {
|
|||
// logger is used for logging the operation of the TLS Manager.
|
||||
logger *slog.Logger
|
||||
|
||||
// mu protects status, certLastMod, conf, and servePlainDNS.
|
||||
// mu protects status, certLastMod, extTLSConf, and servePlainDNS.
|
||||
mu *sync.Mutex
|
||||
|
||||
// status is the current status of the configuration. It is never nil.
|
||||
|
|
@ -54,8 +54,12 @@ type tlsManager struct {
|
|||
// Resolve it.
|
||||
web *webAPI
|
||||
|
||||
// conf contains the TLS configuration settings. It must not be nil.
|
||||
conf *tlsConfigSettings
|
||||
// extTLSConf contains extended TLS configuration settings. It must not be
|
||||
// nil.
|
||||
// TODO(m.kazantsev): Add a field of a type of [*tls.Config] which will
|
||||
// represent the TLS settings. This is why these settings are called
|
||||
// 'extended'.
|
||||
extTLSConf *tlsConfigSettings
|
||||
|
||||
// confModifier is used to update the global configuration.
|
||||
confModifier agh.ConfigModifier
|
||||
|
|
@ -111,7 +115,7 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager,
|
|||
httpReg: conf.httpReg,
|
||||
manager: conf.manager,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: &conf.tlsSettings,
|
||||
extTLSConf: &conf.tlsSettings,
|
||||
servePlainDNS: conf.servePlainDNS,
|
||||
}
|
||||
|
||||
|
|
@ -133,21 +137,21 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager,
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.conf.Enabled {
|
||||
if !m.extTLSConf.Enabled {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
err = m.manager.Set(ctx, aghtls.TLSPair{
|
||||
CertPath: m.conf.CertificatePath,
|
||||
KeyPath: m.conf.PrivateKeyPath,
|
||||
CertPath: m.extTLSConf.CertificatePath,
|
||||
KeyPath: m.extTLSConf.PrivateKeyPath,
|
||||
})
|
||||
if err != nil {
|
||||
m.logger.ErrorContext(ctx, "setting tls files", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
err = m.loadTLSConfig(ctx, m.conf, m.status)
|
||||
err = m.loadTLSConfig(ctx, m.extTLSConf, m.status)
|
||||
if err != nil {
|
||||
m.conf.Enabled = false
|
||||
m.extTLSConf.Enabled = false
|
||||
|
||||
return m, err
|
||||
}
|
||||
|
|
@ -166,22 +170,22 @@ func (m *tlsManager) setWebAPI(webAPI *webAPI) {
|
|||
m.web = webAPI
|
||||
}
|
||||
|
||||
// config returns a deep copy of the stored TLS configuration.
|
||||
func (m *tlsManager) config() (conf *tlsConfigSettings) {
|
||||
// extendedTLSConfig returns a deep copy of the stored TLS configuration.
|
||||
func (m *tlsManager) extendedTLSConfig() (extTLSConf *tlsConfigSettings) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.conf.clone()
|
||||
return m.extTLSConf.clone()
|
||||
}
|
||||
|
||||
// setCertFileTime sets [tlsManager.certLastMod] from the certificate. If there
|
||||
// are errors, setCertFileTime logs them. m.mu is expected to be locked.
|
||||
func (m *tlsManager) setCertFileTime(ctx context.Context) {
|
||||
if len(m.conf.CertificatePath) == 0 {
|
||||
if len(m.extTLSConf.CertificatePath) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
fi, err := os.Stat(m.conf.CertificatePath)
|
||||
fi, err := os.Stat(m.extTLSConf.CertificatePath)
|
||||
if err != nil {
|
||||
m.logger.ErrorContext(ctx, "looking up certificate path", slogutil.KeyError, err)
|
||||
|
||||
|
|
@ -203,7 +207,7 @@ func (m *tlsManager) start(ctx context.Context) {
|
|||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
m.web.tlsConfigChanged(context.Background(), m.conf)
|
||||
m.web.tlsConfigChanged(context.Background(), m.extTLSConf)
|
||||
|
||||
go m.handleCertFileChange(ctx)
|
||||
}
|
||||
|
|
@ -235,7 +239,7 @@ func (m *tlsManager) reload(ctx context.Context) {
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tlsConfPtr := m.conf
|
||||
tlsConfPtr := m.extTLSConf
|
||||
|
||||
if !tlsConfPtr.Enabled || len(tlsConfPtr.CertificatePath) == 0 {
|
||||
return
|
||||
|
|
@ -267,7 +271,7 @@ func (m *tlsManager) reload(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
m.conf = &tlsConf
|
||||
m.extTLSConf = &tlsConf
|
||||
m.status = status
|
||||
|
||||
m.certLastMod = fi.ModTime().UTC()
|
||||
|
|
@ -280,7 +284,7 @@ func (m *tlsManager) reload(ctx context.Context) {
|
|||
// The background context is used because the TLSConfigChanged wraps context
|
||||
// with timeout on its own and shuts down the server, which handles current
|
||||
// request.
|
||||
m.web.tlsConfigChanged(context.Background(), m.conf)
|
||||
m.web.tlsConfigChanged(context.Background(), m.extTLSConf)
|
||||
}
|
||||
|
||||
// reconfigureDNSServer updates the DNS server configuration using the stored
|
||||
|
|
@ -289,7 +293,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
|
|||
newConf, err := newServerConfig(
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
m.conf,
|
||||
m.extTLSConf,
|
||||
config.HTTPConfig.DoH,
|
||||
m,
|
||||
m.httpReg,
|
||||
|
|
@ -314,7 +318,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
|
|||
// set in status.WarningValidation.
|
||||
func (m *tlsManager) loadTLSConfig(
|
||||
ctx context.Context,
|
||||
tlsConf *tlsConfigSettings,
|
||||
extTLSConf *tlsConfigSettings,
|
||||
status *tlsConfigStatus,
|
||||
) (err error) {
|
||||
defer func() {
|
||||
|
|
@ -327,13 +331,13 @@ func (m *tlsManager) loadTLSConfig(
|
|||
}
|
||||
}()
|
||||
|
||||
err = loadCertificateChainData(tlsConf)
|
||||
err = loadCertificateChainData(extTLSConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
err = loadPrivateKeyData(tlsConf)
|
||||
err = loadPrivateKeyData(extTLSConf)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
|
|
@ -342,9 +346,9 @@ func (m *tlsManager) loadTLSConfig(
|
|||
err = m.validateCertificates(
|
||||
ctx,
|
||||
status,
|
||||
tlsConf.CertificateChainData,
|
||||
tlsConf.PrivateKeyData,
|
||||
tlsConf.ServerName,
|
||||
extTLSConf.CertificateChainData,
|
||||
extTLSConf.PrivateKeyData,
|
||||
extTLSConf.ServerName,
|
||||
)
|
||||
|
||||
return errors.Annotate(err, "validating certificate pair: %w")
|
||||
|
|
@ -353,15 +357,15 @@ func (m *tlsManager) loadTLSConfig(
|
|||
// loadCertificateChainData loads PEM-encoded certificates chain data to the
|
||||
// TLS configuration. tlsConf must be not nil. tlsConf.CertificateChainData
|
||||
// struct field will be modified in case tlsConfig.CertificatePath is not an
|
||||
// empty string.
|
||||
func loadCertificateChainData(tlsConf *tlsConfigSettings) (err error) {
|
||||
tlsConf.CertificateChainData = []byte(tlsConf.CertificateChain)
|
||||
if tlsConf.CertificatePath != "" {
|
||||
if tlsConf.CertificateChain != "" {
|
||||
// empty string. extTLSConf must not be nil.
|
||||
func loadCertificateChainData(extTLSConf *tlsConfigSettings) (err error) {
|
||||
extTLSConf.CertificateChainData = []byte(extTLSConf.CertificateChain)
|
||||
if extTLSConf.CertificatePath != "" {
|
||||
if extTLSConf.CertificateChain != "" {
|
||||
return errors.Error("certificate data and file can't be set together")
|
||||
}
|
||||
|
||||
tlsConf.CertificateChainData, err = os.ReadFile(tlsConf.CertificatePath)
|
||||
extTLSConf.CertificateChainData, err = os.ReadFile(extTLSConf.CertificatePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading cert file: %w", err)
|
||||
}
|
||||
|
|
@ -373,14 +377,15 @@ func loadCertificateChainData(tlsConf *tlsConfigSettings) (err error) {
|
|||
// loadPrivateKeyData loads PEM-encoded private key data to the TLS
|
||||
// configuration. tlsConf must be not nil. tlsConf.PrivateKeyData struct field
|
||||
// will be modified in case tlsConfig.PrivateKeyPath is not an empty string.
|
||||
func loadPrivateKeyData(tlsConf *tlsConfigSettings) (err error) {
|
||||
tlsConf.PrivateKeyData = []byte(tlsConf.PrivateKey)
|
||||
if tlsConf.PrivateKeyPath != "" {
|
||||
if tlsConf.PrivateKey != "" {
|
||||
// extTLSConf must not be nil.
|
||||
func loadPrivateKeyData(extTLSConf *tlsConfigSettings) (err error) {
|
||||
extTLSConf.PrivateKeyData = []byte(extTLSConf.PrivateKey)
|
||||
if extTLSConf.PrivateKeyPath != "" {
|
||||
if extTLSConf.PrivateKey != "" {
|
||||
return errors.Error("private key data and file can't be set together")
|
||||
}
|
||||
|
||||
tlsConf.PrivateKeyData, err = os.ReadFile(tlsConf.PrivateKeyPath)
|
||||
extTLSConf.PrivateKeyData, err = os.ReadFile(extTLSConf.PrivateKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading key file: %w", err)
|
||||
}
|
||||
|
|
@ -460,7 +465,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
|||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tlsConf = m.conf.clone()
|
||||
tlsConf = m.extTLSConf.clone()
|
||||
servePlainDNS = m.servePlainDNS
|
||||
}()
|
||||
|
||||
|
|
@ -494,7 +499,7 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
|||
defer m.mu.Unlock()
|
||||
|
||||
if setts.PrivateKeySaved {
|
||||
setts.PrivateKey = m.conf.PrivateKey
|
||||
setts.PrivateKey = m.extTLSConf.PrivateKey
|
||||
}
|
||||
|
||||
if err = m.validateTLSSettings(setts); err != nil {
|
||||
|
|
@ -525,14 +530,14 @@ func (m *tlsManager) setConfig(
|
|||
status *tlsConfigStatus,
|
||||
servePlain aghalg.NullBool,
|
||||
) (restartHTTPS bool) {
|
||||
if !m.conf.setPrivateFieldsAndCompare(&newConf) {
|
||||
if !m.extTLSConf.setPrivateFieldsAndCompare(&newConf) {
|
||||
m.logger.InfoContext(ctx, "config has changed, restarting https server")
|
||||
restartHTTPS = true
|
||||
} else {
|
||||
m.logger.InfoContext(ctx, "config has not changed")
|
||||
}
|
||||
|
||||
m.conf = &newConf
|
||||
m.extTLSConf = &newConf
|
||||
|
||||
m.status = status
|
||||
|
||||
|
|
@ -587,10 +592,10 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
|||
defer m.mu.Unlock()
|
||||
|
||||
if req.PrivateKeySaved {
|
||||
req.PrivateKey = m.conf.PrivateKey
|
||||
req.PrivateKey = m.extTLSConf.PrivateKey
|
||||
}
|
||||
|
||||
req.StrictSNICheck = m.conf.StrictSNICheck
|
||||
req.StrictSNICheck = m.extTLSConf.StrictSNICheck
|
||||
|
||||
if err = m.validateTLSSettings(req); err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, m.logger, r, w, http.StatusBadRequest, "%s", err)
|
||||
|
|
|
|||
|
|
@ -25,44 +25,18 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TODO(s.chzhen): Consider moving to testdata.
|
||||
var testCertChainData = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICKzCCAZSgAwIBAgIJAMT9kPVJdM7LMA0GCSqGSIb3DQEBCwUAMC0xFDASBgNV
|
||||
BAoMC0FkR3VhcmQgTHRkMRUwEwYDVQQDDAxBZEd1YXJkIEhvbWUwHhcNMTkwMjI3
|
||||
MDkyNDIzWhcNNDYwNzE0MDkyNDIzWjAtMRQwEgYDVQQKDAtBZEd1YXJkIEx0ZDEV
|
||||
MBMGA1UEAwwMQWRHdWFyZCBIb21lMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB
|
||||
gQCwvwUnPJiOvLcOaWmGu6Y68ksFr13nrXBcsDlhxlXy8PaohVi3XxEmt2OrVjKW
|
||||
QFw/bdV4fZ9tdWFAVRRkgeGbIZzP7YBD1Ore/O5SQ+DbCCEafvjJCcXQIrTeKFE6
|
||||
i9G3aSMHs0Pwq2LgV8U5mYotLrvyFiE8QPInJbDDMpaFYwIDAQABo1MwUTAdBgNV
|
||||
HQ4EFgQUdLUmQpEqrhn4eKO029jYd2AAZEQwHwYDVR0jBBgwFoAUdLUmQpEqrhn4
|
||||
eKO029jYd2AAZEQwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOBgQB8
|
||||
LwlXfbakf7qkVTlCNXgoY7RaJ8rJdPgOZPoCTVToEhT6u/cb1c2qp8QB0dNExDna
|
||||
b0Z+dnODTZqQOJo6z/wIXlcUrnR4cQVvytXt8lFn+26l6Y6EMI26twC/xWr+1swq
|
||||
Muj4FeWHVDerquH4yMr1jsYLD3ci+kc5sbIX6TfVxQ==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
var testPrivateKeyData = []byte(`-----BEGIN PRIVATE KEY-----
|
||||
MIICeAIBADANBgkqhkiG9w0BAQEFAASCAmIwggJeAgEAAoGBALC/BSc8mI68tw5p
|
||||
aYa7pjrySwWvXeetcFywOWHGVfLw9qiFWLdfESa3Y6tWMpZAXD9t1Xh9n211YUBV
|
||||
FGSB4ZshnM/tgEPU6t787lJD4NsIIRp++MkJxdAitN4oUTqL0bdpIwezQ/CrYuBX
|
||||
xTmZii0uu/IWITxA8iclsMMyloVjAgMBAAECgYEAmjzoG1h27UDkIlB9BVWl95TP
|
||||
QVPLB81D267xNFDnWk1Lgr5zL/pnNjkdYjyjgpkBp1yKyE4gHV4skv5sAFWTcOCU
|
||||
QCgfPfUn/rDFcxVzAdJVWAa/CpJNaZgjTPR8NTGU+Ztod+wfBESNCP5tbnuw0GbL
|
||||
MuwdLQJGbzeJYpsNysECQQDfFHYoRNfgxHwMbX24GCoNZIgk12uDmGTA9CS5E+72
|
||||
9t3V1y4CfXxSkfhqNbd5RWrUBRLEw9BKofBS7L9NMDKDAkEAytQoIueE1vqEAaRg
|
||||
a3A1YDUekKesU5wKfKfKlXvNgB7Hwh4HuvoQS9RCvVhf/60Dvq8KSu6hSjkFRquj
|
||||
FQ5roQJBAMwKwyiCD5MfJPeZDmzcbVpiocRQ5Z4wPbffl9dRTDnIA5AciZDthlFg
|
||||
An/jMjZSMCxNl6UyFcqt5Et1EGVhuFECQQCZLXxaT+qcyHjlHJTMzuMgkz1QFbEp
|
||||
O5EX70gpeGQMPDK0QSWpaazg956njJSDbNCFM4BccrdQbJu1cW4qOsfBAkAMgZuG
|
||||
O88slmgTRHX4JGFmy3rrLiHNI2BbJSuJ++Yllz8beVzh6NfvuY+HKRCmPqoBPATU
|
||||
kXS9jgARhhiWXJrk
|
||||
-----END PRIVATE KEY-----`)
|
||||
// Paths to the test TLS-related data.
|
||||
const (
|
||||
testCertificatePath = "./testdata/cert.pem"
|
||||
testPrivateKeyPath = "./testdata/key.pem"
|
||||
)
|
||||
|
||||
func TestValidateCertificates(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
|
@ -92,6 +66,10 @@ func TestValidateCertificates(t *testing.T) {
|
|||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
status := &tlsConfigStatus{}
|
||||
|
||||
testCertChainData := requireReadFile(t, testCertificatePath)
|
||||
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
|
||||
|
||||
err = m.validateCertificates(ctx, status, testCertChainData, testPrivateKeyData, "")
|
||||
assert.Error(t, err)
|
||||
|
||||
|
|
@ -213,7 +191,7 @@ func newCertAndKey(tb testing.TB, n int64) (certDER []byte, key *rsa.PrivateKey)
|
|||
}
|
||||
|
||||
// writeCertAndKey is a helper function that writes certificate and key to
|
||||
// specified paths.
|
||||
// specified paths. key must not be nil.
|
||||
func writeCertAndKey(
|
||||
tb testing.TB,
|
||||
certDER []byte,
|
||||
|
|
@ -310,8 +288,8 @@ func TestTLSManager_Reload(t *testing.T) {
|
|||
web := newTestWeb(t, &webConfig{})
|
||||
m.setWebAPI(web)
|
||||
|
||||
conf := m.config()
|
||||
assertCertSerialNumber(t, conf, snBefore)
|
||||
extTLSConf := m.extendedTLSConfig()
|
||||
assertCertSerialNumber(t, extTLSConf, snBefore)
|
||||
|
||||
certDER, key = newCertAndKey(t, snAfter)
|
||||
writeCertAndKey(t, certDER, certPath, key, keyPath)
|
||||
|
|
@ -324,8 +302,8 @@ func TestTLSManager_Reload(t *testing.T) {
|
|||
return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
conf = m.config()
|
||||
assertCertSerialNumber(t, conf, snAfter)
|
||||
extTLSConf = m.extendedTLSConfig()
|
||||
assertCertSerialNumber(t, extTLSConf, snAfter)
|
||||
}
|
||||
|
||||
func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
||||
|
|
@ -334,13 +312,16 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
|||
err error
|
||||
)
|
||||
|
||||
testCertChain := requireReadFile(t, testCertificatePath)
|
||||
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
|
||||
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
manager: aghtls.EmptyManager{},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
CertificateChain: string(testCertChainData),
|
||||
CertificateChain: string(testCertChain),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
},
|
||||
servePlainDNS: false,
|
||||
|
|
@ -355,7 +336,7 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
|||
err = json.NewDecoder(w.Body).Decode(res)
|
||||
require.NoError(t, err)
|
||||
|
||||
wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChainData)
|
||||
wantCertificateChain := base64.StdEncoding.EncodeToString(testCertChain)
|
||||
assert.True(t, res.Enabled)
|
||||
assert.Equal(t, wantCertificateChain, res.CertificateChain)
|
||||
assert.True(t, res.PrivateKeySaved)
|
||||
|
|
@ -470,9 +451,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
|||
confModifier: agh.EmptyConfigModifier{},
|
||||
manager: aghtls.EmptyManager{},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
CertificateChain: string(testCertChainData),
|
||||
PrivateKey: string(testPrivateKeyData),
|
||||
Enabled: true,
|
||||
CertificatePath: testCertificatePath,
|
||||
PrivateKeyPath: testPrivateKeyPath,
|
||||
},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
|
|
@ -483,9 +464,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
|||
|
||||
setts := &tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
|
||||
Enabled: true,
|
||||
CertificatePath: testCertificatePath,
|
||||
PrivateKeyPath: testPrivateKeyPath,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -500,6 +481,9 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
|||
err = json.NewDecoder(w.Body).Decode(res)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCertChainData := requireReadFile(t, testCertificatePath)
|
||||
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
|
||||
|
||||
cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -541,7 +525,7 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")}
|
||||
config.DNS.BindHosts = []netip.Addr{netutil.IPv4Localhost()}
|
||||
config.DNS.Port = 0
|
||||
|
||||
const wantSerialNumber int64 = 1
|
||||
|
|
@ -571,16 +555,16 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
|||
web := newTestWeb(t, &webConfig{})
|
||||
m.setWebAPI(web)
|
||||
|
||||
conf := m.config()
|
||||
assertCertSerialNumber(t, conf, wantSerialNumber)
|
||||
extTLSConf := m.extendedTLSConfig()
|
||||
assertCertSerialNumber(t, extTLSConf, wantSerialNumber)
|
||||
|
||||
// Prepare a request with the new TLS configuration.
|
||||
setts := &tlsConfigSettingsExt{
|
||||
tlsConfigSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
PortHTTPS: 4433,
|
||||
CertificateChain: base64.StdEncoding.EncodeToString(testCertChainData),
|
||||
PrivateKey: base64.StdEncoding.EncodeToString(testPrivateKeyData),
|
||||
Enabled: true,
|
||||
PortHTTPS: 4433,
|
||||
CertificatePath: testCertificatePath,
|
||||
PrivateKeyPath: testPrivateKeyPath,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -606,6 +590,9 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
|||
err = json.NewDecoder(w.Body).Decode(res)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCertChainData := requireReadFile(t, testCertificatePath)
|
||||
testPrivateKeyData := requireReadFile(t, testPrivateKeyPath)
|
||||
|
||||
cert, err := tls.X509KeyPair(testCertChainData, testPrivateKeyData)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -629,3 +616,15 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
|||
return true
|
||||
}, testTimeout, testTimeout/10)
|
||||
}
|
||||
|
||||
// requireReadFile reads the file at the specified path and returns its content.
|
||||
//
|
||||
// TODO(m.kazantsev): Move to golibs/testutil.
|
||||
func requireReadFile(tb testing.TB, path string) (data []byte) {
|
||||
tb.Helper()
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
require.NoError(tb, err)
|
||||
|
||||
return data
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue