1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-17 16:46:05 +00:00

feat(python): introduce expect argument to client.call

This commit is contained in:
matejcik 2025-01-02 15:43:02 +01:00 committed by matejcik
parent 6c4064489a
commit c7231e5de9
4 changed files with 30 additions and 14 deletions

View File

@ -154,7 +154,6 @@ if TYPE_CHECKING:
from .protobuf import MessageType
@expect(messages.HelloWorldResponse, field="text", ret_type=str)
def say_hello(
client: "TrezorClient",
name: str,
@ -166,8 +165,9 @@ def say_hello(
name=name,
amount=amount,
show_display=show_display,
)
)
),
expect=messages.HelloWorldResponse,
).text
```
Code above is sending `HelloWorldRequest` into Trezor and is expecting to get `HelloWorldResponse` back (from which it extracts the `text` string as a response).

View File

@ -0,0 +1 @@
Added an `expect` argument to `TrezorClient.call()`, to enforce the returned message type.

View File

@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import os
import warnings
@ -24,14 +26,15 @@ from mnemonic import Mnemonic
from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES
from .messages import Capability
from .protobuf import MessageType
from .tools import expect, parse_path, session
if TYPE_CHECKING:
from .protobuf import MessageType
from .transport import Transport
from .ui import TrezorClientUI
UI = TypeVar("UI", bound="TrezorClientUI")
MT = TypeVar("MT", bound=MessageType)
LOG = logging.getLogger(__name__)
@ -149,12 +152,12 @@ class TrezorClient(Generic[UI]):
def cancel(self) -> None:
self._raw_write(messages.Cancel())
def call_raw(self, msg: "MessageType") -> "MessageType":
def call_raw(self, msg: MessageType) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()
def _raw_write(self, msg: "MessageType") -> None:
def _raw_write(self, msg: MessageType) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
LOG.debug(
f"sending message: {msg.__class__.__name__}",
@ -167,7 +170,7 @@ class TrezorClient(Generic[UI]):
)
self.transport.write(msg_type, msg_bytes)
def _raw_read(self) -> "MessageType":
def _raw_read(self) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
msg_type, msg_bytes = self.transport.read()
LOG.log(
@ -181,7 +184,7 @@ class TrezorClient(Generic[UI]):
)
return msg
def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
try:
pin = self.ui.get_pin(msg.type)
except exceptions.Cancelled:
@ -204,12 +207,12 @@ class TrezorClient(Generic[UI]):
else:
return resp
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType":
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
available_on_device = Capability.PassphraseEntry in self.features.capabilities
def send_passphrase(
passphrase: Optional[str] = None, on_device: Optional[bool] = None
) -> "MessageType":
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = self.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
@ -244,7 +247,7 @@ class TrezorClient(Generic[UI]):
return send_passphrase(passphrase, on_device=False)
def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType":
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self._raw_write(messages.ButtonAck())
@ -252,7 +255,7 @@ class TrezorClient(Generic[UI]):
return self._raw_read()
@session
def call(self, msg: "MessageType") -> "MessageType":
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
self.check_firmware_version()
resp = self.call_raw(msg)
while True:
@ -266,6 +269,8 @@ class TrezorClient(Generic[UI]):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
elif not isinstance(resp, expect):
raise exceptions.UnexpectedMessageError(expect, resp)
else:
return resp
@ -397,7 +402,7 @@ class TrezorClient(Generic[UI]):
self,
msg: str,
button_protection: bool = False,
) -> "MessageType":
) -> MessageType:
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.

View File

@ -14,10 +14,13 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .messages import Failure
from .protobuf import MessageType
class TrezorException(Exception):
@ -25,7 +28,7 @@ class TrezorException(Exception):
class TrezorFailure(TrezorException):
def __init__(self, failure: "Failure") -> None:
def __init__(self, failure: Failure) -> None:
self.failure = failure
self.code = failure.code
self.message = failure.message
@ -55,3 +58,10 @@ class Cancelled(TrezorException):
class OutdatedFirmwareError(TrezorException):
pass
class UnexpectedMessageError(TrezorException):
def __init__(self, expected: type[MessageType], actual: MessageType) -> None:
self.expected = expected
self.actual = actual
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")