diff --git a/trezorlib/transport/hid.py b/trezorlib/transport/hid.py index 656e415906..3bd4a9f7e7 100644 --- a/trezorlib/transport/hid.py +++ b/trezorlib/transport/hid.py @@ -36,8 +36,11 @@ HidDeviceHandle = Any class HidHandle: - def __init__(self, path: str, probe_hid_version: bool = False) -> None: + def __init__( + self, path: bytes, serial: str, probe_hid_version: bool = False + ) -> None: self.path = path + self.serial = serial self.handle = None # type: HidDeviceHandle self.hid_version = None if probe_hid_version else 2 @@ -47,8 +50,19 @@ class HidHandle: self.handle.open_path(self.path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): - e.args = e.args + (UDEV_RULES_STR) + e.args = e.args + (UDEV_RULES_STR,) raise e + + # On some platforms, HID path stays the same over device reconnects. + # That means that someone could unplug a Trezor, plug a different one + # and we wouldn't even know. + # So we check that the serial matches what we expect. + serial = self.handle.get_serial_number_string() + if serial != self.serial: + self.handle.close() + self.handle = None + raise TransportException("Unexpected device on path %s" % self.path) + self.handle.set_nonblocking(True) if self.hid_version is None: @@ -97,14 +111,11 @@ class HidTransport(ProtocolBasedTransport): PATH_PREFIX = "hid" ENABLED = hid is not None - def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None: - if hid_handle is None: - hid_handle = HidHandle(device["path"]) - + def __init__(self, device: HidDevice) -> None: self.device = device - self.hid = hid_handle + self.handle = HidHandle(device["path"], device["serial_number"]) - protocol = ProtocolV1(hid_handle) + protocol = ProtocolV1(self.handle) super().__init__(protocol=protocol) def get_path(self) -> str: