diff --git a/docs/developers/hello_world_feature_TT.md b/docs/developers/hello_world_feature_TT.md index fffa7e0f60..cd120fbb5d 100644 --- a/docs/developers/hello_world_feature_TT.md +++ b/docs/developers/hello_world_feature_TT.md @@ -154,7 +154,6 @@ if TYPE_CHECKING: from .protobuf import MessageType -@expect(messages.HelloWorldResponse, field="text", ret_type=str) def say_hello( client: "TrezorClient", name: str, @@ -166,8 +165,9 @@ def say_hello( name=name, amount=amount, show_display=show_display, - ) - ) + ), + expect=messages.HelloWorldResponse, + ).text ``` Code above is sending `HelloWorldRequest` into Trezor and is expecting to get `HelloWorldResponse` back (from which it extracts the `text` string as a response). diff --git a/python/.changelog.d/4464.added.1 b/python/.changelog.d/4464.added.1 new file mode 100644 index 0000000000..2bca4ba58b --- /dev/null +++ b/python/.changelog.d/4464.added.1 @@ -0,0 +1 @@ +Added an `expect` argument to `TrezorClient.call()`, to enforce the returned message type. diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index c912ba00ed..fa7992ab0e 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -14,6 +14,8 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import os import warnings @@ -24,14 +26,15 @@ from mnemonic import Mnemonic from . import exceptions, mapping, messages, models from .log import DUMP_BYTES from .messages import Capability +from .protobuf import MessageType from .tools import expect, parse_path, session if TYPE_CHECKING: - from .protobuf import MessageType from .transport import Transport from .ui import TrezorClientUI UI = TypeVar("UI", bound="TrezorClientUI") +MT = TypeVar("MT", bound=MessageType) LOG = logging.getLogger(__name__) @@ -149,12 +152,12 @@ class TrezorClient(Generic[UI]): def cancel(self) -> None: self._raw_write(messages.Cancel()) - def call_raw(self, msg: "MessageType") -> "MessageType": + def call_raw(self, msg: MessageType) -> MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 self._raw_write(msg) return self._raw_read() - def _raw_write(self, msg: "MessageType") -> None: + def _raw_write(self, msg: MessageType) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 LOG.debug( f"sending message: {msg.__class__.__name__}", @@ -167,7 +170,7 @@ class TrezorClient(Generic[UI]): ) self.transport.write(msg_type, msg_bytes) - def _raw_read(self) -> "MessageType": + def _raw_read(self) -> MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 msg_type, msg_bytes = self.transport.read() LOG.log( @@ -181,7 +184,7 @@ class TrezorClient(Generic[UI]): ) return msg - def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": + def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType: try: pin = self.ui.get_pin(msg.type) except exceptions.Cancelled: @@ -204,12 +207,12 @@ class TrezorClient(Generic[UI]): else: return resp - def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": + def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType: available_on_device = Capability.PassphraseEntry in self.features.capabilities def send_passphrase( passphrase: Optional[str] = None, on_device: Optional[bool] = None - ) -> "MessageType": + ) -> MessageType: msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) resp = self.call_raw(msg) if isinstance(resp, messages.Deprecated_PassphraseStateRequest): @@ -244,7 +247,7 @@ class TrezorClient(Generic[UI]): return send_passphrase(passphrase, on_device=False) - def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": + def _callback_button(self, msg: messages.ButtonRequest) -> MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # do this raw - send ButtonAck first, notify UI later self._raw_write(messages.ButtonAck()) @@ -252,7 +255,7 @@ class TrezorClient(Generic[UI]): return self._raw_read() @session - def call(self, msg: "MessageType") -> "MessageType": + def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: self.check_firmware_version() resp = self.call_raw(msg) while True: @@ -266,6 +269,8 @@ class TrezorClient(Generic[UI]): if resp.code == messages.FailureType.ActionCancelled: raise exceptions.Cancelled raise exceptions.TrezorFailure(resp) + elif not isinstance(resp, expect): + raise exceptions.UnexpectedMessageError(expect, resp) else: return resp @@ -397,7 +402,7 @@ class TrezorClient(Generic[UI]): self, msg: str, button_protection: bool = False, - ) -> "MessageType": + ) -> MessageType: # We would like ping to work on any valid TrezorClient instance, but # due to the protection modes, we need to go through self.call, and that will # raise an exception if the firmware is too old. diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index fd7133d12f..99f0048dd3 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -14,10 +14,13 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: from .messages import Failure + from .protobuf import MessageType class TrezorException(Exception): @@ -25,7 +28,7 @@ class TrezorException(Exception): class TrezorFailure(TrezorException): - def __init__(self, failure: "Failure") -> None: + def __init__(self, failure: Failure) -> None: self.failure = failure self.code = failure.code self.message = failure.message @@ -55,3 +58,10 @@ class Cancelled(TrezorException): class OutdatedFirmwareError(TrezorException): pass + + +class UnexpectedMessageError(TrezorException): + def __init__(self, expected: type[MessageType], actual: MessageType) -> None: + self.expected = expected + self.actual = actual + super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")