mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-05-13 13:57:05 +00:00
dns: Fix deadline
This commit is contained in:
parent
31252a7e95
commit
228eb2df78
3 changed files with 28 additions and 7 deletions
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
|
|
@ -13,6 +15,7 @@ import (
|
|||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio/deadline"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
|
@ -71,6 +74,7 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||
return nil, E.Cause(err, "dial TCP connection")
|
||||
}
|
||||
defer conn.Close()
|
||||
defer setConnDeadline(ctx, conn, deadline.NeedAdditionalReadDeadline(conn))()
|
||||
err = WriteMessage(conn, 0, message)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "write request")
|
||||
|
|
@ -82,6 +86,20 @@ func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||
return response, nil
|
||||
}
|
||||
|
||||
func setConnDeadline(ctx context.Context, conn net.Conn, needClose bool) func() {
|
||||
if needClose {
|
||||
stop := context.AfterFunc(ctx, func() {
|
||||
conn.Close()
|
||||
})
|
||||
return func() { stop() }
|
||||
}
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
conn.SetDeadline(d)
|
||||
return func() { conn.SetDeadline(time.Time{}) }
|
||||
}
|
||||
return func() {}
|
||||
}
|
||||
|
||||
func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
|
||||
var responseLen uint16
|
||||
err := binary.Read(reader, binary.BigEndian, &responseLen)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package transport
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
|
|
@ -12,6 +11,7 @@ import (
|
|||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio/deadline"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
|
@ -38,7 +38,8 @@ type TLSTransport struct {
|
|||
|
||||
type tlsDNSConn struct {
|
||||
tls.Conn
|
||||
queryId uint16
|
||||
queryId uint16
|
||||
needDeadlineClose bool
|
||||
}
|
||||
|
||||
func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
|
||||
|
|
@ -104,7 +105,10 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||
if err != nil {
|
||||
return nil, E.Cause(err, "dial TLS connection")
|
||||
}
|
||||
return &tlsDNSConn{Conn: tlsConn}, nil
|
||||
return &tlsDNSConn{
|
||||
Conn: tlsConn,
|
||||
needDeadlineClose: deadline.NeedAdditionalReadDeadline(tlsConn.NetConn()),
|
||||
}, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -125,9 +129,7 @@ func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
|
|||
}
|
||||
|
||||
func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
conn.SetDeadline(deadline)
|
||||
}
|
||||
defer setConnDeadline(ctx, conn, conn.needDeadlineClose)()
|
||||
conn.queryId++
|
||||
err := WriteMessage(conn, conn.queryId, message)
|
||||
if err != nil {
|
||||
|
|
@ -137,6 +139,5 @@ func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tl
|
|||
if err != nil {
|
||||
return nil, E.Cause(err, "read response")
|
||||
}
|
||||
conn.SetDeadline(time.Time{})
|
||||
return response, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio/deadline"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
|
@ -130,6 +131,7 @@ func (t *UDPTransport) exchangeTCP(ctx context.Context, message *mDNS.Msg) (*mDN
|
|||
return nil, E.Cause(err, "dial TCP connection")
|
||||
}
|
||||
defer conn.Close()
|
||||
defer setConnDeadline(ctx, conn, deadline.NeedAdditionalReadDeadline(conn))()
|
||||
err = WriteMessage(conn, message.Id, message)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "write request")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue