Add Linux usbip client and server services

This commit is contained in:
世界 2026-04-21 18:06:57 +08:00
parent 056c45c2ca
commit 9a5179a730
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
10 changed files with 1545 additions and 0 deletions

View file

@ -33,6 +33,8 @@ const (
TypeOCM = "ocm"
TypeOOMKiller = "oom-killer"
TypeHysteriaRealm = "hysteria-realm"
TypeUSBIPServer = "usbip-server"
TypeUSBIPClient = "usbip-client"
TypeACME = "acme"
TypeCloudflareOriginCA = "cloudflare-origin-ca"
)

View file

@ -39,6 +39,7 @@ import (
originca "github.com/sagernet/sing-box/service/origin_ca"
"github.com/sagernet/sing-box/service/resolved"
"github.com/sagernet/sing-box/service/ssmapi"
"github.com/sagernet/sing-box/service/usbip"
E "github.com/sagernet/sing/common/exceptions"
)
@ -135,6 +136,7 @@ func ServiceRegistry() *service.Registry {
resolved.RegisterService(registry)
ssmapi.RegisterService(registry)
usbip.RegisterService(registry)
registerQUICServices(registry)
registerDERPService(registry)

84
option/usbip.go Normal file
View file

@ -0,0 +1,84 @@
package option
import (
"fmt"
"strconv"
"strings"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
)
// USBIPHexUint16 is a uint16 that accepts either a JSON integer or a hex
// string ("0x1d6b", "1d6b", "0X1D6B") on unmarshal, and emits a hex string
// on marshal. Zero means "unset".
type USBIPHexUint16 uint16
func (h USBIPHexUint16) MarshalJSON() ([]byte, error) {
if h == 0 {
return []byte(`""`), nil
}
return json.Marshal(fmt.Sprintf("0x%04x", uint16(h)))
}
func (h *USBIPHexUint16) UnmarshalJSON(data []byte) error {
var asNumber uint64
if err := json.Unmarshal(data, &asNumber); err == nil {
if asNumber > 0xffff {
return E.New("usb id out of uint16 range: ", asNumber)
}
*h = USBIPHexUint16(asNumber)
return nil
}
var asString string
if err := json.Unmarshal(data, &asString); err != nil {
return E.Cause(err, "parse usb id")
}
asString = strings.TrimSpace(asString)
if asString == "" {
*h = 0
return nil
}
parsed, err := strconv.ParseUint(asString, 0, 16)
if err != nil {
// Allow bare hex without 0x prefix.
if parsed2, err2 := strconv.ParseUint(asString, 16, 16); err2 == nil {
*h = USBIPHexUint16(parsed2)
return nil
}
return E.Cause(err, "parse usb id ", asString)
}
*h = USBIPHexUint16(parsed)
return nil
}
// USBIPDeviceMatch selects a USB device. Non-zero fields AND together.
// An all-zero match is rejected at service construction time.
type USBIPDeviceMatch struct {
BusID string `json:"busid,omitempty"`
VendorID USBIPHexUint16 `json:"vendor_id,omitempty"`
ProductID USBIPHexUint16 `json:"product_id,omitempty"`
Serial string `json:"serial,omitempty"`
}
func (m USBIPDeviceMatch) IsZero() bool {
return m.BusID == "" && m.VendorID == 0 && m.ProductID == 0 && m.Serial == ""
}
// USBIPServerServiceOptions configures a usbip-server service. It listens on
// TCP (default :3240) and binds matching local USB devices to the usbip-host
// kernel driver for export. Empty Devices means export nothing.
type USBIPServerServiceOptions struct {
ListenOptions
Devices []USBIPDeviceMatch `json:"devices,omitempty"`
}
// USBIPClientServiceOptions configures a usbip-client service. It connects to
// one remote usbip server and attaches matching remote USB devices to the
// local kernel via vhci_hcd. Empty Devices means import every device the
// remote currently exports.
type USBIPClientServiceOptions struct {
ServerOptions
DialerOptions
Devices []USBIPDeviceMatch `json:"devices,omitempty"`
}

View file

@ -0,0 +1,338 @@
//go:build linux
package usbip
import (
"context"
"encoding/binary"
"net"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const clientReconnectDelay = 5 * time.Second
type ClientService struct {
boxService.Adapter
ctx context.Context
cancel context.CancelFunc
logger log.ContextLogger
dialer N.Dialer
serverAddr M.Socksaddr
matches []option.USBIPDeviceMatch // empty = import all remote exports
attachMu sync.Mutex // serializes vhci port pick + attach
wg sync.WaitGroup
portsMu sync.Mutex
ports map[int]struct{}
}
func NewClientService(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPClientServiceOptions) (adapter.Service, error) {
for i, m := range options.Devices {
if m.IsZero() {
return nil, E.New("devices[", i, "]: at least one of busid/vendor_id/product_id/serial is required")
}
}
if options.ServerPort == 0 {
options.ServerPort = DefaultPort
}
if options.Server == "" {
return nil, E.New("missing server address")
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerOptions.ServerIsDomain())
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(ctx)
return &ClientService{
Adapter: boxService.NewAdapter(C.TypeUSBIPClient, tag),
ctx: ctx,
cancel: cancel,
logger: logger,
dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(),
matches: options.Devices,
ports: make(map[int]struct{}),
}, nil
}
func (c *ClientService) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if err := ensureVHCI(); err != nil {
return err
}
c.wg.Add(1)
go c.run()
return nil
}
func (c *ClientService) Close() error {
if c.cancel != nil {
c.cancel()
}
// Wait for workers to detach, bounded by 5s.
done := make(chan struct{})
go func() {
c.wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(5 * time.Second):
c.logger.Warn("shutdown timeout; some vhci ports may remain attached")
}
return nil
}
// run resolves the desired busid set (once) and spawns a worker per busid.
func (c *ClientService) run() {
defer c.wg.Done()
busids := c.resolveBusIDs()
if len(busids) == 0 {
c.logger.Warn("no devices to import; client idle")
return
}
for _, busid := range busids {
c.wg.Add(1)
go c.worker(busid)
}
}
// resolveBusIDs connects once, issues OP_REQ_DEVLIST, and returns the busids
// to attach. For the empty-matches case, returns every remote busid.
// For filtered mode, returns the first match per criterion.
func (c *ClientService) resolveBusIDs() []string {
// Busid-only matches don't require enumeration.
if len(c.matches) > 0 && everyMatchBusIDOnly(c.matches) {
out := make([]string, 0, len(c.matches))
for _, m := range c.matches {
out = append(out, m.BusID)
}
return dedupe(out)
}
for {
if err := c.ctx.Err(); err != nil {
return nil
}
entries, err := c.fetchDevList()
if err != nil {
c.logger.Error("enumerate ", c.serverAddr, ": ", err)
if !sleepCtx(c.ctx, clientReconnectDelay) {
return nil
}
continue
}
if len(c.matches) == 0 {
out := make([]string, 0, len(entries))
for i := range entries {
out = append(out, entries[i].Info.BusIDString())
}
return out
}
var out []string
for _, m := range c.matches {
picked := ""
for i := range entries {
key := DeviceKey{
BusID: entries[i].Info.BusIDString(),
VendorID: entries[i].Info.IDVendor,
ProductID: entries[i].Info.IDProduct,
Serial: entries[i].Info.SerialString(),
}
if Matches(m, key) {
picked = key.BusID
break
}
}
if picked == "" {
c.logger.Warn("no remote device matched ", describeMatch(m))
continue
}
out = append(out, picked)
}
return dedupe(out)
}
}
func (c *ClientService) fetchDevList() ([]DeviceEntry, error) {
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return nil, err
}
defer conn.Close()
if err := binary.Write(conn, binary.BigEndian, OpHeader{Version: ProtocolVersion, Code: OpReqDevList, Status: OpStatusOK}); err != nil {
return nil, E.Cause(err, "send OP_REQ_DEVLIST")
}
header, err := ReadOpHeader(conn)
if err != nil {
return nil, E.Cause(err, "read OP_REP_DEVLIST header")
}
if header.Code != OpRepDevList || header.Status != OpStatusOK {
return nil, E.New("OP_REP_DEVLIST status=", header.Status, " code=0x", hex16(header.Code))
}
return ReadOpRepDevListBody(conn)
}
// worker keeps one remote busid attached to vhci_hcd.0. On any error or
// kernel-side detach, waits clientReconnectDelay and retries.
func (c *ClientService) worker(busid string) {
defer c.wg.Done()
for {
if err := c.ctx.Err(); err != nil {
return
}
port, err := c.attemptAttach(busid)
if err != nil {
c.logger.Error("attach ", busid, ": ", err)
if !sleepCtx(c.ctx, clientReconnectDelay) {
return
}
continue
}
c.logger.Info("attached ", busid, " → vhci port ", port)
c.trackPort(port, true)
c.watchPort(port, busid)
c.trackPort(port, false)
if err := c.ctx.Err(); err != nil {
return
}
c.logger.Info("vhci port ", port, " released; reattaching ", busid)
if !sleepCtx(c.ctx, clientReconnectDelay) {
return
}
}
}
// attemptAttach performs one dial → OP_REQ_IMPORT → vhci attach sequence.
// The returned TCP socket is handed to the kernel on success; on failure the
// connection is closed before return.
func (c *ClientService) attemptAttach(busid string) (int, error) {
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return -1, E.Cause(err, "dial ", c.serverAddr)
}
success := false
defer func() {
if !success {
conn.Close()
}
}()
if err := WriteOpReqImport(conn, busid); err != nil {
return -1, E.Cause(err, "write OP_REQ_IMPORT")
}
header, err := ReadOpHeader(conn)
if err != nil {
return -1, E.Cause(err, "read OP_REP_IMPORT header")
}
if header.Code != OpRepImport {
return -1, E.New("unexpected reply code 0x", hex16(header.Code))
}
if header.Status != OpStatusOK {
return -1, E.New("remote rejected import (status=", header.Status, ")")
}
info, err := ReadOpRepImportBody(conn)
if err != nil {
return -1, E.Cause(err, "read OP_REP_IMPORT body")
}
tcp, ok := conn.(*net.TCPConn)
if !ok {
return -1, E.New("dialed conn is not *net.TCPConn (type=", conn, ")")
}
file, err := tcp.File()
if err != nil {
return -1, E.Cause(err, "dup socket fd")
}
defer file.Close()
c.attachMu.Lock()
defer c.attachMu.Unlock()
port, err := vhciPickFreePort(info.Speed)
if err != nil {
return -1, err
}
if err := vhciAttach(port, file.Fd(), info.DevID(), info.Speed); err != nil {
return -1, E.Cause(err, "vhci attach")
}
success = true
return port, nil
}
// watchPort polls vhci status every 2s and returns when the port is no longer
// in VDEV_ST_USED, or when ctx is canceled (in which case it detaches the port).
func (c *ClientService) watchPort(port int, busid string) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
if err := vhciDetach(port); err != nil {
c.logger.Warn("detach port ", port, " (", busid, "): ", err)
}
return
case <-ticker.C:
used, err := vhciPortUsed(port)
if err != nil {
c.logger.Debug("poll port ", port, ": ", err)
continue
}
if !used {
return
}
}
}
}
func (c *ClientService) trackPort(port int, add bool) {
c.portsMu.Lock()
defer c.portsMu.Unlock()
if add {
c.ports[port] = struct{}{}
} else {
delete(c.ports, port)
}
}
func everyMatchBusIDOnly(matches []option.USBIPDeviceMatch) bool {
for _, m := range matches {
if m.BusID == "" || m.VendorID != 0 || m.ProductID != 0 || m.Serial != "" {
return false
}
}
return true
}
func dedupe(in []string) []string {
seen := make(map[string]struct{}, len(in))
out := make([]string, 0, len(in))
for _, s := range in {
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
out = append(out, s)
}
return out
}
func sleepCtx(ctx context.Context, d time.Duration) bool {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return false
case <-t.C:
return true
}
}

