mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-25 02:08:48 +00:00
feat(python): introduce expect argument to client.call
This commit is contained in:
parent
6c4064489a
commit
c7231e5de9
@ -154,7 +154,6 @@ if TYPE_CHECKING:
|
|||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.HelloWorldResponse, field="text", ret_type=str)
|
|
||||||
def say_hello(
|
def say_hello(
|
||||||
client: "TrezorClient",
|
client: "TrezorClient",
|
||||||
name: str,
|
name: str,
|
||||||
@ -166,8 +165,9 @@ def say_hello(
|
|||||||
name=name,
|
name=name,
|
||||||
amount=amount,
|
amount=amount,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
)
|
),
|
||||||
)
|
expect=messages.HelloWorldResponse,
|
||||||
|
).text
|
||||||
```
|
```
|
||||||
|
|
||||||
Code above is sending `HelloWorldRequest` into Trezor and is expecting to get `HelloWorldResponse` back (from which it extracts the `text` string as a response).
|
Code above is sending `HelloWorldRequest` into Trezor and is expecting to get `HelloWorldResponse` back (from which it extracts the `text` string as a response).
|
||||||
|
1
python/.changelog.d/4464.added.1
Normal file
1
python/.changelog.d/4464.added.1
Normal file
@ -0,0 +1 @@
|
|||||||
|
Added an `expect` argument to `TrezorClient.call()`, to enforce the returned message type.
|
@ -14,6 +14,8 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@ -24,14 +26,15 @@ from mnemonic import Mnemonic
|
|||||||
from . import exceptions, mapping, messages, models
|
from . import exceptions, mapping, messages, models
|
||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import Capability
|
from .messages import Capability
|
||||||
|
from .protobuf import MessageType
|
||||||
from .tools import expect, parse_path, session
|
from .tools import expect, parse_path, session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .protobuf import MessageType
|
|
||||||
from .transport import Transport
|
from .transport import Transport
|
||||||
from .ui import TrezorClientUI
|
from .ui import TrezorClientUI
|
||||||
|
|
||||||
UI = TypeVar("UI", bound="TrezorClientUI")
|
UI = TypeVar("UI", bound="TrezorClientUI")
|
||||||
|
MT = TypeVar("MT", bound=MessageType)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -149,12 +152,12 @@ class TrezorClient(Generic[UI]):
|
|||||||
def cancel(self) -> None:
|
def cancel(self) -> None:
|
||||||
self._raw_write(messages.Cancel())
|
self._raw_write(messages.Cancel())
|
||||||
|
|
||||||
def call_raw(self, msg: "MessageType") -> "MessageType":
|
def call_raw(self, msg: MessageType) -> MessageType:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
self._raw_write(msg)
|
self._raw_write(msg)
|
||||||
return self._raw_read()
|
return self._raw_read()
|
||||||
|
|
||||||
def _raw_write(self, msg: "MessageType") -> None:
|
def _raw_write(self, msg: MessageType) -> None:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"sending message: {msg.__class__.__name__}",
|
f"sending message: {msg.__class__.__name__}",
|
||||||
@ -167,7 +170,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
)
|
)
|
||||||
self.transport.write(msg_type, msg_bytes)
|
self.transport.write(msg_type, msg_bytes)
|
||||||
|
|
||||||
def _raw_read(self) -> "MessageType":
|
def _raw_read(self) -> MessageType:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
msg_type, msg_bytes = self.transport.read()
|
msg_type, msg_bytes = self.transport.read()
|
||||||
LOG.log(
|
LOG.log(
|
||||||
@ -181,7 +184,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType":
|
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
|
||||||
try:
|
try:
|
||||||
pin = self.ui.get_pin(msg.type)
|
pin = self.ui.get_pin(msg.type)
|
||||||
except exceptions.Cancelled:
|
except exceptions.Cancelled:
|
||||||
@ -204,12 +207,12 @@ class TrezorClient(Generic[UI]):
|
|||||||
else:
|
else:
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType":
|
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
|
||||||
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
||||||
|
|
||||||
def send_passphrase(
|
def send_passphrase(
|
||||||
passphrase: Optional[str] = None, on_device: Optional[bool] = None
|
passphrase: Optional[str] = None, on_device: Optional[bool] = None
|
||||||
) -> "MessageType":
|
) -> MessageType:
|
||||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||||
resp = self.call_raw(msg)
|
resp = self.call_raw(msg)
|
||||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||||
@ -244,7 +247,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
|
|
||||||
return send_passphrase(passphrase, on_device=False)
|
return send_passphrase(passphrase, on_device=False)
|
||||||
|
|
||||||
def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType":
|
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
# do this raw - send ButtonAck first, notify UI later
|
# do this raw - send ButtonAck first, notify UI later
|
||||||
self._raw_write(messages.ButtonAck())
|
self._raw_write(messages.ButtonAck())
|
||||||
@ -252,7 +255,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
return self._raw_read()
|
return self._raw_read()
|
||||||
|
|
||||||
@session
|
@session
|
||||||
def call(self, msg: "MessageType") -> "MessageType":
|
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
|
||||||
self.check_firmware_version()
|
self.check_firmware_version()
|
||||||
resp = self.call_raw(msg)
|
resp = self.call_raw(msg)
|
||||||
while True:
|
while True:
|
||||||
@ -266,6 +269,8 @@ class TrezorClient(Generic[UI]):
|
|||||||
if resp.code == messages.FailureType.ActionCancelled:
|
if resp.code == messages.FailureType.ActionCancelled:
|
||||||
raise exceptions.Cancelled
|
raise exceptions.Cancelled
|
||||||
raise exceptions.TrezorFailure(resp)
|
raise exceptions.TrezorFailure(resp)
|
||||||
|
elif not isinstance(resp, expect):
|
||||||
|
raise exceptions.UnexpectedMessageError(expect, resp)
|
||||||
else:
|
else:
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@ -397,7 +402,7 @@ class TrezorClient(Generic[UI]):
|
|||||||
self,
|
self,
|
||||||
msg: str,
|
msg: str,
|
||||||
button_protection: bool = False,
|
button_protection: bool = False,
|
||||||
) -> "MessageType":
|
) -> MessageType:
|
||||||
# We would like ping to work on any valid TrezorClient instance, but
|
# We would like ping to work on any valid TrezorClient instance, but
|
||||||
# due to the protection modes, we need to go through self.call, and that will
|
# due to the protection modes, we need to go through self.call, and that will
|
||||||
# raise an exception if the firmware is too old.
|
# raise an exception if the firmware is too old.
|
||||||
|
@ -14,10 +14,13 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .messages import Failure
|
from .messages import Failure
|
||||||
|
from .protobuf import MessageType
|
||||||
|
|
||||||
|
|
||||||
class TrezorException(Exception):
|
class TrezorException(Exception):
|
||||||
@ -25,7 +28,7 @@ class TrezorException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class TrezorFailure(TrezorException):
|
class TrezorFailure(TrezorException):
|
||||||
def __init__(self, failure: "Failure") -> None:
|
def __init__(self, failure: Failure) -> None:
|
||||||
self.failure = failure
|
self.failure = failure
|
||||||
self.code = failure.code
|
self.code = failure.code
|
||||||
self.message = failure.message
|
self.message = failure.message
|
||||||
@ -55,3 +58,10 @@ class Cancelled(TrezorException):
|
|||||||
|
|
||||||
class OutdatedFirmwareError(TrezorException):
|
class OutdatedFirmwareError(TrezorException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedMessageError(TrezorException):
|
||||||
|
def __init__(self, expected: type[MessageType], actual: MessageType) -> None:
|
||||||
|
self.expected = expected
|
||||||
|
self.actual = actual
|
||||||
|
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")
|
||||||
|
Loading…
Reference in New Issue
Block a user