diff --git a/protocol/tailscale/endpoint.go b/protocol/tailscale/endpoint.go index f265a470d..59e52079b 100644 --- a/protocol/tailscale/endpoint.go +++ b/protocol/tailscale/endpoint.go @@ -110,6 +110,7 @@ type Endpoint struct { systemInterfaceName string systemInterfaceMTU uint32 serverStarted bool + started atomic.Bool systemTun tun.Tun systemDialer *dialer.DefaultDialer fallbackTCPCloser func() @@ -429,6 +430,7 @@ func (t *Endpoint) postStart() error { } t.filter = localBackend.ExportFilter() go t.watchState() + t.started.Store(true) return nil } @@ -492,6 +494,7 @@ func (t *Endpoint) watchState() { func (t *Endpoint) Close() error { var err error + t.started.Store(false) if t.serverStarted { err = common.Close(common.PtrOrNil(t.server)) t.serverStarted = false @@ -516,6 +519,9 @@ func (t *Endpoint) DialContext(ctx context.Context, network string, destination case N.NetworkUDP: t.logger.InfoContext(ctx, "outbound packet connection to ", destination) } + if !t.started.Load() { + return nil, E.New("Tailscale is not ready yet") + } if destination.IsDomain() { destinationAddresses, err := t.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { @@ -572,6 +578,9 @@ func (t *Endpoint) DialContext(ctx context.Context, network string, destination } func (t *Endpoint) listenPacketWithAddress(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if !t.started.Load() { + return nil, E.New("Tailscale is not ready yet") + } if t.systemDialer != nil { return t.systemDialer.ListenPacket(ctx, destination) } @@ -639,6 +648,9 @@ func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n } func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + if !t.started.Load() { + return nil, E.New("Tailscale is not ready yet") + } tsFilter := t.filter.Load() if tsFilter != nil { var ipProto ipproto.Proto @@ -732,6 +744,9 @@ func (t *Endpoint) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, } func (t *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + if !t.started.Load() { + return nil, E.New("Tailscale is not ready yet") + } ctx := log.ContextWithNewID(t.ctx) var destination tun.DirectRouteDestination var err error diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index 9fdc4814a..2975b05cb 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "sync/atomic" "time" "github.com/sagernet/sing-box/adapter" @@ -41,6 +42,7 @@ type Endpoint struct { logger logger.ContextLogger localAddresses []netip.Prefix endpoint *wireguard.Endpoint + started atomic.Bool } func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardEndpointOptions) (adapter.Endpoint, error) { @@ -120,16 +122,24 @@ func (w *Endpoint) Start(stage adapter.StartStage) error { case adapter.StartStateStart: return w.endpoint.Start(false) case adapter.StartStatePostStart: - return w.endpoint.Start(true) + err := w.endpoint.Start(true) + if err != nil { + return err + } + w.started.Store(true) } return nil } func (w *Endpoint) Close() error { + w.started.Store(false) return w.endpoint.Close() } func (w *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + if !w.started.Load() { + return nil, E.New("WireGuard is not ready yet") + } var ipVersion uint8 if !destination.IsIPv6() { ipVersion = 4 @@ -210,6 +220,9 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination case N.NetworkUDP: w.logger.InfoContext(ctx, "outbound packet connection to ", destination) } + if !w.started.Load() { + return nil, E.New("WireGuard is not ready yet") + } if destination.IsDomain() { destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { @@ -224,6 +237,9 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination func (w *Endpoint) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) { w.logger.InfoContext(ctx, "outbound packet connection to ", destination) + if !w.started.Load() { + return nil, netip.Addr{}, E.New("WireGuard is not ready yet") + } if destination.IsDomain() { destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) if err != nil { @@ -257,9 +273,15 @@ func (w *Endpoint) PreferredDomain(domain string) bool { } func (w *Endpoint) PreferredAddress(address netip.Addr) bool { + if !w.started.Load() { + return false + } return w.endpoint.Lookup(address) != nil } func (w *Endpoint) NewDirectRouteConnection(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) { + if !w.started.Load() { + return nil, E.New("WireGuard is not ready yet") + } return w.endpoint.NewDirectRouteConnection(metadata, routeContext, timeout) }