diff --git a/service/usbip/client_linux.go b/service/usbip/client_linux.go index 0318ef1bb..db6ed0019 100644 --- a/service/usbip/client_linux.go +++ b/service/usbip/client_linux.go @@ -4,8 +4,9 @@ package usbip import ( "context" - "encoding/binary" + "errors" "net" + "slices" "sync" "time" @@ -20,7 +21,15 @@ import ( N "github.com/sagernet/sing/common/network" ) -const clientReconnectDelay = 5 * time.Second +const ( + clientReconnectDelay = 5 * time.Second + controlPingInterval = 10 * time.Second + controlReadTimeout = 30 * time.Second + controlWriteTimeout = 5 * time.Second + controlSessionIdleHint = "control session lost" +) + +var errImmediateReconnect = errors.New("usbip control reconnect") type clientTarget struct { fixedBusID string @@ -34,6 +43,15 @@ func (t clientTarget) description() string { return describeMatch(t.match) } +type clientAssignedWorker struct { + target clientTarget + updates chan string +} + +type clientBusIDWorker struct { + cancel context.CancelFunc +} + type ClientService struct { boxService.Adapter ctx context.Context @@ -43,9 +61,11 @@ type ClientService struct { serverAddr M.Socksaddr matches []option.USBIPDeviceMatch // empty = import all remote exports - assignMu sync.Mutex - targets []clientTarget - assigned []string + stateMu sync.Mutex + targets []clientTarget + assigned []string + assignedWorkers []*clientAssignedWorker + allWorkers map[string]*clientBusIDWorker attachMu sync.Mutex // serializes vhci port pick + attach wg sync.WaitGroup @@ -79,6 +99,7 @@ func NewClientService(ctx context.Context, logger log.ContextLogger, tag string, dialer: outboundDialer, serverAddr: options.ServerOptions.Build(), matches: options.Devices, + allWorkers: make(map[string]*clientBusIDWorker), ports: make(map[int]struct{}), }, nil } @@ -90,6 +111,7 @@ func (c *ClientService) Start(stage adapter.StartStage) error { if err := ensureVHCI(); err != nil { return err } + c.initializeWorkers() c.wg.Add(1) go c.run() return nil @@ -99,7 +121,6 @@ func (c *ClientService) Close() error { if c.cancel != nil { c.cancel() } - // Wait for workers to detach, bounded by 5s. done := make(chan struct{}) go func() { c.wg.Wait() @@ -113,35 +134,356 @@ func (c *ClientService) Close() error { return nil } -// run prepares the desired targets and spawns one worker per target. -func (c *ClientService) run() { - defer c.wg.Done() +func (c *ClientService) initializeWorkers() { targets := c.buildTargets() - if len(targets) == 0 { - c.logger.Warn("no devices to import; client idle") + c.stateMu.Lock() + c.targets = targets + if len(c.matches) == 0 { + c.stateMu.Unlock() return } - c.assignMu.Lock() - c.targets = targets c.assigned = make([]string, len(targets)) - for i := range targets { - c.assigned[i] = targets[i].fixedBusID + c.assignedWorkers = make([]*clientAssignedWorker, len(targets)) + for i, target := range targets { + c.assignedWorkers[i] = &clientAssignedWorker{ + target: target, + updates: make(chan string, 1), + } } - c.assignMu.Unlock() - for i := range targets { + workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...) + c.stateMu.Unlock() + + for _, worker := range workers { c.wg.Add(1) - go c.worker(i) + go c.runAssignedWorker(worker) + } +} + +func (c *ClientService) run() { + defer c.wg.Done() + immediate := true + for { + if !immediate && !sleepCtx(c.ctx, clientReconnectDelay) { + break + } + err := c.runControlSession() + if c.ctx.Err() != nil { + break + } + if err != nil { + c.logger.Error("control ", c.serverAddr, ": ", err) + } + immediate = errors.Is(err, errImmediateReconnect) + } + c.stopAllWorkers() +} + +func (c *ClientService) runControlSession() error { + conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr) + if err != nil { + return E.Cause(err, "dial ", c.serverAddr) + } + defer conn.Close() + stopCloseOnCancel := closeConnOnContextDone(c.ctx, conn) + defer stopCloseOnCancel() + + _ = conn.SetWriteDeadline(time.Now().Add(controlWriteTimeout)) + _ = conn.SetReadDeadline(time.Now().Add(controlWriteTimeout)) + if err := WriteControlPreface(conn); err != nil { + return E.Cause(err, "write control preface") + } + if err := WriteControlHello(conn); err != nil { + return E.Cause(err, "write control hello") + } + ack, err := ReadControlFrame(conn) + if err != nil { + return E.Cause(err, "read control ack") + } + if ack.Type != controlFrameAck { + return E.New("unexpected control ack frame ", ack.Type) + } + if ack.Version != controlProtocolVersion { + return E.New("unsupported control version ", ack.Version) + } + if ack.Capabilities&controlCapabilities != controlCapabilities { + return E.New("missing control capabilities 0x", ack.Capabilities) + } + _ = conn.SetWriteDeadline(time.Time{}) + _ = conn.SetReadDeadline(time.Time{}) + + if err := c.syncRemoteState(); err != nil { + return E.Cause(err, "initial devlist sync") + } + + pingDone := make(chan struct{}) + go c.controlPingLoop(conn, pingDone) + defer close(pingDone) + + lastSeq := ack.Sequence + for { + if err := conn.SetReadDeadline(time.Now().Add(controlReadTimeout)); err != nil { + return err + } + frame, err := ReadControlFrame(conn) + if err != nil { + return E.Cause(errImmediateReconnect, controlSessionIdleHint, ": ", err) + } + switch frame.Type { + case controlFrameChanged: + if frame.Sequence != lastSeq+1 { + return E.Cause(errImmediateReconnect, "control sequence jumped from ", lastSeq, " to ", frame.Sequence) + } + lastSeq = frame.Sequence + if err := c.syncRemoteState(); err != nil { + return E.Cause(errImmediateReconnect, "devlist sync after change ", frame.Sequence, ": ", err) + } + case controlFramePong: + default: + return E.Cause(errImmediateReconnect, "unexpected control frame ", frame.Type) + } + } +} + +func (c *ClientService) controlPingLoop(conn net.Conn, done <-chan struct{}) { + ticker := time.NewTicker(controlPingInterval) + defer ticker.Stop() + for { + select { + case <-c.ctx.Done(): + return + case <-done: + return + case <-ticker.C: + _ = conn.SetWriteDeadline(time.Now().Add(controlWriteTimeout)) + if err := WriteControlPing(conn); err != nil { + _ = conn.Close() + return + } + _ = conn.SetWriteDeadline(time.Time{}) + } + } +} + +func (c *ClientService) syncRemoteState() error { + entries, err := c.fetchDevList(c.ctx) + if err != nil { + return err + } + c.applyRemoteEntries(entries) + return nil +} + +func (c *ClientService) applyRemoteEntries(entries []DeviceEntry) { + if len(c.matches) == 0 { + c.applyRemoteExports(entries) + return + } + c.applyMatchedExports(entries) +} + +func (c *ClientService) applyRemoteExports(entries []DeviceEntry) { + desired := make(map[string]struct{}, len(entries)) + for i := range entries { + busid := entries[i].Info.BusIDString() + if busid == "" { + continue + } + desired[busid] = struct{}{} + } + + c.stateMu.Lock() + stopWorkers := make([]*clientBusIDWorker, 0) + for busid, worker := range c.allWorkers { + if _, ok := desired[busid]; ok { + continue + } + stopWorkers = append(stopWorkers, worker) + delete(c.allWorkers, busid) + } + startBusIDs := make([]string, 0) + for busid := range desired { + if _, ok := c.allWorkers[busid]; ok { + continue + } + startBusIDs = append(startBusIDs, busid) + } + c.stateMu.Unlock() + + for _, worker := range stopWorkers { + worker.cancel() + } + slices.Sort(startBusIDs) + for _, busid := range startBusIDs { + c.startRemoteBusIDWorker(busid, busid) + } +} + +func (c *ClientService) applyMatchedExports(entries []DeviceEntry) { + keysByBusID := make(map[string]DeviceKey, len(entries)) + for i := range entries { + busid := entries[i].Info.BusIDString() + if busid == "" { + continue + } + keysByBusID[busid] = DeviceKey{ + BusID: busid, + VendorID: entries[i].Info.IDVendor, + ProductID: entries[i].Info.IDProduct, + Serial: entries[i].Info.SerialString(), + } + } + + c.stateMu.Lock() + if len(c.targets) == 0 { + c.stateMu.Unlock() + return + } + + nextAssigned := make([]string, len(c.targets)) + reserved := make(map[string]struct{}, len(c.targets)) + for i, target := range c.targets { + if target.fixedBusID == "" { + continue + } + if _, ok := keysByBusID[target.fixedBusID]; !ok { + continue + } + nextAssigned[i] = target.fixedBusID + reserved[target.fixedBusID] = struct{}{} + } + for i, target := range c.targets { + if target.fixedBusID != "" { + continue + } + current := c.assigned[i] + if current == "" { + continue + } + if _, ok := reserved[current]; ok { + continue + } + key, ok := keysByBusID[current] + if !ok || !Matches(target.match, key) { + continue + } + nextAssigned[i] = current + reserved[current] = struct{}{} + } + for i, target := range c.targets { + if target.fixedBusID != "" || nextAssigned[i] != "" { + continue + } + nextAssigned[i] = firstMatchingUnclaimedBusID(target.match, entries, reserved) + if nextAssigned[i] != "" { + reserved[nextAssigned[i]] = struct{}{} + } + } + + workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...) + previous := append([]string(nil), c.assigned...) + c.assigned = nextAssigned + c.stateMu.Unlock() + + for i, worker := range workers { + if previous[i] == nextAssigned[i] { + continue + } + worker.setDesiredBusID(nextAssigned[i]) + } +} + +func (c *ClientService) runAssignedWorker(worker *clientAssignedWorker) { + defer c.wg.Done() + + var current string + var runnerCancel context.CancelFunc + var runnerDone chan struct{} + + stopRunner := func() { + if runnerCancel == nil { + return + } + runnerCancel() + <-runnerDone + runnerCancel = nil + runnerDone = nil + } + + for { + select { + case <-c.ctx.Done(): + stopRunner() + return + case desired := <-worker.updates: + if desired == current { + continue + } + stopRunner() + current = desired + if desired == "" { + continue + } + + runCtx, cancel := context.WithCancel(c.ctx) + done := make(chan struct{}) + runnerCancel = cancel + runnerDone = done + + c.wg.Add(1) + go func(busid string) { + defer c.wg.Done() + defer close(done) + c.runBusIDLoop(runCtx, busid, worker.target.description()) + }(desired) + } + } +} + +func (w *clientAssignedWorker) setDesiredBusID(busid string) { + select { + case w.updates <- busid: + return + default: + } + select { + case <-w.updates: + default: + } + w.updates <- busid +} + +func (c *ClientService) startRemoteBusIDWorker(busid, description string) { + runCtx, cancel := context.WithCancel(c.ctx) + worker := &clientBusIDWorker{cancel: cancel} + + c.stateMu.Lock() + c.allWorkers[busid] = worker + c.stateMu.Unlock() + + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.runBusIDLoop(runCtx, busid, description) + }() +} + +func (c *ClientService) stopAllWorkers() { + c.stateMu.Lock() + workers := make([]*clientBusIDWorker, 0, len(c.allWorkers)) + for _, worker := range c.allWorkers { + workers = append(workers, worker) + } + c.allWorkers = make(map[string]*clientBusIDWorker) + c.stateMu.Unlock() + + for _, worker := range workers { + worker.cancel() } } func (c *ClientService) buildTargets() []clientTarget { if len(c.matches) == 0 { - busids := c.snapshotRemoteBusIDs() - targets := make([]clientTarget, 0, len(busids)) - for _, busid := range busids { - targets = append(targets, clientTarget{fixedBusID: busid}) - } - return targets + return nil } seenFixed := make(map[string]struct{}) targets := make([]clientTarget, 0, len(c.matches)) @@ -159,36 +501,15 @@ func (c *ClientService) buildTargets() []clientTarget { return targets } -// snapshotRemoteBusIDs connects once, issues OP_REQ_DEVLIST, and returns the -// currently exported remote busids. -func (c *ClientService) snapshotRemoteBusIDs() []string { - for { - if err := c.ctx.Err(); err != nil { - return nil - } - entries, err := c.fetchDevList() - if err != nil { - c.logger.Error("enumerate ", c.serverAddr, ": ", err) - if !sleepCtx(c.ctx, clientReconnectDelay) { - return nil - } - continue - } - out := make([]string, 0, len(entries)) - for i := range entries { - out = append(out, entries[i].Info.BusIDString()) - } - return dedupe(out) - } -} - -func (c *ClientService) fetchDevList() ([]DeviceEntry, error) { - conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr) +func (c *ClientService) fetchDevList(ctx context.Context) ([]DeviceEntry, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr) if err != nil { return nil, err } defer conn.Close() - if err := binary.Write(conn, binary.BigEndian, OpHeader{Version: ProtocolVersion, Code: OpReqDevList, Status: OpStatusOK}); err != nil { + stopCloseOnCancel := closeConnOnContextDone(ctx, conn) + defer stopCloseOnCancel() + if err := WriteOpHeader(conn, OpReqDevList, OpStatusOK); err != nil { return nil, E.Cause(err, "send OP_REQ_DEVLIST") } header, err := ReadOpHeader(conn) @@ -201,118 +522,41 @@ func (c *ClientService) fetchDevList() ([]DeviceEntry, error) { return ReadOpRepDevListBody(conn) } -// worker keeps one target attached to vhci_hcd.0. On any error or kernel-side -// detach, waits clientReconnectDelay and retries. -func (c *ClientService) worker(targetIndex int) { - defer c.wg.Done() - target := c.targets[targetIndex] +func (c *ClientService) runBusIDLoop(ctx context.Context, busid, description string) { for { - if err := c.ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return } - busid, err := c.claimTargetBusID(targetIndex) + port, err := c.attemptAttach(ctx, busid) if err != nil { - c.logger.Error("assign ", target.description(), ": ", err) - if !sleepCtx(c.ctx, clientReconnectDelay) { - return - } - continue - } - if busid == "" { - if !sleepCtx(c.ctx, clientReconnectDelay) { - return - } - continue - } - port, err := c.attemptAttach(busid) - if err != nil { - c.releaseTargetBusID(targetIndex, busid) - c.logger.Error("attach ", busid, ": ", err) - if !sleepCtx(c.ctx, clientReconnectDelay) { + c.logger.Error("attach ", description, " (", busid, "): ", err) + if !sleepCtx(ctx, clientReconnectDelay) { return } continue } c.logger.Info("attached ", busid, " → vhci port ", port) c.trackPort(port, true) - c.watchPort(port, busid) + c.watchPort(ctx, port, busid) c.trackPort(port, false) - c.releaseTargetBusID(targetIndex, busid) - if err := c.ctx.Err(); err != nil { + if err := ctx.Err(); err != nil { return } c.logger.Info("vhci port ", port, " released; reattaching ", busid) - if !sleepCtx(c.ctx, clientReconnectDelay) { + if !sleepCtx(ctx, clientReconnectDelay) { return } } } -func (c *ClientService) claimTargetBusID(targetIndex int) (string, error) { - target := c.targets[targetIndex] - if target.fixedBusID != "" { - return target.fixedBusID, nil - } - c.assignMu.Lock() - current := c.assigned[targetIndex] - c.assignMu.Unlock() - if current != "" { - return current, nil - } - entries, err := c.fetchDevList() - if err != nil { - return "", err - } - return c.refreshAssignments(targetIndex, entries), nil -} - -func (c *ClientService) refreshAssignments(targetIndex int, entries []DeviceEntry) string { - c.assignMu.Lock() - defer c.assignMu.Unlock() - if c.assigned[targetIndex] != "" { - return c.assigned[targetIndex] - } - reserved := make(map[string]struct{}, len(c.assigned)) - for _, busid := range c.assigned { - if busid == "" { - continue - } - reserved[busid] = struct{}{} - } - for i, target := range c.targets { - if target.fixedBusID != "" || c.assigned[i] != "" { - continue - } - busid := firstMatchingUnclaimedBusID(target.match, entries, reserved) - if busid == "" { - continue - } - c.assigned[i] = busid - reserved[busid] = struct{}{} - } - return c.assigned[targetIndex] -} - -func (c *ClientService) releaseTargetBusID(targetIndex int, busid string) { - if c.targets[targetIndex].fixedBusID != "" { - return - } - c.assignMu.Lock() - defer c.assignMu.Unlock() - if c.assigned[targetIndex] == busid { - c.assigned[targetIndex] = "" - } -} - -// attemptAttach performs one dial → OP_REQ_IMPORT → vhci attach sequence. -// The returned TCP socket is handed to the kernel on success; on failure the -// connection is closed before return. -func (c *ClientService) attemptAttach(busid string) (int, error) { - conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr) +func (c *ClientService) attemptAttach(ctx context.Context, busid string) (int, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr) if err != nil { return -1, E.Cause(err, "dial ", c.serverAddr) } defer conn.Close() + stopCloseOnCancel := closeConnOnContextDone(ctx, conn) + defer stopCloseOnCancel() if err := WriteOpReqImport(conn, busid); err != nil { return -1, E.Cause(err, "write OP_REQ_IMPORT") } @@ -351,14 +595,12 @@ func (c *ClientService) attemptAttach(busid string) (int, error) { return port, nil } -// watchPort polls vhci status every 2s and returns when the port is no longer -// in VDEV_ST_USED, or when ctx is canceled (in which case it detaches the port). -func (c *ClientService) watchPort(port int, busid string) { +func (c *ClientService) watchPort(ctx context.Context, port int, busid string) { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() for { select { - case <-c.ctx.Done(): + case <-ctx.Done(): if err := vhciDetach(port); err != nil { c.logger.Warn("detach port ", port, " (", busid, "): ", err) } @@ -431,3 +673,17 @@ func sleepCtx(ctx context.Context, d time.Duration) bool { return true } } + +func closeConnOnContextDone(ctx context.Context, conn net.Conn) func() { + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + _ = conn.Close() + case <-done: + } + }() + return func() { + close(done) + } +} diff --git a/service/usbip/control_protocol.go b/service/usbip/control_protocol.go new file mode 100644 index 000000000..d83d41fe1 --- /dev/null +++ b/service/usbip/control_protocol.go @@ -0,0 +1,112 @@ +package usbip + +import ( + "encoding/binary" + "io" +) + +const ( + controlProtocolVersion uint8 = 1 + + controlFrameHello uint8 = 1 + controlFrameAck uint8 = 2 + controlFrameChanged uint8 = 3 + controlFramePing uint8 = 4 + controlFramePong uint8 = 5 + + controlCapabilityChanged uint32 = 1 << 0 + controlCapabilityPingPong uint32 = 1 << 1 + controlCapabilities = controlCapabilityChanged | controlCapabilityPingPong + + controlPrefaceSize = 8 + controlFrameSize = 16 +) + +var controlPreface = [controlPrefaceSize]byte{'S', 'B', 'U', 'S', 'B', 'I', 'P', '1'} + +type controlFrame struct { + Type uint8 + Version uint8 + _ uint16 + Capabilities uint32 + Sequence uint64 +} + +func WriteControlPreface(w io.Writer) error { + _, err := w.Write(controlPreface[:]) + return err +} + +func IsControlPreface(raw []byte) bool { + if len(raw) != len(controlPreface) { + return false + } + for i := range controlPreface { + if raw[i] != controlPreface[i] { + return false + } + } + return true +} + +func WriteControlHello(w io.Writer) error { + return writeControlFrame(w, controlFrame{ + Type: controlFrameHello, + Version: controlProtocolVersion, + Capabilities: controlCapabilities, + }) +} + +func WriteControlAck(w io.Writer, sequence uint64) error { + return writeControlFrame(w, controlFrame{ + Type: controlFrameAck, + Version: controlProtocolVersion, + Capabilities: controlCapabilities, + Sequence: sequence, + }) +} + +func WriteControlChanged(w io.Writer, sequence uint64) error { + return writeControlFrame(w, controlFrame{ + Type: controlFrameChanged, + Version: controlProtocolVersion, + Sequence: sequence, + }) +} + +func WriteControlPing(w io.Writer) error { + return writeControlFrame(w, controlFrame{ + Type: controlFramePing, + Version: controlProtocolVersion, + }) +} + +func WriteControlPong(w io.Writer) error { + return writeControlFrame(w, controlFrame{ + Type: controlFramePong, + Version: controlProtocolVersion, + }) +} + +func ReadControlFrame(r io.Reader) (controlFrame, error) { + var raw [controlFrameSize]byte + if _, err := io.ReadFull(r, raw[:]); err != nil { + return controlFrame{}, err + } + return controlFrame{ + Type: raw[0], + Version: raw[1], + Capabilities: binary.BigEndian.Uint32(raw[4:8]), + Sequence: binary.BigEndian.Uint64(raw[8:16]), + }, nil +} + +func writeControlFrame(w io.Writer, frame controlFrame) error { + var raw [controlFrameSize]byte + raw[0] = frame.Type + raw[1] = frame.Version + binary.BigEndian.PutUint32(raw[4:8], frame.Capabilities) + binary.BigEndian.PutUint64(raw[8:16], frame.Sequence) + _, err := w.Write(raw[:]) + return err +} diff --git a/service/usbip/protocol.go b/service/usbip/protocol.go index 219286072..36c5eed4c 100644 --- a/service/usbip/protocol.go +++ b/service/usbip/protocol.go @@ -97,6 +97,14 @@ func ReadOpHeader(r io.Reader) (OpHeader, error) { return h, nil } +func ParseOpHeader(raw []byte) OpHeader { + return OpHeader{ + Version: binary.BigEndian.Uint16(raw[:2]), + Code: binary.BigEndian.Uint16(raw[2:4]), + Status: binary.BigEndian.Uint32(raw[4:8]), + } +} + // WriteOpReqImport sends OP_REQ_IMPORT for busid (8 + 32 = 40 bytes). func WriteOpReqImport(w io.Writer, busid string) error { if err := WriteOpHeader(w, OpReqImport, OpStatusOK); err != nil { diff --git a/service/usbip/server_linux.go b/service/usbip/server_linux.go index b996a2a4b..7f5e6f5c0 100644 --- a/service/usbip/server_linux.go +++ b/service/usbip/server_linux.go @@ -4,7 +4,10 @@ package usbip import ( "context" + "io" "net" + "os" + "slices" "sync" "time" @@ -25,6 +28,12 @@ type serverExport struct { originalDriver string } +type serverControlConn struct { + id uint64 + conn net.Conn + send chan controlFrame +} + type ServerService struct { boxService.Adapter ctx context.Context @@ -34,8 +43,13 @@ type ServerService struct { matches []option.USBIPDeviceMatch mu sync.Mutex - exports []serverExport + exports map[string]serverExport listenFD net.Listener + + controlMu sync.Mutex + controlSeq uint64 + controlNextID uint64 + controlSubs map[uint64]*serverControlConn } func NewServerService(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPServerServiceOptions) (adapter.Service, error) { @@ -54,12 +68,14 @@ func NewServerService(ctx context.Context, logger log.ContextLogger, tag string, cancel: cancel, logger: logger, matches: options.Devices, + exports: make(map[string]serverExport), listener: listener.New(listener.Options{ Context: ctx, Logger: logger, Network: []string{N.NetworkTCP}, Listen: options.ListenOptions, }), + controlSubs: make(map[uint64]*serverControlConn), } return s, nil } @@ -71,7 +87,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error { if err := ensureHostDriver(); err != nil { return err } - if err := s.bindExports(); err != nil { + if _, err := s.reconcileExports(); err != nil { s.rollbackExports() return err } @@ -84,6 +100,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error { s.listenFD = tcpListener s.mu.Unlock() go s.acceptLoop(tcpListener) + go s.ueventLoop() return nil } @@ -91,47 +108,58 @@ func (s *ServerService) Close() error { if s.cancel != nil { s.cancel() } + s.closeControlSubscribers() err := common.Close(common.PtrOrNil(s.listener)) s.rollbackExports() return err } -// bindExports resolves every match against current sysfs state, unbinds from -// the current driver, and binds to usbip-host. -func (s *ServerService) bindExports() error { +func (s *ServerService) reconcileExports() (bool, error) { devices, err := listUSBDevices() if err != nil { - return E.Cause(err, "enumerate usb devices") + return false, E.Cause(err, "enumerate usb devices") + } + desired := make(map[string]sysfsDevice) + present := make(map[string]struct{}, len(devices)) + for i := range devices { + present[devices[i].BusID] = struct{}{} } - seen := make(map[string]bool) for _, m := range s.matches { - matched := 0 for i := range devices { if !Matches(m, devices[i].key()) { continue } - if seen[devices[i].BusID] { - matched++ - continue - } if devices[i].DeviceClass == 0x09 { - seen[devices[i].BusID] = true - matched++ s.logger.Warn("skip hub device ", devices[i].BusID, " matched by ", describeMatch(m)) continue } - if err := s.bindOne(&devices[i]); err != nil { - s.logger.Warn("bind ", devices[i].BusID, ": ", err) - continue - } - seen[devices[i].BusID] = true - matched++ - } - if matched == 0 { - s.logger.Warn("no local device matched ", describeMatch(m)) + desired[devices[i].BusID] = devices[i] } } - return nil + + current := s.snapshotExports() + changed := false + for busid, device := range desired { + if _, ok := current[busid]; ok { + continue + } + if err := s.bindOne(&device); err != nil { + s.logger.Warn("bind ", busid, ": ", err) + continue + } + changed = true + } + for busid, export := range current { + if _, ok := desired[busid]; ok { + continue + } + _, restore := present[busid] + if err := s.releaseExport(export, restore); err != nil { + s.logger.Warn("release ", busid, ": ", err) + } + changed = true + } + return changed, nil } func (s *ServerService) bindOne(d *sysfsDevice) error { @@ -141,9 +169,7 @@ func (s *ServerService) bindOne(d *sysfsDevice) error { } if driver == "usbip-host" { s.logger.Info("device ", d.BusID, " already bound to usbip-host; co-opting") - s.mu.Lock() - s.exports = append(s.exports, serverExport{busid: d.BusID}) - s.mu.Unlock() + s.setExport(serverExport{busid: d.BusID}) return nil } if driver != "" { @@ -165,55 +191,92 @@ func (s *ServerService) bindOne(d *sysfsDevice) error { return E.Cause(err, "bind to usbip-host") } s.logger.Info("exported ", d.BusID, " (previously on ", driverOrNone(driver), ")") - s.mu.Lock() - s.exports = append(s.exports, serverExport{ + s.setExport(serverExport{ busid: d.BusID, managed: true, originalDriver: driver, }) - s.mu.Unlock() return nil } +func (s *ServerService) releaseExport(export serverExport, restore bool) error { + s.deleteExport(export.busid) + + var releaseErr error + if err := writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) { + releaseErr = err + } + if !export.managed { + s.logger.Info("stopped tracking ", export.busid, " on usbip-host") + return releaseErr + } + if err := hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && releaseErr == nil { + releaseErr = err + } + if err := hostMatchBusID(export.busid, false); err != nil && releaseErr == nil { + releaseErr = err + } + if !restore { + s.logger.Info("removed export state for disappeared device ", export.busid) + return releaseErr + } + if export.originalDriver == "" { + s.logger.Info("released ", export.busid, " from usbip-host") + return releaseErr + } + if err := bindToDriver(export.busid, export.originalDriver); err != nil { + if releaseErr == nil { + releaseErr = err + } + return releaseErr + } + s.logger.Info("restored ", export.busid, " to ", export.originalDriver) + return releaseErr +} + func (s *ServerService) rollbackExports() { - s.mu.Lock() - exports := s.exports - s.exports = nil - s.mu.Unlock() - for _, e := range exports { - if !e.managed { - continue + exports := s.snapshotExports() + for _, export := range exports { + _, restore := currentSysfsDevice(export.busid) + if err := s.releaseExport(export, restore); err != nil { + s.logger.Warn("rollback ", export.busid, ": ", err) } - // Release any attached peer. - _ = writeUsbipSockfd(e.busid, -1) - if err := hostUnbind(e.busid); err != nil { - s.logger.Warn("unbind ", e.busid, ": ", err) - } - if err := hostMatchBusID(e.busid, false); err != nil { - s.logger.Debug("match_busid del ", e.busid, ": ", err) - } - if e.originalDriver == "" { - s.logger.Info("released ", e.busid, " from usbip-host") - continue - } - if err := bindToDriver(e.busid, e.originalDriver); err != nil { - s.logger.Warn("rebind ", e.busid, " to ", e.originalDriver, ": ", err) - continue - } - s.logger.Info("restored ", e.busid, " to ", e.originalDriver) } } func (s *ServerService) currentExports() []string { s.mu.Lock() defer s.mu.Unlock() - out := make([]string, len(s.exports)) - for i, e := range s.exports { - out[i] = e.busid + out := make([]string, 0, len(s.exports)) + for busid := range s.exports { + out = append(out, busid) + } + slices.Sort(out) + return out +} + +func (s *ServerService) snapshotExports() map[string]serverExport { + s.mu.Lock() + defer s.mu.Unlock() + out := make(map[string]serverExport, len(s.exports)) + for busid, export := range s.exports { + out[busid] = export } return out } +func (s *ServerService) setExport(export serverExport) { + s.mu.Lock() + defer s.mu.Unlock() + s.exports[export.busid] = export +} + +func (s *ServerService) deleteExport(busid string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.exports, busid) +} + func (s *ServerService) acceptLoop(ln net.Listener) { for { conn, err := ln.Accept() @@ -237,17 +300,26 @@ func (s *ServerService) acceptLoop(ln net.Listener) { s.logger.Error("accept: ", err) return } - go s.handleConn(conn) + go s.dispatchConn(conn) } } -func (s *ServerService) handleConn(conn net.Conn) { - defer conn.Close() - header, err := ReadOpHeader(conn) - if err != nil { - s.logger.Debug("read op header: ", err) +func (s *ServerService) dispatchConn(conn net.Conn) { + var prefix [controlPrefaceSize]byte + if _, err := io.ReadFull(conn, prefix[:]); err != nil { + s.logger.Debug("read connection preface: ", err) + _ = conn.Close() return } + if IsControlPreface(prefix[:]) { + s.handleControlConn(conn) + return + } + s.handleStandardConn(conn, ParseOpHeader(prefix[:])) +} + +func (s *ServerService) handleStandardConn(conn net.Conn, header OpHeader) { + defer conn.Close() switch header.Code { case OpReqDevList: s.handleDevList(conn) @@ -258,6 +330,71 @@ func (s *ServerService) handleConn(conn net.Conn) { } } +func (s *ServerService) handleControlConn(conn net.Conn) { + defer conn.Close() + + hello, err := ReadControlFrame(conn) + if err != nil { + s.logger.Debug("read control hello: ", err) + return + } + if hello.Type != controlFrameHello { + s.logger.Debug("unexpected control frame ", hello.Type, " before hello") + return + } + if hello.Version != controlProtocolVersion { + s.logger.Debug("unsupported control version ", hello.Version) + return + } + if hello.Capabilities&controlCapabilities != controlCapabilities { + s.logger.Debug("missing control capabilities 0x", hello.Capabilities) + return + } + + sub, seq := s.registerControlConn(conn) + defer s.unregisterControlConn(sub.id) + + if err := WriteControlAck(conn, seq); err != nil { + s.logger.Debug("write control ack: ", err) + return + } + + readDone := make(chan struct{}) + go s.readControlConn(sub, readDone) + for { + select { + case <-s.ctx.Done(): + return + case <-readDone: + return + case frame := <-sub.send: + if err := writeControlFrame(conn, frame); err != nil { + s.logger.Debug("write control frame: ", err) + return + } + } + } +} + +func (s *ServerService) readControlConn(sub *serverControlConn, done chan<- struct{}) { + defer close(done) + for { + frame, err := ReadControlFrame(sub.conn) + if err != nil { + return + } + switch frame.Type { + case controlFramePing: + s.enqueueControlFrame(sub, controlFrame{ + Type: controlFramePong, + Version: controlProtocolVersion, + }) + default: + return + } + } +} + func (s *ServerService) handleDevList(conn net.Conn) { entries := s.buildDevListEntries() if err := WriteOpRepDevList(conn, entries); err != nil { @@ -346,12 +483,124 @@ func (s *ServerService) handleImport(conn net.Conn) { func (s *ServerService) isExported(busid string) bool { s.mu.Lock() defer s.mu.Unlock() - for _, e := range s.exports { - if e.busid == busid { - return true + _, ok := s.exports[busid] + return ok +} + +func (s *ServerService) ueventLoop() { + for { + listener, err := newUEventListener() + if err != nil { + if s.ctx.Err() != nil { + return + } + s.logger.Warn("open uevent listener: ", err) + if !sleepCtx(s.ctx, time.Second) { + return + } + continue + } + done := make(chan struct{}) + go func() { + select { + case <-s.ctx.Done(): + _ = listener.Close() + case <-done: + } + }() + for { + err = listener.WaitUSBEvent() + if err != nil { + close(done) + _ = listener.Close() + if s.ctx.Err() != nil { + return + } + s.logger.Warn("read uevent: ", err) + if !sleepCtx(s.ctx, time.Second) { + return + } + break + } + changed, reconcileErr := s.reconcileExports() + if reconcileErr != nil { + s.logger.Warn("reconcile exports: ", reconcileErr) + continue + } + if changed { + s.broadcastChanged() + } } } - return false +} + +func (s *ServerService) registerControlConn(conn net.Conn) (*serverControlConn, uint64) { + s.controlMu.Lock() + defer s.controlMu.Unlock() + s.controlNextID++ + sub := &serverControlConn{ + id: s.controlNextID, + conn: conn, + send: make(chan controlFrame, 16), + } + s.controlSubs[sub.id] = sub + return sub, s.controlSeq +} + +func (s *ServerService) unregisterControlConn(id uint64) { + s.controlMu.Lock() + defer s.controlMu.Unlock() + delete(s.controlSubs, id) +} + +func (s *ServerService) closeControlSubscribers() { + s.controlMu.Lock() + subs := make([]*serverControlConn, 0, len(s.controlSubs)) + for _, sub := range s.controlSubs { + subs = append(subs, sub) + } + s.controlSubs = make(map[uint64]*serverControlConn) + s.controlMu.Unlock() + for _, sub := range subs { + _ = sub.conn.Close() + } +} + +func (s *ServerService) broadcastChanged() { + s.controlMu.Lock() + s.controlSeq++ + sequence := s.controlSeq + subs := make([]*serverControlConn, 0, len(s.controlSubs)) + for _, sub := range s.controlSubs { + subs = append(subs, sub) + } + s.controlMu.Unlock() + + frame := controlFrame{ + Type: controlFrameChanged, + Version: controlProtocolVersion, + Sequence: sequence, + } + for _, sub := range subs { + s.enqueueControlFrame(sub, frame) + } +} + +func (s *ServerService) enqueueControlFrame(sub *serverControlConn, frame controlFrame) { + select { + case sub.send <- frame: + default: + s.logger.Debug("control subscriber ", sub.id, " lagged behind") + _ = sub.conn.Close() + } +} + +func currentSysfsDevice(busid string) (sysfsDevice, bool) { + device, err := readSysfsDevice(busid, sysBusDevicePath(busid)) + if err != nil { + return sysfsDevice{}, false + } + return device, true } func sysBusDevicePath(busid string) string { diff --git a/service/usbip/uevent_linux.go b/service/usbip/uevent_linux.go new file mode 100644 index 000000000..2eaeeb53e --- /dev/null +++ b/service/usbip/uevent_linux.go @@ -0,0 +1,58 @@ +//go:build linux + +package usbip + +import ( + "bytes" + "os" + + "golang.org/x/sys/unix" +) + +type ueventListener struct { + fd int +} + +func newUEventListener() (*ueventListener, error) { + fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_DGRAM, unix.NETLINK_KOBJECT_UEVENT) + if err != nil { + return nil, err + } + addr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Pid: uint32(os.Getpid()), + Groups: 1, + } + if err := unix.Bind(fd, addr); err != nil { + _ = unix.Close(fd) + return nil, err + } + return &ueventListener{fd: fd}, nil +} + +func (l *ueventListener) Close() error { + return unix.Close(l.fd) +} + +func (l *ueventListener) WaitUSBEvent() error { + var buf [4096]byte + for { + n, _, err := unix.Recvfrom(l.fd, buf[:], 0) + if err != nil { + return err + } + if isUSBUEvent(buf[:n]) { + return nil + } + } +} + +func isUSBUEvent(raw []byte) bool { + fields := bytes.Split(bytes.TrimRight(raw, "\x00"), []byte{0}) + for _, field := range fields { + if bytes.Equal(field, []byte("SUBSYSTEM=usb")) { + return true + } + } + return false +}