diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 3ba76a67d9..51e6e98c79 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -26,7 +26,7 @@ from contextlib import contextmanager import click from .. import exceptions, transport, ui -from ..client import TrezorClient +from ..client import PROTOCOL_V2, TrezorClient from ..messages import Capability from ..transport import Transport from ..transport.thp.channel_database import get_channel_db @@ -136,7 +136,7 @@ class TrezorConnection: self.passphrase_on_host = passphrase_on_host self.script = script - def get_session(self, derive_cardano: bool = False): + def get_session(self, derive_cardano: bool = False, empty_passphrase: bool = False): client = self.get_client() if self.session_id is not None: @@ -148,10 +148,12 @@ class TrezorConnection: if not passphrase_enabled: return client.get_session(derive_cardano=derive_cardano) - # TODO Passphrase empty by default - ??? - available_on_device = Capability.PassphraseEntry in features.capabilities - passphrase = get_passphrase(available_on_device, self.passphrase_on_host) - # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if empty_passphrase: + passphrase = "" + else: + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func if not isinstance(passphrase, str): raise RuntimeError("Passphrase must be a str") session = client.get_session( @@ -310,6 +312,25 @@ def with_session( return function_with_session +def with_default_session( + func: "t.Callable[Concatenate[Session, P], R]", derive_cardano: bool = False +) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + session = obj.get_session(derive_cardano, empty_passphrase=True) + try: + return func(session, *args, **kwargs) + finally: + pass + # TODO try end session if not resumed + + return function_with_session + + def with_management_session( func: "t.Callable[Concatenate[Session, P], R]", ) -> "t.Callable[P, R]": @@ -355,7 +376,8 @@ def with_client( try: return func(client, *args, **kwargs) finally: - get_channel_db().save_channel(client.protocol) + if client.protocol_version == PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) # if not session_was_resumed: # try: # client.end_session() diff --git a/python/src/trezorlib/cli/benchmark.py b/python/src/trezorlib/cli/benchmark.py index a7ebab12f7..bef9266a28 100644 --- a/python/src/trezorlib/cli/benchmark.py +++ b/python/src/trezorlib/cli/benchmark.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional import click from .. import benchmark -from . import with_session +from . import with_default_session if TYPE_CHECKING: @@ -41,7 +41,7 @@ def cli() -> None: @cli.command() @click.argument("pattern", required=False) -@with_session +@with_default_session def list_names(session: "Session", pattern: Optional[str] = None) -> None: """List names of all supported benchmarks""" names = list_names_patern(session, pattern) @@ -54,7 +54,7 @@ def list_names(session: "Session", pattern: Optional[str] = None) -> None: @cli.command() @click.argument("pattern", required=False) -@with_session +@with_default_session def run(session: "Session", pattern: Optional[str]) -> None: """Run benchmark""" names = list_names_patern(session, pattern) diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index b8fd2cdcb1..1f5be15255 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Tuple import click from .. import misc, tools -from . import ChoiceType, with_session +from . import ChoiceType, with_default_session if TYPE_CHECKING: from ..transport.session import Session @@ -42,7 +42,7 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_session +@with_default_session def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" return misc.get_entropy(session, size).hex() @@ -55,7 +55,7 @@ def get_entropy(session: "Session", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_session +@with_default_session def encrypt_keyvalue( session: "Session", address: str, @@ -91,7 +91,7 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_session +@with_default_session def decrypt_keyvalue( session: "Session", address: str, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 8d5e5628ba..2071a99fb5 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING import click from .. import fido -from . import with_session +from . import with_default_session if TYPE_CHECKING: from ..transport.session import Session @@ -40,7 +40,7 @@ def credentials() -> None: @credentials.command(name="list") -@with_session +@with_default_session def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" creds = fido.list_credentials(session) @@ -79,7 +79,7 @@ def credentials_list(session: "Session") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_session +@with_default_session def credentials_add(session: "Session", hex_credential_id: str) -> str: """Add the credential with the given ID as a resident credential. @@ -92,7 +92,7 @@ def credentials_add(session: "Session", hex_credential_id: str) -> str: @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_session +@with_default_session def credentials_remove(session: "Session", index: int) -> str: """Remove the resident credential at the given index.""" return fido.remove_credential(session, index) @@ -110,14 +110,14 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_session +@with_default_session def counter_set(session: "Session", counter: int) -> str: """Set FIDO/U2F counter value.""" return fido.set_counter(session, counter) @counter.command(name="get-next") -@with_session +@with_default_session def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 64308ee0cc..1269092941 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -54,6 +54,7 @@ from . import ( tezos, with_client, with_session, + with_default_session, ) F = TypeVar("F", bound=Callable) @@ -328,7 +329,7 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_session +@with_default_session def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" @@ -362,7 +363,7 @@ def get_session( "Upgrade your firmware to enable session support." ) - client.ensure_unlocked() + # client.ensure_unlocked() session = client.get_session( passphrase=passphrase, derive_cardano=derive_cardano ) @@ -376,9 +377,9 @@ def get_session( @with_session def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" - # TODO something like old: return client.clear_session() - print("NOT IMPLEMENTED") - raise NotImplementedError + session.call(messages.LockDevice()) + session.end() + # TODO different behaviour than main, not sure if ok @cli.command() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index e7a7b0597f..266ca9f8a8 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -48,6 +48,9 @@ Or visit https://suite.trezor.io/ LOG = logging.getLogger(__name__) +UNKNOWN = -1 +PROTOCOL_V1 = 1 +PROTOCOL_V2 = 2 class TrezorClient: button_callback: t.Callable[[Session, t.Any], t.Any] | None = None @@ -56,7 +59,7 @@ class TrezorClient: _management_session: Session | None = None _features: messages.Features | None = None - + _protocol_version: int def __init__( self, transport: Transport, @@ -77,6 +80,12 @@ class TrezorClient: else: self.protocol = protocol self.protocol.mapping = self.mapping + if isinstance(self.protocol, ProtocolV1): + self._protocol_version = PROTOCOL_V1 + elif isinstance(self.protocol, ProtocolV2): + self._protocol_version = PROTOCOL_V2 + else: + self._protocol_version = UNKNOWN @classmethod def resume( @@ -170,6 +179,10 @@ class TrezorClient: assert self._features is not None return self._features + @property + def protocol_version(self) -> int: + return self._protocol_version + @property def model(self) -> models.TrezorModel: f = self.features @@ -196,10 +209,6 @@ class TrezorClient: self.protocol.update_features() self._features = self.protocol.get_features() - def ensure_unlocked(self) -> None: - # TODO implement - raise NotImplementedError - def _get_protocol(self) -> ProtocolAndChannel: self.transport.open() diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index efe40ffc0a..9be4083194 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -103,6 +103,8 @@ class SessionV1(Session): else: session = SessionV1(client, session_id) session.button_callback = client.button_callback + if session.button_callback is None: + session.button_callback = _callback_button session.pin_callback = client.pin_callback session.passphrase_callback = client.passphrase_callback session.passphrase = passphrase