1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-25 23:01:02 +00:00

fix(python): fix list devices session resumption

[no changelog]
This commit is contained in:
M1nd3r 2024-11-19 18:07:53 +01:00
parent 0e0b322050
commit d6508e3235
2 changed files with 24 additions and 34 deletions

View File

@ -101,6 +101,27 @@ def get_passphrase(
raise exceptions.Cancelled from None raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
stored_channels = channel_database.load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
try:
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
channel_database.remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
return client
class TrezorConnection: class TrezorConnection:
def __init__( def __init__(
@ -151,27 +172,7 @@ class TrezorConnection:
return transport.get_transport(self.path, prefix_search=True) return transport.get_transport(self.path, prefix_search=True)
def get_client(self) -> TrezorClient: def get_client(self) -> TrezorClient:
transport = self.get_transport() return get_client(self.get_transport())
stored_channels = channel_database.load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
try:
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
channel_database.remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
return client
def get_management_session(self) -> Session: def get_management_session(self) -> Session:
client = self.get_client() client = self.get_client()

View File

@ -290,22 +290,11 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices.""" """List connected Trezor devices."""
if no_resolve: if no_resolve:
return enumerate_devices() return enumerate_devices()
from . import get_client
stored_channels = channel_database.load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
for transport in enumerate_devices(): for transport in enumerate_devices():
try: try:
path = transport.get_path() client = get_client(transport)
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
else:
client = TrezorClient(transport)
description = format_device_name(client.features) description = format_device_name(client.features)
# json_string = channel_database.channel_to_str(client.protocol) # json_string = channel_database.channel_to_str(client.protocol)
# print(json_string) # print(json_string)