From 7b3a1de7bc94d420ab05fb235c8d68e7756dcf23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 11 May 2026 20:59:49 +0800 Subject: [PATCH] dns: Fix conn pool leak --- dns/transport/conn_pool.go | 323 ++++++++++--------------------------- 1 file changed, 88 insertions(+), 235 deletions(-) diff --git a/dns/transport/conn_pool.go b/dns/transport/conn_pool.go index 6161e9bdb..ff288b773 100644 --- a/dns/transport/conn_pool.go +++ b/dns/transport/conn_pool.go @@ -4,7 +4,6 @@ import ( "context" "net" "sync" - "time" "github.com/sagernet/sing/common/x/list" ) @@ -53,19 +52,6 @@ type connPoolConnect[T comparable] struct { err error } -type connPoolDialContext struct { - context.Context - parent context.Context -} - -func (c connPoolDialContext) Deadline() (time.Time, bool) { - return c.parent.Deadline() -} - -func (c connPoolDialContext) Value(key any) any { - return c.parent.Value(key) -} - func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] { return &ConnPool[T]{ options: options, @@ -108,67 +94,27 @@ func (p *ConnPool[T]) AcquireShared(ctx context.Context, dial func(context.Conte } func (p *ConnPool[T]) Release(conn T, reuse bool) { - var ( - closeConn bool - closeErr error - ) - p.access.Lock() - if p.closed || p.state == nil { - closeConn = true - closeErr = net.ErrClosed + if p.closed { p.access.Unlock() - if closeConn { - p.options.Close(conn, closeErr) - } + p.options.Close(conn, net.ErrClosed) return } - - currentState := p.state - _, tracked := currentState.all[conn] - if !tracked { - closeConn = true - closeErr = p.closeCause(currentState) + state := p.state + if _, tracked := state.all[conn]; !tracked { p.access.Unlock() - if closeConn { - p.options.Close(conn, closeErr) - } + p.options.Close(conn, net.ErrClosed) return } - if !reuse || !p.options.IsAlive(conn) { - delete(currentState.all, conn) - switch p.options.Mode { - case ConnPoolSingle: - if currentState.hasShared && currentState.shared == conn { - var zero T - currentState.shared = zero - currentState.hasShared = false - currentState.sharedClaimed = false - currentState.sharedCtx = nil - if currentState.sharedCancel != nil { - currentState.sharedCancel(net.ErrClosed) - currentState.sharedCancel = nil - } - } - case ConnPoolOrdered: - if element, loaded := currentState.idleElements[conn]; loaded { - currentState.idle.Remove(element) - delete(currentState.idleElements, conn) - } - } - closeConn = true - closeErr = net.ErrClosed + p.removeConn(state, conn, net.ErrClosed) p.access.Unlock() - if closeConn { - p.options.Close(conn, closeErr) - } + p.options.Close(conn, net.ErrClosed) return } - if p.options.Mode == ConnPoolOrdered { - if _, loaded := currentState.idleElements[conn]; !loaded { - currentState.idleElements[conn] = currentState.idle.PushBack(conn) + if _, loaded := state.idleElements[conn]; !loaded { + state.idleElements[conn] = state.idle.PushBack(conn) } } p.access.Unlock() @@ -176,42 +122,43 @@ func (p *ConnPool[T]) Release(conn T, reuse bool) { func (p *ConnPool[T]) Invalidate(conn T, cause error) { p.access.Lock() - if p.closed || p.state == nil { + if p.closed { p.access.Unlock() p.options.Close(conn, cause) return } - - currentState := p.state - _, tracked := currentState.all[conn] - if !tracked { + state := p.state + if _, tracked := state.all[conn]; !tracked { p.access.Unlock() return } + p.removeConn(state, conn, cause) + p.access.Unlock() + p.options.Close(conn, cause) +} - delete(currentState.all, conn) +// removeConn must be called with p.access held. +func (p *ConnPool[T]) removeConn(state *connPoolState[T], conn T, cause error) { + delete(state.all, conn) switch p.options.Mode { case ConnPoolSingle: - if currentState.hasShared && currentState.shared == conn { + if state.hasShared && state.shared == conn { var zero T - currentState.shared = zero - currentState.hasShared = false - currentState.sharedClaimed = false - currentState.sharedCtx = nil - if currentState.sharedCancel != nil { - currentState.sharedCancel(cause) - currentState.sharedCancel = nil + state.shared = zero + state.hasShared = false + state.sharedClaimed = false + state.sharedCtx = nil + if state.sharedCancel != nil { + state.sharedCancel(cause) + state.sharedCancel = nil } } case ConnPoolOrdered: - if element, loaded := currentState.idleElements[conn]; loaded { - currentState.idle.Remove(element) - delete(currentState.idleElements, conn) + if element, loaded := state.idleElements[conn]; loaded { + state.idle.Remove(element) + delete(state.idleElements, conn) } } - p.access.Unlock() - - p.options.Close(conn, cause) } func (p *ConnPool[T]) Reset() { @@ -220,7 +167,6 @@ func (p *ConnPool[T]) Reset() { p.access.Unlock() return } - oldState := p.state p.state = newConnPoolState[T](p.options.Mode) p.access.Unlock() @@ -234,7 +180,6 @@ func (p *ConnPool[T]) Close() error { p.access.Unlock() return nil } - p.closed = true oldState := p.state p.state = nil @@ -247,40 +192,47 @@ func (p *ConnPool[T]) Close() error { func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) { var zero T for { - var ( - staleConn T - hasStale bool - ) - p.access.Lock() if p.closed { p.access.Unlock() return zero, false, net.ErrClosed } - - currentState := p.state - if element := currentState.idle.Front(); element != nil { - conn := currentState.idle.Remove(element) - delete(currentState.idleElements, conn) + current := p.state + if element := current.idle.Front(); element != nil { + conn := current.idle.Remove(element) + delete(current.idleElements, conn) if p.options.IsAlive(conn) { p.access.Unlock() return conn, false, nil } - delete(currentState.all, conn) - staleConn = conn - hasStale = true + delete(current.all, conn) + p.access.Unlock() + p.options.Close(conn, net.ErrClosed) + continue } p.access.Unlock() - if hasStale { - p.options.Close(staleConn, net.ErrClosed) - continue + dialCtx, dialCancel := context.WithCancelCause(ctx) + stopStateCancel := context.AfterFunc(current.ctx, func() { + dialCancel(context.Cause(current.ctx)) + }) + conn, err := dial(dialCtx) + stateCancelStopped := stopStateCancel() + dialErr := context.Cause(dialCtx) + if dialErr == nil && !stateCancelStopped { + dialErr = context.Cause(current.ctx) } - - conn, err := p.dial(ctx, currentState, dial) + dialCancel(nil) if err != nil { + if dialErr != nil { + return zero, false, dialErr + } return zero, false, err } + if dialErr != nil { + p.options.Close(conn, dialErr) + return zero, false, dialErr + } p.access.Lock() if p.closed { @@ -288,13 +240,12 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont p.options.Close(conn, net.ErrClosed) return zero, false, net.ErrClosed } - if p.state != currentState { - cause := p.closeCause(currentState) + if p.state != current { p.access.Unlock() - p.options.Close(conn, cause) - return zero, false, cause + p.options.Close(conn, net.ErrClosed) + return zero, false, net.ErrClosed } - currentState.all[conn] = struct{}{} + current.all[conn] = struct{}{} p.access.Unlock() return conn, true, nil } @@ -303,21 +254,12 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { var zero T for { - var ( - staleConn T - hasStale bool - state *connPoolConnect[T] - current *connPoolState[T] - startDial bool - ) - p.access.Lock() if p.closed { p.access.Unlock() return zero, nil, false, net.ErrClosed } - - current = p.state + current := p.state if current.hasShared { conn := current.shared if p.options.IsAlive(conn) { @@ -327,35 +269,19 @@ func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Conte p.access.Unlock() return conn, connCtx, created, nil } - delete(current.all, conn) - var zeroConn T - current.shared = zeroConn - current.hasShared = false - current.sharedClaimed = false - current.sharedCtx = nil - if current.sharedCancel != nil { - current.sharedCancel(net.ErrClosed) - current.sharedCancel = nil - } - staleConn = conn - hasStale = true + p.removeConn(current, conn, net.ErrClosed) p.access.Unlock() - p.options.Close(staleConn, net.ErrClosed) + p.options.Close(conn, net.ErrClosed) continue } - if current.connecting == nil { - current.connecting = &connPoolConnect[T]{ - done: make(chan struct{}), - } - startDial = true + startDial := current.connecting == nil + if startDial { + current.connecting = &connPoolConnect[T]{done: make(chan struct{})} } - state = current.connecting + state := current.connecting p.access.Unlock() - if hasStale { - continue - } if startDial { go p.connectSingle(current, state, ctx, dial) } @@ -381,35 +307,39 @@ func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Conte } func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) { - conn, err := p.dial(ctx, current, dial) - if err != nil { - p.access.Lock() - if current.connecting == state { - current.connecting = nil + dialCtx, dialCancel := context.WithCancelCause(ctx) + stopStateCancel := context.AfterFunc(current.ctx, func() { + dialCancel(context.Cause(current.ctx)) + }) + conn, err := dial(dialCtx) + stateCancelStopped := stopStateCancel() + dialErr := context.Cause(dialCtx) + if dialErr == nil && !stateCancelStopped { + dialErr = context.Cause(current.ctx) + } + dialCancel(nil) + if dialErr != nil { + if err == nil { + p.options.Close(conn, dialErr) } - state.err = err - p.access.Unlock() - close(state.done) - return + err = dialErr } var closeErr error - p.access.Lock() - if current.connecting == state { - current.connecting = nil - } - if p.closed { + current.connecting = nil + if err != nil { + state.err = err + } else if p.closed { closeErr = net.ErrClosed state.err = closeErr } else if p.state != current { - closeErr = p.closeCause(current) + closeErr = net.ErrClosed state.err = closeErr } else { sharedCtx, sharedCancel := context.WithCancelCause(current.ctx) current.shared = conn current.hasShared = true - current.sharedClaimed = false current.sharedCtx = sharedCtx current.sharedCancel = sharedCancel current.all[conn] = struct{}{} @@ -439,9 +369,8 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo return zero, nil, false, false, net.ErrClosed } if p.state != current { - cause := p.closeCause(current) p.access.Unlock() - return zero, nil, false, false, cause + return zero, nil, false, false, net.ErrClosed } if !current.hasShared { p.access.Unlock() @@ -450,16 +379,7 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo conn := current.shared if !p.options.IsAlive(conn) { - delete(current.all, conn) - var zeroConn T - current.shared = zeroConn - current.hasShared = false - current.sharedClaimed = false - current.sharedCtx = nil - if current.sharedCancel != nil { - current.sharedCancel(net.ErrClosed) - current.sharedCancel = nil - } + p.removeConn(current, conn, net.ErrClosed) p.access.Unlock() p.options.Close(conn, net.ErrClosed) return zero, nil, false, true, nil @@ -472,76 +392,9 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo return conn, connCtx, created, false, nil } -func (p *ConnPool[T]) dial(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, error) { - var zero T - - if err := ctx.Err(); err != nil { - return zero, err - } - if cause := context.Cause(current.ctx); cause != nil { - return zero, cause - } - - dialCtx, cancel := context.WithCancelCause(current.ctx) - var ( - stateAccess sync.Mutex - dialComplete bool - ) - stopCancel := context.AfterFunc(ctx, func() { - stateAccess.Lock() - if !dialComplete { - cancel(context.Cause(ctx)) - } - stateAccess.Unlock() - }) - - select { - case <-ctx.Done(): - stateAccess.Lock() - dialComplete = true - stateAccess.Unlock() - stopCancel() - cancel(context.Cause(ctx)) - return zero, ctx.Err() - default: - } - - conn, err := dial(connPoolDialContext{ - Context: dialCtx, - parent: ctx, - }) - stateAccess.Lock() - dialComplete = true - stateAccess.Unlock() - stopCancel() - if err != nil { - if cause := context.Cause(dialCtx); cause != nil { - return zero, cause - } - return zero, err - } - if cause := context.Cause(dialCtx); cause != nil { - p.options.Close(conn, cause) - return zero, cause - } - return conn, nil -} - func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) { - if state == nil { - return - } - state.cancel(cause) - if state.sharedCancel != nil { - state.sharedCancel(cause) - } for conn := range state.all { p.options.Close(conn, cause) } } - -func (p *ConnPool[T]) closeCause(state *connPoolState[T]) error { - _ = state - return net.ErrClosed -}