mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-05-13 13:57:05 +00:00
dns: Fix conn pool leak
This commit is contained in:
parent
2e1a7a53ae
commit
7b3a1de7bc
1 changed files with 88 additions and 235 deletions
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
|
@ -53,19 +52,6 @@ type connPoolConnect[T comparable] struct {
|
|||
err error
|
||||
}
|
||||
|
||||
type connPoolDialContext struct {
|
||||
context.Context
|
||||
parent context.Context
|
||||
}
|
||||
|
||||
func (c connPoolDialContext) Deadline() (time.Time, bool) {
|
||||
return c.parent.Deadline()
|
||||
}
|
||||
|
||||
func (c connPoolDialContext) Value(key any) any {
|
||||
return c.parent.Value(key)
|
||||
}
|
||||
|
||||
func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] {
|
||||
return &ConnPool[T]{
|
||||
options: options,
|
||||
|
|
@ -108,67 +94,27 @@ func (p *ConnPool[T]) AcquireShared(ctx context.Context, dial func(context.Conte
|
|||
}
|
||||
|
||||
func (p *ConnPool[T]) Release(conn T, reuse bool) {
|
||||
var (
|
||||
closeConn bool
|
||||
closeErr error
|
||||
)
|
||||
|
||||
p.access.Lock()
|
||||
if p.closed || p.state == nil {
|
||||
closeConn = true
|
||||
closeErr = net.ErrClosed
|
||||
if p.closed {
|
||||
p.access.Unlock()
|
||||
if closeConn {
|
||||
p.options.Close(conn, closeErr)
|
||||
}
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return
|
||||
}
|
||||
|
||||
currentState := p.state
|
||||
_, tracked := currentState.all[conn]
|
||||
if !tracked {
|
||||
closeConn = true
|
||||
closeErr = p.closeCause(currentState)
|
||||
state := p.state
|
||||
if _, tracked := state.all[conn]; !tracked {
|
||||
p.access.Unlock()
|
||||
if closeConn {
|
||||
p.options.Close(conn, closeErr)
|
||||
}
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return
|
||||
}
|
||||
|
||||
if !reuse || !p.options.IsAlive(conn) {
|
||||
delete(currentState.all, conn)
|
||||
switch p.options.Mode {
|
||||
case ConnPoolSingle:
|
||||
if currentState.hasShared && currentState.shared == conn {
|
||||
var zero T
|
||||
currentState.shared = zero
|
||||
currentState.hasShared = false
|
||||
currentState.sharedClaimed = false
|
||||
currentState.sharedCtx = nil
|
||||
if currentState.sharedCancel != nil {
|
||||
currentState.sharedCancel(net.ErrClosed)
|
||||
currentState.sharedCancel = nil
|
||||
}
|
||||
}
|
||||
case ConnPoolOrdered:
|
||||
if element, loaded := currentState.idleElements[conn]; loaded {
|
||||
currentState.idle.Remove(element)
|
||||
delete(currentState.idleElements, conn)
|
||||
}
|
||||
}
|
||||
closeConn = true
|
||||
closeErr = net.ErrClosed
|
||||
p.removeConn(state, conn, net.ErrClosed)
|
||||
p.access.Unlock()
|
||||
if closeConn {
|
||||
p.options.Close(conn, closeErr)
|
||||
}
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return
|
||||
}
|
||||
|
||||
if p.options.Mode == ConnPoolOrdered {
|
||||
if _, loaded := currentState.idleElements[conn]; !loaded {
|
||||
currentState.idleElements[conn] = currentState.idle.PushBack(conn)
|
||||
if _, loaded := state.idleElements[conn]; !loaded {
|
||||
state.idleElements[conn] = state.idle.PushBack(conn)
|
||||
}
|
||||
}
|
||||
p.access.Unlock()
|
||||
|
|
@ -176,42 +122,43 @@ func (p *ConnPool[T]) Release(conn T, reuse bool) {
|
|||
|
||||
func (p *ConnPool[T]) Invalidate(conn T, cause error) {
|
||||
p.access.Lock()
|
||||
if p.closed || p.state == nil {
|
||||
if p.closed {
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, cause)
|
||||
return
|
||||
}
|
||||
|
||||
currentState := p.state
|
||||
_, tracked := currentState.all[conn]
|
||||
if !tracked {
|
||||
state := p.state
|
||||
if _, tracked := state.all[conn]; !tracked {
|
||||
p.access.Unlock()
|
||||
return
|
||||
}
|
||||
p.removeConn(state, conn, cause)
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, cause)
|
||||
}
|
||||
|
||||
delete(currentState.all, conn)
|
||||
// removeConn must be called with p.access held.
|
||||
func (p *ConnPool[T]) removeConn(state *connPoolState[T], conn T, cause error) {
|
||||
delete(state.all, conn)
|
||||
switch p.options.Mode {
|
||||
case ConnPoolSingle:
|
||||
if currentState.hasShared && currentState.shared == conn {
|
||||
if state.hasShared && state.shared == conn {
|
||||
var zero T
|
||||
currentState.shared = zero
|
||||
currentState.hasShared = false
|
||||
currentState.sharedClaimed = false
|
||||
currentState.sharedCtx = nil
|
||||
if currentState.sharedCancel != nil {
|
||||
currentState.sharedCancel(cause)
|
||||
currentState.sharedCancel = nil
|
||||
state.shared = zero
|
||||
state.hasShared = false
|
||||
state.sharedClaimed = false
|
||||
state.sharedCtx = nil
|
||||
if state.sharedCancel != nil {
|
||||
state.sharedCancel(cause)
|
||||
state.sharedCancel = nil
|
||||
}
|
||||
}
|
||||
case ConnPoolOrdered:
|
||||
if element, loaded := currentState.idleElements[conn]; loaded {
|
||||
currentState.idle.Remove(element)
|
||||
delete(currentState.idleElements, conn)
|
||||
if element, loaded := state.idleElements[conn]; loaded {
|
||||
state.idle.Remove(element)
|
||||
delete(state.idleElements, conn)
|
||||
}
|
||||
}
|
||||
p.access.Unlock()
|
||||
|
||||
p.options.Close(conn, cause)
|
||||
}
|
||||
|
||||
func (p *ConnPool[T]) Reset() {
|
||||
|
|
@ -220,7 +167,6 @@ func (p *ConnPool[T]) Reset() {
|
|||
p.access.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
oldState := p.state
|
||||
p.state = newConnPoolState[T](p.options.Mode)
|
||||
p.access.Unlock()
|
||||
|
|
@ -234,7 +180,6 @@ func (p *ConnPool[T]) Close() error {
|
|||
p.access.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
p.closed = true
|
||||
oldState := p.state
|
||||
p.state = nil
|
||||
|
|
@ -247,40 +192,47 @@ func (p *ConnPool[T]) Close() error {
|
|||
func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) {
|
||||
var zero T
|
||||
for {
|
||||
var (
|
||||
staleConn T
|
||||
hasStale bool
|
||||
)
|
||||
|
||||
p.access.Lock()
|
||||
if p.closed {
|
||||
p.access.Unlock()
|
||||
return zero, false, net.ErrClosed
|
||||
}
|
||||
|
||||
currentState := p.state
|
||||
if element := currentState.idle.Front(); element != nil {
|
||||
conn := currentState.idle.Remove(element)
|
||||
delete(currentState.idleElements, conn)
|
||||
current := p.state
|
||||
if element := current.idle.Front(); element != nil {
|
||||
conn := current.idle.Remove(element)
|
||||
delete(current.idleElements, conn)
|
||||
if p.options.IsAlive(conn) {
|
||||
p.access.Unlock()
|
||||
return conn, false, nil
|
||||
}
|
||||
delete(currentState.all, conn)
|
||||
staleConn = conn
|
||||
hasStale = true
|
||||
delete(current.all, conn)
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
continue
|
||||
}
|
||||
p.access.Unlock()
|
||||
|
||||
if hasStale {
|
||||
p.options.Close(staleConn, net.ErrClosed)
|
||||
continue
|
||||
dialCtx, dialCancel := context.WithCancelCause(ctx)
|
||||
stopStateCancel := context.AfterFunc(current.ctx, func() {
|
||||
dialCancel(context.Cause(current.ctx))
|
||||
})
|
||||
conn, err := dial(dialCtx)
|
||||
stateCancelStopped := stopStateCancel()
|
||||
dialErr := context.Cause(dialCtx)
|
||||
if dialErr == nil && !stateCancelStopped {
|
||||
dialErr = context.Cause(current.ctx)
|
||||
}
|
||||
|
||||
conn, err := p.dial(ctx, currentState, dial)
|
||||
dialCancel(nil)
|
||||
if err != nil {
|
||||
if dialErr != nil {
|
||||
return zero, false, dialErr
|
||||
}
|
||||
return zero, false, err
|
||||
}
|
||||
if dialErr != nil {
|
||||
p.options.Close(conn, dialErr)
|
||||
return zero, false, dialErr
|
||||
}
|
||||
|
||||
p.access.Lock()
|
||||
if p.closed {
|
||||
|
|
@ -288,13 +240,12 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont
|
|||
p.options.Close(conn, net.ErrClosed)
|
||||
return zero, false, net.ErrClosed
|
||||
}
|
||||
if p.state != currentState {
|
||||
cause := p.closeCause(currentState)
|
||||
if p.state != current {
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, cause)
|
||||
return zero, false, cause
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return zero, false, net.ErrClosed
|
||||
}
|
||||
currentState.all[conn] = struct{}{}
|
||||
current.all[conn] = struct{}{}
|
||||
p.access.Unlock()
|
||||
return conn, true, nil
|
||||
}
|
||||
|
|
@ -303,21 +254,12 @@ func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Cont
|
|||
func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) {
|
||||
var zero T
|
||||
for {
|
||||
var (
|
||||
staleConn T
|
||||
hasStale bool
|
||||
state *connPoolConnect[T]
|
||||
current *connPoolState[T]
|
||||
startDial bool
|
||||
)
|
||||
|
||||
p.access.Lock()
|
||||
if p.closed {
|
||||
p.access.Unlock()
|
||||
return zero, nil, false, net.ErrClosed
|
||||
}
|
||||
|
||||
current = p.state
|
||||
current := p.state
|
||||
if current.hasShared {
|
||||
conn := current.shared
|
||||
if p.options.IsAlive(conn) {
|
||||
|
|
@ -327,35 +269,19 @@ func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Conte
|
|||
p.access.Unlock()
|
||||
return conn, connCtx, created, nil
|
||||
}
|
||||
delete(current.all, conn)
|
||||
var zeroConn T
|
||||
current.shared = zeroConn
|
||||
current.hasShared = false
|
||||
current.sharedClaimed = false
|
||||
current.sharedCtx = nil
|
||||
if current.sharedCancel != nil {
|
||||
current.sharedCancel(net.ErrClosed)
|
||||
current.sharedCancel = nil
|
||||
}
|
||||
staleConn = conn
|
||||
hasStale = true
|
||||
p.removeConn(current, conn, net.ErrClosed)
|
||||
p.access.Unlock()
|
||||
p.options.Close(staleConn, net.ErrClosed)
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
continue
|
||||
}
|
||||
|
||||
if current.connecting == nil {
|
||||
current.connecting = &connPoolConnect[T]{
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
startDial = true
|
||||
startDial := current.connecting == nil
|
||||
if startDial {
|
||||
current.connecting = &connPoolConnect[T]{done: make(chan struct{})}
|
||||
}
|
||||
state = current.connecting
|
||||
state := current.connecting
|
||||
p.access.Unlock()
|
||||
|
||||
if hasStale {
|
||||
continue
|
||||
}
|
||||
if startDial {
|
||||
go p.connectSingle(current, state, ctx, dial)
|
||||
}
|
||||
|
|
@ -381,35 +307,39 @@ func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Conte
|
|||
}
|
||||
|
||||
func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) {
|
||||
conn, err := p.dial(ctx, current, dial)
|
||||
if err != nil {
|
||||
p.access.Lock()
|
||||
if current.connecting == state {
|
||||
current.connecting = nil
|
||||
dialCtx, dialCancel := context.WithCancelCause(ctx)
|
||||
stopStateCancel := context.AfterFunc(current.ctx, func() {
|
||||
dialCancel(context.Cause(current.ctx))
|
||||
})
|
||||
conn, err := dial(dialCtx)
|
||||
stateCancelStopped := stopStateCancel()
|
||||
dialErr := context.Cause(dialCtx)
|
||||
if dialErr == nil && !stateCancelStopped {
|
||||
dialErr = context.Cause(current.ctx)
|
||||
}
|
||||
dialCancel(nil)
|
||||
if dialErr != nil {
|
||||
if err == nil {
|
||||
p.options.Close(conn, dialErr)
|
||||
}
|
||||
state.err = err
|
||||
p.access.Unlock()
|
||||
close(state.done)
|
||||
return
|
||||
err = dialErr
|
||||
}
|
||||
|
||||
var closeErr error
|
||||
|
||||
p.access.Lock()
|
||||
if current.connecting == state {
|
||||
current.connecting = nil
|
||||
}
|
||||
if p.closed {
|
||||
current.connecting = nil
|
||||
if err != nil {
|
||||
state.err = err
|
||||
} else if p.closed {
|
||||
closeErr = net.ErrClosed
|
||||
state.err = closeErr
|
||||
} else if p.state != current {
|
||||
closeErr = p.closeCause(current)
|
||||
closeErr = net.ErrClosed
|
||||
state.err = closeErr
|
||||
} else {
|
||||
sharedCtx, sharedCancel := context.WithCancelCause(current.ctx)
|
||||
current.shared = conn
|
||||
current.hasShared = true
|
||||
current.sharedClaimed = false
|
||||
current.sharedCtx = sharedCtx
|
||||
current.sharedCancel = sharedCancel
|
||||
current.all[conn] = struct{}{}
|
||||
|
|
@ -439,9 +369,8 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo
|
|||
return zero, nil, false, false, net.ErrClosed
|
||||
}
|
||||
if p.state != current {
|
||||
cause := p.closeCause(current)
|
||||
p.access.Unlock()
|
||||
return zero, nil, false, false, cause
|
||||
return zero, nil, false, false, net.ErrClosed
|
||||
}
|
||||
if !current.hasShared {
|
||||
p.access.Unlock()
|
||||
|
|
@ -450,16 +379,7 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo
|
|||
|
||||
conn := current.shared
|
||||
if !p.options.IsAlive(conn) {
|
||||
delete(current.all, conn)
|
||||
var zeroConn T
|
||||
current.shared = zeroConn
|
||||
current.hasShared = false
|
||||
current.sharedClaimed = false
|
||||
current.sharedCtx = nil
|
||||
if current.sharedCancel != nil {
|
||||
current.sharedCancel(net.ErrClosed)
|
||||
current.sharedCancel = nil
|
||||
}
|
||||
p.removeConn(current, conn, net.ErrClosed)
|
||||
p.access.Unlock()
|
||||
p.options.Close(conn, net.ErrClosed)
|
||||
return zero, nil, false, true, nil
|
||||
|
|
@ -472,76 +392,9 @@ func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolCo
|
|||
return conn, connCtx, created, false, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool[T]) dial(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, error) {
|
||||
var zero T
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return zero, err
|
||||
}
|
||||
if cause := context.Cause(current.ctx); cause != nil {
|
||||
return zero, cause
|
||||
}
|
||||
|
||||
dialCtx, cancel := context.WithCancelCause(current.ctx)
|
||||
var (
|
||||
stateAccess sync.Mutex
|
||||
dialComplete bool
|
||||
)
|
||||
stopCancel := context.AfterFunc(ctx, func() {
|
||||
stateAccess.Lock()
|
||||
if !dialComplete {
|
||||
cancel(context.Cause(ctx))
|
||||
}
|
||||
stateAccess.Unlock()
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stateAccess.Lock()
|
||||
dialComplete = true
|
||||
stateAccess.Unlock()
|
||||
stopCancel()
|
||||
cancel(context.Cause(ctx))
|
||||
return zero, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := dial(connPoolDialContext{
|
||||
Context: dialCtx,
|
||||
parent: ctx,
|
||||
})
|
||||
stateAccess.Lock()
|
||||
dialComplete = true
|
||||
stateAccess.Unlock()
|
||||
stopCancel()
|
||||
if err != nil {
|
||||
if cause := context.Cause(dialCtx); cause != nil {
|
||||
return zero, cause
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
if cause := context.Cause(dialCtx); cause != nil {
|
||||
p.options.Close(conn, cause)
|
||||
return zero, cause
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) {
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
|
||||
state.cancel(cause)
|
||||
if state.sharedCancel != nil {
|
||||
state.sharedCancel(cause)
|
||||
}
|
||||
for conn := range state.all {
|
||||
p.options.Close(conn, cause)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ConnPool[T]) closeCause(state *connPoolState[T]) error {
|
||||
_ = state
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue