1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-08 05:32:39 +00:00

feat(python): implement session based trezorlib

This commit is contained in:
M1nd3r 2025-02-04 15:19:56 +01:00
parent 7f5764b7d4
commit fbff05a89f
23 changed files with 1307 additions and 827 deletions

View File

@ -7,7 +7,7 @@ import typing as t
from importlib import metadata from importlib import metadata
from . import device from . import device
from .client import TrezorClient from .transport.session import Session
try: try:
cryptography_version = metadata.version("cryptography") cryptography_version = metadata.version("cryptography")
@ -361,7 +361,7 @@ def verify_authentication_response(
def authenticate_device( def authenticate_device(
client: TrezorClient, session: Session,
challenge: bytes | None = None, challenge: bytes | None = None,
*, *,
whitelist: t.Collection[bytes] | None = None, whitelist: t.Collection[bytes] | None = None,
@ -371,7 +371,7 @@ def authenticate_device(
if challenge is None: if challenge is None:
challenge = secrets.token_bytes(16) challenge = secrets.token_bytes(16)
resp = device.authenticate(client, challenge) resp = device.authenticate(session, challenge)
return verify_authentication_response( return verify_authentication_response(
challenge, challenge,

View File

@ -19,16 +19,16 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .transport.session import Session
def list_names( def list_names(
client: "TrezorClient", session: "Session",
) -> messages.BenchmarkNames: ) -> messages.BenchmarkNames:
return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult: def run(session: "Session", name: str) -> messages.BenchmarkResult:
return client.call( return session.call(
messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
) )

View File

@ -18,20 +18,19 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
from .protobuf import dict_to_proto from .protobuf import dict_to_proto
from .tools import session
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
messages.BinanceGetAddress( messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
), ),
@ -40,17 +39,16 @@ def get_address(
def get_public_key( def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False session: "Session", address_n: "Address", show_display: bool = False
) -> bytes: ) -> bytes:
return client.call( return session.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display), messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
expect=messages.BinancePublicKey, expect=messages.BinancePublicKey,
).public_key ).public_key
@session
def sign_tx( def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
) -> messages.BinanceSignedTx: ) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0] msg = tx_json["msgs"][0]
tx_msg = tx_json.copy() tx_msg = tx_json.copy()
@ -59,7 +57,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
client.call(envelope, expect=messages.BinanceTxRequest) session.call(envelope, expect=messages.BinanceTxRequest)
if "refid" in msg: if "refid" in msg:
msg = dict_to_proto(messages.BinanceCancelMsg, msg) msg = dict_to_proto(messages.BinanceCancelMsg, msg)
@ -70,4 +68,4 @@ def sign_tx(
else: else:
raise ValueError("can not determine msg type") raise ValueError("can not determine msg type")
return client.call(msg, expect=messages.BinanceSignedTx) return session.call(msg, expect=messages.BinanceSignedTx)

View File

@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict from typing_extensions import Protocol, TypedDict
from . import exceptions, messages from . import exceptions, messages
from .tools import _return_success, prepare_message_bytes, session from .tools import _return_success, prepare_message_bytes
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
class ScriptSig(TypedDict): class ScriptSig(TypedDict):
asm: str asm: str
@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
def get_public_node( def get_public_node(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
show_display: bool = False, show_display: bool = False,
@ -116,12 +116,12 @@ def get_public_node(
unlock_path_mac: Optional[bytes] = None, unlock_path_mac: Optional[bytes] = None,
) -> messages.PublicKey: ) -> messages.PublicKey:
if unlock_path: if unlock_path:
client.call( session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest, expect=messages.UnlockedPathRequest,
) )
return client.call( return session.call(
messages.GetPublicKey( messages.GetPublicKey(
address_n=n, address_n=n,
ecdsa_curve_name=ecdsa_curve_name, ecdsa_curve_name=ecdsa_curve_name,
@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str:
def get_authenticated_address( def get_authenticated_address(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
@ -151,12 +151,12 @@ def get_authenticated_address(
chunkify: bool = False, chunkify: bool = False,
) -> messages.Address: ) -> messages.Address:
if unlock_path: if unlock_path:
client.call( session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest, expect=messages.UnlockedPathRequest,
) )
return client.call( return session.call(
messages.GetAddress( messages.GetAddress(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -171,13 +171,13 @@ def get_authenticated_address(
def get_ownership_id( def get_ownership_id(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> bytes: ) -> bytes:
return client.call( return session.call(
messages.GetOwnershipId( messages.GetOwnershipId(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -188,8 +188,9 @@ def get_ownership_id(
).ownership_id ).ownership_id
# TODO this is used by tests only
def get_ownership_proof( def get_ownership_proof(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
@ -200,9 +201,9 @@ def get_ownership_proof(
preauthorized: bool = False, preauthorized: bool = False,
) -> Tuple[bytes, bytes]: ) -> Tuple[bytes, bytes]:
if preauthorized: if preauthorized:
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
res = client.call( res = session.call(
messages.GetOwnershipProof( messages.GetOwnershipProof(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -219,7 +220,7 @@ def get_ownership_proof(
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
message: AnyStr, message: AnyStr,
@ -227,7 +228,7 @@ def sign_message(
no_script_type: bool = False, no_script_type: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> messages.MessageSignature: ) -> messages.MessageSignature:
return client.call( return session.call(
messages.SignMessage( messages.SignMessage(
coin_name=coin_name, coin_name=coin_name,
address_n=n, address_n=n,
@ -241,7 +242,7 @@ def sign_message(
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
address: str, address: str,
signature: bytes, signature: bytes,
@ -249,7 +250,7 @@ def verify_message(
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
client.call( session.call(
messages.VerifyMessage( messages.VerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
@ -264,9 +265,9 @@ def verify_message(
return False return False
@session # @session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
inputs: Sequence[messages.TxInputType], inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType], outputs: Sequence[messages.TxOutputType],
@ -314,14 +315,14 @@ def sign_tx(
setattr(signtx, name, value) setattr(signtx, name, value)
if unlock_path: if unlock_path:
client.call( session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest, expect=messages.UnlockedPathRequest,
) )
elif preauthorized: elif preauthorized:
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
res = client.call(signtx, expect=messages.TxRequest) res = session.call(signtx, expect=messages.TxRequest)
# Prepare structure for signatures # Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs) signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -380,7 +381,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ: if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index] msg = payment_reqs[res.details.request_index]
res = client.call(msg, expect=messages.TxRequest) res = session.call(msg, expect=messages.TxRequest)
else: else:
msg = messages.TransactionType() msg = messages.TransactionType()
if res.request_type == R.TXMETA: if res.request_type == R.TXMETA:
@ -410,7 +411,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}." f"Unknown request type - {res.request_type}."
) )
res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest) res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
for i, sig in zip(inputs, signatures): for i, sig in zip(inputs, signatures):
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None: if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
@ -420,7 +421,7 @@ def sign_tx(
def authorize_coinjoin( def authorize_coinjoin(
client: "TrezorClient", session: "Session",
coordinator: str, coordinator: str,
max_rounds: int, max_rounds: int,
max_coordinator_fee_rate: int, max_coordinator_fee_rate: int,
@ -429,7 +430,7 @@ def authorize_coinjoin(
coin_name: str, coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> str | None: ) -> str | None:
resp = client.call( resp = session.call(
messages.AuthorizeCoinJoin( messages.AuthorizeCoinJoin(
coordinator=coordinator, coordinator=coordinator,
max_rounds=max_rounds, max_rounds=max_rounds,

View File

@ -35,7 +35,7 @@ from . import messages as m
from . import tools from . import tools
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .transport.session import Session
PROTOCOL_MAGICS = { PROTOCOL_MAGICS = {
"mainnet": 764824073, "mainnet": 764824073,
@ -818,7 +818,7 @@ def _get_collateral_inputs_items(
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_parameters: m.CardanoAddressParametersType, address_parameters: m.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
@ -826,7 +826,7 @@ def get_address(
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
m.CardanoGetAddress( m.CardanoGetAddress(
address_parameters=address_parameters, address_parameters=address_parameters,
protocol_magic=protocol_magic, protocol_magic=protocol_magic,
@ -840,12 +840,12 @@ def get_address(
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
show_display: bool = False, show_display: bool = False,
) -> m.CardanoPublicKey: ) -> m.CardanoPublicKey:
return client.call( return session.call(
m.CardanoGetPublicKey( m.CardanoGetPublicKey(
address_n=address_n, address_n=address_n,
derivation_type=derivation_type, derivation_type=derivation_type,
@ -856,12 +856,12 @@ def get_public_key(
def get_native_script_hash( def get_native_script_hash(
client: "TrezorClient", session: "Session",
native_script: m.CardanoNativeScript, native_script: m.CardanoNativeScript,
display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE, display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
) -> m.CardanoNativeScriptHash: ) -> m.CardanoNativeScriptHash:
return client.call( return session.call(
m.CardanoGetNativeScriptHash( m.CardanoGetNativeScriptHash(
script=native_script, script=native_script,
display_format=display_format, display_format=display_format,
@ -872,7 +872,7 @@ def get_native_script_hash(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
signing_mode: m.CardanoTxSigningMode, signing_mode: m.CardanoTxSigningMode,
inputs: List[InputWithPath], inputs: List[InputWithPath],
outputs: List[OutputWithData], outputs: List[OutputWithData],
@ -907,7 +907,7 @@ def sign_tx(
signing_mode, signing_mode,
) )
response = client.call( response = session.call(
m.CardanoSignTxInit( m.CardanoSignTxInit(
signing_mode=signing_mode, signing_mode=signing_mode,
inputs_count=len(inputs), inputs_count=len(inputs),
@ -942,12 +942,12 @@ def sign_tx(
_get_certificates_items(certificates), _get_certificates_items(certificates),
withdrawals, withdrawals,
): ):
response = client.call(tx_item, expect=m.CardanoTxItemAck) response = session.call(tx_item, expect=m.CardanoTxItemAck)
sign_tx_response: Dict[str, Any] = {} sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None: if auxiliary_data is not None:
auxiliary_data_supplement = client.call( auxiliary_data_supplement = session.call(
auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
) )
if ( if (
@ -958,25 +958,25 @@ def sign_tx(
auxiliary_data_supplement.__dict__ auxiliary_data_supplement.__dict__
) )
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
for tx_item in chain( for tx_item in chain(
_get_mint_items(mint), _get_mint_items(mint),
_get_collateral_inputs_items(collateral_inputs), _get_collateral_inputs_items(collateral_inputs),
required_signers, required_signers,
): ):
response = client.call(tx_item, expect=m.CardanoTxItemAck) response = session.call(tx_item, expect=m.CardanoTxItemAck)
if collateral_return is not None: if collateral_return is not None:
for tx_item in _get_output_items(collateral_return): for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item, expect=m.CardanoTxItemAck) response = session.call(tx_item, expect=m.CardanoTxItemAck)
for reference_input in reference_inputs: for reference_input in reference_inputs:
response = client.call(reference_input, expect=m.CardanoTxItemAck) response = session.call(reference_input, expect=m.CardanoTxItemAck)
sign_tx_response["witnesses"] = [] sign_tx_response["witnesses"] = []
for witness_request in witness_requests: for witness_request in witness_requests:
response = client.call(witness_request, expect=m.CardanoTxWitnessResponse) response = session.call(witness_request, expect=m.CardanoTxWitnessResponse)
sign_tx_response["witnesses"].append( sign_tx_response["witnesses"].append(
{ {
"type": response.type, "type": response.type,
@ -986,9 +986,9 @@ def sign_tx(
} }
) )
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
sign_tx_response["tx_hash"] = response.tx_hash sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
return sign_tx_response return sign_tx_response

View File

@ -13,28 +13,24 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations from __future__ import annotations
import logging import logging
import os import os
import warnings import typing as t
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from enum import IntEnum
from mnemonic import Mnemonic from . import mapping, messages, models
from .mapping import ProtobufMapping
from .tools import parse_path
from .transport import Transport, get_transport
from .transport.thp.channel_data import ChannelData
from .transport.thp.protocol_and_channel import ProtocolAndChannel
from .transport.thp.protocol_v1 import ProtocolV1
from .transport.thp.protocol_v2 import ProtocolV2
from . import exceptions, mapping, messages, models if t.TYPE_CHECKING:
from .log import DUMP_BYTES from .transport.session import Session
from .messages import Capability
from .protobuf import MessageType
from .tools import parse_path, session
if TYPE_CHECKING:
from .transport import Transport
from .ui import TrezorClientUI
UI = TypeVar("UI", bound="TrezorClientUI")
MT = TypeVar("MT", bound=MessageType)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -51,8 +47,205 @@ Or visit https://suite.trezor.io/
""".strip() """.strip()
LOG = logging.getLogger(__name__)
class ProtocolVersion(IntEnum):
UNKNOWN = 0x00
PROTOCOL_V1 = 0x01 # Codec
PROTOCOL_V2 = 0x02 # THP
class TrezorClient:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
_seedless_session: Session | None = None
_features: messages.Features | None = None
_protocol_version: int
_setup_pin: str | None = None # Should by used only by conftest
def __init__(
self,
transport: Transport,
protobuf_mapping: ProtobufMapping | None = None,
protocol: ProtocolAndChannel | None = None,
) -> None:
self._is_invalidated: bool = False
self.transport = transport
if protobuf_mapping is None:
self.mapping = mapping.DEFAULT_MAPPING
else:
self.mapping = protobuf_mapping
if protocol is None:
self.protocol = self._get_protocol()
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
if isinstance(self.protocol, ProtocolV1):
self._protocol_version = ProtocolVersion.PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2):
self._protocol_version = ProtocolVersion.PROTOCOL_V2
else:
self._protocol_version = ProtocolVersion.UNKNOWN
@classmethod
def resume(
cls,
transport: Transport,
channel_data: ChannelData,
protobuf_mapping: ProtobufMapping | None = None,
) -> TrezorClient:
if protobuf_mapping is None:
protobuf_mapping = mapping.DEFAULT_MAPPING
protocol_v1 = ProtocolV1(transport, protobuf_mapping)
if channel_data.protocol_version_major == 2:
try:
protocol_v1.write(messages.Ping(message="Sanity check - to resume"))
except Exception as e:
print(type(e))
response = protocol_v1.read()
if (
isinstance(response, messages.Failure)
and response.code == messages.FailureType.InvalidProtocol
):
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
protocol.write(0, messages.Ping())
response = protocol.read(0)
if not isinstance(response, messages.Success):
LOG.debug("Failed to resume ProtocolV2")
raise Exception("Failed to resume ProtocolV2")
LOG.debug("Protocol V2 detected - can be resumed")
else:
LOG.debug("Failed to resume ProtocolV2")
raise Exception("Failed to resume ProtocolV2")
else:
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
return TrezorClient(transport, protobuf_mapping, protocol)
def get_session(
self,
passphrase: str | object | None = None,
derive_cardano: bool = False,
session_id: int = 0,
) -> Session:
"""
Returns initialized session (with derived seed).
Will fail if the device is not initialized
"""
from .transport.session import SessionV1, SessionV2
if isinstance(self.protocol, ProtocolV1):
if passphrase is None:
passphrase = ""
return SessionV1.new(self, passphrase, derive_cardano)
if isinstance(self.protocol, ProtocolV2):
assert isinstance(passphrase, str) or passphrase is None
return SessionV2.new(self, passphrase, derive_cardano, session_id)
raise NotImplementedError # TODO
def resume_session(self, session: Session):
"""
Note: this function potentially modifies the input session.
"""
from .debuglink import SessionDebugWrapper
from .transport.session import SessionV1, SessionV2
if isinstance(session, SessionDebugWrapper):
session = session._session
if isinstance(session, SessionV2):
return session
elif isinstance(session, SessionV1):
session.init_session()
return session
else:
raise NotImplementedError
def get_seedless_session(self, new_session: bool = False) -> Session:
from .transport.session import SessionV1, SessionV2
if not new_session and self._seedless_session is not None:
return self._seedless_session
if isinstance(self.protocol, ProtocolV1):
self._seedless_session = SessionV1.new(
client=self,
passphrase="",
derive_cardano=False,
)
elif isinstance(self.protocol, ProtocolV2):
self._seedless_session = SessionV2(client=self, id=b"\x00")
assert self._seedless_session is not None
return self._seedless_session
def invalidate(self) -> None:
self._is_invalidated = True
@property
def features(self) -> messages.Features:
if self._features is None:
self._features = self.protocol.get_features()
assert self._features is not None
return self._features
@property
def protocol_version(self) -> int:
return self._protocol_version
@property
def model(self) -> models.TrezorModel:
f = self.features
model = models.by_name(f.model or "1")
if model is None:
raise RuntimeError(
"Unsupported Trezor model"
f" (internal_model: {f.internal_model}, model: {f.model})"
)
return model
@property
def version(self) -> tuple[int, int, int]:
f = self.features
ver = (
f.major_version,
f.minor_version,
f.patch_version,
)
return ver
@property
def is_invalidated(self) -> bool:
return self._is_invalidated
def refresh_features(self) -> None:
self.protocol.update_features()
self._features = self.protocol.get_features()
def _get_protocol(self) -> ProtocolAndChannel:
self.transport.open()
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
protocol.write(messages.Initialize())
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
if response.code == messages.FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2(self.transport, self.mapping)
return protocol
def get_default_client( def get_default_client(
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any path: t.Optional[str] = None,
**kwargs: t.Any,
) -> "TrezorClient": ) -> "TrezorClient":
"""Get a client for a connected Trezor device. """Get a client for a connected Trezor device.
@ -62,436 +255,10 @@ def get_default_client(
the value of TREZOR_PATH env variable, or finds first connected Trezor. the value of TREZOR_PATH env variable, or finds first connected Trezor.
If no UI is supplied, instantiates the default CLI UI. If no UI is supplied, instantiates the default CLI UI.
""" """
from .transport import get_transport
from .ui import ClickUI
if path is None: if path is None:
path = os.getenv("TREZOR_PATH") path = os.getenv("TREZOR_PATH")
transport = get_transport(path, prefix_search=True) transport = get_transport(path, prefix_search=True)
if ui is None:
ui = ClickUI()
return TrezorClient(transport, ui, **kwargs) return TrezorClient(transport, **kwargs)
class TrezorClient(Generic[UI]):
"""Trezor client, a connection to a Trezor device.
This class allows you to manage connection state, send and receive protobuf
messages, handle user interactions, and perform some generic tasks
(send a cancel message, initialize or clear a session, ping the device).
"""
model: models.TrezorModel
transport: "Transport"
session_id: Optional[bytes]
ui: UI
features: messages.Features
def __init__(
self,
transport: "Transport",
ui: UI,
session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None,
model: Optional[models.TrezorModel] = None,
_init_device: bool = True,
) -> None:
"""Create a TrezorClient instance.
You have to provide a `transport`, i.e., a raw connection to the device. You can
use `trezorlib.transport.get_transport` to find one.
You have to provide a UI implementation for the three kinds of interaction:
- button request (notify the user that their interaction is needed)
- PIN request (on T1, ask the user to input numbers for a PIN matrix)
- passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for
details.
You can supply a `session_id` you might have saved in the previous session. If
you do, the user might not need to enter their passphrase again.
You can provide Trezor model information. If not provided, it is detected from
the model name reported at initialization time.
By default, the instance will open a connection to the Trezor device, send an
`Initialize` message, set up the `features` field from the response, and connect
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
this means that `client.features` is unset. Use `client.init_device()` or
`client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break.
Only use this if you are _sure_ that you know what you are doing. This feature
might be removed at any time.
"""
LOG.info(f"creating client instance for device: {transport.get_path()}")
# Here, self.model could be set to None. Unless _init_device is False, it will
# get correctly reconfigured as part of the init_device flow.
self.model = model # type: ignore ["None" is incompatible with "TrezorModel"]
if self.model:
self.mapping = self.model.default_mapping
else:
self.mapping = mapping.DEFAULT_MAPPING
self.transport = transport
self.ui = ui
self.session_counter = 0
self.session_id = session_id
if _init_device:
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
def open(self) -> None:
if self.session_counter == 0:
self.transport.begin_session()
self.session_counter += 1
def close(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
# TODO call EndSession here?
self.transport.end_session()
def cancel(self) -> None:
self._raw_write(messages.Cancel())
def call_raw(self, msg: MessageType) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()
def _raw_write(self, msg: MessageType) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self.transport.write(msg_type, msg_bytes)
def _raw_read(self) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
msg_type, msg_bytes = self.transport.read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
return msg
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
try:
pin = self.ui.get_pin(msg.type)
except exceptions.Cancelled:
self.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
self.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise exceptions.PinException(resp.code, resp.message)
else:
return resp
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
available_on_device = Capability.PassphraseEntry in self.features.capabilities
def send_passphrase(
passphrase: Optional[str] = None, on_device: Optional[bool] = None
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = self.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
self.session_id = resp.state
resp = self.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
passphrase = self.ui.get_passphrase(available_on_device=available_on_device)
except exceptions.Cancelled:
self.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
self.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
self.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self._raw_write(messages.ButtonAck())
self.ui.button_request(msg)
return self._raw_read()
@session
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
self.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
resp = self._callback_pin(resp)
elif isinstance(resp, messages.PassphraseRequest):
resp = self._callback_passphrase(resp)
elif isinstance(resp, messages.ButtonRequest):
resp = self._callback_button(resp)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
elif not isinstance(resp, expect):
raise exceptions.UnexpectedMessageError(expect, resp)
else:
return resp
def _refresh_features(self, features: messages.Features) -> None:
"""Update internal fields based on passed-in Features message."""
if not self.model:
# Trezor Model One bootloader 1.8.0 or older does not send model name
model = models.by_internal_name(features.internal_model)
if model is None:
model = models.by_name(features.model or "1")
if model is None:
raise RuntimeError(
"Unsupported Trezor model"
f" (internal_model: {features.internal_model}, model: {features.model})"
)
self.model = model
if features.vendor not in self.model.vendors:
raise RuntimeError("Unsupported device")
self.features = features
self.version = (
self.features.major_version,
self.features.minor_version,
self.features.patch_version,
)
self.check_firmware_version(warn_only=True)
if self.features.session_id is not None:
self.session_id = self.features.session_id
self.features.session_id = None
@session
def refresh_features(self) -> messages.Features:
"""Reload features from the device.
Should be called after changing settings or performing operations that affect
device state.
"""
resp = self.call_raw(messages.GetFeatures())
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._refresh_features(resp)
return resp
@session
def init_device(
self,
*,
session_id: Optional[bytes] = None,
new_session: bool = False,
derive_cardano: Optional[bool] = None,
) -> Optional[bytes]:
"""Initialize the device and return a session ID.
You can optionally specify a session ID. If the session still exists on the
device, the same session ID will be returned and the session is resumed.
Otherwise a different session ID is returned.
Specify `new_session=True` to open a fresh session. Since firmware version
1.9.0/2.3.0, the previous session will remain cached on the device, and can be
resumed by calling `init_device` again with the appropriate session ID.
If neither `new_session` nor `session_id` is specified, the current session ID
will be reused. If no session ID was cached, a new session ID will be allocated
and returned.
# Version notes:
Trezor One older than 1.9.0 does not have session management. Optional arguments
have no effect and the function returns None
Trezor T older than 2.3.0 does not have session cache. Requesting a new session
will overwrite the old one. In addition, this function will always return None.
A valid session_id can be obtained from the `session_id` attribute, but only
after a passphrase-protected call is performed. You can use the following code:
>>> client.init_device()
>>> client.ensure_unlocked()
>>> valid_session_id = client.session_id
"""
if new_session:
self.session_id = None
elif session_id is not None:
self.session_id = session_id
resp = self.call_raw(
messages.Initialize(
session_id=self.session_id,
derive_cardano=derive_cardano,
)
)
if isinstance(resp, messages.Failure):
# can happen if `derive_cardano` does not match the current session
raise exceptions.TrezorFailure(resp)
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to Initialize")
if self.session_id is not None and resp.session_id == self.session_id:
LOG.info("Successfully resumed session")
elif session_id is not None:
LOG.info("Failed to resume session")
# TT < 2.3.0 compatibility:
# _refresh_features will clear out the session_id field. We want this function
# to return its value, so that callers can rely on it being either a valid
# session_id, or None if we can't do that.
# Older TT FW does not report session_id in Features and self.session_id might
# be invalid because TT will not allocate a session_id until a passphrase
# exchange happens.
reported_session_id = resp.session_id
self._refresh_features(resp)
return reported_session_id
def is_outdated(self) -> bool:
if self.features.bootloader_mode:
return False
return self.version < self.model.minimum_version
def check_firmware_version(self, warn_only: bool = False) -> None:
if self.is_outdated():
if warn_only:
warnings.warn("Firmware is out of date", stacklevel=2)
else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
def ping(self, msg: str, button_protection: bool = False) -> str:
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.
# So we short-circuit the simplest variant of ping with call_raw.
if not button_protection:
# XXX this should be: `with self:`
try:
self.open()
resp = self.call_raw(messages.Ping(message=msg))
if isinstance(resp, messages.ButtonRequest):
# device is PIN-locked.
# respond and hope for the best
resp = self._callback_button(resp)
resp = messages.Success.ensure_isinstance(resp)
assert resp.message is not None
return resp.message
finally:
self.close()
resp = self.call(
messages.Ping(message=msg, button_protection=button_protection),
expect=messages.Success,
)
assert resp.message is not None
return resp.message
def get_device_id(self) -> Optional[str]:
return self.features.device_id
@session
def lock(self, *, _refresh_features: bool = True) -> None:
"""Lock the device.
If the device does not have a PIN configured, this will do nothing.
Otherwise, a lock screen will be shown and the device will prompt for PIN
before further actions.
This call does _not_ invalidate passphrase cache. If passphrase is in use,
the device will not prompt for it after unlocking.
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
passphrase cache, use `clear_session()`.
"""
# Private argument _refresh_features can be used internally to avoid
# refreshing in cases where we will refresh soon anyway. This is used
# in TrezorClient.clear_session()
self.call(messages.LockDevice())
if _refresh_features:
self.refresh_features()
@session
def ensure_unlocked(self) -> None:
"""Ensure the device is unlocked and a passphrase is cached.
If the device is locked, this will prompt for PIN. If passphrase is enabled
and no passphrase is cached for the current session, the device will also
prompt for passphrase.
After calling this method, further actions on the device will not prompt for
PIN or passphrase until the device is locked or the session becomes invalid.
"""
from .btc import get_address
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features()
def end_session(self) -> None:
"""Close the current session and clear cached passphrase.
The session will become invalid until `init_device()` is called again.
If passphrase is enabled, further actions will prompt for it again.
This is a no-op in bootloader mode, as it does not support session management.
"""
# since: 2.3.4, 1.9.4
try:
if not self.features.bootloader_mode:
self.call(messages.EndSession())
except exceptions.TrezorFailure:
# A failure most likely means that the FW version does not support
# the EndSession call. We ignore the failure and clear the local session_id.
# The client-side end result is identical.
pass
self.session_id = None
@session
def clear_session(self) -> None:
"""Lock the device and present a fresh session.
The current session will be invalidated and a new one will be started. If the
device has PIN enabled, it will become locked.
Equivalent to calling `lock()`, `end_session()` and `init_device()`.
"""
self.lock(_refresh_features=False)
self.end_session()
self.init_device(new_session=True)

View File

@ -21,55 +21,55 @@ import logging
import re import re
import textwrap import textwrap
import time import time
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto
from itertools import zip_longest from itertools import zip_longest
from pathlib import Path from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
Sequence,
Tuple,
Union,
)
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import mapping, messages, models, protobuf from . import btc, mapping, messages, models, protobuf
from .client import TrezorClient from .client import (
from .exceptions import TrezorFailure, UnexpectedMessageError MAX_PASSPHRASE_LENGTH,
MAX_PIN_LENGTH,
PASSPHRASE_ON_DEVICE,
TrezorClient,
)
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import DebugWaitType from .messages import Capability, DebugWaitType
from .protobuf import MessageType
from .tools import parse_path
from .transport.session import Session, SessionV1
from .transport.thp.protocol_v1 import ProtocolV1
if TYPE_CHECKING: if t.TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
from .messages import PinMatrixRequestType from .messages import PinMatrixRequestType
from .transport import Transport from .transport import Transport
ExpectedMessage = Union[ ExpectedMessage = t.Union[
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
] ]
AnyDict = Dict[str, Any] AnyDict = t.Dict[str, t.Any]
class InputFunc(Protocol): class InputFunc(Protocol):
def __call__( def __call__(
self, self,
hold_ms: int | None = None, hold_ms: int | None = None,
) -> "None": ... ) -> "None": ...
InputFlowType = Generator[None, messages.ButtonRequest, None] InputFlowType = t.Generator[None, messages.ButtonRequest, None]
EXPECTED_RESPONSES_CONTEXT_LINES = 3 EXPECTED_RESPONSES_CONTEXT_LINES = 3
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -107,11 +107,11 @@ class UnstructuredJSONReader:
except json.JSONDecodeError: except json.JSONDecodeError:
self.dict = {} self.dict = {}
def top_level_value(self, key: str) -> Any: def top_level_value(self, key: str) -> t.Any:
return self.dict.get(key) return self.dict.get(key)
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: def find_objects_with_key_and_value(self, key: str, value: t.Any) -> list[AnyDict]:
def recursively_find(data: Any) -> Iterator[Any]: def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
if isinstance(data, dict): if isinstance(data, dict):
if data.get(key) == value: if data.get(key) == value:
yield data yield data
@ -124,7 +124,7 @@ class UnstructuredJSONReader:
return list(recursively_find(self.dict)) return list(recursively_find(self.dict))
def find_unique_object_with_key_and_value( def find_unique_object_with_key_and_value(
self, key: str, value: Any self, key: str, value: t.Any
) -> AnyDict | None: ) -> AnyDict | None:
objects = self.find_objects_with_key_and_value(key, value) objects = self.find_objects_with_key_and_value(key, value)
if not objects: if not objects:
@ -132,8 +132,10 @@ class UnstructuredJSONReader:
assert len(objects) == 1 assert len(objects) == 1
return objects[0] return objects[0]
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: def find_values_by_key(
def recursively_find(data: Any) -> Iterator[Any]: self, key: str, only_type: type | None = None
) -> list[t.Any]:
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
if isinstance(data, dict): if isinstance(data, dict):
if key in data: if key in data:
yield data[key] yield data[key]
@ -151,8 +153,8 @@ class UnstructuredJSONReader:
return values return values
def find_unique_value_by_key( def find_unique_value_by_key(
self, key: str, default: Any, only_type: type | None = None self, key: str, default: t.Any, only_type: type | None = None
) -> Any: ) -> t.Any:
values = self.find_values_by_key(key, only_type=only_type) values = self.find_values_by_key(key, only_type=only_type)
if not values: if not values:
return default return default
@ -163,7 +165,7 @@ class UnstructuredJSONReader:
class LayoutContent(UnstructuredJSONReader): class LayoutContent(UnstructuredJSONReader):
"""Contains helper functions to extract specific parts of the layout.""" """Contains helper functions to extract specific parts of the layout."""
def __init__(self, json_tokens: Sequence[str]) -> None: def __init__(self, json_tokens: t.Sequence[str]) -> None:
json_str = "".join(json_tokens) json_str = "".join(json_tokens)
super().__init__(json_str) super().__init__(json_str)
@ -429,6 +431,7 @@ class DebugLink:
self.allow_interactions = auto_interact self.allow_interactions = auto_interact
self.mapping = mapping.DEFAULT_MAPPING self.mapping = mapping.DEFAULT_MAPPING
self.protocol = ProtocolV1(self.transport, self.mapping)
# To be set by TrezorClientDebugLink (is not known during creation time) # To be set by TrezorClientDebugLink (is not known during creation time)
self.model: models.TrezorModel | None = None self.model: models.TrezorModel | None = None
self.version: tuple[int, int, int] = (0, 0, 0) self.version: tuple[int, int, int] = (0, 0, 0)
@ -471,10 +474,16 @@ class DebugLink:
return LayoutType.from_model(self.model) return LayoutType.from_model(self.model)
def open(self) -> None: def open(self) -> None:
self.transport.begin_session() self.transport.open()
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_begin_session()
def close(self) -> None: def close(self) -> None:
self.transport.end_session() pass
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_end_session()
def _write(self, msg: protobuf.MessageType) -> None: def _write(self, msg: protobuf.MessageType) -> None:
if self.waiting_for_layout_change: if self.waiting_for_layout_change:
@ -491,15 +500,10 @@ class DebugLink:
DUMP_BYTES, DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
) )
self.transport.write(msg_type, msg_bytes) self.protocol.write(msg)
def _read(self) -> protobuf.MessageType: def _read(self) -> protobuf.MessageType:
ret_type, ret_bytes = self.transport.read() msg = self.protocol.read()
LOG.log(
DUMP_BYTES,
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
)
msg = self.mapping.decode(ret_type, ret_bytes)
# Collapse tokens to make log use less lines. # Collapse tokens to make log use less lines.
msg_for_log = msg msg_for_log = msg
@ -513,7 +517,7 @@ class DebugLink:
) )
return msg return msg
def _call(self, msg: protobuf.MessageType) -> Any: def _call(self, msg: protobuf.MessageType) -> t.Any:
self._write(msg) self._write(msg)
return self._read() return self._read()
@ -531,6 +535,25 @@ class DebugLink:
raise TrezorFailure(result) raise TrezorFailure(result)
return result return result
def pairing_info(
self,
thp_channel_id: bytes | None = None,
handshake_hash: bytes | None = None,
nfc_secret_host: bytes | None = None,
) -> messages.DebugLinkPairingInfo:
result = self._call(
messages.DebugLinkGetPairingInfo(
channel_id=thp_channel_id,
handshake_hash=handshake_hash,
nfc_secret_host=nfc_secret_host,
)
)
while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)):
result = self._read()
if isinstance(result, messages.Failure):
raise TrezorFailure(result)
return result
def read_layout(self, wait: bool | None = None) -> LayoutContent: def read_layout(self, wait: bool | None = None) -> LayoutContent:
""" """
Force waiting for the layout by setting `wait=True`. Force not waiting by Force waiting for the layout by setting `wait=True`. Force not waiting by
@ -547,7 +570,7 @@ class DebugLink:
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
# Next layout change will be caused by external event # Next layout change will be caused by external event
# (e.g. device being auto-locked or as a result of device_handler.run(xxx)) # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
# and not by our debug actions/decisions. # and not by our debug actions/decisions.
# Resetting the debug state so we wait for the next layout change # Resetting the debug state so we wait for the next layout change
# (and do not return the current state). # (and do not return the current state).
@ -562,7 +585,7 @@ class DebugLink:
return LayoutContent(obj.tokens) return LayoutContent(obj.tokens)
@contextmanager @contextmanager
def wait_for_layout_change(self) -> Iterator[None]: def wait_for_layout_change(self) -> t.Iterator[None]:
# make sure some current layout is up by issuing a dummy GetState # make sure some current layout is up by issuing a dummy GetState
self.state() self.state()
@ -615,7 +638,7 @@ class DebugLink:
return "".join([str(matrix.index(p) + 1) for p in pin]) return "".join([str(matrix.index(p) + 1) for p in pin])
def read_recovery_word(self) -> Tuple[str | None, int | None]: def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
state = self.state() state = self.state()
return (state.recovery_fake_word, state.recovery_word_pos) return (state.recovery_fake_word, state.recovery_word_pos)
@ -671,7 +694,7 @@ class DebugLink:
"""Send text input to the device. See `_decision` for more details.""" """Send text input to the device. See `_decision` for more details."""
self._decision(messages.DebugLinkDecision(input=word)) self._decision(messages.DebugLinkDecision(input=word))
def click(self, click: Tuple[int, int], hold_ms: int | None = None) -> None: def click(self, click: t.Tuple[int, int], hold_ms: int | None = None) -> None:
"""Send a click to the device. See `_decision` for more details.""" """Send a click to the device. See `_decision` for more details."""
x, y = click x, y = click
self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms)) self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms))
@ -794,10 +817,10 @@ class DebugUI:
self.clear() self.clear()
def clear(self) -> None: def clear(self) -> None:
self.pins: Iterator[str] | None = None self.pins: t.Iterator[str] | None = None
self.passphrase = "" self.passphrase = ""
self.input_flow: Union[ self.input_flow: t.Union[
Generator[None, messages.ButtonRequest, None], object, None t.Generator[None, messages.ButtonRequest, None], object, None
] = None ] = None
def _default_input_flow(self, br: messages.ButtonRequest) -> None: def _default_input_flow(self, br: messages.ButtonRequest) -> None:
@ -829,7 +852,7 @@ class DebugUI:
raise AssertionError("input flow ended prematurely") raise AssertionError("input flow ended prematurely")
else: else:
try: try:
assert isinstance(self.input_flow, Generator) assert isinstance(self.input_flow, t.Generator)
self.input_flow.send(br) self.input_flow.send(br)
except StopIteration: except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE self.input_flow = self.INPUT_FLOW_DONE
@ -851,12 +874,15 @@ class DebugUI:
class MessageFilter: class MessageFilter:
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
def __init__(
self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
) -> None:
self.message_type = message_type self.message_type = message_type
self.fields: Dict[str, Any] = {} self.fields: t.Dict[str, t.Any] = {}
self.update_fields(**fields) self.update_fields(**fields)
def update_fields(self, **fields: Any) -> "MessageFilter": def update_fields(self, **fields: t.Any) -> "MessageFilter":
for name, value in fields.items(): for name, value in fields.items():
try: try:
self.fields[name] = self.from_message_or_type(value) self.fields[name] = self.from_message_or_type(value)
@ -904,7 +930,7 @@ class MessageFilter:
return True return True
def to_string(self, maxwidth: int = 80) -> str: def to_string(self, maxwidth: int = 80) -> str:
fields: list[Tuple[str, str]] = [] fields: list[t.Tuple[str, str]] = []
for field in self.message_type.FIELDS.values(): for field in self.message_type.FIELDS.values():
if field.name not in self.fields: if field.name not in self.fields:
continue continue
@ -934,7 +960,7 @@ class MessageFilter:
class MessageFilterGenerator: class MessageFilterGenerator:
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
message_type = getattr(messages, key) message_type = getattr(messages, key)
return MessageFilter(message_type).update_fields return MessageFilter(message_type).update_fields
@ -942,6 +968,245 @@ class MessageFilterGenerator:
message_filters = MessageFilterGenerator() message_filters = MessageFilterGenerator()
class SessionDebugWrapper(Session):
def __init__(self, session: Session) -> None:
self._session = session
self.reset_debug_features()
if isinstance(session, SessionDebugWrapper):
raise Exception("Cannot wrap already wrapped session!")
@property
def protocol_version(self) -> int:
return self.client.protocol_version
@property
def client(self) -> TrezorClientDebugLink:
assert isinstance(self._session.client, TrezorClientDebugLink)
return self._session.client
@property
def id(self) -> bytes:
return self._session.id
def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__)
self._session._write(self._filter_message(msg))
def _read(self) -> t.Any:
resp = self._filter_message(self._session._read())
print("reading message:", resp.__class__.__name__)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def set_expected_responses(
self,
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
) -> None:
"""Set a sequence of expected responses to session calls.
Within a given with-block, the list of received responses from device must
match the list of expected responses, otherwise an ``AssertionError`` is raised.
If an expected response is given a field value other than ``None``, that field value
must exactly match the received field value. If a given field is ``None``
(or unspecified) in the expected response, the received field value is not
checked.
Each expected response can also be a tuple ``(bool, message)``. In that case, the
expected response is only evaluated if the first field is ``True``.
This is useful for differentiating sequences between Trezor models:
>>> trezor_one = session.features.model == "1"
>>> session.set_expected_responses([
>>> messages.ButtonRequest(code=ConfirmOutput),
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
>>> messages.Success(),
>>> ])
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
# make sure all items are (bool, message) tuples
expected_with_validity = (
e if isinstance(e, tuple) else (True, e) for e in expected
)
# only apply those items that are (True, message)
self.expected_responses = [
MessageFilter.from_message_or_type(expected)
for valid, expected in expected_with_validity
if valid
]
self.actual_responses = []
def lock(self, *, _refresh_features: bool = True) -> None:
"""Lock the device.
If the device does not have a PIN configured, this will do nothing.
Otherwise, a lock screen will be shown and the device will prompt for PIN
before further actions.
This call does _not_ invalidate passphrase cache. If passphrase is in use,
the device will not prompt for it after unlocking.
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
passphrase cache, use `clear_session()`.
"""
# TODO update the documentation above
# Private argument _refresh_features can be used internally to avoid
# refreshing in cases where we will refresh soon anyway. This is used
# in TrezorClient.clear_session()
self.call(messages.LockDevice())
if _refresh_features:
self.refresh_features()
def cancel(self) -> None:
self._write(messages.Cancel())
def ensure_unlocked(self) -> None:
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features()
def set_filter(
self,
message_type: t.Type[protobuf.MessageType],
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.filters[message_type] = callback
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def reset_debug_features(self) -> None:
"""Prepare the debugging session for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: t.Dict[
t.Type[protobuf.MessageType],
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {}
self.button_callback = self.client.button_callback
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self._session.passphrase_callback
self.passphrase = self._session.passphrase
def __enter__(self) -> "SessionDebugWrapper":
# For usage in with/expected_responses
if self.in_with_statement:
raise RuntimeError("Do not nest!")
self.in_with_statement = True
return self
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# copy expected/actual responses before clearing them
expected_responses = self.expected_responses
actual_responses = self.actual_responses
# grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.client.ui, DebugUI):
input_flow = self.client.ui.input_flow
else:
input_flow = None
self.reset_debug_features()
if exc_type is None:
# If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses)
if isinstance(input_flow, t.Generator):
# Ensure that the input flow is exhausted
try:
input_flow.throw(
AssertionError("input flow continues past end of test")
)
except StopIteration:
pass
elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
@classmethod
def _verify_responses(
cls,
expected: list[MessageFilter] | None,
actual: list[protobuf.MessageType] | None,
) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if expected is None and actual is None:
return
assert expected is not None
assert actual is not None
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
if exp is None:
output = cls._expectation_lines(expected, i)
output.append("No more messages were expected, but we got:")
for resp in actual[i:]:
output.append(
textwrap.indent(protobuf.format_message(resp), " ")
)
raise AssertionError("\n".join(output))
if act is None:
output = cls._expectation_lines(expected, i)
output.append("This and the following message was not received.")
raise AssertionError("\n".join(output))
if not exp.match(act):
output = cls._expectation_lines(expected, i)
output.append("Actually received:")
output.append(textwrap.indent(protobuf.format_message(act), " "))
raise AssertionError("\n".join(output))
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: list[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
for i in range(start_at, stop_at):
exp = expected[i]
prefix = " " if i != current else ">>> "
output.append(textwrap.indent(exp.to_string(), prefix))
if stop_at < len(expected):
omitted = len(expected) - stop_at
output.append(f" (...{omitted} following responses omitted)")
output.append("")
return output
class TrezorClientDebugLink(TrezorClient): class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses # This class implements automatic responses
# and other functionality for unit tests # and other functionality for unit tests
@ -967,23 +1232,34 @@ class TrezorClientDebugLink(TrezorClient):
raise raise
# set transport explicitly so that sync_responses can work # set transport explicitly so that sync_responses can work
self.transport = transport super().__init__(transport)
self.reset_debug_features() self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
self.reset_debug_features(new_seedless_session=True)
self.sync_responses() self.sync_responses()
super().__init__(transport, ui=self.ui)
# So that we can choose right screenshotting logic (T1 vs TT) # So that we can choose right screenshotting logic (T1 vs TT)
# and know the supported debug capabilities # and know the supported debug capabilities
self.debug.model = self.model self.debug.model = self.model
self.debug.version = self.version self.debug.version = self.version
self.passphrase: str | None = None
@property @property
def layout_type(self) -> LayoutType: def layout_type(self) -> LayoutType:
return self.debug.layout_type return self.debug.layout_type
def reset_debug_features(self) -> None: def get_new_client(self) -> TrezorClientDebugLink:
"""Prepare the debugging client for a new testcase. new_client = TrezorClientDebugLink(
self.transport, self.debug.allow_interactions
)
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
return new_client
def reset_debug_features(self, new_seedless_session: bool = False) -> None:
"""
Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase. Clears all debugging state that might have been modified by a testcase.
""" """
@ -991,30 +1267,139 @@ class TrezorClientDebugLink(TrezorClient):
self.in_with_statement = False self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: dict[ self.filters: t.Dict[
type[protobuf.MessageType], t.Type[protobuf.MessageType],
Callable[[protobuf.MessageType], protobuf.MessageType] | None, t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {} ] = {}
if new_seedless_session:
self._seedless_session = self.get_seedless_session(new_session=True)
@property
def button_callback(self):
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
return _callback_button
@property
def pin_callback(self):
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
try:
pin = self.ui.get_pin(msg.type)
except Cancelled:
session.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
session.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
return _callback_pin
@property
def passphrase_callback(self):
def _callback_passphrase(
session: Session, msg: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
# session.session_id = resp.state
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
if session.passphrase is None and isinstance(session, SessionV1):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
else:
passphrase = session.passphrase
except Cancelled:
session.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
return _callback_passphrase
def ensure_open(self) -> None: def ensure_open(self) -> None:
"""Only open session if there isn't already an open one.""" """Only open session if there isn't already an open one."""
if self.session_counter == 0: # if self.session_counter == 0:
self.open() # self.open()
# TODO check if is this needed
def open(self) -> None: def open(self) -> None:
super().open() pass
if self.session_counter == 1: # TODO is this needed?
self.debug.open() # self.debug.open()
def close(self) -> None: def close(self) -> None:
if self.session_counter == 1: pass
self.debug.close() # TODO is this needed?
super().close() # self.debug.close()
def lock(self) -> None:
s = SessionDebugWrapper(self.get_seedless_session())
s.lock()
def get_session(
self,
passphrase: str | object | None = "",
derive_cardano: bool = False,
session_id: int = 0,
) -> Session:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
return super().get_session(passphrase, derive_cardano, session_id)
def set_filter( def set_filter(
self, self,
message_type: type[protobuf.MessageType], message_type: t.Type[protobuf.MessageType],
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None: ) -> None:
"""Configure a filter function for a specified message type. """Configure a filter function for a specified message type.
@ -1039,7 +1424,7 @@ class TrezorClientDebugLink(TrezorClient):
return msg return msg
def set_input_flow( def set_input_flow(
self, input_flow: InputFlowType | Callable[[], InputFlowType] self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
) -> None: ) -> None:
"""Configure a sequence of input events for the current with-block. """Configure a sequence of input events for the current with-block.
@ -1095,7 +1480,7 @@ class TrezorClientDebugLink(TrezorClient):
self.in_with_statement = True self.in_with_statement = True
return self return self
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
# copy expected/actual responses before clearing them # copy expected/actual responses before clearing them
@ -1108,21 +1493,23 @@ class TrezorClientDebugLink(TrezorClient):
else: else:
input_flow = None input_flow = None
self.reset_debug_features() self.reset_debug_features(new_seedless_session=False)
if exc_type is None: if exc_type is None:
# If no other exception was raised, evaluate missed responses # If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch) # (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses) self._verify_responses(expected_responses, actual_responses)
elif isinstance(input_flow, Generator): elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in # Propagate the exception through the input flow, so that we see in
# traceback where it is stuck. # traceback where it is stuck.
input_flow.throw(exc_type, value, traceback) input_flow.throw(exc_type, value, traceback)
def set_expected_responses( def set_expected_responses(
self, self,
expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]], expected: t.Sequence[
t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]]
],
) -> None: ) -> None:
"""Set a sequence of expected responses to client calls. """Set a sequence of expected responses to client calls.
@ -1161,7 +1548,7 @@ class TrezorClientDebugLink(TrezorClient):
] ]
self.actual_responses = [] self.actual_responses = []
def use_pin_sequence(self, pins: Iterable[str]) -> None: def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs. """Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts. The sequence must be at least as long as the expected number of PIN prompts.
""" """
@ -1169,6 +1556,7 @@ class TrezorClientDebugLink(TrezorClient):
def use_passphrase(self, passphrase: str) -> None: def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase.""" """Respond to passphrase prompts from device with the provided passphrase."""
self.passphrase = passphrase
self.ui.passphrase = Mnemonic.normalize_string(passphrase) self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def use_mnemonic(self, mnemonic: str) -> None: def use_mnemonic(self, mnemonic: str) -> None:
@ -1178,15 +1566,14 @@ class TrezorClientDebugLink(TrezorClient):
def _raw_read(self) -> protobuf.MessageType: def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = self.get_seedless_session()._read()
resp = super()._raw_read()
resp = self._filter_message(resp) resp = self._filter_message(resp)
if self.actual_responses is not None: if self.actual_responses is not None:
self.actual_responses.append(resp) self.actual_responses.append(resp)
return resp return resp
def _raw_write(self, msg: protobuf.MessageType) -> None: def _raw_write(self, msg: protobuf.MessageType) -> None:
return super()._raw_write(self._filter_message(msg)) return self.get_seedless_session()._write(self._filter_message(msg))
@staticmethod @staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
@ -1256,23 +1643,25 @@ class TrezorClientDebugLink(TrezorClient):
# Start by canceling whatever is on screen. This will work to cancel T1 PIN # Start by canceling whatever is on screen. This will work to cancel T1 PIN
# prompt, which is in TINY mode and does not respond to `Ping`. # prompt, which is in TINY mode and does not respond to `Ping`.
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) # TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
self.transport.begin_session() self.transport.open()
try: try:
self.transport.write(*cancel_msg) # self.protocol.write(messages.Cancel())
message = "SYNC" + secrets.token_hex(8) message = "SYNC" + secrets.token_hex(8)
ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) self.get_seedless_session()._write(messages.Ping(message=message))
self.transport.write(*ping_msg)
resp = None resp = None
while resp != messages.Success(message=message): while resp != messages.Success(message=message):
msg_id, msg_bytes = self.transport.read()
try: try:
resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) resp = self.get_seedless_session()._read()
raise Exception
except Exception: except Exception:
pass pass
finally: finally:
self.transport.end_session() pass # TODO fix
# self.transport.end_session(self.session_id or b"")
def mnemonic_callback(self, _) -> str: def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word() word, pos = self.debug.read_recovery_word()
@ -1285,8 +1674,8 @@ class TrezorClientDebugLink(TrezorClient):
def load_device( def load_device(
client: "TrezorClient", session: "Session",
mnemonic: Union[str, Iterable[str]], mnemonic: str | t.Iterable[str],
pin: str | None, pin: str | None,
passphrase_protection: bool, passphrase_protection: bool,
label: str | None, label: str | None,
@ -1299,12 +1688,12 @@ def load_device(
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
if client.features.initialized: if session.features.initialized:
raise RuntimeError( raise RuntimeError(
"Device is initialized already. Call device.wipe() and try again." "Device is initialized already. Call device.wipe() and try again."
) )
client.call( session.call(
messages.LoadDevice( messages.LoadDevice(
mnemonics=mnemonics, mnemonics=mnemonics,
pin=pin, pin=pin,
@ -1316,18 +1705,18 @@ def load_device(
), ),
expect=messages.Success, expect=messages.Success,
) )
client.init_device() session.refresh_features()
# keep the old name for compatibility # keep the old name for compatibility
load_device_by_mnemonic = load_device load_device_by_mnemonic = load_device
def prodtest_t1(client: "TrezorClient") -> None: def prodtest_t1(session: "Session") -> None:
if client.features.bootloader_mode is not True: if session.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
client.call( session.call(
messages.ProdTestT1( messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
), ),
@ -1337,8 +1726,8 @@ def prodtest_t1(client: "TrezorClient") -> None:
def record_screen( def record_screen(
debug_client: "TrezorClientDebugLink", debug_client: "TrezorClientDebugLink",
directory: Union[str, None], directory: str | None,
report_func: Union[Callable[[str], None], None] = None, report_func: t.Callable[[str], None] | None = None,
) -> None: ) -> None:
"""Record screen changes into a specified directory. """Record screen changes into a specified directory.
@ -1383,5 +1772,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
return debug_client.features.fw_vendor == "EMULATOR" return debug_client.features.fw_vendor == "EMULATOR"
def optiga_set_sec_max(client: "TrezorClient") -> None: def optiga_set_sec_max(session: "Session") -> None:
client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) session.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)

View File

@ -28,16 +28,10 @@ from slip10 import SLIP10
from . import messages from . import messages
from .exceptions import Cancelled, TrezorException from .exceptions import Cancelled, TrezorException
from .tools import ( from .tools import Address, _deprecation_retval_helper, _return_success, parse_path
Address,
_deprecation_retval_helper,
_return_success,
parse_path,
session,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .transport.session import Session
RECOVERY_BACK = "\x08" # backspace character, sent literally RECOVERY_BACK = "\x08" # backspace character, sent literally
@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7) ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
@session
def apply_settings( def apply_settings(
client: "TrezorClient", session: "Session",
label: Optional[str] = None, label: Optional[str] = None,
language: Optional[str] = None, language: Optional[str] = None,
use_passphrase: Optional[bool] = None, use_passphrase: Optional[bool] = None,
@ -79,13 +72,13 @@ def apply_settings(
haptic_feedback=haptic_feedback, haptic_feedback=haptic_feedback,
) )
out = client.call(settings, expect=messages.Success) out = session.call(settings, expect=messages.Success)
client.refresh_features() session.refresh_features()
return _return_success(out) return _return_success(out)
def _send_language_data( def _send_language_data(
client: "TrezorClient", session: "Session",
request: "messages.TranslationDataRequest", request: "messages.TranslationDataRequest",
language_data: bytes, language_data: bytes,
) -> None: ) -> None:
@ -95,69 +88,63 @@ def _send_language_data(
data_length = response.data_length data_length = response.data_length
data_offset = response.data_offset data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length] chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk)) response = session.call(messages.TranslationDataAck(data_chunk=chunk))
@session
def change_language( def change_language(
client: "TrezorClient", session: "Session",
language_data: bytes, language_data: bytes,
show_display: bool | None = None, show_display: bool | None = None,
) -> str | None: ) -> str | None:
data_length = len(language_data) data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg) response = session.call(msg)
if data_length > 0: if data_length > 0:
response = messages.TranslationDataRequest.ensure_isinstance(response) response = messages.TranslationDataRequest.ensure_isinstance(response)
_send_language_data(client, response, language_data) _send_language_data(session, response, language_data)
else: else:
messages.Success.ensure_isinstance(response) messages.Success.ensure_isinstance(response)
client.refresh_features() # changing the language in features session.refresh_features() # changing the language in features
return _return_success(messages.Success(message="Language changed.")) return _return_success(messages.Success(message="Language changed."))
@session def apply_flags(session: "Session", flags: int) -> str | None:
def apply_flags(client: "TrezorClient", flags: int) -> str | None: out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success) session.refresh_features()
client.refresh_features()
return _return_success(out) return _return_success(out)
@session def change_pin(session: "Session", remove: bool = False) -> str | None:
def change_pin(client: "TrezorClient", remove: bool = False) -> str | None: ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success)
ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success) session.refresh_features()
client.refresh_features()
return _return_success(ret) return _return_success(ret)
@session def change_wipe_code(session: "Session", remove: bool = False) -> str | None:
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None: ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) session.refresh_features()
client.refresh_features()
return _return_success(ret) return _return_success(ret)
@session
def sd_protect( def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType session: "Session", operation: messages.SdProtectOperationType
) -> str | None: ) -> str | None:
ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success) ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success)
client.refresh_features() session.refresh_features()
return _return_success(ret) return _return_success(ret)
@session def wipe(session: "Session") -> str | None:
def wipe(client: "TrezorClient") -> str | None: ret = session.call(messages.WipeDevice(), expect=messages.Success)
ret = client.call(messages.WipeDevice(), expect=messages.Success) session.invalidate()
if not client.features.bootloader_mode: # if not session.features.bootloader_mode:
client.init_device() # session.refresh_features()
return _return_success(ret) return _return_success(ret)
@session
def recover( def recover(
client: "TrezorClient", session: "Session",
word_count: int = 24, word_count: int = 24,
passphrase_protection: bool = False, passphrase_protection: bool = False,
pin_protection: bool = True, pin_protection: bool = True,
@ -193,13 +180,13 @@ def recover(
if type is None: if type is None:
type = messages.RecoveryType.NormalRecovery type = messages.RecoveryType.NormalRecovery
if client.features.model == "1" and input_callback is None: if session.features.model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One") raise RuntimeError("Input callback required for Trezor One")
if word_count not in (12, 18, 24): if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24") raise ValueError("Invalid word count. Use 12/18/24")
if client.features.initialized and type == messages.RecoveryType.NormalRecovery: if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
raise RuntimeError( raise RuntimeError(
"Device already initialized. Call device.wipe() and try again." "Device already initialized. Call device.wipe() and try again."
) )
@ -221,20 +208,20 @@ def recover(
msg.label = label msg.label = label
msg.u2f_counter = u2f_counter msg.u2f_counter = u2f_counter
res = client.call(msg) res = session.call(msg)
while isinstance(res, messages.WordRequest): while isinstance(res, messages.WordRequest):
try: try:
assert input_callback is not None assert input_callback is not None
inp = input_callback(res.type) inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp)) res = session.call(messages.WordAck(word=inp))
except Cancelled: except Cancelled:
res = client.call(messages.Cancel()) res = session.call(messages.Cancel())
# check that the result is a Success # check that the result is a Success
res = messages.Success.ensure_isinstance(res) res = messages.Success.ensure_isinstance(res)
# reinitialize the device # reinitialize the device
client.init_device() session.refresh_features()
return _deprecation_retval_helper(res) return _deprecation_retval_helper(res)
@ -280,7 +267,7 @@ def _seed_from_entropy(
def reset( def reset(
client: "TrezorClient", session: "Session",
display_random: bool = False, display_random: bool = False,
strength: Optional[int] = None, strength: Optional[int] = None,
passphrase_protection: bool = False, passphrase_protection: bool = False,
@ -313,7 +300,7 @@ def reset(
) )
setup( setup(
client, session,
strength=strength, strength=strength,
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
pin_protection=pin_protection, pin_protection=pin_protection,
@ -331,9 +318,8 @@ def _get_external_entropy() -> bytes:
return secrets.token_bytes(32) return secrets.token_bytes(32)
@session
def setup( def setup(
client: "TrezorClient", session: "Session",
*, *,
strength: Optional[int] = None, strength: Optional[int] = None,
passphrase_protection: bool = True, passphrase_protection: bool = True,
@ -388,19 +374,19 @@ def setup(
check. check.
""" """
if client.features.initialized: if session.features.initialized:
raise RuntimeError( raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again." "Device is initialized already. Call wipe_device() and try again."
) )
if strength is None: if strength is None:
if client.features.model == "1": if session.features.model == "1":
strength = 256 strength = 256
else: else:
strength = 128 strength = 128
if backup_type is None: if backup_type is None:
if client.version < SLIP39_EXTENDABLE_MIN_VERSION: if session.version < SLIP39_EXTENDABLE_MIN_VERSION:
# includes Trezor One 1.x.x # includes Trezor One 1.x.x
backup_type = messages.BackupType.Bip39 backup_type = messages.BackupType.Bip39
else: else:
@ -411,7 +397,7 @@ def setup(
paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")] paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")]
if entropy_check_count is None: if entropy_check_count is None:
if client.version < ENTROPY_CHECK_MIN_VERSION: if session.version < ENTROPY_CHECK_MIN_VERSION:
# includes Trezor One 1.x.x # includes Trezor One 1.x.x
entropy_check_count = 0 entropy_check_count = 0
else: else:
@ -431,18 +417,18 @@ def setup(
) )
if entropy_check_count > 0: if entropy_check_count > 0:
xpubs = _reset_with_entropycheck( xpubs = _reset_with_entropycheck(
client, msg, entropy_check_count, paths, _get_entropy session, msg, entropy_check_count, paths, _get_entropy
) )
else: else:
_reset_no_entropycheck(client, msg, _get_entropy) _reset_no_entropycheck(session, msg, _get_entropy)
xpubs = [] xpubs = []
client.init_device() session.refresh_features()
return xpubs return xpubs
def _reset_no_entropycheck( def _reset_no_entropycheck(
client: "TrezorClient", session: "Session",
msg: messages.ResetDevice, msg: messages.ResetDevice,
get_entropy: Callable[[], bytes], get_entropy: Callable[[], bytes],
) -> None: ) -> None:
@ -454,12 +440,12 @@ def _reset_no_entropycheck(
<< Success << Success
""" """
assert msg.entropy_check is False assert msg.entropy_check is False
client.call(msg, expect=messages.EntropyRequest) session.call(msg, expect=messages.EntropyRequest)
client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success) session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
def _reset_with_entropycheck( def _reset_with_entropycheck(
client: "TrezorClient", session: "Session",
reset_msg: messages.ResetDevice, reset_msg: messages.ResetDevice,
entropy_check_count: int, entropy_check_count: int,
paths: Iterable[Address], paths: Iterable[Address],
@ -495,7 +481,7 @@ def _reset_with_entropycheck(
def get_xpubs() -> list[tuple[Address, str]]: def get_xpubs() -> list[tuple[Address, str]]:
xpubs = [] xpubs = []
for path in paths: for path in paths:
resp = client.call( resp = session.call(
messages.GetPublicKey(address_n=path), expect=messages.PublicKey messages.GetPublicKey(address_n=path), expect=messages.PublicKey
) )
xpubs.append((path, resp.xpub)) xpubs.append((path, resp.xpub))
@ -524,13 +510,13 @@ def _reset_with_entropycheck(
raise TrezorException("Invalid XPUB in entropy check") raise TrezorException("Invalid XPUB in entropy check")
xpubs = [] xpubs = []
resp = client.call(reset_msg, expect=messages.EntropyRequest) resp = session.call(reset_msg, expect=messages.EntropyRequest)
entropy_commitment = resp.entropy_commitment entropy_commitment = resp.entropy_commitment
while True: while True:
# provide external entropy for this round # provide external entropy for this round
external_entropy = get_entropy() external_entropy = get_entropy()
client.call( session.call(
messages.EntropyAck(entropy=external_entropy), messages.EntropyAck(entropy=external_entropy),
expect=messages.EntropyCheckReady, expect=messages.EntropyCheckReady,
) )
@ -540,7 +526,7 @@ def _reset_with_entropycheck(
if entropy_check_count <= 0: if entropy_check_count <= 0:
# last round, wait for a Success and exit the loop # last round, wait for a Success and exit the loop
client.call( session.call(
messages.EntropyCheckContinue(finish=True), messages.EntropyCheckContinue(finish=True),
expect=messages.Success, expect=messages.Success,
) )
@ -549,7 +535,7 @@ def _reset_with_entropycheck(
entropy_check_count -= 1 entropy_check_count -= 1
# Next round starts. # Next round starts.
resp = client.call( resp = session.call(
messages.EntropyCheckContinue(finish=False), messages.EntropyCheckContinue(finish=False),
expect=messages.EntropyRequest, expect=messages.EntropyRequest,
) )
@ -570,13 +556,12 @@ def _reset_with_entropycheck(
return xpubs return xpubs
@session
def backup( def backup(
client: "TrezorClient", session: "Session",
group_threshold: Optional[int] = None, group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (), groups: Iterable[tuple[int, int]] = (),
) -> str | None: ) -> str | None:
ret = client.call( ret = session.call(
messages.BackupDevice( messages.BackupDevice(
group_threshold=group_threshold, group_threshold=group_threshold,
groups=[ groups=[
@ -586,37 +571,36 @@ def backup(
), ),
expect=messages.Success, expect=messages.Success,
) )
client.refresh_features() session.refresh_features()
return _return_success(ret) return _return_success(ret)
def cancel_authorization(client: "TrezorClient") -> str | None: def cancel_authorization(session: "Session") -> str | None:
ret = client.call(messages.CancelAuthorization(), expect=messages.Success) ret = session.call(messages.CancelAuthorization(), expect=messages.Success)
return _return_success(ret) return _return_success(ret)
def unlock_path(client: "TrezorClient", n: "Address") -> bytes: def unlock_path(session: "Session", n: "Address") -> bytes:
resp = client.call( resp = session.call(
messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
) )
# Cancel the UnlockPath workflow now that we have the authentication code. # Cancel the UnlockPath workflow now that we have the authentication code.
try: try:
client.call(messages.Cancel()) session.call(messages.Cancel())
except Cancelled: except Cancelled:
return resp.mac return resp.mac
else: else:
raise TrezorException("Unexpected response in UnlockPath flow") raise TrezorException("Unexpected response in UnlockPath flow")
@session
def reboot_to_bootloader( def reboot_to_bootloader(
client: "TrezorClient", session: "Session",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None, firmware_header: Optional[bytes] = None,
language_data: bytes = b"", language_data: bytes = b"",
) -> str | None: ) -> str | None:
response = client.call( response = session.call(
messages.RebootToBootloader( messages.RebootToBootloader(
boot_command=boot_command, boot_command=boot_command,
firmware_header=firmware_header, firmware_header=firmware_header,
@ -624,43 +608,38 @@ def reboot_to_bootloader(
) )
) )
if isinstance(response, messages.TranslationDataRequest): if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data) response = _send_language_data(session, response, language_data)
return _return_success(messages.Success(message="")) return _return_success(messages.Success(message=""))
@session def show_device_tutorial(session: "Session") -> str | None:
def show_device_tutorial(client: "TrezorClient") -> str | None: ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success)
ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
return _return_success(ret) return _return_success(ret)
@session def unlock_bootloader(session: "Session") -> str | None:
def unlock_bootloader(client: "TrezorClient") -> str | None: ret = session.call(messages.UnlockBootloader(), expect=messages.Success)
ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
return _return_success(ret) return _return_success(ret)
@session def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None:
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
"""Sets or clears the busy state of the device. """Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen. In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state. Setting `expiry_ms=None` clears the busy state.
""" """
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
client.refresh_features() session.refresh_features()
return _return_success(ret) return _return_success(ret)
def authenticate( def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof:
client: "TrezorClient", challenge: bytes return session.call(
) -> messages.AuthenticityProof:
return client.call(
messages.AuthenticateDevice(challenge=challenge), messages.AuthenticateDevice(challenge=challenge),
expect=messages.AuthenticityProof, expect=messages.AuthenticityProof,
) )
def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None: def set_brightness(session: "Session", value: Optional[int] = None) -> str | None:
ret = client.call(messages.SetBrightness(value=value), expect=messages.Success) ret = session.call(messages.SetBrightness(value=value), expect=messages.Success)
return _return_success(ret) return _return_success(ret)

View File

@ -18,11 +18,11 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages from . import exceptions, messages
from .tools import b58decode, session from .tools import b58decode
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
def name_to_number(name: str) -> int: def name_to_number(name: str) -> int:
@ -319,17 +319,16 @@ def parse_transaction_json(
def get_public_key( def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False session: "Session", n: "Address", show_display: bool = False
) -> messages.EosPublicKey: ) -> messages.EosPublicKey:
return client.call( return session.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display), messages.EosGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EosPublicKey, expect=messages.EosPublicKey,
) )
@session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: "Address", address: "Address",
transaction: dict, transaction: dict,
chain_id: str, chain_id: str,
@ -345,11 +344,11 @@ def sign_tx(
chunkify=chunkify, chunkify=chunkify,
) )
response = client.call(msg) response = session.call(msg)
try: try:
while isinstance(response, messages.EosTxActionRequest): while isinstance(response, messages.EosTxActionRequest):
response = client.call(actions.pop(0)) response = session.call(actions.pop(0))
except IndexError: except IndexError:
# pop from empty list # pop from empty list
raise exceptions.TrezorException( raise exceptions.TrezorException(

View File

@ -18,11 +18,11 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages from . import definitions, exceptions, messages
from .tools import prepare_message_bytes, session, unharden from .tools import prepare_message_bytes, unharden
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
def int_to_big_endian(value: int) -> bytes: def int_to_big_endian(value: int) -> bytes:
@ -161,13 +161,13 @@ def network_from_address_n(
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
resp = client.call( resp = session.call(
messages.EthereumGetAddress( messages.EthereumGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
@ -181,17 +181,16 @@ def get_address(
def get_public_node( def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False session: "Session", n: "Address", show_display: bool = False
) -> messages.EthereumPublicKey: ) -> messages.EthereumPublicKey:
return client.call( return session.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display), messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EthereumPublicKey, expect=messages.EthereumPublicKey,
) )
@session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
nonce: int, nonce: int,
gas_price: int, gas_price: int,
@ -227,13 +226,13 @@ def sign_tx(
data, chunk = data[1024:], data[:1024] data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk msg.data_initial_chunk = chunk
response = client.call(msg) response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None assert response.signature_v is not None
@ -248,9 +247,8 @@ def sign_tx(
return response.signature_v, response.signature_r, response.signature_s return response.signature_v, response.signature_r, response.signature_s
@session
def sign_tx_eip1559( def sign_tx_eip1559(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
*, *,
nonce: int, nonce: int,
@ -283,13 +281,13 @@ def sign_tx_eip1559(
chunkify=chunkify, chunkify=chunkify,
) )
response = client.call(msg) response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None assert response.signature_v is not None
@ -299,13 +297,13 @@ def sign_tx_eip1559(
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
message: AnyStr, message: AnyStr,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> messages.EthereumMessageSignature: ) -> messages.EthereumMessageSignature:
return client.call( return session.call(
messages.EthereumSignMessage( messages.EthereumSignMessage(
address_n=n, address_n=n,
message=prepare_message_bytes(message), message=prepare_message_bytes(message),
@ -317,7 +315,7 @@ def sign_message(
def sign_typed_data( def sign_typed_data(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
data: Dict[str, Any], data: Dict[str, Any],
*, *,
@ -333,7 +331,7 @@ def sign_typed_data(
metamask_v4_compat=metamask_v4_compat, metamask_v4_compat=metamask_v4_compat,
definitions=definitions, definitions=definitions,
) )
response = client.call(request) response = session.call(request)
# Sending all the types # Sending all the types
while isinstance(response, messages.EthereumTypedDataStructRequest): while isinstance(response, messages.EthereumTypedDataStructRequest):
@ -349,7 +347,7 @@ def sign_typed_data(
members.append(struct_member) members.append(struct_member)
request = messages.EthereumTypedDataStructAck(members=members) request = messages.EthereumTypedDataStructAck(members=members)
response = client.call(request) response = session.call(request)
# Sending the whole message that should be signed # Sending the whole message that should be signed
while isinstance(response, messages.EthereumTypedDataValueRequest): while isinstance(response, messages.EthereumTypedDataValueRequest):
@ -362,7 +360,7 @@ def sign_typed_data(
member_typename = data["primaryType"] member_typename = data["primaryType"]
member_data = data["message"] member_data = data["message"]
else: else:
client.cancel() # TODO session.cancel()
raise exceptions.TrezorException("Root index can only be 0 or 1") raise exceptions.TrezorException("Root index can only be 0 or 1")
# It can be asking for a nested structure (the member path being [X, Y, Z, ...]) # It can be asking for a nested structure (the member path being [X, Y, Z, ...])
@ -385,20 +383,20 @@ def sign_typed_data(
encoded_data = encode_data(member_data, member_typename) encoded_data = encode_data(member_data, member_typename)
request = messages.EthereumTypedDataValueAck(value=encoded_data) request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request) response = session.call(request)
return messages.EthereumTypedDataSignature.ensure_isinstance(response) return messages.EthereumTypedDataSignature.ensure_isinstance(response)
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
address: str, address: str,
signature: bytes, signature: bytes,
message: AnyStr, message: AnyStr,
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
client.call( session.call(
messages.EthereumVerifyMessage( messages.EthereumVerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
@ -413,13 +411,13 @@ def verify_message(
def sign_typed_data_hash( def sign_typed_data_hash(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
domain_hash: bytes, domain_hash: bytes,
message_hash: Optional[bytes], message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
) -> messages.EthereumTypedDataSignature: ) -> messages.EthereumTypedDataSignature:
return client.call( return session.call(
messages.EthereumSignTypedHash( messages.EthereumSignTypedHash(
address_n=n, address_n=n,
domain_separator_hash=domain_hash, domain_separator_hash=domain_hash,

View File

@ -65,3 +65,7 @@ class UnexpectedMessageError(TrezorException):
self.expected = expected self.expected = expected
self.actual = actual self.actual = actual
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}") super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")
class DeviceLockedException(TrezorException):
pass

View File

@ -22,37 +22,37 @@ from . import messages
from .tools import _return_success from .tools import _return_success
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .transport.session import Session
def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]: def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]:
return client.call( return session.call(
messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
).credentials ).credentials
def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None: def add_credential(session: "Session", credential_id: bytes) -> str | None:
ret = client.call( ret = session.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id), messages.WebAuthnAddResidentCredential(credential_id=credential_id),
expect=messages.Success, expect=messages.Success,
) )
return _return_success(ret) return _return_success(ret)
def remove_credential(client: "TrezorClient", index: int) -> str | None: def remove_credential(session: "Session", index: int) -> str | None:
ret = client.call( ret = session.call(
messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
) )
return _return_success(ret) return _return_success(ret)
def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None: def set_counter(session: "Session", u2f_counter: int) -> str | None:
ret = client.call( ret = session.call(
messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
) )
return _return_success(ret) return _return_success(ret)
def get_next_counter(client: "TrezorClient") -> int: def get_next_counter(session: "Session") -> int:
ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
return ret.u2f_counter return ret.u2f_counter

View File

@ -20,7 +20,6 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard from typing_extensions import Protocol, TypeGuard
from .. import messages from .. import messages
from ..tools import session
from .core import VendorFirmware from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware from .legacy import LegacyFirmware, LegacyV2Firmware
@ -38,7 +37,7 @@ if True:
from .vendor import * # noqa: F401, F403 from .vendor import * # noqa: F401, F403
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
T = t.TypeVar("T", bound="FirmwareType") T = t.TypeVar("T", bound="FirmwareType")
@ -72,20 +71,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
# ====== Client functions ====== # # ====== Client functions ====== #
@session
def update( def update(
client: "TrezorClient", session: "Session",
data: bytes, data: bytes,
progress_update: t.Callable[[int], t.Any] = lambda _: None, progress_update: t.Callable[[int], t.Any] = lambda _: None,
): ):
if client.features.bootloader_mode is False: if session.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
resp = client.call(messages.FirmwareErase(length=len(data))) resp = session.call(messages.FirmwareErase(length=len(data)))
# TREZORv1 method # TREZORv1 method
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
resp = client.call(messages.FirmwareUpload(payload=data)) resp = session.call(messages.FirmwareUpload(payload=data))
progress_update(len(data)) progress_update(len(data))
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
return return
@ -97,7 +95,7 @@ def update(
length = resp.length length = resp.length
payload = data[resp.offset : resp.offset + length] payload = data[resp.offset : resp.offset + length]
digest = blake2s(payload).digest() digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
progress_update(length) progress_update(length)
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
@ -106,7 +104,7 @@ def update(
raise RuntimeError(f"Unexpected message {resp}") raise RuntimeError(f"Unexpected message {resp}")
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes: def get_hash(session: "Session", challenge: t.Optional[bytes]) -> bytes:
return client.call( return session.call(
messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
).hash ).hash

View File

@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import io import io
import logging
from types import ModuleType from types import ModuleType
from typing import Dict, Optional, Tuple, Type, TypeVar from typing import Dict, Optional, Tuple, Type, TypeVar
@ -25,6 +26,7 @@ from typing_extensions import Self
from . import messages, protobuf from . import messages, protobuf
T = TypeVar("T") T = TypeVar("T")
LOG = logging.getLogger(__name__)
class ProtobufMapping: class ProtobufMapping:
@ -63,11 +65,21 @@ class ProtobufMapping:
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
if wire_type is None: if wire_type is None:
raise ValueError("Cannot encode class without wire type") raise ValueError("Cannot encode class without wire type")
LOG.debug("encoding wire type %d", wire_type)
buf = io.BytesIO() buf = io.BytesIO()
protobuf.dump_message(buf, msg) protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue() return wire_type, buf.getvalue()
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
"""Serialize a Python protobuf class.
Returns the byte representation of the protobuf message.
"""
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return buf.getvalue()
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
"""Deserialize a protobuf message into a Python class.""" """Deserialize a protobuf message into a Python class."""
cls = self.type_to_class[msg_wire_type] cls = self.type_to_class[msg_wire_type]
@ -83,7 +95,9 @@ class ProtobufMapping:
mapping = cls() mapping = cls()
message_types = getattr(module, "MessageType") message_types = getattr(module, "MessageType")
for entry in message_types: thp_message_types = getattr(module, "ThpMessageType")
for entry in (*message_types, *thp_message_types):
msg_class = getattr(module, entry.name, None) msg_class = getattr(module, entry.name, None)
if msg_class is None: if msg_class is None:
raise ValueError( raise ValueError(

View File

@ -43,6 +43,10 @@ class FailureType(IntEnum):
PinMismatch = 12 PinMismatch = 12
WipeCodeMismatch = 13 WipeCodeMismatch = 13
InvalidSession = 14 InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
BufferError = 17
DeviceIsBusy = 18
FirmwareError = 99 FirmwareError = 99
@ -400,6 +404,34 @@ class TezosBallotType(IntEnum):
Pass = 2 Pass = 2
class ThpMessageType(IntEnum):
ThpCreateNewSession = 1000
ThpPairingRequest = 1006
ThpPairingRequestApproved = 1007
ThpSelectMethod = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceTrezor = 1018
ThpCodeEntryCpaceHostTag = 1019
ThpCodeEntrySecret = 1020
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcTagHost = 1032
ThpNfcTagTrezor = 1033
class ThpPairingMethod(IntEnum):
SkipPairing = 1
CodeEntry = 2
QrCode = 3
NFC = 4
class MessageType(IntEnum): class MessageType(IntEnum):
Initialize = 0 Initialize = 0
Ping = 1 Ping = 1
@ -500,6 +532,8 @@ class MessageType(IntEnum):
DebugLinkWatchLayout = 9006 DebugLinkWatchLayout = 9006
DebugLinkResetDebugEvents = 9007 DebugLinkResetDebugEvents = 9007
DebugLinkOptigaSetSecMax = 9008 DebugLinkOptigaSetSecMax = 9008
DebugLinkGetPairingInfo = 9009
DebugLinkPairingInfo = 9010
EthereumGetPublicKey = 450 EthereumGetPublicKey = 450
EthereumPublicKey = 451 EthereumPublicKey = 451
EthereumGetAddress = 56 EthereumGetAddress = 56
@ -4203,6 +4237,52 @@ class DebugLinkState(protobuf.MessageType):
self.mnemonic_type = mnemonic_type self.mnemonic_type = mnemonic_type
class DebugLinkGetPairingInfo(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 9009
FIELDS = {
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
channel_id: Optional["bytes"] = None,
handshake_hash: Optional["bytes"] = None,
nfc_secret_host: Optional["bytes"] = None,
) -> None:
self.channel_id = channel_id
self.handshake_hash = handshake_hash
self.nfc_secret_host = nfc_secret_host
class DebugLinkPairingInfo(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 9010
FIELDS = {
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None),
4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None),
5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
channel_id: Optional["bytes"] = None,
handshake_hash: Optional["bytes"] = None,
code_entry_code: Optional["int"] = None,
code_qr_code: Optional["bytes"] = None,
nfc_secret_trezor: Optional["bytes"] = None,
) -> None:
self.channel_id = channel_id
self.handshake_hash = handshake_hash
self.code_entry_code = code_entry_code
self.code_qr_code = code_qr_code
self.nfc_secret_trezor = nfc_secret_trezor
class DebugLinkStop(protobuf.MessageType): class DebugLinkStop(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 103 MESSAGE_WIRE_TYPE = 103
@ -7863,8 +7943,68 @@ class TezosManagerTransfer(protobuf.MessageType):
self.amount = amount self.amount = amount
class ThpCredentialMetadata(protobuf.MessageType): class ThpDeviceProperties(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None),
4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None),
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
internal_model: Optional["str"] = None,
model_variant: Optional["int"] = None,
protocol_version_major: Optional["int"] = None,
protocol_version_minor: Optional["int"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.internal_model = internal_model
self.model_variant = model_variant
self.protocol_version_major = protocol_version_major
self.protocol_version_minor = protocol_version_minor
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_pairing_credential: Optional["bytes"] = None,
) -> None:
self.host_pairing_credential = host_pairing_credential
class ThpCreateNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1000
FIELDS = {
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
passphrase: Optional["str"] = None,
on_device: Optional["bool"] = None,
derive_cardano: Optional["bool"] = None,
) -> None:
self.passphrase = passphrase
self.on_device = on_device
self.derive_cardano = derive_cardano
class ThpPairingRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1006
FIELDS = { FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
} }
@ -7877,6 +8017,216 @@ class ThpCredentialMetadata(protobuf.MessageType):
self.host_name = host_name self.host_name = host_name
class ThpPairingRequestApproved(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1007
class ThpSelectMethod(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008
FIELDS = {
1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
selected_pairing_method: Optional["ThpPairingMethod"] = None,
) -> None:
self.selected_pairing_method = selected_pairing_method
class ThpPairingPreparationsFinished(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1009
class ThpCodeEntryCommitment(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1016
FIELDS = {
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
commitment: Optional["bytes"] = None,
) -> None:
self.commitment = commitment
class ThpCodeEntryChallenge(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1017
FIELDS = {
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
challenge: Optional["bytes"] = None,
) -> None:
self.challenge = challenge
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1018
FIELDS = {
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_trezor_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_trezor_public_key = cpace_trezor_public_key
class ThpCodeEntryCpaceHostTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1019
FIELDS = {
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_host_public_key: Optional["bytes"] = None,
tag: Optional["bytes"] = None,
) -> None:
self.cpace_host_public_key = cpace_host_public_key
self.tag = tag
class ThpCodeEntrySecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1020
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpQrCodeTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1024
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpQrCodeSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1025
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpNfcTagHost(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1032
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpNfcTagTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1033
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpCredentialRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1010
FIELDS = {
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_static_pubkey: Optional["bytes"] = None,
autoconnect: Optional["bool"] = None,
) -> None:
self.host_static_pubkey = host_static_pubkey
self.autoconnect = autoconnect
class ThpCredentialResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1011
FIELDS = {
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
trezor_static_pubkey: Optional["bytes"] = None,
credential: Optional["bytes"] = None,
) -> None:
self.trezor_static_pubkey = trezor_static_pubkey
self.credential = credential
class ThpEndRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1012
class ThpEndResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1013
class ThpCredentialMetadata(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["str"] = None,
autoconnect: Optional["bool"] = None,
) -> None:
self.host_name = host_name
self.autoconnect = autoconnect
class ThpPairingCredential(protobuf.MessageType): class ThpPairingCredential(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None MESSAGE_WIRE_TYPE = None
FIELDS = { FIELDS = {

View File

@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional
from . import messages from . import messages
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
def get_entropy(client: "TrezorClient", size: int) -> bytes: def get_entropy(session: "Session", size: int) -> bytes:
return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
def sign_identity( def sign_identity(
client: "TrezorClient", session: "Session",
identity: messages.IdentityType, identity: messages.IdentityType,
challenge_hidden: bytes, challenge_hidden: bytes,
challenge_visual: str, challenge_visual: str,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> messages.SignedIdentity: ) -> messages.SignedIdentity:
return client.call( return session.call(
messages.SignIdentity( messages.SignIdentity(
identity=identity, identity=identity,
challenge_hidden=challenge_hidden, challenge_hidden=challenge_hidden,
@ -46,12 +46,12 @@ def sign_identity(
def get_ecdh_session_key( def get_ecdh_session_key(
client: "TrezorClient", session: "Session",
identity: messages.IdentityType, identity: messages.IdentityType,
peer_public_key: bytes, peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> messages.ECDHSessionKey: ) -> messages.ECDHSessionKey:
return client.call( return session.call(
messages.GetECDHSessionKey( messages.GetECDHSessionKey(
identity=identity, identity=identity,
peer_public_key=peer_public_key, peer_public_key=peer_public_key,
@ -62,7 +62,7 @@ def get_ecdh_session_key(
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
key: str, key: str,
value: bytes, value: bytes,
@ -70,7 +70,7 @@ def encrypt_keyvalue(
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> bytes: ) -> bytes:
return client.call( return session.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
@ -85,7 +85,7 @@ def encrypt_keyvalue(
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
key: str, key: str,
value: bytes, value: bytes,
@ -93,7 +93,7 @@ def decrypt_keyvalue(
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> bytes: ) -> bytes:
return client.call( return session.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
@ -107,5 +107,5 @@ def decrypt_keyvalue(
).value ).value
def get_nonce(client: "TrezorClient") -> bytes: def get_nonce(session: "Session") -> bytes:
return client.call(messages.GetNonce(), expect=messages.Nonce).nonce return session.call(messages.GetNonce(), expect=messages.Nonce).nonce

View File

@ -19,8 +19,8 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
# MAINNET = 0 # MAINNET = 0
@ -30,13 +30,13 @@ if TYPE_CHECKING:
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False, chunkify: bool = False,
) -> bytes: ) -> bytes:
return client.call( return session.call(
messages.MoneroGetAddress( messages.MoneroGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
@ -48,11 +48,11 @@ def get_address(
def get_watch_key( def get_watch_key(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> messages.MoneroWatchKey: ) -> messages.MoneroWatchKey:
return client.call( return session.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type), messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
expect=messages.MoneroWatchKey, expect=messages.MoneroWatchKey,
) )

View File

@ -20,8 +20,8 @@ from typing import TYPE_CHECKING
from . import exceptions, messages from . import exceptions, messages
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801 TYPE_IMPORTANCE_TRANSFER = 0x0801
@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
network: int, network: int,
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
messages.NEMGetAddress( messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify address_n=n, network=network, show_display=show_display, chunkify=chunkify
), ),
@ -210,7 +210,7 @@ def get_address(
def sign_tx( def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False session: "Session", n: "Address", transaction: dict, chunkify: bool = False
) -> messages.NEMSignedTx: ) -> messages.NEMSignedTx:
try: try:
msg = create_sign_tx(transaction, chunkify=chunkify) msg = create_sign_tx(transaction, chunkify=chunkify)
@ -219,4 +219,4 @@ def sign_tx(
assert msg.transaction is not None assert msg.transaction is not None
msg.transaction.address_n = n msg.transaction.address_n = n
return client.call(msg, expect=messages.NEMSignedTx) return session.call(msg, expect=messages.NEMSignedTx)

View File

@ -21,20 +21,20 @@ from .protobuf import dict_to_proto
from .tools import dict_from_camelcase from .tools import dict_from_camelcase
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
messages.RippleGetAddress( messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
), ),
@ -43,14 +43,14 @@ def get_address(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
msg: messages.RippleSignTx, msg: messages.RippleSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> messages.RippleSignedTx: ) -> messages.RippleSignedTx:
msg.address_n = address_n msg.address_n = address_n
msg.chunkify = chunkify msg.chunkify = chunkify
return client.call(msg, expect=messages.RippleSignedTx) return session.call(msg, expect=messages.RippleSignedTx)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional
from . import messages from . import messages
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .transport.session import Session
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
) -> bytes: ) -> bytes:
return client.call( return session.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display), messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
expect=messages.SolanaPublicKey, expect=messages.SolanaPublicKey,
).public_key ).public_key
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
messages.SolanaGetAddress( messages.SolanaGetAddress(
address_n=address_n, address_n=address_n,
show_display=show_display, show_display=show_display,
@ -34,12 +34,12 @@ def get_address(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
serialized_tx: bytes, serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo], additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> bytes: ) -> bytes:
return client.call( return session.call(
messages.SolanaSignTx( messages.SolanaSignTx(
address_n=address_n, address_n=address_n,
serialized_tx=serialized_tx, serialized_tx=serialized_tx,

View File

@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union
from . import exceptions, messages from . import exceptions, messages
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
StellarMessageType = Union[ StellarMessageType = Union[
messages.StellarAccountMergeOp, messages.StellarAccountMergeOp,
@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
messages.StellarGetAddress( messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
), ),
@ -336,7 +336,7 @@ def get_address(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
tx: messages.StellarSignTx, tx: messages.StellarSignTx,
operations: List["StellarMessageType"], operations: List["StellarMessageType"],
address_n: "Address", address_n: "Address",
@ -352,10 +352,10 @@ def sign_tx(
# 3. Receive a StellarTxOpRequest message # 3. Receive a StellarTxOpRequest message
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
# 5. The final message received will be StellarSignedTx which is returned from this method # 5. The final message received will be StellarSignedTx which is returned from this method
resp = client.call(tx) resp = session.call(tx)
try: try:
while isinstance(resp, messages.StellarTxOpRequest): while isinstance(resp, messages.StellarTxOpRequest):
resp = client.call(operations.pop(0)) resp = session.call(operations.pop(0))
except IndexError: except IndexError:
# pop from empty list # pop from empty list
raise exceptions.TrezorException( raise exceptions.TrezorException(

View File

@ -19,17 +19,17 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address from .tools import Address
from .transport.session import Session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
messages.TezosGetAddress( messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
), ),
@ -38,12 +38,12 @@ def get_address(
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> str: ) -> str:
return client.call( return session.call(
messages.TezosGetPublicKey( messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
), ),
@ -52,11 +52,11 @@ def get_public_key(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
sign_tx_msg: messages.TezosSignTx, sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> messages.TezosSignedTx: ) -> messages.TezosSignedTx:
sign_tx_msg.address_n = address_n sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg, expect=messages.TezosSignedTx) return session.call(sign_tx_msg, expect=messages.TezosSignedTx)

View File

@ -45,7 +45,7 @@ if TYPE_CHECKING:
# More details: https://www.python.org/dev/peps/pep-0612/ # More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec from typing_extensions import ParamSpec
from . import client from . import client
from .messages import Success from .messages import Success
@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None:
return _deprecation_retval_helper(msg.message, stacklevel=1) return _deprecation_retval_helper(msg.message, stacklevel=1)
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
return f(client, *args, **kwargs)
finally:
client.close()
return wrapped_f
# de-camelcasifier # de-camelcasifier
# https://stackoverflow.com/a/1176023/222189 # https://stackoverflow.com/a/1176023/222189