1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-25 06:40:58 +00:00

feat(python): implement session based trezorlib

[no changelog]
This commit is contained in:
M1nd3r 2024-12-02 15:44:10 +01:00
parent 4a18f67f8f
commit 6b1fc71ce3
61 changed files with 4063 additions and 1663 deletions

View File

@ -95,6 +95,15 @@ class Emulator:
raise RuntimeError
return self._client
@client.setter
def client(self, new_client: TrezorClientDebugLink) -> None:
"""Setter for the client property to update _client."""
if not isinstance(new_client, TrezorClientDebugLink):
raise TypeError(
f"Expected a TrezorClientDebugLink, got {type(new_client).__name__}."
)
self._client = new_client
def make_args(self) -> List[str]:
return []
@ -112,7 +121,7 @@ class Emulator:
start = time.monotonic()
try:
while True:
if transport._ping():
if transport.ping():
break
if self.process.poll() is not None:
raise RuntimeError("Emulator process died")

View File

@ -7,7 +7,7 @@ import typing as t
from importlib import metadata
from . import device
from .client import TrezorClient
from .transport.session import Session
try:
cryptography_version = metadata.version("cryptography")
@ -361,7 +361,7 @@ def verify_authentication_response(
def authenticate_device(
client: TrezorClient,
session: Session,
challenge: bytes | None = None,
*,
whitelist: t.Collection[bytes] | None = None,
@ -371,7 +371,7 @@ def authenticate_device(
if challenge is None:
challenge = secrets.token_bytes(16)
resp = device.authenticate(client, challenge)
resp = device.authenticate(session, challenge)
return verify_authentication_response(
challenge,

View File

@ -20,17 +20,17 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.session import Session
@expect(messages.BenchmarkNames)
def list_names(
client: "TrezorClient",
session: "Session",
) -> "MessageType":
return client.call(messages.BenchmarkListNames())
return session.call(messages.BenchmarkListNames())
@expect(messages.BenchmarkResult)
def run(client: "TrezorClient", name: str) -> "MessageType":
return client.call(messages.BenchmarkRun(name=name))
def run(session: "Session", name: str) -> "MessageType":
return session.call(messages.BenchmarkRun(name=name))

View File

@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import expect, session
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
@expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -42,16 +42,15 @@ def get_address(
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
session: "Session", address_n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
return session.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
)
@session
def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0]
tx_msg = tx_json.copy()
@ -60,7 +59,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
response = client.call(envelope)
response = session.call(envelope)
if not isinstance(response, messages.BinanceTxRequest):
raise RuntimeError(
@ -77,7 +76,7 @@ def sign_tx(
else:
raise ValueError("can not determine msg type")
response = client.call(msg)
response = session.call(msg)
if not isinstance(response, messages.BinanceSignedTx):
raise RuntimeError(

View File

@ -13,7 +13,6 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import warnings
from copy import copy
from decimal import Decimal
@ -23,12 +22,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict
from . import exceptions, messages
from .tools import expect, prepare_message_bytes, session
from .tools import expect, prepare_message_bytes
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
class ScriptSig(TypedDict):
asm: str
@ -105,7 +104,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
@expect(messages.PublicKey)
def get_public_node(
client: "TrezorClient",
session: "Session",
n: "Address",
ecdsa_curve_name: Optional[str] = None,
show_display: bool = False,
@ -116,13 +115,13 @@ def get_public_node(
unlock_path_mac: Optional[bytes] = None,
) -> "MessageType":
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
return session.call(
messages.GetPublicKey(
address_n=n,
ecdsa_curve_name=ecdsa_curve_name,
@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any):
@expect(messages.Address)
def get_authenticated_address(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
show_display: bool = False,
@ -153,13 +152,13 @@ def get_authenticated_address(
chunkify: bool = False,
) -> "MessageType":
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
return session.call(
messages.GetAddress(
address_n=n,
coin_name=coin_name,
@ -172,15 +171,16 @@ def get_authenticated_address(
)
# TODO this is used by tests only
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
return session.call(
messages.GetOwnershipId(
address_n=n,
coin_name=coin_name,
@ -190,8 +190,9 @@ def get_ownership_id(
)
# TODO this is used by tests only
def get_ownership_proof(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
@ -202,11 +203,11 @@ def get_ownership_proof(
preauthorized: bool = False,
) -> Tuple[bytes, bytes]:
if preauthorized:
res = client.call(messages.DoPreauthorized())
res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(
res = session.call(
messages.GetOwnershipProof(
address_n=n,
coin_name=coin_name,
@ -226,7 +227,7 @@ def get_ownership_proof(
@expect(messages.MessageSignature)
def sign_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
message: AnyStr,
@ -234,7 +235,7 @@ def sign_message(
no_script_type: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.SignMessage(
coin_name=coin_name,
address_n=n,
@ -247,7 +248,7 @@ def sign_message(
def verify_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
address: str,
signature: bytes,
@ -255,7 +256,7 @@ def verify_message(
chunkify: bool = False,
) -> bool:
try:
resp = client.call(
resp = session.call(
messages.VerifyMessage(
address=address,
signature=signature,
@ -269,9 +270,9 @@ def verify_message(
return isinstance(resp, messages.Success)
@session
# @session
def sign_tx(
client: "TrezorClient",
session: "Session",
coin_name: str,
inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType],
@ -319,17 +320,17 @@ def sign_tx(
setattr(signtx, name, value)
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
elif preauthorized:
res = client.call(messages.DoPreauthorized())
res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(signtx)
res = session.call(signtx)
# Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -388,7 +389,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index]
res = client.call(msg)
res = session.call(msg)
else:
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
@ -418,7 +419,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}."
)
res = client.call(messages.TxAck(tx=msg))
res = session.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest):
raise exceptions.TrezorException("Unexpected message")
@ -432,7 +433,7 @@ def sign_tx(
@expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin(
client: "TrezorClient",
session: "Session",
coordinator: str,
max_rounds: int,
max_coordinator_fee_rate: int,
@ -441,7 +442,7 @@ def authorize_coinjoin(
coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
return session.call(
messages.AuthorizeCoinJoin(
coordinator=coordinator,
max_rounds=max_rounds,

View File

@ -35,8 +35,8 @@ from . import exceptions, messages, tools
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.session import Session
PROTOCOL_MAGICS = {
"mainnet": 764824073,
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
@expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_parameters: messages.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
@ -833,7 +833,7 @@ def get_address(
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetAddress(
address_parameters=address_parameters,
protocol_magic=protocol_magic,
@ -847,12 +847,12 @@ def get_address(
@expect(messages.CardanoPublicKey)
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
show_display: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetPublicKey(
address_n=address_n,
derivation_type=derivation_type,
@ -863,12 +863,12 @@ def get_public_key(
@expect(messages.CardanoNativeScriptHash)
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> "MessageType":
return client.call(
return session.call(
messages.CardanoGetNativeScriptHash(
script=native_script,
display_format=display_format,
@ -878,7 +878,7 @@ def get_native_script_hash(
def sign_tx(
client: "TrezorClient",
session: "Session",
signing_mode: messages.CardanoTxSigningMode,
inputs: List[InputWithPath],
outputs: List[OutputWithData],
@ -915,7 +915,7 @@ def sign_tx(
signing_mode,
)
response = client.call(
response = session.call(
messages.CardanoSignTxInit(
signing_mode=signing_mode,
inputs_count=len(inputs),
@ -951,14 +951,14 @@ def sign_tx(
_get_certificates_items(certificates),
withdrawals,
):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data)
auxiliary_data_supplement = session.call(auxiliary_data)
if not isinstance(
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
):
@ -971,7 +971,7 @@ def sign_tx(
auxiliary_data_supplement.__dict__
)
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
@ -980,24 +980,24 @@ def sign_tx(
_get_collateral_inputs_items(collateral_inputs),
required_signers,
):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
if collateral_return is not None:
for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item)
response = session.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
for reference_input in reference_inputs:
response = client.call(reference_input)
response = session.call(reference_input)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"] = []
for witness_request in witness_requests:
response = client.call(witness_request)
response = session.call(witness_request)
if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"].append(
@ -1009,12 +1009,12 @@ def sign_tx(
}
)
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck())
response = session.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR

View File

@ -14,33 +14,42 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import functools
import logging
import os
import sys
import typing as t
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
from .. import exceptions, transport
from ..client import TrezorClient
from ..ui import ClickUI, ScriptUI
from .. import exceptions, transport, ui
from ..client import ProtocolVersion, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2
from ..transport.thp.channel_database import get_channel_db
if TYPE_CHECKING:
LOG = logging.getLogger(__name__)
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P")
R = TypeVar("R")
R = t.TypeVar("R")
FuncWithSession = t.Callable[Concatenate[Session, P], R]
class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive
if case_sensitive:
@ -48,7 +57,7 @@ class ChoiceType(click.Choice):
else:
self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values():
return value
value = super().convert(value, param, ctx)
@ -57,11 +66,69 @@ class ChoiceType(click.Choice):
return self.typemap[value]
def get_passphrase(
passphrase_on_host: bool, available_on_device: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return ui.PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
stored_channels = get_channel_db().load_stored_channels()
stored_transport_paths = [ch.transport_path for ch in stored_channels]
path = transport.get_path()
if path in stored_transport_paths:
stored_channel_with_correct_transport_path = next(
ch for ch in stored_channels if ch.transport_path == path
)
try:
client = TrezorClient.resume(
transport, stored_channel_with_correct_transport_path
)
except Exception:
LOG.debug("Failed to resume a channel. Replacing by a new one.")
get_channel_db().remove_channel(path)
client = TrezorClient(transport)
else:
client = TrezorClient(transport)
return client
class TrezorConnection:
def __init__(
self,
path: str,
session_id: Optional[bytes],
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
) -> None:
@ -70,6 +137,54 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(
self,
derive_cardano: bool = False,
empty_passphrase: bool = False,
must_resume: bool = False,
) -> Session:
client = self.get_client()
if must_resume and self.session_id is None:
click.echo("Failed to resume session - no session id provided")
raise RuntimeError("Failed to resume session - no session id provided")
# Try resume session from id
if self.session_id is not None:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:
raise Exception("Unsupported client protocol", client.protocol_version)
if must_resume:
if session.id != self.session_id or session.id is None:
click.echo("Failed to resume session")
RuntimeError("Failed to resume session - no session id provided")
return session
features = client.protocol.get_features()
passphrase_enabled = True # TODO what to do here?
if not passphrase_enabled:
return client.get_session(derive_cardano=derive_cardano)
if empty_passphrase:
passphrase = ""
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
return session
def get_transport(self) -> "Transport":
try:
# look for transport without prefix search
@ -82,19 +197,13 @@ class TrezorConnection:
# if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI":
if self.script:
# It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods)
return ScriptUI
else:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
def get_client(self) -> TrezorClient:
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)
return get_client(self.get_transport())
def get_management_session(self) -> Session:
client = self.get_client()
management_session = client.get_management_session()
return management_session
@contextmanager
def client_context(self):
@ -128,7 +237,57 @@ class TrezorConnection:
# other exceptions may cause a traceback
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
def with_session(
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
*,
empty_passphrase: bool = False,
derive_cardano: bool = False,
management: bool = False,
must_resume: bool = False,
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
"""Provides a Click command with parameter `session=obj.get_session(...)` or
`session=obj.get_management_session()` based on the parameters provided.
If default parameters are ok, this decorator can be used without parentheses.
TODO: handle resumption of sessions and their (potential) closure.
"""
def decorator(
func: FuncWithSession,
) -> "t.Callable[P, R]":
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
if management:
session = obj.get_management_session()
else:
session = obj.get_session(
derive_cardano=derive_cardano,
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
try:
return func(session, *args, **kwargs)
finally:
pass
# TODO try end session if not resumed
return function_with_session
# If the decorator @get_session is used without parentheses
if func and callable(func):
return decorator(func) # type: ignore [Function return type]
return decorator
def with_client(
func: "t.Callable[Concatenate[TrezorClient, P], R]",
) -> "t.Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
@ -142,23 +301,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
click.echo(
"Warning: resume session detection is not implemented yet!", err=True
)
try:
return func(client, *args, **kwargs)
finally:
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
return trezorctl_command_with_client
# def with_client(
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
# ) -> "t.Callable[P, R]":
# """Wrap a Click command in `with obj.client_context() as client`.
# Sessions are handled transparently. The user is warned when session did not resume
# cleanly. The session is closed after the command completes - unless the session
# was resumed, in which case it should remain open.
# """
# @click.pass_obj
# @functools.wraps(func)
# def trezorctl_command_with_client(
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
# ) -> "R":
# with obj.client_context() as client:
# session_was_resumed = obj.session_id == client.session_id
# if not session_was_resumed and obj.session_id is not None:
# # tried to resume but failed
# click.echo("Warning: failed to resume session.", err=True)
# try:
# return func(client, *args, **kwargs)
# finally:
# if not session_was_resumed:
# try:
# client.end_session()
# except Exception:
# pass
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
# return trezorctl_command_with_client
class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility.
@ -188,14 +386,14 @@ class AliasedGroup(click.Group):
def __init__(
self,
aliases: Optional[Dict[str, click.Command]] = None,
*args: Any,
**kwargs: Any,
aliases: t.Dict[str, click.Command] | None = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
super().__init__(*args, **kwargs)
self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-")
# try to look up the real name
cmd = super().get_command(ctx, cmd_name)

View File

@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
import click
from .. import benchmark
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
def list_names_patern(
client: "TrezorClient", pattern: Optional[str] = None
) -> List[str]:
names = list(benchmark.list_names(client).names)
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
names = list(benchmark.list_names(session).names)
if pattern is None:
return names
return [name for name in names if fnmatch(name, pattern)]
@ -43,10 +41,10 @@ def cli() -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@with_session(empty_passphrase=True)
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
@with_session(empty_passphrase=True)
def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
for name in names:
result = benchmark.run(client, name)
result = benchmark.run(session, name)
click.echo(f"{name}: {result.value} {result.unit}")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import binance, tools
from . import with_client
from ..transport.session import Session
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Binance address for specified path."""
address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify)
return binance.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key."""
address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex()
return binance.get_public_key(session, address_n, show_display).hex()
@cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx":
"""Sign Binance transaction.
Transaction must be provided as a JSON file.
"""
address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -13,6 +13,7 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64
import json
@ -22,10 +23,10 @@ import click
import construct as c
from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48
@ -174,15 +175,15 @@ def cli() -> None:
help="Sort pubkeys lexicographically using BIP-67",
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
script_type: Optional[messages.InputScriptType],
script_type: messages.InputScriptType | None,
show_display: bool,
multisig_xpub: List[str],
multisig_threshold: Optional[int],
multisig_threshold: int | None,
multisig_suffix_length: int,
multisig_sort_pubkeys: bool,
chunkify: bool,
@ -235,7 +236,7 @@ def get_address(
multisig = None
return btc.get_address(
client,
session,
coin,
address_n,
show_display,
@ -252,9 +253,9 @@ def get_address(
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_node(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
curve: Optional[str],
@ -266,7 +267,7 @@ def get_public_node(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node(
client,
session,
address_n,
ecdsa_curve_name=curve,
show_display=show_display,
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
purpose: Optional[int],
@ -326,7 +327,7 @@ def _get_descriptor(
n = tools.parse_path(path)
pub = btc.get_public_node(
client,
session,
n,
show_display=show_display,
coin_name=coin,
@ -363,9 +364,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
account_type: Optional[int],
@ -375,7 +376,7 @@ def get_descriptor(
"""Get descriptor of given account."""
try:
return _get_descriptor(
client, coin, account, account_type, script_type, show_display
session, coin, account, account_type, script_type, show_display
)
except ValueError as e:
raise click.ClickException(str(e))
@ -390,8 +391,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File())
@with_client
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
}
_, serialized_tx = btc.sign_tx(
client,
session,
coin,
inputs,
outputs,
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
message: str,
@ -462,7 +463,7 @@ def sign_message(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
res = btc.sign_message(
client,
session,
coin,
address_n,
message,
@ -483,9 +484,9 @@ def sign_message(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
signature: str,
@ -495,7 +496,7 @@ def verify_message(
"""Verify message."""
signature_bytes = base64.b64decode(signature)
return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify
session, coin, address, signature_bytes, message, chunkify=chunkify
)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import cardano, messages, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def sign_tx(
client: "TrezorClient",
session: "Session",
file: TextIO,
signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int,
@ -123,9 +123,8 @@ def sign_tx(
for p in transaction["additional_witness_requests"]
]
client.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx(
client,
session,
signing_mode,
inputs,
outputs,
@ -209,9 +208,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
address_type: messages.CardanoAddressType,
staking_address: str,
@ -262,9 +261,8 @@ def get_address(
script_staking_hash_bytes,
)
client.init_device(derive_cardano=True)
return cardano.get_address(
client,
session,
address_parameters,
protocol_magic,
network_id,
@ -283,18 +281,17 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
derivation_type: messages.CardanoDerivationType,
show_display: bool,
) -> messages.CardanoPublicKey:
"""Get Cardano public key."""
address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display
session, address_n, derivation_type=derivation_type, show_display=show_display
)
@ -312,9 +309,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS,
)
@with_client
@with_session(derive_cardano=True)
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType,
@ -323,7 +320,6 @@ def get_native_script_hash(
native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True)
return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type
session, native_script, display_format, derivation_type=derivation_type
)

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click
from .. import misc, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command()
@click.argument("size", type=int)
@with_client
def get_entropy(client: "TrezorClient", size: int) -> str:
@with_session(empty_passphrase=True)
def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device."""
return misc.get_entropy(client, size).hex()
return misc.get_entropy(session, size).hex()
@cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session(empty_passphrase=True)
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.encrypt_keyvalue(
client,
session,
address_n,
key,
value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session(empty_passphrase=True)
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(
client,
session,
address_n,
key,
bytes.fromhex(value),

View File

@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
import click
from .. import mapping, messages, protobuf
from ..client import TrezorClient
from ..debuglink import TrezorClientDebugLink
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
from ..debuglink import record_screen
from . import with_client
from ..transport.session import Session
from . import with_session
if TYPE_CHECKING:
from . import TrezorConnection
@ -35,51 +34,51 @@ def cli() -> None:
"""Miscellaneous debug features."""
@cli.command()
@click.argument("message_name_or_type")
@click.argument("hex_data")
@click.pass_obj
def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
) -> None:
"""Send raw bytes to Trezor.
# @cli.command()
# @click.argument("message_name_or_type")
# @click.argument("hex_data")
# @click.pass_obj
# def send_bytes(
# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
# ) -> None:
# """Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message
chunking works on the transport level. Message length is calculated and sent
automatically, and it is currently impossible to explicitly specify invalid length.
# Message type and message data must be specified separately, due to how message
# chunking works on the transport level. Message length is calculated and sent
# automatically, and it is currently impossible to explicitly specify invalid length.
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
in which case the value of that enum is used.
"""
if message_name_or_type.isdigit():
message_type = int(message_name_or_type)
else:
message_type = getattr(messages.MessageType, message_name_or_type)
# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
# in which case the value of that enum is used.
# """
# if message_name_or_type.isdigit():
# message_type = int(message_name_or_type)
# else:
# message_type = getattr(messages.MessageType, message_name_or_type)
if not isinstance(message_type, int):
raise click.ClickException("Invalid message type.")
# if not isinstance(message_type, int):
# raise click.ClickException("Invalid message type.")
try:
message_data = bytes.fromhex(hex_data)
except Exception as e:
raise click.ClickException("Invalid hex data.") from e
# try:
# message_data = bytes.fromhex(hex_data)
# except Exception as e:
# raise click.ClickException("Invalid hex data.") from e
transport = obj.get_transport()
transport.begin_session()
transport.write(message_type, message_data)
# transport = obj.get_transport()
# transport.deprecated_begin_session()
# transport.write(message_type, message_data)
response_type, response_data = transport.read()
transport.end_session()
# response_type, response_data = transport.read()
# transport.deprecated_end_session()
click.echo(f"Response type: {response_type}")
click.echo(f"Response data: {response_data.hex()}")
# click.echo(f"Response type: {response_type}")
# click.echo(f"Response data: {response_data.hex()}")
try:
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
click.echo("Parsed message:")
click.echo(protobuf.format_message(msg))
except Exception as e:
click.echo(f"Could not parse response: {e}")
# try:
# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
# click.echo("Parsed message:")
# click.echo(protobuf.format_message(msg))
# except Exception as e:
# click.echo(f"Could not parse response: {e}")
@cli.command()
@ -106,17 +105,17 @@ def record_screen_from_connection(
@cli.command()
@with_client
def prodtest_t1(client: "TrezorClient") -> str:
@with_session(management=True)
def prodtest_t1(session: "Session") -> str:
"""Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
"""
return debuglink_prodtest_t1(client)
return debuglink_prodtest_t1(session)
@cli.command()
@with_client
def optiga_set_sec_max(client: "TrezorClient") -> str:
@with_session(management=True)
def optiga_set_sec_max(session: "Session") -> str:
"""Set Optiga's security event counter to maximum."""
return debuglink_optiga_set_sec_max(client)
return debuglink_optiga_set_sec_max(session)

View File

@ -25,11 +25,11 @@ import requests
from .. import debuglink, device, exceptions, messages, ui
from ..tools import format_path
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..protobuf import MessageType
from ..transport.session import Session
from . import TrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = {
@ -65,17 +65,18 @@ def cli() -> None:
help="Wipe device in bootloader mode. This also erases the firmware.",
is_flag=True,
)
@with_client
def wipe(client: "TrezorClient", bootloader: bool) -> str:
@with_session(management=True)
def wipe(session: "Session", bootloader: bool) -> str:
"""Reset device to factory defaults and remove all private data."""
features = session.features
if bootloader:
if not client.features.bootloader_mode:
if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
else:
click.echo("Wiping user data and firmware!")
else:
if client.features.bootloader_mode:
if features.bootloader_mode:
click.echo(
"Your device is in bootloader mode. This operation would also erase firmware."
)
@ -88,7 +89,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
click.echo("Wiping user data!")
try:
return device.wipe(client)
return device.wipe(
session
) # TODO decide where the wipe should happen - management or regular session
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@ -104,9 +107,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
@click.option("-a", "--academic", is_flag=True)
@click.option("-b", "--needs-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True)
@with_client
@with_session(management=True)
def load(
client: "TrezorClient",
session: "Session",
mnemonic: t.Sequence[str],
pin: str,
passphrase_protection: bool,
@ -137,7 +140,7 @@ def load(
try:
return debuglink.load_device(
client,
session,
mnemonic=list(mnemonic),
pin=pin,
passphrase_protection=passphrase_protection,
@ -172,9 +175,9 @@ def load(
)
@click.option("-d", "--dry-run", is_flag=True)
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
@with_client
@with_session(management=True)
def recover(
client: "TrezorClient",
session: "Session",
words: str,
expand: bool,
pin_protection: bool,
@ -202,7 +205,7 @@ def recover(
type = messages.RecoveryType.UnlockRepeatedBackup
return device.recover(
client,
session,
word_count=int(words),
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -224,9 +227,9 @@ def recover(
@click.option("-n", "--no-backup", is_flag=True)
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
@click.option("-e", "--entropy-check-count", type=click.IntRange(0))
@with_client
@with_session(management=True)
def setup(
client: "TrezorClient",
session: "Session",
strength: int | None,
passphrase_protection: bool,
pin_protection: bool,
@ -244,7 +247,7 @@ def setup(
BT = messages.BackupType
if backup_type is None:
if client.version >= (2, 7, 1):
if session.version >= (2, 7, 1):
# SLIP39 extendable was introduced in 2.7.1
backup_type = BT.Slip39_Single_Extendable
else:
@ -254,10 +257,10 @@ def setup(
if (
backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
and messages.Capability.Shamir not in client.features.capabilities
and messages.Capability.Shamir not in session.features.capabilities
) or (
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
and messages.Capability.ShamirGroups not in client.features.capabilities
and messages.Capability.ShamirGroups not in session.features.capabilities
):
click.echo(
"WARNING: Your Trezor device does not indicate support for the requested\n"
@ -265,7 +268,7 @@ def setup(
)
resp, path_xpubs = device.reset_entropy_check(
client,
session,
strength=strength,
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -289,23 +292,21 @@ def setup(
@cli.command()
@click.option("-t", "--group-threshold", type=int)
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
@with_client
@with_session(management=True)
def backup(
client: "TrezorClient",
session: "Session",
group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (),
) -> str:
"""Perform device seed backup."""
return device.backup(client, group_threshold, groups)
return device.backup(session, group_threshold, groups)
@cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> str:
@with_session(management=True)
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str:
"""Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored
@ -319,9 +320,9 @@ def sd_protect(
off - Remove SD card secret protection.
refresh - Replace the current SD card secret with a new one.
"""
if client.features.model == "1":
if session.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation)
return device.sd_protect(session, operation)
@cli.command()
@ -331,24 +332,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str:
Currently only supported on Trezor Model One.
"""
# avoid using @with_client because it closes the session afterwards,
# avoid using @with_management_session because it closes the session afterwards,
# which triggers double prompt on device
with obj.client_context() as client:
return device.reboot_to_bootloader(client)
return device.reboot_to_bootloader(client.get_management_session())
@cli.command()
@with_client
def tutorial(client: "TrezorClient") -> str:
@with_session(management=True)
def tutorial(session: "Session") -> str:
"""Show on-device tutorial."""
return device.show_device_tutorial(client)
return device.show_device_tutorial(session)
@cli.command()
@with_client
def unlock_bootloader(client: "TrezorClient") -> str:
@with_session(management=True)
def unlock_bootloader(session: "Session") -> str:
"""Unlocks bootloader. Irreversible."""
return device.unlock_bootloader(client)
return device.unlock_bootloader(session)
@cli.command()
@ -359,11 +360,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
type=int,
help="Dialog expiry in seconds.",
)
@with_client
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str:
@with_session(management=True)
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str:
"""Show a "Do not disconnect" dialog."""
if enable is False:
return device.set_busy(client, None)
return device.set_busy(session, None)
if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -373,7 +374,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
)
return device.set_busy(client, expiry * 1000)
return device.set_busy(session, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = (
@ -393,9 +394,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
is_flag=True,
help="Do not check intermediate certificates against the whitelist.",
)
@with_client
@with_session(management=True)
def authenticate(
client: "TrezorClient",
session: "Session",
hex_challenge: str | None,
root: t.BinaryIO | None,
raw: bool | None,
@ -420,7 +421,7 @@ def authenticate(
challenge = bytes.fromhex(hex_challenge)
if raw:
msg = device.authenticate(client, challenge)
msg = device.authenticate(session, challenge)
click.echo(f"Challenge: {hex_challenge}")
click.echo(f"Signature of challenge: {msg.signature.hex()}")
@ -468,14 +469,14 @@ def authenticate(
else:
whitelist_json = requests.get(
PUBKEY_WHITELIST_URL_TEMPLATE.format(
model=client.model.internal_name.lower()
model=session.model.internal_name.lower()
)
).json()
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
try:
authentication.authenticate_device(
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
)
except authentication.DeviceNotAuthentic:
click.echo("Device is not authentic.")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import eos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display)
res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx":
"""Sign EOS transaction."""
tx_json = json.load(file)
address_n = tools.parse_path(address)
return eos.sign_tx(
client,
session,
address_n,
tx_json["transaction"],
tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions
from . import with_client
from . import with_session
if TYPE_CHECKING:
import web3
from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify)
return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
@with_session
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path."""
address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display)
result = ethereum.get_public_node(session, address_n, show_display=show_display)
return {
"node": {
"depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address")
@click.argument("amount", callback=_amount_to_int)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
chain_id: int,
address: str,
amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address)
from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network
session, address_n, encoded_network=encoded_network
)
if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None
assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559(
client,
session,
n=address_n,
nonce=nonce,
gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price
assert gas_price is not None
sig = ethereum.sign_tx(
client,
session,
n=address_n,
tx_type=tx_type,
nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool
session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]:
"""Sign message with Ethereum address."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = {
"message": message,
"address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation",
)
@click.argument("file", type=click.File("r"))
@with_client
@with_session
def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read())
ret = ethereum.sign_typed_data(
client,
session,
address_n,
data,
metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: str,
message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify
session, address, signature_bytes, message, chunkify=chunkify
)
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex")
@click.argument("message_hash_hex")
@with_client
@with_session
def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]:
"""
Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network
session, address_n, domain_hash, message_hash, network
)
output = {
"domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click
from .. import fido
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list")
@with_client
def credentials_list(client: "TrezorClient") -> None:
@with_session(empty_passphrase=True)
def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device."""
creds = fido.list_credentials(client)
creds = fido.list_credentials(session)
for cred in creds:
click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_client
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
@with_session(empty_passphrase=True)
def credentials_add(session: "Session", hex_credential_id: str) -> str:
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
"""
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
return fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove")
@click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_client
def credentials_remove(client: "TrezorClient", index: int) -> str:
@with_session(empty_passphrase=True)
def credentials_remove(session: "Session", index: int) -> str:
"""Remove the resident credential at the given index."""
return fido.remove_credential(client, index)
return fido.remove_credential(session, index)
#
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set")
@click.argument("counter", type=int)
@with_client
def counter_set(client: "TrezorClient", counter: int) -> str:
@with_session(empty_passphrase=True)
def counter_set(session: "Session", counter: int) -> str:
"""Set FIDO/U2F counter value."""
return fido.set_counter(client, counter)
return fido.set_counter(session, counter)
@counter.command(name="get-next")
@with_client
def counter_get_next(client: "TrezorClient") -> int:
@with_session(empty_passphrase=True)
def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation
and returns the counter value.
"""
return fido.get_next_counter(client)
return fido.get_next_counter(session)

View File

@ -37,10 +37,11 @@ import requests
from .. import device, exceptions, firmware, messages, models
from ..firmware import models as fw_models
from ..models import TrezorModel
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
from . import TrezorConnection
MODEL_CHOICE = ChoiceType(
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader.
"""
f = client.features
version = (f.major_version, f.minor_version, f.patch_version)
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
features = client.features
version = client.version
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2
@ -306,25 +307,26 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version
(higher than the specified one, if existing).
"""
features = client.features
model = client.model
if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features)
bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version))
f = client.features
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
releases = get_all_firmware_releases(model, bitcoin_only, beta)
highest_version = releases[0]["version"]
if version:
want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version:
if want_version[0] != features.major_version:
click.echo(
f"Warning: Trezor {client.model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})"
f"Warning: Trezor {model.name} firmware version should be "
f"{features.major_version}.X.Y (requested: {version})"
)
else:
want_version = highest_version
@ -359,8 +361,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal
# compatible version first
# Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version]
if f.bootloader_mode:
client_version = client.version
if features.bootloader_mode:
key_to_compare = "min_bootloader_version"
else:
key_to_compare = "min_firmware_version"
@ -447,11 +449,11 @@ def extract_embedded_fw(
def upload_firmware_into_device(
client: "TrezorClient",
session: "Session",
firmware_data: bytes,
) -> None:
"""Perform the final act of loading the firmware into Trezor."""
f = client.features
f = session.features
try:
if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest
@ -461,7 +463,7 @@ def upload_firmware_into_device(
with click.progressbar(
label="Uploading", length=len(firmware_data), show_eta=False
) as bar:
firmware.update(client, firmware_data, bar.update)
firmware.update(session, firmware_data, bar.update)
except exceptions.Cancelled:
click.echo("Update aborted on device.")
except exceptions.TrezorException as e:
@ -654,6 +656,7 @@ def update(
against data.trezor.io information, if available.
"""
with obj.client_context() as client:
management_session = client.get_management_session()
if sum(bool(x) for x in (filename, url, version)) > 1:
click.echo("You can use only one of: filename, url, version.")
sys.exit(1)
@ -709,7 +712,7 @@ def update(
if _is_strict_update(client, firmware_data):
header_size = _get_firmware_header_size(firmware_data)
device.reboot_to_bootloader(
client,
management_session,
boot_command=messages.BootCommand.INSTALL_UPGRADE,
firmware_header=firmware_data[:header_size],
language_data=language_data,
@ -719,7 +722,7 @@ def update(
click.echo(
"WARNING: Seamless installation not possible, language data will not be uploaded."
)
device.reboot_to_bootloader(client)
device.reboot_to_bootloader(management_session)
click.echo("Waiting for bootloader...")
while True:
@ -735,13 +738,15 @@ def update(
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
upload_firmware_into_device(client=client, firmware_data=firmware_data)
upload_firmware_into_device(
session=client.get_management_session(), firmware_data=firmware_data
)
@cli.command()
@click.argument("hex_challenge", required=False)
@with_client
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
@with_session(management=True)
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
"""Get a hash of the installed firmware combined with the optional challenge."""
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
return firmware.get_hash(client, challenge).hex()
return firmware.get_hash(session, challenge).hex()

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click
from .. import messages, monero, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes:
"""Get Monero address for specified path."""
address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify)
return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET,
)
@with_client
@with_session
def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]:
"""Get Monero watch key for specified path."""
address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type)
res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey
assert res.address is not None
assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests
from .. import nem, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
network: int,
show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str:
"""Get NEM address for specified path."""
address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify)
return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
file: TextIO,
broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
"""
address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import ripple, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ripple address"""
address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify)
return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction"""
address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:")
click.echo(result.signature.hex())
click.echo()

View File

@ -24,10 +24,11 @@ import click
import requests
from .. import device, messages, toif
from . import AliasedGroup, ChoiceType, with_client
from ..transport.session import Session
from . import AliasedGroup, ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
pass
try:
from PIL import Image
@ -180,18 +181,18 @@ def cli() -> None:
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
@with_session(management=True)
def pin(session: "Session", enable: Optional[bool], remove: bool) -> str:
"""Set, change or remove PIN."""
# Remove argument is there for backwards compatibility
return device.change_pin(client, remove=_should_remove(enable, remove))
return device.change_pin(session, remove=_should_remove(enable, remove))
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
@with_session(management=True)
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str:
"""Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
removed and the device will be reset to factory defaults.
"""
# Remove argument is there for backwards compatibility
return device.change_wipe_code(client, remove=_should_remove(enable, remove))
return device.change_wipe_code(session, remove=_should_remove(enable, remove))
@cli.command()
# keep the deprecated -l/--label option, make it do nothing
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label")
@with_client
def label(client: "TrezorClient", label: str) -> str:
@with_session(management=True)
def label(session: "Session", label: str) -> str:
"""Set new device label."""
return device.apply_settings(client, label=label)
return device.apply_settings(session, label=label)
@cli.command()
@with_client
def brightness(client: "TrezorClient") -> str:
@with_session(management=True)
def brightness(session: "Session") -> str:
"""Set display brightness."""
return device.set_brightness(client)
return device.set_brightness(session)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
@with_session(management=True)
def haptic_feedback(session: "Session", enable: bool) -> str:
"""Enable or disable haptic feedback."""
return device.apply_settings(client, haptic_feedback=enable)
return device.apply_settings(session, haptic_feedback=enable)
@cli.command()
@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
)
@click.option("-d/-D", "--display/--no-display", default=None)
@with_client
@with_session(management=True)
def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
) -> str:
"""Set new language with translations."""
if remove != (path_or_url is None):
@ -260,29 +261,29 @@ def language(
f"Failed to load translations from {path_or_url}"
) from None
return device.change_language(
client, language_data=language_data, show_display=display
session, language_data=language_data, show_display=display
)
@cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION))
@with_client
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str:
@with_session(management=True)
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str:
"""Set display rotation.
Configure display rotation for Trezor Model T. The options are
north, east, south or west.
"""
return device.apply_settings(client, display_rotation=rotation)
return device.apply_settings(session, display_rotation=rotation)
@cli.command()
@click.argument("delay", type=str)
@with_client
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
@with_session(management=True)
def auto_lock_delay(session: "Session", delay: str) -> str:
"""Set auto-lock delay (in seconds)."""
if not client.features.pin_protection:
if not session.features.pin_protection:
raise click.ClickException("Set up a PIN first")
value, unit = delay[:-1], delay[-1:]
@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
seconds = float(value) * units[unit]
else:
seconds = float(delay) # assume seconds if no unit is specified
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
@cli.command()
@click.argument("flags")
@with_client
def flags(client: "TrezorClient", flags: str) -> str:
@with_session(management=True)
def flags(session: "Session", flags: str) -> str:
"""Set device flags."""
if flags.lower().startswith("0b"):
flags_int = int(flags, 2)
@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
flags_int = int(flags, 16)
else:
flags_int = int(flags)
return device.apply_flags(client, flags=flags_int)
return device.apply_flags(session, flags=flags_int)
@cli.command()
@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str:
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
)
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
@with_session(management=True)
def homescreen(session: "Session", filename: str, quality: int) -> str:
"""Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default'
@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
if not path.exists() or not path.is_file():
raise click.ClickException("Cannot open file")
if client.features.model == "1":
if session.features.model == "1":
img = image_to_t1(path)
else:
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 240
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 240
)
img = image_to_jpeg(path, width, height, quality)
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = client.features.homescreen_width
height = client.features.homescreen_height
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = session.features.homescreen_width
height = session.features.homescreen_height
if width is None or height is None:
raise click.ClickException("Device did not report homescreen size.")
img = image_to_toif(path, width, height, True)
elif (
client.features.homescreen_format == messages.HomescreenFormat.Toif
or client.features.homescreen_format is None
session.features.homescreen_format == messages.HomescreenFormat.Toif
or session.features.homescreen_format is None
):
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 144
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 144
)
img = image_to_toif(path, width, height, False)
@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"Unknown image format requested by the device."
)
return device.apply_settings(client, homescreen=img)
return device.apply_settings(session, homescreen=img)
@cli.command()
@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
)
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client
@with_session(management=True)
def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
session: "Session", always: bool, level: messages.SafetyCheckLevel
) -> str:
"""Set safety check level.
@ -392,18 +393,18 @@ def safety_checks(
"""
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways
return device.apply_settings(client, safety_checks=level)
return device.apply_settings(session, safety_checks=level)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def experimental_features(client: "TrezorClient", enable: bool) -> str:
@with_session(management=True)
def experimental_features(session: "Session", enable: bool) -> str:
"""Enable or disable experimental message types.
This is a developer feature. Use with caution.
"""
return device.apply_settings(client, experimental_features=enable)
return device.apply_settings(session, experimental_features=enable)
#
@ -426,25 +427,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
@with_session(management=True)
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str:
"""Enable passphrase."""
if client.features.passphrase_protection is not True:
if session.features.passphrase_protection is not True:
use_passphrase = True
else:
use_passphrase = None
return device.apply_settings(
client,
session,
use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device,
)
@passphrase.command(name="off")
@with_client
def passphrase_off(client: "TrezorClient") -> str:
@with_session(management=True)
def passphrase_off(session: "Session") -> str:
"""Disable passphrase."""
return device.apply_settings(client, use_passphrase=False)
return device.apply_settings(session, use_passphrase=False)
# Registering the aliases for backwards compatibility
@ -457,10 +458,10 @@ passphrase.aliases = {
@passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str:
@with_session(management=True)
def hide_passphrase_from_host(session: "Session", hide: bool) -> str:
"""Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution.
"""
return device.apply_settings(client, hide_passphrase_from_host=hide)
return device.apply_settings(session, hide_passphrase_from_host=hide)

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import messages, solana, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
) -> messages.SolanaPublicKey:
"""Get Solana public key."""
address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display)
return solana.get_public_key(session, address_n, show_display)
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
chunkify: bool,
) -> messages.SolanaAddress:
"""Get Solana address."""
address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify)
return solana.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
serialized_tx: str,
additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
)
return solana.sign_tx(
client,
session,
address_n,
bytes.fromhex(serialized_tx),
additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click
from .. import stellar, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
try:
from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Stellar public address."""
address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify)
return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).",
)
@click.argument("b64envelope")
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes:
"""Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import messages, protobuf, tezos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
@ -37,23 +37,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Tezos address for specified path."""
address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display, chunkify)
return tezos.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Tezos public key."""
address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display)
return tezos.get_public_key(session, address_n, show_display)
@cli.command()
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> messages.TezosSignedTx:
"""Sign Tezos transaction."""
address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify)
return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)

View File

@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click
from .. import __version__, log, messages, protobuf, ui
from ..client import TrezorClient
from .. import __version__, log, messages, protobuf
from ..client import ProtocolVersion, TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.thp import channel_database
from ..transport.thp.channel_database import get_channel_db
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
@ -50,6 +53,7 @@ from . import (
stellar,
tezos,
with_client,
with_session,
)
F = TypeVar("F", bound=Callable)
@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None:
"--record",
help="Record screen changes into a specified directory.",
)
@click.option(
"-n",
"--no-store",
is_flag=True,
help="Do not store channels data between commands.",
default=False,
)
@click.version_option(version=__version__)
@click.pass_context
def cli_main(
@ -204,9 +215,10 @@ def cli_main(
script: bool,
session_id: Optional[str],
record: Optional[str],
no_store: bool,
) -> None:
configure_logging(verbose)
channel_database.set_channel_database(should_not_store=no_store)
bytes_session_id: Optional[bytes] = None
if session_id is not None:
try:
@ -214,6 +226,7 @@ def cli_main(
except ValueError:
raise click.ClickException(f"Not a valid session id: {session_id}")
# ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
# Optionally record the screen into a specified directory.
@ -285,18 +298,23 @@ def format_device_name(features: messages.Features) -> str:
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""
if no_resolve:
return enumerate_devices()
for d in enumerate_devices():
print(d.get_path())
return
from . import get_client
for transport in enumerate_devices():
try:
client = TrezorClient(transport, ui=ui.ClickUI())
client = get_client(transport)
description = format_device_name(client.features)
client.end_session()
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy:
description = "Device is in use by another process"
except Exception:
description = "Failed to read details"
click.echo(f"{transport} - {description}")
except Exception as e:
description = "Failed to read details " + str(type(e))
click.echo(f"{transport.get_path()} - {description}")
return None
@ -314,15 +332,19 @@ def version() -> str:
@cli.command()
@click.argument("message")
@click.option("-b", "--button-protection", is_flag=True)
@with_client
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
@with_session(empty_passphrase=True)
def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message."""
return client.ping(message, button_protection=button_protection)
# TODO return short-circuit from old client for old Trezors
return session.ping(message, button_protection)
@cli.command()
@click.pass_obj
def get_session(obj: TrezorConnection) -> str:
def get_session(
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
) -> str:
"""Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
@ -336,23 +358,44 @@ def get_session(obj: TrezorConnection) -> str:
obj.session_id = None
with obj.client_context() as client:
if client.features.model == "1" and client.version < (1, 9, 0):
raise click.ClickException(
"Upgrade your firmware to enable session support."
)
client.ensure_unlocked()
if client.session_id is None:
# client.ensure_unlocked()
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
)
if session.id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.")
else:
return client.session_id.hex()
return session.id.hex()
@cli.command()
@with_client
def clear_session(client: "TrezorClient") -> None:
@with_session(must_resume=True, empty_passphrase=True)
def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.)."""
return client.clear_session()
if session is None:
click.echo("Cannot clear session as it was not properly resumed.")
return
session.call(messages.LockDevice())
session.end()
# TODO different behaviour than main, not sure if ok
@cli.command()
def delete_channels() -> None:
"""
Delete cached channels.
Do not use together with the `-n` (`--no-store`) flag,
as the JSON database will not be deleted in that case.
"""
get_channel_db().clear_stored_channels()
click.echo("Deleted stored channels")
@cli.command()

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -21,47 +21,44 @@ import logging
import re
import textwrap
import time
import typing as t
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from enum import Enum, IntEnum, auto
from itertools import zip_longest
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
Sequence,
Tuple,
Union,
)
from mnemonic import Mnemonic
from . import mapping, messages, models, protobuf
from .client import TrezorClient
from .exceptions import TrezorFailure
from . import btc, mapping, messages, models, protobuf
from .client import (
MAX_PASSPHRASE_LENGTH,
MAX_PIN_LENGTH,
PASSPHRASE_ON_DEVICE,
TrezorClient,
)
from .exceptions import Cancelled, PinException, TrezorFailure
from .log import DUMP_BYTES
from .messages import DebugWaitType
from .tools import expect
from .messages import Capability, DebugWaitType
from .tools import expect, parse_path
from .transport.session import Session, SessionV1
from .transport.thp.protocol_v1 import ProtocolV1
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from typing_extensions import Protocol
from .messages import PinMatrixRequestType
from .transport import Transport
ExpectedMessage = Union[
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
ExpectedMessage = t.Union[
protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
]
AnyDict = Dict[str, Any]
AnyDict = t.Dict[str, t.Any]
class InputFunc(Protocol):
def __call__(
self,
hold_ms: int | None = None,
@ -70,6 +67,7 @@ if TYPE_CHECKING:
EXPECTED_RESPONSES_CONTEXT_LINES = 3
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
LOG = logging.getLogger(__name__)
@ -104,11 +102,13 @@ class UnstructuredJSONReader:
except json.JSONDecodeError:
self.dict = {}
def top_level_value(self, key: str) -> Any:
def top_level_value(self, key: str) -> t.Any:
return self.dict.get(key)
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
def recursively_find(data: Any) -> Iterator[Any]:
def find_objects_with_key_and_value(
self, key: str, value: t.Any
) -> list["AnyDict"]:
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
if isinstance(data, dict):
if data.get(key) == value:
yield data
@ -121,7 +121,7 @@ class UnstructuredJSONReader:
return list(recursively_find(self.dict))
def find_unique_object_with_key_and_value(
self, key: str, value: Any
self, key: str, value: t.Any
) -> AnyDict | None:
objects = self.find_objects_with_key_and_value(key, value)
if not objects:
@ -129,8 +129,10 @@ class UnstructuredJSONReader:
assert len(objects) == 1
return objects[0]
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
def recursively_find(data: Any) -> Iterator[Any]:
def find_values_by_key(
self, key: str, only_type: type | None = None
) -> list[t.Any]:
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
if isinstance(data, dict):
if key in data:
yield data[key]
@ -148,8 +150,8 @@ class UnstructuredJSONReader:
return values
def find_unique_value_by_key(
self, key: str, default: Any, only_type: type | None = None
) -> Any:
self, key: str, default: t.Any, only_type: type | None = None
) -> t.Any:
values = self.find_values_by_key(key, only_type=only_type)
if not values:
return default
@ -160,7 +162,7 @@ class UnstructuredJSONReader:
class LayoutContent(UnstructuredJSONReader):
"""Contains helper functions to extract specific parts of the layout."""
def __init__(self, json_tokens: Sequence[str]) -> None:
def __init__(self, json_tokens: t.Sequence[str]) -> None:
json_str = "".join(json_tokens)
super().__init__(json_str)
@ -422,11 +424,13 @@ def _make_input_func(
class DebugLink:
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self.transport = transport
self.allow_interactions = auto_interact
self.mapping = mapping.DEFAULT_MAPPING
self.protocol = ProtocolV1(self.transport, self.mapping)
# To be set by TrezorClientDebugLink (is not known during creation time)
self.model: models.TrezorModel | None = None
self.version: tuple[int, int, int] = (0, 0, 0)
@ -479,10 +483,16 @@ class DebugLink:
self.screen_text_file = file_path
def open(self) -> None:
self.transport.begin_session()
self.transport.open()
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_begin_session()
def close(self) -> None:
self.transport.end_session()
pass
# raise NotImplementedError
# TODO is this needed?
# self.transport.deprecated_end_session()
def _write(self, msg: protobuf.MessageType) -> None:
if self.waiting_for_layout_change:
@ -499,15 +509,10 @@ class DebugLink:
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self.transport.write(msg_type, msg_bytes)
self.protocol.write(msg)
def _read(self) -> protobuf.MessageType:
ret_type, ret_bytes = self.transport.read()
LOG.log(
DUMP_BYTES,
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
)
msg = self.mapping.decode(ret_type, ret_bytes)
msg = self.protocol.read()
# Collapse tokens to make log use less lines.
msg_for_log = msg
@ -521,18 +526,27 @@ class DebugLink:
)
return msg
def _call(self, msg: protobuf.MessageType) -> Any:
def _call(self, msg: protobuf.MessageType) -> t.Any:
self._write(msg)
return self._read()
def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkState:
def state(
self,
wait_type: DebugWaitType | None = None,
thp_channel_id: bytes | None = None,
) -> messages.DebugLinkState:
if wait_type is None:
wait_type = (
DebugWaitType.CURRENT_LAYOUT
if self.has_global_layout
else DebugWaitType.IMMEDIATE
)
result = self._call(messages.DebugLinkGetState(wait_layout=wait_type))
result = self._call(
messages.DebugLinkGetState(
wait_layout=wait_type,
thp_channel_id=thp_channel_id,
)
)
while not isinstance(result, (messages.Failure, messages.DebugLinkState)):
result = self._read()
if isinstance(result, messages.Failure):
@ -544,7 +558,7 @@ class DebugLink:
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
# Next layout change will be caused by external event
# (e.g. device being auto-locked or as a result of device_handler.run(xxx))
# (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
# and not by our debug actions/decisions.
# Resetting the debug state so we wait for the next layout change
# (and do not return the current state).
@ -560,7 +574,7 @@ class DebugLink:
return LayoutContent(obj.tokens)
@contextmanager
def wait_for_layout_change(self) -> Iterator[LayoutContent]:
def wait_for_layout_change(self) -> t.Iterator[LayoutContent]:
# set up a dummy layout content object to be yielded
layout_content = LayoutContent(
["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("]
@ -622,7 +636,7 @@ class DebugLink:
return "".join([str(matrix.index(p) + 1) for p in pin])
def read_recovery_word(self) -> Tuple[str | None, int | None]:
def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
state = self.state()
return (state.recovery_fake_word, state.recovery_word_pos)
@ -700,7 +714,7 @@ class DebugLink:
def click(
self,
click: Tuple[int, int],
click: t.Tuple[int, int],
hold_ms: int | None = None,
wait: bool | None = None,
) -> LayoutContent:
@ -862,10 +876,10 @@ class DebugUI:
self.clear()
def clear(self) -> None:
self.pins: Iterator[str] | None = None
self.pins: t.Iterator[str] | None = None
self.passphrase = ""
self.input_flow: Union[
Generator[None, messages.ButtonRequest, None], object, None
self.input_flow: t.Union[
t.Generator[None, messages.ButtonRequest, None], object, None
] = None
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
@ -896,7 +910,7 @@ class DebugUI:
raise AssertionError("input flow ended prematurely")
else:
try:
assert isinstance(self.input_flow, Generator)
assert isinstance(self.input_flow, t.Generator)
self.input_flow.send(br)
except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE
@ -918,12 +932,15 @@ class DebugUI:
class MessageFilter:
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
def __init__(
self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
) -> None:
self.message_type = message_type
self.fields: Dict[str, Any] = {}
self.fields: t.Dict[str, t.Any] = {}
self.update_fields(**fields)
def update_fields(self, **fields: Any) -> "MessageFilter":
def update_fields(self, **fields: t.Any) -> "MessageFilter":
for name, value in fields.items():
try:
self.fields[name] = self.from_message_or_type(value)
@ -971,7 +988,7 @@ class MessageFilter:
return True
def to_string(self, maxwidth: int = 80) -> str:
fields: list[Tuple[str, str]] = []
fields: list[t.Tuple[str, str]] = []
for field in self.message_type.FIELDS.values():
if field.name not in self.fields:
continue
@ -1001,7 +1018,8 @@ class MessageFilter:
class MessageFilterGenerator:
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
message_type = getattr(messages, key)
return MessageFilter(message_type).update_fields
@ -1009,6 +1027,245 @@ class MessageFilterGenerator:
message_filters = MessageFilterGenerator()
class SessionDebugWrapper(Session):
def __init__(self, session: Session) -> None:
self._session = session
self.reset_debug_features()
if isinstance(session, SessionDebugWrapper):
raise Exception("Cannot wrap already wrapped session!")
@property
def protocol_version(self) -> int:
return self.client.protocol_version
@property
def client(self) -> TrezorClientDebugLink:
assert isinstance(self._session.client, TrezorClientDebugLink)
return self._session.client
@property
def id(self) -> bytes:
return self._session.id
def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__)
self._session._write(self._filter_message(msg))
def _read(self) -> t.Any:
resp = self._filter_message(self._session._read())
print("reading message:", resp.__class__.__name__)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def set_expected_responses(
self,
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
) -> None:
"""Set a sequence of expected responses to session calls.
Within a given with-block, the list of received responses from device must
match the list of expected responses, otherwise an ``AssertionError`` is raised.
If an expected response is given a field value other than ``None``, that field value
must exactly match the received field value. If a given field is ``None``
(or unspecified) in the expected response, the received field value is not
checked.
Each expected response can also be a tuple ``(bool, message)``. In that case, the
expected response is only evaluated if the first field is ``True``.
This is useful for differentiating sequences between Trezor models:
>>> trezor_one = session.features.model == "1"
>>> session.set_expected_responses([
>>> messages.ButtonRequest(code=ConfirmOutput),
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
>>> messages.Success(),
>>> ])
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
# make sure all items are (bool, message) tuples
expected_with_validity = (
e if isinstance(e, tuple) else (True, e) for e in expected
)
# only apply those items that are (True, message)
self.expected_responses = [
MessageFilter.from_message_or_type(expected)
for valid, expected in expected_with_validity
if valid
]
self.actual_responses = []
def lock(self, *, _refresh_features: bool = True) -> None:
"""Lock the device.
If the device does not have a PIN configured, this will do nothing.
Otherwise, a lock screen will be shown and the device will prompt for PIN
before further actions.
This call does _not_ invalidate passphrase cache. If passphrase is in use,
the device will not prompt for it after unlocking.
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
passphrase cache, use `clear_session()`.
"""
# TODO update the documentation above
# Private argument _refresh_features can be used internally to avoid
# refreshing in cases where we will refresh soon anyway. This is used
# in TrezorClient.clear_session()
self.call(messages.LockDevice())
if _refresh_features:
self.refresh_features()
def cancel(self) -> None:
self._write(messages.Cancel())
def ensure_unlocked(self) -> None:
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features()
def set_filter(
self,
message_type: t.Type[protobuf.MessageType],
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.filters[message_type] = callback
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def reset_debug_features(self) -> None:
"""Prepare the debugging session for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: t.Dict[
t.Type[protobuf.MessageType],
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {}
self.button_callback = self.client.button_callback
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self._session.passphrase_callback
self.passphrase = self._session.passphrase
def __enter__(self) -> "SessionDebugWrapper":
# For usage in with/expected_responses
if self.in_with_statement:
raise RuntimeError("Do not nest!")
self.in_with_statement = True
return self
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# copy expected/actual responses before clearing them
expected_responses = self.expected_responses
actual_responses = self.actual_responses
# grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.client.ui, DebugUI):
input_flow = self.client.ui.input_flow
else:
input_flow = None
self.reset_debug_features()
if exc_type is None:
# If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses)
if isinstance(input_flow, t.Generator):
# Ensure that the input flow is exhausted
try:
input_flow.throw(
AssertionError("input flow continues past end of test")
)
except StopIteration:
pass
elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
@classmethod
def _verify_responses(
cls,
expected: list[MessageFilter] | None,
actual: list[protobuf.MessageType] | None,
) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if expected is None and actual is None:
return
assert expected is not None
assert actual is not None
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
if exp is None:
output = cls._expectation_lines(expected, i)
output.append("No more messages were expected, but we got:")
for resp in actual[i:]:
output.append(
textwrap.indent(protobuf.format_message(resp), " ")
)
raise AssertionError("\n".join(output))
if act is None:
output = cls._expectation_lines(expected, i)
output.append("This and the following message was not received.")
raise AssertionError("\n".join(output))
if not exp.match(act):
output = cls._expectation_lines(expected, i)
output.append("Actually received:")
output.append(textwrap.indent(protobuf.format_message(act), " "))
raise AssertionError("\n".join(output))
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: list[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
for i in range(start_at, stop_at):
exp = expected[i]
prefix = " " if i != current else ">>> "
output.append(textwrap.indent(exp.to_string(), prefix))
if stop_at < len(expected):
omitted = len(expected) - stop_at
output.append(f" (...{omitted} following responses omitted)")
output.append("")
return output
class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses
# and other functionality for unit tests
@ -1034,54 +1291,165 @@ class TrezorClientDebugLink(TrezorClient):
raise
# set transport explicitly so that sync_responses can work
super().__init__(transport)
self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
self.reset_debug_features()
self.reset_debug_features(new_management_session=True)
self.sync_responses()
super().__init__(transport, ui=self.ui)
# So that we can choose right screenshotting logic (T1 vs TT)
# and know the supported debug capabilities
self.debug.model = self.model
self.debug.version = self.version
self.passphrase: str | None = None
@property
def layout_type(self) -> LayoutType:
return self.debug.layout_type
def reset_debug_features(self) -> None:
"""Prepare the debugging client for a new testcase.
def get_new_client(self) -> TrezorClientDebugLink:
return TrezorClientDebugLink(self.transport, self.debug.allow_interactions)
def reset_debug_features(self, new_management_session: bool = False) -> None:
"""
Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.ui: DebugUI = DebugUI(self.debug)
# self.pin_callback = self.ui.debug_callback_button
self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: dict[
type[protobuf.MessageType],
Callable[[protobuf.MessageType], protobuf.MessageType] | None,
self.filters: t.Dict[
t.Type[protobuf.MessageType],
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {}
if new_management_session:
self._management_session = self.get_management_session(new_session=True)
@property
def button_callback(self):
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
return _callback_button
@property
def pin_callback(self):
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
try:
pin = self.ui.get_pin(msg.type)
except Cancelled:
session.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
session.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
return _callback_pin
@property
def passphrase_callback(self):
def _callback_passphrase(
session: Session, msg: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> t.Any:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
# session.session_id = resp.state
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
if session.passphrase is None and isinstance(session, SessionV1):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
else:
passphrase = session.passphrase
except Cancelled:
session.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
return _callback_passphrase
def ensure_open(self) -> None:
"""Only open session if there isn't already an open one."""
if self.session_counter == 0:
self.open()
# if self.session_counter == 0:
# self.open()
# TODO check if is this needed
def open(self) -> None:
super().open()
if self.session_counter == 1:
self.debug.open()
pass
# TODO is this needed?
# self.debug.open()
def close(self) -> None:
if self.session_counter == 1:
self.debug.close()
super().close()
pass
# TODO is this needed?
# self.debug.close()
def get_session(
self,
passphrase: str | object | None = "",
derive_cardano: bool = False,
) -> Session:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
return super().get_session(passphrase, derive_cardano)
def set_filter(
self,
message_type: type[protobuf.MessageType],
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
message_type: t.Type[protobuf.MessageType],
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
@ -1106,7 +1474,8 @@ class TrezorClientDebugLink(TrezorClient):
return msg
def set_input_flow(
self, input_flow: Generator[None, messages.ButtonRequest | None, None]
self,
input_flow: t.Generator[None, messages.ButtonRequest | None, None],
) -> None:
"""Configure a sequence of input events for the current with-block.
@ -1140,6 +1509,7 @@ class TrezorClientDebugLink(TrezorClient):
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow
assert input_flow is not None
input_flow.send(None) # start the generator
def watch_layout(self, watch: bool = True) -> None:
@ -1162,7 +1532,7 @@ class TrezorClientDebugLink(TrezorClient):
self.in_with_statement = True
return self
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# copy expected/actual responses before clearing them
@ -1175,20 +1545,21 @@ class TrezorClientDebugLink(TrezorClient):
else:
input_flow = None
self.reset_debug_features()
self.reset_debug_features(new_management_session=False)
if exc_type is None:
# If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses)
elif isinstance(input_flow, Generator):
elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
def set_expected_responses(
self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
self,
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
) -> None:
"""Set a sequence of expected responses to client calls.
@ -1227,7 +1598,7 @@ class TrezorClientDebugLink(TrezorClient):
]
self.actual_responses = []
def use_pin_sequence(self, pins: Iterable[str]) -> None:
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts.
"""
@ -1235,6 +1606,7 @@ class TrezorClientDebugLink(TrezorClient):
def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase."""
self.passphrase = passphrase
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def use_mnemonic(self, mnemonic: str) -> None:
@ -1244,15 +1616,14 @@ class TrezorClientDebugLink(TrezorClient):
def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = super()._raw_read()
resp = self.get_management_session()._read()
resp = self._filter_message(resp)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def _raw_write(self, msg: protobuf.MessageType) -> None:
return super()._raw_write(self._filter_message(msg))
return self.get_management_session()._write(self._filter_message(msg))
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
@ -1322,23 +1693,25 @@ class TrezorClientDebugLink(TrezorClient):
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
# prompt, which is in TINY mode and does not respond to `Ping`.
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
self.transport.begin_session()
# TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
self.transport.open()
try:
self.transport.write(*cancel_msg)
# self.protocol.write(messages.Cancel())
message = "SYNC" + secrets.token_hex(8)
ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message))
self.transport.write(*ping_msg)
self.get_management_session()._write(messages.Ping(message=message))
resp = None
while resp != messages.Success(message=message):
msg_id, msg_bytes = self.transport.read()
try:
resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes)
resp = self.get_management_session()._read()
raise Exception
except Exception:
pass
finally:
self.transport.end_session()
pass # TODO fix
# self.transport.end_session(self.session_id or b"")
def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word()
@ -1352,8 +1725,8 @@ class TrezorClientDebugLink(TrezorClient):
@expect(messages.Success, field="message", ret_type=str)
def load_device(
client: "TrezorClient",
mnemonic: Union[str, Iterable[str]],
session: "Session",
mnemonic: str | t.Iterable[str],
pin: str | None,
passphrase_protection: bool,
label: str | None,
@ -1366,12 +1739,12 @@ def load_device(
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
if client.features.initialized:
if session.features.initialized:
raise RuntimeError(
"Device is initialized already. Call device.wipe() and try again."
)
resp = client.call(
resp = session.call(
messages.LoadDevice(
mnemonics=mnemonics,
pin=pin,
@ -1382,7 +1755,7 @@ def load_device(
no_backup=no_backup,
)
)
client.init_device()
session.refresh_features()
return resp
@ -1391,11 +1764,11 @@ load_device_by_mnemonic = load_device
@expect(messages.Success, field="message", ret_type=str)
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
if client.features.bootloader_mode is not True:
def prodtest_t1(session: "Session") -> protobuf.MessageType:
if session.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode")
return client.call(
return session.call(
messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
)
@ -1404,8 +1777,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
def record_screen(
debug_client: "TrezorClientDebugLink",
directory: Union[str, None],
report_func: Union[Callable[[str], None], None] = None,
directory: str | None,
report_func: t.Callable[[str], None] | None = None,
) -> None:
"""Record screen changes into a specified directory.
@ -1451,5 +1824,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
@expect(messages.Success, field="message", ret_type=str)
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType:
return client.call(messages.DebugLinkOptigaSetSecMax())
def optiga_set_sec_max(session: "Session") -> protobuf.MessageType:
return session.call(messages.DebugLinkOptigaSetSecMax())

View File

@ -27,20 +27,19 @@ from slip10 import SLIP10
from . import messages
from .exceptions import Cancelled, TrezorException
from .tools import Address, expect, parse_path, session
from .tools import Address, expect, parse_path
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.session import Session
RECOVERY_BACK = "\x08" # backspace character, sent literally
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_settings(
client: "TrezorClient",
session: "Session",
label: Optional[str] = None,
language: Optional[str] = None,
use_passphrase: Optional[bool] = None,
@ -71,13 +70,13 @@ def apply_settings(
haptic_feedback=haptic_feedback,
)
out = client.call(settings)
client.refresh_features()
out = session.call(settings)
session.refresh_features()
return out
def _send_language_data(
client: "TrezorClient",
session: "Session",
request: "messages.TranslationDataRequest",
language_data: bytes,
) -> "MessageType":
@ -87,76 +86,70 @@ def _send_language_data(
data_length = response.data_length
data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
return response
@expect(messages.Success, field="message", ret_type=str)
@session
def change_language(
client: "TrezorClient",
session: "Session",
language_data: bytes,
show_display: bool | None = None,
) -> "MessageType":
data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg)
response = session.call(msg)
if data_length > 0:
assert isinstance(response, messages.TranslationDataRequest)
response = _send_language_data(client, response, language_data)
response = _send_language_data(session, response, language_data)
assert isinstance(response, messages.Success)
client.refresh_features() # changing the language in features
session.refresh_features() # changing the language in features
return response
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
out = client.call(messages.ApplyFlags(flags=flags))
client.refresh_features()
def apply_flags(session: "Session", flags: int) -> "MessageType":
out = session.call(messages.ApplyFlags(flags=flags))
session.refresh_features()
return out
@expect(messages.Success, field="message", ret_type=str)
@session
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangePin(remove=remove))
client.refresh_features()
def change_pin(session: "Session", remove: bool = False) -> "MessageType":
ret = session.call(messages.ChangePin(remove=remove))
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
@session
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangeWipeCode(remove=remove))
client.refresh_features()
def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType":
ret = session.call(messages.ChangeWipeCode(remove=remove))
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
@session
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
session: "Session", operation: messages.SdProtectOperationType
) -> "MessageType":
ret = client.call(messages.SdProtect(operation=operation))
client.refresh_features()
ret = session.call(messages.SdProtect(operation=operation))
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
@session
def wipe(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.WipeDevice())
if not client.features.bootloader_mode:
client.init_device()
def wipe(session: "Session") -> "MessageType":
ret = session.call(messages.WipeDevice())
# if not session.features.bootloader_mode:
# session.refresh_features()
return ret
@session
def recover(
client: "TrezorClient",
session: "Session",
word_count: int = 24,
passphrase_protection: bool = False,
pin_protection: bool = True,
@ -192,13 +185,13 @@ def recover(
if type is None:
type = messages.RecoveryType.NormalRecovery
if client.features.model == "1" and input_callback is None:
if session.features.model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One")
if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24")
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
raise RuntimeError(
"Device already initialized. Call device.wipe() and try again."
)
@ -220,17 +213,17 @@ def recover(
msg.label = label
msg.u2f_counter = u2f_counter
res = client.call(msg)
res = session.call(msg)
while isinstance(res, messages.WordRequest):
try:
assert input_callback is not None
inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp))
res = session.call(messages.WordAck(word=inp))
except Cancelled:
res = client.call(messages.Cancel())
res = session.call(messages.Cancel())
client.init_device()
session.refresh_features()
return res
@ -279,9 +272,8 @@ def reset(*args: Any, **kwargs: Any) -> "MessageType":
return reset_entropy_check(*args, **kwargs)[0]
@session
def reset_entropy_check(
client: "TrezorClient",
session: "Session",
display_random: bool = False,
strength: Optional[int] = None,
passphrase_protection: bool = False,
@ -307,13 +299,13 @@ def reset_entropy_check(
DeprecationWarning,
)
if client.features.initialized:
if session.features.initialized:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
)
if strength is None:
if client.features.model == "1":
if session.features.model == "1":
strength = 256
else:
strength = 128
@ -335,7 +327,7 @@ def reset_entropy_check(
entropy_check=entropy_check_count is not None,
)
resp = client.call(msg)
resp = session.call(msg)
if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest")
@ -344,7 +336,7 @@ def reset_entropy_check(
external_entropy = os.urandom(32)
entropy_commitment = resp.entropy_commitment
resp = client.call(messages.EntropyAck(entropy=external_entropy))
resp = session.call(messages.EntropyAck(entropy=external_entropy))
if entropy_check_count is None:
break
@ -353,18 +345,18 @@ def reset_entropy_check(
return resp, []
for path in paths:
resp = client.call(messages.GetPublicKey(address_n=path))
resp = session.call(messages.GetPublicKey(address_n=path))
if not isinstance(resp, messages.PublicKey):
return resp, []
xpubs.append(resp.xpub)
if entropy_check_count <= 0:
resp = client.call(messages.EntropyCheckContinue(finish=True))
resp = session.call(messages.EntropyCheckContinue(finish=True))
break
entropy_check_count -= 1
resp = client.call(messages.EntropyCheckContinue(finish=False))
resp = session.call(messages.EntropyCheckContinue(finish=False))
if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest")
@ -385,18 +377,17 @@ def reset_entropy_check(
if slip10.get_xpub_from_path(path) != xpub:
raise RuntimeError("Invalid XPUB in entropy check")
client.init_device()
session.refresh_features()
return resp, zip(paths, xpubs)
@expect(messages.Success, field="message", ret_type=str)
@session
def backup(
client: "TrezorClient",
session: "Session",
group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (),
) -> "MessageType":
ret = client.call(
ret = session.call(
messages.BackupDevice(
group_threshold=group_threshold,
groups=[
@ -405,37 +396,36 @@ def backup(
],
)
)
client.refresh_features()
session.refresh_features()
return ret
@expect(messages.Success, field="message", ret_type=str)
def cancel_authorization(client: "TrezorClient") -> "MessageType":
return client.call(messages.CancelAuthorization())
def cancel_authorization(session: "Session") -> "MessageType":
return session.call(messages.CancelAuthorization())
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType":
resp = client.call(messages.UnlockPath(address_n=n))
def unlock_path(session: "Session", n: "Address") -> "MessageType":
resp = session.call(messages.UnlockPath(address_n=n))
# Cancel the UnlockPath workflow now that we have the authentication code.
try:
client.call(messages.Cancel())
session.call(messages.Cancel())
except Cancelled:
return resp
else:
raise TrezorException("Unexpected response in UnlockPath flow")
@session
@expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader(
client: "TrezorClient",
session: "Session",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None,
language_data: bytes = b"",
) -> "MessageType":
response = client.call(
response = session.call(
messages.RebootToBootloader(
boot_command=boot_command,
firmware_header=firmware_header,
@ -443,42 +433,37 @@ def reboot_to_bootloader(
)
)
if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data)
response = _send_language_data(session, response, language_data)
return response
@session
@expect(messages.Success, field="message", ret_type=str)
def show_device_tutorial(client: "TrezorClient") -> "MessageType":
return client.call(messages.ShowDeviceTutorial())
@session
@expect(messages.Success, field="message", ret_type=str)
def unlock_bootloader(client: "TrezorClient") -> "MessageType":
return client.call(messages.UnlockBootloader())
def show_device_tutorial(session: "Session") -> "MessageType":
return session.call(messages.ShowDeviceTutorial())
@expect(messages.Success, field="message", ret_type=str)
@session
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType":
def unlock_bootloader(session: "Session") -> "MessageType":
return session.call(messages.UnlockBootloader())
@expect(messages.Success, field="message", ret_type=str)
def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType":
"""Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state.
"""
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms))
client.refresh_features()
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms))
session.refresh_features()
return ret
@expect(messages.AuthenticityProof)
def authenticate(client: "TrezorClient", challenge: bytes):
return client.call(messages.AuthenticateDevice(challenge=challenge))
def authenticate(session: "Session", challenge: bytes):
return session.call(messages.AuthenticateDevice(challenge=challenge))
@expect(messages.Success, field="message", ret_type=str)
def set_brightness(
client: "TrezorClient", value: Optional[int] = None
) -> "MessageType":
return client.call(messages.SetBrightness(value=value))
def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType":
return session.call(messages.SetBrightness(value=value))

View File

@ -18,12 +18,12 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages
from .tools import b58decode, expect, session
from .tools import b58decode, expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
def name_to_number(name: str) -> int:
@ -321,17 +321,16 @@ def parse_transaction_json(
@expect(messages.EosPublicKey)
def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False
session: "Session", n: "Address", show_display: bool = False
) -> "MessageType":
response = client.call(
response = session.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display)
)
return response
@session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: "Address",
transaction: dict,
chain_id: str,
@ -347,11 +346,11 @@ def sign_tx(
chunkify=chunkify,
)
response = client.call(msg)
response = session.call(msg)
try:
while isinstance(response, messages.EosTxActionRequest):
response = client.call(actions.pop(0))
response = session.call(actions.pop(0))
except IndexError:
# pop from empty list
raise exceptions.TrezorException(

View File

@ -18,12 +18,12 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages
from .tools import expect, prepare_message_bytes, session, unharden
from .tools import expect, prepare_message_bytes, unharden
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
def int_to_big_endian(value: int) -> bytes:
@ -163,13 +163,13 @@ def network_from_address_n(
@expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
show_display: bool = False,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumGetAddress(
address_n=n,
show_display=show_display,
@ -181,16 +181,15 @@ def get_address(
@expect(messages.EthereumPublicKey)
def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False
session: "Session", n: "Address", show_display: bool = False
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
)
@session
def sign_tx(
client: "TrezorClient",
session: "Session",
n: "Address",
nonce: int,
gas_price: int,
@ -226,13 +225,13 @@ def sign_tx(
data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk
response = client.call(msg)
response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
@ -247,9 +246,8 @@ def sign_tx(
return response.signature_v, response.signature_r, response.signature_s
@session
def sign_tx_eip1559(
client: "TrezorClient",
session: "Session",
n: "Address",
*,
nonce: int,
@ -282,13 +280,13 @@ def sign_tx_eip1559(
chunkify=chunkify,
)
response = client.call(msg)
response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
@ -299,13 +297,13 @@ def sign_tx_eip1559(
@expect(messages.EthereumMessageSignature)
def sign_message(
client: "TrezorClient",
session: "Session",
n: "Address",
message: AnyStr,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumSignMessage(
address_n=n,
message=prepare_message_bytes(message),
@ -317,7 +315,7 @@ def sign_message(
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data(
client: "TrezorClient",
session: "Session",
n: "Address",
data: Dict[str, Any],
*,
@ -333,7 +331,7 @@ def sign_typed_data(
metamask_v4_compat=metamask_v4_compat,
definitions=definitions,
)
response = client.call(request)
response = session.call(request)
# Sending all the types
while isinstance(response, messages.EthereumTypedDataStructRequest):
@ -349,7 +347,7 @@ def sign_typed_data(
members.append(struct_member)
request = messages.EthereumTypedDataStructAck(members=members)
response = client.call(request)
response = session.call(request)
# Sending the whole message that should be signed
while isinstance(response, messages.EthereumTypedDataValueRequest):
@ -362,7 +360,7 @@ def sign_typed_data(
member_typename = data["primaryType"]
member_data = data["message"]
else:
client.cancel()
# TODO session.cancel()
raise exceptions.TrezorException("Root index can only be 0 or 1")
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
@ -385,20 +383,20 @@ def sign_typed_data(
encoded_data = encode_data(member_data, member_typename)
request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request)
response = session.call(request)
return response
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: bytes,
message: AnyStr,
chunkify: bool = False,
) -> bool:
try:
resp = client.call(
resp = session.call(
messages.EthereumVerifyMessage(
address=address,
signature=signature,
@ -413,13 +411,13 @@ def verify_message(
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data_hash(
client: "TrezorClient",
session: "Session",
n: "Address",
domain_hash: bytes,
message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None,
) -> "MessageType":
return client.call(
return session.call(
messages.EthereumSignTypedHash(
address_n=n,
domain_separator_hash=domain_hash,

View File

@ -20,8 +20,8 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.session import Session
@expect(
@ -29,27 +29,27 @@ if TYPE_CHECKING:
field="credentials",
ret_type=List[messages.WebAuthnCredential],
)
def list_credentials(client: "TrezorClient") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials())
def list_credentials(session: "Session") -> "MessageType":
return session.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message", ret_type=str)
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
return client.call(
def add_credential(session: "Session", credential_id: bytes) -> "MessageType":
return session.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
)
@expect(messages.Success, field="message", ret_type=str)
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
def remove_credential(session: "Session", index: int) -> "MessageType":
return session.call(messages.WebAuthnRemoveResidentCredential(index=index))
@expect(messages.Success, field="message", ret_type=str)
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
def set_counter(session: "Session", u2f_counter: int) -> "MessageType":
return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
def get_next_counter(client: "TrezorClient") -> "MessageType":
return client.call(messages.GetNextU2FCounter())
def get_next_counter(session: "Session") -> "MessageType":
return session.call(messages.GetNextU2FCounter())

View File

@ -20,7 +20,7 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard
from .. import messages
from ..tools import expect, session
from ..tools import expect
from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware
@ -38,7 +38,7 @@ if True:
from .vendor import * # noqa: F401, F403
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
T = t.TypeVar("T", bound="FirmwareType")
@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
# ====== Client functions ====== #
@session
def update(
client: "TrezorClient",
session: "Session",
data: bytes,
progress_update: t.Callable[[int], t.Any] = lambda _: None,
):
if client.features.bootloader_mode is False:
if session.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode")
resp = client.call(messages.FirmwareErase(length=len(data)))
resp = session.call(messages.FirmwareErase(length=len(data)))
# TREZORv1 method
if isinstance(resp, messages.Success):
resp = client.call(messages.FirmwareUpload(payload=data))
resp = session.call(messages.FirmwareUpload(payload=data))
progress_update(len(data))
if isinstance(resp, messages.Success):
return
@ -97,7 +96,7 @@ def update(
length = resp.length
payload = data[resp.offset : resp.offset + length]
digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
progress_update(length)
if isinstance(resp, messages.Success):
@ -107,5 +106,5 @@ def update(
@expect(messages.FirmwareHash, field="hash", ret_type=bytes)
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]):
return client.call(messages.GetFirmwareHash(challenge=challenge))
def get_hash(session: "Session", challenge: t.Optional[bytes]):
return session.call(messages.GetFirmwareHash(challenge=challenge))

View File

@ -17,6 +17,7 @@
from __future__ import annotations
import io
import logging
from types import ModuleType
from typing import Dict, Optional, Tuple, Type, TypeVar
@ -25,6 +26,7 @@ from typing_extensions import Self
from . import messages, protobuf
T = TypeVar("T")
LOG = logging.getLogger(__name__)
class ProtobufMapping:
@ -63,11 +65,21 @@ class ProtobufMapping:
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
if wire_type is None:
raise ValueError("Cannot encode class without wire type")
LOG.debug("encoding wire type %d", wire_type)
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return wire_type, buf.getvalue()
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
"""Serialize a Python protobuf class.
Returns the byte representation of the protobuf message.
"""
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return buf.getvalue()
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
"""Deserialize a protobuf message into a Python class."""
cls = self.type_to_class[msg_wire_type]
@ -83,7 +95,9 @@ class ProtobufMapping:
mapping = cls()
message_types = getattr(module, "MessageType")
for entry in message_types:
thp_message_types = getattr(module, "ThpMessageType")
for entry in (*message_types, *thp_message_types):
msg_class = getattr(module, entry.name, None)
if msg_class is None:
raise ValueError(

View File

@ -43,6 +43,8 @@ class FailureType(IntEnum):
PinMismatch = 12
WipeCodeMismatch = 13
InvalidSession = 14
ThpUnallocatedSession = 15
InvalidProtocol = 16
FirmwareError = 99
@ -400,6 +402,34 @@ class TezosBallotType(IntEnum):
Pass = 2
class ThpMessageType(IntEnum):
ThpCreateNewSession = 1000
ThpNewSession = 1001
ThpStartPairingRequest = 1008
ThpPairingPreparationsFinished = 1009
ThpCredentialRequest = 1010
ThpCredentialResponse = 1011
ThpEndRequest = 1012
ThpEndResponse = 1013
ThpCodeEntryCommitment = 1016
ThpCodeEntryChallenge = 1017
ThpCodeEntryCpaceHost = 1018
ThpCodeEntryCpaceTrezor = 1019
ThpCodeEntryTag = 1020
ThpCodeEntrySecret = 1021
ThpQrCodeTag = 1024
ThpQrCodeSecret = 1025
ThpNfcUnidirectionalTag = 1032
ThpNfcUnidirectionalSecret = 1033
class ThpPairingMethod(IntEnum):
NoMethod = 1
CodeEntry = 2
QrCode = 3
NFC_Unidirectional = 4
class MessageType(IntEnum):
Initialize = 0
Ping = 1
@ -4136,6 +4166,7 @@ class DebugLinkGetState(protobuf.MessageType):
1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None),
2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE),
4: protobuf.Field("thp_channel_id", "bytes", repeated=False, required=False, default=None),
}
def __init__(
@ -4144,10 +4175,12 @@ class DebugLinkGetState(protobuf.MessageType):
wait_word_list: Optional["bool"] = None,
wait_word_pos: Optional["bool"] = None,
wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE,
thp_channel_id: Optional["bytes"] = None,
) -> None:
self.wait_word_list = wait_word_list
self.wait_word_pos = wait_word_pos
self.wait_layout = wait_layout
self.thp_channel_id = thp_channel_id
class DebugLinkState(protobuf.MessageType):
@ -4166,6 +4199,9 @@ class DebugLinkState(protobuf.MessageType):
11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None),
12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None),
13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None),
14: protobuf.Field("thp_pairing_code_entry_code", "uint32", repeated=False, required=False, default=None),
15: protobuf.Field("thp_pairing_code_qr_code", "bytes", repeated=False, required=False, default=None),
16: protobuf.Field("thp_pairing_code_nfc_unidirectional", "bytes", repeated=False, required=False, default=None),
}
def __init__(
@ -4184,6 +4220,9 @@ class DebugLinkState(protobuf.MessageType):
recovery_word_pos: Optional["int"] = None,
reset_word_pos: Optional["int"] = None,
mnemonic_type: Optional["BackupType"] = None,
thp_pairing_code_entry_code: Optional["int"] = None,
thp_pairing_code_qr_code: Optional["bytes"] = None,
thp_pairing_code_nfc_unidirectional: Optional["bytes"] = None,
) -> None:
self.tokens: Sequence["str"] = tokens if tokens is not None else []
self.layout = layout
@ -4198,6 +4237,9 @@ class DebugLinkState(protobuf.MessageType):
self.recovery_word_pos = recovery_word_pos
self.reset_word_pos = reset_word_pos
self.mnemonic_type = mnemonic_type
self.thp_pairing_code_entry_code = thp_pairing_code_entry_code
self.thp_pairing_code_qr_code = thp_pairing_code_qr_code
self.thp_pairing_code_nfc_unidirectional = thp_pairing_code_nfc_unidirectional
class DebugLinkStop(protobuf.MessageType):
@ -7860,6 +7902,280 @@ class TezosManagerTransfer(protobuf.MessageType):
self.amount = amount
class ThpDeviceProperties(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None),
4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None),
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
internal_model: Optional["str"] = None,
model_variant: Optional["int"] = None,
bootloader_mode: Optional["bool"] = None,
protocol_version: Optional["int"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.internal_model = internal_model
self.model_variant = model_variant
self.bootloader_mode = bootloader_mode
self.protocol_version = protocol_version
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
}
def __init__(
self,
*,
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
host_pairing_credential: Optional["bytes"] = None,
) -> None:
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
self.host_pairing_credential = host_pairing_credential
class ThpCreateNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1000
FIELDS = {
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
passphrase: Optional["str"] = None,
on_device: Optional["bool"] = None,
derive_cardano: Optional["bool"] = None,
) -> None:
self.passphrase = passphrase
self.on_device = on_device
self.derive_cardano = derive_cardano
class ThpNewSession(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1001
FIELDS = {
1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
new_session_id: Optional["int"] = None,
) -> None:
self.new_session_id = new_session_id
class ThpStartPairingRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008
FIELDS = {
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_name: Optional["str"] = None,
) -> None:
self.host_name = host_name
class ThpPairingPreparationsFinished(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1009
class ThpCodeEntryCommitment(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1016
FIELDS = {
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
commitment: Optional["bytes"] = None,
) -> None:
self.commitment = commitment
class ThpCodeEntryChallenge(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1017
FIELDS = {
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
challenge: Optional["bytes"] = None,
) -> None:
self.challenge = challenge
class ThpCodeEntryCpaceHost(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1018
FIELDS = {
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_host_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_host_public_key = cpace_host_public_key
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1019
FIELDS = {
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
cpace_trezor_public_key: Optional["bytes"] = None,
) -> None:
self.cpace_trezor_public_key = cpace_trezor_public_key
class ThpCodeEntryTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1020
FIELDS = {
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpCodeEntrySecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1021
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpQrCodeTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1024
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpQrCodeSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1025
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpNfcUnidirectionalTag(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1032
FIELDS = {
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
tag: Optional["bytes"] = None,
) -> None:
self.tag = tag
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1033
FIELDS = {
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
secret: Optional["bytes"] = None,
) -> None:
self.secret = secret
class ThpCredentialRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1010
FIELDS = {
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
host_static_pubkey: Optional["bytes"] = None,
) -> None:
self.host_static_pubkey = host_static_pubkey
class ThpCredentialResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1011
FIELDS = {
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
}
def __init__(
self,
*,
trezor_static_pubkey: Optional["bytes"] = None,
credential: Optional["bytes"] = None,
) -> None:
self.trezor_static_pubkey = trezor_static_pubkey
self.credential = credential
class ThpEndRequest(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1012
class ThpEndResponse(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1013
class ThpCredentialMetadata(protobuf.MessageType):
MESSAGE_WIRE_TYPE = None
FIELDS = {

View File

@ -20,25 +20,25 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
@expect(messages.Entropy, field="entropy", ret_type=bytes)
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
return client.call(messages.GetEntropy(size=size))
def get_entropy(session: "Session", size: int) -> "MessageType":
return session.call(messages.GetEntropy(size=size))
@expect(messages.SignedIdentity)
def sign_identity(
client: "TrezorClient",
session: "Session",
identity: messages.IdentityType,
challenge_hidden: bytes,
challenge_visual: str,
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
return client.call(
return session.call(
messages.SignIdentity(
identity=identity,
challenge_hidden=challenge_hidden,
@ -50,12 +50,12 @@ def sign_identity(
@expect(messages.ECDHSessionKey)
def get_ecdh_session_key(
client: "TrezorClient",
session: "Session",
identity: messages.IdentityType,
peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
return client.call(
return session.call(
messages.GetECDHSessionKey(
identity=identity,
peer_public_key=peer_public_key,
@ -66,7 +66,7 @@ def get_ecdh_session_key(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
n: "Address",
key: str,
value: bytes,
@ -74,7 +74,7 @@ def encrypt_keyvalue(
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> "MessageType":
return client.call(
return session.call(
messages.CipherKeyValue(
address_n=n,
key=key,
@ -89,7 +89,7 @@ def encrypt_keyvalue(
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
n: "Address",
key: str,
value: bytes,
@ -97,7 +97,7 @@ def decrypt_keyvalue(
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> "MessageType":
return client.call(
return session.call(
messages.CipherKeyValue(
address_n=n,
key=key,
@ -111,5 +111,5 @@ def decrypt_keyvalue(
@expect(messages.Nonce, field="nonce", ret_type=bytes)
def get_nonce(client: "TrezorClient"):
return client.call(messages.GetNonce())
def get_nonce(session: "Session"):
return session.call(messages.GetNonce())

View File

@ -20,9 +20,9 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
# MAINNET = 0
@ -33,13 +33,13 @@ if TYPE_CHECKING:
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.MoneroGetAddress(
address_n=n,
show_display=show_display,
@ -51,10 +51,10 @@ def get_address(
@expect(messages.MoneroWatchKey)
def get_watch_key(
client: "TrezorClient",
session: "Session",
n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> "MessageType":
return client.call(
return session.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
)

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801
@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
@expect(messages.NEMAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
network: int,
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify
)
@ -213,7 +213,7 @@ def get_address(
@expect(messages.NEMSignedTx)
def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
) -> "MessageType":
try:
msg = create_sign_tx(transaction, chunkify=chunkify)
@ -222,4 +222,4 @@ def sign_tx(
assert msg.transaction is not None
msg.transaction.address_n = n
return client.call(msg)
return session.call(msg)

View File

@ -21,9 +21,9 @@ from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@ -31,12 +31,12 @@ REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -45,14 +45,14 @@ def get_address(
@expect(messages.RippleSignedTx)
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: "Address",
msg: messages.RippleSignTx,
chunkify: bool = False,
) -> "MessageType":
msg.address_n = address_n
msg.chunkify = chunkify
return client.call(msg)
return session.call(msg)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -4,29 +4,29 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .transport.session import Session
@expect(messages.SolanaPublicKey)
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: List[int],
show_display: bool,
) -> "MessageType":
return client.call(
return session.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
)
@expect(messages.SolanaAddress)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: List[int],
show_display: bool,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.SolanaGetAddress(
address_n=address_n,
show_display=show_display,
@ -37,12 +37,12 @@ def get_address(
@expect(messages.SolanaTxSignature)
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: List[int],
serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> "MessageType":
return client.call(
return session.call(
messages.SolanaSignTx(
address_n=address_n,
serialized_tx=serialized_tx,

View File

@ -21,9 +21,9 @@ from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
StellarMessageType = Union[
messages.StellarAccountMergeOp,
@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
@expect(messages.StellarAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -338,7 +338,7 @@ def get_address(
def sign_tx(
client: "TrezorClient",
session: "Session",
tx: messages.StellarSignTx,
operations: List["StellarMessageType"],
address_n: "Address",
@ -354,10 +354,10 @@ def sign_tx(
# 3. Receive a StellarTxOpRequest message
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
# 5. The final message received will be StellarSignedTx which is returned from this method
resp = client.call(tx)
resp = session.call(tx)
try:
while isinstance(resp, messages.StellarTxOpRequest):
resp = client.call(operations.pop(0))
resp = session.call(operations.pop(0))
except IndexError:
# pop from empty list
raise exceptions.TrezorException(

View File

@ -20,19 +20,19 @@ from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
from .transport.session import Session
@expect(messages.TezosAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -41,12 +41,12 @@ def get_address(
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
@ -55,11 +55,11 @@ def get_public_key(
@expect(messages.TezosSignedTx)
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: "Address",
sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False,
) -> "MessageType":
sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg)
return session.call(sign_tx_msg)

View File

@ -40,7 +40,7 @@ if TYPE_CHECKING:
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
from . import client
from .protobuf import MessageType
@ -301,23 +301,6 @@ def expect(
return decorator
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
return f(client, *args, **kwargs)
finally:
client.close()
return wrapped_f
# de-camelcasifier
# https://stackoverflow.com/a/1176023/222189

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -14,24 +14,18 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
import typing as t
from ..exceptions import TrezorException
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from ..models import TrezorModel
T = TypeVar("T", bound="Transport")
T = t.TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__)
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip()
MessagePayload = Tuple[int, bytes]
MessagePayload = t.Tuple[int, bytes]
class TransportException(TrezorException):
@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException):
class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX: str
ENABLED = False
def __str__(self) -> str:
return self.get_path()
@classmethod
def enumerate(
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
) -> t.Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if device.get_path() == path:
return device
if prefix_search and device.get_path().startswith(path):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str:
raise NotImplementedError
def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
def find_debug(self: "T") -> "T":
raise NotImplementedError
@classmethod
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
def open(self) -> None:
raise NotImplementedError
@classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if (
path is None
or device.get_path() == path
or (prefix_search and device.get_path().startswith(path))
):
return device
def close(self) -> None:
raise NotImplementedError
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self) -> bytes:
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int]
def all_transports() -> Iterable[Type["Transport"]]:
def all_transports() -> t.Iterable[t.Type["Transport"]]:
from .bridge import BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
transports: Tuple[Type["Transport"], ...] = (
transports: t.Tuple[t.Type["Transport"], ...] = (
BridgeTransport,
HidTransport,
UdpTransport,
@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = []
models: t.Iterable["TrezorModel"] | None = None,
) -> t.Sequence["Transport"]:
devices: t.List["Transport"] = []
for transport in all_transports():
name = transport.__name__
try:
@ -145,9 +121,7 @@ def enumerate_devices(
return devices
def get_transport(
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
if path is None:
try:
return next(iter(enumerate_devices()))

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -14,24 +14,30 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import struct
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
import typing as t
import requests
from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from ..models import TrezorModel
LOG = logging.getLogger(__name__)
PROTOCOL_VERSION_1 = 1
PROTOCOL_VERSION_2 = 2
TREZORD_HOST = "http://127.0.0.1:21325"
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
TREZORD_VERSION_MODERN = (2, 0, 25)
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
@ -45,7 +51,7 @@ class BridgeException(TransportException):
super().__init__(f"trezord: {path} failed with code {status}: {message}")
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
def call_bridge(path: str, data: str | None = None) -> requests.Response:
url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data)
if r.status_code != 200:
@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
return r
def is_legacy_bridge() -> bool:
def get_bridge_version() -> t.Tuple[int, ...]:
config = call_bridge("configure").json()
version_tuple = tuple(map(int, config["version"].split(".")))
return version_tuple < TREZORD_VERSION_MODERN
return tuple(map(int, config["version"].split(".")))
def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN
def supports_protocolV2() -> bool:
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
def detect_protocol_version(transport: "BridgeTransport") -> int:
from .. import mapping, messages
from ..messages import FailureType
protocol_version = PROTOCOL_VERSION_1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
transport.deprecated_begin_session()
transport.deprecated_write(request_type, request_data)
response_type, response_data = transport.deprecated_read()
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
transport.deprecated_begin_session()
if isinstance(response, messages.Failure):
if response.code == FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol_version = PROTOCOL_VERSION_2
return protocol_version
def _is_transport_valid(transport: "BridgeTransport") -> bool:
is_valid = (
supports_protocolV2()
or detect_protocol_version(transport) == PROTOCOL_VERSION_1
)
if not is_valid:
LOG.warning("Detected unsupported Bridge transport!")
return is_valid
def filter_invalid_bridge_transports(
transports: t.Iterable["BridgeTransport"],
) -> t.Sequence["BridgeTransport"]:
"""Filters out invalid bridge transports. Keeps only valid ones."""
return [t for t in transports if _is_transport_valid(t)]
class BridgeHandle:
@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle):
class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport)
self.request: Optional[str] = None
self.request: str | None = None
def write_buf(self, buf: bytes) -> None:
if self.request is not None:
@ -112,13 +162,12 @@ class BridgeTransport(Transport):
ENABLED: bool = True
def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False
self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
) -> None:
if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge")
self.device = device
self.session: Optional[str] = None
self.session: str | None = device["session"]
self.debug = debug
self.legacy = legacy
@ -135,7 +184,7 @@ class BridgeTransport(Transport):
raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True)
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
def _call(self, action: str, data: str | None = None) -> requests.Response:
session = self.session or "null"
uri = action + "/" + str(session)
if self.debug:
@ -144,17 +193,20 @@ class BridgeTransport(Transport):
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BridgeTransport"]:
cls, _models: t.Iterable["TrezorModel"] | None = None
) -> t.Iterable["BridgeTransport"]:
try:
legacy = is_legacy_bridge()
return [
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json()
return filter_invalid_bridge_transports(
[
BridgeTransport(dev, legacy)
for dev in call_bridge("enumerate").json()
]
)
except Exception:
return []
def begin_session(self) -> None:
def deprecated_begin_session(self) -> None:
try:
data = self._call("acquire/" + self.device["path"])
except BridgeException as e:
@ -163,18 +215,32 @@ class BridgeTransport(Transport):
raise
self.session = data.json()["session"]
def end_session(self) -> None:
def deprecated_end_session(self) -> None:
if not self.session:
return
self._call("release")
self.session = None
def write(self, message_type: int, message_data: bytes) -> None:
def deprecated_write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data)
def read(self) -> MessagePayload:
def deprecated_read(self) -> MessagePayload:
data = self.handle.read_buf()
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen]
def open(self) -> None:
pass
# TODO self.handle.open()
def close(self) -> None:
pass
# TODO self.handle.close()
def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :)
self.handle.write_buf(chunk)
def read_chunk(self) -> bytes: # TODO check if it works :)
return self.handle.read_buf()

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -14,15 +14,16 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import sys
import time
from typing import Any, Dict, Iterable, List, Optional
import typing as t
from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import UDEV_RULES_STR, Transport, TransportException
LOG = logging.getLogger(__name__)
@ -35,23 +36,61 @@ except Exception as e:
HID_IMPORTED = False
HidDevice = Dict[str, Any]
HidDeviceHandle = Any
HidDevice = t.Dict[str, t.Any]
HidDeviceHandle = t.Any
class HidHandle:
def __init__(
self, path: bytes, serial: str, probe_hid_version: bool = False
) -> None:
self.path = path
self.serial = serial
class HidTransport(Transport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
self.device = device
self.device_path = device["path"]
self.device_serial_number = device["serial_number"]
self.handle: HidDeviceHandle = None
self.hid_version = None if probe_hid_version else 2
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
) -> t.Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: t.List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self) -> None:
self.handle = hid.device()
try:
self.handle.open_path(self.path)
self.handle.open_path(self.device_path)
except (IOError, OSError) as e:
if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,)
@ -62,11 +101,11 @@ class HidHandle:
# and we wouldn't even know.
# So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string()
if serial != self.serial:
if serial != self.device_serial_number:
self.handle.close()
self.handle = None
raise TransportException(
f"Unexpected device {serial} on path {self.path.decode()}"
f"Unexpected device {serial} on path {self.device_path.decode()}"
)
self.handle.set_nonblocking(True)
@ -77,7 +116,7 @@ class HidHandle:
def close(self) -> None:
if self.handle is not None:
# reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string()
self.device_serial_number = self.handle.get_serial_number_string()
self.handle.close()
self.handle = None
@ -115,53 +154,6 @@ class HidHandle:
raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
super().__init__(protocol=ProtocolV1(self.handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0

View File

@ -1,165 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import MessagePayload, Transport
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None: ...
def close(self) -> None: ...
def read_chunk(self) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self) -> MessagePayload:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
HEADER_LEN = struct.calcsize(">HL")
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -0,0 +1,210 @@
from __future__ import annotations
import logging
import typing as t
from .. import exceptions, messages, models
from .thp.protocol_v1 import ProtocolV1
from .thp.protocol_v2 import ProtocolV2
if t.TYPE_CHECKING:
from ..client import TrezorClient
LOG = logging.getLogger(__name__)
class Session:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
def __init__(
self, client: TrezorClient, id: bytes, passphrase: str | object | None = None
) -> None:
self.client = client
self._id = id
self.passphrase = passphrase
@classmethod
def new(
cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool
) -> Session:
raise NotImplementedError
def call(self, msg: t.Any) -> t.Any:
# TODO self.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
if self.pin_callback is None:
raise Exception # TODO
resp = self.pin_callback(self, resp)
elif isinstance(resp, messages.PassphraseRequest):
if self.passphrase_callback is None:
raise Exception # TODO
resp = self.passphrase_callback(self, resp)
elif isinstance(resp, messages.ButtonRequest):
if self.button_callback is None:
raise Exception # TODO
resp = self.button_callback(self, resp)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
else:
return resp
def call_raw(self, msg: t.Any) -> t.Any:
self._write(msg)
return self._read()
def _write(self, msg: t.Any) -> None:
raise NotImplementedError
def _read(self) -> t.Any:
raise NotImplementedError
def refresh_features(self) -> None:
self.client.refresh_features()
def end(self) -> t.Any:
return self.call(messages.EndSession())
def ping(self, message: str, button_protection: bool | None = None) -> str:
resp: messages.Success = self.call(
messages.Ping(message=message, button_protection=button_protection)
)
return resp.message or ""
@property
def features(self) -> messages.Features:
return self.client.features
@property
def model(self) -> models.TrezorModel:
return self.client.model
@property
def version(self) -> t.Tuple[int, int, int]:
return self.client.version
@property
def id(self) -> bytes:
return self._id
@id.setter
def id(self, value: bytes) -> None:
if not isinstance(value, bytes):
raise ValueError("id must be of type bytes")
self._id = value
class SessionV1(Session):
derive_cardano: bool | None = False
@classmethod
def new(
cls,
client: TrezorClient,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, id=session_id or b"")
session._init_callbacks()
session.passphrase = passphrase
session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano)
return session
@classmethod
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, session_id)
session.init_session()
return session
def _init_callbacks(self) -> None:
self.button_callback = self.client.button_callback
if self.button_callback is None:
self.button_callback = _callback_button
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self.client.passphrase_callback
def _write(self, msg: t.Any) -> None:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
self.client.protocol.write(msg)
def _read(self) -> t.Any:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None):
if self.id == b"":
session_id = None
else:
session_id = self.id
resp: messages.Features = self.call_raw(
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
)
if isinstance(self.passphrase, str):
self.passphrase_callback = _send_passphrase
self._id = resp.session_id
def _send_passphrase(session: Session, resp: t.Any) -> None:
assert isinstance(session.passphrase, str)
return session.call(messages.PassphraseAck(passphrase=session.passphrase))
def _callback_button(session: Session, msg: t.Any) -> t.Any:
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
return session.call(messages.ButtonAck())
class SessionV2(Session):
@classmethod
def new(
cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2)
session = cls(client, b"\x00")
new_session: messages.ThpNewSession = session.call(
messages.ThpCreateNewSession(
passphrase=passphrase, derive_cardano=derive_cardano
)
)
assert new_session.new_session_id is not None
session_id = new_session.new_session_id
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: TrezorClient, id: bytes) -> None:
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2)
self.pin_callback = client.pin_callback
self.button_callback = client.button_callback
if self.button_callback is None:
self.button_callback = _callback_button
self.channel: ProtocolV2 = client.protocol.get_channel()
self.update_id_and_sid(id)
def _write(self, msg: t.Any) -> None:
LOG.debug("writing message %s", type(msg))
self.channel.write(self.sid, msg)
def _read(self) -> t.Any:
msg = self.channel.read(self.sid)
LOG.debug("reading message %s", type(msg))
return msg
def update_id_and_sid(self, id: bytes) -> None:
self._id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

View File

@ -0,0 +1,102 @@
# from storage.cache_thp import ChannelCache
# from trezor import log
# from trezor.wire.thp import ThpError
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
# """
# Checks if:
# - an ACK message is expected
# - the received ACK message acknowledges correct sequence number (bit)
# """
# if not _is_ack_expected(cache):
# return False
# if not _has_ack_correct_sync_bit(cache, ack_bit):
# return False
# return True
# def _is_ack_expected(cache: ChannelCache) -> bool:
# is_expected: bool = not is_sending_allowed(cache)
# if __debug__ and not is_expected:
# log.debug(__name__, "Received unexpected ACK message")
# return is_expected
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
# if __debug__ and not is_correct:
# log.debug(__name__, "Received ACK message with wrong ack bit")
# return is_correct
# def is_sending_allowed(cache: ChannelCache) -> bool:
# """
# Checks whether sending a message in the provided channel is allowed.
# Note: Sending a message in a channel before receipt of ACK message for the previously
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
# """
# return bool(cache.sync >> 7)
# def get_send_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the sequential number (bit) of the next message to be sent
# in the provided channel.
# """
# return (cache.sync & 0x20) >> 5
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the (expected) sequential number (bit) of the next message
# to be received in the provided channel.
# """
# return (cache.sync & 0x40) >> 6
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
# """
# Set the flag whether sending a message in this channel is allowed or not.
# """
# cache.sync &= 0x7F
# if sending_allowed:
# cache.sync |= 0x80
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# """
# Set the expected sequential number (bit) of the next message to be received
# in the provided channel
# """
# if __debug__:
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected receive sync bit")
# # set second bit to "seq_bit" value
# cache.sync &= 0xBF
# if seq_bit:
# cache.sync |= 0x40
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected send seq bit")
# if __debug__:
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# # set third bit to "seq_bit" value
# cache.sync &= 0xDF
# if seq_bit:
# cache.sync |= 0x20
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
# """
# Set the sequential bit of the "next message to be send" to the opposite value,
# i.e. 1 -> 0 and 0 -> 1
# """
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from binascii import hexlify
class ChannelData:
def __init__(
self,
protocol_version: int,
transport_path: str,
channel_id: int,
key_request: bytes,
key_response: bytes,
nonce_request: int,
nonce_response: int,
sync_bit_send: int,
sync_bit_receive: int,
) -> None:
self.protocol_version: int = protocol_version
self.transport_path: str = transport_path
self.channel_id: int = channel_id
self.key_request: str = hexlify(key_request).decode()
self.key_response: str = hexlify(key_response).decode()
self.nonce_request: int = nonce_request
self.nonce_response: int = nonce_response
self.sync_bit_receive: int = sync_bit_receive
self.sync_bit_send: int = sync_bit_send
def to_dict(self):
return {
"protocol_version": self.protocol_version,
"transport_path": self.transport_path,
"channel_id": self.channel_id,
"key_request": self.key_request,
"key_response": self.key_response,
"nonce_request": self.nonce_request,
"nonce_response": self.nonce_response,
"sync_bit_send": self.sync_bit_send,
"sync_bit_receive": self.sync_bit_receive,
}

View File

@ -0,0 +1,146 @@
from __future__ import annotations
import json
import logging
import os
import typing as t
from ..thp.channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
db: "ChannelDatabase | None" = None
def get_channel_db() -> ChannelDatabase:
if db is None:
set_channel_database(should_not_store=True)
assert db is not None
return db
class ChannelDatabase:
def load_stored_channels(self) -> t.List[ChannelData]: ...
def clear_stored_channels(self) -> None: ...
def read_all_channels(self) -> t.List: ...
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
def save_channel(self, new_channel: ProtocolAndChannel): ...
def remove_channel(self, transport_path: str) -> None: ...
class DummyChannelDatabase(ChannelDatabase):
def load_stored_channels(self) -> t.List[ChannelData]:
return []
def clear_stored_channels(self) -> None:
pass
def read_all_channels(self) -> t.List:
return []
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
return
def save_channel(self, new_channel: ProtocolAndChannel):
pass
def remove_channel(self, transport_path: str) -> None:
pass
class JsonChannelDatabase(ChannelDatabase):
def __init__(self, data_path: str) -> None:
self.data_path = data_path
super().__init__()
def load_stored_channels(self) -> t.List[ChannelData]:
dicts = self.read_all_channels()
return [dict_to_channel_data(d) for d in dicts]
def clear_stored_channels(self) -> None:
LOG.debug("Clearing contents of %s", self.data_path)
with open(self.data_path, "w") as f:
json.dump([], f)
try:
os.remove(self.data_path)
except Exception as e:
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
def read_all_channels(self) -> t.List:
ensure_file_exists(self.data_path)
with open(self.data_path, "r") as f:
return json.load(f)
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(self.data_path, "w") as f:
json.dump(channels, f, indent=4)
def save_channel(self, new_channel: ProtocolAndChannel):
LOG.debug("save channel")
channels = self.read_all_channels()
transport_path = new_channel.transport.get_path()
# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
self.save_all_channels(channels)
return
# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
self.save_all_channels(channels)
def remove_channel(self, transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = self.read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
self.save_all_channels(remaining_channels)
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
return ChannelData(
protocol_version=dict["protocol_version"],
transport_path=dict["transport_path"],
channel_id=dict["channel_id"],
key_request=bytes.fromhex(dict["key_request"]),
key_response=bytes.fromhex(dict["key_response"]),
nonce_request=dict["nonce_request"],
nonce_response=dict["nonce_response"],
sync_bit_send=dict["sync_bit_send"],
sync_bit_receive=dict["sync_bit_receive"],
)
def ensure_file_exists(file_path: str) -> None:
LOG.debug("checking if file %s exists", file_path)
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
LOG.debug("File %s does not exist. Creating a new one.", file_path)
with open(file_path, "w") as f:
json.dump([], f)
def set_channel_database(should_not_store: bool):
global db
if should_not_store:
db = DummyChannelDatabase()
else:
from platformdirs import user_cache_dir
APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
db = JsonChannelDatabase(DATA_PATH)

View File

@ -0,0 +1,19 @@
import zlib
CHECKSUM_LENGTH = 4
def compute(data: bytes) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,59 @@
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise Exception("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise Exception("Unexpected acknowledgement bit")
def get_seq_bit(ctrl_byte: int) -> int:
return (ctrl_byte & 0x10) >> 4
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,116 @@
from typing import Tuple
p = 2**255 - 19
J = 486662
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
a24 = 121666 # (J + 2) // 4
def decode_scalar(scalar: bytes) -> int:
# decodeScalar25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(scalar) != 32:
raise ValueError("Invalid length of scalar")
array = bytearray(scalar)
array[0] &= 248
array[31] &= 127
array[31] |= 64
return int.from_bytes(array, "little")
def decode_coordinate(coordinate: bytes) -> int:
# decodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(coordinate) != 32:
raise ValueError("Invalid length of coordinate")
array = bytearray(coordinate)
array[-1] &= 0x7F
return int.from_bytes(array, "little") % p
def encode_coordinate(coordinate: int) -> bytes:
# encodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
return coordinate.to_bytes(32, "little")
def get_private_key(secret: bytes) -> bytes:
return decode_scalar(secret).to_bytes(32, "little")
def get_public_key(private_key: bytes) -> bytes:
base_point = int.to_bytes(9, 32, "little")
return multiply(private_key, base_point)
def multiply(private_scalar: bytes, public_point: bytes):
# X25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
def ladder_operation(
x1: int, x2: int, z2: int, x3: int, z3: int
) -> Tuple[int, int, int, int]:
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
# (x4, z4) = 2 * (x2, z2)
# (x5, z5) = (x2, z2) + (x3, z3)
# where (x1, 1) = (x3, z3) - (x2, z2)
a = (x2 + z2) % p
aa = (a * a) % p
b = (x2 - z2) % p
bb = (b * b) % p
e = (aa - bb) % p
c = (x3 + z3) % p
d = (x3 - z3) % p
da = (d * a) % p
cb = (c * b) % p
t0 = (da + cb) % p
x5 = (t0 * t0) % p
t1 = (da - cb) % p
t2 = (t1 * t1) % p
z5 = (x1 * t2) % p
x4 = (aa * bb) % p
t3 = (a24 * e) % p
t4 = (bb + t3) % p
z4 = (e * t4) % p
return x4, z4, x5, z5
def conditional_swap(first: int, second: int, condition: int):
# Returns (second, first) if condition is true and (first, second) otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
first & true_mask
)
k = decode_scalar(private_scalar)
u = decode_coordinate(public_point)
x_1 = u
x_2 = 1
z_2 = 0
x_3 = u
z_3 = 1
swap = 0
for i in reversed(range(256)):
bit = (k >> i) & 1
swap = bit ^ swap
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
swap = bit
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
x = pow(z_2, p - 2, p) * x_2 % p
return encode_coordinate(x)

View File

@ -0,0 +1,82 @@
import struct
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
BROADCAST_CHANNEL_ID = 0xFFFF
class MessageHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.data_length = length
def to_bytes_init(self) -> bytes:
return struct.pack(
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
)
def to_bytes_cont(self) -> bytes:
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.data_length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
def is_ack(self) -> bool:
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_channel_allocation_response(self):
return (
self.cid == BROADCAST_CHANNEL_ID
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
)
def is_handshake_init_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
def is_handshake_comp_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
def is_encrypted_transport(self) -> bool:
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
@classmethod
def get_error_header(cls, cid: int, length: int):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_request_header(cls, length: int):
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)

View File

@ -0,0 +1,32 @@
from __future__ import annotations
import logging
from ... import messages
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp.channel_data import ChannelData
LOG = logging.getLogger(__name__)
class ProtocolAndChannel:
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.transport = transport
self.mapping = mapping
self.channel_keys = channel_data
def get_features(self) -> messages.Features:
raise NotImplementedError()
def get_channel_data(self) -> ChannelData:
raise NotImplementedError
def update_features(self) -> None:
raise NotImplementedError

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
class ProtocolV1(ProtocolAndChannel):
HEADER_LEN = struct.calcsize(">HL")
_features: messages.Features | None = None
def get_features(self) -> messages.Features:
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
def read(self) -> t.Any:
msg_type, msg_bytes = self._read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
self.transport.close()
return msg
def write(self, msg: t.Any) -> None:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: chunk_size - 1]
chunk = chunk.ljust(chunk_size, b"\x00")
self.transport.write_chunk(chunk)
buffer = buffer[63:]
def _read(self) -> t.Tuple[int, bytes]:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> t.Tuple[int, int, bytes]:
chunk = self.transport.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.transport.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -0,0 +1,404 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import typing as t
from binascii import hexlify
from enum import IntEnum
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, curve25519, thp_io
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
MANAGEMENT_SESSION_ID: int = 0
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2(ProtocolAndChannel):
channel_id: int
channel_database: ChannelDatabase
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
sync_bit_send: int
sync_bit_receive: int
_has_valid_channel: bool = False
_features: messages.Features | None = None
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.channel_database: ChannelDatabase = get_channel_db()
super().__init__(transport, mapping, channel_data)
if channel_data is not None:
self.channel_id = channel_data.channel_id
self.key_request = bytes.fromhex(channel_data.key_request)
self.key_response = bytes.fromhex(channel_data.key_response)
self.nonce_request = channel_data.nonce_request
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self._has_valid_channel = True
def get_channel(self) -> ProtocolV2:
if not self._has_valid_channel:
self._establish_new_channel()
return self
def get_channel_data(self) -> ChannelData:
return ChannelData(
protocol_version=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.key_request,
key_response=self.key_response,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
sync_bit_send=self.sync_bit_send,
)
def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on a different session.")
self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
self.channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel()
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message)
self.session_id: int = 0
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK
_, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
if not isinstance(features, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features
def _establish_new_channel(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
# Send channel allocation request
# Note that [:8] on the following line is required when tests use
# WITH_MOCK_URANDOM. Without [:8] such tests will (almost always) fail.
channel_id_request_nonce = os.urandom(8)[:8]
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
MessageHeader.get_channel_allocation_request_header(12),
channel_id_request_nonce,
)
# Read channel allocation response
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, channel_id_request_nonce
):
# TODO raise exception here, I guess
raise Exception("Invalid channel allocation response.")
self.channel_id = int.from_bytes(payload[8:10], "big")
self.device_properties = payload[10:]
# Send handshake init request
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
# Note that [:32] on the following line is required when tests use
# WITH_MOCK_URANDOM. Without [:32] such tests will (almost always) fail.
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)[:32])
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
# Read handshake init response
header, payload = self._read_until_valid_crc_check()
self._send_ack_0()
if not header.is_handshake_init_response():
click.echo(
"Received message is not a valid handshake init response message",
err=True,
)
trezor_ephemeral_pubkey = payload[:32]
encrypted_trezor_static_pubkey = payload[32:80]
noise_tag = payload[80:96]
# TODO check noise tag
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# Prepare and send handshake completion request
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
except Exception as e:
click.echo(
f"Exception of type{type(e)}", err=True
) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
pairing_methods=[
messages.ThpPairingMethod.NoMethod,
]
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload)
ha_completion_req_header = MessageHeader(
0x12,
self.channel_id,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
click.echo(
"Received message is not a valid handshake completion response",
err=True,
)
self._send_ack_1()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
# Send StartPairingReqest message
message = messages.ThpStartPairingRequest()
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
# Read
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpEndResponse)
self._has_valid_channel = True
def _send_ack_0(self):
LOG.debug("sending ack 0")
header = MessageHeader(0x20, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _send_ack_1(self):
LOG.debug("sending ack 1")
header = MessageHeader(0x28, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _encrypt_and_write(
self,
session_id: int,
message_type: int,
message_data: bytes,
ctrl_byte: int | None = None,
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
if ctrl_byte is None:
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
self.sync_bit_send = 1 - self.sync_bit_send
sid = session_id.to_bytes(1, "big")
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, header, encrypted_message
)
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if control_byte.is_ack(header.ctrl_byte):
return self.read_and_decrypt()
if not header.is_encrypted_transport():
click.echo(
"Trying to decrypt not encrypted message!"
+ hexlify(header.to_bytes_init() + raw_payload).decode(),
err=True,
)
if not control_byte.is_ack(header.ctrl_byte):
LOG.debug(
"--> Get sequence bit %d %s %s",
control_byte.get_seq_bit(header.ctrl_byte),
"from control byte",
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
)
if control_byte.get_seq_bit(header.ctrl_byte):
self._send_ack_1()
else:
self._send_ack_0()
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
session_id,
int.from_bytes(message_type, "big"),
message_data,
)
def _read_until_valid_crc_check(
self,
) -> t.Tuple[MessageHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.transport)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
click.echo(
"Received a message with an invalid checksum:"
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
err=True,
)
header, payload, chksum = thp_io.read(self.transport)
return header, payload
def _is_valid_channel_allocation_response(
self, header: MessageHeader, payload: bytes, original_nonce: bytes
) -> bool:
if not header.is_channel_allocation_response():
click.echo(
"Received message is not a channel allocation response", err=True
)
return False
if len(payload) < 10:
click.echo("Invalid channel allocation response payload", err=True)
return False
if payload[:8] != original_nonce:
click.echo(
"Invalid channel allocation response payload (nonce mismatch)", err=True
)
return False
return True
class ControlByteType(IntEnum):
CHANNEL_ALLOCATION_RES = 1
HANDSHAKE_INIT_RES = 2
HANDSHAKE_COMP_RES = 3
ACK = 4
ENCRYPTED_TRANSPORT = 5

View File

@ -0,0 +1,93 @@
import struct
from typing import Tuple
from .. import Transport
from ..thp import checksum
from .message_header import MessageHeader
INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3
MAX_PAYLOAD_LEN = 60000
MESSAGE_TYPE_LENGTH = 2
CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum
write_payload_to_wire(transport, header, data)
def write_payload_to_wire(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
transport.open()
buffer = bytearray(transport_payload)
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
while buffer:
chunk = (
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
)
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
"""
Reads from the given wire transport.
Returns `Tuple[MessageHeader, bytes, bytes]`:
1. `header` (`MessageHeader`): Header of the message.
2. `data` (`bytes`): Contents of the message (if any).
3. `checksum` (`bytes`): crc32 checksum of the header + data.
"""
buffer = bytearray()
# Read header with first part of message data
header, first_chunk = read_first(transport)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < header.data_length:
buffer.extend(read_next(transport, header.cid))
data_len = header.data_length - checksum.CHECKSUM_LENGTH
msg_data = buffer[:data_len]
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
return (header, msg_data, chksum)
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
chunk = transport.read_chunk()
try:
ctrl_byte, cid, data_length = struct.unpack(
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
)
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[INIT_HEADER_LENGTH:]
return MessageHeader(ctrl_byte, cid, data_length), data
def read_next(transport: Transport, cid: int) -> bytes:
chunk = transport.read_chunk()
ctrl_byte, read_cid = struct.unpack(
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
)
if ctrl_byte != CONTINUATION_PACKET:
raise RuntimeError("Continuation packet with incorrect control byte")
if read_cid != cid:
raise RuntimeError("Continuation packet for different channel")
return chunk[CONT_HEADER_LENGTH:]

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -14,14 +14,15 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import socket
import time
from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, Iterable, Tuple
from ..log import DUMP_PACKETS
from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import Transport, TransportException
if TYPE_CHECKING:
from ..models import TrezorModel
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport):
class UdpTransport(Transport):
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324
PATH_PREFIX = "udp"
ENABLED: bool = True
CHUNK_SIZE = 64
def __init__(self, device: Optional[str] = None) -> None:
def __init__(
self,
device: str | None = None,
) -> None:
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
devparts = device.split(":")
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port)
self.socket: Optional[socket.socket] = None
self.device: Tuple[str, int] = (host, port)
super().__init__(protocol=ProtocolV1(self))
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
self.socket: socket.socket | None = None
super().__init__()
@classmethod
def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path)
try:
d.open()
if d._ping():
if d.ping():
return d
else:
raise TransportException(
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try:
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
else:
raise TransportException(f"No UDP device at {path}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self._ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.close()
self.socket = None
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"
def write_chunk(self, chunk: bytes) -> None:
if self.socket is None:
self.open()
assert self.socket is not None
if len(chunk) != 64:
raise TransportException("Unexpected data length")
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.sendall(chunk)
def read_chunk(self) -> bytes:
if self.socket is None:
self.open()
assert self.socket is not None
while True:
try:
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self.ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -14,16 +14,17 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit
import logging
import sys
import time
from typing import Iterable, List, Optional
from typing import Iterable, List
from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
LOG = logging.getLogger(__name__)
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle:
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
class WebUsbTransport(Transport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
CHUNK_SIZE = 64
def __init__(
self,
device: "usb1.USBDevice",
debug: bool = False,
) -> None:
self.device = device
self.debug = debug
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0
self.handle: Optional["usb1.USBDeviceHandle"] = None
self.handle: usb1.USBDeviceHandle | None = None
super().__init__()
@classmethod
def enumerate(
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
def open(self) -> None:
self.handle = self.device.open()
@ -64,6 +121,8 @@ class WebUsbHandle:
self.handle.claimInterface(self.interface)
except usb1.USBErrorAccess as e:
raise DeviceIsBusy(self.device) from e
except usb1.USBErrorBusy as e:
raise DeviceIsBusy(self.device) from e
def close(self) -> None:
if self.handle is not None:
@ -75,6 +134,8 @@ class WebUsbHandle:
self.handle = None
def write_chunk(self, chunk: bytes) -> None:
if self.handle is None:
self.open()
assert self.handle is not None
if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
@ -97,6 +158,8 @@ class WebUsbHandle:
return
def read_chunk(self) -> bytes:
if self.handle is None:
self.open()
assert self.handle is not None
endpoint = 0x80 | self.endpoint
while True:
@ -117,70 +180,6 @@ class WebUsbHandle:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
def __init__(
self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
debug: bool = False,
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
self.device = device
self.handle = handle
self.debug = debug
super().__init__(protocol=ProtocolV1(handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
except usb1.USBErrorPipe:
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
return devices
def find_debug(self) -> "WebUsbTransport":
# For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True)