diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index e3c2825c87..f440f84adf 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -30,7 +30,7 @@ from .. import exceptions, transport, ui from ..client import ProtocolVersion, TrezorClient from ..messages import Capability from ..transport import Transport -from ..transport.session import Session, SessionV1 +from ..transport.session import Session, SessionV1, SessionV2 from ..transport.thp.channel_database import ChannelDatabase, get_channel_db LOG = logging.getLogger(__name__) @@ -148,7 +148,7 @@ class TrezorConnection: empty_passphrase: bool = False, must_resume: bool = False, ) -> Session: - client = self.get_client() + client = self.get_client() # add channel database if must_resume and self.session_id is None: click.echo("Failed to resume session - no session id provided") raise RuntimeError("Failed to resume session - no session id provided") @@ -159,6 +159,9 @@ class TrezorConnection: session = SessionV1.resume_from_id( client=client, session_id=self.session_id ) + elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: + session = SessionV2(client, self.session_id) + # TODO fix resumption on THP else: raise Exception("Unsupported client protocol", client.protocol_version) if must_resume: @@ -361,7 +364,8 @@ def with_client( try: return func(client, *args, **kwargs) finally: - pass + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) # if not session_was_resumed: # try: # client.end_session() diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 6ac327c1cd..cc9b7a5779 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca import click from .. import __version__, log, messages, protobuf -from ..client import TrezorClient +from ..client import ProtocolVersion, TrezorClient from ..transport import DeviceIsBusy, enumerate_devices from ..transport.session import Session from ..transport.thp import channel_database @@ -309,6 +309,8 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: try: client = get_client(transport, get_channel_db()) description = format_device_name(client.features) + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" except Exception as e: