1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 15:30:55 +00:00

feat(python): implement session based trezorlib

[no changelog]
This commit is contained in:
M1nd3r 2024-12-02 15:44:10 +01:00
parent 4a18f67f8f
commit 6b1fc71ce3
61 changed files with 4063 additions and 1663 deletions

View File

@ -95,6 +95,15 @@ class Emulator:
raise RuntimeError raise RuntimeError
return self._client return self._client
@client.setter
def client(self, new_client: TrezorClientDebugLink) -> None:
"""Setter for the client property to update _client."""
if not isinstance(new_client, TrezorClientDebugLink):
raise TypeError(
f"Expected a TrezorClientDebugLink, got {type(new_client).__name__}."
)
self._client = new_client
def make_args(self) -> List[str]: def make_args(self) -> List[str]:
return [] return []
@ -112,7 +121,7 @@ class Emulator:
start = time.monotonic() start = time.monotonic()
try: try:
while True: while True:
if transport._ping(): if transport.ping():
break break
if self.process.poll() is not None: if self.process.poll() is not None:
raise RuntimeError("Emulator process died") raise RuntimeError("Emulator process died")

View File

@ -7,7 +7,7 @@ import typing as t
from importlib import metadata from importlib import metadata
from . import device from . import device
from .client import TrezorClient from .transport.session import Session
try: try:
cryptography_version = metadata.version("cryptography") cryptography_version = metadata.version("cryptography")
@ -361,7 +361,7 @@ def verify_authentication_response(
def authenticate_device( def authenticate_device(
client: TrezorClient, session: Session,
challenge: bytes | None = None, challenge: bytes | None = None,
*, *,
whitelist: t.Collection[bytes] | None = None, whitelist: t.Collection[bytes] | None = None,
@ -371,7 +371,7 @@ def authenticate_device(
if challenge is None: if challenge is None:
challenge = secrets.token_bytes(16) challenge = secrets.token_bytes(16)
resp = device.authenticate(client, challenge) resp = device.authenticate(session, challenge)
return verify_authentication_response( return verify_authentication_response(
challenge, challenge,

View File

@ -20,17 +20,17 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.session import Session
@expect(messages.BenchmarkNames) @expect(messages.BenchmarkNames)
def list_names( def list_names(
client: "TrezorClient", session: "Session",
) -> "MessageType": ) -> "MessageType":
return client.call(messages.BenchmarkListNames()) return session.call(messages.BenchmarkListNames())
@expect(messages.BenchmarkResult) @expect(messages.BenchmarkResult)
def run(client: "TrezorClient", name: str) -> "MessageType": def run(session: "Session", name: str) -> "MessageType":
return client.call(messages.BenchmarkRun(name=name)) return session.call(messages.BenchmarkRun(name=name))

View File

@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
from .protobuf import dict_to_proto from .protobuf import dict_to_proto
from .tools import expect, session from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
@expect(messages.BinanceAddress, field="address", ret_type=str) @expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.BinanceGetAddress( messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -42,16 +42,15 @@ def get_address(
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) @expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key( def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False session: "Session", address_n: "Address", show_display: bool = False
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
) )
@session
def sign_tx( def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
) -> messages.BinanceSignedTx: ) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0] msg = tx_json["msgs"][0]
tx_msg = tx_json.copy() tx_msg = tx_json.copy()
@ -60,7 +59,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
response = client.call(envelope) response = session.call(envelope)
if not isinstance(response, messages.BinanceTxRequest): if not isinstance(response, messages.BinanceTxRequest):
raise RuntimeError( raise RuntimeError(
@ -77,7 +76,7 @@ def sign_tx(
else: else:
raise ValueError("can not determine msg type") raise ValueError("can not determine msg type")
response = client.call(msg) response = session.call(msg)
if not isinstance(response, messages.BinanceSignedTx): if not isinstance(response, messages.BinanceSignedTx):
raise RuntimeError( raise RuntimeError(

View File

@ -13,7 +13,6 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import warnings import warnings
from copy import copy from copy import copy
from decimal import Decimal from decimal import Decimal
@ -23,12 +22,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict from typing_extensions import Protocol, TypedDict
from . import exceptions, messages from . import exceptions, messages
from .tools import expect, prepare_message_bytes, session from .tools import expect, prepare_message_bytes
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
class ScriptSig(TypedDict): class ScriptSig(TypedDict):
asm: str asm: str
@ -105,7 +104,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
@expect(messages.PublicKey) @expect(messages.PublicKey)
def get_public_node( def get_public_node(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
show_display: bool = False, show_display: bool = False,
@ -116,13 +115,13 @@ def get_public_node(
unlock_path_mac: Optional[bytes] = None, unlock_path_mac: Optional[bytes] = None,
) -> "MessageType": ) -> "MessageType":
if unlock_path: if unlock_path:
res = client.call( res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
) )
if not isinstance(res, messages.UnlockedPathRequest): if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
return client.call( return session.call(
messages.GetPublicKey( messages.GetPublicKey(
address_n=n, address_n=n,
ecdsa_curve_name=ecdsa_curve_name, ecdsa_curve_name=ecdsa_curve_name,
@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any):
@expect(messages.Address) @expect(messages.Address)
def get_authenticated_address( def get_authenticated_address(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
@ -153,13 +152,13 @@ def get_authenticated_address(
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
if unlock_path: if unlock_path:
res = client.call( res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
) )
if not isinstance(res, messages.UnlockedPathRequest): if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
return client.call( return session.call(
messages.GetAddress( messages.GetAddress(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -172,15 +171,16 @@ def get_authenticated_address(
) )
# TODO this is used by tests only
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id( def get_ownership_id(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.GetOwnershipId( messages.GetOwnershipId(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -190,8 +190,9 @@ def get_ownership_id(
) )
# TODO this is used by tests only
def get_ownership_proof( def get_ownership_proof(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
@ -202,11 +203,11 @@ def get_ownership_proof(
preauthorized: bool = False, preauthorized: bool = False,
) -> Tuple[bytes, bytes]: ) -> Tuple[bytes, bytes]:
if preauthorized: if preauthorized:
res = client.call(messages.DoPreauthorized()) res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest): if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
res = client.call( res = session.call(
messages.GetOwnershipProof( messages.GetOwnershipProof(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -226,7 +227,7 @@ def get_ownership_proof(
@expect(messages.MessageSignature) @expect(messages.MessageSignature)
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
message: AnyStr, message: AnyStr,
@ -234,7 +235,7 @@ def sign_message(
no_script_type: bool = False, no_script_type: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SignMessage( messages.SignMessage(
coin_name=coin_name, coin_name=coin_name,
address_n=n, address_n=n,
@ -247,7 +248,7 @@ def sign_message(
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
address: str, address: str,
signature: bytes, signature: bytes,
@ -255,7 +256,7 @@ def verify_message(
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
resp = client.call( resp = session.call(
messages.VerifyMessage( messages.VerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
@ -269,9 +270,9 @@ def verify_message(
return isinstance(resp, messages.Success) return isinstance(resp, messages.Success)
@session # @session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
inputs: Sequence[messages.TxInputType], inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType], outputs: Sequence[messages.TxOutputType],
@ -319,17 +320,17 @@ def sign_tx(
setattr(signtx, name, value) setattr(signtx, name, value)
if unlock_path: if unlock_path:
res = client.call( res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
) )
if not isinstance(res, messages.UnlockedPathRequest): if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
elif preauthorized: elif preauthorized:
res = client.call(messages.DoPreauthorized()) res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest): if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
res = client.call(signtx) res = session.call(signtx)
# Prepare structure for signatures # Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs) signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -388,7 +389,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ: if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index] msg = payment_reqs[res.details.request_index]
res = client.call(msg) res = session.call(msg)
else: else:
msg = messages.TransactionType() msg = messages.TransactionType()
if res.request_type == R.TXMETA: if res.request_type == R.TXMETA:
@ -418,7 +419,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}." f"Unknown request type - {res.request_type}."
) )
res = client.call(messages.TxAck(tx=msg)) res = session.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest): if not isinstance(res, messages.TxRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
@ -432,7 +433,7 @@ def sign_tx(
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin( def authorize_coinjoin(
client: "TrezorClient", session: "Session",
coordinator: str, coordinator: str,
max_rounds: int, max_rounds: int,
max_coordinator_fee_rate: int, max_coordinator_fee_rate: int,
@ -441,7 +442,7 @@ def authorize_coinjoin(
coin_name: str, coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.AuthorizeCoinJoin( messages.AuthorizeCoinJoin(
coordinator=coordinator, coordinator=coordinator,
max_rounds=max_rounds, max_rounds=max_rounds,

View File

@ -35,8 +35,8 @@ from . import exceptions, messages, tools
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.session import Session
PROTOCOL_MAGICS = { PROTOCOL_MAGICS = {
"mainnet": 764824073, "mainnet": 764824073,
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
@expect(messages.CardanoAddress, field="address", ret_type=str) @expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_parameters: messages.CardanoAddressParametersType, address_parameters: messages.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
@ -833,7 +833,7 @@ def get_address(
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CardanoGetAddress( messages.CardanoGetAddress(
address_parameters=address_parameters, address_parameters=address_parameters,
protocol_magic=protocol_magic, protocol_magic=protocol_magic,
@ -847,12 +847,12 @@ def get_address(
@expect(messages.CardanoPublicKey) @expect(messages.CardanoPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
show_display: bool = False, show_display: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CardanoGetPublicKey( messages.CardanoGetPublicKey(
address_n=address_n, address_n=address_n,
derivation_type=derivation_type, derivation_type=derivation_type,
@ -863,12 +863,12 @@ def get_public_key(
@expect(messages.CardanoNativeScriptHash) @expect(messages.CardanoNativeScriptHash)
def get_native_script_hash( def get_native_script_hash(
client: "TrezorClient", session: "Session",
native_script: messages.CardanoNativeScript, native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CardanoGetNativeScriptHash( messages.CardanoGetNativeScriptHash(
script=native_script, script=native_script,
display_format=display_format, display_format=display_format,
@ -878,7 +878,7 @@ def get_native_script_hash(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
signing_mode: messages.CardanoTxSigningMode, signing_mode: messages.CardanoTxSigningMode,
inputs: List[InputWithPath], inputs: List[InputWithPath],
outputs: List[OutputWithData], outputs: List[OutputWithData],
@ -915,7 +915,7 @@ def sign_tx(
signing_mode, signing_mode,
) )
response = client.call( response = session.call(
messages.CardanoSignTxInit( messages.CardanoSignTxInit(
signing_mode=signing_mode, signing_mode=signing_mode,
inputs_count=len(inputs), inputs_count=len(inputs),
@ -951,14 +951,14 @@ def sign_tx(
_get_certificates_items(certificates), _get_certificates_items(certificates),
withdrawals, withdrawals,
): ):
response = client.call(tx_item) response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response: Dict[str, Any] = {} sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None: if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data) auxiliary_data_supplement = session.call(auxiliary_data)
if not isinstance( if not isinstance(
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
): ):
@ -971,7 +971,7 @@ def sign_tx(
auxiliary_data_supplement.__dict__ auxiliary_data_supplement.__dict__
) )
response = client.call(messages.CardanoTxHostAck()) response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
@ -980,24 +980,24 @@ def sign_tx(
_get_collateral_inputs_items(collateral_inputs), _get_collateral_inputs_items(collateral_inputs),
required_signers, required_signers,
): ):
response = client.call(tx_item) response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
if collateral_return is not None: if collateral_return is not None:
for tx_item in _get_output_items(collateral_return): for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item) response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
for reference_input in reference_inputs: for reference_input in reference_inputs:
response = client.call(reference_input) response = session.call(reference_input)
if not isinstance(response, messages.CardanoTxItemAck): if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"] = [] sign_tx_response["witnesses"] = []
for witness_request in witness_requests: for witness_request in witness_requests:
response = client.call(witness_request) response = session.call(witness_request)
if not isinstance(response, messages.CardanoTxWitnessResponse): if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"].append( sign_tx_response["witnesses"].append(
@ -1009,12 +1009,12 @@ def sign_tx(
} }
) )
response = client.call(messages.CardanoTxHostAck()) response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxBodyHash): if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["tx_hash"] = response.tx_hash sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck()) response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoSignTxFinished): if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR raise UNEXPECTED_RESPONSE_ERROR

View File

@ -14,33 +14,42 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import functools import functools
import logging
import os
import sys import sys
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click import click
from .. import exceptions, transport from .. import exceptions, transport, ui
from ..client import TrezorClient from ..client import ProtocolVersion, TrezorClient
from ..ui import ClickUI, ScriptUI from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2
from ..transport.thp.channel_database import get_channel_db
if TYPE_CHECKING: LOG = logging.getLogger(__name__)
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators # Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/ # More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = t.TypeVar("R")
FuncWithSession = t.Callable[Concatenate[Session, P], R]
class ChoiceType(click.Choice): class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys())) super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive self.case_sensitive = case_sensitive
if case_sensitive: if case_sensitive:
@ -48,7 +57,7 @@ class ChoiceType(click.Choice):
else: else:
self.typemap = {k.lower(): v for k, v in typemap.items()} self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values(): if value in self.typemap.values():
return value return value
value = super().convert(value, param, ctx) value = super().convert(value, param, ctx)
@ -57,11 +66,69 @@ class ChoiceType(click.Choice):
return self.typemap[value] return self.typemap[value]
def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
stored_channels = get_channel_db().load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
try:
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
get_channel_db().remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
return client
class TrezorConnection: class TrezorConnection:
def __init__( def __init__(
self, self,
path: str, path: str,
session_id: Optional[bytes], session_id: bytes | None,
passphrase_on_host: bool, passphrase_on_host: bool,
script: bool, script: bool,
) -> None: ) -> None:
@ -70,6 +137,54 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host self.passphrase_on_host = passphrase_on_host
self.script = script self.script = script
def get_session(
self,
derive_cardano: bool = False,
empty_passphrase: bool = False,
must_resume: bool = False,
) -> Session:
client = self.get_client()
if must_resume and self.session_id is None:
click.echo("Failed to resume session - no session id provided")
raise RuntimeError("Failed to resume session - no session id provided")
# Try resume session from id
if self.session_id is not None:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:
raise Exception("Unsupported client protocol", client.protocol_version)
if must_resume:
if session.id != self.session_id or session.id is None:
click.echo("Failed to resume session")
RuntimeError("Failed to resume session - no session id provided")
return session
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=derive_cardano)
if empty_passphrase:
passphrase = ""
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
return session
def get_transport(self) -> "Transport": def get_transport(self) -> "Transport":
try: try:
# look for transport without prefix search # look for transport without prefix search
@ -82,19 +197,13 @@ class TrezorConnection:
# if this fails, we want the exception to bubble up to the caller # if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True) return transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI":
if self.script:
# It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods)
return ScriptUI
else:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self) -> TrezorClient: def get_client(self) -> TrezorClient:
transport = self.get_transport() return get_client(self.get_transport())
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id) def get_management_session(self) -> Session:
client = self.get_client()
management_session = client.get_management_session()
return management_session
@contextmanager @contextmanager
def client_context(self): def client_context(self):
@ -128,7 +237,57 @@ class TrezorConnection:
# other exceptions may cause a traceback # other exceptions may cause a traceback
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": def with_session(
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
*,
empty_passphrase: bool = False,
derive_cardano: bool = False,
management: bool = False,
must_resume: bool = False,
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
"""Provides a Click command with parameter `session=obj.get_session(...)` or
`session=obj.get_management_session()` based on the parameters provided.
If default parameters are ok, this decorator can be used without parentheses.
TODO: handle resumption of sessions and their (potential) closure.
"""
def decorator(
func: FuncWithSession,
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
if management:
session = obj.get_management_session()
else:
session = obj.get_session(
derive_cardano=derive_cardano,
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
return function_with_session
# If the decorator @get_session is used without parentheses
if func and callable(func):
return decorator(func) # type: ignore [Function return type]
return decorator
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`. """Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume Sessions are handled transparently. The user is warned when session did not resume
@ -142,23 +301,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R": ) -> "R":
with obj.client_context() as client: with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id # session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None: # if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed # # tried to resume but failed
click.echo("Warning: failed to resume session.", err=True) # click.echo("Warning: failed to resume session.", err=True)
click.echo(
"Warning: resume session detection is not implemented yet!", err=True
)
try: try:
return func(client, *args, **kwargs) return func(client, *args, **kwargs)
finally: finally:
if not session_was_resumed: if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
try: get_channel_db().save_channel(client.protocol)
client.end_session() # if not session_was_resumed:
except Exception: # try:
pass # client.end_session()
# except Exception:
# pass
return trezorctl_command_with_client return trezorctl_command_with_client
# def with_client(
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
# ) -> "t.Callable[P, R]":
# """Wrap a Click command in `with obj.client_context() as client`.
# Sessions are handled transparently. The user is warned when session did not resume
# cleanly. The session is closed after the command completes - unless the session
# was resumed, in which case it should remain open.
# """
# @click.pass_obj
# @functools.wraps(func)
# def trezorctl_command_with_client(
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
# ) -> "R":
# with obj.client_context() as client:
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
# try:
# return func(client, *args, **kwargs)
# finally:
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
# return trezorctl_command_with_client
class AliasedGroup(click.Group): class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility. """Command group that handles aliases and Click 6.x compatibility.
@ -188,14 +386,14 @@ class AliasedGroup(click.Group):
def __init__( def __init__(
self, self,
aliases: Optional[Dict[str, click.Command]] = None, aliases: t.Dict[str, click.Command] | None = None,
*args: Any, *args: t.Any,
**kwargs: Any, **kwargs: t.Any,
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.aliases = aliases or {} self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-") cmd_name = cmd_name.replace("_", "-")
# try to look up the real name # try to look up the real name
cmd = super().get_command(ctx, cmd_name) cmd = super().get_command(ctx, cmd_name)

View File

@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
import click import click
from .. import benchmark from .. import benchmark
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
def list_names_patern( def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
client: "TrezorClient", pattern: Optional[str] = None names = list(benchmark.list_names(session).names)
) -> List[str]:
names = list(benchmark.list_names(client).names)
if pattern is None: if pattern is None:
return names return names
return [name for name in names if fnmatch(name, pattern)] return [name for name in names if fnmatch(name, pattern)]
@ -43,10 +41,10 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_client @with_session(empty_passphrase=True)
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks""" """List names of all supported benchmarks"""
names = list_names_patern(client, pattern) names = list_names_patern(session, pattern)
if len(names) == 0: if len(names) == 0:
click.echo("No benchmark satisfies the pattern.") click.echo("No benchmark satisfies the pattern.")
else: else:
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_client @with_session(empty_passphrase=True)
def run(client: "TrezorClient", pattern: Optional[str]) -> None: def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark""" """Run benchmark"""
names = list_names_patern(client, pattern) names = list_names_patern(session, pattern)
if len(names) == 0: if len(names) == 0:
click.echo("No benchmark satisfies the pattern.") click.echo("No benchmark satisfies the pattern.")
else: else:
for name in names: for name in names:
result = benchmark.run(client, name) result = benchmark.run(session, name)
click.echo(f"{name}: {result.value} {result.unit}") click.echo(f"{name}: {result.value} {result.unit}")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import binance, tools from .. import binance, tools
from . import with_client from ..transport.session import Session
from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import messages from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Binance address for specified path.""" """Get Binance address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify) return binance.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key.""" """Get Binance public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex() return binance.get_public_key(session, address_n, show_display).hex()
@cli.command() @cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx": ) -> "messages.BinanceSignedTx":
"""Sign Binance transaction. """Sign Binance transaction.
Transaction must be provided as a JSON file. Transaction must be provided as a JSON file.
""" """
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -13,6 +13,7 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64 import base64
import json import json
@ -22,10 +23,10 @@ import click
import construct as c import construct as c
from .. import btc, messages, protobuf, tools from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PURPOSE_BIP44 = 44 PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48 PURPOSE_BIP48 = 48
@ -174,15 +175,15 @@ def cli() -> None:
help="Sort pubkeys lexicographically using BIP-67", help="Sort pubkeys lexicographically using BIP-67",
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
script_type: Optional[messages.InputScriptType], script_type: messages.InputScriptType | None,
show_display: bool, show_display: bool,
multisig_xpub: List[str], multisig_xpub: List[str],
multisig_threshold: Optional[int], multisig_threshold: int | None,
multisig_suffix_length: int, multisig_suffix_length: int,
multisig_sort_pubkeys: bool, multisig_sort_pubkeys: bool,
chunkify: bool, chunkify: bool,
@ -235,7 +236,7 @@ def get_address(
multisig = None multisig = None
return btc.get_address( return btc.get_address(
client, session,
coin, coin,
address_n, address_n,
show_display, show_display,
@ -252,9 +253,9 @@ def get_address(
@click.option("-e", "--curve") @click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_node( def get_public_node(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
curve: Optional[str], curve: Optional[str],
@ -266,7 +267,7 @@ def get_public_node(
if script_type is None: if script_type is None:
script_type = guess_script_type_from_path(address_n) script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node( result = btc.get_public_node(
client, session,
address_n, address_n,
ecdsa_curve_name=curve, ecdsa_curve_name=curve,
show_display=show_display, show_display=show_display,
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor( def _get_descriptor(
client: "TrezorClient", session: "Session",
coin: Optional[str], coin: Optional[str],
account: int, account: int,
purpose: Optional[int], purpose: Optional[int],
@ -326,7 +327,7 @@ def _get_descriptor(
n = tools.parse_path(path) n = tools.parse_path(path)
pub = btc.get_public_node( pub = btc.get_public_node(
client, session,
n, n,
show_display=show_display, show_display=show_display,
coin_name=coin, coin_name=coin,
@ -363,9 +364,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_descriptor( def get_descriptor(
client: "TrezorClient", session: "Session",
coin: Optional[str], coin: Optional[str],
account: int, account: int,
account_type: Optional[int], account_type: Optional[int],
@ -375,7 +376,7 @@ def get_descriptor(
"""Get descriptor of given account.""" """Get descriptor of given account."""
try: try:
return _get_descriptor( return _get_descriptor(
client, coin, account, account_type, script_type, show_display session, coin, account, account_type, script_type, show_display
) )
except ValueError as e: except ValueError as e:
raise click.ClickException(str(e)) raise click.ClickException(str(e))
@ -390,8 +391,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File()) @click.argument("json_file", type=click.File())
@with_client @with_session
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction. """Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
} }
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, session,
coin, coin,
inputs, inputs,
outputs, outputs,
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("message") @click.argument("message")
@with_client @with_session
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
message: str, message: str,
@ -462,7 +463,7 @@ def sign_message(
if script_type is None: if script_type is None:
script_type = guess_script_type_from_path(address_n) script_type = guess_script_type_from_path(address_n)
res = btc.sign_message( res = btc.sign_message(
client, session,
coin, coin,
address_n, address_n,
message, message,
@ -483,9 +484,9 @@ def sign_message(
@click.argument("address") @click.argument("address")
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_session
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
signature: str, signature: str,
@ -495,7 +496,7 @@ def verify_message(
"""Verify message.""" """Verify message."""
signature_bytes = base64.b64decode(signature) signature_bytes = base64.b64decode(signature)
return btc.verify_message( return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify session, coin, address, signature_bytes, message, chunkify=chunkify
) )

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click import click
from .. import cardano, messages, tools from .. import cardano, messages, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True) @click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client @with_session(derive_cardano=True)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
file: TextIO, file: TextIO,
signing_mode: messages.CardanoTxSigningMode, signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int, protocol_magic: int,
@ -123,9 +123,8 @@ def sign_tx(
for p in transaction["additional_witness_requests"] for p in transaction["additional_witness_requests"]
] ]
client.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx( sign_tx_response = cardano.sign_tx(
client, session,
signing_mode, signing_mode,
inputs, inputs,
outputs, outputs,
@ -209,9 +208,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session(derive_cardano=True)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
address_type: messages.CardanoAddressType, address_type: messages.CardanoAddressType,
staking_address: str, staking_address: str,
@ -262,9 +261,8 @@ def get_address(
script_staking_hash_bytes, script_staking_hash_bytes,
) )
client.init_device(derive_cardano=True)
return cardano.get_address( return cardano.get_address(
client, session,
address_parameters, address_parameters,
protocol_magic, protocol_magic,
network_id, network_id,
@ -283,18 +281,17 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session(derive_cardano=True)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address: str, address: str,
derivation_type: messages.CardanoDerivationType, derivation_type: messages.CardanoDerivationType,
show_display: bool, show_display: bool,
) -> messages.CardanoPublicKey: ) -> messages.CardanoPublicKey:
"""Get Cardano public key.""" """Get Cardano public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
return cardano.get_public_key( return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display session, address_n, derivation_type=derivation_type, show_display=show_display
) )
@ -312,9 +309,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS, default=messages.CardanoDerivationType.ICARUS,
) )
@with_client @with_session(derive_cardano=True)
def get_native_script_hash( def get_native_script_hash(
client: "TrezorClient", session: "Session",
file: TextIO, file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat, display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType, derivation_type: messages.CardanoDerivationType,
@ -323,7 +320,6 @@ def get_native_script_hash(
native_script_json = json.load(file) native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json) native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True)
return cardano.get_native_script_hash( return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type session, native_script, display_format, derivation_type=derivation_type
) )

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click import click
from .. import misc, tools from .. import misc, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PROMPT_TYPE = ChoiceType( PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("size", type=int) @click.argument("size", type=int)
@with_client @with_session(empty_passphrase=True)
def get_entropy(client: "TrezorClient", size: int) -> str: def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device.""" """Get random bytes from device."""
return misc.get_entropy(client, size).hex() return misc.get_entropy(session, size).hex()
@cli.command() @cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_session(empty_passphrase=True)
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", session: "Session",
address: str, address: str,
key: str, key: str,
value: str, value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.encrypt_keyvalue( return misc.encrypt_keyvalue(
client, session,
address_n, address_n,
key, key,
value.encode(), value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
) )
@click.argument("key") @click.argument("key")
@click.argument("value") @click.argument("value")
@with_client @with_session(empty_passphrase=True)
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", session: "Session",
address: str, address: str,
key: str, key: str,
value: str, value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return misc.decrypt_keyvalue( return misc.decrypt_keyvalue(
client, session,
address_n, address_n,
key, key,
bytes.fromhex(value), bytes.fromhex(value),

View File

@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
import click import click
from .. import mapping, messages, protobuf
from ..client import TrezorClient
from ..debuglink import TrezorClientDebugLink from ..debuglink import TrezorClientDebugLink
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
from ..debuglink import record_screen from ..debuglink import record_screen
from . import with_client from ..transport.session import Session
from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from . import TrezorConnection from . import TrezorConnection
@ -35,51 +34,51 @@ def cli() -> None:
"""Miscellaneous debug features.""" """Miscellaneous debug features."""
@cli.command() # @cli.command()
@click.argument("message_name_or_type") # @click.argument("message_name_or_type")
@click.argument("hex_data") # @click.argument("hex_data")
@click.pass_obj # @click.pass_obj
def send_bytes( # def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str # obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
) -> None: # ) -> None:
"""Send raw bytes to Trezor. # """Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message # Message type and message data must be specified separately, due to how message
chunking works on the transport level. Message length is calculated and sent # chunking works on the transport level. Message length is calculated and sent
automatically, and it is currently impossible to explicitly specify invalid length. # automatically, and it is currently impossible to explicitly specify invalid length.
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, # MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
in which case the value of that enum is used. # in which case the value of that enum is used.
""" # """
if message_name_or_type.isdigit(): # if message_name_or_type.isdigit():
message_type = int(message_name_or_type) # message_type = int(message_name_or_type)
else: # else:
message_type = getattr(messages.MessageType, message_name_or_type) # message_type = getattr(messages.MessageType, message_name_or_type)
if not isinstance(message_type, int): # if not isinstance(message_type, int):
raise click.ClickException("Invalid message type.") # raise click.ClickException("Invalid message type.")
try: # try:
message_data = bytes.fromhex(hex_data) # message_data = bytes.fromhex(hex_data)
except Exception as e: # except Exception as e:
raise click.ClickException("Invalid hex data.") from e # raise click.ClickException("Invalid hex data.") from e
transport = obj.get_transport() # transport = obj.get_transport()
transport.begin_session() # transport.deprecated_begin_session()
transport.write(message_type, message_data) # transport.write(message_type, message_data)
response_type, response_data = transport.read() # response_type, response_data = transport.read()
transport.end_session() # transport.deprecated_end_session()
click.echo(f"Response type: {response_type}") # click.echo(f"Response type: {response_type}")
click.echo(f"Response data: {response_data.hex()}") # click.echo(f"Response data: {response_data.hex()}")
try: # try:
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) # msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
click.echo("Parsed message:") # click.echo("Parsed message:")
click.echo(protobuf.format_message(msg)) # click.echo(protobuf.format_message(msg))
except Exception as e: # except Exception as e:
click.echo(f"Could not parse response: {e}") # click.echo(f"Could not parse response: {e}")
@cli.command() @cli.command()
@ -106,17 +105,17 @@ def record_screen_from_connection(
@cli.command() @cli.command()
@with_client @with_session(management=True)
def prodtest_t1(client: "TrezorClient") -> str: def prodtest_t1(session: "Session") -> str:
"""Perform a prodtest on Model One. """Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test. Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
""" """
return debuglink_prodtest_t1(client) return debuglink_prodtest_t1(session)
@cli.command() @cli.command()
@with_client @with_session(management=True)
def optiga_set_sec_max(client: "TrezorClient") -> str: def optiga_set_sec_max(session: "Session") -> str:
"""Set Optiga's security event counter to maximum.""" """Set Optiga's security event counter to maximum."""
return debuglink_optiga_set_sec_max(client) return debuglink_optiga_set_sec_max(session)

View File

@ -25,11 +25,11 @@ import requests
from .. import debuglink, device, exceptions, messages, ui from .. import debuglink, device, exceptions, messages, ui
from ..tools import format_path from ..tools import format_path
from . import ChoiceType, with_client from . import ChoiceType, with_session
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..protobuf import MessageType from ..protobuf import MessageType
from ..transport.session import Session
from . import TrezorConnection from . import TrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = { RECOVERY_DEVICE_INPUT_METHOD = {
@ -65,17 +65,18 @@ def cli() -> None:
help="Wipe device in bootloader mode. This also erases the firmware.", help="Wipe device in bootloader mode. This also erases the firmware.",
is_flag=True, is_flag=True,
) )
@with_client @with_session(management=True)
def wipe(client: "TrezorClient", bootloader: bool) -> str: def wipe(session: "Session", bootloader: bool) -> str:
"""Reset device to factory defaults and remove all private data.""" """Reset device to factory defaults and remove all private data."""
features = session.features
if bootloader: if bootloader:
if not client.features.bootloader_mode: if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.") click.echo("Please switch your device to bootloader mode.")
sys.exit(1) sys.exit(1)
else: else:
click.echo("Wiping user data and firmware!") click.echo("Wiping user data and firmware!")
else: else:
if client.features.bootloader_mode: if features.bootloader_mode:
click.echo( click.echo(
"Your device is in bootloader mode. This operation would also erase firmware." "Your device is in bootloader mode. This operation would also erase firmware."
) )
@ -88,7 +89,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
click.echo("Wiping user data!") click.echo("Wiping user data!")
try: try:
return device.wipe(client) return device.wipe(
session
) # TODO decide where the wipe should happen - management or regular session
except exceptions.TrezorFailure as e: except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args)) click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3) sys.exit(3)
@ -104,9 +107,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
@click.option("-a", "--academic", is_flag=True) @click.option("-a", "--academic", is_flag=True)
@click.option("-b", "--needs-backup", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True)
@with_client @with_session(management=True)
def load( def load(
client: "TrezorClient", session: "Session",
mnemonic: t.Sequence[str], mnemonic: t.Sequence[str],
pin: str, pin: str,
passphrase_protection: bool, passphrase_protection: bool,
@ -137,7 +140,7 @@ def load(
try: try:
return debuglink.load_device( return debuglink.load_device(
client, session,
mnemonic=list(mnemonic), mnemonic=list(mnemonic),
pin=pin, pin=pin,
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
@ -172,9 +175,9 @@ def load(
) )
@click.option("-d", "--dry-run", is_flag=True) @click.option("-d", "--dry-run", is_flag=True)
@click.option("-b", "--unlock-repeated-backup", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True)
@with_client @with_session(management=True)
def recover( def recover(
client: "TrezorClient", session: "Session",
words: str, words: str,
expand: bool, expand: bool,
pin_protection: bool, pin_protection: bool,
@ -202,7 +205,7 @@ def recover(
type = messages.RecoveryType.UnlockRepeatedBackup type = messages.RecoveryType.UnlockRepeatedBackup
return device.recover( return device.recover(
client, session,
word_count=int(words), word_count=int(words),
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
pin_protection=pin_protection, pin_protection=pin_protection,
@ -224,9 +227,9 @@ def recover(
@click.option("-n", "--no-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True)
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
@click.option("-e", "--entropy-check-count", type=click.IntRange(0)) @click.option("-e", "--entropy-check-count", type=click.IntRange(0))
@with_client @with_session(management=True)
def setup( def setup(
client: "TrezorClient", session: "Session",
strength: int | None, strength: int | None,
passphrase_protection: bool, passphrase_protection: bool,
pin_protection: bool, pin_protection: bool,
@ -244,7 +247,7 @@ def setup(
BT = messages.BackupType BT = messages.BackupType
if backup_type is None: if backup_type is None:
if client.version >= (2, 7, 1): if session.version >= (2, 7, 1):
# SLIP39 extendable was introduced in 2.7.1 # SLIP39 extendable was introduced in 2.7.1
backup_type = BT.Slip39_Single_Extendable backup_type = BT.Slip39_Single_Extendable
else: else:
@ -254,10 +257,10 @@ def setup(
if ( if (
backup_type backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
and messages.Capability.Shamir not in client.features.capabilities and messages.Capability.Shamir not in session.features.capabilities
) or ( ) or (
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
and messages.Capability.ShamirGroups not in client.features.capabilities and messages.Capability.ShamirGroups not in session.features.capabilities
): ):
click.echo( click.echo(
"WARNING: Your Trezor device does not indicate support for the requested\n" "WARNING: Your Trezor device does not indicate support for the requested\n"
@ -265,7 +268,7 @@ def setup(
) )
resp, path_xpubs = device.reset_entropy_check( resp, path_xpubs = device.reset_entropy_check(
client, session,
strength=strength, strength=strength,
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
pin_protection=pin_protection, pin_protection=pin_protection,
@ -289,23 +292,21 @@ def setup(
@cli.command() @cli.command()
@click.option("-t", "--group-threshold", type=int) @click.option("-t", "--group-threshold", type=int)
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
@with_client @with_session(management=True)
def backup( def backup(
client: "TrezorClient", session: "Session",
group_threshold: int | None = None, group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (), groups: t.Sequence[tuple[int, int]] = (),
) -> str: ) -> str:
"""Perform device seed backup.""" """Perform device seed backup."""
return device.backup(client, group_threshold, groups) return device.backup(session, group_threshold, groups)
@cli.command() @cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client @with_session(management=True)
def sd_protect( def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str:
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> str:
"""Secure the device with SD card protection. """Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored When SD card protection is enabled, a randomly generated secret is stored
@ -319,9 +320,9 @@ def sd_protect(
off - Remove SD card secret protection. off - Remove SD card secret protection.
refresh - Replace the current SD card secret with a new one. refresh - Replace the current SD card secret with a new one.
""" """
if client.features.model == "1": if session.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.") raise click.ClickException("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation) return device.sd_protect(session, operation)
@cli.command() @cli.command()
@ -331,24 +332,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str:
Currently only supported on Trezor Model One. Currently only supported on Trezor Model One.
""" """
# avoid using @with_client because it closes the session afterwards, # avoid using @with_management_session because it closes the session afterwards,
# which triggers double prompt on device # which triggers double prompt on device
with obj.client_context() as client: with obj.client_context() as client:
return device.reboot_to_bootloader(client) return device.reboot_to_bootloader(client.get_management_session())
@cli.command() @cli.command()
@with_client @with_session(management=True)
def tutorial(client: "TrezorClient") -> str: def tutorial(session: "Session") -> str:
"""Show on-device tutorial.""" """Show on-device tutorial."""
return device.show_device_tutorial(client) return device.show_device_tutorial(session)
@cli.command() @cli.command()
@with_client @with_session(management=True)
def unlock_bootloader(client: "TrezorClient") -> str: def unlock_bootloader(session: "Session") -> str:
"""Unlocks bootloader. Irreversible.""" """Unlocks bootloader. Irreversible."""
return device.unlock_bootloader(client) return device.unlock_bootloader(session)
@cli.command() @cli.command()
@ -359,11 +360,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
type=int, type=int,
help="Dialog expiry in seconds.", help="Dialog expiry in seconds.",
) )
@with_client @with_session(management=True)
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str: def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str:
"""Show a "Do not disconnect" dialog.""" """Show a "Do not disconnect" dialog."""
if enable is False: if enable is False:
return device.set_busy(client, None) return device.set_busy(session, None)
if expiry is None: if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.") raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -373,7 +374,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
) )
return device.set_busy(client, expiry * 1000) return device.set_busy(session, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = ( PUBKEY_WHITELIST_URL_TEMPLATE = (
@ -393,9 +394,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
is_flag=True, is_flag=True,
help="Do not check intermediate certificates against the whitelist.", help="Do not check intermediate certificates against the whitelist.",
) )
@with_client @with_session(management=True)
def authenticate( def authenticate(
client: "TrezorClient", session: "Session",
hex_challenge: str | None, hex_challenge: str | None,
root: t.BinaryIO | None, root: t.BinaryIO | None,
raw: bool | None, raw: bool | None,
@ -420,7 +421,7 @@ def authenticate(
challenge = bytes.fromhex(hex_challenge) challenge = bytes.fromhex(hex_challenge)
if raw: if raw:
msg = device.authenticate(client, challenge) msg = device.authenticate(session, challenge)
click.echo(f"Challenge: {hex_challenge}") click.echo(f"Challenge: {hex_challenge}")
click.echo(f"Signature of challenge: {msg.signature.hex()}") click.echo(f"Signature of challenge: {msg.signature.hex()}")
@ -468,14 +469,14 @@ def authenticate(
else: else:
whitelist_json = requests.get( whitelist_json = requests.get(
PUBKEY_WHITELIST_URL_TEMPLATE.format( PUBKEY_WHITELIST_URL_TEMPLATE.format(
model=client.model.internal_name.lower() model=session.model.internal_name.lower()
) )
).json() ).json()
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
try: try:
authentication.authenticate_device( authentication.authenticate_device(
client, challenge, root_pubkey=root_bytes, whitelist=whitelist session, challenge, root_pubkey=root_bytes, whitelist=whitelist
) )
except authentication.DeviceNotAuthentic: except authentication.DeviceNotAuthentic:
click.echo("Device is not authentic.") click.echo("Device is not authentic.")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import eos, tools from .. import eos, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import messages from .. import messages
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding.""" """Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display) res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_transaction( def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx": ) -> "messages.EosSignedTx":
"""Sign EOS transaction.""" """Sign EOS transaction."""
tx_json = json.load(file) tx_json = json.load(file)
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return eos.sign_tx( return eos.sign_tx(
client, session,
address_n, address_n,
tx_json["transaction"], tx_json["transaction"],
tx_json["chain_id"], tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions from ..messages import EthereumDefinitions
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
import web3 import web3
from eth_typing import ChecksumAddress # noqa: I900 from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei from web3.types import Wei
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Ethereum address in hex encoding.""" """Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify) return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path.""" """Get Ethereum public node of given path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display) result = ethereum.get_public_node(session, address_n, show_display=show_display)
return { return {
"node": { "node": {
"depth": result.node.depth, "depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address") @click.argument("to_address")
@click.argument("amount", callback=_amount_to_int) @click.argument("amount", callback=_amount_to_int)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
chain_id: int, chain_id: int,
address: str, address: str,
amount: int, amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
from_address = ethereum.get_address( from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network session, address_n, encoded_network=encoded_network
) )
if token: if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None assert max_gas_fee is not None
assert max_priority_fee is not None assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559( sig = ethereum.sign_tx_eip1559(
client, session,
n=address_n, n=address_n,
nonce=nonce, nonce=nonce,
gas_limit=gas_limit, gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price gas_price = _get_web3().eth.gas_price
assert gas_price is not None assert gas_price is not None
sig = ethereum.sign_tx( sig = ethereum.sign_tx(
client, session,
n=address_n, n=address_n,
tx_type=tx_type, tx_type=tx_type,
nonce=nonce, nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("message") @click.argument("message")
@with_client @with_session
def sign_message( def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Sign message with Ethereum address.""" """Sign message with Ethereum address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = { output = {
"message": message, "message": message,
"address": ret.address, "address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation", help="Be compatible with Metamask's signTypedData_v4 implementation",
) )
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@with_client @with_session
def sign_typed_data( def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address. """Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network) defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read()) data = json.loads(file.read())
ret = ethereum.sign_typed_data( ret = ethereum.sign_typed_data(
client, session,
address_n, address_n,
data, data,
metamask_v4_compat=metamask_v4_compat, metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address") @click.argument("address")
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_session
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
address: str, address: str,
signature: str, signature: str,
message: str, message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address.""" """Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature) signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message( return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify session, address, signature_bytes, message, chunkify=chunkify
) )
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex") @click.argument("domain_hash_hex")
@click.argument("message_hash_hex") @click.argument("message_hash_hex")
@with_client @with_session
def sign_typed_data_hash( def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]: ) -> Dict[str, str]:
""" """
Sign hash of typed data (EIP-712) with Ethereum address. Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash( ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network session, address_n, domain_hash, message_hash, network
) )
output = { output = {
"domain_hash": domain_hash_hex, "domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click import click
from .. import fido from .. import fido
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list") @credentials.command(name="list")
@with_client @with_session(empty_passphrase=True)
def credentials_list(client: "TrezorClient") -> None: def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device.""" """List all resident credentials on the device."""
creds = fido.list_credentials(client) creds = fido.list_credentials(session)
for cred in creds: for cred in creds:
click.echo("") click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:") click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add") @credentials.command(name="add")
@click.argument("hex_credential_id") @click.argument("hex_credential_id")
@with_client @with_session(empty_passphrase=True)
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: def credentials_add(session: "Session", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential. """Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
""" """
return fido.add_credential(client, bytes.fromhex(hex_credential_id)) return fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove") @credentials.command(name="remove")
@click.option( @click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
) )
@with_client @with_session(empty_passphrase=True)
def credentials_remove(client: "TrezorClient", index: int) -> str: def credentials_remove(session: "Session", index: int) -> str:
"""Remove the resident credential at the given index.""" """Remove the resident credential at the given index."""
return fido.remove_credential(client, index) return fido.remove_credential(session, index)
# #
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set") @counter.command(name="set")
@click.argument("counter", type=int) @click.argument("counter", type=int)
@with_client @with_session(empty_passphrase=True)
def counter_set(client: "TrezorClient", counter: int) -> str: def counter_set(session: "Session", counter: int) -> str:
"""Set FIDO/U2F counter value.""" """Set FIDO/U2F counter value."""
return fido.set_counter(client, counter) return fido.set_counter(session, counter)
@counter.command(name="get-next") @counter.command(name="get-next")
@with_client @with_session(empty_passphrase=True)
def counter_get_next(client: "TrezorClient") -> int: def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter. """Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation is returned and atomically increased. This command performs the same operation
and returns the counter value. and returns the counter value.
""" """
return fido.get_next_counter(client) return fido.get_next_counter(session)

View File

@ -37,10 +37,11 @@ import requests
from .. import device, exceptions, firmware, messages, models from .. import device, exceptions, firmware, messages, models
from ..firmware import models as fw_models from ..firmware import models as fw_models
from ..models import TrezorModel from ..models import TrezorModel
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..client import TrezorClient
from ..transport.session import Session
from . import TrezorConnection from . import TrezorConnection
MODEL_CHOICE = ChoiceType( MODEL_CHOICE = ChoiceType(
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader. 1.8.0 because that installs the appropriate bootloader.
""" """
f = client.features features = client.features
version = (f.major_version, f.minor_version, f.patch_version) version = client.version
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2 return bootloader_onev2
@ -306,25 +307,26 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version If the specified version is not found, prints the closest available version
(higher than the specified one, if existing). (higher than the specified one, if existing).
""" """
features = client.features
model = client.model
if bitcoin_only is None: if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features) bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str: def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version)) return ".".join(map(str, version))
f = client.features releases = get_all_firmware_releases(model, bitcoin_only, beta)
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
highest_version = releases[0]["version"] highest_version = releases[0]["version"]
if version: if version:
want_version = [int(x) for x in version.split(".")] want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3: if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.") click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version: if want_version[0] != features.major_version:
click.echo( click.echo(
f"Warning: Trezor {client.model.name} firmware version should be " f"Warning: Trezor {model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})" f"{features.major_version}.X.Y (requested: {version})"
) )
else: else:
want_version = highest_version want_version = highest_version
@ -359,8 +361,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal # to the newer one, in that case update to the minimal
# compatible version first # compatible version first
# Choosing the version key to compare based on (not) being in BL mode # Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version] client_version = client.version
if f.bootloader_mode: if features.bootloader_mode:
key_to_compare = "min_bootloader_version" key_to_compare = "min_bootloader_version"
else: else:
key_to_compare = "min_firmware_version" key_to_compare = "min_firmware_version"
@ -447,11 +449,11 @@ def extract_embedded_fw(
def upload_firmware_into_device( def upload_firmware_into_device(
client: "TrezorClient", session: "Session",
firmware_data: bytes, firmware_data: bytes,
) -> None: ) -> None:
"""Perform the final act of loading the firmware into Trezor.""" """Perform the final act of loading the firmware into Trezor."""
f = client.features f = session.features
try: try:
if f.major_version == 1 and f.firmware_present is not False: if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest # Trezor One does not send ButtonRequest
@ -461,7 +463,7 @@ def upload_firmware_into_device(
with click.progressbar( with click.progressbar(
label="Uploading", length=len(firmware_data), show_eta=False label="Uploading", length=len(firmware_data), show_eta=False
) as bar: ) as bar:
firmware.update(client, firmware_data, bar.update) firmware.update(session, firmware_data, bar.update)
except exceptions.Cancelled: except exceptions.Cancelled:
click.echo("Update aborted on device.") click.echo("Update aborted on device.")
except exceptions.TrezorException as e: except exceptions.TrezorException as e:
@ -654,6 +656,7 @@ def update(
against data.trezor.io information, if available. against data.trezor.io information, if available.
""" """
with obj.client_context() as client: with obj.client_context() as client:
management_session = client.get_management_session()
if sum(bool(x) for x in (filename, url, version)) > 1: if sum(bool(x) for x in (filename, url, version)) > 1:
click.echo("You can use only one of: filename, url, version.") click.echo("You can use only one of: filename, url, version.")
sys.exit(1) sys.exit(1)
@ -709,7 +712,7 @@ def update(
if _is_strict_update(client, firmware_data): if _is_strict_update(client, firmware_data):
header_size = _get_firmware_header_size(firmware_data) header_size = _get_firmware_header_size(firmware_data)
device.reboot_to_bootloader( device.reboot_to_bootloader(
client, management_session,
boot_command=messages.BootCommand.INSTALL_UPGRADE, boot_command=messages.BootCommand.INSTALL_UPGRADE,
firmware_header=firmware_data[:header_size], firmware_header=firmware_data[:header_size],
language_data=language_data, language_data=language_data,
@ -719,7 +722,7 @@ def update(
click.echo( click.echo(
"WARNING: Seamless installation not possible, language data will not be uploaded." "WARNING: Seamless installation not possible, language data will not be uploaded."
) )
device.reboot_to_bootloader(client) device.reboot_to_bootloader(management_session)
click.echo("Waiting for bootloader...") click.echo("Waiting for bootloader...")
while True: while True:
@ -735,13 +738,15 @@ def update(
click.echo("Please switch your device to bootloader mode.") click.echo("Please switch your device to bootloader mode.")
sys.exit(1) sys.exit(1)
upload_firmware_into_device(client=client, firmware_data=firmware_data) upload_firmware_into_device(
session=client.get_management_session(), firmware_data=firmware_data
)
@cli.command() @cli.command()
@click.argument("hex_challenge", required=False) @click.argument("hex_challenge", required=False)
@with_client @with_session(management=True)
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
"""Get a hash of the installed firmware combined with the optional challenge.""" """Get a hash of the installed firmware combined with the optional challenge."""
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
return firmware.get_hash(client, challenge).hex() return firmware.get_hash(session, challenge).hex()

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click import click
from .. import messages, monero, tools from .. import messages, monero, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET, default=messages.MoneroNetworkType.MAINNET,
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
network_type: messages.MoneroNetworkType, network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes: ) -> bytes:
"""Get Monero address for specified path.""" """Get Monero address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify) return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command() @cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET, default=messages.MoneroNetworkType.MAINNET,
) )
@with_client @with_session
def get_watch_key( def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Get Monero watch key for specified path.""" """Get Monero watch key for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type) res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey # TODO: could be made required in MoneroWatchKey
assert res.address is not None assert res.address is not None
assert res.watch_key is not None assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests import requests
from .. import nem, tools from .. import nem, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68) @click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
network: int, network: int,
show_display: bool, show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str: ) -> str:
"""Get NEM address for specified path.""" """Get NEM address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify) return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command() @cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: str, address: str,
file: TextIO, file: TextIO,
broadcast: Optional[str], broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format. Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
""" """
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import ripple, tools from .. import ripple, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Ripple address""" """Get Ripple address"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify) return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction""" """Sign Ripple transaction"""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file)) msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:") click.echo("Signature:")
click.echo(result.signature.hex()) click.echo(result.signature.hex())
click.echo() click.echo()

View File

@ -24,10 +24,11 @@ import click
import requests import requests
from .. import device, messages, toif from .. import device, messages, toif
from . import AliasedGroup, ChoiceType, with_client from ..transport.session import Session
from . import AliasedGroup, ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient pass
try: try:
from PIL import Image from PIL import Image
@ -180,18 +181,18 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True) @click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client @with_session(management=True)
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: def pin(session: "Session", enable: Optional[bool], remove: bool) -> str:
"""Set, change or remove PIN.""" """Set, change or remove PIN."""
# Remove argument is there for backwards compatibility # Remove argument is there for backwards compatibility
return device.change_pin(client, remove=_should_remove(enable, remove)) return device.change_pin(session, remove=_should_remove(enable, remove))
@cli.command() @cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True) @click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client @with_session(management=True)
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str:
"""Set or remove the wipe code. """Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
removed and the device will be reset to factory defaults. removed and the device will be reset to factory defaults.
""" """
# Remove argument is there for backwards compatibility # Remove argument is there for backwards compatibility
return device.change_wipe_code(client, remove=_should_remove(enable, remove)) return device.change_wipe_code(session, remove=_should_remove(enable, remove))
@cli.command() @cli.command()
# keep the deprecated -l/--label option, make it do nothing # keep the deprecated -l/--label option, make it do nothing
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label") @click.argument("label")
@with_client @with_session(management=True)
def label(client: "TrezorClient", label: str) -> str: def label(session: "Session", label: str) -> str:
"""Set new device label.""" """Set new device label."""
return device.apply_settings(client, label=label) return device.apply_settings(session, label=label)
@cli.command() @cli.command()
@with_client @with_session(management=True)
def brightness(client: "TrezorClient") -> str: def brightness(session: "Session") -> str:
"""Set display brightness.""" """Set display brightness."""
return device.set_brightness(client) return device.set_brightness(session)
@cli.command() @cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False})) @click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client @with_session(management=True)
def haptic_feedback(client: "TrezorClient", enable: bool) -> str: def haptic_feedback(session: "Session", enable: bool) -> str:
"""Enable or disable haptic feedback.""" """Enable or disable haptic feedback."""
return device.apply_settings(client, haptic_feedback=enable) return device.apply_settings(session, haptic_feedback=enable)
@cli.command() @cli.command()
@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
"-r", "--remove", is_flag=True, default=False, help="Switch back to english." "-r", "--remove", is_flag=True, default=False, help="Switch back to english."
) )
@click.option("-d/-D", "--display/--no-display", default=None) @click.option("-d/-D", "--display/--no-display", default=None)
@with_client @with_session(management=True)
def language( def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None session: "Session", path_or_url: str | None, remove: bool, display: bool | None
) -> str: ) -> str:
"""Set new language with translations.""" """Set new language with translations."""
if remove != (path_or_url is None): if remove != (path_or_url is None):
@ -260,29 +261,29 @@ def language(
f"Failed to load translations from {path_or_url}" f"Failed to load translations from {path_or_url}"
) from None ) from None
return device.change_language( return device.change_language(
client, language_data=language_data, show_display=display session, language_data=language_data, show_display=display
) )
@cli.command() @cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION)) @click.argument("rotation", type=ChoiceType(ROTATION))
@with_client @with_session(management=True)
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str: def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str:
"""Set display rotation. """Set display rotation.
Configure display rotation for Trezor Model T. The options are Configure display rotation for Trezor Model T. The options are
north, east, south or west. north, east, south or west.
""" """
return device.apply_settings(client, display_rotation=rotation) return device.apply_settings(session, display_rotation=rotation)
@cli.command() @cli.command()
@click.argument("delay", type=str) @click.argument("delay", type=str)
@with_client @with_session(management=True)
def auto_lock_delay(client: "TrezorClient", delay: str) -> str: def auto_lock_delay(session: "Session", delay: str) -> str:
"""Set auto-lock delay (in seconds).""" """Set auto-lock delay (in seconds)."""
if not client.features.pin_protection: if not session.features.pin_protection:
raise click.ClickException("Set up a PIN first") raise click.ClickException("Set up a PIN first")
value, unit = delay[:-1], delay[-1:] value, unit = delay[:-1], delay[-1:]
@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
seconds = float(value) * units[unit] seconds = float(value) * units[unit]
else: else:
seconds = float(delay) # assume seconds if no unit is specified seconds = float(delay) # assume seconds if no unit is specified
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
@cli.command() @cli.command()
@click.argument("flags") @click.argument("flags")
@with_client @with_session(management=True)
def flags(client: "TrezorClient", flags: str) -> str: def flags(session: "Session", flags: str) -> str:
"""Set device flags.""" """Set device flags."""
if flags.lower().startswith("0b"): if flags.lower().startswith("0b"):
flags_int = int(flags, 2) flags_int = int(flags, 2)
@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
flags_int = int(flags, 16) flags_int = int(flags, 16)
else: else:
flags_int = int(flags) flags_int = int(flags)
return device.apply_flags(client, flags=flags_int) return device.apply_flags(session, flags=flags_int)
@cli.command() @cli.command()
@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str:
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
) )
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client @with_session(management=True)
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: def homescreen(session: "Session", filename: str, quality: int) -> str:
"""Set new homescreen. """Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default' To revert to default homescreen, use 'trezorctl set homescreen default'
@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
if not path.exists() or not path.is_file(): if not path.exists() or not path.is_file():
raise click.ClickException("Cannot open file") raise click.ClickException("Cannot open file")
if client.features.model == "1": if session.features.model == "1":
img = image_to_t1(path) img = image_to_t1(path)
else: else:
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
width = ( width = (
client.features.homescreen_width session.features.homescreen_width
if client.features.homescreen_width is not None if session.features.homescreen_width is not None
else 240 else 240
) )
height = ( height = (
client.features.homescreen_height session.features.homescreen_height
if client.features.homescreen_height is not None if session.features.homescreen_height is not None
else 240 else 240
) )
img = image_to_jpeg(path, width, height, quality) img = image_to_jpeg(path, width, height, quality)
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = client.features.homescreen_width width = session.features.homescreen_width
height = client.features.homescreen_height height = session.features.homescreen_height
if width is None or height is None: if width is None or height is None:
raise click.ClickException("Device did not report homescreen size.") raise click.ClickException("Device did not report homescreen size.")
img = image_to_toif(path, width, height, True) img = image_to_toif(path, width, height, True)
elif ( elif (
client.features.homescreen_format == messages.HomescreenFormat.Toif session.features.homescreen_format == messages.HomescreenFormat.Toif
or client.features.homescreen_format is None or session.features.homescreen_format is None
): ):
width = ( width = (
client.features.homescreen_width session.features.homescreen_width
if client.features.homescreen_width is not None if session.features.homescreen_width is not None
else 144 else 144
) )
height = ( height = (
client.features.homescreen_height session.features.homescreen_height
if client.features.homescreen_height is not None if session.features.homescreen_height is not None
else 144 else 144
) )
img = image_to_toif(path, width, height, False) img = image_to_toif(path, width, height, False)
@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"Unknown image format requested by the device." "Unknown image format requested by the device."
) )
return device.apply_settings(client, homescreen=img) return device.apply_settings(session, homescreen=img)
@cli.command() @cli.command()
@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
) )
@click.argument("level", type=ChoiceType(SAFETY_LEVELS)) @click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client @with_session(management=True)
def safety_checks( def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel session: "Session", always: bool, level: messages.SafetyCheckLevel
) -> str: ) -> str:
"""Set safety check level. """Set safety check level.
@ -392,18 +393,18 @@ def safety_checks(
""" """
if always and level == messages.SafetyCheckLevel.PromptTemporarily: if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways level = messages.SafetyCheckLevel.PromptAlways
return device.apply_settings(client, safety_checks=level) return device.apply_settings(session, safety_checks=level)
@cli.command() @cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False})) @click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client @with_session(management=True)
def experimental_features(client: "TrezorClient", enable: bool) -> str: def experimental_features(session: "Session", enable: bool) -> str:
"""Enable or disable experimental message types. """Enable or disable experimental message types.
This is a developer feature. Use with caution. This is a developer feature. Use with caution.
""" """
return device.apply_settings(client, experimental_features=enable) return device.apply_settings(session, experimental_features=enable)
# #
@ -426,25 +427,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on") @passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client @with_session(management=True)
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str: def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str:
"""Enable passphrase.""" """Enable passphrase."""
if client.features.passphrase_protection is not True: if session.features.passphrase_protection is not True:
use_passphrase = True use_passphrase = True
else: else:
use_passphrase = None use_passphrase = None
return device.apply_settings( return device.apply_settings(
client, session,
use_passphrase=use_passphrase, use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device, passphrase_always_on_device=force_on_device,
) )
@passphrase.command(name="off") @passphrase.command(name="off")
@with_client @with_session(management=True)
def passphrase_off(client: "TrezorClient") -> str: def passphrase_off(session: "Session") -> str:
"""Disable passphrase.""" """Disable passphrase."""
return device.apply_settings(client, use_passphrase=False) return device.apply_settings(session, use_passphrase=False)
# Registering the aliases for backwards compatibility # Registering the aliases for backwards compatibility
@ -457,10 +458,10 @@ passphrase.aliases = {
@passphrase.command(name="hide") @passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False})) @click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client @with_session(management=True)
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str: def hide_passphrase_from_host(session: "Session", hide: bool) -> str:
"""Enable or disable hiding passphrase coming from host. """Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution. This is a developer feature. Use with caution.
""" """
return device.apply_settings(client, hide_passphrase_from_host=hide) return device.apply_settings(session, hide_passphrase_from_host=hide)

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click import click
from .. import messages, solana, tools from .. import messages, solana, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command() @cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
) -> messages.SolanaPublicKey: ) -> messages.SolanaPublicKey:
"""Get Solana public key.""" """Get Solana public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display) return solana.get_public_key(session, address_n, show_display)
@cli.command() @cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address: str, address: str,
show_display: bool, show_display: bool,
chunkify: bool, chunkify: bool,
) -> messages.SolanaAddress: ) -> messages.SolanaAddress:
"""Get Solana address.""" """Get Solana address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify) return solana.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.argument("serialized_tx", type=str) @click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r")) @click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: str, address: str,
serialized_tx: str, serialized_tx: str,
additional_information_file: Optional[TextIO], additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
) )
return solana.sign_tx( return solana.sign_tx(
client, session,
address_n, address_n,
bytes.fromhex(serialized_tx), bytes.fromhex(serialized_tx),
additional_information, additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click import click
from .. import stellar, tools from .. import stellar, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
try: try:
from stellar_sdk import ( from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
) )
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Stellar public address.""" """Get Stellar public address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify) return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).", help="Network passphrase (blank for public network).",
) )
@click.argument("b64envelope") @click.argument("b64envelope")
@with_client @with_session
def sign_transaction( def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes: ) -> bytes:
"""Sign a base64-encoded transaction envelope. """Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope) tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature) return base64.b64encode(resp.signature)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click import click
from .. import messages, protobuf, tezos, tools from .. import messages, protobuf, tezos, tools
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
@ -37,23 +37,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool session: "Session", address: str, show_display: bool, chunkify: bool
) -> str: ) -> str:
"""Get Tezos address for specified path.""" """Get Tezos address for specified path."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display, chunkify) return tezos.get_address(session, address_n, show_display, chunkify)
@cli.command() @cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Tezos public key.""" """Get Tezos public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display) return tezos.get_public_key(session, address_n, show_display)
@cli.command() @cli.command()
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def sign_tx( def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool session: "Session", address: str, file: TextIO, chunkify: bool
) -> messages.TezosSignedTx: ) -> messages.TezosSignedTx:
"""Sign Tezos transaction.""" """Sign Tezos transaction."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)

View File

@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click import click
from .. import __version__, log, messages, protobuf, ui from .. import __version__, log, messages, protobuf
from ..client import TrezorClient from ..client import ProtocolVersion, TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.thp import channel_database
from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
from . import ( from . import (
AliasedGroup, AliasedGroup,
@ -50,6 +53,7 @@ from . import (
stellar, stellar,
tezos, tezos,
with_client, with_client,
with_session,
) )
F = TypeVar("F", bound=Callable) F = TypeVar("F", bound=Callable)
@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None:
"--record", "--record",
help="Record screen changes into a specified directory.", help="Record screen changes into a specified directory.",
) )
@click.option(
"-n",
"--no-store",
is_flag=True,
help="Do not store channels data between commands.",
default=False,
)
@click.version_option(version=__version__) @click.version_option(version=__version__)
@click.pass_context @click.pass_context
def cli_main( def cli_main(
@ -204,9 +215,10 @@ def cli_main(
script: bool, script: bool,
session_id: Optional[str], session_id: Optional[str],
record: Optional[str], record: Optional[str],
no_store: bool,
) -> None: ) -> None:
configure_logging(verbose) configure_logging(verbose)
channel_database.set_channel_database(should_not_store=no_store)
bytes_session_id: Optional[bytes] = None bytes_session_id: Optional[bytes] = None
if session_id is not None: if session_id is not None:
try: try:
@ -214,6 +226,7 @@ def cli_main(
except ValueError: except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}") raise click.ClickException(f"Not a valid session id: {session_id}")
# ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
# Optionally record the screen into a specified directory. # Optionally record the screen into a specified directory.
@ -285,18 +298,23 @@ def format_device_name(features: messages.Features) -> str:
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices.""" """List connected Trezor devices."""
if no_resolve: if no_resolve:
return enumerate_devices() for d in enumerate_devices():
print(d.get_path())
return
from . import get_client
for transport in enumerate_devices(): for transport in enumerate_devices():
try: try:
client = TrezorClient(transport, ui=ui.ClickUI()) client = get_client(transport)
description = format_device_name(client.features) description = format_device_name(client.features)
client.end_session() if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"
except Exception: except Exception as e:
description = "Failed to read details" description = "Failed to read details " + str(type(e))
click.echo(f"{transport} - {description}") click.echo(f"{transport.get_path()} - {description}")
return None return None
@ -314,15 +332,19 @@ def version() -> str:
@cli.command() @cli.command()
@click.argument("message") @click.argument("message")
@click.option("-b", "--button-protection", is_flag=True) @click.option("-b", "--button-protection", is_flag=True)
@with_client @with_session(empty_passphrase=True)
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message.""" """Send ping message."""
return client.ping(message, button_protection=button_protection)
# TODO return short-circuit from old client for old Trezors
return session.ping(message, button_protection)
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
def get_session(obj: TrezorConnection) -> str: def get_session(
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
"""Get a session ID for subsequent commands. """Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -336,23 +358,44 @@ def get_session(obj: TrezorConnection) -> str:
obj.session_id = None obj.session_id = None
with obj.client_context() as client: with obj.client_context() as client:
if client.features.model == "1" and client.version < (1, 9, 0): if client.features.model == "1" and client.version < (1, 9, 0):
raise click.ClickException( raise click.ClickException(
"Upgrade your firmware to enable session support." "Upgrade your firmware to enable session support."
) )
client.ensure_unlocked() # client.ensure_unlocked()
if client.session_id is None: session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
if session.id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.") raise click.ClickException("Passphrase not enabled or firmware too old.")
else: else:
return client.session_id.hex() return session.id.hex()
@cli.command() @cli.command()
@with_client @with_session(must_resume=True, empty_passphrase=True)
def clear_session(client: "TrezorClient") -> None: def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.).""" """Clear session (remove cached PIN, passphrase, etc.)."""
return client.clear_session() if session is None:
click.echo("Cannot clear session as it was not properly resumed.")
return
session.call(messages.LockDevice())
session.end()
# TODO different behaviour than main, not sure if ok
@cli.command()
def delete_channels() -> None:
"""
Delete cached channels.
Do not use together with the `-n` (`--no-store`) flag,
as the JSON database will not be deleted in that case.
"""
get_channel_db().clear_stored_channels()
click.echo("Deleted stored channels")
@cli.command() @cli.command()

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project. # This file is part of the Trezor project.
# #
# Copyright (C) 2012-2022 SatoshiLabs and contributors # Copyright (C) 2012-2024 SatoshiLabs and contributors
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 # it under the terms of the GNU Lesser General Public License version 3
@ -21,47 +21,44 @@ import logging
import re import re
import textwrap import textwrap
import time import time
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto
from itertools import zip_longest from itertools import zip_longest
from pathlib import Path from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
Sequence,
Tuple,
Union,
)
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import mapping, messages, models, protobuf from . import btc, mapping, messages, models, protobuf
from .client import TrezorClient from .client import (
from .exceptions import TrezorFailure MAX_PASSPHRASE_LENGTH,
MAX_PIN_LENGTH,
PASSPHRASE_ON_DEVICE,
TrezorClient,
)
from .exceptions import Cancelled, PinException, TrezorFailure
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import DebugWaitType from .messages import Capability, DebugWaitType
from .tools import expect from .tools import expect, parse_path
from .transport.session import Session, SessionV1
from .transport.thp.protocol_v1 import ProtocolV1
if TYPE_CHECKING: if t.TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
from .messages import PinMatrixRequestType from .messages import PinMatrixRequestType
from .transport import Transport from .transport import Transport
ExpectedMessage = Union[ ExpectedMessage = t.Union[
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
] ]
AnyDict = Dict[str, Any] AnyDict = t.Dict[str, t.Any]
class InputFunc(Protocol): class InputFunc(Protocol):
def __call__( def __call__(
self, self,
hold_ms: int | None = None, hold_ms: int | None = None,
@ -70,6 +67,7 @@ if TYPE_CHECKING:
EXPECTED_RESPONSES_CONTEXT_LINES = 3 EXPECTED_RESPONSES_CONTEXT_LINES = 3
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -104,11 +102,13 @@ class UnstructuredJSONReader:
except json.JSONDecodeError: except json.JSONDecodeError:
self.dict = {} self.dict = {}
def top_level_value(self, key: str) -> Any: def top_level_value(self, key: str) -> t.Any:
return self.dict.get(key) return self.dict.get(key)
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: def find_objects_with_key_and_value(
def recursively_find(data: Any) -> Iterator[Any]: self, key: str, value: t.Any
) -> list["AnyDict"]:
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
if isinstance(data, dict): if isinstance(data, dict):
if data.get(key) == value: if data.get(key) == value:
yield data yield data
@ -121,7 +121,7 @@ class UnstructuredJSONReader:
return list(recursively_find(self.dict)) return list(recursively_find(self.dict))
def find_unique_object_with_key_and_value( def find_unique_object_with_key_and_value(
self, key: str, value: Any self, key: str, value: t.Any
) -> AnyDict | None: ) -> AnyDict | None:
objects = self.find_objects_with_key_and_value(key, value) objects = self.find_objects_with_key_and_value(key, value)
if not objects: if not objects:
@ -129,8 +129,10 @@ class UnstructuredJSONReader:
assert len(objects) == 1 assert len(objects) == 1
return objects[0] return objects[0]
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: def find_values_by_key(
def recursively_find(data: Any) -> Iterator[Any]: self, key: str, only_type: type | None = None
) -> list[t.Any]:
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
if isinstance(data, dict): if isinstance(data, dict):
if key in data: if key in data:
yield data[key] yield data[key]
@ -148,8 +150,8 @@ class UnstructuredJSONReader:
return values return values
def find_unique_value_by_key( def find_unique_value_by_key(
self, key: str, default: Any, only_type: type | None = None self, key: str, default: t.Any, only_type: type | None = None
) -> Any: ) -> t.Any:
values = self.find_values_by_key(key, only_type=only_type) values = self.find_values_by_key(key, only_type=only_type)
if not values: if not values:
return default return default
@ -160,7 +162,7 @@ class UnstructuredJSONReader:
class LayoutContent(UnstructuredJSONReader): class LayoutContent(UnstructuredJSONReader):
"""Contains helper functions to extract specific parts of the layout.""" """Contains helper functions to extract specific parts of the layout."""
def __init__(self, json_tokens: Sequence[str]) -> None: def __init__(self, json_tokens: t.Sequence[str]) -> None:
json_str = "".join(json_tokens) json_str = "".join(json_tokens)
super().__init__(json_str) super().__init__(json_str)
@ -422,11 +424,13 @@ def _make_input_func(
class DebugLink: class DebugLink:
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self.transport = transport self.transport = transport
self.allow_interactions = auto_interact self.allow_interactions = auto_interact
self.mapping = mapping.DEFAULT_MAPPING self.mapping = mapping.DEFAULT_MAPPING
self.protocol = ProtocolV1(self.transport, self.mapping)
# To be set by TrezorClientDebugLink (is not known during creation time) # To be set by TrezorClientDebugLink (is not known during creation time)
self.model: models.TrezorModel | None = None self.model: models.TrezorModel | None = None
self.version: tuple[int, int, int] = (0, 0, 0) self.version: tuple[int, int, int] = (0, 0, 0)
@ -479,10 +483,16 @@ class DebugLink:
self.screen_text_file = file_path self.screen_text_file = file_path
def open(self) -> None: def open(self) -> None:
self.transport.begin_session() self.transport.open()
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_begin_session()
def close(self) -> None: def close(self) -> None:
self.transport.end_session() pass
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_end_session()
def _write(self, msg: protobuf.MessageType) -> None: def _write(self, msg: protobuf.MessageType) -> None:
if self.waiting_for_layout_change: if self.waiting_for_layout_change:
@ -499,15 +509,10 @@ class DebugLink:
DUMP_BYTES, DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
) )
self.transport.write(msg_type, msg_bytes) self.protocol.write(msg)
def _read(self) -> protobuf.MessageType: def _read(self) -> protobuf.MessageType:
ret_type, ret_bytes = self.transport.read() msg = self.protocol.read()
LOG.log(
DUMP_BYTES,
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
)
msg = self.mapping.decode(ret_type, ret_bytes)
# Collapse tokens to make log use less lines. # Collapse tokens to make log use less lines.
msg_for_log = msg msg_for_log = msg
@ -521,18 +526,27 @@ class DebugLink:
) )
return msg return msg
def _call(self, msg: protobuf.MessageType) -> Any: def _call(self, msg: protobuf.MessageType) -> t.Any:
self._write(msg) self._write(msg)
return self._read() return self._read()
def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkState: def state(
self,
wait_type: DebugWaitType | None = None,
thp_channel_id: bytes | None = None,
) -> messages.DebugLinkState:
if wait_type is None: if wait_type is None:
wait_type = ( wait_type = (
DebugWaitType.CURRENT_LAYOUT DebugWaitType.CURRENT_LAYOUT
if self.has_global_layout if self.has_global_layout
else DebugWaitType.IMMEDIATE else DebugWaitType.IMMEDIATE
) )
result = self._call(messages.DebugLinkGetState(wait_layout=wait_type)) result = self._call(
messages.DebugLinkGetState(
wait_layout=wait_type,
thp_channel_id=thp_channel_id,
)
)
while not isinstance(result, (messages.Failure, messages.DebugLinkState)): while not isinstance(result, (messages.Failure, messages.DebugLinkState)):
result = self._read() result = self._read()
if isinstance(result, messages.Failure): if isinstance(result, messages.Failure):
@ -544,7 +558,7 @@ class DebugLink:
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
# Next layout change will be caused by external event # Next layout change will be caused by external event
# (e.g. device being auto-locked or as a result of device_handler.run(xxx)) # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
# and not by our debug actions/decisions. # and not by our debug actions/decisions.
# Resetting the debug state so we wait for the next layout change # Resetting the debug state so we wait for the next layout change
# (and do not return the current state). # (and do not return the current state).
@ -560,7 +574,7 @@ class DebugLink:
return LayoutContent(obj.tokens) return LayoutContent(obj.tokens)
@contextmanager @contextmanager
def wait_for_layout_change(self) -> Iterator[LayoutContent]: def wait_for_layout_change(self) -> t.Iterator[LayoutContent]:
# set up a dummy layout content object to be yielded # set up a dummy layout content object to be yielded
layout_content = LayoutContent( layout_content = LayoutContent(
["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("] ["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("]
@ -622,7 +636,7 @@ class DebugLink:
return "".join([str(matrix.index(p) + 1) for p in pin]) return "".join([str(matrix.index(p) + 1) for p in pin])
def read_recovery_word(self) -> Tuple[str | None, int | None]: def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
state = self.state() state = self.state()
return (state.recovery_fake_word, state.recovery_word_pos) return (state.recovery_fake_word, state.recovery_word_pos)
@ -700,7 +714,7 @@ class DebugLink:
def click( def click(
self, self,
click: Tuple[int, int], click: t.Tuple[int, int],
hold_ms: int | None = None, hold_ms: int | None = None,
wait: bool | None = None, wait: bool | None = None,
) -> LayoutContent: ) -> LayoutContent:
@ -862,10 +876,10 @@ class DebugUI:
self.clear() self.clear()
def clear(self) -> None: def clear(self) -> None:
self.pins: Iterator[str] | None = None self.pins: t.Iterator[str] | None = None
self.passphrase = "" self.passphrase = ""
self.input_flow: Union[ self.input_flow: t.Union[
Generator[None, messages.ButtonRequest, None], object, None t.Generator[None, messages.ButtonRequest, None], object, None
] = None ] = None
def _default_input_flow(self, br: messages.ButtonRequest) -> None: def _default_input_flow(self, br: messages.ButtonRequest) -> None:
@ -896,7 +910,7 @@ class DebugUI:
raise AssertionError("input flow ended prematurely") raise AssertionError("input flow ended prematurely")
else: else:
try: try:
assert isinstance(self.input_flow, Generator) assert isinstance(self.input_flow, t.Generator)
self.input_flow.send(br) self.input_flow.send(br)
except StopIteration: except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE self.input_flow = self.INPUT_FLOW_DONE
@ -918,12 +932,15 @@ class DebugUI:
class MessageFilter: class MessageFilter:
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
def __init__(
self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
) -> None:
self.message_type = message_type self.message_type = message_type
self.fields: Dict[str, Any] = {} self.fields: t.Dict[str, t.Any] = {}
self.update_fields(**fields) self.update_fields(**fields)
def update_fields(self, **fields: Any) -> "MessageFilter": def update_fields(self, **fields: t.Any) -> "MessageFilter":
for name, value in fields.items(): for name, value in fields.items():
try: try:
self.fields[name] = self.from_message_or_type(value) self.fields[name] = self.from_message_or_type(value)
@ -971,7 +988,7 @@ class MessageFilter:
return True return True
def to_string(self, maxwidth: int = 80) -> str: def to_string(self, maxwidth: int = 80) -> str:
fields: list[Tuple[str, str]] = [] fields: list[t.Tuple[str, str]] = []
for field in self.message_type.FIELDS.values(): for field in self.message_type.FIELDS.values():
if field.name not in self.fields: if field.name not in self.fields:
continue continue
@ -1001,7 +1018,8 @@ class MessageFilter:
class MessageFilterGenerator: class MessageFilterGenerator:
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
message_type = getattr(messages, key) message_type = getattr(messages, key)
return MessageFilter(message_type).update_fields return MessageFilter(message_type).update_fields
@ -1009,6 +1027,245 @@ class MessageFilterGenerator:
message_filters = MessageFilterGenerator() message_filters = MessageFilterGenerator()
class SessionDebugWrapper(Session):
def __init__(self, session: Session) -> None:
self._session = session
self.reset_debug_features()
if isinstance(session, SessionDebugWrapper):
raise Exception("Cannot wrap already wrapped session!")
@property
def protocol_version(self) -> int:
return self.client.protocol_version
@property
def client(self) -> TrezorClientDebugLink:
assert isinstance(self._session.client, TrezorClientDebugLink)
return self._session.client
@property
def id(self) -> bytes:
return self._session.id
def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__)
self._session._write(self._filter_message(msg))
def _read(self) -> t.Any:
resp = self._filter_message(self._session._read())
print("reading message:", resp.__class__.__name__)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def set_expected_responses(
self,
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
) -> None:
"""Set a sequence of expected responses to session calls.
Within a given with-block, the list of received responses from device must
match the list of expected responses, otherwise an ``AssertionError`` is raised.
If an expected response is given a field value other than ``None``, that field value
must exactly match the received field value. If a given field is ``None``
(or unspecified) in the expected response, the received field value is not
checked.
Each expected response can also be a tuple ``(bool, message)``. In that case, the
expected response is only evaluated if the first field is ``True``.
This is useful for differentiating sequences between Trezor models:
>>> trezor_one = session.features.model == "1"
>>> session.set_expected_responses([
>>> messages.ButtonRequest(code=ConfirmOutput),
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
>>> messages.Success(),
>>> ])
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
# make sure all items are (bool, message) tuples
expected_with_validity = (
e if isinstance(e, tuple) else (True, e) for e in expected
)
# only apply those items that are (True, message)
self.expected_responses = [
MessageFilter.from_message_or_type(expected)
for valid, expected in expected_with_validity
if valid
]
self.actual_responses = []
def lock(self, *, _refresh_features: bool = True) -> None:
"""Lock the device.
If the device does not have a PIN configured, this will do nothing.
Otherwise, a lock screen will be shown and the device will prompt for PIN
before further actions.
This call does _not_ invalidate passphrase cache. If passphrase is in use,
the device will not prompt for it after unlocking.
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
passphrase cache, use `clear_session()`.
"""
# TODO update the documentation above
# Private argument _refresh_features can be used internally to avoid
# refreshing in cases where we will refresh soon anyway. This is used
# in TrezorClient.clear_session()
self.call(messages.LockDevice())
if _refresh_features:
self.refresh_features()
def cancel(self) -> None:
self._write(messages.Cancel())
def ensure_unlocked(self) -> None:
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features()
def set_filter(
self,
message_type: t.Type[protobuf.MessageType],
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.filters[message_type] = callback
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def reset_debug_features(self) -> None:
"""Prepare the debugging session for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: t.Dict[
t.Type[protobuf.MessageType],
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {}
self.button_callback = self.client.button_callback
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self._session.passphrase_callback
self.passphrase = self._session.passphrase
def __enter__(self) -> "SessionDebugWrapper":
# For usage in with/expected_responses
if self.in_with_statement:
raise RuntimeError("Do not nest!")
self.in_with_statement = True
return self
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# copy expected/actual responses before clearing them
expected_responses = self.expected_responses
actual_responses = self.actual_responses
# grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.client.ui, DebugUI):
input_flow = self.client.ui.input_flow
else:
input_flow = None
self.reset_debug_features()
if exc_type is None:
# If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses)
if isinstance(input_flow, t.Generator):
# Ensure that the input flow is exhausted
try:
input_flow.throw(
AssertionError("input flow continues past end of test")
)
except StopIteration:
pass
elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
@classmethod
def _verify_responses(
cls,
expected: list[MessageFilter] | None,
actual: list[protobuf.MessageType] | None,
) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if expected is None and actual is None:
return
assert expected is not None
assert actual is not None
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
if exp is None:
output = cls._expectation_lines(expected, i)
output.append("No more messages were expected, but we got:")
for resp in actual[i:]:
output.append(
textwrap.indent(protobuf.format_message(resp), " ")
)
raise AssertionError("\n".join(output))
if act is None:
output = cls._expectation_lines(expected, i)
output.append("This and the following message was not received.")
raise AssertionError("\n".join(output))
if not exp.match(act):
output = cls._expectation_lines(expected, i)
output.append("Actually received:")
output.append(textwrap.indent(protobuf.format_message(act), " "))
raise AssertionError("\n".join(output))
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: list[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
for i in range(start_at, stop_at):
exp = expected[i]
prefix = " " if i != current else ">>> "
output.append(textwrap.indent(exp.to_string(), prefix))
if stop_at < len(expected):
omitted = len(expected) - stop_at
output.append(f" (...{omitted} following responses omitted)")
output.append("")
return output
class TrezorClientDebugLink(TrezorClient): class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses # This class implements automatic responses
# and other functionality for unit tests # and other functionality for unit tests
@ -1034,54 +1291,165 @@ class TrezorClientDebugLink(TrezorClient):
raise raise
# set transport explicitly so that sync_responses can work # set transport explicitly so that sync_responses can work
super().__init__(transport)
self.transport = transport self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
self.reset_debug_features() self.reset_debug_features(new_management_session=True)
self.sync_responses() self.sync_responses()
super().__init__(transport, ui=self.ui)
# So that we can choose right screenshotting logic (T1 vs TT) # So that we can choose right screenshotting logic (T1 vs TT)
# and know the supported debug capabilities # and know the supported debug capabilities
self.debug.model = self.model self.debug.model = self.model
self.debug.version = self.version self.debug.version = self.version
self.passphrase: str | None = None
@property @property
def layout_type(self) -> LayoutType: def layout_type(self) -> LayoutType:
return self.debug.layout_type return self.debug.layout_type
def reset_debug_features(self) -> None: def get_new_client(self) -> TrezorClientDebugLink:
"""Prepare the debugging client for a new testcase. return TrezorClientDebugLink(self.transport, self.debug.allow_interactions)
def reset_debug_features(self, new_management_session: bool = False) -> None:
"""
Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase. Clears all debugging state that might have been modified by a testcase.
""" """
self.ui: DebugUI = DebugUI(self.debug) self.ui: DebugUI = DebugUI(self.debug)
# self.pin_callback = self.ui.debug_callback_button
self.in_with_statement = False self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: dict[ self.filters: t.Dict[
type[protobuf.MessageType], t.Type[protobuf.MessageType],
Callable[[protobuf.MessageType], protobuf.MessageType] | None, t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {} ] = {}
if new_management_session:
self._management_session = self.get_management_session(new_session=True)
@property
def button_callback(self):
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
return _callback_button
@property
def pin_callback(self):
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
try:
pin = self.ui.get_pin(msg.type)
except Cancelled:
session.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
session.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
return _callback_pin
@property
def passphrase_callback(self):
def _callback_passphrase(
session: Session, msg: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> t.Any:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
# session.session_id = resp.state
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
if session.passphrase is None and isinstance(session, SessionV1):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
else:
passphrase = session.passphrase
except Cancelled:
session.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
return _callback_passphrase
def ensure_open(self) -> None: def ensure_open(self) -> None:
"""Only open session if there isn't already an open one.""" """Only open session if there isn't already an open one."""
if self.session_counter == 0: # if self.session_counter == 0:
self.open() # self.open()
# TODO check if is this needed
def open(self) -> None: def open(self) -> None:
super().open() pass
if self.session_counter == 1: # TODO is this needed?
self.debug.open() # self.debug.open()
def close(self) -> None: def close(self) -> None:
if self.session_counter == 1: pass
self.debug.close() # TODO is this needed?
super().close() # self.debug.close()
def get_session(
self,
passphrase: str | object | None = "",
derive_cardano: bool = False,
) -> Session:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
return super().get_session(passphrase, derive_cardano)
def set_filter( def set_filter(
self, self,
message_type: type[protobuf.MessageType], message_type: t.Type[protobuf.MessageType],
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None: ) -> None:
"""Configure a filter function for a specified message type. """Configure a filter function for a specified message type.
@ -1106,7 +1474,8 @@ class TrezorClientDebugLink(TrezorClient):
return msg return msg
def set_input_flow( def set_input_flow(
self, input_flow: Generator[None, messages.ButtonRequest | None, None] self,
input_flow: t.Generator[None, messages.ButtonRequest | None, None],
) -> None: ) -> None:
"""Configure a sequence of input events for the current with-block. """Configure a sequence of input events for the current with-block.
@ -1140,6 +1509,7 @@ class TrezorClientDebugLink(TrezorClient):
if not hasattr(input_flow, "send"): if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function") raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow self.ui.input_flow = input_flow
assert input_flow is not None
input_flow.send(None) # start the generator input_flow.send(None) # start the generator
def watch_layout(self, watch: bool = True) -> None: def watch_layout(self, watch: bool = True) -> None:
@ -1162,7 +1532,7 @@ class TrezorClientDebugLink(TrezorClient):
self.in_with_statement = True self.in_with_statement = True
return self return self
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
# copy expected/actual responses before clearing them # copy expected/actual responses before clearing them
@ -1175,20 +1545,21 @@ class TrezorClientDebugLink(TrezorClient):
else: else:
input_flow = None input_flow = None
self.reset_debug_features() self.reset_debug_features(new_management_session=False)
if exc_type is None: if exc_type is None:
# If no other exception was raised, evaluate missed responses # If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch) # (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses) self._verify_responses(expected_responses, actual_responses)
elif isinstance(input_flow, Generator): elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in # Propagate the exception through the input flow, so that we see in
# traceback where it is stuck. # traceback where it is stuck.
input_flow.throw(exc_type, value, traceback) input_flow.throw(exc_type, value, traceback)
def set_expected_responses( def set_expected_responses(
self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] self,
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
) -> None: ) -> None:
"""Set a sequence of expected responses to client calls. """Set a sequence of expected responses to client calls.
@ -1227,7 +1598,7 @@ class TrezorClientDebugLink(TrezorClient):
] ]
self.actual_responses = [] self.actual_responses = []
def use_pin_sequence(self, pins: Iterable[str]) -> None: def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs. """Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts. The sequence must be at least as long as the expected number of PIN prompts.
""" """
@ -1235,6 +1606,7 @@ class TrezorClientDebugLink(TrezorClient):
def use_passphrase(self, passphrase: str) -> None: def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase.""" """Respond to passphrase prompts from device with the provided passphrase."""
self.passphrase = passphrase
self.ui.passphrase = Mnemonic.normalize_string(passphrase) self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def use_mnemonic(self, mnemonic: str) -> None: def use_mnemonic(self, mnemonic: str) -> None:
@ -1244,15 +1616,14 @@ class TrezorClientDebugLink(TrezorClient):
def _raw_read(self) -> protobuf.MessageType: def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = self.get_management_session()._read()
resp = super()._raw_read()
resp = self._filter_message(resp) resp = self._filter_message(resp)
if self.actual_responses is not None: if self.actual_responses is not None:
self.actual_responses.append(resp) self.actual_responses.append(resp)
return resp return resp
def _raw_write(self, msg: protobuf.MessageType) -> None: def _raw_write(self, msg: protobuf.MessageType) -> None:
return super()._raw_write(self._filter_message(msg)) return self.get_management_session()._write(self._filter_message(msg))
@staticmethod @staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
@ -1322,23 +1693,25 @@ class TrezorClientDebugLink(TrezorClient):
# Start by canceling whatever is on screen. This will work to cancel T1 PIN # Start by canceling whatever is on screen. This will work to cancel T1 PIN
# prompt, which is in TINY mode and does not respond to `Ping`. # prompt, which is in TINY mode and does not respond to `Ping`.
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) # TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
self.transport.begin_session() self.transport.open()
try: try:
self.transport.write(*cancel_msg) # self.protocol.write(messages.Cancel())
message = "SYNC" + secrets.token_hex(8) message = "SYNC" + secrets.token_hex(8)
ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) self.get_management_session()._write(messages.Ping(message=message))
self.transport.write(*ping_msg)
resp = None resp = None
while resp != messages.Success(message=message): while resp != messages.Success(message=message):
msg_id, msg_bytes = self.transport.read()
try: try:
resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) resp = self.get_management_session()._read()
raise Exception
except Exception: except Exception:
pass pass
finally: finally:
self.transport.end_session() pass # TODO fix
# self.transport.end_session(self.session_id or b"")
def mnemonic_callback(self, _) -> str: def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word() word, pos = self.debug.read_recovery_word()
@ -1352,8 +1725,8 @@ class TrezorClientDebugLink(TrezorClient):
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def load_device( def load_device(
client: "TrezorClient", session: "Session",
mnemonic: Union[str, Iterable[str]], mnemonic: str | t.Iterable[str],
pin: str | None, pin: str | None,
passphrase_protection: bool, passphrase_protection: bool,
label: str | None, label: str | None,
@ -1366,12 +1739,12 @@ def load_device(
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
if client.features.initialized: if session.features.initialized:
raise RuntimeError( raise RuntimeError(
"Device is initialized already. Call device.wipe() and try again." "Device is initialized already. Call device.wipe() and try again."
) )
resp = client.call( resp = session.call(
messages.LoadDevice( messages.LoadDevice(
mnemonics=mnemonics, mnemonics=mnemonics,
pin=pin, pin=pin,
@ -1382,7 +1755,7 @@ def load_device(
no_backup=no_backup, no_backup=no_backup,
) )
) )
client.init_device() session.refresh_features()
return resp return resp
@ -1391,11 +1764,11 @@ load_device_by_mnemonic = load_device
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: def prodtest_t1(session: "Session") -> protobuf.MessageType:
if client.features.bootloader_mode is not True: if session.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
return client.call( return session.call(
messages.ProdTestT1( messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
) )
@ -1404,8 +1777,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
def record_screen( def record_screen(
debug_client: "TrezorClientDebugLink", debug_client: "TrezorClientDebugLink",
directory: Union[str, None], directory: str | None,
report_func: Union[Callable[[str], None], None] = None, report_func: t.Callable[[str], None] | None = None,
) -> None: ) -> None:
"""Record screen changes into a specified directory. """Record screen changes into a specified directory.
@ -1451,5 +1824,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType: def optiga_set_sec_max(session: "Session") -> protobuf.MessageType:
return client.call(messages.DebugLinkOptigaSetSecMax()) return session.call(messages.DebugLinkOptigaSetSecMax())

View File

@ -27,20 +27,19 @@ from slip10 import SLIP10
from . import messages from . import messages
from .exceptions import Cancelled, TrezorException from .exceptions import Cancelled, TrezorException
from .tools import Address, expect, parse_path, session from .tools import Address, expect, parse_path
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.session import Session
RECOVERY_BACK = "\x08" # backspace character, sent literally RECOVERY_BACK = "\x08" # backspace character, sent literally
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def apply_settings( def apply_settings(
client: "TrezorClient", session: "Session",
label: Optional[str] = None, label: Optional[str] = None,
language: Optional[str] = None, language: Optional[str] = None,
use_passphrase: Optional[bool] = None, use_passphrase: Optional[bool] = None,
@ -71,13 +70,13 @@ def apply_settings(
haptic_feedback=haptic_feedback, haptic_feedback=haptic_feedback,
) )
out = client.call(settings) out = session.call(settings)
client.refresh_features() session.refresh_features()
return out return out
def _send_language_data( def _send_language_data(
client: "TrezorClient", session: "Session",
request: "messages.TranslationDataRequest", request: "messages.TranslationDataRequest",
language_data: bytes, language_data: bytes,
) -> "MessageType": ) -> "MessageType":
@ -87,76 +86,70 @@ def _send_language_data(
data_length = response.data_length data_length = response.data_length
data_offset = response.data_offset data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length] chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk)) response = session.call(messages.TranslationDataAck(data_chunk=chunk))
return response return response
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def change_language( def change_language(
client: "TrezorClient", session: "Session",
language_data: bytes, language_data: bytes,
show_display: bool | None = None, show_display: bool | None = None,
) -> "MessageType": ) -> "MessageType":
data_length = len(language_data) data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg) response = session.call(msg)
if data_length > 0: if data_length > 0:
assert isinstance(response, messages.TranslationDataRequest) assert isinstance(response, messages.TranslationDataRequest)
response = _send_language_data(client, response, language_data) response = _send_language_data(session, response, language_data)
assert isinstance(response, messages.Success) assert isinstance(response, messages.Success)
client.refresh_features() # changing the language in features session.refresh_features() # changing the language in features
return response return response
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def apply_flags(session: "Session", flags: int) -> "MessageType":
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": out = session.call(messages.ApplyFlags(flags=flags))
out = client.call(messages.ApplyFlags(flags=flags)) session.refresh_features()
client.refresh_features()
return out return out
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def change_pin(session: "Session", remove: bool = False) -> "MessageType":
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = session.call(messages.ChangePin(remove=remove))
ret = client.call(messages.ChangePin(remove=remove)) session.refresh_features()
client.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType":
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": ret = session.call(messages.ChangeWipeCode(remove=remove))
ret = client.call(messages.ChangeWipeCode(remove=remove)) session.refresh_features()
client.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def sd_protect( def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType session: "Session", operation: messages.SdProtectOperationType
) -> "MessageType": ) -> "MessageType":
ret = client.call(messages.SdProtect(operation=operation)) ret = session.call(messages.SdProtect(operation=operation))
client.refresh_features() session.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def wipe(session: "Session") -> "MessageType":
def wipe(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.WipeDevice()) ret = session.call(messages.WipeDevice())
if not client.features.bootloader_mode: # if not session.features.bootloader_mode:
client.init_device() # session.refresh_features()
return ret return ret
@session
def recover( def recover(
client: "TrezorClient", session: "Session",
word_count: int = 24, word_count: int = 24,
passphrase_protection: bool = False, passphrase_protection: bool = False,
pin_protection: bool = True, pin_protection: bool = True,
@ -192,13 +185,13 @@ def recover(
if type is None: if type is None:
type = messages.RecoveryType.NormalRecovery type = messages.RecoveryType.NormalRecovery
if client.features.model == "1" and input_callback is None: if session.features.model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One") raise RuntimeError("Input callback required for Trezor One")
if word_count not in (12, 18, 24): if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24") raise ValueError("Invalid word count. Use 12/18/24")
if client.features.initialized and type == messages.RecoveryType.NormalRecovery: if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
raise RuntimeError( raise RuntimeError(
"Device already initialized. Call device.wipe() and try again." "Device already initialized. Call device.wipe() and try again."
) )
@ -220,17 +213,17 @@ def recover(
msg.label = label msg.label = label
msg.u2f_counter = u2f_counter msg.u2f_counter = u2f_counter
res = client.call(msg) res = session.call(msg)
while isinstance(res, messages.WordRequest): while isinstance(res, messages.WordRequest):
try: try:
assert input_callback is not None assert input_callback is not None
inp = input_callback(res.type) inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp)) res = session.call(messages.WordAck(word=inp))
except Cancelled: except Cancelled:
res = client.call(messages.Cancel()) res = session.call(messages.Cancel())
client.init_device() session.refresh_features()
return res return res
@ -279,9 +272,8 @@ def reset(*args: Any, **kwargs: Any) -> "MessageType":
return reset_entropy_check(*args, **kwargs)[0] return reset_entropy_check(*args, **kwargs)[0]
@session
def reset_entropy_check( def reset_entropy_check(
client: "TrezorClient", session: "Session",
display_random: bool = False, display_random: bool = False,
strength: Optional[int] = None, strength: Optional[int] = None,
passphrase_protection: bool = False, passphrase_protection: bool = False,
@ -307,13 +299,13 @@ def reset_entropy_check(
DeprecationWarning, DeprecationWarning,
) )
if client.features.initialized: if session.features.initialized:
raise RuntimeError( raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again." "Device is initialized already. Call wipe_device() and try again."
) )
if strength is None: if strength is None:
if client.features.model == "1": if session.features.model == "1":
strength = 256 strength = 256
else: else:
strength = 128 strength = 128
@ -335,7 +327,7 @@ def reset_entropy_check(
entropy_check=entropy_check_count is not None, entropy_check=entropy_check_count is not None,
) )
resp = client.call(msg) resp = session.call(msg)
if not isinstance(resp, messages.EntropyRequest): if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest") raise RuntimeError("Invalid response, expected EntropyRequest")
@ -344,7 +336,7 @@ def reset_entropy_check(
external_entropy = os.urandom(32) external_entropy = os.urandom(32)
entropy_commitment = resp.entropy_commitment entropy_commitment = resp.entropy_commitment
resp = client.call(messages.EntropyAck(entropy=external_entropy)) resp = session.call(messages.EntropyAck(entropy=external_entropy))
if entropy_check_count is None: if entropy_check_count is None:
break break
@ -353,18 +345,18 @@ def reset_entropy_check(
return resp, [] return resp, []
for path in paths: for path in paths:
resp = client.call(messages.GetPublicKey(address_n=path)) resp = session.call(messages.GetPublicKey(address_n=path))
if not isinstance(resp, messages.PublicKey): if not isinstance(resp, messages.PublicKey):
return resp, [] return resp, []
xpubs.append(resp.xpub) xpubs.append(resp.xpub)
if entropy_check_count <= 0: if entropy_check_count <= 0:
resp = client.call(messages.EntropyCheckContinue(finish=True)) resp = session.call(messages.EntropyCheckContinue(finish=True))
break break
entropy_check_count -= 1 entropy_check_count -= 1
resp = client.call(messages.EntropyCheckContinue(finish=False)) resp = session.call(messages.EntropyCheckContinue(finish=False))
if not isinstance(resp, messages.EntropyRequest): if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest") raise RuntimeError("Invalid response, expected EntropyRequest")
@ -385,18 +377,17 @@ def reset_entropy_check(
if slip10.get_xpub_from_path(path) != xpub: if slip10.get_xpub_from_path(path) != xpub:
raise RuntimeError("Invalid XPUB in entropy check") raise RuntimeError("Invalid XPUB in entropy check")
client.init_device() session.refresh_features()
return resp, zip(paths, xpubs) return resp, zip(paths, xpubs)
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session
def backup( def backup(
client: "TrezorClient", session: "Session",
group_threshold: Optional[int] = None, group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (), groups: Iterable[tuple[int, int]] = (),
) -> "MessageType": ) -> "MessageType":
ret = client.call( ret = session.call(
messages.BackupDevice( messages.BackupDevice(
group_threshold=group_threshold, group_threshold=group_threshold,
groups=[ groups=[
@ -405,37 +396,36 @@ def backup(
], ],
) )
) )
client.refresh_features() session.refresh_features()
return ret return ret
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def cancel_authorization(client: "TrezorClient") -> "MessageType": def cancel_authorization(session: "Session") -> "MessageType":
return client.call(messages.CancelAuthorization()) return session.call(messages.CancelAuthorization())
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes) @expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType": def unlock_path(session: "Session", n: "Address") -> "MessageType":
resp = client.call(messages.UnlockPath(address_n=n)) resp = session.call(messages.UnlockPath(address_n=n))
# Cancel the UnlockPath workflow now that we have the authentication code. # Cancel the UnlockPath workflow now that we have the authentication code.
try: try:
client.call(messages.Cancel()) session.call(messages.Cancel())
except Cancelled: except Cancelled:
return resp return resp
else: else:
raise TrezorException("Unexpected response in UnlockPath flow") raise TrezorException("Unexpected response in UnlockPath flow")
@session
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader( def reboot_to_bootloader(
client: "TrezorClient", session: "Session",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None, firmware_header: Optional[bytes] = None,
language_data: bytes = b"", language_data: bytes = b"",
) -> "MessageType": ) -> "MessageType":
response = client.call( response = session.call(
messages.RebootToBootloader( messages.RebootToBootloader(
boot_command=boot_command, boot_command=boot_command,
firmware_header=firmware_header, firmware_header=firmware_header,
@ -443,42 +433,37 @@ def reboot_to_bootloader(
) )
) )
if isinstance(response, messages.TranslationDataRequest): if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data) response = _send_language_data(session, response, language_data)
return response return response
@session
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def show_device_tutorial(client: "TrezorClient") -> "MessageType": def show_device_tutorial(session: "Session") -> "MessageType":
return client.call(messages.ShowDeviceTutorial()) return session.call(messages.ShowDeviceTutorial())
@session
@expect(messages.Success, field="message", ret_type=str)
def unlock_bootloader(client: "TrezorClient") -> "MessageType":
return client.call(messages.UnlockBootloader())
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
@session def unlock_bootloader(session: "Session") -> "MessageType":
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType": return session.call(messages.UnlockBootloader())
@expect(messages.Success, field="message", ret_type=str)
def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType":
"""Sets or clears the busy state of the device. """Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen. In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state. Setting `expiry_ms=None` clears the busy state.
""" """
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms)) ret = session.call(messages.SetBusy(expiry_ms=expiry_ms))
client.refresh_features() session.refresh_features()
return ret return ret
@expect(messages.AuthenticityProof) @expect(messages.AuthenticityProof)
def authenticate(client: "TrezorClient", challenge: bytes): def authenticate(session: "Session", challenge: bytes):
return client.call(messages.AuthenticateDevice(challenge=challenge)) return session.call(messages.AuthenticateDevice(challenge=challenge))
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def set_brightness( def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType":
client: "TrezorClient", value: Optional[int] = None return session.call(messages.SetBrightness(value=value))
) -> "MessageType":
return client.call(messages.SetBrightness(value=value))

View File

@ -18,12 +18,12 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages from . import exceptions, messages
from .tools import b58decode, expect, session from .tools import b58decode, expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
def name_to_number(name: str) -> int: def name_to_number(name: str) -> int:
@ -321,17 +321,16 @@ def parse_transaction_json(
@expect(messages.EosPublicKey) @expect(messages.EosPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False session: "Session", n: "Address", show_display: bool = False
) -> "MessageType": ) -> "MessageType":
response = client.call( response = session.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display) messages.EosGetPublicKey(address_n=n, show_display=show_display)
) )
return response return response
@session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address: "Address", address: "Address",
transaction: dict, transaction: dict,
chain_id: str, chain_id: str,
@ -347,11 +346,11 @@ def sign_tx(
chunkify=chunkify, chunkify=chunkify,
) )
response = client.call(msg) response = session.call(msg)
try: try:
while isinstance(response, messages.EosTxActionRequest): while isinstance(response, messages.EosTxActionRequest):
response = client.call(actions.pop(0)) response = session.call(actions.pop(0))
except IndexError: except IndexError:
# pop from empty list # pop from empty list
raise exceptions.TrezorException( raise exceptions.TrezorException(

View File

@ -18,12 +18,12 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages from . import definitions, exceptions, messages
from .tools import expect, prepare_message_bytes, session, unharden from .tools import expect, prepare_message_bytes, unharden
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
def int_to_big_endian(value: int) -> bytes: def int_to_big_endian(value: int) -> bytes:
@ -163,13 +163,13 @@ def network_from_address_n(
@expect(messages.EthereumAddress, field="address", ret_type=str) @expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumGetAddress( messages.EthereumGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
@ -181,16 +181,15 @@ def get_address(
@expect(messages.EthereumPublicKey) @expect(messages.EthereumPublicKey)
def get_public_node( def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False session: "Session", n: "Address", show_display: bool = False
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display) messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
) )
@session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
nonce: int, nonce: int,
gas_price: int, gas_price: int,
@ -226,13 +225,13 @@ def sign_tx(
data, chunk = data[1024:], data[:1024] data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk msg.data_initial_chunk = chunk
response = client.call(msg) response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None assert response.signature_v is not None
@ -247,9 +246,8 @@ def sign_tx(
return response.signature_v, response.signature_r, response.signature_s return response.signature_v, response.signature_r, response.signature_s
@session
def sign_tx_eip1559( def sign_tx_eip1559(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
*, *,
nonce: int, nonce: int,
@ -282,13 +280,13 @@ def sign_tx_eip1559(
chunkify=chunkify, chunkify=chunkify,
) )
response = client.call(msg) response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None: while response.data_length is not None:
data_length = response.data_length data_length = response.data_length
data, chunk = data[data_length:], data[:data_length] data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk)) response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest) assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None assert response.signature_v is not None
@ -299,13 +297,13 @@ def sign_tx_eip1559(
@expect(messages.EthereumMessageSignature) @expect(messages.EthereumMessageSignature)
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
message: AnyStr, message: AnyStr,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumSignMessage( messages.EthereumSignMessage(
address_n=n, address_n=n,
message=prepare_message_bytes(message), message=prepare_message_bytes(message),
@ -317,7 +315,7 @@ def sign_message(
@expect(messages.EthereumTypedDataSignature) @expect(messages.EthereumTypedDataSignature)
def sign_typed_data( def sign_typed_data(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
data: Dict[str, Any], data: Dict[str, Any],
*, *,
@ -333,7 +331,7 @@ def sign_typed_data(
metamask_v4_compat=metamask_v4_compat, metamask_v4_compat=metamask_v4_compat,
definitions=definitions, definitions=definitions,
) )
response = client.call(request) response = session.call(request)
# Sending all the types # Sending all the types
while isinstance(response, messages.EthereumTypedDataStructRequest): while isinstance(response, messages.EthereumTypedDataStructRequest):
@ -349,7 +347,7 @@ def sign_typed_data(
members.append(struct_member) members.append(struct_member)
request = messages.EthereumTypedDataStructAck(members=members) request = messages.EthereumTypedDataStructAck(members=members)
response = client.call(request) response = session.call(request)
# Sending the whole message that should be signed # Sending the whole message that should be signed
while isinstance(response, messages.EthereumTypedDataValueRequest): while isinstance(response, messages.EthereumTypedDataValueRequest):
@ -362,7 +360,7 @@ def sign_typed_data(
member_typename = data["primaryType"] member_typename = data["primaryType"]
member_data = data["message"] member_data = data["message"]
else: else:
client.cancel() # TODO session.cancel()
raise exceptions.TrezorException("Root index can only be 0 or 1") raise exceptions.TrezorException("Root index can only be 0 or 1")
# It can be asking for a nested structure (the member path being [X, Y, Z, ...]) # It can be asking for a nested structure (the member path being [X, Y, Z, ...])
@ -385,20 +383,20 @@ def sign_typed_data(
encoded_data = encode_data(member_data, member_typename) encoded_data = encode_data(member_data, member_typename)
request = messages.EthereumTypedDataValueAck(value=encoded_data) request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request) response = session.call(request)
return response return response
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
address: str, address: str,
signature: bytes, signature: bytes,
message: AnyStr, message: AnyStr,
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
resp = client.call( resp = session.call(
messages.EthereumVerifyMessage( messages.EthereumVerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
@ -413,13 +411,13 @@ def verify_message(
@expect(messages.EthereumTypedDataSignature) @expect(messages.EthereumTypedDataSignature)
def sign_typed_data_hash( def sign_typed_data_hash(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
domain_hash: bytes, domain_hash: bytes,
message_hash: Optional[bytes], message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.EthereumSignTypedHash( messages.EthereumSignTypedHash(
address_n=n, address_n=n,
domain_separator_hash=domain_hash, domain_separator_hash=domain_hash,

View File

@ -20,8 +20,8 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.session import Session
@expect( @expect(
@ -29,27 +29,27 @@ if TYPE_CHECKING:
field="credentials", field="credentials",
ret_type=List[messages.WebAuthnCredential], ret_type=List[messages.WebAuthnCredential],
) )
def list_credentials(client: "TrezorClient") -> "MessageType": def list_credentials(session: "Session") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials()) return session.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": def add_credential(session: "Session", credential_id: bytes) -> "MessageType":
return client.call( return session.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id) messages.WebAuthnAddResidentCredential(credential_id=credential_id)
) )
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def remove_credential(client: "TrezorClient", index: int) -> "MessageType": def remove_credential(session: "Session", index: int) -> "MessageType":
return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) return session.call(messages.WebAuthnRemoveResidentCredential(index=index))
@expect(messages.Success, field="message", ret_type=str) @expect(messages.Success, field="message", ret_type=str)
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": def set_counter(session: "Session", u2f_counter: int) -> "MessageType":
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) @expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
def get_next_counter(client: "TrezorClient") -> "MessageType": def get_next_counter(session: "Session") -> "MessageType":
return client.call(messages.GetNextU2FCounter()) return session.call(messages.GetNextU2FCounter())

View File

@ -20,7 +20,7 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard from typing_extensions import Protocol, TypeGuard
from .. import messages from .. import messages
from ..tools import expect, session from ..tools import expect
from .core import VendorFirmware from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware from .legacy import LegacyFirmware, LegacyV2Firmware
@ -38,7 +38,7 @@ if True:
from .vendor import * # noqa: F401, F403 from .vendor import * # noqa: F401, F403
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
T = t.TypeVar("T", bound="FirmwareType") T = t.TypeVar("T", bound="FirmwareType")
@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
# ====== Client functions ====== # # ====== Client functions ====== #
@session
def update( def update(
client: "TrezorClient", session: "Session",
data: bytes, data: bytes,
progress_update: t.Callable[[int], t.Any] = lambda _: None, progress_update: t.Callable[[int], t.Any] = lambda _: None,
): ):
if client.features.bootloader_mode is False: if session.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
resp = client.call(messages.FirmwareErase(length=len(data))) resp = session.call(messages.FirmwareErase(length=len(data)))
# TREZORv1 method # TREZORv1 method
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
resp = client.call(messages.FirmwareUpload(payload=data)) resp = session.call(messages.FirmwareUpload(payload=data))
progress_update(len(data)) progress_update(len(data))
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
return return
@ -97,7 +96,7 @@ def update(
length = resp.length length = resp.length
payload = data[resp.offset : resp.offset + length] payload = data[resp.offset : resp.offset + length]
digest = blake2s(payload).digest() digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
progress_update(length) progress_update(length)
if isinstance(resp, messages.Success): if isinstance(resp, messages.Success):
@ -107,5 +106,5 @@ def update(
@expect(messages.FirmwareHash, field="hash", ret_type=bytes) @expect(messages.FirmwareHash, field="hash", ret_type=bytes)
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]): def get_hash(session: "Session", challenge: t.Optional[bytes]):
return client.call(messages.GetFirmwareHash(challenge=challenge)) return session.call(messages.GetFirmwareHash(challenge=challenge))

View File

@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import io import io
import logging
from types import ModuleType from types import ModuleType
from typing import Dict, Optional, Tuple, Type, TypeVar from typing import Dict, Optional, Tuple, Type, TypeVar
@ -25,6 +26,7 @@ from typing_extensions import Self
from . import messages, protobuf from . import messages, protobuf
T = TypeVar("T") T = TypeVar("T")
LOG = logging.getLogger(__name__)
class ProtobufMapping: class ProtobufMapping:
@ -63,11 +65,21 @@ class ProtobufMapping:
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
if wire_type is None: if wire_type is None:
raise ValueError("Cannot encode class without wire type") raise ValueError("Cannot encode class without wire type")
LOG.debug("encoding wire type %d", wire_type)
buf = io.BytesIO() buf = io.BytesIO()
protobuf.dump_message(buf, msg) protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue() return wire_type, buf.getvalue()
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
"""Serialize a Python protobuf class.
Returns the byte representation of the protobuf message.
"""
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return buf.getvalue()
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
"""Deserialize a protobuf message into a Python class.""" """Deserialize a protobuf message into a Python class."""
cls = self.type_to_class[msg_wire_type] cls = self.type_to_class[msg_wire_type]
@ -83,7 +95,9 @@ class ProtobufMapping:
mapping = cls() mapping = cls()
message_types = getattr(module, "MessageType") message_types = getattr(module, "MessageType")
for entry in message_types: thp_message_types = getattr(module, "ThpMessageType")
for entry in (*message_types, *thp_message_types):
msg_class = getattr(module, entry.name, None) msg_class = getattr(module, entry.name, None)
if msg_class is None: if msg_class is None:
raise ValueError( raise ValueError(

View File

@ -43,6 +43,8 @@ class FailureType(IntEnum):
PinMismatch = 12 PinMismatch = 12
WipeCodeMismatch = 13 WipeCodeMismatch = 13
InvalidSession = 14 InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
FirmwareError = 99 FirmwareError = 99
@ -400,6 +402,34 @@ class TezosBallotType(IntEnum):
Pass = 2 Pass = 2
class ThpMessageType(IntEnum):
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
class ThpPairingMethod(IntEnum):
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4
class MessageType(IntEnum): class MessageType(IntEnum):
Initialize = 0 Initialize = 0
Ping = 1 Ping = 1
@ -4136,6 +4166,7 @@ class DebugLinkGetState(protobuf.MessageType):
1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None), 1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None),
2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None), 2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE), 3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE),
4: protobuf.Field("thp_channel_id", "bytes", repeated=False, required=False, default=None),
} }
def __init__( def __init__(
@ -4144,10 +4175,12 @@ class DebugLinkGetState(protobuf.MessageType):
wait_word_list: Optional["bool"] = None, wait_word_list: Optional["bool"] = None,
wait_word_pos: Optional["bool"] = None, wait_word_pos: Optional["bool"] = None,
wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE, wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE,
thp_channel_id: Optional["bytes"] = None,
) -> None: ) -> None:
self.wait_word_list = wait_word_list self.wait_word_list = wait_word_list
self.wait_word_pos = wait_word_pos self.wait_word_pos = wait_word_pos
self.wait_layout = wait_layout self.wait_layout = wait_layout
self.thp_channel_id = thp_channel_id
class DebugLinkState(protobuf.MessageType): class DebugLinkState(protobuf.MessageType):
@ -4166,6 +4199,9 @@ class DebugLinkState(protobuf.MessageType):
11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None), 11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None),
12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None), 12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None),
13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None), 13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None),
14: protobuf.Field("thp_pairing_code_entry_code", "uint32", repeated=False, required=False, default=None),
15: protobuf.Field("thp_pairing_code_qr_code", "bytes", repeated=False, required=False, default=None),
16: protobuf.Field("thp_pairing_code_nfc_unidirectional", "bytes", repeated=False, required=False, default=None),
} }
def __init__( def __init__(
@ -4184,6 +4220,9 @@ class DebugLinkState(protobuf.MessageType):
recovery_word_pos: Optional["int"] = None, recovery_word_pos: Optional["int"] = None,
reset_word_pos: Optional["int"] = None, reset_word_pos: Optional["int"] = None,
mnemonic_type: Optional["BackupType"] = None, mnemonic_type: Optional["BackupType"] = None,
thp_pairing_code_entry_code: Optional["int"] = None,
thp_pairing_code_qr_code: Optional["bytes"] = None,
thp_pairing_code_nfc_unidirectional: Optional["bytes"] = None,
) -> None: ) -> None:
self.tokens: Sequence["str"] = tokens if tokens is not None else [] self.tokens: Sequence["str"] = tokens if tokens is not None else []
self.layout = layout self.layout = layout
@ -4198,6 +4237,9 @@ class DebugLinkState(protobuf.MessageType):
self.recovery_word_pos = recovery_word_pos self.recovery_word_pos = recovery_word_pos
self.reset_word_pos = reset_word_pos self.reset_word_pos = reset_word_pos
self.mnemonic_type = mnemonic_type self.mnemonic_type = mnemonic_type
self.thp_pairing_code_entry_code = thp_pairing_code_entry_code
self.thp_pairing_code_qr_code = thp_pairing_code_qr_code
self.thp_pairing_code_nfc_unidirectional = thp_pairing_code_nfc_unidirectional
class DebugLinkStop(protobuf.MessageType): class DebugLinkStop(protobuf.MessageType):
@ -7860,6 +7902,280 @@ class TezosManagerTransfer(protobuf.MessageType):
self.amount = amount self.amount = amount
class ThpDeviceProperties(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None),
4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None),
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
internal_model: Optional["str"] = None,
model_variant: Optional["int"] = None,
bootloader_mode: Optional["bool"] = None,
protocol_version: Optional["int"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.internal_model = internal_model
self.model_variant = model_variant
self.bootloader_mode = bootloader_mode
self.protocol_version = protocol_version
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
host_pairing_credential: Optional["bytes"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.host_pairing_credential = host_pairing_credential
class ThpCreateNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1000
FIELDS = {
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
passphrase: Optional["str"] = None,
on_device: Optional["bool"] = None,
derive_cardano: Optional["bool"] = None,
) -> None:
self.passphrase = passphrase
self.on_device = on_device
self.derive_cardano = derive_cardano
class ThpNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1001
FIELDS = {
1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
new_session_id: Optional["int"] = None,
) -> None:
self.new_session_id = new_session_id
class ThpStartPairingRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["str"] = None,
) -> None:
self.host_name = host_name
class ThpPairingPreparationsFinished(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1009
class ThpCodeEntryCommitment(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1016
FIELDS = {
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
commitment: Optional["bytes"] = None,
) -> None:
self.commitment = commitment
class ThpCodeEntryChallenge(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1017
FIELDS = {
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
challenge: Optional["bytes"] = None,
) -> None:
self.challenge = challenge
class ThpCodeEntryCpaceHost(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1018
FIELDS = {
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_host_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_host_public_key = cpace_host_public_key
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1019
FIELDS = {
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_trezor_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_trezor_public_key = cpace_trezor_public_key
class ThpCodeEntryTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1020
FIELDS = {
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpCodeEntrySecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1021
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpQrCodeTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1024
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpQrCodeSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1025
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpNfcUnidirectionalTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1032
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1033
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpCredentialRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1010
FIELDS = {
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_static_pubkey: Optional["bytes"] = None,
) -> None:
self.host_static_pubkey = host_static_pubkey
class ThpCredentialResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1011
FIELDS = {
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
trezor_static_pubkey: Optional["bytes"] = None,
credential: Optional["bytes"] = None,
) -> None:
self.trezor_static_pubkey = trezor_static_pubkey
self.credential = credential
class ThpEndRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1012
class ThpEndResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1013
class ThpCredentialMetadata(protobuf.MessageType): class ThpCredentialMetadata(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None MESSAGE_WIRE_TYPE = None
FIELDS = { FIELDS = {

View File

@ -20,25 +20,25 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
@expect(messages.Entropy, field="entropy", ret_type=bytes) @expect(messages.Entropy, field="entropy", ret_type=bytes)
def get_entropy(client: "TrezorClient", size: int) -> "MessageType": def get_entropy(session: "Session", size: int) -> "MessageType":
return client.call(messages.GetEntropy(size=size)) return session.call(messages.GetEntropy(size=size))
@expect(messages.SignedIdentity) @expect(messages.SignedIdentity)
def sign_identity( def sign_identity(
client: "TrezorClient", session: "Session",
identity: messages.IdentityType, identity: messages.IdentityType,
challenge_hidden: bytes, challenge_hidden: bytes,
challenge_visual: str, challenge_visual: str,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SignIdentity( messages.SignIdentity(
identity=identity, identity=identity,
challenge_hidden=challenge_hidden, challenge_hidden=challenge_hidden,
@ -50,12 +50,12 @@ def sign_identity(
@expect(messages.ECDHSessionKey) @expect(messages.ECDHSessionKey)
def get_ecdh_session_key( def get_ecdh_session_key(
client: "TrezorClient", session: "Session",
identity: messages.IdentityType, identity: messages.IdentityType,
peer_public_key: bytes, peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.GetECDHSessionKey( messages.GetECDHSessionKey(
identity=identity, identity=identity,
peer_public_key=peer_public_key, peer_public_key=peer_public_key,
@ -66,7 +66,7 @@ def get_ecdh_session_key(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) @expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
key: str, key: str,
value: bytes, value: bytes,
@ -74,7 +74,7 @@ def encrypt_keyvalue(
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
@ -89,7 +89,7 @@ def encrypt_keyvalue(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) @expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
key: str, key: str,
value: bytes, value: bytes,
@ -97,7 +97,7 @@ def decrypt_keyvalue(
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
@ -111,5 +111,5 @@ def decrypt_keyvalue(
@expect(messages.Nonce, field="nonce", ret_type=bytes) @expect(messages.Nonce, field="nonce", ret_type=bytes)
def get_nonce(client: "TrezorClient"): def get_nonce(session: "Session"):
return client.call(messages.GetNonce()) return session.call(messages.GetNonce())

View File

@ -20,9 +20,9 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
# MAINNET = 0 # MAINNET = 0
@ -33,13 +33,13 @@ if TYPE_CHECKING:
@expect(messages.MoneroAddress, field="address", ret_type=bytes) @expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.MoneroGetAddress( messages.MoneroGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
@ -51,10 +51,10 @@ def get_address(
@expect(messages.MoneroWatchKey) @expect(messages.MoneroWatchKey)
def get_watch_key( def get_watch_key(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type) messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
) )

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801 TYPE_IMPORTANCE_TRANSFER = 0x0801
@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
@expect(messages.NEMAddress, field="address", ret_type=str) @expect(messages.NEMAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
network: int, network: int,
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.NEMGetAddress( messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify address_n=n, network=network, show_display=show_display, chunkify=chunkify
) )
@ -213,7 +213,7 @@ def get_address(
@expect(messages.NEMSignedTx) @expect(messages.NEMSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False session: "Session", n: "Address", transaction: dict, chunkify: bool = False
) -> "MessageType": ) -> "MessageType":
try: try:
msg = create_sign_tx(transaction, chunkify=chunkify) msg = create_sign_tx(transaction, chunkify=chunkify)
@ -222,4 +222,4 @@ def sign_tx(
assert msg.transaction is not None assert msg.transaction is not None
msg.transaction.address_n = n msg.transaction.address_n = n
return client.call(msg) return session.call(msg)

View File

@ -21,9 +21,9 @@ from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect from .tools import dict_from_camelcase, expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@ -31,12 +31,12 @@ REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address", ret_type=str) @expect(messages.RippleAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.RippleGetAddress( messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -45,14 +45,14 @@ def get_address(
@expect(messages.RippleSignedTx) @expect(messages.RippleSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
msg: messages.RippleSignTx, msg: messages.RippleSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
msg.address_n = address_n msg.address_n = address_n
msg.chunkify = chunkify msg.chunkify = chunkify
return client.call(msg) return session.call(msg)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -4,29 +4,29 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.session import Session
@expect(messages.SolanaPublicKey) @expect(messages.SolanaPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display) messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
) )
@expect(messages.SolanaAddress) @expect(messages.SolanaAddress)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SolanaGetAddress( messages.SolanaGetAddress(
address_n=address_n, address_n=address_n,
show_display=show_display, show_display=show_display,
@ -37,12 +37,12 @@ def get_address(
@expect(messages.SolanaTxSignature) @expect(messages.SolanaTxSignature)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: List[int], address_n: List[int],
serialized_tx: bytes, serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo], additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SolanaSignTx( messages.SolanaSignTx(
address_n=address_n, address_n=address_n,
serialized_tx=serialized_tx, serialized_tx=serialized_tx,

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
StellarMessageType = Union[ StellarMessageType = Union[
messages.StellarAccountMergeOp, messages.StellarAccountMergeOp,
@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
@expect(messages.StellarAddress, field="address", ret_type=str) @expect(messages.StellarAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.StellarGetAddress( messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -338,7 +338,7 @@ def get_address(
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
tx: messages.StellarSignTx, tx: messages.StellarSignTx,
operations: List["StellarMessageType"], operations: List["StellarMessageType"],
address_n: "Address", address_n: "Address",
@ -354,10 +354,10 @@ def sign_tx(
# 3. Receive a StellarTxOpRequest message # 3. Receive a StellarTxOpRequest message
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
# 5. The final message received will be StellarSignedTx which is returned from this method # 5. The final message received will be StellarSignedTx which is returned from this method
resp = client.call(tx) resp = session.call(tx)
try: try:
while isinstance(resp, messages.StellarTxOpRequest): while isinstance(resp, messages.StellarTxOpRequest):
resp = client.call(operations.pop(0)) resp = session.call(operations.pop(0))
except IndexError: except IndexError:
# pop from empty list # pop from empty list
raise exceptions.TrezorException( raise exceptions.TrezorException(

View File

@ -20,19 +20,19 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .tools import Address from .tools import Address
from .transport.session import Session
@expect(messages.TezosAddress, field="address", ret_type=str) @expect(messages.TezosAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.TezosGetAddress( messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -41,12 +41,12 @@ def get_address(
@expect(messages.TezosPublicKey, field="public_key", ret_type=str) @expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key( def get_public_key(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.TezosGetPublicKey( messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) )
@ -55,11 +55,11 @@ def get_public_key(
@expect(messages.TezosSignedTx) @expect(messages.TezosSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
address_n: "Address", address_n: "Address",
sign_tx_msg: messages.TezosSignTx, sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
sign_tx_msg.address_n = address_n sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg) return session.call(sign_tx_msg)

View File

@ -40,7 +40,7 @@ if TYPE_CHECKING:
# More details: https://www.python.org/dev/peps/pep-0612/ # More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec from typing_extensions import ParamSpec
from . import client from . import client
from .protobuf import MessageType from .protobuf import MessageType
@ -301,23 +301,6 @@ def expect(
return decorator return decorator
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
return f(client, *args, **kwargs)
finally:
client.close()
return wrapped_f
# de-camelcasifier # de-camelcasifier
# https://stackoverflow.com/a/1176023/222189 # https://stackoverflow.com/a/1176023/222189

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project. # This file is part of the Trezor project.
# #
# Copyright (C) 2012-2022 SatoshiLabs and contributors # Copyright (C) 2012-2024 SatoshiLabs and contributors
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 # it under the terms of the GNU Lesser General Public License version 3
@ -14,24 +14,18 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
from typing import ( import typing as t
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from ..exceptions import TrezorException from ..exceptions import TrezorException
if TYPE_CHECKING: if t.TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
T = TypeVar("T", bound="Transport") T = t.TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip() """.strip()
MessagePayload = Tuple[int, bytes] MessagePayload = t.Tuple[int, bytes]
class TransportException(TrezorException): class TransportException(TrezorException):
@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException):
class Transport: class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX: str PATH_PREFIX: str
ENABLED = False
def __str__(self) -> str: @classmethod
return self.get_path() def enumerate(
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
) -> t.Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if device.get_path() == path:
return device
if prefix_search and device.get_path().startswith(path):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str: def get_path(self) -> str:
raise NotImplementedError raise NotImplementedError
def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
def find_debug(self: "T") -> "T": def find_debug(self: "T") -> "T":
raise NotImplementedError raise NotImplementedError
@classmethod def open(self) -> None:
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
raise NotImplementedError raise NotImplementedError
@classmethod def close(self) -> None:
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": raise NotImplementedError
for device in cls.enumerate():
if (
path is None
or device.get_path() == path
or (prefix_search and device.get_path().startswith(path))
):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self) -> bytes:
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int]
def all_transports() -> Iterable[Type["Transport"]]: def all_transports() -> t.Iterable[t.Type["Transport"]]:
from .bridge import BridgeTransport from .bridge import BridgeTransport
from .hid import HidTransport from .hid import HidTransport
from .udp import UdpTransport from .udp import UdpTransport
from .webusb import WebUsbTransport from .webusb import WebUsbTransport
transports: Tuple[Type["Transport"], ...] = ( transports: t.Tuple[t.Type["Transport"], ...] = (
BridgeTransport, BridgeTransport,
HidTransport, HidTransport,
UdpTransport, UdpTransport,
@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
def enumerate_devices( def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None, models: t.Iterable["TrezorModel"] | None = None,
) -> Sequence["Transport"]: ) -> t.Sequence["Transport"]:
devices: List["Transport"] = [] devices: t.List["Transport"] = []
for transport in all_transports(): for transport in all_transports():
name = transport.__name__ name = transport.__name__
try: try:
@ -145,9 +121,7 @@ def enumerate_devices(
return devices return devices
def get_transport( def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
if path is None: if path is None:
try: try:
return next(iter(enumerate_devices())) return next(iter(enumerate_devices()))

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project. # This file is part of the Trezor project.
# #
# Copyright (C) 2012-2022 SatoshiLabs and contributors # Copyright (C) 2012-2024 SatoshiLabs and contributors
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 # it under the terms of the GNU Lesser General Public License version 3
@ -14,24 +14,30 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import struct import struct
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional import typing as t
import requests import requests
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException from . import DeviceIsBusy, MessagePayload, Transport, TransportException
if TYPE_CHECKING: if t.TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
PROTOCOL_VERSION_1 = 1
PROTOCOL_VERSION_2 = 2
TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_HOST = "http://127.0.0.1:21325"
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
TREZORD_VERSION_MODERN = (2, 0, 25) TREZORD_VERSION_MODERN = (2, 0, 25)
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
CONNECTION = requests.Session() CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
@ -45,7 +51,7 @@ class BridgeException(TransportException):
super().__init__(f"trezord: {path} failed with code {status}: {message}") super().__init__(f"trezord: {path} failed with code {status}: {message}")
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: def call_bridge(path: str, data: str | None = None) -> requests.Response:
url = TREZORD_HOST + "/" + path url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data) r = CONNECTION.post(url, data=data)
if r.status_code != 200: if r.status_code != 200:
@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
return r return r
def is_legacy_bridge() -> bool: def get_bridge_version() -> t.Tuple[int, ...]:
config = call_bridge("configure").json() config = call_bridge("configure").json()
version_tuple = tuple(map(int, config["version"].split("."))) return tuple(map(int, config["version"].split(".")))
return version_tuple < TREZORD_VERSION_MODERN
def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN
def supports_protocolV2() -> bool:
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
def detect_protocol_version(transport: "BridgeTransport") -> int:
from .. import mapping, messages
from ..messages import FailureType
protocol_version = PROTOCOL_VERSION_1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
transport.deprecated_begin_session()
transport.deprecated_write(request_type, request_data)
response_type, response_data = transport.deprecated_read()
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
transport.deprecated_begin_session()
if isinstance(response, messages.Failure):
if response.code == FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol_version = PROTOCOL_VERSION_2
return protocol_version
def _is_transport_valid(transport: "BridgeTransport") -> bool:
is_valid = (
supports_protocolV2()
or detect_protocol_version(transport) == PROTOCOL_VERSION_1
)
if not is_valid:
LOG.warning("Detected unsupported Bridge transport!")
return is_valid
def filter_invalid_bridge_transports(
transports: t.Iterable["BridgeTransport"],
) -> t.Sequence["BridgeTransport"]:
"""Filters out invalid bridge transports. Keeps only valid ones."""
return [t for t in transports if _is_transport_valid(t)]
class BridgeHandle: class BridgeHandle:
@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle):
class BridgeHandleLegacy(BridgeHandle): class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None: def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport) super().__init__(transport)
self.request: Optional[str] = None self.request: str | None = None
def write_buf(self, buf: bytes) -> None: def write_buf(self, buf: bytes) -> None:
if self.request is not None: if self.request is not None:
@ -112,13 +162,12 @@ class BridgeTransport(Transport):
ENABLED: bool = True ENABLED: bool = True
def __init__( def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
) -> None: ) -> None:
if legacy and debug: if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge") raise TransportException("Debugging not supported on legacy Bridge")
self.device = device self.device = device
self.session: Optional[str] = None self.session: str | None = device["session"]
self.debug = debug self.debug = debug
self.legacy = legacy self.legacy = legacy
@ -135,7 +184,7 @@ class BridgeTransport(Transport):
raise TransportException("Debug device not available") raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True) return BridgeTransport(self.device, self.legacy, debug=True)
def _call(self, action: str, data: Optional[str] = None) -> requests.Response: def _call(self, action: str, data: str | None = None) -> requests.Response:
session = self.session or "null" session = self.session or "null"
uri = action + "/" + str(session) uri = action + "/" + str(session)
if self.debug: if self.debug:
@ -144,17 +193,20 @@ class BridgeTransport(Transport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: t.Iterable["TrezorModel"] | None = None
) -> Iterable["BridgeTransport"]: ) -> t.Iterable["BridgeTransport"]:
try: try:
legacy = is_legacy_bridge() legacy = is_legacy_bridge()
return [ return filter_invalid_bridge_transports(
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() [
] BridgeTransport(dev, legacy)
for dev in call_bridge("enumerate").json()
]
)
except Exception: except Exception:
return [] return []
def begin_session(self) -> None: def deprecated_begin_session(self) -> None:
try: try:
data = self._call("acquire/" + self.device["path"]) data = self._call("acquire/" + self.device["path"])
except BridgeException as e: except BridgeException as e:
@ -163,18 +215,32 @@ class BridgeTransport(Transport):
raise raise
self.session = data.json()["session"] self.session = data.json()["session"]
def end_session(self) -> None: def deprecated_end_session(self) -> None:
if not self.session: if not self.session:
return return
self._call("release") self._call("release")
self.session = None self.session = None
def write(self, message_type: int, message_data: bytes) -> None: def deprecated_write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data)) header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data) self.handle.write_buf(header + message_data)
def read(self) -> MessagePayload: def deprecated_read(self) -> MessagePayload:
data = self.handle.read_buf() data = self.handle.read_buf()
headerlen = struct.calcsize(">HL") headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen]) msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen] return msg_type, data[headerlen : headerlen + datalen]
def open(self) -> None:
pass
# TODO self.handle.open()
def close(self) -> None:
pass
# TODO self.handle.close()
def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :)
self.handle.write_buf(chunk)
def read_chunk(self) -> bytes: # TODO check if it works :)
return self.handle.read_buf()

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project. # This file is part of the Trezor project.
# #
# Copyright (C) 2012-2022 SatoshiLabs and contributors # Copyright (C) 2012-2024 SatoshiLabs and contributors
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 # it under the terms of the GNU Lesser General Public License version 3
@ -14,15 +14,16 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import sys import sys
import time import time
from typing import Any, Dict, Iterable, List, Optional import typing as t
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException from . import UDEV_RULES_STR, Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -35,23 +36,61 @@ except Exception as e:
HID_IMPORTED = False HID_IMPORTED = False
HidDevice = Dict[str, Any] HidDevice = t.Dict[str, t.Any]
HidDeviceHandle = Any HidDeviceHandle = t.Any
class HidHandle: class HidTransport(Transport):
def __init__( """
self, path: bytes, serial: str, probe_hid_version: bool = False HidTransport implements transport over USB HID interface.
) -> None: """
self.path = path
self.serial = serial PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
self.device = device
self.device_path = device["path"]
self.device_serial_number = device["serial_number"]
self.handle: HidDeviceHandle = None self.handle: HidDeviceHandle = None
self.hid_version = None if probe_hid_version else 2 self.hid_version = None if probe_hid_version else 2
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
) -> t.Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: t.List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self) -> None: def open(self) -> None:
self.handle = hid.device() self.handle = hid.device()
try: try:
self.handle.open_path(self.path) self.handle.open_path(self.device_path)
except (IOError, OSError) as e: except (IOError, OSError) as e:
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,) e.args = e.args + (UDEV_RULES_STR,)
@ -62,11 +101,11 @@ class HidHandle:
# and we wouldn't even know. # and we wouldn't even know.
# So we check that the serial matches what we expect. # So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string() serial = self.handle.get_serial_number_string()
if serial != self.serial: if serial != self.device_serial_number:
self.handle.close() self.handle.close()
self.handle = None self.handle = None
raise TransportException( raise TransportException(
f"Unexpected device {serial} on path {self.path.decode()}" f"Unexpected device {serial} on path {self.device_path.decode()}"
) )
self.handle.set_nonblocking(True) self.handle.set_nonblocking(True)
@ -77,7 +116,7 @@ class HidHandle:
def close(self) -> None: def close(self) -> None:
if self.handle is not None: if self.handle is not None:
# reload serial, because device.wipe() can reset it # reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string() self.device_serial_number = self.handle.get_serial_number_string()
self.handle.close() self.handle.close()
self.handle = None self.handle = None
@ -115,53 +154,6 @@ class HidHandle:
raise TransportException("Unknown HID version") raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
super().__init__(protocol=ProtocolV1(self.handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool: def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0

View File

@ -1,165 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import MessagePayload, Transport
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None: ...
def close(self) -> None: ...
def read_chunk(self) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self) -> MessagePayload:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
HEADER_LEN = struct.calcsize(">HL")
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -0,0 +1,210 @@
from __future__ import annotations
import logging
import typing as t
from .. import exceptions, messages, models
from .thp.protocol_v1 import ProtocolV1
from .thp.protocol_v2 import ProtocolV2
if t.TYPE_CHECKING:
from ..client import TrezorClient
LOG = logging.getLogger(__name__)
class Session:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
def __init__(
self, client: TrezorClient, id: bytes, passphrase: str | object | None = None
) -> None:
self.client = client
self._id = id
self.passphrase = passphrase
@classmethod
def new(
cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool
) -> Session:
raise NotImplementedError
def call(self, msg: t.Any) -> t.Any:
# TODO self.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
if self.pin_callback is None:
raise Exception # TODO
resp = self.pin_callback(self, resp)
elif isinstance(resp, messages.PassphraseRequest):
if self.passphrase_callback is None:
raise Exception # TODO
resp = self.passphrase_callback(self, resp)
elif isinstance(resp, messages.ButtonRequest):
if self.button_callback is None:
raise Exception # TODO
resp = self.button_callback(self, resp)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
else:
return resp
def call_raw(self, msg: t.Any) -> t.Any:
self._write(msg)
return self._read()
def _write(self, msg: t.Any) -> None:
raise NotImplementedError
def _read(self) -> t.Any:
raise NotImplementedError
def refresh_features(self) -> None:
self.client.refresh_features()
def end(self) -> t.Any:
return self.call(messages.EndSession())
def ping(self, message: str, button_protection: bool | None = None) -> str:
resp: messages.Success = self.call(
messages.Ping(message=message, button_protection=button_protection)
)
return resp.message or ""
@property
def features(self) -> messages.Features:
return self.client.features
@property
def model(self) -> models.TrezorModel:
return self.client.model
@property
def version(self) -> t.Tuple[int, int, int]:
return self.client.version
@property
def id(self) -> bytes:
return self._id
@id.setter
def id(self, value: bytes) -> None:
if not isinstance(value, bytes):
raise ValueError("id must be of type bytes")
self._id = value
class SessionV1(Session):
derive_cardano: bool | None = False
@classmethod
def new(
cls,
client: TrezorClient,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, id=session_id or b"")
session._init_callbacks()
session.passphrase = passphrase
session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano)
return session
@classmethod
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, session_id)
session.init_session()
return session
def _init_callbacks(self) -> None:
self.button_callback = self.client.button_callback
if self.button_callback is None:
self.button_callback = _callback_button
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self.client.passphrase_callback
def _write(self, msg: t.Any) -> None:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
self.client.protocol.write(msg)
def _read(self) -> t.Any:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None):
if self.id == b"":
session_id = None
else:
session_id = self.id
resp: messages.Features = self.call_raw(
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
)
if isinstance(self.passphrase, str):
self.passphrase_callback = _send_passphrase
self._id = resp.session_id
def _send_passphrase(session: Session, resp: t.Any) -> None:
assert isinstance(session.passphrase, str)
return session.call(messages.PassphraseAck(passphrase=session.passphrase))
def _callback_button(session: Session, msg: t.Any) -> t.Any:
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
return session.call(messages.ButtonAck())
class SessionV2(Session):
@classmethod
def new(
cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2)
session = cls(client, b"\x00")
new_session: messages.ThpNewSession = session.call(
messages.ThpCreateNewSession(
passphrase=passphrase, derive_cardano=derive_cardano
)
)
assert new_session.new_session_id is not None
session_id = new_session.new_session_id
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: TrezorClient, id: bytes) -> None:
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2)
self.pin_callback = client.pin_callback
self.button_callback = client.button_callback
if self.button_callback is None:
self.button_callback = _callback_button
self.channel: ProtocolV2 = client.protocol.get_channel()
self.update_id_and_sid(id)
def _write(self, msg: t.Any) -> None:
LOG.debug("writing message %s", type(msg))
self.channel.write(self.sid, msg)
def _read(self) -> t.Any:
msg = self.channel.read(self.sid)
LOG.debug("reading message %s", type(msg))
return msg
def update_id_and_sid(self, id: bytes) -> None:
self._id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

View File

@ -0,0 +1,102 @@
# from storage.cache_thp import ChannelCache
# from trezor import log
# from trezor.wire.thp import ThpError
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
# """
# Checks if:
# - an ACK message is expected
# - the received ACK message acknowledges correct sequence number (bit)
# """
# if not _is_ack_expected(cache):
# return False
# if not _has_ack_correct_sync_bit(cache, ack_bit):
# return False
# return True
# def _is_ack_expected(cache: ChannelCache) -> bool:
# is_expected: bool = not is_sending_allowed(cache)
# if __debug__ and not is_expected:
# log.debug(__name__, "Received unexpected ACK message")
# return is_expected
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
# if __debug__ and not is_correct:
# log.debug(__name__, "Received ACK message with wrong ack bit")
# return is_correct
# def is_sending_allowed(cache: ChannelCache) -> bool:
# """
# Checks whether sending a message in the provided channel is allowed.
# Note: Sending a message in a channel before receipt of ACK message for the previously
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
# """
# return bool(cache.sync >> 7)
# def get_send_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the sequential number (bit) of the next message to be sent
# in the provided channel.
# """
# return (cache.sync & 0x20) >> 5
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the (expected) sequential number (bit) of the next message
# to be received in the provided channel.
# """
# return (cache.sync & 0x40) >> 6
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
# """
# Set the flag whether sending a message in this channel is allowed or not.
# """
# cache.sync &= 0x7F
# if sending_allowed:
# cache.sync |= 0x80
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# """
# Set the expected sequential number (bit) of the next message to be received
# in the provided channel
# """
# if __debug__:
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected receive sync bit")
# # set second bit to "seq_bit" value
# cache.sync &= 0xBF
# if seq_bit:
# cache.sync |= 0x40
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected send seq bit")
# if __debug__:
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# # set third bit to "seq_bit" value
# cache.sync &= 0xDF
# if seq_bit:
# cache.sync |= 0x20
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
# """
# Set the sequential bit of the "next message to be send" to the opposite value,
# i.e. 1 -> 0 and 0 -> 1
# """
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from binascii import hexlify
class ChannelData:
def __init__(
self,
protocol_version: int,
transport_path: str,
channel_id: int,
key_request: bytes,
key_response: bytes,
nonce_request: int,
nonce_response: int,
sync_bit_send: int,
sync_bit_receive: int,
) -> None:
self.protocol_version: int = protocol_version
self.transport_path: str = transport_path
self.channel_id: int = channel_id
self.key_request: str = hexlify(key_request).decode()
self.key_response: str = hexlify(key_response).decode()
self.nonce_request: int = nonce_request
self.nonce_response: int = nonce_response
self.sync_bit_receive: int = sync_bit_receive
self.sync_bit_send: int = sync_bit_send
def to_dict(self):
return {
"protocol_version": self.protocol_version,
"transport_path": self.transport_path,
"channel_id": self.channel_id,
"key_request": self.key_request,
"key_response": self.key_response,
"nonce_request": self.nonce_request,
"nonce_response": self.nonce_response,
"sync_bit_send": self.sync_bit_send,
"sync_bit_receive": self.sync_bit_receive,
}

View File

@ -0,0 +1,146 @@
from __future__ import annotations
import json
import logging
import os
import typing as t
from ..thp.channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
db: "ChannelDatabase | None" = None
def get_channel_db() -> ChannelDatabase:
if db is None:
set_channel_database(should_not_store=True)
assert db is not None
return db
class ChannelDatabase:
def load_stored_channels(self) -> t.List[ChannelData]: ...
def clear_stored_channels(self) -> None: ...
def read_all_channels(self) -> t.List: ...
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
def save_channel(self, new_channel: ProtocolAndChannel): ...
def remove_channel(self, transport_path: str) -> None: ...
class DummyChannelDatabase(ChannelDatabase):
def load_stored_channels(self) -> t.List[ChannelData]:
return []
def clear_stored_channels(self) -> None:
pass
def read_all_channels(self) -> t.List:
return []
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
return
def save_channel(self, new_channel: ProtocolAndChannel):
pass
def remove_channel(self, transport_path: str) -> None:
pass
class JsonChannelDatabase(ChannelDatabase):
def __init__(self, data_path: str) -> None:
self.data_path = data_path
super().__init__()
def load_stored_channels(self) -> t.List[ChannelData]:
dicts = self.read_all_channels()
return [dict_to_channel_data(d) for d in dicts]
def clear_stored_channels(self) -> None:
LOG.debug("Clearing contents of %s", self.data_path)
with open(self.data_path, "w") as f:
json.dump([], f)
try:
os.remove(self.data_path)
except Exception as e:
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
def read_all_channels(self) -> t.List:
ensure_file_exists(self.data_path)
with open(self.data_path, "r") as f:
return json.load(f)
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(self.data_path, "w") as f:
json.dump(channels, f, indent=4)
def save_channel(self, new_channel: ProtocolAndChannel):
LOG.debug("save channel")
channels = self.read_all_channels()
transport_path = new_channel.transport.get_path()
# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
self.save_all_channels(channels)
return
# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
self.save_all_channels(channels)
def remove_channel(self, transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = self.read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
self.save_all_channels(remaining_channels)
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
return ChannelData(
protocol_version=dict["protocol_version"],
transport_path=dict["transport_path"],
channel_id=dict["channel_id"],
key_request=bytes.fromhex(dict["key_request"]),
key_response=bytes.fromhex(dict["key_response"]),
nonce_request=dict["nonce_request"],
nonce_response=dict["nonce_response"],
sync_bit_send=dict["sync_bit_send"],
sync_bit_receive=dict["sync_bit_receive"],
)
def ensure_file_exists(file_path: str) -> None:
LOG.debug("checking if file %s exists", file_path)
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
LOG.debug("File %s does not exist. Creating a new one.", file_path)
with open(file_path, "w") as f:
json.dump([], f)
def set_channel_database(should_not_store: bool):
global db
if should_not_store:
db = DummyChannelDatabase()
else:
from platformdirs import user_cache_dir
APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
db = JsonChannelDatabase(DATA_PATH)

View File

@ -0,0 +1,19 @@
import zlib
CHECKSUM_LENGTH = 4
def compute(data: bytes) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,59 @@
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise Exception("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise Exception("Unexpected acknowledgement bit")
def get_seq_bit(ctrl_byte: int) -> int:
return (ctrl_byte & 0x10) >> 4
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,116 @@
from typing import Tuple
p = 2**255 - 19
J = 486662
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
a24 = 121666 # (J + 2) // 4
def decode_scalar(scalar: bytes) -> int:
# decodeScalar25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(scalar) != 32:
raise ValueError("Invalid length of scalar")
array = bytearray(scalar)
array[0] &= 248
array[31] &= 127
array[31] |= 64
return int.from_bytes(array, "little")
def decode_coordinate(coordinate: bytes) -> int:
# decodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(coordinate) != 32:
raise ValueError("Invalid length of coordinate")
array = bytearray(coordinate)
array[-1] &= 0x7F
return int.from_bytes(array, "little") % p
def encode_coordinate(coordinate: int) -> bytes:
# encodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
return coordinate.to_bytes(32, "little")
def get_private_key(secret: bytes) -> bytes:
return decode_scalar(secret).to_bytes(32, "little")
def get_public_key(private_key: bytes) -> bytes:
base_point = int.to_bytes(9, 32, "little")
return multiply(private_key, base_point)
def multiply(private_scalar: bytes, public_point: bytes):
# X25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
def ladder_operation(
x1: int, x2: int, z2: int, x3: int, z3: int
) -> Tuple[int, int, int, int]:
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
# (x4, z4) = 2 * (x2, z2)
# (x5, z5) = (x2, z2) + (x3, z3)
# where (x1, 1) = (x3, z3) - (x2, z2)
a = (x2 + z2) % p
aa = (a * a) % p
b = (x2 - z2) % p
bb = (b * b) % p
e = (aa - bb) % p
c = (x3 + z3) % p
d = (x3 - z3) % p
da = (d * a) % p
cb = (c * b) % p
t0 = (da + cb) % p
x5 = (t0 * t0) % p
t1 = (da - cb) % p
t2 = (t1 * t1) % p
z5 = (x1 * t2) % p
x4 = (aa * bb) % p
t3 = (a24 * e) % p
t4 = (bb + t3) % p
z4 = (e * t4) % p
return x4, z4, x5, z5
def conditional_swap(first: int, second: int, condition: int):
# Returns (second, first) if condition is true and (first, second) otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
first & true_mask
)
k = decode_scalar(private_scalar)
u = decode_coordinate(public_point)
x_1 = u
x_2 = 1
z_2 = 0
x_3 = u
z_3 = 1
swap = 0
for i in reversed(range(256)):
bit = (k >> i) & 1
swap = bit ^ swap
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
swap = bit
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
x = pow(z_2, p - 2, p) * x_2 % p
return encode_coordinate(x)

View File

@ -0,0 +1,82 @@
import struct
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
BROADCAST_CHANNEL_ID = 0xFFFF
class MessageHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.data_length = length
def to_bytes_init(self) -> bytes:
return struct.pack(
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
)
def to_bytes_cont(self) -> bytes:
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.data_length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
def is_ack(self) -> bool:
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_channel_allocation_response(self):
return (
self.cid == BROADCAST_CHANNEL_ID
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
)
def is_handshake_init_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
def is_handshake_comp_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
def is_encrypted_transport(self) -> bool:
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
@classmethod
def get_error_header(cls, cid: int, length: int):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_request_header(cls, length: int):
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)

View File

@ -0,0 +1,32 @@
from __future__ import annotations
import logging
from ... import messages
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp.channel_data import ChannelData
LOG = logging.getLogger(__name__)
class ProtocolAndChannel:
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.transport = transport
self.mapping = mapping
self.channel_keys = channel_data
def get_features(self) -> messages.Features:
raise NotImplementedError()
def get_channel_data(self) -> ChannelData:
raise NotImplementedError
def update_features(self) -> None:
raise NotImplementedError

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
class ProtocolV1(ProtocolAndChannel):
HEADER_LEN = struct.calcsize(">HL")
_features: messages.Features | None = None
def get_features(self) -> messages.Features:
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
def read(self) -> t.Any:
msg_type, msg_bytes = self._read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
self.transport.close()
return msg
def write(self, msg: t.Any) -> None:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: chunk_size - 1]
chunk = chunk.ljust(chunk_size, b"\x00")
self.transport.write_chunk(chunk)
buffer = buffer[63:]
def _read(self) -> t.Tuple[int, bytes]:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> t.Tuple[int, int, bytes]:
chunk = self.transport.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.transport.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -0,0 +1,404 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import typing as t
from binascii import hexlify
from enum import IntEnum
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, curve25519, thp_io
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
MANAGEMENT_SESSION_ID: int = 0
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2(ProtocolAndChannel):
channel_id: int
channel_database: ChannelDatabase
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
sync_bit_send: int
sync_bit_receive: int
_has_valid_channel: bool = False
_features: messages.Features | None = None
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.channel_database: ChannelDatabase = get_channel_db()
super().__init__(transport, mapping, channel_data)
if channel_data is not None:
self.channel_id = channel_data.channel_id
self.key_request = bytes.fromhex(channel_data.key_request)
self.key_response = bytes.fromhex(channel_data.key_response)
self.nonce_request = channel_data.nonce_request
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self._has_valid_channel = True
def get_channel(self) -> ProtocolV2:
if not self._has_valid_channel:
self._establish_new_channel()
return self
def get_channel_data(self) -> ChannelData:
return ChannelData(
protocol_version=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.key_request,
key_response=self.key_response,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
sync_bit_send=self.sync_bit_send,
)
def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on a different session.")
self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
self.channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel()
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message)
self.session_id: int = 0
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK
_, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
if not isinstance(features, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features
def _establish_new_channel(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
# Send channel allocation request
# Note that [:8] on the following line is required when tests use
# WITH_MOCK_URANDOM. Without [:8] such tests will (almost always) fail.
channel_id_request_nonce = os.urandom(8)[:8]
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
MessageHeader.get_channel_allocation_request_header(12),
channel_id_request_nonce,
)
# Read channel allocation response
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, channel_id_request_nonce
):
# TODO raise exception here, I guess
raise Exception("Invalid channel allocation response.")
self.channel_id = int.from_bytes(payload[8:10], "big")
self.device_properties = payload[10:]
# Send handshake init request
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
# Note that [:32] on the following line is required when tests use
# WITH_MOCK_URANDOM. Without [:32] such tests will (almost always) fail.
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)[:32])
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
# Read handshake init response
header, payload = self._read_until_valid_crc_check()
self._send_ack_0()
if not header.is_handshake_init_response():
click.echo(
"Received message is not a valid handshake init response message",
err=True,
)
trezor_ephemeral_pubkey = payload[:32]
encrypted_trezor_static_pubkey = payload[32:80]
noise_tag = payload[80:96]
# TODO check noise tag
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# Prepare and send handshake completion request
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
except Exception as e:
click.echo(
f"Exception of type{type(e)}", err=True
) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
pairing_methods=[
messages.ThpPairingMethod.NoMethod,
]
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload)
ha_completion_req_header = MessageHeader(
0x12,
self.channel_id,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
click.echo(
"Received message is not a valid handshake completion response",
err=True,
)
self._send_ack_1()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
# Send StartPairingReqest message
message = messages.ThpStartPairingRequest()
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
# Read
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpEndResponse)
self._has_valid_channel = True
def _send_ack_0(self):
LOG.debug("sending ack 0")
header = MessageHeader(0x20, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _send_ack_1(self):
LOG.debug("sending ack 1")
header = MessageHeader(0x28, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _encrypt_and_write(
self,
session_id: int,
message_type: int,
message_data: bytes,
ctrl_byte: int | None = None,
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
if ctrl_byte is None:
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
self.sync_bit_send = 1 - self.sync_bit_send
sid = session_id.to_bytes(1, "big")
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, header, encrypted_message
)
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if control_byte.is_ack(header.ctrl_byte):
return self.read_and_decrypt()
if not header.is_encrypted_transport():
click.echo(
"Trying to decrypt not encrypted message!"
+ hexlify(header.to_bytes_init() + raw_payload).decode(),
err=True,
)
if not control_byte.is_ack(header.ctrl_byte):
LOG.debug(
"--> Get sequence bit %d %s %s",
control_byte.get_seq_bit(header.ctrl_byte),
"from control byte",
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
)
if control_byte.get_seq_bit(header.ctrl_byte):
self._send_ack_1()
else:
self._send_ack_0()
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
session_id,
int.from_bytes(message_type, "big"),
message_data,
)
def _read_until_valid_crc_check(
self,
) -> t.Tuple[MessageHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.transport)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
click.echo(
"Received a message with an invalid checksum:"
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
err=True,
)
header, payload, chksum = thp_io.read(self.transport)
return header, payload
def _is_valid_channel_allocation_response(
self, header: MessageHeader, payload: bytes, original_nonce: bytes
) -> bool:
if not header.is_channel_allocation_response():
click.echo(
"Received message is not a channel allocation response", err=True
)
return False
if len(payload) < 10:
click.echo("Invalid channel allocation response payload", err=True)
return False
if payload[:8] != original_nonce:
click.echo(
"Invalid channel allocation response payload (nonce mismatch)", err=True
)
return False
return True
class ControlByteType(IntEnum):
CHANNEL_ALLOCATION_RES = 1
HANDSHAKE_INIT_RES = 2
HANDSHAKE_COMP_RES = 3
ACK = 4
ENCRYPTED_TRANSPORT = 5

View File

@ -0,0 +1,93 @@
import struct
from typing import Tuple
from .. import Transport
from ..thp import checksum
from .message_header import MessageHeader
INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3
MAX_PAYLOAD_LEN = 60000
MESSAGE_TYPE_LENGTH = 2
CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum
write_payload_to_wire(transport, header, data)
def write_payload_to_wire(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
transport.open()
buffer = bytearray(transport_payload)
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
while buffer:
chunk = (
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
)
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
"""
Reads from the given wire transport.
Returns `Tuple[MessageHeader, bytes, bytes]`:
1. `header` (`MessageHeader`): Header of the message.
2. `data` (`bytes`): Contents of the message (if any).
3. `checksum` (`bytes`): crc32 checksum of the header + data.
"""
buffer = bytearray()
# Read header with first part of message data
header, first_chunk = read_first(transport)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < header.data_length:
buffer.extend(read_next(transport, header.cid))
data_len = header.data_length - checksum.CHECKSUM_LENGTH
msg_data = buffer[:data_len]
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
return (header, msg_data, chksum)
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
chunk = transport.read_chunk()
try:
ctrl_byte, cid, data_length = struct.unpack(
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
)
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[INIT_HEADER_LENGTH:]
return MessageHeader(ctrl_byte, cid, data_length), data
def read_next(transport: Transport, cid: int) -> bytes:
chunk = transport.read_chunk()
ctrl_byte, read_cid = struct.unpack(
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
)
if ctrl_byte != CONTINUATION_PACKET:
raise RuntimeError("Continuation packet with incorrect control byte")
if read_cid != cid:
raise RuntimeError("Continuation packet for different channel")
return chunk[CONT_HEADER_LENGTH:]

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project. # This file is part of the Trezor project.
# #
# Copyright (C) 2012-2022 SatoshiLabs and contributors # Copyright (C) 2012-2024 SatoshiLabs and contributors
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 # it under the terms of the GNU Lesser General Public License version 3
@ -14,14 +14,15 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import socket import socket
import time import time
from typing import TYPE_CHECKING, Iterable, Optional from typing import TYPE_CHECKING, Iterable, Tuple
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TransportException from . import Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING: if TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport): class UdpTransport(Transport):
DEFAULT_HOST = "127.0.0.1" DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324 DEFAULT_PORT = 21324
PATH_PREFIX = "udp" PATH_PREFIX = "udp"
ENABLED: bool = True ENABLED: bool = True
CHUNK_SIZE = 64
def __init__(self, device: Optional[str] = None) -> None: def __init__(
self,
device: str | None = None,
) -> None:
if not device: if not device:
host = UdpTransport.DEFAULT_HOST host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT port = UdpTransport.DEFAULT_PORT
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
devparts = device.split(":") devparts = device.split(":")
host = devparts[0] host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port) self.device: Tuple[str, int] = (host, port)
self.socket: Optional[socket.socket] = None
super().__init__(protocol=ProtocolV1(self)) self.socket: socket.socket | None = None
super().__init__()
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
@classmethod @classmethod
def _try_path(cls, path: str) -> "UdpTransport": def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path) d = cls(path)
try: try:
d.open() d.open()
if d._ping(): if d.ping():
return d return d
else: else:
raise TransportException( raise TransportException(
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]: ) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try: try:
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
else: else:
raise TransportException(f"No UDP device at {path}") raise TransportException(f"No UDP device at {path}")
def wait_until_ready(self, timeout: float = 10) -> None: def get_path(self) -> str:
try: return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
self.open()
start = time.monotonic()
while True:
if self._ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def open(self) -> None: def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.close() self.socket.close()
self.socket = None self.socket = None
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected data length") raise TransportException("Unexpected data length")
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
while True: while True:
try: try:
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk) return bytearray(chunk)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self.ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project. # This file is part of the Trezor project.
# #
# Copyright (C) 2012-2022 SatoshiLabs and contributors # Copyright (C) 2012-2024 SatoshiLabs and contributors
# #
# This library is free software: you can redistribute it and/or modify # This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3 # it under the terms of the GNU Lesser General Public License version 3
@ -14,16 +14,17 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit import atexit
import logging import logging
import sys import sys
import time import time
from typing import Iterable, List, Optional from typing import Iterable, List
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
WEBUSB_CHUNK_SIZE = 64 WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle: class WebUsbTransport(Transport):
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: """
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
CHUNK_SIZE = 64
def __init__(
self,
device: "usb1.USBDevice",
debug: bool = False,
) -> None:
self.device = device self.device = device
self.debug = debug
self.interface = DEBUG_INTERFACE if debug else INTERFACE self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0 self.handle: usb1.USBDeviceHandle | None = None
self.handle: Optional["usb1.USBDeviceHandle"] = None
super().__init__()
@classmethod
def enumerate(
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
def open(self) -> None: def open(self) -> None:
self.handle = self.device.open() self.handle = self.device.open()
@ -64,6 +121,8 @@ class WebUsbHandle:
self.handle.claimInterface(self.interface) self.handle.claimInterface(self.interface)
except usb1.USBErrorAccess as e: except usb1.USBErrorAccess as e:
raise DeviceIsBusy(self.device) from e raise DeviceIsBusy(self.device) from e
except usb1.USBErrorBusy as e:
raise DeviceIsBusy(self.device) from e
def close(self) -> None: def close(self) -> None:
if self.handle is not None: if self.handle is not None:
@ -75,6 +134,8 @@ class WebUsbHandle:
self.handle = None self.handle = None
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
if len(chunk) != WEBUSB_CHUNK_SIZE: if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
@ -97,6 +158,8 @@ class WebUsbHandle:
return return
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
endpoint = 0x80 | self.endpoint endpoint = 0x80 | self.endpoint
while True: while True:
@ -117,70 +180,6 @@ class WebUsbHandle:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk return chunk
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
def __init__(
self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
debug: bool = False,
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
self.device = device
self.handle = handle
self.debug = debug
super().__init__(protocol=ProtocolV1(handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
except usb1.USBErrorPipe:
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
return devices
def find_debug(self) -> "WebUsbTransport": def find_debug(self) -> "WebUsbTransport":
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True) return WebUsbTransport(self.device, debug=True)