mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-05-13 13:57:05 +00:00
service/usbip: add hotplug control push path
This commit is contained in:
parent
9a92967c4e
commit
f3bbd3c07b
5 changed files with 893 additions and 210 deletions
|
|
@ -4,8 +4,9 @@ package usbip
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -20,7 +21,15 @@ import (
|
|||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
const clientReconnectDelay = 5 * time.Second
|
||||
const (
|
||||
clientReconnectDelay = 5 * time.Second
|
||||
controlPingInterval = 10 * time.Second
|
||||
controlReadTimeout = 30 * time.Second
|
||||
controlWriteTimeout = 5 * time.Second
|
||||
controlSessionIdleHint = "control session lost"
|
||||
)
|
||||
|
||||
var errImmediateReconnect = errors.New("usbip control reconnect")
|
||||
|
||||
type clientTarget struct {
|
||||
fixedBusID string
|
||||
|
|
@ -34,6 +43,15 @@ func (t clientTarget) description() string {
|
|||
return describeMatch(t.match)
|
||||
}
|
||||
|
||||
type clientAssignedWorker struct {
|
||||
target clientTarget
|
||||
updates chan string
|
||||
}
|
||||
|
||||
type clientBusIDWorker struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type ClientService struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
|
|
@ -43,9 +61,11 @@ type ClientService struct {
|
|||
serverAddr M.Socksaddr
|
||||
matches []option.USBIPDeviceMatch // empty = import all remote exports
|
||||
|
||||
assignMu sync.Mutex
|
||||
targets []clientTarget
|
||||
assigned []string
|
||||
stateMu sync.Mutex
|
||||
targets []clientTarget
|
||||
assigned []string
|
||||
assignedWorkers []*clientAssignedWorker
|
||||
allWorkers map[string]*clientBusIDWorker
|
||||
|
||||
attachMu sync.Mutex // serializes vhci port pick + attach
|
||||
wg sync.WaitGroup
|
||||
|
|
@ -79,6 +99,7 @@ func NewClientService(ctx context.Context, logger log.ContextLogger, tag string,
|
|||
dialer: outboundDialer,
|
||||
serverAddr: options.ServerOptions.Build(),
|
||||
matches: options.Devices,
|
||||
allWorkers: make(map[string]*clientBusIDWorker),
|
||||
ports: make(map[int]struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -90,6 +111,7 @@ func (c *ClientService) Start(stage adapter.StartStage) error {
|
|||
if err := ensureVHCI(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.initializeWorkers()
|
||||
c.wg.Add(1)
|
||||
go c.run()
|
||||
return nil
|
||||
|
|
@ -99,7 +121,6 @@ func (c *ClientService) Close() error {
|
|||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
// Wait for workers to detach, bounded by 5s.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.wg.Wait()
|
||||
|
|
@ -113,35 +134,356 @@ func (c *ClientService) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// run prepares the desired targets and spawns one worker per target.
|
||||
func (c *ClientService) run() {
|
||||
defer c.wg.Done()
|
||||
func (c *ClientService) initializeWorkers() {
|
||||
targets := c.buildTargets()
|
||||
if len(targets) == 0 {
|
||||
c.logger.Warn("no devices to import; client idle")
|
||||
c.stateMu.Lock()
|
||||
c.targets = targets
|
||||
if len(c.matches) == 0 {
|
||||
c.stateMu.Unlock()
|
||||
return
|
||||
}
|
||||
c.assignMu.Lock()
|
||||
c.targets = targets
|
||||
c.assigned = make([]string, len(targets))
|
||||
for i := range targets {
|
||||
c.assigned[i] = targets[i].fixedBusID
|
||||
c.assignedWorkers = make([]*clientAssignedWorker, len(targets))
|
||||
for i, target := range targets {
|
||||
c.assignedWorkers[i] = &clientAssignedWorker{
|
||||
target: target,
|
||||
updates: make(chan string, 1),
|
||||
}
|
||||
}
|
||||
c.assignMu.Unlock()
|
||||
for i := range targets {
|
||||
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
|
||||
c.stateMu.Unlock()
|
||||
|
||||
for _, worker := range workers {
|
||||
c.wg.Add(1)
|
||||
go c.worker(i)
|
||||
go c.runAssignedWorker(worker)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) run() {
|
||||
defer c.wg.Done()
|
||||
immediate := true
|
||||
for {
|
||||
if !immediate && !sleepCtx(c.ctx, clientReconnectDelay) {
|
||||
break
|
||||
}
|
||||
err := c.runControlSession()
|
||||
if c.ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
c.logger.Error("control ", c.serverAddr, ": ", err)
|
||||
}
|
||||
immediate = errors.Is(err, errImmediateReconnect)
|
||||
}
|
||||
c.stopAllWorkers()
|
||||
}
|
||||
|
||||
func (c *ClientService) runControlSession() error {
|
||||
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
|
||||
if err != nil {
|
||||
return E.Cause(err, "dial ", c.serverAddr)
|
||||
}
|
||||
defer conn.Close()
|
||||
stopCloseOnCancel := closeConnOnContextDone(c.ctx, conn)
|
||||
defer stopCloseOnCancel()
|
||||
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(controlWriteTimeout))
|
||||
_ = conn.SetReadDeadline(time.Now().Add(controlWriteTimeout))
|
||||
if err := WriteControlPreface(conn); err != nil {
|
||||
return E.Cause(err, "write control preface")
|
||||
}
|
||||
if err := WriteControlHello(conn); err != nil {
|
||||
return E.Cause(err, "write control hello")
|
||||
}
|
||||
ack, err := ReadControlFrame(conn)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read control ack")
|
||||
}
|
||||
if ack.Type != controlFrameAck {
|
||||
return E.New("unexpected control ack frame ", ack.Type)
|
||||
}
|
||||
if ack.Version != controlProtocolVersion {
|
||||
return E.New("unsupported control version ", ack.Version)
|
||||
}
|
||||
if ack.Capabilities&controlCapabilities != controlCapabilities {
|
||||
return E.New("missing control capabilities 0x", ack.Capabilities)
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Time{})
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if err := c.syncRemoteState(); err != nil {
|
||||
return E.Cause(err, "initial devlist sync")
|
||||
}
|
||||
|
||||
pingDone := make(chan struct{})
|
||||
go c.controlPingLoop(conn, pingDone)
|
||||
defer close(pingDone)
|
||||
|
||||
lastSeq := ack.Sequence
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(controlReadTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
frame, err := ReadControlFrame(conn)
|
||||
if err != nil {
|
||||
return E.Cause(errImmediateReconnect, controlSessionIdleHint, ": ", err)
|
||||
}
|
||||
switch frame.Type {
|
||||
case controlFrameChanged:
|
||||
if frame.Sequence != lastSeq+1 {
|
||||
return E.Cause(errImmediateReconnect, "control sequence jumped from ", lastSeq, " to ", frame.Sequence)
|
||||
}
|
||||
lastSeq = frame.Sequence
|
||||
if err := c.syncRemoteState(); err != nil {
|
||||
return E.Cause(errImmediateReconnect, "devlist sync after change ", frame.Sequence, ": ", err)
|
||||
}
|
||||
case controlFramePong:
|
||||
default:
|
||||
return E.Cause(errImmediateReconnect, "unexpected control frame ", frame.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) controlPingLoop(conn net.Conn, done <-chan struct{}) {
|
||||
ticker := time.NewTicker(controlPingInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(controlWriteTimeout))
|
||||
if err := WriteControlPing(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) syncRemoteState() error {
|
||||
entries, err := c.fetchDevList(c.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.applyRemoteEntries(entries)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientService) applyRemoteEntries(entries []DeviceEntry) {
|
||||
if len(c.matches) == 0 {
|
||||
c.applyRemoteExports(entries)
|
||||
return
|
||||
}
|
||||
c.applyMatchedExports(entries)
|
||||
}
|
||||
|
||||
func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
|
||||
desired := make(map[string]struct{}, len(entries))
|
||||
for i := range entries {
|
||||
busid := entries[i].Info.BusIDString()
|
||||
if busid == "" {
|
||||
continue
|
||||
}
|
||||
desired[busid] = struct{}{}
|
||||
}
|
||||
|
||||
c.stateMu.Lock()
|
||||
stopWorkers := make([]*clientBusIDWorker, 0)
|
||||
for busid, worker := range c.allWorkers {
|
||||
if _, ok := desired[busid]; ok {
|
||||
continue
|
||||
}
|
||||
stopWorkers = append(stopWorkers, worker)
|
||||
delete(c.allWorkers, busid)
|
||||
}
|
||||
startBusIDs := make([]string, 0)
|
||||
for busid := range desired {
|
||||
if _, ok := c.allWorkers[busid]; ok {
|
||||
continue
|
||||
}
|
||||
startBusIDs = append(startBusIDs, busid)
|
||||
}
|
||||
c.stateMu.Unlock()
|
||||
|
||||
for _, worker := range stopWorkers {
|
||||
worker.cancel()
|
||||
}
|
||||
slices.Sort(startBusIDs)
|
||||
for _, busid := range startBusIDs {
|
||||
c.startRemoteBusIDWorker(busid, busid)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
|
||||
keysByBusID := make(map[string]DeviceKey, len(entries))
|
||||
for i := range entries {
|
||||
busid := entries[i].Info.BusIDString()
|
||||
if busid == "" {
|
||||
continue
|
||||
}
|
||||
keysByBusID[busid] = DeviceKey{
|
||||
BusID: busid,
|
||||
VendorID: entries[i].Info.IDVendor,
|
||||
ProductID: entries[i].Info.IDProduct,
|
||||
Serial: entries[i].Info.SerialString(),
|
||||
}
|
||||
}
|
||||
|
||||
c.stateMu.Lock()
|
||||
if len(c.targets) == 0 {
|
||||
c.stateMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
nextAssigned := make([]string, len(c.targets))
|
||||
reserved := make(map[string]struct{}, len(c.targets))
|
||||
for i, target := range c.targets {
|
||||
if target.fixedBusID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := keysByBusID[target.fixedBusID]; !ok {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = target.fixedBusID
|
||||
reserved[target.fixedBusID] = struct{}{}
|
||||
}
|
||||
for i, target := range c.targets {
|
||||
if target.fixedBusID != "" {
|
||||
continue
|
||||
}
|
||||
current := c.assigned[i]
|
||||
if current == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := reserved[current]; ok {
|
||||
continue
|
||||
}
|
||||
key, ok := keysByBusID[current]
|
||||
if !ok || !Matches(target.match, key) {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = current
|
||||
reserved[current] = struct{}{}
|
||||
}
|
||||
for i, target := range c.targets {
|
||||
if target.fixedBusID != "" || nextAssigned[i] != "" {
|
||||
continue
|
||||
}
|
||||
nextAssigned[i] = firstMatchingUnclaimedBusID(target.match, entries, reserved)
|
||||
if nextAssigned[i] != "" {
|
||||
reserved[nextAssigned[i]] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
|
||||
previous := append([]string(nil), c.assigned...)
|
||||
c.assigned = nextAssigned
|
||||
c.stateMu.Unlock()
|
||||
|
||||
for i, worker := range workers {
|
||||
if previous[i] == nextAssigned[i] {
|
||||
continue
|
||||
}
|
||||
worker.setDesiredBusID(nextAssigned[i])
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) runAssignedWorker(worker *clientAssignedWorker) {
|
||||
defer c.wg.Done()
|
||||
|
||||
var current string
|
||||
var runnerCancel context.CancelFunc
|
||||
var runnerDone chan struct{}
|
||||
|
||||
stopRunner := func() {
|
||||
if runnerCancel == nil {
|
||||
return
|
||||
}
|
||||
runnerCancel()
|
||||
<-runnerDone
|
||||
runnerCancel = nil
|
||||
runnerDone = nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
stopRunner()
|
||||
return
|
||||
case desired := <-worker.updates:
|
||||
if desired == current {
|
||||
continue
|
||||
}
|
||||
stopRunner()
|
||||
current = desired
|
||||
if desired == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(c.ctx)
|
||||
done := make(chan struct{})
|
||||
runnerCancel = cancel
|
||||
runnerDone = done
|
||||
|
||||
c.wg.Add(1)
|
||||
go func(busid string) {
|
||||
defer c.wg.Done()
|
||||
defer close(done)
|
||||
c.runBusIDLoop(runCtx, busid, worker.target.description())
|
||||
}(desired)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *clientAssignedWorker) setDesiredBusID(busid string) {
|
||||
select {
|
||||
case w.updates <- busid:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-w.updates:
|
||||
default:
|
||||
}
|
||||
w.updates <- busid
|
||||
}
|
||||
|
||||
func (c *ClientService) startRemoteBusIDWorker(busid, description string) {
|
||||
runCtx, cancel := context.WithCancel(c.ctx)
|
||||
worker := &clientBusIDWorker{cancel: cancel}
|
||||
|
||||
c.stateMu.Lock()
|
||||
c.allWorkers[busid] = worker
|
||||
c.stateMu.Unlock()
|
||||
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
c.runBusIDLoop(runCtx, busid, description)
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *ClientService) stopAllWorkers() {
|
||||
c.stateMu.Lock()
|
||||
workers := make([]*clientBusIDWorker, 0, len(c.allWorkers))
|
||||
for _, worker := range c.allWorkers {
|
||||
workers = append(workers, worker)
|
||||
}
|
||||
c.allWorkers = make(map[string]*clientBusIDWorker)
|
||||
c.stateMu.Unlock()
|
||||
|
||||
for _, worker := range workers {
|
||||
worker.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) buildTargets() []clientTarget {
|
||||
if len(c.matches) == 0 {
|
||||
busids := c.snapshotRemoteBusIDs()
|
||||
targets := make([]clientTarget, 0, len(busids))
|
||||
for _, busid := range busids {
|
||||
targets = append(targets, clientTarget{fixedBusID: busid})
|
||||
}
|
||||
return targets
|
||||
return nil
|
||||
}
|
||||
seenFixed := make(map[string]struct{})
|
||||
targets := make([]clientTarget, 0, len(c.matches))
|
||||
|
|
@ -159,36 +501,15 @@ func (c *ClientService) buildTargets() []clientTarget {
|
|||
return targets
|
||||
}
|
||||
|
||||
// snapshotRemoteBusIDs connects once, issues OP_REQ_DEVLIST, and returns the
|
||||
// currently exported remote busids.
|
||||
func (c *ClientService) snapshotRemoteBusIDs() []string {
|
||||
for {
|
||||
if err := c.ctx.Err(); err != nil {
|
||||
return nil
|
||||
}
|
||||
entries, err := c.fetchDevList()
|
||||
if err != nil {
|
||||
c.logger.Error("enumerate ", c.serverAddr, ": ", err)
|
||||
if !sleepCtx(c.ctx, clientReconnectDelay) {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
out := make([]string, 0, len(entries))
|
||||
for i := range entries {
|
||||
out = append(out, entries[i].Info.BusIDString())
|
||||
}
|
||||
return dedupe(out)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) fetchDevList() ([]DeviceEntry, error) {
|
||||
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
|
||||
func (c *ClientService) fetchDevList(ctx context.Context) ([]DeviceEntry, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
if err := binary.Write(conn, binary.BigEndian, OpHeader{Version: ProtocolVersion, Code: OpReqDevList, Status: OpStatusOK}); err != nil {
|
||||
stopCloseOnCancel := closeConnOnContextDone(ctx, conn)
|
||||
defer stopCloseOnCancel()
|
||||
if err := WriteOpHeader(conn, OpReqDevList, OpStatusOK); err != nil {
|
||||
return nil, E.Cause(err, "send OP_REQ_DEVLIST")
|
||||
}
|
||||
header, err := ReadOpHeader(conn)
|
||||
|
|
@ -201,118 +522,41 @@ func (c *ClientService) fetchDevList() ([]DeviceEntry, error) {
|
|||
return ReadOpRepDevListBody(conn)
|
||||
}
|
||||
|
||||
// worker keeps one target attached to vhci_hcd.0. On any error or kernel-side
|
||||
// detach, waits clientReconnectDelay and retries.
|
||||
func (c *ClientService) worker(targetIndex int) {
|
||||
defer c.wg.Done()
|
||||
target := c.targets[targetIndex]
|
||||
func (c *ClientService) runBusIDLoop(ctx context.Context, busid, description string) {
|
||||
for {
|
||||
if err := c.ctx.Err(); err != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
busid, err := c.claimTargetBusID(targetIndex)
|
||||
port, err := c.attemptAttach(ctx, busid)
|
||||
if err != nil {
|
||||
c.logger.Error("assign ", target.description(), ": ", err)
|
||||
if !sleepCtx(c.ctx, clientReconnectDelay) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if busid == "" {
|
||||
if !sleepCtx(c.ctx, clientReconnectDelay) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
port, err := c.attemptAttach(busid)
|
||||
if err != nil {
|
||||
c.releaseTargetBusID(targetIndex, busid)
|
||||
c.logger.Error("attach ", busid, ": ", err)
|
||||
if !sleepCtx(c.ctx, clientReconnectDelay) {
|
||||
c.logger.Error("attach ", description, " (", busid, "): ", err)
|
||||
if !sleepCtx(ctx, clientReconnectDelay) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
c.logger.Info("attached ", busid, " → vhci port ", port)
|
||||
c.trackPort(port, true)
|
||||
c.watchPort(port, busid)
|
||||
c.watchPort(ctx, port, busid)
|
||||
c.trackPort(port, false)
|
||||
c.releaseTargetBusID(targetIndex, busid)
|
||||
if err := c.ctx.Err(); err != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
c.logger.Info("vhci port ", port, " released; reattaching ", busid)
|
||||
if !sleepCtx(c.ctx, clientReconnectDelay) {
|
||||
if !sleepCtx(ctx, clientReconnectDelay) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientService) claimTargetBusID(targetIndex int) (string, error) {
|
||||
target := c.targets[targetIndex]
|
||||
if target.fixedBusID != "" {
|
||||
return target.fixedBusID, nil
|
||||
}
|
||||
c.assignMu.Lock()
|
||||
current := c.assigned[targetIndex]
|
||||
c.assignMu.Unlock()
|
||||
if current != "" {
|
||||
return current, nil
|
||||
}
|
||||
entries, err := c.fetchDevList()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return c.refreshAssignments(targetIndex, entries), nil
|
||||
}
|
||||
|
||||
func (c *ClientService) refreshAssignments(targetIndex int, entries []DeviceEntry) string {
|
||||
c.assignMu.Lock()
|
||||
defer c.assignMu.Unlock()
|
||||
if c.assigned[targetIndex] != "" {
|
||||
return c.assigned[targetIndex]
|
||||
}
|
||||
reserved := make(map[string]struct{}, len(c.assigned))
|
||||
for _, busid := range c.assigned {
|
||||
if busid == "" {
|
||||
continue
|
||||
}
|
||||
reserved[busid] = struct{}{}
|
||||
}
|
||||
for i, target := range c.targets {
|
||||
if target.fixedBusID != "" || c.assigned[i] != "" {
|
||||
continue
|
||||
}
|
||||
busid := firstMatchingUnclaimedBusID(target.match, entries, reserved)
|
||||
if busid == "" {
|
||||
continue
|
||||
}
|
||||
c.assigned[i] = busid
|
||||
reserved[busid] = struct{}{}
|
||||
}
|
||||
return c.assigned[targetIndex]
|
||||
}
|
||||
|
||||
func (c *ClientService) releaseTargetBusID(targetIndex int, busid string) {
|
||||
if c.targets[targetIndex].fixedBusID != "" {
|
||||
return
|
||||
}
|
||||
c.assignMu.Lock()
|
||||
defer c.assignMu.Unlock()
|
||||
if c.assigned[targetIndex] == busid {
|
||||
c.assigned[targetIndex] = ""
|
||||
}
|
||||
}
|
||||
|
||||
// attemptAttach performs one dial → OP_REQ_IMPORT → vhci attach sequence.
|
||||
// The returned TCP socket is handed to the kernel on success; on failure the
|
||||
// connection is closed before return.
|
||||
func (c *ClientService) attemptAttach(busid string) (int, error) {
|
||||
conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, c.serverAddr)
|
||||
func (c *ClientService) attemptAttach(ctx context.Context, busid string) (int, error) {
|
||||
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.serverAddr)
|
||||
if err != nil {
|
||||
return -1, E.Cause(err, "dial ", c.serverAddr)
|
||||
}
|
||||
defer conn.Close()
|
||||
stopCloseOnCancel := closeConnOnContextDone(ctx, conn)
|
||||
defer stopCloseOnCancel()
|
||||
if err := WriteOpReqImport(conn, busid); err != nil {
|
||||
return -1, E.Cause(err, "write OP_REQ_IMPORT")
|
||||
}
|
||||
|
|
@ -351,14 +595,12 @@ func (c *ClientService) attemptAttach(busid string) (int, error) {
|
|||
return port, nil
|
||||
}
|
||||
|
||||
// watchPort polls vhci status every 2s and returns when the port is no longer
|
||||
// in VDEV_ST_USED, or when ctx is canceled (in which case it detaches the port).
|
||||
func (c *ClientService) watchPort(port int, busid string) {
|
||||
func (c *ClientService) watchPort(ctx context.Context, port int, busid string) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
if err := vhciDetach(port); err != nil {
|
||||
c.logger.Warn("detach port ", port, " (", busid, "): ", err)
|
||||
}
|
||||
|
|
@ -431,3 +673,17 @@ func sleepCtx(ctx context.Context, d time.Duration) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func closeConnOnContextDone(ctx context.Context, conn net.Conn) func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = conn.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
close(done)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
112
service/usbip/control_protocol.go
Normal file
112
service/usbip/control_protocol.go
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
package usbip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
controlProtocolVersion uint8 = 1
|
||||
|
||||
controlFrameHello uint8 = 1
|
||||
controlFrameAck uint8 = 2
|
||||
controlFrameChanged uint8 = 3
|
||||
controlFramePing uint8 = 4
|
||||
controlFramePong uint8 = 5
|
||||
|
||||
controlCapabilityChanged uint32 = 1 << 0
|
||||
controlCapabilityPingPong uint32 = 1 << 1
|
||||
controlCapabilities = controlCapabilityChanged | controlCapabilityPingPong
|
||||
|
||||
controlPrefaceSize = 8
|
||||
controlFrameSize = 16
|
||||
)
|
||||
|
||||
var controlPreface = [controlPrefaceSize]byte{'S', 'B', 'U', 'S', 'B', 'I', 'P', '1'}
|
||||
|
||||
type controlFrame struct {
|
||||
Type uint8
|
||||
Version uint8
|
||||
_ uint16
|
||||
Capabilities uint32
|
||||
Sequence uint64
|
||||
}
|
||||
|
||||
func WriteControlPreface(w io.Writer) error {
|
||||
_, err := w.Write(controlPreface[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func IsControlPreface(raw []byte) bool {
|
||||
if len(raw) != len(controlPreface) {
|
||||
return false
|
||||
}
|
||||
for i := range controlPreface {
|
||||
if raw[i] != controlPreface[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func WriteControlHello(w io.Writer) error {
|
||||
return writeControlFrame(w, controlFrame{
|
||||
Type: controlFrameHello,
|
||||
Version: controlProtocolVersion,
|
||||
Capabilities: controlCapabilities,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteControlAck(w io.Writer, sequence uint64) error {
|
||||
return writeControlFrame(w, controlFrame{
|
||||
Type: controlFrameAck,
|
||||
Version: controlProtocolVersion,
|
||||
Capabilities: controlCapabilities,
|
||||
Sequence: sequence,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteControlChanged(w io.Writer, sequence uint64) error {
|
||||
return writeControlFrame(w, controlFrame{
|
||||
Type: controlFrameChanged,
|
||||
Version: controlProtocolVersion,
|
||||
Sequence: sequence,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteControlPing(w io.Writer) error {
|
||||
return writeControlFrame(w, controlFrame{
|
||||
Type: controlFramePing,
|
||||
Version: controlProtocolVersion,
|
||||
})
|
||||
}
|
||||
|
||||
func WriteControlPong(w io.Writer) error {
|
||||
return writeControlFrame(w, controlFrame{
|
||||
Type: controlFramePong,
|
||||
Version: controlProtocolVersion,
|
||||
})
|
||||
}
|
||||
|
||||
func ReadControlFrame(r io.Reader) (controlFrame, error) {
|
||||
var raw [controlFrameSize]byte
|
||||
if _, err := io.ReadFull(r, raw[:]); err != nil {
|
||||
return controlFrame{}, err
|
||||
}
|
||||
return controlFrame{
|
||||
Type: raw[0],
|
||||
Version: raw[1],
|
||||
Capabilities: binary.BigEndian.Uint32(raw[4:8]),
|
||||
Sequence: binary.BigEndian.Uint64(raw[8:16]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func writeControlFrame(w io.Writer, frame controlFrame) error {
|
||||
var raw [controlFrameSize]byte
|
||||
raw[0] = frame.Type
|
||||
raw[1] = frame.Version
|
||||
binary.BigEndian.PutUint32(raw[4:8], frame.Capabilities)
|
||||
binary.BigEndian.PutUint64(raw[8:16], frame.Sequence)
|
||||
_, err := w.Write(raw[:])
|
||||
return err
|
||||
}
|
||||
|
|
@ -97,6 +97,14 @@ func ReadOpHeader(r io.Reader) (OpHeader, error) {
|
|||
return h, nil
|
||||
}
|
||||
|
||||
func ParseOpHeader(raw []byte) OpHeader {
|
||||
return OpHeader{
|
||||
Version: binary.BigEndian.Uint16(raw[:2]),
|
||||
Code: binary.BigEndian.Uint16(raw[2:4]),
|
||||
Status: binary.BigEndian.Uint32(raw[4:8]),
|
||||
}
|
||||
}
|
||||
|
||||
// WriteOpReqImport sends OP_REQ_IMPORT for busid (8 + 32 = 40 bytes).
|
||||
func WriteOpReqImport(w io.Writer, busid string) error {
|
||||
if err := WriteOpHeader(w, OpReqImport, OpStatusOK); err != nil {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ package usbip
|
|||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -25,6 +28,12 @@ type serverExport struct {
|
|||
originalDriver string
|
||||
}
|
||||
|
||||
type serverControlConn struct {
|
||||
id uint64
|
||||
conn net.Conn
|
||||
send chan controlFrame
|
||||
}
|
||||
|
||||
type ServerService struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
|
|
@ -34,8 +43,13 @@ type ServerService struct {
|
|||
matches []option.USBIPDeviceMatch
|
||||
|
||||
mu sync.Mutex
|
||||
exports []serverExport
|
||||
exports map[string]serverExport
|
||||
listenFD net.Listener
|
||||
|
||||
controlMu sync.Mutex
|
||||
controlSeq uint64
|
||||
controlNextID uint64
|
||||
controlSubs map[uint64]*serverControlConn
|
||||
}
|
||||
|
||||
func NewServerService(ctx context.Context, logger log.ContextLogger, tag string, options option.USBIPServerServiceOptions) (adapter.Service, error) {
|
||||
|
|
@ -54,12 +68,14 @@ func NewServerService(ctx context.Context, logger log.ContextLogger, tag string,
|
|||
cancel: cancel,
|
||||
logger: logger,
|
||||
matches: options.Devices,
|
||||
exports: make(map[string]serverExport),
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
controlSubs: make(map[uint64]*serverControlConn),
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
|
@ -71,7 +87,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error {
|
|||
if err := ensureHostDriver(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.bindExports(); err != nil {
|
||||
if _, err := s.reconcileExports(); err != nil {
|
||||
s.rollbackExports()
|
||||
return err
|
||||
}
|
||||
|
|
@ -84,6 +100,7 @@ func (s *ServerService) Start(stage adapter.StartStage) error {
|
|||
s.listenFD = tcpListener
|
||||
s.mu.Unlock()
|
||||
go s.acceptLoop(tcpListener)
|
||||
go s.ueventLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -91,47 +108,58 @@ func (s *ServerService) Close() error {
|
|||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
s.closeControlSubscribers()
|
||||
err := common.Close(common.PtrOrNil(s.listener))
|
||||
s.rollbackExports()
|
||||
return err
|
||||
}
|
||||
|
||||
// bindExports resolves every match against current sysfs state, unbinds from
|
||||
// the current driver, and binds to usbip-host.
|
||||
func (s *ServerService) bindExports() error {
|
||||
func (s *ServerService) reconcileExports() (bool, error) {
|
||||
devices, err := listUSBDevices()
|
||||
if err != nil {
|
||||
return E.Cause(err, "enumerate usb devices")
|
||||
return false, E.Cause(err, "enumerate usb devices")
|
||||
}
|
||||
desired := make(map[string]sysfsDevice)
|
||||
present := make(map[string]struct{}, len(devices))
|
||||
for i := range devices {
|
||||
present[devices[i].BusID] = struct{}{}
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range s.matches {
|
||||
matched := 0
|
||||
for i := range devices {
|
||||
if !Matches(m, devices[i].key()) {
|
||||
continue
|
||||
}
|
||||
if seen[devices[i].BusID] {
|
||||
matched++
|
||||
continue
|
||||
}
|
||||
if devices[i].DeviceClass == 0x09 {
|
||||
seen[devices[i].BusID] = true
|
||||
matched++
|
||||
s.logger.Warn("skip hub device ", devices[i].BusID, " matched by ", describeMatch(m))
|
||||
continue
|
||||
}
|
||||
if err := s.bindOne(&devices[i]); err != nil {
|
||||
s.logger.Warn("bind ", devices[i].BusID, ": ", err)
|
||||
continue
|
||||
}
|
||||
seen[devices[i].BusID] = true
|
||||
matched++
|
||||
}
|
||||
if matched == 0 {
|
||||
s.logger.Warn("no local device matched ", describeMatch(m))
|
||||
desired[devices[i].BusID] = devices[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
current := s.snapshotExports()
|
||||
changed := false
|
||||
for busid, device := range desired {
|
||||
if _, ok := current[busid]; ok {
|
||||
continue
|
||||
}
|
||||
if err := s.bindOne(&device); err != nil {
|
||||
s.logger.Warn("bind ", busid, ": ", err)
|
||||
continue
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
for busid, export := range current {
|
||||
if _, ok := desired[busid]; ok {
|
||||
continue
|
||||
}
|
||||
_, restore := present[busid]
|
||||
if err := s.releaseExport(export, restore); err != nil {
|
||||
s.logger.Warn("release ", busid, ": ", err)
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
func (s *ServerService) bindOne(d *sysfsDevice) error {
|
||||
|
|
@ -141,9 +169,7 @@ func (s *ServerService) bindOne(d *sysfsDevice) error {
|
|||
}
|
||||
if driver == "usbip-host" {
|
||||
s.logger.Info("device ", d.BusID, " already bound to usbip-host; co-opting")
|
||||
s.mu.Lock()
|
||||
s.exports = append(s.exports, serverExport{busid: d.BusID})
|
||||
s.mu.Unlock()
|
||||
s.setExport(serverExport{busid: d.BusID})
|
||||
return nil
|
||||
}
|
||||
if driver != "" {
|
||||
|
|
@ -165,55 +191,92 @@ func (s *ServerService) bindOne(d *sysfsDevice) error {
|
|||
return E.Cause(err, "bind to usbip-host")
|
||||
}
|
||||
s.logger.Info("exported ", d.BusID, " (previously on ", driverOrNone(driver), ")")
|
||||
s.mu.Lock()
|
||||
s.exports = append(s.exports, serverExport{
|
||||
s.setExport(serverExport{
|
||||
busid: d.BusID,
|
||||
managed: true,
|
||||
originalDriver: driver,
|
||||
})
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerService) releaseExport(export serverExport, restore bool) error {
|
||||
s.deleteExport(export.busid)
|
||||
|
||||
var releaseErr error
|
||||
if err := writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
|
||||
releaseErr = err
|
||||
}
|
||||
if !export.managed {
|
||||
s.logger.Info("stopped tracking ", export.busid, " on usbip-host")
|
||||
return releaseErr
|
||||
}
|
||||
if err := hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && releaseErr == nil {
|
||||
releaseErr = err
|
||||
}
|
||||
if err := hostMatchBusID(export.busid, false); err != nil && releaseErr == nil {
|
||||
releaseErr = err
|
||||
}
|
||||
if !restore {
|
||||
s.logger.Info("removed export state for disappeared device ", export.busid)
|
||||
return releaseErr
|
||||
}
|
||||
if export.originalDriver == "" {
|
||||
s.logger.Info("released ", export.busid, " from usbip-host")
|
||||
return releaseErr
|
||||
}
|
||||
if err := bindToDriver(export.busid, export.originalDriver); err != nil {
|
||||
if releaseErr == nil {
|
||||
releaseErr = err
|
||||
}
|
||||
return releaseErr
|
||||
}
|
||||
s.logger.Info("restored ", export.busid, " to ", export.originalDriver)
|
||||
return releaseErr
|
||||
}
|
||||
|
||||
func (s *ServerService) rollbackExports() {
|
||||
s.mu.Lock()
|
||||
exports := s.exports
|
||||
s.exports = nil
|
||||
s.mu.Unlock()
|
||||
for _, e := range exports {
|
||||
if !e.managed {
|
||||
continue
|
||||
exports := s.snapshotExports()
|
||||
for _, export := range exports {
|
||||
_, restore := currentSysfsDevice(export.busid)
|
||||
if err := s.releaseExport(export, restore); err != nil {
|
||||
s.logger.Warn("rollback ", export.busid, ": ", err)
|
||||
}
|
||||
// Release any attached peer.
|
||||
_ = writeUsbipSockfd(e.busid, -1)
|
||||
if err := hostUnbind(e.busid); err != nil {
|
||||
s.logger.Warn("unbind ", e.busid, ": ", err)
|
||||
}
|
||||
if err := hostMatchBusID(e.busid, false); err != nil {
|
||||
s.logger.Debug("match_busid del ", e.busid, ": ", err)
|
||||
}
|
||||
if e.originalDriver == "" {
|
||||
s.logger.Info("released ", e.busid, " from usbip-host")
|
||||
continue
|
||||
}
|
||||
if err := bindToDriver(e.busid, e.originalDriver); err != nil {
|
||||
s.logger.Warn("rebind ", e.busid, " to ", e.originalDriver, ": ", err)
|
||||
continue
|
||||
}
|
||||
s.logger.Info("restored ", e.busid, " to ", e.originalDriver)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) currentExports() []string {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make([]string, len(s.exports))
|
||||
for i, e := range s.exports {
|
||||
out[i] = e.busid
|
||||
out := make([]string, 0, len(s.exports))
|
||||
for busid := range s.exports {
|
||||
out = append(out, busid)
|
||||
}
|
||||
slices.Sort(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *ServerService) snapshotExports() map[string]serverExport {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := make(map[string]serverExport, len(s.exports))
|
||||
for busid, export := range s.exports {
|
||||
out[busid] = export
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *ServerService) setExport(export serverExport) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.exports[export.busid] = export
|
||||
}
|
||||
|
||||
func (s *ServerService) deleteExport(busid string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.exports, busid)
|
||||
}
|
||||
|
||||
func (s *ServerService) acceptLoop(ln net.Listener) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
|
|
@ -237,17 +300,26 @@ func (s *ServerService) acceptLoop(ln net.Listener) {
|
|||
s.logger.Error("accept: ", err)
|
||||
return
|
||||
}
|
||||
go s.handleConn(conn)
|
||||
go s.dispatchConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
header, err := ReadOpHeader(conn)
|
||||
if err != nil {
|
||||
s.logger.Debug("read op header: ", err)
|
||||
func (s *ServerService) dispatchConn(conn net.Conn) {
|
||||
var prefix [controlPrefaceSize]byte
|
||||
if _, err := io.ReadFull(conn, prefix[:]); err != nil {
|
||||
s.logger.Debug("read connection preface: ", err)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
if IsControlPreface(prefix[:]) {
|
||||
s.handleControlConn(conn)
|
||||
return
|
||||
}
|
||||
s.handleStandardConn(conn, ParseOpHeader(prefix[:]))
|
||||
}
|
||||
|
||||
func (s *ServerService) handleStandardConn(conn net.Conn, header OpHeader) {
|
||||
defer conn.Close()
|
||||
switch header.Code {
|
||||
case OpReqDevList:
|
||||
s.handleDevList(conn)
|
||||
|
|
@ -258,6 +330,71 @@ func (s *ServerService) handleConn(conn net.Conn) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) handleControlConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
hello, err := ReadControlFrame(conn)
|
||||
if err != nil {
|
||||
s.logger.Debug("read control hello: ", err)
|
||||
return
|
||||
}
|
||||
if hello.Type != controlFrameHello {
|
||||
s.logger.Debug("unexpected control frame ", hello.Type, " before hello")
|
||||
return
|
||||
}
|
||||
if hello.Version != controlProtocolVersion {
|
||||
s.logger.Debug("unsupported control version ", hello.Version)
|
||||
return
|
||||
}
|
||||
if hello.Capabilities&controlCapabilities != controlCapabilities {
|
||||
s.logger.Debug("missing control capabilities 0x", hello.Capabilities)
|
||||
return
|
||||
}
|
||||
|
||||
sub, seq := s.registerControlConn(conn)
|
||||
defer s.unregisterControlConn(sub.id)
|
||||
|
||||
if err := WriteControlAck(conn, seq); err != nil {
|
||||
s.logger.Debug("write control ack: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
readDone := make(chan struct{})
|
||||
go s.readControlConn(sub, readDone)
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-readDone:
|
||||
return
|
||||
case frame := <-sub.send:
|
||||
if err := writeControlFrame(conn, frame); err != nil {
|
||||
s.logger.Debug("write control frame: ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) readControlConn(sub *serverControlConn, done chan<- struct{}) {
|
||||
defer close(done)
|
||||
for {
|
||||
frame, err := ReadControlFrame(sub.conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch frame.Type {
|
||||
case controlFramePing:
|
||||
s.enqueueControlFrame(sub, controlFrame{
|
||||
Type: controlFramePong,
|
||||
Version: controlProtocolVersion,
|
||||
})
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) handleDevList(conn net.Conn) {
|
||||
entries := s.buildDevListEntries()
|
||||
if err := WriteOpRepDevList(conn, entries); err != nil {
|
||||
|
|
@ -346,12 +483,124 @@ func (s *ServerService) handleImport(conn net.Conn) {
|
|||
func (s *ServerService) isExported(busid string) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, e := range s.exports {
|
||||
if e.busid == busid {
|
||||
return true
|
||||
_, ok := s.exports[busid]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *ServerService) ueventLoop() {
|
||||
for {
|
||||
listener, err := newUEventListener()
|
||||
if err != nil {
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.logger.Warn("open uevent listener: ", err)
|
||||
if !sleepCtx(s.ctx, time.Second) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
_ = listener.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
for {
|
||||
err = listener.WaitUSBEvent()
|
||||
if err != nil {
|
||||
close(done)
|
||||
_ = listener.Close()
|
||||
if s.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.logger.Warn("read uevent: ", err)
|
||||
if !sleepCtx(s.ctx, time.Second) {
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
changed, reconcileErr := s.reconcileExports()
|
||||
if reconcileErr != nil {
|
||||
s.logger.Warn("reconcile exports: ", reconcileErr)
|
||||
continue
|
||||
}
|
||||
if changed {
|
||||
s.broadcastChanged()
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *ServerService) registerControlConn(conn net.Conn) (*serverControlConn, uint64) {
|
||||
s.controlMu.Lock()
|
||||
defer s.controlMu.Unlock()
|
||||
s.controlNextID++
|
||||
sub := &serverControlConn{
|
||||
id: s.controlNextID,
|
||||
conn: conn,
|
||||
send: make(chan controlFrame, 16),
|
||||
}
|
||||
s.controlSubs[sub.id] = sub
|
||||
return sub, s.controlSeq
|
||||
}
|
||||
|
||||
func (s *ServerService) unregisterControlConn(id uint64) {
|
||||
s.controlMu.Lock()
|
||||
defer s.controlMu.Unlock()
|
||||
delete(s.controlSubs, id)
|
||||
}
|
||||
|
||||
func (s *ServerService) closeControlSubscribers() {
|
||||
s.controlMu.Lock()
|
||||
subs := make([]*serverControlConn, 0, len(s.controlSubs))
|
||||
for _, sub := range s.controlSubs {
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
s.controlSubs = make(map[uint64]*serverControlConn)
|
||||
s.controlMu.Unlock()
|
||||
for _, sub := range subs {
|
||||
_ = sub.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) broadcastChanged() {
|
||||
s.controlMu.Lock()
|
||||
s.controlSeq++
|
||||
sequence := s.controlSeq
|
||||
subs := make([]*serverControlConn, 0, len(s.controlSubs))
|
||||
for _, sub := range s.controlSubs {
|
||||
subs = append(subs, sub)
|
||||
}
|
||||
s.controlMu.Unlock()
|
||||
|
||||
frame := controlFrame{
|
||||
Type: controlFrameChanged,
|
||||
Version: controlProtocolVersion,
|
||||
Sequence: sequence,
|
||||
}
|
||||
for _, sub := range subs {
|
||||
s.enqueueControlFrame(sub, frame)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerService) enqueueControlFrame(sub *serverControlConn, frame controlFrame) {
|
||||
select {
|
||||
case sub.send <- frame:
|
||||
default:
|
||||
s.logger.Debug("control subscriber ", sub.id, " lagged behind")
|
||||
_ = sub.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func currentSysfsDevice(busid string) (sysfsDevice, bool) {
|
||||
device, err := readSysfsDevice(busid, sysBusDevicePath(busid))
|
||||
if err != nil {
|
||||
return sysfsDevice{}, false
|
||||
}
|
||||
return device, true
|
||||
}
|
||||
|
||||
func sysBusDevicePath(busid string) string {
|
||||
|
|
|
|||
58
service/usbip/uevent_linux.go
Normal file
58
service/usbip/uevent_linux.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
//go:build linux
|
||||
|
||||
package usbip
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type ueventListener struct {
|
||||
fd int
|
||||
}
|
||||
|
||||
func newUEventListener() (*ueventListener, error) {
|
||||
fd, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_DGRAM, unix.NETLINK_KOBJECT_UEVENT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Pid: uint32(os.Getpid()),
|
||||
Groups: 1,
|
||||
}
|
||||
if err := unix.Bind(fd, addr); err != nil {
|
||||
_ = unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
return &ueventListener{fd: fd}, nil
|
||||
}
|
||||
|
||||
func (l *ueventListener) Close() error {
|
||||
return unix.Close(l.fd)
|
||||
}
|
||||
|
||||
func (l *ueventListener) WaitUSBEvent() error {
|
||||
var buf [4096]byte
|
||||
for {
|
||||
n, _, err := unix.Recvfrom(l.fd, buf[:], 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isUSBUEvent(buf[:n]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isUSBUEvent(raw []byte) bool {
|
||||
fields := bytes.Split(bytes.TrimRight(raw, "\x00"), []byte{0})
|
||||
for _, field := range fields {
|
||||
if bytes.Equal(field, []byte("SUBSYSTEM=usb")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue