mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2026-06-28 03:41:19 +00:00
Pull request 2560: AGDNS-3568-refresh-tls
Squashed commit of the following:
commit 74888f2e45fc95fe1e992d87c3f527cc7a8c8628
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Jan 20 20:34:29 2026 +0300
aghos: use sets
commit 1110f57a45efdc805c017746f291befbe598e128
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Jan 20 19:57:50 2026 +0300
all: imp code, docs
commit b9ede8aba0b1b6f3e48e57fa4c72e2e5ff91d30c
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Jan 20 19:50:41 2026 +0300
aghtls: imp interface
commit 02e668e7d6ee871837fbbf017b0d1e1ea09a0de5
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Jan 20 18:44:04 2026 +0300
all: use test util, fix code
commit 7a57d6abb46a7302ca94d9aa77eedde6bd5825b7
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date: Tue Jan 20 18:05:09 2026 +0300
all: imp fswatcher, add tls manager
This commit is contained in:
parent
4c360e4ae8
commit
32eb727125
8 changed files with 343 additions and 80 deletions
|
|
@ -69,14 +69,12 @@ func TestNewHostsContainer(t *testing.T) {
|
|||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) {
|
||||
panic(testutil.UnexpectedCall(ctx))
|
||||
},
|
||||
OnEvents: onEvents,
|
||||
OnAdd: onAdd,
|
||||
OnShutdown: func(_ context.Context) (err error) { return nil },
|
||||
}, tc.paths...)
|
||||
watcher := aghtest.NewFSWatcher()
|
||||
watcher.OnEvents = onEvents
|
||||
watcher.OnAdd = onAdd
|
||||
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
|
||||
|
||||
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, watcher, tc.paths...)
|
||||
if tc.wantErr != nil {
|
||||
require.ErrorIs(t, err, tc.wantErr)
|
||||
|
||||
|
|
@ -99,15 +97,13 @@ func TestNewHostsContainer(t *testing.T) {
|
|||
t.Run("nil_fs", func(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
require.Panics(t, func() {
|
||||
_, _ = aghnet.NewHostsContainer(ctx, testLogger, nil, &aghtest.FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) {
|
||||
panic(testutil.UnexpectedCall(ctx))
|
||||
},
|
||||
// Those shouldn't panic.
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(_ string) (err error) { return nil },
|
||||
OnShutdown: func(_ context.Context) (err error) { return nil },
|
||||
}, p)
|
||||
watcher := aghtest.NewFSWatcher()
|
||||
// Those shouldn't panic.
|
||||
watcher.OnAdd = func(_ string) (err error) { return nil }
|
||||
watcher.OnEvents = func() (e <-chan struct{}) { return nil }
|
||||
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
|
||||
|
||||
_, _ = aghnet.NewHostsContainer(ctx, testLogger, nil, watcher, p)
|
||||
})
|
||||
})
|
||||
|
||||
|
|
@ -121,12 +117,9 @@ func TestNewHostsContainer(t *testing.T) {
|
|||
t.Run("err_watcher", func(t *testing.T) {
|
||||
const errOnAdd errors.Error = "error"
|
||||
|
||||
errWatcher := &aghtest.FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
OnEvents: func() (_ <-chan struct{}) { panic(testutil.UnexpectedCall()) },
|
||||
OnAdd: func(_ string) (err error) { return errOnAdd },
|
||||
OnShutdown: func(_ context.Context) (err error) { return nil },
|
||||
}
|
||||
errWatcher := aghtest.NewFSWatcher()
|
||||
errWatcher.OnAdd = func(_ string) (err error) { return errOnAdd }
|
||||
errWatcher.OnShutdown = func(_ context.Context) (err error) { return nil }
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, errWatcher, p)
|
||||
|
|
@ -167,16 +160,14 @@ func TestHostsContainer_refresh(t *testing.T) {
|
|||
eventsCh := make(chan event, 1)
|
||||
t.Cleanup(func() { close(eventsCh) })
|
||||
|
||||
w := &aghtest.FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
OnEvents: func() (e <-chan event) { return eventsCh },
|
||||
OnAdd: func(name string) (err error) {
|
||||
assert.Equal(t, "dir", name)
|
||||
w := aghtest.NewFSWatcher()
|
||||
w.OnEvents = func() (e <-chan event) { return eventsCh }
|
||||
w.OnAdd = func(name string) (err error) {
|
||||
assert.Equal(t, "dir", name)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnShutdown: func(_ context.Context) (err error) { return nil },
|
||||
return nil
|
||||
}
|
||||
w.OnShutdown = func(_ context.Context) (err error) { return nil }
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, w, "dir")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"io/fs"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
|
|
@ -15,11 +16,11 @@ import (
|
|||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// event is a convenient alias for an empty struct to signal that watching
|
||||
// Event is a convenient alias for an empty struct to signal that watched file
|
||||
// event happened.
|
||||
type event = struct{}
|
||||
type Event = struct{}
|
||||
|
||||
// FSWatcher tracks all the fyle system events and notifies about those.
|
||||
// FSWatcher tracks all the file system events and notifies about those.
|
||||
//
|
||||
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||
//
|
||||
|
|
@ -28,11 +29,14 @@ type FSWatcher interface {
|
|||
service.Interface
|
||||
|
||||
// Events returns the channel to notify about the file system events.
|
||||
Events() (e <-chan event)
|
||||
Events() (e <-chan Event)
|
||||
|
||||
// Add starts tracking the file. It returns an error if the file can't be
|
||||
// tracked. It must not be called after Start.
|
||||
// tracked.
|
||||
Add(name string) (err error)
|
||||
|
||||
// Remove stops tracking the file.
|
||||
Remove(name string) (err error)
|
||||
}
|
||||
|
||||
// osWatcher tracks the file system provided by the OS.
|
||||
|
|
@ -40,14 +44,21 @@ type osWatcher struct {
|
|||
// logger is used for logging the operations of the osWatcher.
|
||||
logger *slog.Logger
|
||||
|
||||
// fsys is the file system to track.
|
||||
fsys fs.FS
|
||||
|
||||
// filesMu protects files.
|
||||
filesMu *sync.RWMutex
|
||||
|
||||
// watcher is the actual notifier that is handled by osWatcher.
|
||||
watcher *fsnotify.Watcher
|
||||
|
||||
// events is the channel to notify.
|
||||
events chan event
|
||||
events chan Event
|
||||
|
||||
// files is the set of tracked files.
|
||||
files *container.MapSet[string]
|
||||
// files maps directories to the files tracked in them. If the tracked file
|
||||
// is a directory, it is mapped to itself.
|
||||
files map[string]*container.MapSet[string]
|
||||
}
|
||||
|
||||
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
|
||||
|
|
@ -59,17 +70,18 @@ const osWatcherPref = "os watcher"
|
|||
func NewOSWritesWatcher(l *slog.Logger) (w FSWatcher, err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||
|
||||
var watcher *fsnotify.Watcher
|
||||
watcher, err = fsnotify.NewWatcher()
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating watcher: %w", err)
|
||||
}
|
||||
|
||||
return &osWatcher{
|
||||
logger: l,
|
||||
fsys: osutil.RootDirFS(),
|
||||
filesMu: &sync.RWMutex{},
|
||||
watcher: watcher,
|
||||
events: make(chan event, 1),
|
||||
files: container.NewMapSet[string](),
|
||||
events: make(chan Event, 1),
|
||||
files: map[string]*container.MapSet[string]{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -90,7 +102,7 @@ func (w *osWatcher) Shutdown(_ context.Context) (err error) {
|
|||
}
|
||||
|
||||
// Events implements the FSWatcher interface for *osWatcher.
|
||||
func (w *osWatcher) Events() (e <-chan event) {
|
||||
func (w *osWatcher) Events() (e <-chan Event) {
|
||||
return w.events
|
||||
}
|
||||
|
||||
|
|
@ -100,24 +112,77 @@ func (w *osWatcher) Events() (e <-chan event) {
|
|||
func (w *osWatcher) Add(name string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||
|
||||
fi, err := fs.Stat(osutil.RootDirFS(), name)
|
||||
fi, err := fs.Stat(w.fsys, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking file %q: %w", name, err)
|
||||
}
|
||||
|
||||
name = filepath.Join("/", name)
|
||||
w.files.Add(name)
|
||||
|
||||
// Watch the directory and filter the events by the file name, since the
|
||||
// common recomendation to the fsnotify package is to watch the directory
|
||||
// instead of the file itself.
|
||||
//
|
||||
// See https://pkg.go.dev/github.com/fsnotify/fsnotify@v1.7.0#readme-watching-a-file-doesn-t-work-well.
|
||||
dirName := name
|
||||
if !fi.IsDir() {
|
||||
name = filepath.Dir(name)
|
||||
dirName = filepath.Dir(name)
|
||||
}
|
||||
|
||||
return w.watcher.Add(name)
|
||||
w.filesMu.Lock()
|
||||
defer w.filesMu.Unlock()
|
||||
|
||||
names := w.files[dirName]
|
||||
if names == nil {
|
||||
names = container.NewMapSet[string]()
|
||||
w.files[dirName] = names
|
||||
}
|
||||
names.Add(name)
|
||||
|
||||
err = w.watcher.Add(dirName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding %q: %w", dirName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove implements the [FSWatcher] interface for *osWatcher.
|
||||
func (w *osWatcher) Remove(name string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||
|
||||
dirName := filepath.Dir(name)
|
||||
|
||||
w.filesMu.Lock()
|
||||
defer w.filesMu.Unlock()
|
||||
|
||||
names, ok := w.files[name]
|
||||
if ok {
|
||||
dirName = name
|
||||
} else {
|
||||
names = w.files[dirName]
|
||||
}
|
||||
|
||||
if !names.Has(name) {
|
||||
// Name is not tracked.
|
||||
return nil
|
||||
}
|
||||
|
||||
names.Delete(name)
|
||||
if names.Len() > 0 {
|
||||
// Some files are still tracked in the directory.
|
||||
return nil
|
||||
}
|
||||
|
||||
// No more files tracked in the directory, unwatch it.
|
||||
delete(w.files, dirName)
|
||||
|
||||
err = w.watcher.Remove(dirName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing %q: %w", dirName, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// handleEvents notifies about the received file system's event if needed. It
|
||||
|
|
@ -129,14 +194,14 @@ func (w *osWatcher) handleEvents(ctx context.Context) {
|
|||
|
||||
ch := w.watcher.Events
|
||||
for e := range ch {
|
||||
if e.Op&fsnotify.Write == 0 || !w.files.Has(e.Name) {
|
||||
if e.Op&fsnotify.Write == 0 || !w.isTrackedFile(e.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
skipDuplicates(ch)
|
||||
|
||||
select {
|
||||
case w.events <- event{}:
|
||||
case w.events <- Event{}:
|
||||
// Go on.
|
||||
default:
|
||||
w.logger.DebugContext(ctx, "events buffer is full")
|
||||
|
|
@ -144,6 +209,21 @@ func (w *osWatcher) handleEvents(ctx context.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
// isTrackedFile returns true if the file is tracked.
|
||||
func (w *osWatcher) isTrackedFile(name string) (isDir bool) {
|
||||
dirName := filepath.Dir(name)
|
||||
|
||||
w.filesMu.RLock()
|
||||
defer w.filesMu.RUnlock()
|
||||
|
||||
names, isDir := w.files[name]
|
||||
if !isDir {
|
||||
names = w.files[dirName]
|
||||
}
|
||||
|
||||
return names.Has(name)
|
||||
}
|
||||
|
||||
// skipDuplicates drains the given channel of events, assuming that some events
|
||||
// might occur multiple times.
|
||||
func skipDuplicates(ch <-chan fsnotify.Event) {
|
||||
|
|
@ -188,7 +268,7 @@ func (EmptyFSWatcher) Shutdown(_ context.Context) (err error) {
|
|||
|
||||
// Events implements the [FSWatcher] interface for EmptyFSWatcher. It always
|
||||
// returns nil channel.
|
||||
func (EmptyFSWatcher) Events() (e <-chan event) {
|
||||
func (EmptyFSWatcher) Events() (e <-chan Event) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -197,3 +277,9 @@ func (EmptyFSWatcher) Events() (e <-chan event) {
|
|||
func (EmptyFSWatcher) Add(_ string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove implements the [FSWatcher] interface for EmptyFSWatcher. It always
|
||||
// returns nil error.
|
||||
func (EmptyFSWatcher) Remove(_ string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
|
@ -20,8 +21,20 @@ import (
|
|||
type FSWatcher struct {
|
||||
OnStart func(ctx context.Context) (err error)
|
||||
OnShutdown func(ctx context.Context) (err error)
|
||||
OnEvents func() (e <-chan struct{})
|
||||
OnEvents func() (e <-chan aghos.Event)
|
||||
OnAdd func(name string) (err error)
|
||||
OnRemove func(name string) (err error)
|
||||
}
|
||||
|
||||
// NewFSWatcher returns a new *FSWatcher all methods of which panic.
|
||||
func NewFSWatcher() (w *FSWatcher) {
|
||||
return &FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
OnShutdown: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
OnEvents: func() (_ <-chan aghos.Event) { panic(testutil.UnexpectedCall()) },
|
||||
OnAdd: func(name string) (_ error) { panic(testutil.UnexpectedCall(name)) },
|
||||
OnRemove: func(name string) (_ error) { panic(testutil.UnexpectedCall(name)) },
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
|
|
@ -38,7 +51,7 @@ func (w *FSWatcher) Shutdown(ctx context.Context) (err error) {
|
|||
}
|
||||
|
||||
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||
func (w *FSWatcher) Events() (e <-chan aghos.Event) {
|
||||
return w.OnEvents()
|
||||
}
|
||||
|
||||
|
|
@ -47,6 +60,11 @@ func (w *FSWatcher) Add(name string) (err error) {
|
|||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Remove implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||
func (w *FSWatcher) Remove(name string) (err error) {
|
||||
return w.OnRemove(name)
|
||||
}
|
||||
|
||||
// ServiceWithConfig is a fake [nextagh.ServiceWithConfig] implementation for
|
||||
// tests.
|
||||
type ServiceWithConfig[ConfigType any] struct {
|
||||
|
|
|
|||
135
internal/aghtls/defaultmanager.go
Normal file
135
internal/aghtls/defaultmanager.go
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
package aghtls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// DefaultManagerConfig is the configuration structure for [NewDefaultManager].
|
||||
type DefaultManagerConfig struct {
|
||||
// Logger is used for logging the operation of the manager. It must not be
|
||||
// nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Watcher is used to watch the TLS certificate and key files. It must not
|
||||
// be nil.
|
||||
Watcher aghos.FSWatcher
|
||||
}
|
||||
|
||||
// DefaultManager is the default implementation of the [Manager] interface.
|
||||
//
|
||||
// TODO(e.burkov): Use.
|
||||
type DefaultManager struct {
|
||||
logger *slog.Logger
|
||||
pair *atomic.Pointer[TLSPair]
|
||||
updates chan UpdateSignal
|
||||
watcher aghos.FSWatcher
|
||||
}
|
||||
|
||||
// NewDefaultManager returns a new properly initialized default manager.
|
||||
func NewDefaultManager(c *DefaultManagerConfig) (mgr *DefaultManager) {
|
||||
return &DefaultManager{
|
||||
logger: c.Logger,
|
||||
pair: &atomic.Pointer[TLSPair]{},
|
||||
// Buffer the channel to avoid missing updates.
|
||||
updates: make(chan UpdateSignal, 1),
|
||||
watcher: c.Watcher,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Manager = (*DefaultManager)(nil)
|
||||
|
||||
// Set implements the [Manager] interface for *DefaultManager.
|
||||
func (mgr *DefaultManager) Set(ctx context.Context, certKey *TLSPair) (err error) {
|
||||
old := mgr.pair.Swap(certKey)
|
||||
|
||||
if old != nil {
|
||||
err = errors.Join(
|
||||
mgr.watcher.Remove(old.CertPath),
|
||||
mgr.watcher.Remove(old.KeyPath),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing old certificate and key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if certKey != nil {
|
||||
err = errors.Join(
|
||||
mgr.watcher.Add(certKey.CertPath),
|
||||
mgr.watcher.Add(certKey.KeyPath),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adding new certificate and key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh implements the [service.Refresher] interface for *DefaultManager.
|
||||
func (mgr *DefaultManager) Refresh(ctx context.Context) (err error) {
|
||||
select {
|
||||
case mgr.updates <- UpdateSignal{}:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("refreshing: %w", ctx.Err())
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Start implements the [service.Interface] interface for *DefaultManager.
|
||||
func (mgr *DefaultManager) Start(ctx context.Context) (err error) {
|
||||
err = mgr.watcher.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting watcher: %w", err)
|
||||
}
|
||||
|
||||
go mgr.handleEvents(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown implements the [service.Interface] interface for *DefaultManager.
|
||||
func (mgr *DefaultManager) Shutdown(ctx context.Context) (err error) {
|
||||
defer close(mgr.updates)
|
||||
|
||||
err = mgr.watcher.Shutdown(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shutting down watcher: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Updates implements the [Manager] interface for *DefaultManager.
|
||||
func (mgr *DefaultManager) Updates(ctx context.Context) (updates <-chan UpdateSignal) {
|
||||
return mgr.updates
|
||||
}
|
||||
|
||||
// handleEvents handles changes of the tracked files. It is intended to be run
|
||||
// in a separate goroutine.
|
||||
func (mgr *DefaultManager) handleEvents(ctx context.Context) {
|
||||
defer slogutil.RecoverAndLog(ctx, mgr.logger)
|
||||
|
||||
eventsCh := mgr.watcher.Events()
|
||||
if eventsCh == nil {
|
||||
mgr.logger.DebugContext(ctx, "watcher does not emit events")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for range eventsCh {
|
||||
err := mgr.Refresh(ctx)
|
||||
if err != nil {
|
||||
mgr.logger.ErrorContext(ctx, "refreshing", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
36
internal/aghtls/manager.go
Normal file
36
internal/aghtls/manager.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package aghtls
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
)
|
||||
|
||||
// TLSPair is a pair of paths to a certificate and a key.
|
||||
type TLSPair struct {
|
||||
// CertPath is the path to the certificate.
|
||||
CertPath string
|
||||
|
||||
// KeyPath is the path to the key.
|
||||
KeyPath string
|
||||
}
|
||||
|
||||
// UpdateSignal is the signal that the TLS certificate and key have been
|
||||
// updated.
|
||||
type UpdateSignal struct{}
|
||||
|
||||
// Manager manages TLS certificates and keys updates.
|
||||
type Manager interface {
|
||||
service.Interface
|
||||
service.Refresher
|
||||
|
||||
// Set sets the TLS certificate and key.
|
||||
Set(ctx context.Context, certKey *TLSPair) (err error)
|
||||
|
||||
// Updates returns a channel that emits signals when the TLS certificate
|
||||
// and/or key have been updated.
|
||||
//
|
||||
// TODO(e.burkov): Move reloading logic to the manager and get rid of this
|
||||
// method.
|
||||
Updates(ctx context.Context) (updates <-chan UpdateSignal)
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@ package dnsforward
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
|
|
@ -24,6 +23,7 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
|
|
@ -1516,21 +1516,20 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
|||
}
|
||||
|
||||
var eventsCalledCounter uint32
|
||||
watcher := aghtest.NewFSWatcher()
|
||||
watcher.OnEvents = func() (e <-chan aghos.Event) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
return nil
|
||||
}
|
||||
watcher.OnAdd = func(name string) (err error) {
|
||||
assert.Equal(t, hostsFilename, name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
OnEvents: func() (e <-chan struct{}) {
|
||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||
|
||||
return nil
|
||||
},
|
||||
OnAdd: func(name string) (err error) {
|
||||
assert.Equal(t, hostsFilename, name)
|
||||
|
||||
return nil
|
||||
},
|
||||
OnShutdown: func(ctx context.Context) (err error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
}, hostsFilename)
|
||||
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, watcher, hostsFilename)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter))
|
||||
|
|
|
|||
|
|
@ -366,6 +366,11 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
|||
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
|
||||
}).String()
|
||||
|
||||
watcher := aghtest.NewFSWatcher()
|
||||
watcher.OnEvents = func() (e <-chan struct{}) { return nil }
|
||||
watcher.OnAdd = func(_ string) (err error) { return nil }
|
||||
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
hc, err := aghnet.NewHostsContainer(
|
||||
ctx,
|
||||
|
|
@ -375,12 +380,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
|||
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),
|
||||
},
|
||||
},
|
||||
&aghtest.FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(_ string) (err error) { return nil },
|
||||
OnShutdown: func(_ context.Context) (err error) { return nil },
|
||||
},
|
||||
watcher,
|
||||
hostsFileName,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -42,12 +42,10 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
|||
Data: []byte(data),
|
||||
},
|
||||
}
|
||||
watcher := &aghtest.FSWatcher{
|
||||
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
|
||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||
OnAdd: func(name string) (err error) { return nil },
|
||||
OnShutdown: func(_ context.Context) (err error) { return nil },
|
||||
}
|
||||
watcher := aghtest.NewFSWatcher()
|
||||
watcher.OnEvents = func() (e <-chan struct{}) { return nil }
|
||||
watcher.OnAdd = func(name string) (err error) { return nil }
|
||||
watcher.OnShutdown = func(_ context.Context) (err error) { return nil }
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
hc, err := aghnet.NewHostsContainer(ctx, testLogger, files, watcher, "hosts")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue