mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-22 12:32:02 +00:00
wip: channel database changes
This commit is contained in:
parent
2439395cad
commit
6a3ee2f70f
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
@ -30,7 +31,7 @@ from ..client import ProtocolVersion, TrezorClient
|
||||
from ..messages import Capability
|
||||
from ..transport import Transport
|
||||
from ..transport.session import Session, SessionV1
|
||||
from ..transport.thp.channel_database import get_channel_db
|
||||
from ..transport.thp.channel_database import ChannelDatabase, get_channel_db
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -102,8 +103,9 @@ def get_passphrase(
|
||||
raise exceptions.Cancelled from None
|
||||
|
||||
|
||||
def get_client(transport: Transport) -> TrezorClient:
|
||||
stored_channels = get_channel_db().load_stored_channels()
|
||||
def get_client(transport: Transport, channel_database: ChannelDatabase) -> TrezorClient:
|
||||
stored_channels = channel_database.load_stored_channels()
|
||||
|
||||
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
||||
path = transport.get_path()
|
||||
if path in stored_transport_paths:
|
||||
@ -120,6 +122,7 @@ def get_client(transport: Transport) -> TrezorClient:
|
||||
client = TrezorClient(transport)
|
||||
else:
|
||||
client = TrezorClient(transport)
|
||||
atexit.register(lambda: channel_database.save_channel(client.protocol))
|
||||
return client
|
||||
|
||||
|
||||
@ -131,11 +134,13 @@ class TrezorConnection:
|
||||
session_id: bytes | None,
|
||||
passphrase_on_host: bool,
|
||||
script: bool,
|
||||
channel_database: ChannelDatabase,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.session_id = session_id
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
self.channel_database = channel_database
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
@ -195,7 +200,7 @@ class TrezorConnection:
|
||||
return transport.get_transport(self.path, prefix_search=True)
|
||||
|
||||
def get_client(self) -> TrezorClient:
|
||||
return get_client(self.get_transport())
|
||||
return get_client(self.get_transport(), self.channel_database)
|
||||
|
||||
def get_seedless_session(self) -> Session:
|
||||
client = self.get_client()
|
||||
|
@ -226,7 +226,9 @@ def cli_main(
|
||||
except ValueError:
|
||||
raise click.ClickException(f"Not a valid session id: {session_id}")
|
||||
|
||||
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
|
||||
ctx.obj = TrezorConnection(
|
||||
path, bytes_session_id, passphrase_on_host, script, get_channel_db()
|
||||
)
|
||||
|
||||
# Optionally record the screen into a specified directory.
|
||||
if record:
|
||||
@ -305,7 +307,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||
|
||||
for transport in enumerate_devices():
|
||||
try:
|
||||
client = get_client(transport)
|
||||
client = get_client(transport, get_channel_db())
|
||||
description = format_device_name(client.features)
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
|
Loading…
Reference in New Issue
Block a user