diff --git a/constant/proxy.go b/constant/proxy.go index 868a3bb85..3b1196dc6 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -33,6 +33,8 @@ const ( TypeOCM = "ocm" TypeOOMKiller = "oom-killer" TypeHysteriaRealm = "hysteria-realm" + TypeUSBIPServer = "usbip-server" + TypeUSBIPClient = "usbip-client" TypeACME = "acme" TypeCloudflareOriginCA = "cloudflare-origin-ca" ) diff --git a/include/registry.go b/include/registry.go index 2d478bfe8..596ff18a7 100644 --- a/include/registry.go +++ b/include/registry.go @@ -39,6 +39,7 @@ import ( originca "github.com/sagernet/sing-box/service/origin_ca" "github.com/sagernet/sing-box/service/resolved" "github.com/sagernet/sing-box/service/ssmapi" + "github.com/sagernet/sing-box/service/usbip" E "github.com/sagernet/sing/common/exceptions" ) @@ -135,6 +136,7 @@ func ServiceRegistry() *service.Registry { resolved.RegisterService(registry) ssmapi.RegisterService(registry) + usbip.RegisterService(registry) registerQUICServices(registry) registerDERPService(registry) diff --git a/option/usbip.go b/option/usbip.go new file mode 100644 index 000000000..e2a7be199 --- /dev/null +++ b/option/usbip.go @@ -0,0 +1,84 @@ +package option + +import ( + "fmt" + "strconv" + "strings" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" +) + +// USBIPHexUint16 is a uint16 that accepts either a JSON integer or a hex +// string ("0x1d6b", "1d6b", "0X1D6B") on unmarshal, and emits a hex string +// on marshal. Zero means "unset". +type USBIPHexUint16 uint16 + +func (h USBIPHexUint16) MarshalJSON() ([]byte, error) { + if h == 0 { + return []byte(`""`), nil + } + return json.Marshal(fmt.Sprintf("0x%04x", uint16(h))) +} + +func (h *USBIPHexUint16) UnmarshalJSON(data []byte) error { + var asNumber uint64 + if err := json.Unmarshal(data, &asNumber); err == nil { + if asNumber > 0xffff { + return E.New("usb id out of uint16 range: ", asNumber) + } + *h = USBIPHexUint16(asNumber) + return nil + } + var asString string + if err := json.Unmarshal(data, &asString); err != nil { + return E.Cause(err, "parse usb id") + } + asString = strings.TrimSpace(asString) + if asString == "" { + *h = 0 + return nil + } + parsed, err := strconv.ParseUint(asString, 0, 16) + if err != nil { + // Allow bare hex without 0x prefix. + if parsed2, err2 := strconv.ParseUint(asString, 16, 16); err2 == nil { + *h = USBIPHexUint16(parsed2) + return nil + } + return E.Cause(err, "parse usb id ", asString) + } + *h = USBIPHexUint16(parsed) + return nil +} + +// USBIPDeviceMatch selects a USB device. Non-zero fields AND together. +// An all-zero match is rejected at service construction time. +type USBIPDeviceMatch struct { + BusID string `json:"busid,omitempty"` + VendorID USBIPHexUint16 `json:"vendor_id,omitempty"` + ProductID USBIPHexUint16 `json:"product_id,omitempty"` + Serial string `json:"serial,omitempty"` +} + +func (m USBIPDeviceMatch) IsZero() bool { + return m.BusID == "" && m.VendorID == 0 && m.ProductID == 0 && m.Serial == "" +} + +// USBIPServerServiceOptions configures a usbip-server service. It listens on +// TCP (default :3240) and binds matching local USB devices to the usbip-host +// kernel driver for export. Empty Devices means export nothing. +type USBIPServerServiceOptions struct { + ListenOptions + Devices []USBIPDeviceMatch `json:"devices,omitempty"` +} + +// USBIPClientServiceOptions configures a usbip-client service. It connects to +// one remote usbip server and attaches matching remote USB devices to the +// local kernel via vhci_hcd. Empty Devices means import every device the +// remote currently exports. +type USBIPClientServiceOptions struct { + ServerOptions + DialerOptions + Devices []USBIPDeviceMatch `json:"devices,omitempty"` +} diff --git a/service/usbip/client_linux.go b/service/usbip/client_linux.go new file mode 100644 index 000000000..7ce85b128 --- /dev/null +++ b/service/usbip/client_linux.go @@ -0,0 +1,338 @@ +//go:build linux + +package usbip + +import ( + "context" + "encoding/binary" + "net" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + boxService "github.com/sagernet/sing-box/adapter/service" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const clientReconnectDelay = 5 * time.Second + +type ClientService struct { + boxService.Adapter + ctx context.Context + cancel context.CancelFunc + logger log.ContextLogger + dialer N.Dialer + serverAddr M.Socksaddr + matches []option.USBIPDeviceMatch // empty = import all remote exports + + attachMu sync.Mutex // serializes vhci port pick + attach + wg sync.WaitGroup + + portsMu sync.Mutex + ports map[int]struct{} +} + +func NewClientService(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPClientServiceOptions) (adapter.Service, error) { + for i, m := range options.Devices { + if m.IsZero() { + return nil, E.New("devices[", i, "]: at least one of busid/vendor_id/product_id/serial is required") + } + } + if options.ServerPort == 0 { + options.ServerPort = DefaultPort + } + if options.Server == "" { + return nil, E.New("missing server address") + } + outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerOptions.ServerIsDomain()) + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(ctx) + return &ClientService{ + Adapter: boxService.NewAdapter(C.TypeUSBIPClient, tag), + ctx: ctx, + cancel: cancel, + logger: logger, + dialer: outboundDialer, + serverAddr: options.ServerOptions.Build(), + matches: options.Devices, + ports: make(map[int]struct{}), + }, nil +} + +func (c *ClientService) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + if err := ensureVHCI(); err != nil { + return err + } + c.wg.Add(1) + go c.run() + return nil +} + +func (c *ClientService) Close() error { + if c.cancel != nil { + c.cancel() + } + // Wait for workers to detach, bounded by 5s. + done := make(chan struct{}) + go func() { + c.wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(5 * time.Second): + c.logger.Warn("shutdown timeout; some vhci ports may remain attached") + } + return nil +} + +// run resolves the desired busid set (once) and spawns a worker per busid. +func (c *ClientService) run() { + defer c.wg.Done() + busids := c.resolveBusIDs() + if len(busids) == 0 { + c.logger.Warn("no devices to import; client idle") + return + } + for _, busid := range busids { + c.wg.Add(1) + go c.worker(busid) + } +} + +// resolveBusIDs connects once, issues OP_REQ_DEVLIST, and returns the busids +// to attach. For the empty-matches case, returns every remote busid. +// For filtered mode, returns the first match per criterion. +func (c *ClientService) resolveBusIDs() []string { + // Busid-only matches don't require enumeration. + if len(c.matches) > 0 && everyMatchBusIDOnly(c.matches) { + out := make([]string, 0, len(c.matches)) + for _, m := range c.matches { + out = append(out, m.BusID) + } + return dedupe(out) + } + for { + if err := c.ctx.Err(); err != nil { + return nil + } + entries, err := c.fetchDevList() + if err != nil { + c.logger.Error("enumerate ", c.serverAddr, ": ", err) + if !sleepCtx(c.ctx, clientReconnectDelay) { + return nil + } + continue + } + if len(c.matches) == 0 { + out := make([]string, 0, len(entries)) + for i := range entries { + out = append(out, entries[i].Info.BusIDString()) + } + return out + } + var out []string + for _, m := range c.matches { + picked := "" + for i := range entries { + key := DeviceKey{ + BusID: entries[i].Info.BusIDString(), + VendorID: entries[i].Info.IDVendor, + ProductID: entries[i].Info.IDProduct, + Serial: entries[i].Info.SerialString(), + } + if Matches(m, key) { + picked = key.BusID + break + } + } + if picked == "" { + c.logger.Warn("no remote device matched ", describeMatch(m)) + continue + } + out = append(out, picked) + } + return dedupe(out) + } +} + +func (c *ClientService) fetchDevList() ([]DeviceEntry, error) { + conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr) + if err != nil { + return nil, err + } + defer conn.Close() + if err := binary.Write(conn, binary.BigEndian, OpHeader{Version: ProtocolVersion, Code: OpReqDevList, Status: OpStatusOK}); err != nil { + return nil, E.Cause(err, "send OP_REQ_DEVLIST") + } + header, err := ReadOpHeader(conn) + if err != nil { + return nil, E.Cause(err, "read OP_REP_DEVLIST header") + } + if header.Code != OpRepDevList || header.Status != OpStatusOK { + return nil, E.New("OP_REP_DEVLIST status=", header.Status, " code=0x", hex16(header.Code)) + } + return ReadOpRepDevListBody(conn) +} + +// worker keeps one remote busid attached to vhci_hcd.0. On any error or +// kernel-side detach, waits clientReconnectDelay and retries. +func (c *ClientService) worker(busid string) { + defer c.wg.Done() + for { + if err := c.ctx.Err(); err != nil { + return + } + port, err := c.attemptAttach(busid) + if err != nil { + c.logger.Error("attach ", busid, ": ", err) + if !sleepCtx(c.ctx, clientReconnectDelay) { + return + } + continue + } + c.logger.Info("attached ", busid, " → vhci port ", port) + c.trackPort(port, true) + c.watchPort(port, busid) + c.trackPort(port, false) + if err := c.ctx.Err(); err != nil { + return + } + c.logger.Info("vhci port ", port, " released; reattaching ", busid) + if !sleepCtx(c.ctx, clientReconnectDelay) { + return + } + } +} + +// attemptAttach performs one dial → OP_REQ_IMPORT → vhci attach sequence. +// The returned TCP socket is handed to the kernel on success; on failure the +// connection is closed before return. +func (c *ClientService) attemptAttach(busid string) (int, error) { + conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr) + if err != nil { + return -1, E.Cause(err, "dial ", c.serverAddr) + } + success := false + defer func() { + if !success { + conn.Close() + } + }() + if err := WriteOpReqImport(conn, busid); err != nil { + return -1, E.Cause(err, "write OP_REQ_IMPORT") + } + header, err := ReadOpHeader(conn) + if err != nil { + return -1, E.Cause(err, "read OP_REP_IMPORT header") + } + if header.Code != OpRepImport { + return -1, E.New("unexpected reply code 0x", hex16(header.Code)) + } + if header.Status != OpStatusOK { + return -1, E.New("remote rejected import (status=", header.Status, ")") + } + info, err := ReadOpRepImportBody(conn) + if err != nil { + return -1, E.Cause(err, "read OP_REP_IMPORT body") + } + tcp, ok := conn.(*net.TCPConn) + if !ok { + return -1, E.New("dialed conn is not *net.TCPConn (type=", conn, ")") + } + file, err := tcp.File() + if err != nil { + return -1, E.Cause(err, "dup socket fd") + } + defer file.Close() + c.attachMu.Lock() + defer c.attachMu.Unlock() + port, err := vhciPickFreePort(info.Speed) + if err != nil { + return -1, err + } + if err := vhciAttach(port, file.Fd(), info.DevID(), info.Speed); err != nil { + return -1, E.Cause(err, "vhci attach") + } + success = true + return port, nil +} + +// watchPort polls vhci status every 2s and returns when the port is no longer +// in VDEV_ST_USED, or when ctx is canceled (in which case it detaches the port). +func (c *ClientService) watchPort(port int, busid string) { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + if err := vhciDetach(port); err != nil { + c.logger.Warn("detach port ", port, " (", busid, "): ", err) + } + return + case <-ticker.C: + used, err := vhciPortUsed(port) + if err != nil { + c.logger.Debug("poll port ", port, ": ", err) + continue + } + if !used { + return + } + } + } +} + +func (c *ClientService) trackPort(port int, add bool) { + c.portsMu.Lock() + defer c.portsMu.Unlock() + if add { + c.ports[port] = struct{}{} + } else { + delete(c.ports, port) + } +} + +func everyMatchBusIDOnly(matches []option.USBIPDeviceMatch) bool { + for _, m := range matches { + if m.BusID == "" || m.VendorID != 0 || m.ProductID != 0 || m.Serial != "" { + return false + } + } + return true +} + +func dedupe(in []string) []string { + seen := make(map[string]struct{}, len(in)) + out := make([]string, 0, len(in)) + for _, s := range in { + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + out = append(out, s) + } + return out +} + +func sleepCtx(ctx context.Context, d time.Duration) bool { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + return false + case <-t.C: + return true + } +} diff --git a/service/usbip/match.go b/service/usbip/match.go new file mode 100644 index 000000000..54f34f055 --- /dev/null +++ b/service/usbip/match.go @@ -0,0 +1,37 @@ +package usbip + +import ( + "github.com/sagernet/sing-box/option" +) + +// DeviceKey is the minimal set of fields needed to evaluate a USBIPDeviceMatch. +// Both the local sysfs enumerator and the remote OP_REP_DEVLIST parser populate +// this before running matches. +type DeviceKey struct { + BusID string + VendorID uint16 + ProductID uint16 + Serial string +} + +// Matches reports whether d satisfies m. Non-zero fields AND together; an +// all-zero match is treated as non-matching (callers should reject such +// configs earlier). +func Matches(m option.USBIPDeviceMatch, d DeviceKey) bool { + if m.IsZero() { + return false + } + if m.BusID != "" && m.BusID != d.BusID { + return false + } + if m.VendorID != 0 && uint16(m.VendorID) != d.VendorID { + return false + } + if m.ProductID != 0 && uint16(m.ProductID) != d.ProductID { + return false + } + if m.Serial != "" && m.Serial != d.Serial { + return false + } + return true +} diff --git a/service/usbip/protocol.go b/service/usbip/protocol.go new file mode 100644 index 000000000..219286072 --- /dev/null +++ b/service/usbip/protocol.go @@ -0,0 +1,255 @@ +package usbip + +import ( + "encoding/binary" + "io" + "strings" + + E "github.com/sagernet/sing/common/exceptions" +) + +// Wire constants. +const ( + DefaultPort = 3240 + + ProtocolVersion uint16 = 0x0111 + + OpReqDevList uint16 = 0x8005 + OpRepDevList uint16 = 0x0005 + OpReqImport uint16 = 0x8003 + OpRepImport uint16 = 0x0003 + + OpStatusOK uint32 = 0 + OpStatusError uint32 = 1 + + maxOpRepDevListEntries = 4096 + maxOpRepDevListBodyBytes = 8 << 20 + deviceInfoWireSize = 312 + deviceInterfaceWireSize = 4 +) + +// USB speeds (enum usb_device_speed). +const ( + SpeedUnknown uint32 = 0 + SpeedLow uint32 = 1 + SpeedFull uint32 = 2 + SpeedHigh uint32 = 3 + SpeedWireless uint32 = 4 + SpeedSuper uint32 = 5 + SpeedSuperPlus uint32 = 6 +) + +// OpHeader is the 8-byte header prefix of every OP message. +type OpHeader struct { + Version uint16 + Code uint16 + Status uint32 +} + +// DeviceInfoTruncated is the 312-byte device descriptor shared by OP_REP_DEVLIST +// entries and OP_REP_IMPORT bodies. +type DeviceInfoTruncated struct { + Path [256]byte + BusID [32]byte + BusNum uint32 + DevNum uint32 + Speed uint32 + IDVendor uint16 + IDProduct uint16 + BCDDevice uint16 + BDeviceClass uint8 + BDeviceSubClass uint8 + BDeviceProtocol uint8 + BConfigurationValue uint8 + BNumConfigurations uint8 + BNumInterfaces uint8 +} + +// DeviceInterface is the 4-byte per-interface descriptor carried in OP_REP_DEVLIST. +type DeviceInterface struct { + BInterfaceClass uint8 + BInterfaceSubClass uint8 + BInterfaceProtocol uint8 + Padding uint8 +} + +// DeviceEntry is one element of an OP_REP_DEVLIST body. +type DeviceEntry struct { + Info DeviceInfoTruncated + Interfaces []DeviceInterface +} + +// WriteOpHeader emits the 8-byte OP header. +func WriteOpHeader(w io.Writer, code uint16, status uint32) error { + return binary.Write(w, binary.BigEndian, OpHeader{ + Version: ProtocolVersion, + Code: code, + Status: status, + }) +} + +// ReadOpHeader consumes the 8-byte OP header and returns it. +func ReadOpHeader(r io.Reader) (OpHeader, error) { + var h OpHeader + if err := binary.Read(r, binary.BigEndian, &h); err != nil { + return h, err + } + return h, nil +} + +// WriteOpReqImport sends OP_REQ_IMPORT for busid (8 + 32 = 40 bytes). +func WriteOpReqImport(w io.Writer, busid string) error { + if err := WriteOpHeader(w, OpReqImport, OpStatusOK); err != nil { + return err + } + var field [32]byte + if len(busid) >= len(field) { + return E.New("busid too long: ", busid) + } + copy(field[:], busid) + return binary.Write(w, binary.BigEndian, field) +} + +// ReadOpReqImportBody reads the 32-byte busid that follows the OP header. +func ReadOpReqImportBody(r io.Reader) (string, error) { + var field [32]byte + if _, err := io.ReadFull(r, field[:]); err != nil { + return "", err + } + return cstring(field[:]), nil +} + +// WriteOpRepImport sends OP_REP_IMPORT. If status != OpStatusOK, info is omitted. +func WriteOpRepImport(w io.Writer, status uint32, info *DeviceInfoTruncated) error { + if err := WriteOpHeader(w, OpRepImport, status); err != nil { + return err + } + if status != OpStatusOK { + return nil + } + if info == nil { + return E.New("OP_REP_IMPORT success without device info") + } + return binary.Write(w, binary.BigEndian, info) +} + +// ReadOpRepImportBody reads the 312-byte device info that follows a successful header. +func ReadOpRepImportBody(r io.Reader) (DeviceInfoTruncated, error) { + var info DeviceInfoTruncated + if err := binary.Read(r, binary.BigEndian, &info); err != nil { + return info, err + } + return info, nil +} + +// WriteOpRepDevList emits an OP_REP_DEVLIST response with the given entries. +func WriteOpRepDevList(w io.Writer, entries []DeviceEntry) error { + if err := WriteOpHeader(w, OpRepDevList, OpStatusOK); err != nil { + return err + } + if err := binary.Write(w, binary.BigEndian, uint32(len(entries))); err != nil { + return err + } + for i := range entries { + if err := binary.Write(w, binary.BigEndian, &entries[i].Info); err != nil { + return err + } + for j := range entries[i].Interfaces { + if err := binary.Write(w, binary.BigEndian, &entries[i].Interfaces[j]); err != nil { + return err + } + } + } + return nil +} + +// ReadOpRepDevListBody reads what follows a devlist OP header: a uint32 count +// plus that many device entries with their per-interface tails. +func ReadOpRepDevListBody(r io.Reader) ([]DeviceEntry, error) { + var count uint32 + if err := binary.Read(r, binary.BigEndian, &count); err != nil { + return nil, err + } + if count > maxOpRepDevListEntries { + return nil, E.New("OP_REP_DEVLIST device count too large: ", count) + } + bodyBytes := uint64(4) + entries := make([]DeviceEntry, int(count)) + for i := range entries { + if err := binary.Read(r, binary.BigEndian, &entries[i].Info); err != nil { + return nil, err + } + bodyBytes += deviceInfoWireSize + if bodyBytes > maxOpRepDevListBodyBytes { + return nil, E.New("OP_REP_DEVLIST body too large") + } + n := int(entries[i].Info.BNumInterfaces) + if n > 0 { + bodyBytes += uint64(n) * deviceInterfaceWireSize + if bodyBytes > maxOpRepDevListBodyBytes { + return nil, E.New("OP_REP_DEVLIST interface data too large") + } + entries[i].Interfaces = make([]DeviceInterface, n) + for j := range entries[i].Interfaces { + if err := binary.Read(r, binary.BigEndian, &entries[i].Interfaces[j]); err != nil { + return nil, err + } + } + } + } + return entries, nil +} + +// BusID extracts the null-terminated busid from a DeviceInfoTruncated. +func (d *DeviceInfoTruncated) BusIDString() string { + return cstring(d.BusID[:]) +} + +// PathString extracts the null-terminated sysfs path. +func (d *DeviceInfoTruncated) PathString() string { + return cstring(d.Path[:]) +} + +func (d *DeviceInfoTruncated) SerialString() string { + meta := trailingCString(d.Path[:]) + serial, ok := strings.CutPrefix(meta, "serial=") + if !ok { + return "" + } + return serial +} + +// DevID packs busnum/devnum the way vhci_hcd.attach expects. +func (d *DeviceInfoTruncated) DevID() uint32 { + return (d.BusNum << 16) | (d.DevNum & 0xffff) +} + +func encodePathField(dst *[256]byte, path, serial string) { + copy(dst[:], path) + if serial == "" || len(path) >= len(dst)-1 { + return + } + copy(dst[len(path)+1:], "serial="+serial) +} + +func cstring(b []byte) string { + for i, c := range b { + if c == 0 { + return string(b[:i]) + } + } + return string(b) +} + +func trailingCString(b []byte) string { + for i, c := range b { + if c != 0 { + continue + } + if i+1 >= len(b) { + return "" + } + return cstring(b[i+1:]) + } + return "" +} diff --git a/service/usbip/register_linux.go b/service/usbip/register_linux.go new file mode 100644 index 000000000..7735e63c3 --- /dev/null +++ b/service/usbip/register_linux.go @@ -0,0 +1,14 @@ +//go:build linux + +package usbip + +import ( + boxService "github.com/sagernet/sing-box/adapter/service" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" +) + +func RegisterService(registry *boxService.Registry) { + boxService.Register[option.USBIPServerServiceOptions](registry, C.TypeUSBIPServer, NewServerService) + boxService.Register[option.USBIPClientServiceOptions](registry, C.TypeUSBIPClient, NewClientService) +} diff --git a/service/usbip/register_stub.go b/service/usbip/register_stub.go new file mode 100644 index 000000000..c1ff5b20d --- /dev/null +++ b/service/usbip/register_stub.go @@ -0,0 +1,23 @@ +//go:build !linux + +package usbip + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + boxService "github.com/sagernet/sing-box/adapter/service" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func RegisterService(registry *boxService.Registry) { + boxService.Register[option.USBIPServerServiceOptions](registry, C.TypeUSBIPServer, func(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPServerServiceOptions) (adapter.Service, error) { + return nil, E.New("usbip-server service is only supported on Linux") + }) + boxService.Register[option.USBIPClientServiceOptions](registry, C.TypeUSBIPClient, func(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPClientServiceOptions) (adapter.Service, error) { + return nil, E.New("usbip-client service is only supported on Linux") + }) +} diff --git a/service/usbip/server_linux.go b/service/usbip/server_linux.go new file mode 100644 index 000000000..57c5770b1 --- /dev/null +++ b/service/usbip/server_linux.go @@ -0,0 +1,381 @@ +//go:build linux + +package usbip + +import ( + "context" + "net" + "sync" + + "github.com/sagernet/sing-box/adapter" + boxService "github.com/sagernet/sing-box/adapter/service" + "github.com/sagernet/sing-box/common/listener" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" +) + +type serverExport struct { + busid string + managed bool + originalDriver string +} + +type ServerService struct { + boxService.Adapter + ctx context.Context + cancel context.CancelFunc + logger log.ContextLogger + listener *listener.Listener + matches []option.USBIPDeviceMatch + + mu sync.Mutex + exports []serverExport + listenFD net.Listener +} + +func NewServerService(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPServerServiceOptions) (adapter.Service, error) { + for i, m := range options.Devices { + if m.IsZero() { + return nil, E.New("devices[", i, "]: at least one of busid/vendor_id/product_id/serial is required") + } + } + if options.ListenPort == 0 { + options.ListenPort = DefaultPort + } + ctx, cancel := context.WithCancel(ctx) + s := &ServerService{ + Adapter: boxService.NewAdapter(C.TypeUSBIPServer, tag), + ctx: ctx, + cancel: cancel, + logger: logger, + matches: options.Devices, + listener: listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Network: []string{N.NetworkTCP}, + Listen: options.ListenOptions, + }), + } + return s, nil +} + +func (s *ServerService) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + if err := ensureHostDriver(); err != nil { + return err + } + if err := s.bindExports(); err != nil { + s.rollbackExports() + return err + } + tcpListener, err := s.listener.ListenTCP() + if err != nil { + s.rollbackExports() + return err + } + s.mu.Lock() + s.listenFD = tcpListener + s.mu.Unlock() + go s.acceptLoop(tcpListener) + return nil +} + +func (s *ServerService) Close() error { + if s.cancel != nil { + s.cancel() + } + err := common.Close(common.PtrOrNil(s.listener)) + s.rollbackExports() + return err +} + +// bindExports resolves every match against current sysfs state, unbinds from +// the current driver, and binds to usbip-host. +func (s *ServerService) bindExports() error { + devices, err := listUSBDevices() + if err != nil { + return E.Cause(err, "enumerate usb devices") + } + seen := make(map[string]bool) + for _, m := range s.matches { + matched := 0 + for i := range devices { + if !Matches(m, devices[i].key()) { + continue + } + if seen[devices[i].BusID] { + matched++ + continue + } + if err := s.bindOne(&devices[i]); err != nil { + s.logger.Warn("bind ", devices[i].BusID, ": ", err) + continue + } + seen[devices[i].BusID] = true + matched++ + } + if matched == 0 { + s.logger.Warn("no local device matched ", describeMatch(m)) + } + } + return nil +} + +func (s *ServerService) bindOne(d *sysfsDevice) error { + driver, err := currentDriver(d.BusID) + if err != nil { + return err + } + if driver == "usbip-host" { + s.logger.Info("device ", d.BusID, " already bound to usbip-host; co-opting") + s.mu.Lock() + s.exports = append(s.exports, serverExport{busid: d.BusID}) + s.mu.Unlock() + return nil + } + if driver != "" { + if err := unbindFromDriver(d.BusID, driver); err != nil { + return E.Cause(err, "unbind from ", driver) + } + } + if err := hostMatchBusID(d.BusID, true); err != nil { + if driver != "" { + _ = bindToDriver(d.BusID, driver) + } + return E.Cause(err, "match_busid add") + } + if err := hostBind(d.BusID); err != nil { + _ = hostMatchBusID(d.BusID, false) + if driver != "" { + _ = bindToDriver(d.BusID, driver) + } + return E.Cause(err, "bind to usbip-host") + } + s.logger.Info("exported ", d.BusID, " (previously on ", driverOrNone(driver), ")") + s.mu.Lock() + s.exports = append(s.exports, serverExport{ + busid: d.BusID, + managed: true, + originalDriver: driver, + }) + s.mu.Unlock() + return nil +} + +func (s *ServerService) rollbackExports() { + s.mu.Lock() + exports := s.exports + s.exports = nil + s.mu.Unlock() + for _, e := range exports { + if !e.managed { + continue + } + // Release any attached peer. + _ = writeUsbipSockfd(e.busid, -1) + if err := hostUnbind(e.busid); err != nil { + s.logger.Warn("unbind ", e.busid, ": ", err) + } + if err := hostMatchBusID(e.busid, false); err != nil { + s.logger.Debug("match_busid del ", e.busid, ": ", err) + } + if e.originalDriver == "" { + s.logger.Info("released ", e.busid, " from usbip-host") + continue + } + if err := bindToDriver(e.busid, e.originalDriver); err != nil { + s.logger.Warn("rebind ", e.busid, " to ", e.originalDriver, ": ", err) + continue + } + s.logger.Info("restored ", e.busid, " to ", e.originalDriver) + } +} + +func (s *ServerService) currentExports() []string { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]string, len(s.exports)) + for i, e := range s.exports { + out[i] = e.busid + } + return out +} + +func (s *ServerService) acceptLoop(ln net.Listener) { + for { + conn, err := ln.Accept() + if err != nil { + select { + case <-s.ctx.Done(): + return + default: + } + if E.IsClosed(err) { + return + } + s.logger.Error("accept: ", err) + return + } + go s.handleConn(conn) + } +} + +func (s *ServerService) handleConn(conn net.Conn) { + defer conn.Close() + header, err := ReadOpHeader(conn) + if err != nil { + s.logger.Debug("read op header: ", err) + return + } + switch header.Code { + case OpReqDevList: + s.handleDevList(conn) + case OpReqImport: + s.handleImport(conn) + default: + s.logger.Debug("unknown opcode 0x", hex16(header.Code)) + } +} + +func (s *ServerService) handleDevList(conn net.Conn) { + entries := s.buildDevListEntries() + if err := WriteOpRepDevList(conn, entries); err != nil { + s.logger.Debug("write devlist: ", err) + } +} + +func (s *ServerService) buildDevListEntries() []DeviceEntry { + busids := s.currentExports() + if len(busids) == 0 { + return nil + } + entries := make([]DeviceEntry, 0, len(busids)) + for _, busid := range busids { + d, err := readSysfsDevice(busid, sysBusDevicePath(busid)) + if err != nil { + s.logger.Debug("refresh ", busid, ": ", err) + continue + } + entries = append(entries, DeviceEntry{ + Info: d.toProtocol(), + Interfaces: d.Interfaces, + }) + } + return entries +} + +func (s *ServerService) handleImport(conn net.Conn) { + busid, err := ReadOpReqImportBody(conn) + if err != nil { + s.logger.Debug("read import body: ", err) + return + } + if !s.isExported(busid) { + s.logger.Info("import rejected (unknown busid): ", busid) + _ = WriteOpRepImport(conn, OpStatusError, nil) + return + } + status, err := readUsbipStatus(busid) + if err != nil || status != 1 { + s.logger.Info("import rejected (busid ", busid, " status=", status, " err=", err, ")") + _ = WriteOpRepImport(conn, OpStatusError, nil) + return + } + dev, err := readSysfsDevice(busid, sysBusDevicePath(busid)) + if err != nil { + s.logger.Warn("refresh ", busid, ": ", err) + _ = WriteOpRepImport(conn, OpStatusError, nil) + return + } + tcp, ok := conn.(*net.TCPConn) + if !ok { + s.logger.Warn("import requires *net.TCPConn, got ", conn) + _ = WriteOpRepImport(conn, OpStatusError, nil) + return + } + file, err := tcp.File() + if err != nil { + s.logger.Warn("dup socket fd: ", err) + _ = WriteOpRepImport(conn, OpStatusError, nil) + return + } + defer file.Close() + if err := writeUsbipSockfd(busid, int(file.Fd())); err != nil { + s.logger.Warn("hand off ", busid, " to kernel: ", err) + _ = WriteOpRepImport(conn, OpStatusError, nil) + return + } + info := dev.toProtocol() + if err := WriteOpRepImport(conn, OpStatusOK, &info); err != nil { + s.logger.Warn("reply import ", busid, ": ", err) + _ = writeUsbipSockfd(busid, -1) + return + } + s.logger.Info("attached ", busid, " to remote ", conn.RemoteAddr()) +} + +func (s *ServerService) isExported(busid string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, e := range s.exports { + if e.busid == busid { + return true + } + } + return false +} + +func sysBusDevicePath(busid string) string { + return sysBusUSBDevices + "/" + busid +} + +func describeMatch(m option.USBIPDeviceMatch) string { + var parts []string + if m.BusID != "" { + parts = append(parts, "busid="+m.BusID) + } + if m.VendorID != 0 { + parts = append(parts, "vendor_id=0x"+hex16(uint16(m.VendorID))) + } + if m.ProductID != 0 { + parts = append(parts, "product_id=0x"+hex16(uint16(m.ProductID))) + } + if m.Serial != "" { + parts = append(parts, "serial="+m.Serial) + } + return "{" + joinComma(parts) + "}" +} + +func driverOrNone(d string) string { + if d == "" { + return "(no driver)" + } + return d +} + +func hex16(v uint16) string { + const hexdigits = "0123456789abcdef" + return string([]byte{ + hexdigits[(v>>12)&0xf], + hexdigits[(v>>8)&0xf], + hexdigits[(v>>4)&0xf], + hexdigits[v&0xf], + }) +} + +func joinComma(parts []string) string { + out := "" + for i, p := range parts { + if i > 0 { + out += "," + } + out += p + } + return out +} diff --git a/service/usbip/sysfs_linux.go b/service/usbip/sysfs_linux.go new file mode 100644 index 000000000..754a893fe --- /dev/null +++ b/service/usbip/sysfs_linux.go @@ -0,0 +1,409 @@ +//go:build linux + +package usbip + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + sysBusUSBDevices = "/sys/bus/usb/devices" + sysUsbipHostDriver = "/sys/bus/usb/drivers/usbip-host" + sysVHCIControllerV0 = "/sys/devices/platform/vhci_hcd.0" +) + +// sysfsDevice captures the subset of USB device attributes needed for +// matching, export, and OP_REP_DEVLIST emission. +type sysfsDevice struct { + BusID string + Path string // sysfs path, e.g. /sys/bus/usb/devices/1-2 + BusNum uint32 + DevNum uint32 + Speed uint32 + VendorID uint16 + ProductID uint16 + BCDDevice uint16 + DeviceClass uint8 + DeviceSubClass uint8 + DeviceProtocol uint8 + ConfigValue uint8 + NumConfigs uint8 + NumInterfaces uint8 + Serial string + Interfaces []DeviceInterface +} + +func (d *sysfsDevice) key() DeviceKey { + return DeviceKey{ + BusID: d.BusID, + VendorID: d.VendorID, + ProductID: d.ProductID, + Serial: d.Serial, + } +} + +func (d *sysfsDevice) toProtocol() DeviceInfoTruncated { + var info DeviceInfoTruncated + encodePathField(&info.Path, d.Path, d.Serial) + copy(info.BusID[:], d.BusID) + info.BusNum = d.BusNum + info.DevNum = d.DevNum + info.Speed = d.Speed + info.IDVendor = d.VendorID + info.IDProduct = d.ProductID + info.BCDDevice = d.BCDDevice + info.BDeviceClass = d.DeviceClass + info.BDeviceSubClass = d.DeviceSubClass + info.BDeviceProtocol = d.DeviceProtocol + info.BConfigurationValue = d.ConfigValue + info.BNumConfigurations = d.NumConfigs + info.BNumInterfaces = d.NumInterfaces + return info +} + +type vhciStatusRecord struct { + hub string + port int + state int +} + +// ensureHostDriver verifies the usbip-host kernel driver is loaded. +func ensureHostDriver() error { + if _, err := os.Stat(sysUsbipHostDriver); err != nil { + return E.Cause(err, "usbip-host driver not present; modprobe usbip-host") + } + return nil +} + +// ensureVHCI verifies the vhci_hcd controller is loaded. +func ensureVHCI() error { + if _, err := os.Stat(sysVHCIControllerV0); err != nil { + return E.Cause(err, "vhci_hcd.0 not present; modprobe vhci-hcd") + } + return nil +} + +// listUSBDevices enumerates /sys/bus/usb/devices, returning non-interface +// device entries that expose idVendor. +func listUSBDevices() ([]sysfsDevice, error) { + entries, err := os.ReadDir(sysBusUSBDevices) + if err != nil { + return nil, err + } + var devices []sysfsDevice + for _, entry := range entries { + name := entry.Name() + if strings.Contains(name, ":") { + continue + } + path := filepath.Join(sysBusUSBDevices, name) + device, err := readSysfsDevice(name, path) + if err != nil { + continue + } + devices = append(devices, device) + } + return devices, nil +} + +// readSysfsDevice populates a sysfsDevice from the attributes at path. +func readSysfsDevice(busid, path string) (sysfsDevice, error) { + d := sysfsDevice{BusID: busid, Path: path} + vendor, err := readHexU16(path, "idVendor") + if err != nil { + return d, err + } + d.VendorID = vendor + d.ProductID, _ = readHexU16(path, "idProduct") + d.BCDDevice, _ = readHexU16(path, "bcdDevice") + if v, err := readDecU32(path, "busnum"); err == nil { + d.BusNum = v + } + if v, err := readDecU32(path, "devnum"); err == nil { + d.DevNum = v + } + d.Speed = speedCodeFromString(readString(path, "speed")) + d.DeviceClass, _ = readHexU8(path, "bDeviceClass") + d.DeviceSubClass, _ = readHexU8(path, "bDeviceSubClass") + d.DeviceProtocol, _ = readHexU8(path, "bDeviceProtocol") + d.ConfigValue, _ = readDecU8(path, "bConfigurationValue") + d.NumConfigs, _ = readDecU8(path, "bNumConfigurations") + d.NumInterfaces, _ = readDecU8(path, "bNumInterfaces") + d.Serial = readString(path, "serial") + d.Interfaces = readInterfaces(path, busid, d.ConfigValue, int(d.NumInterfaces)) + return d, nil +} + +// readInterfaces reads the per-interface descriptors sibling to the device node. +func readInterfaces(devicePath, busid string, configValue uint8, count int) []DeviceInterface { + if count == 0 { + return nil + } + interfaces := make([]DeviceInterface, count) + for i := 0; i < count; i++ { + name := fmt.Sprintf("%s:%d.%d", busid, configValue, i) + ipath := filepath.Join(filepath.Dir(devicePath), name) + class, _ := readHexU8(ipath, "bInterfaceClass") + subClass, _ := readHexU8(ipath, "bInterfaceSubClass") + protocol, _ := readHexU8(ipath, "bInterfaceProtocol") + interfaces[i] = DeviceInterface{ + BInterfaceClass: class, + BInterfaceSubClass: subClass, + BInterfaceProtocol: protocol, + } + } + return interfaces +} + +// currentDriver returns the driver currently bound to busid, or "" if none. +func currentDriver(busid string) (string, error) { + link, err := os.Readlink(filepath.Join(sysBusUSBDevices, busid, "driver")) + if err != nil { + if os.IsNotExist(err) { + return "", nil + } + return "", err + } + return filepath.Base(link), nil +} + +// unbindFromDriver detaches busid from driver. +func unbindFromDriver(busid, driver string) error { + path := filepath.Join("/sys/bus/usb/drivers", driver, "unbind") + return writeSysfs(path, busid) +} + +// bindToDriver attaches busid to driver. +func bindToDriver(busid, driver string) error { + path := filepath.Join("/sys/bus/usb/drivers", driver, "bind") + return writeSysfs(path, busid) +} + +// hostMatchBusID writes "add " / "del " to the usbip-host +// match_busid attribute. Returns nil on ENOENT for "del" (idempotent). +func hostMatchBusID(busid string, add bool) error { + verb := "del" + if add { + verb = "add" + } + path := filepath.Join(sysUsbipHostDriver, "match_busid") + return writeSysfs(path, verb+" "+busid) +} + +// hostBind attaches busid to usbip-host. +func hostBind(busid string) error { + return writeSysfs(filepath.Join(sysUsbipHostDriver, "bind"), busid) +} + +// hostUnbind detaches busid from usbip-host. +func hostUnbind(busid string) error { + return writeSysfs(filepath.Join(sysUsbipHostDriver, "unbind"), busid) +} + +// readUsbipStatus returns the usbip_status attribute value for busid. +// 1 = AVAILABLE, 2 = USED, 3 = ERROR. +func readUsbipStatus(busid string) (int, error) { + raw, err := os.ReadFile(filepath.Join(sysBusUSBDevices, busid, "usbip_status")) + if err != nil { + return 0, err + } + v, err := strconv.Atoi(strings.TrimSpace(string(raw))) + if err != nil { + return 0, err + } + return v, nil +} + +// writeUsbipSockfd hands the fd for busid to the usbip-host kernel driver. +// Passing -1 as fd releases the connection. +func writeUsbipSockfd(busid string, fd int) error { + return writeSysfs(filepath.Join(sysBusUSBDevices, busid, "usbip_sockfd"), strconv.Itoa(fd)) +} + +// vhciPickFreePort scans the vhci_hcd.0 status table and returns a free port +// from the hub that matches the remote device speed. +func vhciPickFreePort(speed uint32) (int, error) { + records, err := readVHCIStatus() + if err != nil { + return -1, err + } + targetHub := vhciHubForSpeed(speed) + for _, record := range records { + if record.hub != targetHub || record.state != 4 { + continue + } + return record.port, nil + } + return -1, E.New("no free ", targetHub, " vhci port") +} + +// vhciPortUsed reports whether the given port is currently in VDEV_ST_USED (6). +func vhciPortUsed(port int) (bool, error) { + records, err := readVHCIStatus() + if err != nil { + return false, err + } + for _, record := range records { + if record.port != port { + continue + } + return record.state == 6, nil // VDEV_ST_USED + } + return false, nil +} + +// vhciAttach writes "port fd devid speed" to the attach attribute. +func vhciAttach(port int, fd uintptr, devid uint32, speed uint32) error { + line := fmt.Sprintf("%d %d %d %d", port, int(fd), devid, speed) + return writeSysfs(filepath.Join(sysVHCIControllerV0, "attach"), line) +} + +// vhciDetach writes the port number to the detach attribute. +func vhciDetach(port int) error { + return writeSysfs(filepath.Join(sysVHCIControllerV0, "detach"), strconv.Itoa(port)) +} + +func readVHCIStatus() ([]vhciStatusRecord, error) { + raw, err := os.ReadFile(filepath.Join(sysVHCIControllerV0, "status")) + if err != nil { + return nil, err + } + return parseVHCIStatus(string(raw)), nil +} + +func parseVHCIStatus(raw string) []vhciStatusRecord { + scanner := bufio.NewScanner(strings.NewReader(raw)) + records := make([]vhciStatusRecord, 0) + first := true + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if first { + first = false + continue + } + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) < 3 { + continue + } + port, err := strconv.Atoi(fields[1]) + if err != nil { + continue + } + state, err := strconv.Atoi(fields[2]) + if err != nil { + continue + } + records = append(records, vhciStatusRecord{ + hub: fields[0], + port: port, + state: state, + }) + } + return records +} + +func vhciHubForSpeed(speed uint32) string { + switch speed { + case SpeedSuper, SpeedSuperPlus: + return "ss" + default: + return "hs" + } +} + +// --- small helpers ------------------------------------------------------ + +func writeSysfs(path, content string) error { + f, err := os.OpenFile(path, os.O_WRONLY, 0) + if err != nil { + return err + } + defer f.Close() + _, err = f.WriteString(content) + return err +} + +func readString(dir, attr string) string { + raw, err := os.ReadFile(filepath.Join(dir, attr)) + if err != nil { + return "" + } + return strings.TrimSpace(string(raw)) +} + +func readHexU16(dir, attr string) (uint16, error) { + s := readString(dir, attr) + if s == "" { + return 0, E.New(attr, " missing") + } + v, err := strconv.ParseUint(s, 16, 16) + if err != nil { + return 0, err + } + return uint16(v), nil +} + +func readHexU8(dir, attr string) (uint8, error) { + s := readString(dir, attr) + if s == "" { + return 0, E.New(attr, " missing") + } + v, err := strconv.ParseUint(s, 16, 8) + if err != nil { + return 0, err + } + return uint8(v), nil +} + +func readDecU8(dir, attr string) (uint8, error) { + s := readString(dir, attr) + if s == "" { + return 0, E.New(attr, " missing") + } + v, err := strconv.ParseUint(s, 10, 8) + if err != nil { + return 0, err + } + return uint8(v), nil +} + +func readDecU32(dir, attr string) (uint32, error) { + s := readString(dir, attr) + if s == "" { + return 0, E.New(attr, " missing") + } + v, err := strconv.ParseUint(s, 10, 32) + if err != nil { + return 0, err + } + return uint32(v), nil +} + +// speedCodeFromString maps the sysfs "speed" attribute to the USB/IP wire +// enum usb_device_speed. +func speedCodeFromString(s string) uint32 { + switch s { + case "1.5": + return SpeedLow + case "12": + return SpeedFull + case "480": + return SpeedHigh + case "5000": + return SpeedSuper + case "10000": + return SpeedSuperPlus + default: + return SpeedUnknown + } +}