usbip: harden linux interop teardown

This commit is contained in:
世界 2026-04-24 10:36:18 +08:00
parent 794e1bcf15
commit f09bad8f79
No known key found for this signature in database
GPG key ID: CD109927C34A63C4
8 changed files with 432 additions and 81 deletions

View file

@ -173,7 +173,7 @@ func (c *ClientService) applyControlSnapshot(snapshot controlDeviceSnapshot) {
c.remoteMu.Lock()
c.remoteDevicesV2 = devices
c.remoteMu.Unlock()
c.applyRemoteEntries(deviceInfoV2ToEntries(values, true))
c.applyRemoteDeviceState(values)
}
func (c *ClientService) applyControlDelta(delta controlDeviceDelta) {
@ -198,7 +198,7 @@ func (c *ClientService) applyControlDelta(delta controlDeviceDelta) {
}
values := sortedDeviceInfoV2Values(c.remoteDevicesV2)
c.remoteMu.Unlock()
c.applyRemoteEntries(deviceInfoV2ToEntries(values, true))
c.applyRemoteDeviceState(values)
}
func (c *ClientService) clearControlDeviceState() {

View file

@ -340,6 +340,22 @@ func (c *ClientService) applyRemoteEntries(entries []DeviceEntry) {
c.applyMatchedExports(entries)
}
func (c *ClientService) applyRemoteDeviceState(devices []DeviceInfoV2) {
availableEntries := deviceInfoV2ToEntries(devices, true)
if len(c.matches) == 0 {
c.applyRemoteExports(availableEntries)
return
}
knownKeys := make(map[string]DeviceKey, len(devices))
for _, device := range devices {
if device.BusID == "" {
continue
}
knownKeys[device.BusID] = device.key()
}
c.applyMatchedExportsWithRetained(availableEntries, knownKeys)
}
func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
desired := make(map[string]struct{}, len(entries))
for i := range entries {
@ -377,12 +393,17 @@ func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
}
func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
c.applyMatchedExportsWithRetained(entries, nil)
}
func (c *ClientService) applyMatchedExportsWithRetained(entries []DeviceEntry, knownKeys map[string]DeviceKey) {
c.stateMu.Lock()
if len(c.targets) == 0 {
c.stateMu.Unlock()
return
}
nextAssigned := assignMatchedBusIDs(c.targets, c.assigned, entries)
activeCurrent := c.activeCurrentAssignmentsLocked(c.assigned, knownKeys)
nextAssigned := assignMatchedBusIDsWithRetained(c.targets, c.assigned, entries, knownKeys, activeCurrent)
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
previous := append([]string(nil), c.assigned...)
c.assigned = nextAssigned
@ -395,6 +416,29 @@ func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
}
}
func (c *ClientService) activeCurrentAssignmentsLocked(current []string, knownKeys map[string]DeviceKey) map[string]struct{} {
if len(knownKeys) == 0 {
return nil
}
var activeCurrent map[string]struct{}
for _, busid := range current {
if busid == "" {
continue
}
if _, ok := knownKeys[busid]; !ok {
continue
}
if !c.isBusIDActive(busid) {
continue
}
if activeCurrent == nil {
activeCurrent = make(map[string]struct{})
}
activeCurrent[busid] = struct{}{}
}
return activeCurrent
}
func (c *ClientService) runAssignedWorker(worker *clientAssignedWorker) {
defer c.wg.Done()
var current string

View file

@ -23,6 +23,9 @@ import (
const (
clientReconnectDelay = 5 * time.Second
clientShutdownTimeout = 15 * time.Second
clientDetachTimeout = 10 * time.Second
clientDetachPoll = 100 * time.Millisecond
controlPingInterval = 10 * time.Second
controlReadTimeout = 30 * time.Second
controlWriteTimeout = 5 * time.Second
@ -145,7 +148,7 @@ func (c *ClientService) Close() error {
}()
select {
case <-done:
case <-time.After(5 * time.Second):
case <-time.After(clientShutdownTimeout):
c.logger.Warn("shutdown timeout; some vhci ports may remain attached")
}
return nil
@ -363,6 +366,22 @@ func (c *ClientService) applyRemoteEntries(entries []DeviceEntry) {
c.applyMatchedExports(entries)
}
func (c *ClientService) applyRemoteDeviceState(devices []DeviceInfoV2) {
availableEntries := deviceInfoV2ToEntries(devices, true)
if len(c.matches) == 0 {
c.applyRemoteExports(availableEntries)
return
}
knownKeys := make(map[string]DeviceKey, len(devices))
for _, device := range devices {
if device.BusID == "" {
continue
}
knownKeys[device.BusID] = device.key()
}
c.applyMatchedExportsWithRetained(availableEntries, knownKeys)
}
func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
desired := make(map[string]struct{}, len(entries))
for i := range entries {
@ -405,12 +424,17 @@ func (c *ClientService) applyRemoteExports(entries []DeviceEntry) {
}
func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
c.applyMatchedExportsWithRetained(entries, nil)
}
func (c *ClientService) applyMatchedExportsWithRetained(entries []DeviceEntry, knownKeys map[string]DeviceKey) {
c.stateMu.Lock()
if len(c.targets) == 0 {
c.stateMu.Unlock()
return
}
nextAssigned := assignMatchedBusIDs(c.targets, c.assigned, entries)
activeCurrent := c.activeCurrentAssignmentsLocked(c.assigned, knownKeys)
nextAssigned := assignMatchedBusIDsWithRetained(c.targets, c.assigned, entries, knownKeys, activeCurrent)
workers := append([]*clientAssignedWorker(nil), c.assignedWorkers...)
previous := append([]string(nil), c.assigned...)
c.assigned = nextAssigned
@ -424,6 +448,29 @@ func (c *ClientService) applyMatchedExports(entries []DeviceEntry) {
}
}
func (c *ClientService) activeCurrentAssignmentsLocked(current []string, knownKeys map[string]DeviceKey) map[string]struct{} {
if len(knownKeys) == 0 {
return nil
}
var activeCurrent map[string]struct{}
for _, busid := range current {
if busid == "" {
continue
}
if _, ok := knownKeys[busid]; !ok {
continue
}
if !c.isBusIDActive(busid) {
continue
}
if activeCurrent == nil {
activeCurrent = make(map[string]struct{})
}
activeCurrent[busid] = struct{}{}
}
return activeCurrent
}
func (c *ClientService) runAssignedWorker(worker *clientAssignedWorker) {
defer c.wg.Done()
@ -678,6 +725,7 @@ func (c *ClientService) watchPort(ctx context.Context, port int, busid string) {
if err := c.ops.vhciDetach(port); err != nil {
c.logger.Warn("detach port ", port, " (", busid, "): ", err)
}
c.waitVHCIPortIdle(port, busid)
return
case <-settleDeadline.C:
if !seenUsed {
@ -685,6 +733,7 @@ func (c *ClientService) watchPort(ctx context.Context, port int, busid string) {
if err := c.ops.vhciDetach(port); err != nil {
c.logger.Warn("detach port ", port, " (", busid, "): ", err)
}
c.waitVHCIPortIdle(port, busid)
return
}
case <-ticker.C:
@ -708,6 +757,25 @@ func (c *ClientService) watchPort(ctx context.Context, port int, busid string) {
}
}
func (c *ClientService) waitVHCIPortIdle(port int, busid string) {
deadline := time.Now().Add(clientDetachTimeout)
for {
used, err := c.ops.vhciPortUsed(port)
if err == nil && !used {
return
}
if time.Now().After(deadline) {
if err != nil {
c.logger.Warn("poll detached vhci port ", port, " (", busid, "): ", err)
} else {
c.logger.Warn("vhci port ", port, " stayed used after detach for ", busid)
}
return
}
time.Sleep(clientDetachPoll)
}
}
func (c *ClientService) trackPort(port int, add bool) {
c.portsMu.Lock()
defer c.portsMu.Unlock()
@ -786,6 +854,16 @@ func isBusIDOnlyMatch(m option.USBIPDeviceMatch) bool {
}
func assignMatchedBusIDs(targets []clientTarget, current []string, entries []DeviceEntry) []string {
return assignMatchedBusIDsWithRetained(targets, current, entries, nil, nil)
}
func assignMatchedBusIDsWithRetained(
targets []clientTarget,
current []string,
entries []DeviceEntry,
knownKeys map[string]DeviceKey,
activeCurrent map[string]struct{},
) []string {
if len(targets) == 0 {
return nil
}
@ -797,6 +875,16 @@ func assignMatchedBusIDs(targets []clientTarget, current []string, entries []Dev
}
keysByBusID[busid] = entryDeviceKey(entries[i])
}
currentKey := func(busid string) (DeviceKey, bool) {
if key, ok := keysByBusID[busid]; ok {
return key, true
}
if _, active := activeCurrent[busid]; !active {
return DeviceKey{}, false
}
key, ok := knownKeys[busid]
return key, ok
}
nextAssigned := make([]string, len(targets))
reserved := make(map[string]struct{}, len(targets))
@ -804,11 +892,18 @@ func assignMatchedBusIDs(targets []clientTarget, current []string, entries []Dev
if target.fixedBusID == "" {
continue
}
if _, ok := keysByBusID[target.fixedBusID]; !ok {
if _, ok := keysByBusID[target.fixedBusID]; ok {
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
continue
}
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
if i >= len(current) || current[i] != target.fixedBusID {
continue
}
if _, ok := currentKey(target.fixedBusID); ok {
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
}
}
for i, target := range targets {
if target.fixedBusID != "" || i >= len(current) {
@ -820,7 +915,7 @@ func assignMatchedBusIDs(targets []clientTarget, current []string, entries []Dev
if _, ok := reserved[current[i]]; ok {
continue
}
key, ok := keysByBusID[current[i]]
key, ok := currentKey(current[i])
if !ok || !Matches(target.match, key) {
continue
}

View file

@ -86,7 +86,11 @@ func (h *usbipConnHandoff) Close() error {
func (h *usbipConnHandoff) startRelay(ctx context.Context, logger log.ContextLogger, side string, busid string) bool {
if !h.relay() {
return false
go func() {
<-ctx.Done()
_ = h.conn.Close()
}()
return true
}
relayConn := h.relayConn
h.relayConn = nil

View file

@ -33,6 +33,9 @@ const (
testACMProductID uint16 = 0x0104
testHIDProductID uint16 = 0x0105
testUDCCount = 2
testUSBIPTeardownTimeout = 20 * time.Second
testUSBIPTeardownPollInterval = 100 * time.Millisecond
)
var testHIDReportDescriptor = []byte{
@ -212,30 +215,114 @@ func releaseTestUDC(name string) {
testUDCMu.Unlock()
}
func waitForUSBIPTeardown(condition func() bool) bool {
deadline := time.Now().Add(testUSBIPTeardownTimeout)
for {
if condition() {
return true
}
if time.Now().After(deadline) {
return false
}
time.Sleep(testUSBIPTeardownPollInterval)
}
}
func detachUsedVHCIPorts() {
records, err := readVHCIStatus()
if err != nil {
return
}
for _, record := range records {
if record.state == 6 {
_ = vhciDetach(record.port)
}
}
}
func allVHCIPortsIdle() bool {
records, err := readVHCIStatus()
if err != nil {
return true
}
for _, record := range records {
if record.state == 6 {
return false
}
}
return true
}
func waitForAllVHCIPortsIdle(t *testing.T) {
t.Helper()
require.Eventually(t, allVHCIPortsIdle, testUSBIPTeardownTimeout, testUSBIPTeardownPollInterval)
}
func waitForVHCIPortIdle(t *testing.T, port int) {
t.Helper()
require.Eventually(t, func() bool {
used, err := vhciPortUsed(port)
return err == nil && !used
}, testUSBIPTeardownTimeout, testUSBIPTeardownPollInterval)
}
func waitForUSBIPHostAvailable(busid string) bool {
return waitForUSBIPTeardown(func() bool {
status, err := readUsbipStatus(busid)
if err != nil {
return os.IsNotExist(err) || isMissingUSBDeviceError(err)
}
return status == usbipStatusAvailable
})
}
func waitForDriverAway(busid string, driver string) bool {
return waitForUSBIPTeardown(func() bool {
current, err := currentDriver(busid)
if err != nil {
return os.IsNotExist(err) || isMissingUSBDeviceError(err)
}
return current != driver
})
}
func waitForSysfsPathGone(path string) bool {
return waitForUSBIPTeardown(func() bool {
_, err := os.Stat(path)
return os.IsNotExist(err)
})
}
func waitForGadgetNodesGone(nodes map[string]string) bool {
return waitForUSBIPTeardown(func() bool {
for _, path := range nodes {
if _, err := os.Stat(path); err == nil {
return false
}
}
return true
})
}
func shutdownUSBIPHostDevice(busid string) {
status, err := readUsbipStatus(busid)
if err == nil && status == usbipStatusUsed {
_ = writeUsbipSockfd(busid, -1)
_ = waitForUSBIPHostAvailable(busid)
}
if driver, err := currentDriver(busid); err == nil && driver == "usbip-host" {
_ = hostUnbind(busid)
_ = hostMatchBusID(busid, false)
_ = waitForDriverAway(busid, "usbip-host")
}
}
func resetUSBIPInteropState(t *testing.T) {
t.Helper()
requireRoot(t)
records, err := readVHCIStatus()
if err == nil {
for _, record := range records {
if record.state == 6 {
_ = vhciDetach(record.port)
}
}
require.Eventually(t, func() bool {
records, err = readVHCIStatus()
if err != nil {
return false
}
for _, record := range records {
if record.state == 6 {
return false
}
}
return true
}, 10*time.Second, 100*time.Millisecond)
}
detachUsedVHCIPorts()
waitForAllVHCIPortsIdle(t)
devices, err := listUSBDevices()
if err != nil {
@ -245,12 +332,7 @@ func resetUSBIPInteropState(t *testing.T) {
if !strings.HasPrefix(device.Serial, "codex-usbip-") {
continue
}
driver, err := currentDriver(device.BusID)
if err != nil || driver != "usbip-host" {
continue
}
_ = hostUnbind(device.BusID)
_ = hostMatchBusID(device.BusID, false)
shutdownUSBIPHostDevice(device.BusID)
_ = bindToDriver(device.BusID, "usb")
}
@ -277,10 +359,10 @@ func resetUSBIPInteropState(t *testing.T) {
require.Eventually(t, func() bool {
paths, _ := filepath.Glob("/sys/kernel/config/usb_gadget/codex_usbip_*")
return len(paths) == 0
}, 10*time.Second, 100*time.Millisecond)
}, testUSBIPTeardownTimeout, testUSBIPTeardownPollInterval)
require.Eventually(t, func() bool {
return len(importedNodeSnapshot("/dev/ttyACM*")) == 0 && len(importedNodeSnapshot("/dev/hidraw*")) == 0
}, 10*time.Second, 100*time.Millisecond)
}, testUSBIPTeardownTimeout, testUSBIPTeardownPollInterval)
testUDCMu.Lock()
testAllocatedUDC = make(map[string]struct{})
@ -623,13 +705,27 @@ func readExactlyWithin(reader io.Reader, size int, timeout time.Duration) ([]byt
func openRawTTY(t *testing.T, path string) *rawFile {
t.Helper()
file, err := os.OpenFile(path, os.O_RDWR, 0)
require.NoError(t, err)
state, err := term.MakeRaw(int(file.Fd()))
require.NoError(t, err)
return &rawFile{
file: file,
state: state,
var lastErr error
deadline := time.Now().Add(testUSBIPTeardownTimeout)
for {
file, err := os.OpenFile(path, os.O_RDWR, 0)
if err == nil {
state, err := term.MakeRaw(int(file.Fd()))
if err == nil {
return &rawFile{
file: file,
state: state,
}
}
lastErr = err
_ = file.Close()
} else {
lastErr = err
}
if time.Now().After(deadline) {
require.NoErrorf(t, lastErr, "open raw tty %s", path)
}
time.Sleep(testUSBIPTeardownPollInterval)
}
}
@ -739,13 +835,14 @@ func (g *testVirtualGadget) Close() {
defer releaseTestUDC(g.udcName)
if g.busid != "" {
if driver, err := currentDriver(g.busid); err == nil && driver == "usbip-host" {
_ = hostUnbind(g.busid)
_ = hostMatchBusID(g.busid, false)
}
shutdownUSBIPHostDevice(g.busid)
}
_ = writeSysfsLine(filepath.Join(g.path, "UDC"), "")
if g.busid != "" {
_ = waitForSysfsPathGone(sysBusDevicePath(g.busid))
}
_ = waitForGadgetNodesGone(g.nodes)
for _, function := range g.functions {
_ = os.Remove(filepath.Join(g.path, "configs/c.1", function.instance))
@ -758,30 +855,7 @@ func (g *testVirtualGadget) Close() {
_ = os.RemoveAll(filepath.Join(g.path, "strings/0x409"))
_ = os.RemoveAll(g.path)
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if _, err := os.Stat(g.path); err == nil {
time.Sleep(100 * time.Millisecond)
continue
}
if g.busid != "" {
if _, err := os.Stat(sysBusDevicePath(g.busid)); err == nil {
time.Sleep(100 * time.Millisecond)
continue
}
}
remainingNode := false
for _, path := range g.nodes {
if _, err := os.Stat(path); err == nil {
remainingNode = true
break
}
}
if !remainingNode {
return
}
time.Sleep(100 * time.Millisecond)
}
_ = waitForSysfsPathGone(g.path)
})
}
@ -926,7 +1000,9 @@ func TestUSBIPInteropOurServerWithOfficialClientACM(t *testing.T) {
require.Contains(t, portOutput, fmt.Sprintf("Port %02d", port))
runUSBIP(t, tools, "detach", "--port="+strconv.Itoa(port))
waitForVHCIPortIdle(t, port)
waitForPathGone(t, importedTTY)
require.NoError(t, server.Close())
_ = server
}
@ -938,7 +1014,7 @@ func TestUSBIPInteropOurServerWithOfficialClientHID(t *testing.T) {
requireVHCI(t)
gadget := newTestHIDGadget(t)
_, address := startRealUSBIPServer(t, []option.USBIPDeviceMatch{{Serial: gadget.serial}})
server, address := startRealUSBIPServer(t, []option.USBIPDeviceMatch{{Serial: gadget.serial}})
beforePorts := usedVHCIPorts(t)
beforeHID := importedNodeSnapshot("/dev/hidraw*")
@ -954,7 +1030,9 @@ func TestUSBIPInteropOurServerWithOfficialClientHID(t *testing.T) {
require.Contains(t, portOutput, fmt.Sprintf("Port %02d", port))
runUSBIP(t, tools, "detach", "--port="+strconv.Itoa(port))
waitForVHCIPortIdle(t, port)
waitForPathGone(t, importedHID)
require.NoError(t, server.Close())
}
func TestUSBIPInteropOurClientWithOfficialServerACM(t *testing.T) {
@ -983,6 +1061,7 @@ func TestUSBIPInteropOurClientWithOfficialServerACM(t *testing.T) {
gadget.exerciseImportedIO(t, importedTTY)
require.NoError(t, client.Close())
waitForAllVHCIPortsIdle(t)
waitForPathGone(t, importedTTY)
}
@ -1012,6 +1091,7 @@ func TestUSBIPInteropOurClientWithOfficialServerHID(t *testing.T) {
gadget.exerciseImportedIO(t, importedHID)
require.NoError(t, client.Close())
waitForAllVHCIPortsIdle(t)
waitForPathGone(t, importedHID)
}
@ -1047,6 +1127,7 @@ func TestUSBIPOfficialServerHasStaticDiscoveryOnly(t *testing.T) {
ensureNoNewImportedNode(t, "/dev/hidraw*", beforeHID, 3*time.Second)
require.NoError(t, client.Close())
waitForAllVHCIPortsIdle(t)
waitForPathGone(t, importedTTY)
}
@ -1081,6 +1162,11 @@ func TestUSBIPControlHotplugACMReattach(t *testing.T) {
second := newTestACMGadget(t)
secondImportedTTY := waitForNewImportedNode(t, "/dev/ttyACM*", secondBefore)
second.exerciseImportedIO(t, secondImportedTTY)
require.NoError(t, client.Close())
waitForAllVHCIPortsIdle(t)
waitForPathGone(t, secondImportedTTY)
require.NoError(t, server.Close())
}
func TestUSBIPControlImportAllACMAndHID(t *testing.T) {
@ -1089,7 +1175,7 @@ func TestUSBIPControlImportAllACMAndHID(t *testing.T) {
requireVHCI(t)
ensureTestUDCs(t, testUDCCount)
_, address := startRealUSBIPServer(t, []option.USBIPDeviceMatch{
server, address := startRealUSBIPServer(t, []option.USBIPDeviceMatch{
{VendorID: option.USBIPHexUint16(testVendorID), ProductID: option.USBIPHexUint16(testACMProductID)},
{VendorID: option.USBIPHexUint16(testVendorID), ProductID: option.USBIPHexUint16(testHIDProductID)},
})
@ -1109,4 +1195,10 @@ func TestUSBIPControlImportAllACMAndHID(t *testing.T) {
acm.exerciseImportedIO(t, importedTTY)
hid.exerciseImportedIO(t, importedHID)
require.NoError(t, client.Close())
waitForAllVHCIPortsIdle(t)
waitForPathGone(t, importedTTY)
waitForPathGone(t, importedHID)
require.NoError(t, server.Close())
}

View file

@ -535,6 +535,51 @@ func TestClientApplyRemoteExportsKeepsActiveBusIDWorker(t *testing.T) {
require.NotContains(t, client.allWorkers, "1-1")
}
func TestClientApplyControlDeviceStateKeepsActiveMatchedBusyBusID(t *testing.T) {
t.Parallel()
match := option.USBIPDeviceMatch{VendorID: 0x1d6b, ProductID: 0x0002}
target := clientTarget{match: match}
device := newTestDevice("1-1", 0x1d6b, 0x0002, "serial-1", SpeedHigh)
busyDevice := deviceInfoV2FromEntry(device.toDeviceEntry(), "linux-sysfs", "linux-busid:1-1", deviceStateBusy, usbipStatusUsed, "used")
worker := &clientAssignedWorker{target: target, updates: make(chan string, 1)}
client := &ClientService{
matches: []option.USBIPDeviceMatch{match},
targets: []clientTarget{target},
assigned: []string{"1-1"},
assignedWorkers: []*clientAssignedWorker{worker},
activeBusIDs: map[string]struct{}{"1-1": {}},
}
client.applyRemoteDeviceState([]DeviceInfoV2{busyDevice})
require.Equal(t, []string{"1-1"}, client.assigned)
select {
case update := <-worker.updates:
t.Fatalf("unexpected assignment update %q", update)
default:
}
idleWorker := &clientAssignedWorker{target: target, updates: make(chan string, 1)}
idleClient := &ClientService{
matches: []option.USBIPDeviceMatch{match},
targets: []clientTarget{target},
assigned: []string{""},
assignedWorkers: []*clientAssignedWorker{idleWorker},
activeBusIDs: make(map[string]struct{}),
}
idleClient.applyRemoteDeviceState([]DeviceInfoV2{busyDevice})
require.Equal(t, []string{""}, idleClient.assigned)
select {
case update := <-idleWorker.updates:
t.Fatalf("unexpected assignment update %q", update)
default:
}
}
func TestClientShouldRetryBusIDRefreshesImportAllState(t *testing.T) {
t.Parallel()
@ -846,11 +891,14 @@ func TestServerReconcileExportsReleasesRemovedExports(t *testing.T) {
device := newTestDevice("1-1", 0x1d6b, 0x0002, "regular", SpeedHigh)
store := newTestDeviceStore(device)
store.setStatus("1-1", usbipStatusUsed)
ops := newTestUSBIPOps(t)
var actions []string
ops.listUSBDevices = store.listUSBDevices
ops.readUsbipStatus = store.readUsbipStatus
ops.writeUsbipSockfd = func(busid string, fd int) error {
actions = append(actions, "sockfd "+busid)
store.setStatus(busid, usbipStatusAvailable)
return nil
}
ops.hostUnbind = func(busid string) error {
@ -919,6 +967,9 @@ func TestServerReleaseExportRetainsTrackingOnFailure(t *testing.T) {
}
ops := newTestUSBIPOps(t)
ops.readUsbipStatus = func(string) (int, error) {
return usbipStatusAvailable, nil
}
ops.writeUsbipSockfd = func(string, int) error {
return nil
}
@ -2183,6 +2234,7 @@ func TestClientRunControlSessionSyncsAssignmentsOnChanged(t *testing.T) {
controlSubs: make(map[uint64]*serverControlConn),
ops: serverOps,
}
server.refreshControlState()
serverAddr, closeServer := startDispatchServer(t, server)
defer closeServer()

View file

@ -40,6 +40,11 @@ type serverControlConn struct {
send chan controlOutboundMessage
}
const (
usbipExportReleaseTimeout = 10 * time.Second
usbipExportReleasePollInterval = 100 * time.Millisecond
)
type ServerService struct {
boxService.Adapter
ctx context.Context
@ -227,8 +232,19 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
s.logger.Info("stopped tracking ", export.busid, " on usbip-host")
return nil
}
if err := s.ops.writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
return err
status, statusErr := s.ops.readUsbipStatus(export.busid)
if statusErr != nil && !os.IsNotExist(statusErr) && !isMissingUSBDeviceError(statusErr) {
return statusErr
}
if statusErr == nil && status == usbipStatusUsed {
if err := s.ops.writeUsbipSockfd(export.busid, -1); err != nil && !os.IsNotExist(err) {
return err
}
if restore {
if err := s.waitUSBIPStatusAvailable(export.busid, usbipExportReleaseTimeout); err != nil {
return err
}
}
}
if err := s.ops.hostUnbind(export.busid); err != nil && !os.IsNotExist(err) && !(isMissingUSBDeviceError(err) && !restore) {
return err
@ -254,6 +270,27 @@ func (s *ServerService) releaseExport(export serverExport, restore bool) error {
return nil
}
func (s *ServerService) waitUSBIPStatusAvailable(busid string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for {
status, err := s.ops.readUsbipStatus(busid)
if err != nil {
if os.IsNotExist(err) || isMissingUSBDeviceError(err) {
return nil
}
} else if status == usbipStatusAvailable {
return nil
}
if time.Now().After(deadline) {
if err != nil {
return E.Cause(err, "wait for ", busid, " usbip status available")
}
return E.New("timed out waiting for ", busid, " usbip status available")
}
time.Sleep(usbipExportReleasePollInterval)
}
}
func (s *ServerService) rollbackExports() {
exports := s.snapshotExports()
for _, export := range exports {

View file

@ -27,6 +27,16 @@ func isBusIDOnlyMatch(m option.USBIPDeviceMatch) bool {
}
func assignMatchedBusIDs(targets []clientTarget, current []string, entries []DeviceEntry) []string {
return assignMatchedBusIDsWithRetained(targets, current, entries, nil, nil)
}
func assignMatchedBusIDsWithRetained(
targets []clientTarget,
current []string,
entries []DeviceEntry,
knownKeys map[string]DeviceKey,
activeCurrent map[string]struct{},
) []string {
if len(targets) == 0 {
return nil
}
@ -38,17 +48,34 @@ func assignMatchedBusIDs(targets []clientTarget, current []string, entries []Dev
}
keysByBusID[busid] = entryDeviceKey(entries[i])
}
currentKey := func(busid string) (DeviceKey, bool) {
if key, ok := keysByBusID[busid]; ok {
return key, true
}
if _, active := activeCurrent[busid]; !active {
return DeviceKey{}, false
}
key, ok := knownKeys[busid]
return key, ok
}
nextAssigned := make([]string, len(targets))
reserved := make(map[string]struct{}, len(targets))
for i, target := range targets {
if target.fixedBusID == "" {
continue
}
if _, ok := keysByBusID[target.fixedBusID]; !ok {
if _, ok := keysByBusID[target.fixedBusID]; ok {
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
continue
}
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
if i >= len(current) || current[i] != target.fixedBusID {
continue
}
if _, ok := currentKey(target.fixedBusID); ok {
nextAssigned[i] = target.fixedBusID
reserved[target.fixedBusID] = struct{}{}
}
}
for i, target := range targets {
if target.fixedBusID != "" || i >= len(current) || current[i] == "" {
@ -57,7 +84,7 @@ func assignMatchedBusIDs(targets []clientTarget, current []string, entries []Dev
if _, ok := reserved[current[i]]; ok {
continue
}
key, ok := keysByBusID[current[i]]
key, ok := currentKey(current[i])
if !ok || !Matches(target.match, key) {
continue
}