Pull request 2560: AGDNS-3568-refresh-tls

Squashed commit of the following:

commit 74888f2e45fc95fe1e992d87c3f527cc7a8c8628
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Jan 20 20:34:29 2026 +0300

    aghos: use sets

commit 1110f57a45efdc805c017746f291befbe598e128
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Jan 20 19:57:50 2026 +0300

    all: imp code, docs

commit b9ede8aba0b1b6f3e48e57fa4c72e2e5ff91d30c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Jan 20 19:50:41 2026 +0300

    aghtls: imp interface

commit 02e668e7d6ee871837fbbf017b0d1e1ea09a0de5
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Jan 20 18:44:04 2026 +0300

    all: use test util, fix code

commit 7a57d6abb46a7302ca94d9aa77eedde6bd5825b7
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Jan 20 18:05:09 2026 +0300

    all: imp fswatcher, add tls manager
This commit is contained in:
Eugene Burkov 2026-01-21 12:21:48 +00:00
parent 4c360e4ae8
commit 32eb727125
8 changed files with 343 additions and 80 deletions

View file

@ -69,14 +69,12 @@ func TestNewHostsContainer(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) {
panic(testutil.UnexpectedCall(ctx))
},
OnEvents: onEvents,
OnAdd: onAdd,
OnShutdown: func(_ context.Context) (err error) { return nil },
}, tc.paths...)
watcher := aghtest.NewFSWatcher()
watcher.OnEvents = onEvents
watcher.OnAdd = onAdd
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, watcher, tc.paths...)
if tc.wantErr != nil {
require.ErrorIs(t, err, tc.wantErr)
@ -99,15 +97,13 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_fs", func(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
require.Panics(t, func() {
_, _ = aghnet.NewHostsContainer(ctx, testLogger, nil, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) {
panic(testutil.UnexpectedCall(ctx))
},
// Those shouldn't panic.
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(_ string) (err error) { return nil },
OnShutdown: func(_ context.Context) (err error) { return nil },
}, p)
watcher := aghtest.NewFSWatcher()
// Those shouldn't panic.
watcher.OnAdd = func(_ string) (err error) { return nil }
watcher.OnEvents = func() (e <-chan struct{}) { return nil }
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
_, _ = aghnet.NewHostsContainer(ctx, testLogger, nil, watcher, p)
})
})
@ -121,12 +117,9 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("err_watcher", func(t *testing.T) {
const errOnAdd errors.Error = "error"
errWatcher := &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (_ <-chan struct{}) { panic(testutil.UnexpectedCall()) },
OnAdd: func(_ string) (err error) { return errOnAdd },
OnShutdown: func(_ context.Context) (err error) { return nil },
}
errWatcher := aghtest.NewFSWatcher()
errWatcher.OnAdd = func(_ string) (err error) { return errOnAdd }
errWatcher.OnShutdown = func(_ context.Context) (err error) { return nil }
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, errWatcher, p)
@ -167,16 +160,14 @@ func TestHostsContainer_refresh(t *testing.T) {
eventsCh := make(chan event, 1)
t.Cleanup(func() { close(eventsCh) })
w := &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (e <-chan event) { return eventsCh },
OnAdd: func(name string) (err error) {
assert.Equal(t, "dir", name)
w := aghtest.NewFSWatcher()
w.OnEvents = func() (e <-chan event) { return eventsCh }
w.OnAdd = func(name string) (err error) {
assert.Equal(t, "dir", name)
return nil
},
OnShutdown: func(_ context.Context) (err error) { return nil },
return nil
}
w.OnShutdown = func(_ context.Context) (err error) { return nil }
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, w, "dir")

View file

