1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00

wip: channel database changes

This commit is contained in:
M1nd3r 2025-02-17 12:17:52 +01:00
parent 2439395cad
commit 6a3ee2f70f
2 changed files with 13 additions and 6 deletions

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import atexit
import functools
import logging
import os
@ -30,7 +31,7 @@ from ..client import ProtocolVersion, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1
from ..transport.thp.channel_database import get_channel_db
from ..transport.thp.channel_database import ChannelDatabase, get_channel_db
LOG = logging.getLogger(__name__)
@ -102,8 +103,9 @@ def get_passphrase(
raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
stored_channels = get_channel_db().load_stored_channels()
def get_client(transport: Transport, channel_database: ChannelDatabase) -> TrezorClient:
stored_channels = channel_database.load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
@ -120,6 +122,7 @@ def get_client(transport: Transport) -> TrezorClient:
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
atexit.register(lambda: channel_database.save_channel(client.protocol))
return client
@ -131,11 +134,13 @@ class TrezorConnection:
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
channel_database: ChannelDatabase,
) -> None:
self.path = path
self.session_id = session_id
self.passphrase_on_host = passphrase_on_host
self.script = script
self.channel_database = channel_database
def get_session(
self,
@ -195,7 +200,7 @@ class TrezorConnection:
return transport.get_transport(self.path, prefix_search=True)
def get_client(self) -> TrezorClient:
return get_client(self.get_transport())
return get_client(self.get_transport(), self.channel_database)
def get_seedless_session(self) -> Session:
client = self.get_client()

View File

@ -226,7 +226,9 @@ def cli_main(
except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}")
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
ctx.obj = TrezorConnection(
path, bytes_session_id, passphrase_on_host, script, get_channel_db()
)
# Optionally record the screen into a specified directory.
if record:
@ -305,7 +307,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
for transport in enumerate_devices():
try:
client = get_client(transport)
client = get_client(transport, get_channel_db())
description = format_device_name(client.features)
except DeviceIsBusy:
description = "Device is in use by another process"