37
service/usbip/match.go Normal file
View file

@ -0,0 +1,37 @@
package usbip
import (
"github.com/sagernet/sing-box/option"
)
// DeviceKey is the minimal set of fields needed to evaluate a USBIPDeviceMatch.
// Both the local sysfs enumerator and the remote OP_REP_DEVLIST parser populate
// this before running matches.
type DeviceKey struct {
BusID string
VendorID uint16
ProductID uint16
Serial string
}
// Matches reports whether d satisfies m. Non-zero fields AND together; an
// all-zero match is treated as non-matching (callers should reject such
// configs earlier).
func Matches(m option.USBIPDeviceMatch, d DeviceKey) bool {
if m.IsZero() {
return false
}
if m.BusID != "" && m.BusID != d.BusID {
return false
}
if m.VendorID != 0 && uint16(m.VendorID) != d.VendorID {
return false
}
if m.ProductID != 0 && uint16(m.ProductID) != d.ProductID {
return false
}
if m.Serial != "" && m.Serial != d.Serial {
return false
}
return true
}

255
service/usbip/protocol.go Normal file
View file

@ -0,0 +1,255 @@
package usbip
import (
"encoding/binary"
"io"
"strings"
E "github.com/sagernet/sing/common/exceptions"
)
// Wire constants.
const (
DefaultPort = 3240
ProtocolVersion uint16 = 0x0111
OpReqDevList uint16 = 0x8005
OpRepDevList uint16 = 0x0005
OpReqImport uint16 = 0x8003
OpRepImport uint16 = 0x0003
OpStatusOK uint32 = 0
OpStatusError uint32 = 1
maxOpRepDevListEntries = 4096
maxOpRepDevListBodyBytes = 8 << 20
deviceInfoWireSize = 312
deviceInterfaceWireSize = 4
)
// USB speeds (enum usb_device_speed).
const (
SpeedUnknown uint32 = 0
SpeedLow uint32 = 1
SpeedFull uint32 = 2
SpeedHigh uint32 = 3
SpeedWireless uint32 = 4
SpeedSuper uint32 = 5
SpeedSuperPlus uint32 = 6
)
// OpHeader is the 8-byte header prefix of every OP message.
type OpHeader struct {
Version uint16
Code uint16
Status uint32
}
// DeviceInfoTruncated is the 312-byte device descriptor shared by OP_REP_DEVLIST
// entries and OP_REP_IMPORT bodies.
type DeviceInfoTruncated struct {
Path [256]byte
BusID [32]byte
BusNum uint32
DevNum uint32
Speed uint32
IDVendor uint16
IDProduct uint16
BCDDevice uint16
BDeviceClass uint8
BDeviceSubClass uint8
BDeviceProtocol uint8
BConfigurationValue uint8
BNumConfigurations uint8
BNumInterfaces uint8
}
// DeviceInterface is the 4-byte per-interface descriptor carried in OP_REP_DEVLIST.
type DeviceInterface struct {
BInterfaceClass uint8
BInterfaceSubClass uint8
BInterfaceProtocol uint8
Padding uint8
}
// DeviceEntry is one element of an OP_REP_DEVLIST body.
type DeviceEntry struct {
Info DeviceInfoTruncated
Interfaces []DeviceInterface
}
// WriteOpHeader emits the 8-byte OP header.
func WriteOpHeader(w io.Writer, code uint16, status uint32) error {
return binary.Write(w, binary.BigEndian, OpHeader{
Version: ProtocolVersion,
Code: code,
Status: status,
})
}
// ReadOpHeader consumes the 8-byte OP header and returns it.
func ReadOpHeader(r io.Reader) (OpHeader, error) {
var h OpHeader
if err := binary.Read(r, binary.BigEndian, &h); err != nil {
return h, err
}
return h, nil
}
// WriteOpReqImport sends OP_REQ_IMPORT for busid (8 + 32 = 40 bytes).
func WriteOpReqImport(w io.Writer, busid string) error {
if err := WriteOpHeader(w, OpReqImport, OpStatusOK); err != nil {
return err
}
var field [32]byte
if len(busid) >= len(field) {
return E.New("busid too long: ", busid)
}
copy(field[:], busid)
return binary.Write(w, binary.BigEndian, field)
}
// ReadOpReqImportBody reads the 32-byte busid that follows the OP header.
func ReadOpReqImportBody(r io.Reader) (string, error) {
var field [32]byte
if _, err := io.ReadFull(r, field[:]); err != nil {
return "", err
}
return cstring(field[:]), nil
}
// WriteOpRepImport sends OP_REP_IMPORT. If status != OpStatusOK, info is omitted.
func WriteOpRepImport(w io.Writer, status uint32, info *DeviceInfoTruncated) error {
if err := WriteOpHeader(w, OpRepImport, status); err != nil {
return err
}
if status != OpStatusOK {
return nil
}
if info == nil {
return E.New("OP_REP_IMPORT success without device info")
}
return binary.Write(w, binary.BigEndian, info)
}
// ReadOpRepImportBody reads the 312-byte device info that follows a successful header.
func ReadOpRepImportBody(r io.Reader) (DeviceInfoTruncated, error) {
var info DeviceInfoTruncated
if err := binary.Read(r, binary.BigEndian, &info); err != nil {
return info, err
}
return info, nil
}
// WriteOpRepDevList emits an OP_REP_DEVLIST response with the given entries.
func WriteOpRepDevList(w io.Writer, entries []DeviceEntry) error {
if err := WriteOpHeader(w, OpRepDevList, OpStatusOK); err != nil {
return err
}
if err := binary.Write(w, binary.BigEndian, uint32(len(entries))); err != nil {
return err
}
for i := range entries {
if err := binary.Write(w, binary.BigEndian, &entries[i].Info); err != nil {
return err
}
for j := range entries[i].Interfaces {
if err := binary.Write(w, binary.BigEndian, &entries[i].Interfaces[j]); err != nil {
return err
}
}
}
return nil
}
// ReadOpRepDevListBody reads what follows a devlist OP header: a uint32 count
// plus that many device entries with their per-interface tails.
func ReadOpRepDevListBody(r io.Reader) ([]DeviceEntry, error) {
var count uint32
if err := binary.Read(r, binary.BigEndian, &count); err != nil {
return nil, err
}
if count > maxOpRepDevListEntries {
return nil, E.New("OP_REP_DEVLIST device count too large: ", count)
}
bodyBytes := uint64(4)
entries := make([]DeviceEntry, int(count))
for i := range entries {
if err := binary.Read(r, binary.BigEndian, &entries[i].Info); err != nil {
return nil, err
}
bodyBytes += deviceInfoWireSize
if bodyBytes > maxOpRepDevListBodyBytes {
return nil, E.New("OP_REP_DEVLIST body too large")
}
n := int(entries[i].Info.BNumInterfaces)
if n > 0 {
bodyBytes += uint64(n) * deviceInterfaceWireSize
if bodyBytes > maxOpRepDevListBodyBytes {
return nil, E.New("OP_REP_DEVLIST interface data too large")
}
entries[i].Interfaces = make([]DeviceInterface, n)
for j := range entries[i].Interfaces {
if err := binary.Read(r, binary.BigEndian, &entries[i].Interfaces[j]); err != nil {
return nil, err
}
}
}
}
return entries, nil
}
// BusID extracts the null-terminated busid from a DeviceInfoTruncated.
func (d *DeviceInfoTruncated) BusIDString() string {
return cstring(d.BusID[:])
}
// PathString extracts the null-terminated sysfs path.
func (d *DeviceInfoTruncated) PathString() string {
return cstring(d.Path[:])
}
func (d *DeviceInfoTruncated) SerialString() string {
meta := trailingCString(d.Path[:])
serial, ok := strings.CutPrefix(meta, "serial=")
if !ok {
return ""
}
return serial
}
// DevID packs busnum/devnum the way vhci_hcd.attach expects.
func (d *DeviceInfoTruncated) DevID() uint32 {
return (d.BusNum << 16) | (d.DevNum & 0xffff)
}
func encodePathField(dst *[256]byte, path, serial string) {
copy(dst[:], path)
if serial == "" || len(path) >= len(dst)-1 {
return
}
copy(dst[len(path)+1:], "serial="+serial)
}
func cstring(b []byte) string {
for i, c := range b {
if c == 0 {
return string(b[:i])
}
}
return string(b)
}
func trailingCString(b []byte) string {
for i, c := range b {
if c != 0 {
continue
}
if i+1 >= len(b) {
return ""
}
return cstring(b[i+1:])
}
return ""
}

View file

@ -0,0 +1,14 @@
//go:build linux
package usbip
import (
boxService "github.com/sagernet/sing-box/adapter/service"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
)
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.USBIPServerServiceOptions](registry, C.TypeUSBIPServer, NewServerService)
boxService.Register[option.USBIPClientServiceOptions](registry, C.TypeUSBIPClient, NewClientService)
}

View file

@ -0,0 +1,23 @@
//go:build !linux
package usbip
import (
"context"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.USBIPServerServiceOptions](registry, C.TypeUSBIPServer, func(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPServerServiceOptions) (adapter.Service, error) {
return nil, E.New("usbip-server service is only supported on Linux")
})
boxService.Register[option.USBIPClientServiceOptions](registry, C.TypeUSBIPClient, func(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPClientServiceOptions) (adapter.Service, error) {
return nil, E.New("usbip-client service is only supported on Linux")
})
}

View file

@ -0,0 +1,381 @@
//go:build linux
package usbip
import (
"context"
"net"
"sync"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/listener"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
)
type serverExport struct {
busid string
managed bool
originalDriver string
}
type ServerService struct {
boxService.Adapter
ctx context.Context
cancel context.CancelFunc
logger log.ContextLogger
listener *listener.Listener
matches []option.USBIPDeviceMatch
mu sync.Mutex
exports []serverExport
listenFD net.Listener
}
func NewServerService(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPServerServiceOptions) (adapter.Service, error) {
for i, m := range options.Devices {
if m.IsZero() {
return nil, E.New("devices[", i, "]: at least one of busid/vendor_id/product_id/serial is required")
}
}
if options.ListenPort == 0 {
options.ListenPort = DefaultPort
}
ctx, cancel := context.WithCancel(ctx)
s := &ServerService{
Adapter: boxService.NewAdapter(C.TypeUSBIPServer, tag),
ctx: ctx,
cancel: cancel,
logger: logger,
matches: options.Devices,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
}
return s, nil
}
func (s *ServerService) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if err := ensureHostDriver(); err != nil {
return err
}
if err := s.bindExports(); err != nil {
s.rollbackExports()
return err
}
tcpListener, err := s.listener.ListenTCP()
if err != nil {
s.rollbackExports()
return err
}
s.mu.Lock()
s.listenFD = tcpListener
s.mu.Unlock()
go s.acceptLoop(tcpListener)
return nil
}
func (s *ServerService) Close() error {
if s.cancel != nil {
s.cancel()
}
err := common.Close(common.PtrOrNil(s.listener))
s.rollbackExports()
return err
}
// bindExports resolves every match against current sysfs state, unbinds from
// the current driver, and binds to usbip-host.
func (s *ServerService) bindExports() error {
devices, err := listUSBDevices()
if err != nil {
return E.Cause(err, "enumerate usb devices")
}
seen := make(map[string]bool)
for _, m := range s.matches {
matched := 0
for i := range devices {
if !Matches(m, devices[i].key()) {
continue
}
if seen[devices[i].BusID] {
matched++
continue
}
if err := s.bindOne(&devices[i]); err != nil {
s.logger.Warn("bind ", devices[i].BusID, ": ", err)
continue
}
seen[devices[i].BusID] = true
matched++
}
if matched == 0 {
s.logger.Warn("no local device matched ", describeMatch(m))
}
}
return nil
}
func (s *ServerService) bindOne(d *sysfsDevice) error {
driver, err := currentDriver(d.BusID)
if err != nil {
return err
}
if driver == "usbip-host" {
s.logger.Info("device ", d.BusID, " already bound to usbip-host; co-opting")
s.mu.Lock()
s.exports = append(s.exports, serverExport{busid: d.BusID})
s.mu.Unlock()
return nil
}
if driver != "" {
if err := unbindFromDriver(d.BusID, driver); err != nil {
return E.Cause(err, "unbind from ", driver)
}
}
if err := hostMatchBusID(d.BusID, true); err != nil {
if driver != "" {
_ = bindToDriver(d.BusID, driver)
}
return E.Cause(err, "match_busid add")
}
if err := hostBind(d.BusID); err != nil {
_ = hostMatchBusID(d.BusID, false)
if driver != "" {
_ = bindToDriver(d.BusID, driver)
}
return E.Cause(err, "bind to usbip-host")
}
s.logger.Info("exported ", d.BusID, " (previously on ", driverOrNone(driver), ")")
s.mu.Lock()
s.exports = append(s.exports, serverExport{
busid: d.BusID,
managed: true,
originalDriver: driver,
})
s.mu.Unlock()
return nil
}
func (s *ServerService) rollbackExports() {
s.mu.Lock()
exports := s.exports
s.exports = nil
s.mu.Unlock()
for _, e := range exports {
if !e.managed {
continue
}
// Release any attached peer.
_ = writeUsbipSockfd(e.busid, -1)
if err := hostUnbind(e.busid); err != nil {
s.logger.Warn("unbind ", e.busid, ": ", err)
}
if err := hostMatchBusID(e.busid, false); err != nil {
s.logger.Debug("match_busid del ", e.busid, ": ", err)
}
if e.originalDriver == "" {
s.logger.Info("released ", e.busid, " from usbip-host")
continue
}
if err := bindToDriver(e.busid, e.originalDriver); err != nil {
s.logger.Warn("rebind ", e.busid, " to ", e.originalDriver, ": ", err)
continue
}
s.logger.Info("restored ", e.busid, " to ", e.originalDriver)
}
}
func (s *ServerService) currentExports() []string {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]string, len(s.exports))
for i, e := range s.exports {
out[i] = e.busid
}
return out
}
func (s *ServerService) acceptLoop(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
}
if E.IsClosed(err) {
return
}
s.logger.Error("accept: ", err)
return
}
go s.handleConn(conn)
}
}
func (s *ServerService) handleConn(conn net.Conn) {
defer conn.Close()
header, err := ReadOpHeader(conn)
if err != nil {
s.logger.Debug("read op header: ", err)
return
}
switch header.Code {
case OpReqDevList:
s.handleDevList(conn)
case OpReqImport:
s.handleImport(conn)
default:
s.logger.Debug("unknown opcode 0x", hex16(header.Code))
}
}
func (s *ServerService) handleDevList(conn net.Conn) {
entries := s.buildDevListEntries()
if err := WriteOpRepDevList(conn, entries); err != nil {
s.logger.Debug("write devlist: ", err)
}
}
func (s *ServerService) buildDevListEntries() []DeviceEntry {
busids := s.currentExports()
if len(busids) == 0 {
return nil
}
entries := make([]DeviceEntry, 0, len(busids))
for _, busid := range busids {
d, err := readSysfsDevice(busid, sysBusDevicePath(busid))
if err != nil {
s.logger.Debug("refresh ", busid, ": ", err)
continue
}
entries = append(entries, DeviceEntry{
Info: d.toProtocol(),
Interfaces: d.Interfaces,
})
}
return entries
}
func (s *ServerService) handleImport(conn net.Conn) {
busid, err := ReadOpReqImportBody(conn)
if err != nil {
s.logger.Debug("read import body: ", err)
return
}
if !s.isExported(busid) {
s.logger.Info("import rejected (unknown busid): ", busid)
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
status, err := readUsbipStatus(busid)
if err != nil || status != 1 {
s.logger.Info("import rejected (busid ", busid, " status=", status, " err=", err, ")")
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
dev, err := readSysfsDevice(busid, sysBusDevicePath(busid))
if err != nil {
s.logger.Warn("refresh ", busid, ": ", err)
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
tcp, ok := conn.(*net.TCPConn)
if !ok {
s.logger.Warn("import requires *net.TCPConn, got ", conn)
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
file, err := tcp.File()
if err != nil {
s.logger.Warn("dup socket fd: ", err)
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
defer file.Close()
if err := writeUsbipSockfd(busid, int(file.Fd())); err != nil {
s.logger.Warn("hand off ", busid, " to kernel: ", err)
_ = WriteOpRepImport(conn, OpStatusError, nil)
return
}
info := dev.toProtocol()
if err := WriteOpRepImport(conn, OpStatusOK, &info); err != nil {
s.logger.Warn("reply import ", busid, ": ", err)
_ = writeUsbipSockfd(busid, -1)
return
}
s.logger.Info("attached ", busid, " to remote ", conn.RemoteAddr())
}
func (s *ServerService) isExported(busid string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, e := range s.exports {
if e.busid == busid {
return true
}
}
return false
}
func sysBusDevicePath(busid string) string {
return sysBusUSBDevices + "/" + busid
}
func describeMatch(m option.USBIPDeviceMatch) string {
var parts []string
if m.BusID != "" {
parts = append(parts, "busid="+m.BusID)
}
if m.VendorID != 0 {
parts = append(parts, "vendor_id=0x"+hex16(uint16(m.VendorID)))
}
if m.ProductID != 0 {
parts = append(parts, "product_id=0x"+hex16(uint16(m.ProductID)))
}
if m.Serial != "" {
parts = append(parts, "serial="+m.Serial)
}
return "{" + joinComma(parts) + "}"
}
func driverOrNone(d string) string {
if d == "" {
return "(no driver)"
}
return d
}
func hex16(v uint16) string {
const hexdigits = "0123456789abcdef"
return string([]byte{
hexdigits[(v>>12)&0xf],
hexdigits[(v>>8)&0xf],
hexdigits[(v>>4)&0xf],
hexdigits[v&0xf],
})
}
func joinComma(parts []string) string {
out := ""
for i, p := range parts {
if i > 0 {
out += ","
}
out += p
}
return out
}

View file

@ -0,0 +1,409 @@
//go:build linux
package usbip
import (
"bufio"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
E "github.com/sagernet/sing/common/exceptions"
)
const (
sysBusUSBDevices = "/sys/bus/usb/devices"
sysUsbipHostDriver = "/sys/bus/usb/drivers/usbip-host"
sysVHCIControllerV0 = "/sys/devices/platform/vhci_hcd.0"
)
// sysfsDevice captures the subset of USB device attributes needed for
// matching, export, and OP_REP_DEVLIST emission.
type sysfsDevice struct {
BusID string
Path string // sysfs path, e.g. /sys/bus/usb/devices/1-2
BusNum uint32
DevNum uint32
Speed uint32
VendorID uint16
ProductID uint16
BCDDevice uint16
DeviceClass uint8
DeviceSubClass uint8
DeviceProtocol uint8
ConfigValue uint8
NumConfigs uint8
NumInterfaces uint8
Serial string
Interfaces []DeviceInterface
}
func (d *sysfsDevice) key() DeviceKey {
return DeviceKey{
BusID: d.BusID,
VendorID: d.VendorID,
ProductID: d.ProductID,
Serial: d.Serial,
}
}
func (d *sysfsDevice) toProtocol() DeviceInfoTruncated {
var info DeviceInfoTruncated
encodePathField(&info.Path, d.Path, d.Serial)
copy(info.BusID[:], d.BusID)
info.BusNum = d.BusNum
info.DevNum = d.DevNum
info.Speed = d.Speed
info.IDVendor = d.VendorID
info.IDProduct = d.ProductID
info.BCDDevice = d.BCDDevice
info.BDeviceClass = d.DeviceClass
info.BDeviceSubClass = d.DeviceSubClass
info.BDeviceProtocol = d.DeviceProtocol
info.BConfigurationValue = d.ConfigValue
info.BNumConfigurations = d.NumConfigs
info.BNumInterfaces = d.NumInterfaces
return info
}
type vhciStatusRecord struct {
hub string
port int
state int
}
// ensureHostDriver verifies the usbip-host kernel driver is loaded.
func ensureHostDriver() error {
if _, err := os.Stat(sysUsbipHostDriver); err != nil {
return E.Cause(err, "usbip-host driver not present; modprobe usbip-host")
}
return nil
}
// ensureVHCI verifies the vhci_hcd controller is loaded.
func ensureVHCI() error {
if _, err := os.Stat(sysVHCIControllerV0); err != nil {
return E.Cause(err, "vhci_hcd.0 not present; modprobe vhci-hcd")
}
return nil
}
// listUSBDevices enumerates /sys/bus/usb/devices, returning non-interface
// device entries that expose idVendor.
func listUSBDevices() ([]sysfsDevice, error) {
entries, err := os.ReadDir(sysBusUSBDevices)
if err != nil {
return nil, err
}
var devices []sysfsDevice
for _, entry := range entries {
name := entry.Name()
if strings.Contains(name, ":") {
continue
}
path := filepath.Join(sysBusUSBDevices, name)
device, err := readSysfsDevice(name, path)
if err != nil {
continue
}
devices = append(devices, device)
}
return devices, nil
}
// readSysfsDevice populates a sysfsDevice from the attributes at path.
func readSysfsDevice(busid, path string) (sysfsDevice, error) {
d := sysfsDevice{BusID: busid, Path: path}
vendor, err := readHexU16(path, "idVendor")
if err != nil {
return d, err
}
d.VendorID = vendor
d.ProductID, _ = readHexU16(path, "idProduct")
d.BCDDevice, _ = readHexU16(path, "bcdDevice")
if v, err := readDecU32(path, "busnum"); err == nil {
d.BusNum = v
}
if v, err := readDecU32(path, "devnum"); err == nil {
d.DevNum = v
}
d.Speed = speedCodeFromString(readString(path, "speed"))
d.DeviceClass, _ = readHexU8(path, "bDeviceClass")
d.DeviceSubClass, _ = readHexU8(path, "bDeviceSubClass")
d.DeviceProtocol, _ = readHexU8(path, "bDeviceProtocol")
d.ConfigValue, _ = readDecU8(path, "bConfigurationValue")
d.NumConfigs, _ = readDecU8(path, "bNumConfigurations")
d.NumInterfaces, _ = readDecU8(path, "bNumInterfaces")
d.Serial = readString(path, "serial")
d.Interfaces = readInterfaces(path, busid, d.ConfigValue, int(d.NumInterfaces))
return d, nil
}
// readInterfaces reads the per-interface descriptors sibling to the device node.
func readInterfaces(devicePath, busid string, configValue uint8, count int) []DeviceInterface {
if count == 0 {
return nil
}
interfaces := make([]DeviceInterface, count)
for i := 0; i < count; i++ {
name := fmt.Sprintf("%s:%d.%d", busid, configValue, i)
ipath := filepath.Join(filepath.Dir(devicePath), name)
class, _ := readHexU8(ipath, "bInterfaceClass")
subClass, _ := readHexU8(ipath, "bInterfaceSubClass")
protocol, _ := readHexU8(ipath, "bInterfaceProtocol")
interfaces[i] = DeviceInterface{
BInterfaceClass: class,
BInterfaceSubClass: subClass,
BInterfaceProtocol: protocol,
}
}
return interfaces
}
// currentDriver returns the driver currently bound to busid, or "" if none.
func currentDriver(busid string) (string, error) {
link, err := os.Readlink(filepath.Join(sysBusUSBDevices, busid, "driver"))
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", err
}
return filepath.Base(link), nil
}
// unbindFromDriver detaches busid from driver.
func unbindFromDriver(busid, driver string) error {
path := filepath.Join("/sys/bus/usb/drivers", driver, "unbind")
return writeSysfs(path, busid)
}
// bindToDriver attaches busid to driver.
func bindToDriver(busid, driver string) error {
path := filepath.Join("/sys/bus/usb/drivers", driver, "bind")
return writeSysfs(path, busid)
}
// hostMatchBusID writes "add <busid>" / "del <busid>" to the usbip-host
// match_busid attribute. Returns nil on ENOENT for "del" (idempotent).
func hostMatchBusID(busid string, add bool) error {
verb := "del"
if add {
verb = "add"
}
path := filepath.Join(sysUsbipHostDriver, "match_busid")
return writeSysfs(path, verb+" "+busid)
}
// hostBind attaches busid to usbip-host.
func hostBind(busid string) error {
return writeSysfs(filepath.Join(sysUsbipHostDriver, "bind"), busid)
}
// hostUnbind detaches busid from usbip-host.
func hostUnbind(busid string) error {
return writeSysfs(filepath.Join(sysUsbipHostDriver, "unbind"), busid)
}
// readUsbipStatus returns the usbip_status attribute value for busid.
// 1 = AVAILABLE, 2 = USED, 3 = ERROR.
func readUsbipStatus(busid string) (int, error) {
raw, err := os.ReadFile(filepath.Join(sysBusUSBDevices, busid, "usbip_status"))
if err != nil {
return 0, err
}
v, err := strconv.Atoi(strings.TrimSpace(string(raw)))
if err != nil {
return 0, err
}
return v, nil
}
// writeUsbipSockfd hands the fd for busid to the usbip-host kernel driver.
// Passing -1 as fd releases the connection.
func writeUsbipSockfd(busid string, fd int) error {
return writeSysfs(filepath.Join(sysBusUSBDevices, busid, "usbip_sockfd"), strconv.Itoa(fd))
}
// vhciPickFreePort scans the vhci_hcd.0 status table and returns a free port
// from the hub that matches the remote device speed.
func vhciPickFreePort(speed uint32) (int, error) {
records, err := readVHCIStatus()
if err != nil {
return -1, err
}
targetHub := vhciHubForSpeed(speed)
for _, record := range records {
if record.hub != targetHub || record.state != 4 {
continue
}
return record.port, nil
}
return -1, E.New("no free ", targetHub, " vhci port")
}
// vhciPortUsed reports whether the given port is currently in VDEV_ST_USED (6).
func vhciPortUsed(port int) (bool, error) {
records, err := readVHCIStatus()
if err != nil {
return false, err
}
for _, record := range records {
if record.port != port {
continue
}
return record.state == 6, nil // VDEV_ST_USED
}
return false, nil
}
// vhciAttach writes "port fd devid speed" to the attach attribute.
func vhciAttach(port int, fd uintptr, devid uint32, speed uint32) error {
line := fmt.Sprintf("%d %d %d %d", port, int(fd), devid, speed)
return writeSysfs(filepath.Join(sysVHCIControllerV0, "attach"), line)
}
// vhciDetach writes the port number to the detach attribute.
func vhciDetach(port int) error {
return writeSysfs(filepath.Join(sysVHCIControllerV0, "detach"), strconv.Itoa(port))
}
func readVHCIStatus() ([]vhciStatusRecord, error) {
raw, err := os.ReadFile(filepath.Join(sysVHCIControllerV0, "status"))
if err != nil {
return nil, err
}
return parseVHCIStatus(string(raw)), nil
}
func parseVHCIStatus(raw string) []vhciStatusRecord {
scanner := bufio.NewScanner(strings.NewReader(raw))
records := make([]vhciStatusRecord, 0)
first := true
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if first {
first = false
continue
}
if line == "" {
continue
}
fields := strings.Fields(line)
if len(fields) < 3 {
continue
}
port, err := strconv.Atoi(fields[1])
if err != nil {
continue
}
state, err := strconv.Atoi(fields[2])
if err != nil {
continue
}
records = append(records, vhciStatusRecord{
hub: fields[0],
port: port,
state: state,
})
}
return records
}
func vhciHubForSpeed(speed uint32) string {
switch speed {
case SpeedSuper, SpeedSuperPlus:
return "ss"
default:
return "hs"
}
}
// --- small helpers ------------------------------------------------------
func writeSysfs(path, content string) error {
f, err := os.OpenFile(path, os.O_WRONLY, 0)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(content)
return err
}
func readString(dir, attr string) string {
raw, err := os.ReadFile(filepath.Join(dir, attr))
if err != nil {
return ""
}
return strings.TrimSpace(string(raw))
}
func readHexU16(dir, attr string) (uint16, error) {
s := readString(dir, attr)
if s == "" {
return 0, E.New(attr, " missing")
}
v, err := strconv.ParseUint(s, 16, 16)
if err != nil {
return 0, err
}
return uint16(v), nil
}
func readHexU8(dir, attr string) (uint8, error) {
s := readString(dir, attr)
if s == "" {
return 0, E.New(attr, " missing")
}
v, err := strconv.ParseUint(s, 16, 8)
if err != nil {
return 0, err
}
return uint8(v), nil
}
func readDecU8(dir, attr string) (uint8, error) {
s := readString(dir, attr)
if s == "" {
return 0, E.New(attr, " missing")
}
v, err := strconv.ParseUint(s, 10, 8)
if err != nil {
return 0, err
}
return uint8(v), nil
}
func readDecU32(dir, attr string) (uint32, error) {
s := readString(dir, attr)
if s == "" {
return 0, E.New(attr, " missing")
}
v, err := strconv.ParseUint(s, 10, 32)
if err != nil {
return 0, err
}
return uint32(v), nil
}
// speedCodeFromString maps the sysfs "speed" attribute to the USB/IP wire
// enum usb_device_speed.
func speedCodeFromString(s string) uint32 {
switch s {
case "1.5":
return SpeedLow
case "12":
return SpeedFull
case "480":
return SpeedHigh
case "5000":
return SpeedSuper
case "10000":
return SpeedSuperPlus
default:
return SpeedUnknown
}
}