@ -6,6 +6,7 @@ import (
"io/fs"
"log/slog"
"path/filepath"
"sync"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
@ -15,11 +16,11 @@ import (
"github.com/fsnotify/fsnotify"
)
// event is a convenient alias for an empty struct to signal that watching
// Event is a convenient alias for an empty struct to signal that watched file
// event happened.
type event = struct{}
type Event = struct{}
// FSWatcher tracks all the fyle system events and notifies about those.
// FSWatcher tracks all the file system events and notifies about those.
//
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
//
@ -28,11 +29,14 @@ type FSWatcher interface {
service.Interface
// Events returns the channel to notify about the file system events.
Events() (e <-chan event)
Events() (e <-chan Event)
// Add starts tracking the file. It returns an error if the file can't be
// tracked. It must not be called after Start.
// tracked.
Add(name string) (err error)
// Remove stops tracking the file.
Remove(name string) (err error)
}
// osWatcher tracks the file system provided by the OS.
@ -40,14 +44,21 @@ type osWatcher struct {
// logger is used for logging the operations of the osWatcher.
logger *slog.Logger
// fsys is the file system to track.
fsys fs.FS
// filesMu protects files.
filesMu *sync.RWMutex
// watcher is the actual notifier that is handled by osWatcher.
watcher *fsnotify.Watcher
// events is the channel to notify.
events chan event
events chan Event
// files is the set of tracked files.
files *container.MapSet[string]
// files maps directories to the files tracked in them. If the tracked file
// is a directory, it is mapped to itself.
files map[string]*container.MapSet[string]
}
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
@ -59,17 +70,18 @@ const osWatcherPref = "os watcher"
func NewOSWritesWatcher(l *slog.Logger) (w FSWatcher, err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
var watcher *fsnotify.Watcher
watcher, err = fsnotify.NewWatcher()
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("creating watcher: %w", err)
}
return &osWatcher{
logger: l,
fsys: osutil.RootDirFS(),
filesMu: &sync.RWMutex{},
watcher: watcher,
events: make(chan event, 1),
files: container.NewMapSet[string](),
events: make(chan Event, 1),
files: map[string]*container.MapSet[string]{},
}, nil
}
@ -90,7 +102,7 @@ func (w *osWatcher) Shutdown(_ context.Context) (err error) {
}
// Events implements the FSWatcher interface for *osWatcher.
func (w *osWatcher) Events() (e <-chan event) {
func (w *osWatcher) Events() (e <-chan Event) {
return w.events
}
@ -100,24 +112,77 @@ func (w *osWatcher) Events() (e <-chan event) {
func (w *osWatcher) Add(name string) (err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
fi, err := fs.Stat(osutil.RootDirFS(), name)
fi, err := fs.Stat(w.fsys, name)
if err != nil {
return fmt.Errorf("checking file %q: %w", name, err)
}
name = filepath.Join("/", name)
w.files.Add(name)
// Watch the directory and filter the events by the file name, since the
// common recomendation to the fsnotify package is to watch the directory
// instead of the file itself.
//
// See https://pkg.go.dev/github.com/fsnotify/fsnotify@v1.7.0#readme-watching-a-file-doesn-t-work-well.
dirName := name
if !fi.IsDir() {
name = filepath.Dir(name)
dirName = filepath.Dir(name)
}
return w.watcher.Add(name)
w.filesMu.Lock()
defer w.filesMu.Unlock()
names := w.files[dirName]
if names == nil {
names = container.NewMapSet[string]()
w.files[dirName] = names
}
names.Add(name)
err = w.watcher.Add(dirName)
if err != nil {
return fmt.Errorf("adding %q: %w", dirName, err)
}
return nil
}
// Remove implements the [FSWatcher] interface for *osWatcher.
func (w *osWatcher) Remove(name string) (err error) {
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
dirName := filepath.Dir(name)
w.filesMu.Lock()
defer w.filesMu.Unlock()
names, ok := w.files[name]
if ok {
dirName = name
} else {
names = w.files[dirName]
}
if !names.Has(name) {
// Name is not tracked.
return nil
}
names.Delete(name)
if names.Len() > 0 {
// Some files are still tracked in the directory.
return nil
}
// No more files tracked in the directory, unwatch it.
delete(w.files, dirName)
err = w.watcher.Remove(dirName)
if err != nil {
return fmt.Errorf("removing %q: %w", dirName, err)
}
return err
}
// handleEvents notifies about the received file system's event if needed. It
@ -129,14 +194,14 @@ func (w *osWatcher) handleEvents(ctx context.Context) {
ch := w.watcher.Events
for e := range ch {
if e.Op&fsnotify.Write == 0 || !w.files.Has(e.Name) {
if e.Op&fsnotify.Write == 0 || !w.isTrackedFile(e.Name) {
continue
}
skipDuplicates(ch)
select {
case w.events <- event{}:
case w.events <- Event{}:
// Go on.
default:
w.logger.DebugContext(ctx, "events buffer is full")
@ -144,6 +209,21 @@ func (w *osWatcher) handleEvents(ctx context.Context) {
}
}
// isTrackedFile returns true if the file is tracked.
func (w *osWatcher) isTrackedFile(name string) (isDir bool) {
dirName := filepath.Dir(name)
w.filesMu.RLock()
defer w.filesMu.RUnlock()
names, isDir := w.files[name]
if !isDir {
names = w.files[dirName]
}
return names.Has(name)
}
// skipDuplicates drains the given channel of events, assuming that some events
// might occur multiple times.
func skipDuplicates(ch <-chan fsnotify.Event) {
@ -188,7 +268,7 @@ func (EmptyFSWatcher) Shutdown(_ context.Context) (err error) {
// Events implements the [FSWatcher] interface for EmptyFSWatcher. It always
// returns nil channel.
func (EmptyFSWatcher) Events() (e <-chan event) {
func (EmptyFSWatcher) Events() (e <-chan Event) {
return nil
}
@ -197,3 +277,9 @@ func (EmptyFSWatcher) Events() (e <-chan event) {
func (EmptyFSWatcher) Add(_ string) (err error) {
return nil
}
// Remove implements the [FSWatcher] interface for EmptyFSWatcher. It always
// returns nil error.
func (EmptyFSWatcher) Remove(_ string) (err error) {
return nil
}

View file

@ -13,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
)
@ -20,8 +21,20 @@ import (
type FSWatcher struct {
OnStart func(ctx context.Context) (err error)
OnShutdown func(ctx context.Context) (err error)
OnEvents func() (e <-chan struct{})
OnEvents func() (e <-chan aghos.Event)
OnAdd func(name string) (err error)
OnRemove func(name string) (err error)
}
// NewFSWatcher returns a new *FSWatcher all methods of which panic.
func NewFSWatcher() (w *FSWatcher) {
return &FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnShutdown: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (_ <-chan aghos.Event) { panic(testutil.UnexpectedCall()) },
OnAdd: func(name string) (_ error) { panic(testutil.UnexpectedCall(name)) },
OnRemove: func(name string) (_ error) { panic(testutil.UnexpectedCall(name)) },
}
}
// type check
@ -38,7 +51,7 @@ func (w *FSWatcher) Shutdown(ctx context.Context) (err error) {
}
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
func (w *FSWatcher) Events() (e <-chan aghos.Event) {
return w.OnEvents()
}
@ -47,6 +60,11 @@ func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Remove implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Remove(name string) (err error) {
return w.OnRemove(name)
}
// ServiceWithConfig is a fake [nextagh.ServiceWithConfig] implementation for
// tests.
type ServiceWithConfig[ConfigType any] struct {

View file

@ -0,0 +1,135 @@
package aghtls
import (
"context"
"fmt"
"log/slog"
"sync/atomic"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// DefaultManagerConfig is the configuration structure for [NewDefaultManager].
type DefaultManagerConfig struct {
// Logger is used for logging the operation of the manager. It must not be
// nil.
Logger *slog.Logger
// Watcher is used to watch the TLS certificate and key files. It must not
// be nil.
Watcher aghos.FSWatcher
}
// DefaultManager is the default implementation of the [Manager] interface.
//
// TODO(e.burkov): Use.
type DefaultManager struct {
logger *slog.Logger
pair *atomic.Pointer[TLSPair]
updates chan UpdateSignal
watcher aghos.FSWatcher
}
// NewDefaultManager returns a new properly initialized default manager.
func NewDefaultManager(c *DefaultManagerConfig) (mgr *DefaultManager) {
return &DefaultManager{
logger: c.Logger,
pair: &atomic.Pointer[TLSPair]{},
// Buffer the channel to avoid missing updates.
updates: make(chan UpdateSignal, 1),
watcher: c.Watcher,
}
}
// type check
var _ Manager = (*DefaultManager)(nil)
// Set implements the [Manager] interface for *DefaultManager.
func (mgr *DefaultManager) Set(ctx context.Context, certKey *TLSPair) (err error) {
old := mgr.pair.Swap(certKey)
if old != nil {
err = errors.Join(
mgr.watcher.Remove(old.CertPath),
mgr.watcher.Remove(old.KeyPath),
)
if err != nil {
return fmt.Errorf("removing old certificate and key: %w", err)
}
}
if certKey != nil {
err = errors.Join(
mgr.watcher.Add(certKey.CertPath),
mgr.watcher.Add(certKey.KeyPath),
)
if err != nil {
return fmt.Errorf("adding new certificate and key: %w", err)
}
}
return nil
}
// Refresh implements the [service.Refresher] interface for *DefaultManager.
func (mgr *DefaultManager) Refresh(ctx context.Context) (err error) {
select {
case mgr.updates <- UpdateSignal{}:
return nil
case <-ctx.Done():
return fmt.Errorf("refreshing: %w", ctx.Err())
default:
return nil
}
}
// Start implements the [service.Interface] interface for *DefaultManager.
func (mgr *DefaultManager) Start(ctx context.Context) (err error) {
err = mgr.watcher.Start(ctx)
if err != nil {
return fmt.Errorf("starting watcher: %w", err)
}
go mgr.handleEvents(ctx)
return nil
}
// Shutdown implements the [service.Interface] interface for *DefaultManager.
func (mgr *DefaultManager) Shutdown(ctx context.Context) (err error) {
defer close(mgr.updates)
err = mgr.watcher.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down watcher: %w", err)
}
return nil
}
// Updates implements the [Manager] interface for *DefaultManager.
func (mgr *DefaultManager) Updates(ctx context.Context) (updates <-chan UpdateSignal) {
return mgr.updates
}
// handleEvents handles changes of the tracked files. It is intended to be run
// in a separate goroutine.
func (mgr *DefaultManager) handleEvents(ctx context.Context) {
defer slogutil.RecoverAndLog(ctx, mgr.logger)
eventsCh := mgr.watcher.Events()
if eventsCh == nil {
mgr.logger.DebugContext(ctx, "watcher does not emit events")
return
}
for range eventsCh {
err := mgr.Refresh(ctx)
if err != nil {
mgr.logger.ErrorContext(ctx, "refreshing", slogutil.KeyError, err)
}
}
}

View file

@ -0,0 +1,36 @@
package aghtls
import (
"context"
"github.com/AdguardTeam/golibs/service"
)
// TLSPair is a pair of paths to a certificate and a key.
type TLSPair struct {
// CertPath is the path to the certificate.
CertPath string
// KeyPath is the path to the key.
KeyPath string
}
// UpdateSignal is the signal that the TLS certificate and key have been
// updated.
type UpdateSignal struct{}
// Manager manages TLS certificates and keys updates.
type Manager interface {
service.Interface
service.Refresher
// Set sets the TLS certificate and key.
Set(ctx context.Context, certKey *TLSPair) (err error)
// Updates returns a channel that emits signals when the TLS certificate
// and/or key have been updated.
//
// TODO(e.burkov): Move reloading logic to the manager and get rid of this
// method.
Updates(ctx context.Context) (updates <-chan UpdateSignal)
}

View file

@ -2,7 +2,6 @@ package dnsforward
import (
"cmp"
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
@ -24,6 +23,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -1516,21 +1516,20 @@ func TestPTRResponseFromHosts(t *testing.T) {
}
var eventsCalledCounter uint32
watcher := aghtest.NewFSWatcher()
watcher.OnEvents = func() (e <-chan aghos.Event) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
return nil
}
watcher.OnAdd = func(name string) (err error) {
assert.Equal(t, hostsFilename, name)
return nil
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
return nil
},
OnAdd: func(name string) (err error) {
assert.Equal(t, hostsFilename, name)
return nil
},
OnShutdown: func(ctx context.Context) (err error) { panic(testutil.UnexpectedCall(ctx)) },
}, hostsFilename)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, watcher, hostsFilename)
require.NoError(t, err)
t.Cleanup(func() {
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))

View file

@ -366,6 +366,11 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
}).String()
watcher := aghtest.NewFSWatcher()
watcher.OnEvents = func() (e <-chan struct{}) { return nil }
watcher.OnAdd = func(_ string) (err error) { return nil }
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(
ctx,
@ -375,12 +380,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),
},
},
&aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(_ string) (err error) { return nil },
OnShutdown: func(_ context.Context) (err error) { return nil },
},
watcher,
hostsFileName,
)
require.NoError(t, err)

View file

@ -42,12 +42,10 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
Data: []byte(data),
},
}
watcher := &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil },
OnShutdown: func(_ context.Context) (err error) { return nil },
}
watcher := aghtest.NewFSWatcher()
watcher.OnEvents = func() (e <-chan struct{}) { return nil }
watcher.OnAdd = func(name string) (err error) { return nil }
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, files, watcher, "hosts")