service/usbip: add hotplug control push path

This commit is contained in:
世界 2026-04-22 00:21:43 +08:00
parent 9a92967c4e
commit f3bbd3c07b
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
5 changed files with 893 additions and 210 deletions

View file

@ -4,8 +4,9 @@ package usbip
import (
"context"
"encoding/binary"
"errors"
"net"
"slices"
"sync"
"time"
@ -20,7 +21,15 @@ import (
N "github.com/sagernet/sing/common/network"
)
const clientReconnectDelay = 5 * time.Second
const (
clientReconnectDelay = 5 * time.Second
controlPingInterval = 10 * time.Second
controlReadTimeout = 30 * time.Second
controlWriteTimeout = 5 * time.Second
controlSessionIdleHint = "control session lost"
)
var errImmediateReconnect = errors.New("usbip control reconnect")
type clientTarget struct {
fixedBusID string
@ -34,6 +43,15 @@ func (t clientTarget) description() string {
return describeMatch(t.match)
}
type clientAssignedWorker struct {
target clientTarget
updates chan string
}
type clientBusIDWorker struct {
cancel context.CancelFunc
}
type ClientService struct {
boxService.Adapter
ctx context.Context
@ -43,9 +61,11 @@ type ClientService struct {
serverAddr M.Socksaddr
matches []option.USBIPDeviceMatch // empty = import all remote exports
assignMu sync.Mutex
targets []clientTarget
assigned []string
stateMu sync.Mutex
targets []clientTarget
assigned []string
assignedWorkers []*clientAssignedWorker
allWorkers map[string]*clientBusIDWorker
attachMu sync.Mutex // serializes vhci port pick + attach
wg sync.WaitGroup
@ -79,6 +99,7 @@ func NewClientService(ctx context.Context, logger log.ContextLogger, tag string,
dialer: outboundDialer,
serverAddr: options.ServerOptions.Build(),
matches: options.Devices,
allWorkers: make(map[string]*clientBusIDWorker),
ports: make(map[int]struct{}),
}, nil
}
@ -90,6 +111,7 @@ func (c *ClientService) Start(stage adapter.StartStage) error {
if err := ensureVHCI(); err != nil {
return err
}
c.initializeWorkers()
c.wg.Add(1)
go c.run()
return nil
@ -99,7 +121,6 @@ func (c *ClientService) Close() error {
if c.cancel != nil {
c.cancel()
}
// Wait for workers to detach, bounded by 5s.
done := make(chan struct{})
go func() {
c.wg.Wait()
@ -113,35 +134,356 @@ func (c *ClientService) Close() error {
return nil
}
// run prepares the desired targets and spawns one worker per target.
func (c *ClientService) run() {
defer c.wg.Done()
func (c *ClientService) initializeWorkers() {
targets := c.buildTargets()
if len(targets) == 0 {
c.logger.Warn("no devices to import; client idle")
c.stateMu.Lock()
c.targets = targets
if len(c.matches) == 0 {
c.stateMu.Unlock()
return
}
c.assignMu.Lock()
c.targets = targets
c.assigned = make([]string, len(targets))
for i := range targets {
c.assigned[i] = targets[i].fixedBusID
c.assignedWorkers = make([]*clientAssignedWorker, len(targets))
for i, target := range targets {
c.assignedWorkers[i] = &clientAssignedWorker{
target: target,
updates: make(chan string, 1),
}
}
c.assignMu.Unlock()
for i := range targets {
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
c.stateMu.Unlock()
for _, worker := range workers {
c.wg.Add(1)
go c.worker(i)
go c.runAssignedWorker(worker)
}
}
func (c *ClientService) run() {
defer c.wg.Done()
immediate := true
for {
if !immediate && !sleepCtx(c.ctx, clientReconnectDelay) {
break
}
err := c.runControlSession()
if c.ctx.Err() != nil {
break
}
if err != nil {
c.logger.Error("control ", c.serverAddr, ": ", err)
}
immediate = errors.Is(err, errImmediateReconnect)
}
c.stopAllWorkers()
}
func (c *ClientService) runControlSession() error {
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return E.Cause(err, "dial ", c.serverAddr)
}
defer conn.Close()
stopCloseOnCancel := closeConnOnContextDone(c.ctx, conn)
defer stopCloseOnCancel()
_ = conn.SetWriteDeadline(time.Now().Add(controlWriteTimeout))
_ = conn.SetReadDeadline(time.Now().Add(controlWriteTimeout))
if err := WriteControlPreface(conn); err != nil {
return E.Cause(err, "write control preface")
}
if err := WriteControlHello(conn); err != nil {
return E.Cause(err, "write control hello")
}
ack, err := ReadControlFrame(conn)
if err != nil {
return E.Cause(err, "read control ack")
}
if ack.Type != controlFrameAck {
return E.New("unexpected control ack frame ", ack.Type)
}
if ack.Version != controlProtocolVersion {
return E.New("unsupported control version ", ack.Version)
}
if ack.Capabilities&controlCapabilities != controlCapabilities {
return E.New("missing control capabilities 0x", ack.Capabilities)
}
_ = conn.SetWriteDeadline(time.Time{})
_ = conn.SetReadDeadline(time.Time{})
if err := c.syncRemoteState(); err != nil {
return E.Cause(err, "initial devlist sync")
}
pingDone := make(chan struct{})
go c.controlPingLoop(conn, pingDone)
defer close(pingDone)
lastSeq := ack.Sequence
for {
if err := conn.SetReadDeadline(time.Now().Add(controlReadTimeout)); err != nil {
return err
}
frame, err := ReadControlFrame(conn)
if err != nil {
return E.Cause(errImmediateReconnect, controlSessionIdleHint, ": ", err)
}
switch frame.Type {
case controlFrameChanged:
if frame.Sequence != lastSeq+1 {
return E.Cause(errImmediateReconnect, "control sequence jumped from ", lastSeq, " to ", frame.Sequence)
}
lastSeq = frame.Sequence
if err := c.syncRemoteState(); err != nil {
return E.Cause(errImmediateReconnect, "devlist sync after change ", frame.Sequence, ": ", err)
}
case controlFramePong:
default:
return E.Cause(errImmediateReconnect, "unexpected control frame ", frame.Type)
}
}
}
func (c *ClientService) controlPingLoop(conn net.Conn, done <-chan struct{}) {
ticker := time.NewTicker(controlPingInterval)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-done:
return
case <-ticker.C:
_ = conn.SetWriteDeadline(time.Now().Add(controlWriteTimeout))
if err := WriteControlPing(conn); err != nil {
_ = conn.Close()
return
}
_ = conn.SetWriteDeadline(time.Time{})
}
}
}
func (c *ClientService) syncRemoteState() error {
entries, err := c.fetchDevList(c.ctx)
if err != nil {
return err
}
c.applyRemoteEntries(entries)
return nil
}
func (c *ClientService) applyRemoteEntries(entries []DeviceEntry) {
if len(c.matches) == 0 {
c.applyRemoteExports(entries)
return
}
c.applyMatchedExports(entries)
}
func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
desired := make(map[string]struct{}, len(entries))
for i := range entries {
busid := entries[i].Info.BusIDString()
if busid == "" {
continue
}
desired[busid] = struct{}{}
}
c.stateMu.Lock()
stopWorkers := make([]*clientBusIDWorker, 0)
for busid, worker := range c.allWorkers {
if _, ok := desired[busid]; ok {
continue
}
stopWorkers = append(stopWorkers, worker)
delete(c.allWorkers, busid)
}
startBusIDs := make([]string, 0)
for busid := range desired {
if _, ok := c.allWorkers[busid]; ok {
continue
}
startBusIDs = append(startBusIDs, busid)
}
c.stateMu.Unlock()
for _, worker := range stopWorkers {
worker.cancel()
}
slices.Sort(startBusIDs)
for _, busid := range startBusIDs {
c.startRemoteBusIDWorker(busid, busid)
}
}
func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
keysByBusID := make(map[string]DeviceKey, len(entries))
for i := range entries {
busid := entries[i].Info.BusIDString()
if busid == "" {
continue
}
keysByBusID[busid] = DeviceKey{
BusID: busid,
VendorID: entries[i].Info.IDVendor,
ProductID: entries[i].Info.IDProduct,
Serial: entries[i].Info.SerialString(),
}
}
c.stateMu.Lock()
if len(c.targets) == 0 {
c.stateMu.Unlock()
return
}
nextAssigned := make([]string, len(c.targets))
reserved := make(map[string]struct{}, len(c.targets))
for i, target := range c.targets {
if target.fixedBusID == "" {
continue
}
if _, ok := keysByBusID[target.fixedBusID]; !ok {
continue
}
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
}
for i, target := range c.targets {
if target.fixedBusID != "" {
continue
}
current := c.assigned[i]
if current == "" {
continue
}
if _, ok := reserved[current]; ok {
continue
}
key, ok := keysByBusID[current]
if !ok || !Matches(target.match, key) {
continue
}
nextAssigned[i] = current
reserved[current] = struct{}{}
}
for i, target := range c.targets {
if target.fixedBusID != "" || nextAssigned[i] != "" {
continue
}
nextAssigned[i] = firstMatchingUnclaimedBusID(target.match, entries, reserved)
if nextAssigned[i] != "" {
reserved[nextAssigned[i]] = struct{}{}
}
}
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
previous := append([]string(nil), c.assigned...)
c.assigned = nextAssigned
c.stateMu.Unlock()
for i, worker := range workers {
if previous[i] == nextAssigned[i] {
continue
}
worker.setDesiredBusID(nextAssigned[i])
}
}
func (c *ClientService) runAssignedWorker(worker *clientAssignedWorker) {
defer c.wg.Done()
var current string
var runnerCancel context.CancelFunc
var runnerDone chan struct{}
stopRunner := func() {
if runnerCancel == nil {
return
}
runnerCancel()
<-runnerDone
runnerCancel = nil
runnerDone = nil
}
for {
select {
case <-c.ctx.Done():
stopRunner()
return
case desired := <-worker.updates:
if desired == current {
continue
}
stopRunner()
current = desired
if desired == "" {
continue
}
runCtx, cancel := context.WithCancel(c.ctx)
done := make(chan struct{})
runnerCancel = cancel
runnerDone = done
c.wg.Add(1)
go func(busid string) {
defer c.wg.Done()
defer close(done)
c.runBusIDLoop(runCtx, busid, worker.target.description())
}(desired)
}
}
}
func (w *clientAssignedWorker) setDesiredBusID(busid string) {
select {
case w.updates <- busid:
return
default:
}
select {
case <-w.updates:
default:
}
w.updates <- busid
}
func (c *ClientService) startRemoteBusIDWorker(busid, description string) {
runCtx, cancel := context.WithCancel(c.ctx)
worker := &clientBusIDWorker{cancel: cancel}
c.stateMu.Lock()
c.allWorkers[busid] = worker
c.stateMu.Unlock()
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.runBusIDLoop(runCtx, busid, description)
}()
}
func (c *ClientService) stopAllWorkers() {
c.stateMu.Lock()
workers := make([]*clientBusIDWorker, 0, len(c.allWorkers))
for _, worker := range c.allWorkers {
workers = append(workers, worker)
}
c.allWorkers = make(map[string]*clientBusIDWorker)
c.stateMu.Unlock()
for _, worker := range workers {
worker.cancel()
}
}
func (c *ClientService) buildTargets() []clientTarget {
if len(c.matches) == 0 {
busids := c.snapshotRemoteBusIDs()
targets := make([]clientTarget, 0, len(busids))
for _, busid := range busids {
targets = append(targets, clientTarget{fixedBusID: busid})
}
return targets
return nil
}
seenFixed := make(map[string]struct{})
targets := make([]clientTarget, 0, len(c.matches))
@ -159,36 +501,15 @@ func (c *ClientService) buildTargets() []clientTarget {
return targets
}
// snapshotRemoteBusIDs connects once, issues OP_REQ_DEVLIST, and returns the
// currently exported remote busids.
func (c *ClientService) snapshotRemoteBusIDs() []string {
for {
if err := c.ctx.Err(); err != nil {
return nil
}
entries, err := c.fetchDevList()
if err != nil {
c.logger.Error("enumerate ", c.serverAddr, ": ", err)
if !sleepCtx(c.ctx, clientReconnectDelay) {
return nil
}
continue
}
out := make([]string, 0, len(entries))
for i := range entries {
out = append(out, entries[i].Info.BusIDString())
}
return dedupe(out)
}
}
func (c *ClientService) fetchDevList() ([]DeviceEntry, error) {
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
func (c *ClientService) fetchDevList(ctx context.Context) ([]DeviceEntry, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return nil, err
}
defer conn.Close()
if err := binary.Write(conn, binary.BigEndian, OpHeader{Version: ProtocolVersion, Code: OpReqDevList, Status: OpStatusOK}); err != nil {
stopCloseOnCancel := closeConnOnContextDone(ctx, conn)
defer stopCloseOnCancel()
if err := WriteOpHeader(conn, OpReqDevList, OpStatusOK); err != nil {
return nil, E.Cause(err, "send OP_REQ_DEVLIST")
}
header, err := ReadOpHeader(conn)
@ -201,118 +522,41 @@ func (c *ClientService) fetchDevList() ([]DeviceEntry, error) {
return ReadOpRepDevListBody(conn)
}
// worker keeps one target attached to vhci_hcd.0. On any error or kernel-side
// detach, waits clientReconnectDelay and retries.
func (c *ClientService) worker(targetIndex int) {
defer c.wg.Done()
target := c.targets[targetIndex]
func (c *ClientService) runBusIDLoop(ctx context.Context, busid, description string) {
for {
if err := c.ctx.Err(); err != nil {
if err := ctx.Err(); err != nil {
return
}
busid, err := c.claimTargetBusID(targetIndex)
port, err := c.attemptAttach(ctx, busid)
if err != nil {
c.logger.Error("assign ", target.description(), ": ", err)
if !sleepCtx(c.ctx, clientReconnectDelay) {
return
}
continue
}
if busid == "" {
if !sleepCtx(c.ctx, clientReconnectDelay) {
return
}
continue
}
port, err := c.attemptAttach(busid)
if err != nil {
c.releaseTargetBusID(targetIndex, busid)
c.logger.Error("attach ", busid, ": ", err)
if !sleepCtx(c.ctx, clientReconnectDelay) {
c.logger.Error("attach ", description, " (", busid, "): ", err)
if !sleepCtx(ctx, clientReconnectDelay) {
return
}
continue
}
c.logger.Info("attached ", busid, " → vhci port ", port)
c.trackPort(port, true)
c.watchPort(port, busid)
c.watchPort(ctx, port, busid)
c.trackPort(port, false)
c.releaseTargetBusID(targetIndex, busid)
if err := c.ctx.Err(); err != nil {
if err := ctx.Err(); err != nil {
return
}
c.logger.Info("vhci port ", port, " released; reattaching ", busid)
if !sleepCtx(c.ctx, clientReconnectDelay) {
if !sleepCtx(ctx, clientReconnectDelay) {
return
}
}
}
func (c *ClientService) claimTargetBusID(targetIndex int) (string, error) {
target := c.targets[targetIndex]
if target.fixedBusID != "" {
return target.fixedBusID, nil
}
c.assignMu.Lock()
current := c.assigned[targetIndex]
c.assignMu.Unlock()
if current != "" {
return current, nil
}
entries, err := c.fetchDevList()
if err != nil {
return "", err
}
return c.refreshAssignments(targetIndex, entries), nil
}
func (c *ClientService) refreshAssignments(targetIndex int, entries []DeviceEntry) string {
c.assignMu.Lock()
defer c.assignMu.Unlock()
if c.assigned[targetIndex] != "" {
return c.assigned[targetIndex]
}
reserved := make(map[string]struct{}, len(c.assigned))
for _, busid := range c.assigned {
if busid == "" {
continue
}
reserved[busid] = struct{}{}
}
for i, target := range c.targets {
if target.fixedBusID != "" || c.assigned[i] != "" {
continue
}
busid := firstMatchingUnclaimedBusID(target.match, entries, reserved)
if busid == "" {
continue
}
c.assigned[i] = busid
reserved[busid] = struct{}{}
}
return c.assigned[targetIndex]
}
func (c *ClientService) releaseTargetBusID(targetIndex int, busid string) {
if c.targets[targetIndex].fixedBusID != "" {
return
}
c.assignMu.Lock()
defer c.assignMu.Unlock()
if c.assigned[targetIndex] == busid {
c.assigned[targetIndex] = ""
}
}
// attemptAttach performs one dial → OP_REQ_IMPORT → vhci attach sequence.
// The returned TCP socket is handed to the kernel on success; on failure the
// connection is closed before return.
func (c *ClientService) attemptAttach(busid string) (int, error) {
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
func (c *ClientService) attemptAttach(ctx context.Context, busid string) (int, error) {
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
if err != nil {
return -1, E.Cause(err, "dial ", c.serverAddr)
}
defer conn.Close()
stopCloseOnCancel := closeConnOnContextDone(ctx, conn)
defer stopCloseOnCancel()
if err := WriteOpReqImport(conn, busid); err != nil {
return -1, E.Cause(err, "write OP_REQ_IMPORT")
}
@ -351,14 +595,12 @@ func (c *ClientService) attemptAttach(busid string) (int, error) {
return port, nil
}
// watchPort polls vhci status every 2s and returns when the port is no longer
// in VDEV_ST_USED, or when ctx is canceled (in which case it detaches the port).
func (c *ClientService) watchPort(port int, busid string) {
func (c *ClientService) watchPort(ctx context.Context, port int, busid string) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
case <-ctx.Done():
if err := vhciDetach(port); err != nil {
c.logger.Warn("detach port ", port, " (", busid, "): ", err)
}
@ -431,3 +673,17 @@ func sleepCtx(ctx context.Context, d time.Duration) bool {
return true
}
}
func closeConnOnContextDone(ctx context.Context, conn net.Conn) func() {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = conn.Close()
case <-done:
}
}()
return func() {
close(done)
}
}

View file

@ -0,0 +1,112 @@
package usbip
import (
"encoding/binary"
"io"
)
const (
controlProtocolVersion uint8 = 1
controlFrameHello uint8 = 1
controlFrameAck uint8 = 2
controlFrameChanged uint8 = 3
controlFramePing uint8 = 4
controlFramePong uint8 = 5
controlCapabilityChanged uint32 = 1 << 0
controlCapabilityPingPong uint32 = 1 << 1
controlCapabilities = controlCapabilityChanged | controlCapabilityPingPong
controlPrefaceSize = 8
controlFrameSize = 16
)
var controlPreface = [controlPrefaceSize]byte{'S', 'B', 'U', 'S', 'B', 'I', 'P', '1'}
type controlFrame struct {
Type uint8
Version uint8
_ uint16
Capabilities uint32
Sequence uint64
}
func WriteControlPreface(w io.Writer) error {
_, err := w.Write(controlPreface[:])
return err
}
func IsControlPreface(raw []byte) bool {
if len(raw) != len(controlPreface) {
return false
}
for i := range controlPreface {
if raw[i] != controlPreface[i] {
return false
}
}
return true
}
func WriteControlHello(w io.Writer) error {
return writeControlFrame(w, controlFrame{
Type: controlFrameHello,
Version: controlProtocolVersion,
Capabilities: controlCapabilities,
})
}
func WriteControlAck(w io.Writer, sequence uint64) error {
return writeControlFrame(w, controlFrame{
Type: controlFrameAck,
Version: controlProtocolVersion,
Capabilities: controlCapabilities,
Sequence: sequence,
})
}
func WriteControlChanged(w io.Writer, sequence uint64) error {
return writeControlFrame(w, controlFrame{
Type: controlFrameChanged,
Version: controlProtocolVersion,
Sequence: sequence,
})
}
func WriteControlPing(w io.Writer) error {
return writeControlFrame(w, controlFrame{
Type: controlFramePing,
Version: controlProtocolVersion,
})
}
func WriteControlPong(w io.Writer) error {
return writeControlFrame(w, controlFrame{
Type: controlFramePong,
Version: controlProtocolVersion,
})
}
func ReadControlFrame(r io.Reader) (controlFrame, error) {
var raw [controlFrameSize]byte
if _, err := io.ReadFull(r, raw[:]); err != nil {
return controlFrame{}, err
}
return controlFrame{
Type: raw[0],
Version: raw[1],
Capabilities: binary.BigEndian.Uint32(raw[4:8]),
Sequence: binary.BigEndian.Uint64(raw[8:16]),
}, nil
}
func writeControlFrame(w io.Writer, frame controlFrame) error {
var raw [controlFrameSize]byte
raw[0] = frame.Type
raw[1] = frame.Version
binary.BigEndian.PutUint32(raw[4:8], frame.Capabilities)
binary.BigEndian.PutUint64(raw[8:16], frame.Sequence)
_, err := w.Write(raw[:])
return err
}

View file

@ -97,6 +97,14 @@ func ReadOpHeader(r io.Reader) (OpHeader, error) {
return h, nil
}
func ParseOpHeader(raw []byte) OpHeader {
return OpHeader{
Version: binary.BigEndian.Uint16(raw[:2]),
Code: binary.BigEndian.Uint16(raw[2:4]),
Status: binary.BigEndian.Uint32(raw[4:8]),
}
}
// WriteOpReqImport sends OP_REQ_IMPORT for busid (8 + 32 = 40 bytes).
func WriteOpReqImport(w io.Writer, busid string) error {
if err := WriteOpHeader(w, OpReqImport, OpStatusOK); err != nil {

View file

@ -4,7 +4,10 @@ package usbip
import (
"context"
"io"
"net"
"os"
"slices"
"sync"
"time"
@ -25,6 +28,12 @@ type serverExport struct {
originalDriver string
}
type serverControlConn struct {
id uint64
conn net.Conn
send chan controlFrame
}
type ServerService struct {
boxService.Adapter
ctx context.Context
@ -34,8 +43,13 @@ type ServerService struct {
matches []option.USBIPDeviceMatch
mu sync.Mutex
exports []serverExport
exports map[string]serverExport
listenFD net.Listener
controlMu sync.Mutex
controlSeq uint64
controlNextID uint64
controlSubs map[uint64]*serverControlConn
}
func NewServerService(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPServerServiceOptions) (adapter.Service, error) {
@ -54,12 +68,14 @@ func NewServerService(ctx context.Context, logger log.ContextLogger, tag string,
cancel: cancel,
logger: logger,
matches: options.Devices,
exports: make(map[string]serverExport),
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
controlSubs: make(map[uint64]*serverControlConn),
}
return s, nil
}
@ -71,7 +87,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error {
if err := ensureHostDriver(); err != nil {
return err
}
if err := s.bindExports(); err != nil {
if _, err := s.reconcileExports(); err != nil {
s.rollbackExports()
return err
}
@ -84,6 +100,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error {
s.listenFD = tcpListener
s.mu.Unlock()
go s.acceptLoop(tcpListener)
go s.ueventLoop()
return nil
}
@ -91,47 +108,58 @@ func (s *ServerService) Close() error {
if s.cancel != nil {
s.cancel()
}
s.closeControlSubscribers()
err := common.Close(common.PtrOrNil(s.listener))
s.rollbackExports()
return err
}
// bindExports resolves every match against current sysfs state, unbinds from
// the current driver, and binds to usbip-host.
func (s *ServerService) bindExports() error {
func (s *ServerService) reconcileExports() (bool, error) {
devices, err := listUSBDevices()
if err != nil {
return E.Cause(err, "enumerate usb devices")
return false, E.Cause(err, "enumerate usb devices")
}
desired := make(map[string]sysfsDevice)
present := make(map[string]struct{}, len(devices))
for i := range devices {
present[devices[i].BusID] = struct{}{}
}
seen := make(map[string]bool)
for _, m := range s.matches {
matched := 0
for i := range devices {
if !Matches(m, devices[i].key()) {
continue
}
if seen[devices[i].BusID] {
matched++
continue
}
if devices[i].DeviceClass == 0x09 {
seen[devices[i].BusID] = true
matched++
s.logger.Warn("skip hub device ", devices[i].BusID, " matched by ", describeMatch(m))
continue
}
if err := s.bindOne(&devices[i]); err != nil {
s.logger.Warn("bind ", devices[i].BusID, ": ", err)
continue
}
seen[devices[i].BusID] = true
matched++
}
if matched == 0 {
s.logger.Warn("no local device matched ", describeMatch(m))
desired[devices[i].BusID] = devices[i]
}
}
return nil
current := s.snapshotExports()
changed := false
for busid, device := range desired {
if _, ok := current[busid]; ok {
continue
}
if err := s.bindOne(&device); err != nil {
s.logger.Warn("bind ", busid, ": ", err)
continue
}
changed = true
}
for busid, export := range current {
if _, ok := desired[busid]; ok {
continue
}
_, restore := present[busid]
if err := s.releaseExport(export, restore); err != nil {
s.logger.Warn("release ", busid, ": ", err)
}
changed = true
}
return changed, nil
}
func (s *ServerService) bindOne(d *sysfsDevice) error {
@ -141,9 +169,7 @@ func (s *ServerService) bindOne(d *sysfsDevice) error {
}
if driver == "usbip-host" {
s.logger.Info("device ", d.BusID, " already bound to usbip-host; co-opting")
s.mu.Lock()
s.exports = append(s.exports, serverExport{busid: d.BusID})
s.mu.Unlock()
s.setExport(serverExport{busid: d.BusID})
return nil
}
if driver != "" {
@ -165,55 +191,92 @@ func (s *ServerService) bindOne(d *sysfsDevice) error {
return E.Cause(err, "bind to usbip-host")
}
s.logger.Info("exported ", d.BusID, " (previously on ", driverOrNone(driver), ")")
s.mu.Lock()
s.exports = append(s.exports, serverExport{
s.setExport(serverExport{
busid: d.BusID,
managed: true,
originalDriver: driver,
})
s.mu.Unlock()
return nil
}
func (s *ServerService) releaseExport(export serverExport, restore bool) error {
s.deleteExport(export.busid)
var releaseErr error
if err := writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
releaseErr = err
}
if !export.managed {
s.logger.Info("stopped tracking ", export.busid, " on usbip-host")
return releaseErr
}
if err := hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && releaseErr == nil {
releaseErr = err
}
if err := hostMatchBusID(export.busid, false); err != nil && releaseErr == nil {
releaseErr = err
}
if !restore {
s.logger.Info("removed export state for disappeared device ", export.busid)
return releaseErr
}
if export.originalDriver == "" {
s.logger.Info("released ", export.busid, " from usbip-host")
return releaseErr
}
if err := bindToDriver(export.busid, export.originalDriver); err != nil {
if releaseErr == nil {
releaseErr = err
}
return releaseErr
}
s.logger.Info("restored ", export.busid, " to ", export.originalDriver)
return releaseErr
}
func (s *ServerService) rollbackExports() {
s.mu.Lock()
exports := s.exports
s.exports = nil
s.mu.Unlock()
for _, e := range exports {
if !e.managed {
continue
exports := s.snapshotExports()
for _, export := range exports {
_, restore := currentSysfsDevice(export.busid)
if err := s.releaseExport(export, restore); err != nil {
s.logger.Warn("rollback ", export.busid, ": ", err)
}
// Release any attached peer.
_ = writeUsbipSockfd(e.busid, -1)
if err := hostUnbind(e.busid); err != nil {
s.logger.Warn("unbind ", e.busid, ": ", err)
}
if err := hostMatchBusID(e.busid, false); err != nil {
s.logger.Debug("match_busid del ", e.busid, ": ", err)
}
if e.originalDriver == "" {
s.logger.Info("released ", e.busid, " from usbip-host")
continue
}
if err := bindToDriver(e.busid, e.originalDriver); err != nil {
s.logger.Warn("rebind ", e.busid, " to ", e.originalDriver, ": ", err)
continue
}
s.logger.Info("restored ", e.busid, " to ", e.originalDriver)
}
}
func (s *ServerService) currentExports() []string {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]string, len(s.exports))
for i, e := range s.exports {
out[i] = e.busid
out := make([]string, 0, len(s.exports))
for busid := range s.exports {
out = append(out, busid)
}
slices.Sort(out)
return out
}
func (s *ServerService) snapshotExports() map[string]serverExport {
s.mu.Lock()
defer s.mu.Unlock()
out := make(map[string]serverExport, len(s.exports))
for busid, export := range s.exports {
out[busid] = export
}
return out
}
func (s *ServerService) setExport(export serverExport) {
s.mu.Lock()
defer s.mu.Unlock()
s.exports[export.busid] = export
}
func (s *ServerService) deleteExport(busid string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.exports, busid)
}
func (s *ServerService) acceptLoop(ln net.Listener) {
for {
conn, err := ln.Accept()
@ -237,17 +300,26 @@ func (s *ServerService) acceptLoop(ln net.Listener) {
s.logger.Error("accept: ", err)
return
}
go s.handleConn(conn)
go s.dispatchConn(conn)
}
}
func (s *ServerService) handleConn(conn net.Conn) {
defer conn.Close()
header, err := ReadOpHeader(conn)
if err != nil {
s.logger.Debug("read op header: ", err)
func (s *ServerService) dispatchConn(conn net.Conn) {
var prefix [controlPrefaceSize]byte
if _, err := io.ReadFull(conn, prefix[:]); err != nil {
s.logger.Debug("read connection preface: ", err)
_ = conn.Close()
return
}
if IsControlPreface(prefix[:]) {
s.handleControlConn(conn)
return
}
s.handleStandardConn(conn, ParseOpHeader(prefix[:]))
}
func (s *ServerService) handleStandardConn(conn net.Conn, header OpHeader) {
defer conn.Close()
switch header.Code {
case OpReqDevList:
s.handleDevList(conn)
@ -258,6 +330,71 @@ func (s *ServerService) handleConn(conn net.Conn) {
}
}
func (s *ServerService) handleControlConn(conn net.Conn) {
defer conn.Close()
hello, err := ReadControlFrame(conn)
if err != nil {
s.logger.Debug("read control hello: ", err)
return
}
if hello.Type != controlFrameHello {
s.logger.Debug("unexpected control frame ", hello.Type, " before hello")
return
}
if hello.Version != controlProtocolVersion {
s.logger.Debug("unsupported control version ", hello.Version)
return
}
if hello.Capabilities&controlCapabilities != controlCapabilities {
s.logger.Debug("missing control capabilities 0x", hello.Capabilities)
return
}
sub, seq := s.registerControlConn(conn)
defer s.unregisterControlConn(sub.id)
if err := WriteControlAck(conn, seq); err != nil {
s.logger.Debug("write control ack: ", err)
return
}
readDone := make(chan struct{})
go s.readControlConn(sub, readDone)
for {
select {
case <-s.ctx.Done():
return
case <-readDone:
return
case frame := <-sub.send:
if err := writeControlFrame(conn, frame); err != nil {
s.logger.Debug("write control frame: ", err)
return
}
}
}
}
func (s *ServerService) readControlConn(sub *serverControlConn, done chan<- struct{}) {
defer close(done)
for {
frame, err := ReadControlFrame(sub.conn)
if err != nil {
return
}
switch frame.Type {
case controlFramePing:
s.enqueueControlFrame(sub, controlFrame{
Type: controlFramePong,
Version: controlProtocolVersion,
})
default:
return
}
}
}
func (s *ServerService) handleDevList(conn net.Conn) {
entries := s.buildDevListEntries()
if err := WriteOpRepDevList(conn, entries); err != nil {
@ -346,12 +483,124 @@ func (s *ServerService) handleImport(conn net.Conn) {
func (s *ServerService) isExported(busid string) bool {
s.mu.Lock()
defer s.mu.Unlock()
for _, e := range s.exports {
if e.busid == busid {
return true
_, ok := s.exports[busid]
return ok
}
func (s *ServerService) ueventLoop() {
for {
listener, err := newUEventListener()
if err != nil {
if s.ctx.Err() != nil {
return
}
s.logger.Warn("open uevent listener: ", err)
if !sleepCtx(s.ctx, time.Second) {
return
}
continue
}
done := make(chan struct{})
go func() {
select {
case <-s.ctx.Done():
_ = listener.Close()
case <-done:
}
}()
for {
err = listener.WaitUSBEvent()
if err != nil {
close(done)
_ = listener.Close()
if s.ctx.Err() != nil {
return
}
s.logger.Warn("read uevent: ", err)
if !sleepCtx(s.ctx, time.Second) {
return
}
break
}
changed, reconcileErr := s.reconcileExports()
if reconcileErr != nil {
s.logger.Warn("reconcile exports: ", reconcileErr)
continue
}
if changed {
s.broadcastChanged()
}
}
}
return false
}
func (s *ServerService) registerControlConn(conn net.Conn) (*serverControlConn, uint64) {
s.controlMu.Lock()
defer s.controlMu.Unlock()
s.controlNextID++
sub := &serverControlConn{
id: s.controlNextID,
conn: conn,
send: make(chan controlFrame, 16),
}
s.controlSubs[sub.id] = sub
return sub, s.controlSeq
}
func (s *ServerService) unregisterControlConn(id uint64) {
s.controlMu.Lock()
defer s.controlMu.Unlock()
delete(s.controlSubs, id)
}
func (s *ServerService) closeControlSubscribers() {
s.controlMu.Lock()
subs := make([]*serverControlConn, 0, len(s.controlSubs))
for _, sub := range s.controlSubs {
subs = append(subs, sub)
}
s.controlSubs = make(map[uint64]*serverControlConn)
s.controlMu.Unlock()
for _, sub := range subs {
_ = sub.conn.Close()
}
}
func (s *ServerService) broadcastChanged() {
s.controlMu.Lock()
s.controlSeq++
sequence := s.controlSeq
subs := make([]*serverControlConn, 0, len(s.controlSubs))
for _, sub := range s.controlSubs {
subs = append(subs, sub)
}
s.controlMu.Unlock()
frame := controlFrame{
Type: controlFrameChanged,
Version: controlProtocolVersion,
Sequence: sequence,
}
for _, sub := range subs {
s.enqueueControlFrame(sub, frame)
}
}
func (s *ServerService) enqueueControlFrame(sub *serverControlConn, frame controlFrame) {
select {
case sub.send <- frame:
default:
s.logger.Debug("control subscriber ", sub.id, " lagged behind")
_ = sub.conn.Close()
}
}
func currentSysfsDevice(busid string) (sysfsDevice, bool) {
device, err := readSysfsDevice(busid, sysBusDevicePath(busid))
if err != nil {
return sysfsDevice{}, false
}
return device, true
}
func sysBusDevicePath(busid string) string {

View file

@ -0,0 +1,58 @@
//go:build linux
package usbip
import (
"bytes"
"os"
"golang.org/x/sys/unix"
)
type ueventListener struct {
fd int
}
func newUEventListener() (*ueventListener, error) {
fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_DGRAM, unix.NETLINK_KOBJECT_UEVENT)
if err != nil {
return nil, err
}
addr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Pid: uint32(os.Getpid()),
Groups: 1,
}
if err := unix.Bind(fd, addr); err != nil {
_ = unix.Close(fd)
return nil, err
}
return &ueventListener{fd: fd}, nil
}
func (l *ueventListener) Close() error {
return unix.Close(l.fd)
}
func (l *ueventListener) WaitUSBEvent() error {
var buf [4096]byte
for {
n, _, err := unix.Recvfrom(l.fd, buf[:], 0)
if err != nil {
return err
}
if isUSBUEvent(buf[:n]) {
return nil
}
}
}
func isUSBUEvent(raw []byte) bool {
fields := bytes.Split(bytes.TrimRight(raw, "\x00"), []byte{0})
for _, field := range fields {
if bytes.Equal(field, []byte("SUBSYSTEM=usb")) {
return true
}
}
return false
}