diff --git a/UsbDk/ControlDevice.cpp b/UsbDk/ControlDevice.cpp index 75b09e3..a244a31 100644 --- a/UsbDk/ControlDevice.cpp +++ b/UsbDk/ControlDevice.cpp @@ -451,7 +451,7 @@ PDEVICE_OBJECT CUsbDkControlDevice::GetPDOByDeviceID(const USB_DK_DEVICE_ID &Dev return PDO; } -NTSTATUS CUsbDkControlDevice::ResetUsbDevice(const USB_DK_DEVICE_ID &DeviceID) +NTSTATUS CUsbDkControlDevice::ResetUsbDevice(const USB_DK_DEVICE_ID &DeviceID, bool ForceD0) { PDEVICE_OBJECT PDO = GetPDOByDeviceID(DeviceID); if (PDO == nullptr) @@ -460,7 +460,7 @@ NTSTATUS CUsbDkControlDevice::ResetUsbDevice(const USB_DK_DEVICE_ID &DeviceID) } CWdmUsbDeviceAccess pdoAccess(PDO); - auto status = pdoAccess.Reset(); + auto status = pdoAccess.Reset(ForceD0); ObDereferenceObject(PDO); return status; @@ -707,7 +707,7 @@ void CUsbDkControlDevice::AddRedirectRollBack(const USB_DK_DEVICE_ID &DeviceId, return; } - auto resetRes = ResetUsbDevice(DeviceId); + auto resetRes = ResetUsbDevice(DeviceId, false); if (!NT_SUCCESS(resetRes)) { TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, "%!FUNC! Roll-back reset failed. %!STATUS!", resetRes); @@ -737,7 +737,7 @@ NTSTATUS CUsbDkControlDevice::AddRedirect(const USB_DK_DEVICE_ID &DeviceId, HAND TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, "%!FUNC! Success. New redirections list:"); m_Redirections.Dump(); - auto resetRes = ResetUsbDevice(DeviceId); + auto resetRes = ResetUsbDevice(DeviceId, true); if (!NT_SUCCESS(resetRes)) { TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, "%!FUNC! Reset after start redirection failed. %!STATUS!", resetRes); @@ -1029,7 +1029,7 @@ NTSTATUS CUsbDkControlDevice::RemoveRedirect(const USB_DK_DEVICE_ID &DeviceId) { if (NotifyRedirectorRemovalStarted(DeviceId)) { - auto res = ResetUsbDevice(DeviceId); + auto res = ResetUsbDevice(DeviceId, false); if (NT_SUCCESS(res)) { TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, diff --git a/UsbDk/ControlDevice.h b/UsbDk/ControlDevice.h index 7e7cd0f..9223036 100644 --- a/UsbDk/ControlDevice.h +++ b/UsbDk/ControlDevice.h @@ -266,7 +266,7 @@ public: { return ReloadPersistentHideRules(); } bool EnumerateDevices(USB_DK_DEVICE_INFO *outBuff, size_t numberAllocatedDevices, size_t &numberExistingDevices); - NTSTATUS ResetUsbDevice(const USB_DK_DEVICE_ID &DeviceId); + NTSTATUS ResetUsbDevice(const USB_DK_DEVICE_ID &DeviceId, bool ForceD0); NTSTATUS AddRedirect(const USB_DK_DEVICE_ID &DeviceId, HANDLE RequestorProcess, PHANDLE ObjectHandle); NTSTATUS AddHideRule(const USB_DK_HIDE_RULE &UsbDkRule); diff --git a/UsbDk/DeviceAccess.cpp b/UsbDk/DeviceAccess.cpp index 58d28cd..6810a18 100644 --- a/UsbDk/DeviceAccess.cpp +++ b/UsbDk/DeviceAccess.cpp @@ -186,6 +186,23 @@ bool CWdmDeviceAccess::QueryPowerData(CM_POWER_DATA& powerData) #endif } +static void PowerRequestCompletion( + _In_ PDEVICE_OBJECT DeviceObject, + _In_ UCHAR MinorFunction, + _In_ POWER_STATE PowerState, + _In_opt_ PVOID Context, + _In_ PIO_STATUS_BLOCK IoStatus +) +{ + UNREFERENCED_PARAMETER(DeviceObject); + UNREFERENCED_PARAMETER(MinorFunction); + UNREFERENCED_PARAMETER(PowerState); + UNREFERENCED_PARAMETER(IoStatus); + CWdmEvent *pev = (CWdmEvent *)Context; + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_DEVACCESS, "%!FUNC! -> D%d", PowerState.DeviceState - 1); + pev->Set(); +} + PWCHAR CWdmDeviceAccess::MakeNonPagedDuplicate(BUS_QUERY_ID_TYPE idType, PWCHAR idData) { auto bufferLength = GetIdBufferLength(idType, idData); @@ -233,9 +250,23 @@ NTSTATUS CWdmDeviceAccess::QueryForInterface(const GUID &guid, __out INTERFACE & return status; } -NTSTATUS CWdmUsbDeviceAccess::Reset() +NTSTATUS CWdmUsbDeviceAccess::Reset(bool ForceD0) { CIoControlIrp Irp; + CM_POWER_DATA powerData; + if (ForceD0 && QueryPowerData(powerData) && powerData.PD_MostRecentPowerState != PowerDeviceD0) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_DEVACCESS, "%!FUNC! device power state D%d", powerData.PD_MostRecentPowerState - 1); + POWER_STATE PowerState; + CWdmEvent Event; + PowerState.DeviceState = PowerDeviceD0; + auto status = PoRequestPowerIrp(m_DevObj, IRP_MN_SET_POWER, PowerState, PowerRequestCompletion, &Event, NULL); + if (NT_SUCCESS(status)) + { + Event.Wait(); + } + } + auto status = Irp.Create(m_DevObj, IOCTL_INTERNAL_USB_CYCLE_PORT); if (!NT_SUCCESS(status)) diff --git a/UsbDk/DeviceAccess.h b/UsbDk/DeviceAccess.h index d107cee..0749102 100644 --- a/UsbDk/DeviceAccess.h +++ b/UsbDk/DeviceAccess.h @@ -71,7 +71,7 @@ public: : CWdmDeviceAccess(WdmDevice) { } - NTSTATUS Reset(); + NTSTATUS Reset(bool ForceD0); NTSTATUS GetDeviceDescriptor(USB_DEVICE_DESCRIPTOR &Descriptor); NTSTATUS GetConfigurationDescriptor(UCHAR Index, USB_CONFIGURATION_DESCRIPTOR &Descriptor, size_t Length);