diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 0a49c51eb8..61093711eb 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,21 +14,25 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import functools +import os import sys +import typing as t from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport +from .. import exceptions, transport, ui from ..client import TrezorClient +from ..messages import Capability from ..transport.new import channel_database from ..transport.new.client import NewTrezorClient from ..transport.new.transport import NewTransport from ..ui import ClickUI, ScriptUI -if TYPE_CHECKING: +if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar @@ -43,7 +47,10 @@ if TYPE_CHECKING: class ChoiceType(click.Choice): - def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None: + + def __init__( + self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True + ) -> None: super().__init__(list(typemap.keys())) self.case_sensitive = case_sensitive if case_sensitive: @@ -51,7 +58,7 @@ class ChoiceType(click.Choice): else: self.typemap = {k.lower(): v for k, v in typemap.items()} - def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: + def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -60,11 +67,48 @@ class ChoiceType(click.Choice): return self.typemap[value] +def get_passphrase( + passphrase_on_host: bool, available_on_device: bool +) -> t.Union[str, object]: + if available_on_device and not passphrase_on_host: + return ui.PASSPHRASE_ON_DEVICE + + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: + ui.echo("Passphrase required. Using PASSPHRASE environment variable.") + return env_passphrase + + while True: + try: + passphrase = ui.prompt( + "Passphrase required", + hide_input=True, + default="", + show_default=False, + ) + # In case user sees the input on the screen, we do not need confirmation + if not ui.CAN_HANDLE_HIDDEN_INPUT: + return passphrase + second = ui.prompt( + "Confirm your passphrase", + hide_input=True, + default="", + show_default=False, + ) + if passphrase == second: + return passphrase + else: + ui.echo("Passphrase did not match. Please try again.") + except click.Abort: + raise exceptions.Cancelled from None + + class NewTrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -73,6 +117,29 @@ class NewTrezorConnection: self.passphrase_on_host = passphrase_on_host self.script = script + def get_session( + self, + ): + client = self.get_client() + + if self.session_id is not None: + pass # TODO Try resume + features = client.protocol.get_features() + + passphrase_enabled = True # TODO what to do here? + + if not passphrase_enabled: + return client.get_session(derive_cardano=True) + + # TODO Passphrase empty by default - ??? + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + session = client.get_session(passphrase=passphrase, derive_cardano=True) + return session + def get_transport(self) -> "NewTransport": try: # look for transport without prefix search @@ -100,6 +167,7 @@ class NewTrezorConnection: ) else: client = NewTrezorClient(transport) + return client @contextmanager @@ -135,10 +203,11 @@ class NewTrezorConnection: class TrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -205,9 +274,33 @@ class TrezorConnection: # other exceptions may cause a traceback +from ..transport.new.session import Session + + +def with_session( + func: "t.Callable[Concatenate[Session, P], R]", +) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: NewTrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + session = obj.get_session() + try: + return func(session, *args, **kwargs) + finally: + pass + # TODO try end session if not resumed + + # the return type of @click.pass_obj is improperly specified and pyright doesn't + # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) + return function_with_session # type: ignore [is incompatible with return type] + + def new_with_client( - func: "Callable[Concatenate[NewTrezorClient, P], R]", -) -> "Callable[P, R]": + func: "t.Callable[Concatenate[NewTrezorClient, P], R]", +) -> "t.Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -243,7 +336,9 @@ def new_with_client( return trezorctl_command_with_client # type: ignore [is incompatible with return type] -def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": +def with_client( + func: "t.Callable[Concatenate[TrezorClient, P], R]", +) -> "t.Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -305,14 +400,14 @@ class AliasedGroup(click.Group): def __init__( self, - aliases: Optional[Dict[str, click.Command]] = None, - *args: Any, - **kwargs: Any, + aliases: t.Dict[str, click.Command] | None = None, + *args: t.Any, + **kwargs: t.Any, ) -> None: super().__init__(*args, **kwargs) self.aliases = aliases or {} - def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 52d46f2e2d..c4fe506662 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -29,6 +29,7 @@ from ..client import TrezorClient from ..transport import DeviceIsBusy, new_enumerate_devices from ..transport.new import channel_database from ..transport.new.client import NewTrezorClient +from ..transport.new.session import Session from ..transport.new.udp import UdpTransport from . import ( AliasedGroup, @@ -53,6 +54,7 @@ from . import ( stellar, tezos, with_client, + with_session, ) F = TypeVar("F", bound=Callable) @@ -334,10 +336,14 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_client -def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: +@with_session +def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - return client.ping(message, button_protection=button_protection) + + # TODO return short-circuit from old client for old Trezors + return session.call( + messages.Ping(message=message, button_protection=button_protection) + ) @cli.command() diff --git a/python/src/trezorlib/transport/new/client.py b/python/src/trezorlib/transport/new/client.py index 151a6a7432..be024d1b33 100644 --- a/python/src/trezorlib/transport/new/client.py +++ b/python/src/trezorlib/transport/new/client.py @@ -28,9 +28,11 @@ class NewTrezorClient: self.mapping = mapping.DEFAULT_MAPPING else: self.mapping = protobuf_mapping - if protocol is None: - self.protocol = self._get_protocol() + try: + self.protocol = self._get_protocol() + except Exception as e: + print(e) else: self.protocol = protocol self.protocol.mapping = self.mapping @@ -52,9 +54,8 @@ class NewTrezorClient: def get_session( self, - passphrase: str = "", + passphrase: str | None = None, derive_cardano: bool = False, - management_session: bool = False, ) -> Session: if isinstance(self.protocol, ProtocolV1): return SessionV1.new(self, passphrase, derive_cardano) diff --git a/python/src/trezorlib/transport/new/protocol_and_channel.py b/python/src/trezorlib/transport/new/protocol_and_channel.py index 8b03d28532..7c7774592e 100644 --- a/python/src/trezorlib/transport/new/protocol_and_channel.py +++ b/python/src/trezorlib/transport/new/protocol_and_channel.py @@ -4,6 +4,7 @@ import logging import struct import typing as t +from ... import exceptions, messages from ...log import DUMP_BYTES from ...mapping import ProtobufMapping from .channel_data import ChannelData @@ -30,12 +31,28 @@ class ProtocolAndChannel: # def read(self, session_id: bytes) -> t.Any: ... + def get_features(self) -> messages.Features: + raise NotImplementedError() + def get_channel_data(self) -> ChannelData: raise NotImplementedError class ProtocolV1(ProtocolAndChannel): HEADER_LEN = struct.calcsize(">HL") + _features: messages.Features + _has_valid_features: bool = False + + def get_features(self) -> messages.Features: + if not self._has_valid_features: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + self._has_valid_features = True + + return self._features def read(self) -> t.Any: msg_type, msg_bytes = self._read() diff --git a/python/src/trezorlib/transport/new/protocol_v2.py b/python/src/trezorlib/transport/new/protocol_v2.py index 9e0cc8f1b4..cd3a9ace87 100644 --- a/python/src/trezorlib/transport/new/protocol_v2.py +++ b/python/src/trezorlib/transport/new/protocol_v2.py @@ -10,12 +10,12 @@ from enum import IntEnum from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from ... import messages +from ... import exceptions, messages from ...mapping import ProtobufMapping from ..thp import checksum, curve25519, thp_io from ..thp.checksum import CHECKSUM_LENGTH from ..thp.packet_header import PacketHeader -from . import control_byte +from . import channel_database, control_byte from .channel_data import ChannelData from .protocol_and_channel import ProtocolAndChannel from .transport import NewTransport @@ -56,9 +56,9 @@ class ProtocolV2(ProtocolAndChannel): sync_bit_send: int sync_bit_receive: int - has_valid_channel: bool = False - has_valid_features: bool = False - features: messages.Features + _has_valid_channel: bool = False + _has_valid_features: bool = False + _features: messages.Features def __init__( self, @@ -75,13 +75,14 @@ class ProtocolV2(ProtocolAndChannel): self.nonce_response = channel_data.nonce_response self.sync_bit_receive = channel_data.sync_bit_receive self.sync_bit_send = channel_data.sync_bit_send - self.has_valid_channel = True + self._has_valid_channel = True def get_channel(self) -> ProtocolV2: - if not self.has_valid_channel: + if not self._has_valid_channel: self._establish_new_channel() - if not self.has_valid_features: - self.update_features() + # TODO - Q: should ask for features now or when needed? + # if not self.has_valid_features: + # self.update_features() return self def get_channel_data(self) -> ChannelData: @@ -98,12 +99,23 @@ class ProtocolV2(ProtocolAndChannel): ) def read(self, session_id: int) -> t.Any: - header, data = self._read_until_valid_crc_check() - # TODO + sid, msg_type, msg_data = self.read_and_decrypt() + if sid != session_id: + raise Exception("Received messsage on different session.") + channel_database.save_channel(self) + return self.mapping.decode(msg_type, msg_data) def write(self, session_id: int, msg: t.Any) -> None: msg_type, msg_data = self.mapping.encode(msg) self._encrypt_and_write(session_id, msg_type, msg_data) + channel_database.save_channel(self) + + def get_features(self) -> messages.Features: + if not self._has_valid_channel: + self._establish_new_channel() + if not self._has_valid_features: + self.update_features() + return self._features def update_features(self) -> None: message = messages.GetFeatures() @@ -111,11 +123,12 @@ class ProtocolV2(ProtocolAndChannel): self.session_id: int = 0 self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) _ = self._read_until_valid_crc_check() # TODO check ACK - session_id, msg_type, msg_data = self.read_and_decrypt() + _, msg_type, msg_data = self.read_and_decrypt() features = self.mapping.decode(msg_type, msg_data) - assert isinstance(features, messages.Features) - self.features = features - self.has_valid_features = True + if not isinstance(features, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = features + self._has_valid_features = True def _establish_new_channel(self) -> None: self.sync_bit_send = 0 @@ -260,7 +273,7 @@ class ProtocolV2(ProtocolAndChannel): maaa = self.mapping.decode(msg_type, msg_data) assert isinstance(maaa, messages.ThpEndResponse) - self.has_valid_channel = True + self._has_valid_channel = True def _send_ack_0(self): LOG.debug("sending ack 0") @@ -302,8 +315,13 @@ class ProtocolV2(ProtocolAndChannel): def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: header, raw_payload = self._read_until_valid_crc_check() + if control_byte.is_ack(header.ctrl_byte): + return self.read_and_decrypt() if not header.is_encrypted_transport(): print("Trying to decrypt not encrypted message!") + print( + hexlify(header.to_bytes_init()).decode(), hexlify(raw_payload).decode() + ) if not control_byte.is_ack(header.ctrl_byte): LOG.debug( diff --git a/python/src/trezorlib/transport/new/session.py b/python/src/trezorlib/transport/new/session.py index 3498509373..4e48cc7b9e 100644 --- a/python/src/trezorlib/transport/new/session.py +++ b/python/src/trezorlib/transport/new/session.py @@ -19,7 +19,7 @@ class Session: @classmethod def new( - cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool + cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool ) -> Session: raise NotImplementedError @@ -30,7 +30,7 @@ class Session: class SessionV1(Session): @classmethod def new( - cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool + cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool ) -> SessionV1: assert isinstance(client.protocol, ProtocolV1) session = SessionV1(client, b"") @@ -54,9 +54,8 @@ class SessionV2(Session): @classmethod def new( - cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool + cls, client: NewTrezorClient, passphrase: str | None, derive_cardano: bool ) -> SessionV2: - assert isinstance(client.protocol, ProtocolV2) session = SessionV2(client, b"\x00") new_session: ThpNewSession = session.call( @@ -73,9 +72,7 @@ class SessionV2(Session): self.channel: ProtocolV2 = client.protocol.get_channel() self.update_id_and_sid(id) - if not self.channel.has_valid_features: - self.channel.update_features() - self.features = self.channel.features + self.features = self.channel.get_features() def call(self, msg: t.Any) -> t.Any: self.channel.write(self.sid, msg)