diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index b27a323328..867cec4081 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -101,6 +101,27 @@ def get_passphrase( raise exceptions.Cancelled from None +def get_client(transport: Transport) -> TrezorClient: + stored_channels = channel_database.load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + try: + client = TrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + except Exception: + LOG.debug("Failed to resume a channel. Replacing by a new one.") + channel_database.remove_channel(path) + client = TrezorClient(transport) + else: + client = TrezorClient(transport) + return client + + class TrezorConnection: def __init__( @@ -151,27 +172,7 @@ class TrezorConnection: return transport.get_transport(self.path, prefix_search=True) def get_client(self) -> TrezorClient: - transport = self.get_transport() - - stored_channels = channel_database.load_stored_channels() - stored_transport_paths = [ch.transport_path for ch in stored_channels] - path = transport.get_path() - if path in stored_transport_paths: - stored_channel_with_correct_transport_path = next( - ch for ch in stored_channels if ch.transport_path == path - ) - try: - client = TrezorClient.resume( - transport, stored_channel_with_correct_transport_path - ) - except Exception: - LOG.debug("Failed to resume a channel. Replacing by a new one.") - channel_database.remove_channel(path) - client = TrezorClient(transport) - else: - client = TrezorClient(transport) - - return client + return get_client(self.get_transport()) def get_management_session(self) -> Session: client = self.get_client() diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index a45406f5cd..bdf5206e9f 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -290,22 +290,11 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: return enumerate_devices() + from . import get_client - stored_channels = channel_database.load_stored_channels() - stored_transport_paths = [ch.transport_path for ch in stored_channels] for transport in enumerate_devices(): try: - path = transport.get_path() - if path in stored_transport_paths: - stored_channel_with_correct_transport_path = next( - ch for ch in stored_channels if ch.transport_path == path - ) - client = TrezorClient.resume( - transport, stored_channel_with_correct_transport_path - ) - else: - client = TrezorClient(transport) - + client = get_client(transport) description = format_device_name(client.features) # json_string = channel_database.channel_to_str(client.protocol) # print(json_string)