diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 7af36e54d3..32b50d304e 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -196,6 +196,15 @@ class TrezorClient: protocol = ProtocolV2Channel(self.transport, self.mapping) return protocol + def reset_protocol(self): + if self._protocol_version == ProtocolVersion.PROTOCOL_V1: + self.protocol = ProtocolV1Channel(self.transport, self.mapping) + elif self._protocol_version == ProtocolVersion.PROTOCOL_V2: + self.protocol = ProtocolV2Channel(self.transport, self.mapping) + else: + assert False + self._features = None + def is_outdated(self) -> bool: if self.features.bootloader_mode: return False diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 44c22fc48d..e56afed522 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1440,6 +1440,10 @@ class TrezorClientDebugLink(TrezorClient): self.actual_responses.append(resp) return resp + def reset_protocol(self): + super().reset_protocol() + self._seedless_session = self.get_seedless_session(new_session=True) + def load_device( session: "Session", diff --git a/tests/conftest.py b/tests/conftest.py index 66089865fc..7c24f78471 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -353,11 +353,7 @@ def _client_unlocked( # Get a new client _raw_client = _get_raw_client(request) - _raw_client.protocol = None - _raw_client.__init__( - transport=_raw_client.transport, - auto_interact=_raw_client.debug.allow_interactions, - ) + _raw_client.reset_protocol() if not _raw_client.features.bootloader_mode: _raw_client.refresh_features()