diff --git a/dns/transport/tcp.go b/dns/transport/tcp.go index 59333de8d..f8249437a 100644 --- a/dns/transport/tcp.go +++ b/dns/transport/tcp.go @@ -4,6 +4,8 @@ import ( "context" "encoding/binary" "io" + "net" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -13,6 +15,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio/deadline" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -71,6 +74,7 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, E.Cause(err, "dial TCP connection") } defer conn.Close() + defer setConnDeadline(ctx, conn, deadline.NeedAdditionalReadDeadline(conn))() err = WriteMessage(conn, 0, message) if err != nil { return nil, E.Cause(err, "write request") @@ -82,6 +86,20 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return response, nil } +func setConnDeadline(ctx context.Context, conn net.Conn, needClose bool) func() { + if needClose { + stop := context.AfterFunc(ctx, func() { + conn.Close() + }) + return func() { stop() } + } + if d, ok := ctx.Deadline(); ok { + conn.SetDeadline(d) + return func() { conn.SetDeadline(time.Time{}) } + } + return func() {} +} + func ReadMessage(reader io.Reader) (*mDNS.Msg, error) { var responseLen uint16 err := binary.Read(reader, binary.BigEndian, &responseLen) diff --git a/dns/transport/tls.go b/dns/transport/tls.go index 43978b6ff..b7ef25fb7 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -2,7 +2,6 @@ package transport import ( "context" - "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -12,6 +11,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio/deadline" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -38,7 +38,8 @@ type TLSTransport struct { type tlsDNSConn struct { tls.Conn - queryId uint16 + queryId uint16 + needDeadlineClose bool } func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { @@ -104,7 +105,10 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M if err != nil { return nil, E.Cause(err, "dial TLS connection") } - return &tlsDNSConn{Conn: tlsConn}, nil + return &tlsDNSConn{ + Conn: tlsConn, + needDeadlineClose: deadline.NeedAdditionalReadDeadline(tlsConn.NetConn()), + }, nil }) if err != nil { return nil, err @@ -125,9 +129,7 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M } func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) { - if deadline, ok := ctx.Deadline(); ok { - conn.SetDeadline(deadline) - } + defer setConnDeadline(ctx, conn, conn.needDeadlineClose)() conn.queryId++ err := WriteMessage(conn, conn.queryId, message) if err != nil { @@ -137,6 +139,5 @@ func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tl if err != nil { return nil, E.Cause(err, "read response") } - conn.SetDeadline(time.Time{}) return response, nil } diff --git a/dns/transport/udp.go b/dns/transport/udp.go index c9f520e31..7203b5ad4 100644 --- a/dns/transport/udp.go +++ b/dns/transport/udp.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio/deadline" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -130,6 +131,7 @@ func (t *UDPTransport) exchangeTCP(ctx context.Context, message *mDNS.Msg) (*mDN return nil, E.Cause(err, "dial TCP connection") } defer conn.Close() + defer setConnDeadline(ctx, conn, deadline.NeedAdditionalReadDeadline(conn))() err = WriteMessage(conn, message.Id, message) if err != nil { return nil, E.Cause(err, "write request")