mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-23 05:40:57 +00:00
feat(python): implement session based trezorlib
[no changelog]
This commit is contained in:
parent
309af26ae3
commit
edd750e569
@ -95,6 +95,15 @@ class Emulator:
|
||||
raise RuntimeError
|
||||
return self._client
|
||||
|
||||
@client.setter
|
||||
def client(self, new_client: TrezorClientDebugLink) -> None:
|
||||
"""Setter for the client property to update _client."""
|
||||
if not isinstance(new_client, TrezorClientDebugLink):
|
||||
raise TypeError(
|
||||
f"Expected a TrezorClientDebugLink, got {type(new_client).__name__}."
|
||||
)
|
||||
self._client = new_client
|
||||
|
||||
def make_args(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@ -112,7 +121,7 @@ class Emulator:
|
||||
start = time.monotonic()
|
||||
try:
|
||||
while True:
|
||||
if transport._ping():
|
||||
if transport.ping():
|
||||
break
|
||||
if self.process.poll() is not None:
|
||||
raise RuntimeError("Emulator process died")
|
||||
|
@ -7,7 +7,7 @@ import typing as t
|
||||
from importlib import metadata
|
||||
|
||||
from . import device
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
try:
|
||||
cryptography_version = metadata.version("cryptography")
|
||||
@ -361,7 +361,7 @@ def verify_authentication_response(
|
||||
|
||||
|
||||
def authenticate_device(
|
||||
client: TrezorClient,
|
||||
session: Session,
|
||||
challenge: bytes | None = None,
|
||||
*,
|
||||
whitelist: t.Collection[bytes] | None = None,
|
||||
@ -371,7 +371,7 @@ def authenticate_device(
|
||||
if challenge is None:
|
||||
challenge = secrets.token_bytes(16)
|
||||
|
||||
resp = device.authenticate(client, challenge)
|
||||
resp = device.authenticate(session, challenge)
|
||||
|
||||
return verify_authentication_response(
|
||||
challenge,
|
||||
|
@ -20,17 +20,17 @@ from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(messages.BenchmarkNames)
|
||||
def list_names(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
) -> "MessageType":
|
||||
return client.call(messages.BenchmarkListNames())
|
||||
return session.call(messages.BenchmarkListNames())
|
||||
|
||||
|
||||
@expect(messages.BenchmarkResult)
|
||||
def run(client: "TrezorClient", name: str) -> "MessageType":
|
||||
return client.call(messages.BenchmarkRun(name=name))
|
||||
def run(session: "Session", name: str) -> "MessageType":
|
||||
return session.call(messages.BenchmarkRun(name=name))
|
||||
|
@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .protobuf import dict_to_proto
|
||||
from .tools import expect, session
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(messages.BinanceAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
)
|
||||
@ -42,16 +42,15 @@ def get_address(
|
||||
|
||||
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
|
||||
def get_public_key(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
session: "Session", address_n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
) -> messages.BinanceSignedTx:
|
||||
msg = tx_json["msgs"][0]
|
||||
tx_msg = tx_json.copy()
|
||||
@ -60,7 +59,7 @@ def sign_tx(
|
||||
tx_msg["chunkify"] = chunkify
|
||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
||||
|
||||
response = client.call(envelope)
|
||||
response = session.call(envelope)
|
||||
|
||||
if not isinstance(response, messages.BinanceTxRequest):
|
||||
raise RuntimeError(
|
||||
@ -77,7 +76,7 @@ def sign_tx(
|
||||
else:
|
||||
raise ValueError("can not determine msg type")
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
|
||||
if not isinstance(response, messages.BinanceSignedTx):
|
||||
raise RuntimeError(
|
||||
|
@ -13,7 +13,6 @@
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import warnings
|
||||
from copy import copy
|
||||
from decimal import Decimal
|
||||
@ -23,12 +22,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
|
||||
from typing_extensions import Protocol, TypedDict
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import expect, prepare_message_bytes, session
|
||||
from .tools import expect, prepare_message_bytes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
class ScriptSig(TypedDict):
|
||||
asm: str
|
||||
@ -105,7 +104,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
||||
|
||||
@expect(messages.PublicKey)
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
show_display: bool = False,
|
||||
@ -116,13 +115,13 @@ def get_public_node(
|
||||
unlock_path_mac: Optional[bytes] = None,
|
||||
) -> "MessageType":
|
||||
if unlock_path:
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||
)
|
||||
if not isinstance(res, messages.UnlockedPathRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetPublicKey(
|
||||
address_n=n,
|
||||
ecdsa_curve_name=ecdsa_curve_name,
|
||||
@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any):
|
||||
|
||||
@expect(messages.Address)
|
||||
def get_authenticated_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
@ -153,13 +152,13 @@ def get_authenticated_address(
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
if unlock_path:
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||
)
|
||||
if not isinstance(res, messages.UnlockedPathRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetAddress(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -172,15 +171,16 @@ def get_authenticated_address(
|
||||
)
|
||||
|
||||
|
||||
# TODO this is used by tests only
|
||||
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
|
||||
def get_ownership_id(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetOwnershipId(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -190,8 +190,9 @@ def get_ownership_id(
|
||||
)
|
||||
|
||||
|
||||
# TODO this is used by tests only
|
||||
def get_ownership_proof(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
@ -202,11 +203,11 @@ def get_ownership_proof(
|
||||
preauthorized: bool = False,
|
||||
) -> Tuple[bytes, bytes]:
|
||||
if preauthorized:
|
||||
res = client.call(messages.DoPreauthorized())
|
||||
res = session.call(messages.DoPreauthorized())
|
||||
if not isinstance(res, messages.PreauthorizedRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.GetOwnershipProof(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -226,7 +227,7 @@ def get_ownership_proof(
|
||||
|
||||
@expect(messages.MessageSignature)
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
@ -234,7 +235,7 @@ def sign_message(
|
||||
no_script_type: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SignMessage(
|
||||
coin_name=coin_name,
|
||||
address_n=n,
|
||||
@ -247,7 +248,7 @@ def sign_message(
|
||||
|
||||
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
address: str,
|
||||
signature: bytes,
|
||||
@ -255,7 +256,7 @@ def verify_message(
|
||||
chunkify: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.VerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
@ -269,9 +270,9 @@ def verify_message(
|
||||
return isinstance(resp, messages.Success)
|
||||
|
||||
|
||||
@session
|
||||
# @session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
inputs: Sequence[messages.TxInputType],
|
||||
outputs: Sequence[messages.TxOutputType],
|
||||
@ -319,17 +320,17 @@ def sign_tx(
|
||||
setattr(signtx, name, value)
|
||||
|
||||
if unlock_path:
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||
)
|
||||
if not isinstance(res, messages.UnlockedPathRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
elif preauthorized:
|
||||
res = client.call(messages.DoPreauthorized())
|
||||
res = session.call(messages.DoPreauthorized())
|
||||
if not isinstance(res, messages.PreauthorizedRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
|
||||
res = client.call(signtx)
|
||||
res = session.call(signtx)
|
||||
|
||||
# Prepare structure for signatures
|
||||
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||
@ -388,7 +389,7 @@ def sign_tx(
|
||||
if res.request_type == R.TXPAYMENTREQ:
|
||||
assert res.details.request_index is not None
|
||||
msg = payment_reqs[res.details.request_index]
|
||||
res = client.call(msg)
|
||||
res = session.call(msg)
|
||||
else:
|
||||
msg = messages.TransactionType()
|
||||
if res.request_type == R.TXMETA:
|
||||
@ -418,7 +419,7 @@ def sign_tx(
|
||||
f"Unknown request type - {res.request_type}."
|
||||
)
|
||||
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
res = session.call(messages.TxAck(tx=msg))
|
||||
|
||||
if not isinstance(res, messages.TxRequest):
|
||||
raise exceptions.TrezorException("Unexpected message")
|
||||
@ -432,7 +433,7 @@ def sign_tx(
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def authorize_coinjoin(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coordinator: str,
|
||||
max_rounds: int,
|
||||
max_coordinator_fee_rate: int,
|
||||
@ -441,7 +442,7 @@ def authorize_coinjoin(
|
||||
coin_name: str,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.AuthorizeCoinJoin(
|
||||
coordinator=coordinator,
|
||||
max_rounds=max_rounds,
|
||||
|
@ -35,8 +35,8 @@ from . import exceptions, messages, tools
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .transport.session import Session
|
||||
|
||||
PROTOCOL_MAGICS = {
|
||||
"mainnet": 764824073,
|
||||
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
|
||||
|
||||
@expect(messages.CardanoAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_parameters: messages.CardanoAddressParametersType,
|
||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||
network_id: int = NETWORK_IDS["mainnet"],
|
||||
@ -833,7 +833,7 @@ def get_address(
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CardanoGetAddress(
|
||||
address_parameters=address_parameters,
|
||||
protocol_magic=protocol_magic,
|
||||
@ -847,12 +847,12 @@ def get_address(
|
||||
|
||||
@expect(messages.CardanoPublicKey)
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
show_display: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CardanoGetPublicKey(
|
||||
address_n=address_n,
|
||||
derivation_type=derivation_type,
|
||||
@ -863,12 +863,12 @@ def get_public_key(
|
||||
|
||||
@expect(messages.CardanoNativeScriptHash)
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
native_script: messages.CardanoNativeScript,
|
||||
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CardanoGetNativeScriptHash(
|
||||
script=native_script,
|
||||
display_format=display_format,
|
||||
@ -878,7 +878,7 @@ def get_native_script_hash(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
inputs: List[InputWithPath],
|
||||
outputs: List[OutputWithData],
|
||||
@ -915,7 +915,7 @@ def sign_tx(
|
||||
signing_mode,
|
||||
)
|
||||
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
messages.CardanoSignTxInit(
|
||||
signing_mode=signing_mode,
|
||||
inputs_count=len(inputs),
|
||||
@ -951,14 +951,14 @@ def sign_tx(
|
||||
_get_certificates_items(certificates),
|
||||
withdrawals,
|
||||
):
|
||||
response = client.call(tx_item)
|
||||
response = session.call(tx_item)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
sign_tx_response: Dict[str, Any] = {}
|
||||
|
||||
if auxiliary_data is not None:
|
||||
auxiliary_data_supplement = client.call(auxiliary_data)
|
||||
auxiliary_data_supplement = session.call(auxiliary_data)
|
||||
if not isinstance(
|
||||
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
|
||||
):
|
||||
@ -971,7 +971,7 @@ def sign_tx(
|
||||
auxiliary_data_supplement.__dict__
|
||||
)
|
||||
|
||||
response = client.call(messages.CardanoTxHostAck())
|
||||
response = session.call(messages.CardanoTxHostAck())
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
@ -980,24 +980,24 @@ def sign_tx(
|
||||
_get_collateral_inputs_items(collateral_inputs),
|
||||
required_signers,
|
||||
):
|
||||
response = client.call(tx_item)
|
||||
response = session.call(tx_item)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
if collateral_return is not None:
|
||||
for tx_item in _get_output_items(collateral_return):
|
||||
response = client.call(tx_item)
|
||||
response = session.call(tx_item)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
for reference_input in reference_inputs:
|
||||
response = client.call(reference_input)
|
||||
response = session.call(reference_input)
|
||||
if not isinstance(response, messages.CardanoTxItemAck):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
sign_tx_response["witnesses"] = []
|
||||
for witness_request in witness_requests:
|
||||
response = client.call(witness_request)
|
||||
response = session.call(witness_request)
|
||||
if not isinstance(response, messages.CardanoTxWitnessResponse):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
sign_tx_response["witnesses"].append(
|
||||
@ -1009,12 +1009,12 @@ def sign_tx(
|
||||
}
|
||||
)
|
||||
|
||||
response = client.call(messages.CardanoTxHostAck())
|
||||
response = session.call(messages.CardanoTxHostAck())
|
||||
if not isinstance(response, messages.CardanoTxBodyHash):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
sign_tx_response["tx_hash"] = response.tx_hash
|
||||
|
||||
response = client.call(messages.CardanoTxHostAck())
|
||||
response = session.call(messages.CardanoTxHostAck())
|
||||
if not isinstance(response, messages.CardanoSignTxFinished):
|
||||
raise UNEXPECTED_RESPONSE_ERROR
|
||||
|
||||
|
@ -14,33 +14,42 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
import click
|
||||
|
||||
from .. import exceptions, transport
|
||||
from ..client import TrezorClient
|
||||
from ..ui import ClickUI, ScriptUI
|
||||
from .. import exceptions, transport, ui
|
||||
from ..client import ProtocolVersion, TrezorClient
|
||||
from ..messages import Capability
|
||||
from ..transport import Transport
|
||||
from ..transport.session import Session, SessionV1, SessionV2
|
||||
from ..transport.thp.channel_database import get_channel_db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Needed to enforce a return value from decorators
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from ..transport import Transport
|
||||
from ..ui import TrezorClientUI
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R = t.TypeVar("R")
|
||||
FuncWithSession = t.Callable[Concatenate[Session, P], R]
|
||||
|
||||
|
||||
class ChoiceType(click.Choice):
|
||||
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
|
||||
|
||||
def __init__(
|
||||
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
|
||||
) -> None:
|
||||
super().__init__(list(typemap.keys()))
|
||||
self.case_sensitive = case_sensitive
|
||||
if case_sensitive:
|
||||
@ -48,7 +57,7 @@ class ChoiceType(click.Choice):
|
||||
else:
|
||||
self.typemap = {k.lower(): v for k, v in typemap.items()}
|
||||
|
||||
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
|
||||
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
|
||||
if value in self.typemap.values():
|
||||
return value
|
||||
value = super().convert(value, param, ctx)
|
||||
@ -57,11 +66,69 @@ class ChoiceType(click.Choice):
|
||||
return self.typemap[value]
|
||||
|
||||
|
||||
def get_passphrase(
|
||||
passphrase_on_host: bool, available_on_device: bool
|
||||
) -> t.Union[str, object]:
|
||||
if available_on_device and not passphrase_on_host:
|
||||
return ui.PASSPHRASE_ON_DEVICE
|
||||
|
||||
env_passphrase = os.getenv("PASSPHRASE")
|
||||
if env_passphrase is not None:
|
||||
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
return env_passphrase
|
||||
|
||||
while True:
|
||||
try:
|
||||
passphrase = ui.prompt(
|
||||
"Passphrase required",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
# In case user sees the input on the screen, we do not need confirmation
|
||||
if not ui.CAN_HANDLE_HIDDEN_INPUT:
|
||||
return passphrase
|
||||
second = ui.prompt(
|
||||
"Confirm your passphrase",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
if passphrase == second:
|
||||
return passphrase
|
||||
else:
|
||||
ui.echo("Passphrase did not match. Please try again.")
|
||||
except click.Abort:
|
||||
raise exceptions.Cancelled from None
|
||||
|
||||
|
||||
def get_client(transport: Transport) -> TrezorClient:
|
||||
stored_channels = get_channel_db().load_stored_channels()
|
||||
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
||||
path = transport.get_path()
|
||||
if path in stored_transport_paths:
|
||||
stored_channel_with_correct_transport_path = next(
|
||||
ch for ch in stored_channels if ch.transport_path == path
|
||||
)
|
||||
try:
|
||||
client = TrezorClient.resume(
|
||||
transport, stored_channel_with_correct_transport_path
|
||||
)
|
||||
except Exception:
|
||||
LOG.debug("Failed to resume a channel. Replacing by a new one.")
|
||||
get_channel_db().remove_channel(path)
|
||||
client = TrezorClient(transport)
|
||||
else:
|
||||
client = TrezorClient(transport)
|
||||
return client
|
||||
|
||||
|
||||
class TrezorConnection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
session_id: Optional[bytes],
|
||||
session_id: bytes | None,
|
||||
passphrase_on_host: bool,
|
||||
script: bool,
|
||||
) -> None:
|
||||
@ -70,6 +137,54 @@ class TrezorConnection:
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
derive_cardano: bool = False,
|
||||
empty_passphrase: bool = False,
|
||||
must_resume: bool = False,
|
||||
) -> Session:
|
||||
client = self.get_client()
|
||||
if must_resume and self.session_id is None:
|
||||
click.echo("Failed to resume session - no session id provided")
|
||||
raise RuntimeError("Failed to resume session - no session id provided")
|
||||
|
||||
# Try resume session from id
|
||||
if self.session_id is not None:
|
||||
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
||||
session = SessionV1.resume_from_id(
|
||||
client=client, session_id=self.session_id
|
||||
)
|
||||
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
session = SessionV2(client, self.session_id)
|
||||
# TODO fix resumption on THP
|
||||
else:
|
||||
raise Exception("Unsupported client protocol", client.protocol_version)
|
||||
if must_resume:
|
||||
if session.id != self.session_id or session.id is None:
|
||||
click.echo("Failed to resume session")
|
||||
RuntimeError("Failed to resume session - no session id provided")
|
||||
return session
|
||||
|
||||
features = client.protocol.get_features()
|
||||
|
||||
passphrase_enabled = True # TODO what to do here?
|
||||
|
||||
if not passphrase_enabled:
|
||||
return client.get_session(derive_cardano=derive_cardano)
|
||||
|
||||
if empty_passphrase:
|
||||
passphrase = ""
|
||||
else:
|
||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
session = client.get_session(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano
|
||||
)
|
||||
return session
|
||||
|
||||
def get_transport(self) -> "Transport":
|
||||
try:
|
||||
# look for transport without prefix search
|
||||
@ -82,19 +197,13 @@ class TrezorConnection:
|
||||
# if this fails, we want the exception to bubble up to the caller
|
||||
return transport.get_transport(self.path, prefix_search=True)
|
||||
|
||||
def get_ui(self) -> "TrezorClientUI":
|
||||
if self.script:
|
||||
# It is alright to return just the class object instead of instance,
|
||||
# as the ScriptUI class object itself is the implementation of TrezorClientUI
|
||||
# (ScriptUI is just a set of staticmethods)
|
||||
return ScriptUI
|
||||
else:
|
||||
return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
||||
|
||||
def get_client(self) -> TrezorClient:
|
||||
transport = self.get_transport()
|
||||
ui = self.get_ui()
|
||||
return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
||||
return get_client(self.get_transport())
|
||||
|
||||
def get_management_session(self) -> Session:
|
||||
client = self.get_client()
|
||||
management_session = client.get_management_session()
|
||||
return management_session
|
||||
|
||||
@contextmanager
|
||||
def client_context(self):
|
||||
@ -128,7 +237,57 @@ class TrezorConnection:
|
||||
# other exceptions may cause a traceback
|
||||
|
||||
|
||||
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
|
||||
def with_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
|
||||
*,
|
||||
empty_passphrase: bool = False,
|
||||
derive_cardano: bool = False,
|
||||
management: bool = False,
|
||||
must_resume: bool = False,
|
||||
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
|
||||
"""Provides a Click command with parameter `session=obj.get_session(...)` or
|
||||
`session=obj.get_management_session()` based on the parameters provided.
|
||||
|
||||
If default parameters are ok, this decorator can be used without parentheses.
|
||||
|
||||
TODO: handle resumption of sessions and their (potential) closure.
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: FuncWithSession,
|
||||
) -> "t.Callable[P, R]":
|
||||
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def function_with_session(
|
||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
if management:
|
||||
session = obj.get_management_session()
|
||||
else:
|
||||
session = obj.get_session(
|
||||
derive_cardano=derive_cardano,
|
||||
empty_passphrase=empty_passphrase,
|
||||
must_resume=must_resume,
|
||||
)
|
||||
try:
|
||||
return func(session, *args, **kwargs)
|
||||
finally:
|
||||
pass
|
||||
# TODO try end session if not resumed
|
||||
|
||||
return function_with_session
|
||||
|
||||
# If the decorator @get_session is used without parentheses
|
||||
if func and callable(func):
|
||||
return decorator(func) # type: ignore [Function return type]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def with_client(
|
||||
func: "t.Callable[Concatenate[TrezorClient, P], R]",
|
||||
) -> "t.Callable[P, R]":
|
||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||
|
||||
Sessions are handled transparently. The user is warned when session did not resume
|
||||
@ -142,23 +301,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
|
||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
with obj.client_context() as client:
|
||||
session_was_resumed = obj.session_id == client.session_id
|
||||
if not session_was_resumed and obj.session_id is not None:
|
||||
# tried to resume but failed
|
||||
click.echo("Warning: failed to resume session.", err=True)
|
||||
|
||||
# session_was_resumed = obj.session_id == client.session_id
|
||||
# if not session_was_resumed and obj.session_id is not None:
|
||||
# # tried to resume but failed
|
||||
# click.echo("Warning: failed to resume session.", err=True)
|
||||
click.echo(
|
||||
"Warning: resume session detection is not implemented yet!", err=True
|
||||
)
|
||||
try:
|
||||
return func(client, *args, **kwargs)
|
||||
finally:
|
||||
if not session_was_resumed:
|
||||
try:
|
||||
client.end_session()
|
||||
except Exception:
|
||||
pass
|
||||
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
get_channel_db().save_channel(client.protocol)
|
||||
# if not session_was_resumed:
|
||||
# try:
|
||||
# client.end_session()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
return trezorctl_command_with_client
|
||||
|
||||
|
||||
# def with_client(
|
||||
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
|
||||
# ) -> "t.Callable[P, R]":
|
||||
# """Wrap a Click command in `with obj.client_context() as client`.
|
||||
|
||||
# Sessions are handled transparently. The user is warned when session did not resume
|
||||
# cleanly. The session is closed after the command completes - unless the session
|
||||
# was resumed, in which case it should remain open.
|
||||
# """
|
||||
|
||||
# @click.pass_obj
|
||||
# @functools.wraps(func)
|
||||
# def trezorctl_command_with_client(
|
||||
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
# ) -> "R":
|
||||
# with obj.client_context() as client:
|
||||
# session_was_resumed = obj.session_id == client.session_id
|
||||
# if not session_was_resumed and obj.session_id is not None:
|
||||
# # tried to resume but failed
|
||||
# click.echo("Warning: failed to resume session.", err=True)
|
||||
|
||||
# try:
|
||||
# return func(client, *args, **kwargs)
|
||||
# finally:
|
||||
# if not session_was_resumed:
|
||||
# try:
|
||||
# client.end_session()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
|
||||
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
|
||||
# return trezorctl_command_with_client
|
||||
|
||||
|
||||
class AliasedGroup(click.Group):
|
||||
"""Command group that handles aliases and Click 6.x compatibility.
|
||||
|
||||
@ -188,14 +386,14 @@ class AliasedGroup(click.Group):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aliases: Optional[Dict[str, click.Command]] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
aliases: t.Dict[str, click.Command] | None = None,
|
||||
*args: t.Any,
|
||||
**kwargs: t.Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aliases = aliases or {}
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
cmd_name = cmd_name.replace("_", "-")
|
||||
# try to look up the real name
|
||||
cmd = super().get_command(ctx, cmd_name)
|
||||
|
@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import click
|
||||
|
||||
from .. import benchmark
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
def list_names_patern(
|
||||
client: "TrezorClient", pattern: Optional[str] = None
|
||||
) -> List[str]:
|
||||
names = list(benchmark.list_names(client).names)
|
||||
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
|
||||
names = list(benchmark.list_names(session).names)
|
||||
if pattern is None:
|
||||
return names
|
||||
return [name for name in names if fnmatch(name, pattern)]
|
||||
@ -43,10 +41,10 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_client
|
||||
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
||||
"""List names of all supported benchmarks"""
|
||||
names = list_names_patern(client, pattern)
|
||||
names = list_names_patern(session, pattern)
|
||||
if len(names) == 0:
|
||||
click.echo("No benchmark satisfies the pattern.")
|
||||
else:
|
||||
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_client
|
||||
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def run(session: "Session", pattern: Optional[str]) -> None:
|
||||
"""Run benchmark"""
|
||||
names = list_names_patern(client, pattern)
|
||||
names = list_names_patern(session, pattern)
|
||||
if len(names) == 0:
|
||||
click.echo("No benchmark satisfies the pattern.")
|
||||
else:
|
||||
for name in names:
|
||||
result = benchmark.run(client, name)
|
||||
result = benchmark.run(session, name)
|
||||
click.echo(f"{name}: {result.value} {result.unit}")
|
||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import binance, tools
|
||||
from . import with_client
|
||||
from ..transport.session import Session
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import messages
|
||||
from ..client import TrezorClient
|
||||
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
|
||||
@ -39,23 +39,23 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Binance address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_address(client, address_n, show_display, chunkify)
|
||||
return binance.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Binance public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_public_key(client, address_n, show_display).hex()
|
||||
return binance.get_public_key(session, address_n, show_display).hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> "messages.BinanceSignedTx":
|
||||
"""Sign Binance transaction.
|
||||
|
||||
Transaction must be provided as a JSON file.
|
||||
"""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
||||
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||
|
@ -13,6 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
@ -22,10 +23,10 @@ import click
|
||||
import construct as c
|
||||
|
||||
from .. import btc, messages, protobuf, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PURPOSE_BIP44 = 44
|
||||
PURPOSE_BIP48 = 48
|
||||
@ -174,15 +175,15 @@ def cli() -> None:
|
||||
help="Sort pubkeys lexicographically using BIP-67",
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
script_type: Optional[messages.InputScriptType],
|
||||
script_type: messages.InputScriptType | None,
|
||||
show_display: bool,
|
||||
multisig_xpub: List[str],
|
||||
multisig_threshold: Optional[int],
|
||||
multisig_threshold: int | None,
|
||||
multisig_suffix_length: int,
|
||||
multisig_sort_pubkeys: bool,
|
||||
chunkify: bool,
|
||||
@ -235,7 +236,7 @@ def get_address(
|
||||
multisig = None
|
||||
|
||||
return btc.get_address(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
address_n,
|
||||
show_display,
|
||||
@ -252,9 +253,9 @@ def get_address(
|
||||
@click.option("-e", "--curve")
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
curve: Optional[str],
|
||||
@ -266,7 +267,7 @@ def get_public_node(
|
||||
if script_type is None:
|
||||
script_type = guess_script_type_from_path(address_n)
|
||||
result = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
ecdsa_curve_name=curve,
|
||||
show_display=show_display,
|
||||
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
|
||||
|
||||
|
||||
def _get_descriptor(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: Optional[str],
|
||||
account: int,
|
||||
purpose: Optional[int],
|
||||
@ -326,7 +327,7 @@ def _get_descriptor(
|
||||
|
||||
n = tools.parse_path(path)
|
||||
pub = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
n,
|
||||
show_display=show_display,
|
||||
coin_name=coin,
|
||||
@ -363,9 +364,9 @@ def _get_descriptor(
|
||||
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_descriptor(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: Optional[str],
|
||||
account: int,
|
||||
account_type: Optional[int],
|
||||
@ -375,7 +376,7 @@ def get_descriptor(
|
||||
"""Get descriptor of given account."""
|
||||
try:
|
||||
return _get_descriptor(
|
||||
client, coin, account, account_type, script_type, show_display
|
||||
session, coin, account, account_type, script_type, show_display
|
||||
)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
@ -390,8 +391,8 @@ def get_descriptor(
|
||||
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("json_file", type=click.File())
|
||||
@with_client
|
||||
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
@with_session
|
||||
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
|
||||
"""Sign transaction.
|
||||
|
||||
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
||||
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
}
|
||||
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
inputs,
|
||||
outputs,
|
||||
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
message: str,
|
||||
@ -462,7 +463,7 @@ def sign_message(
|
||||
if script_type is None:
|
||||
script_type = guess_script_type_from_path(address_n)
|
||||
res = btc.sign_message(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
address_n,
|
||||
message,
|
||||
@ -483,9 +484,9 @@ def sign_message(
|
||||
@click.argument("address")
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
signature: str,
|
||||
@ -495,7 +496,7 @@ def verify_message(
|
||||
"""Verify message."""
|
||||
signature_bytes = base64.b64decode(signature)
|
||||
return btc.verify_message(
|
||||
client, coin, address, signature_bytes, message, chunkify=chunkify
|
||||
session, coin, address, signature_bytes, message, chunkify=chunkify
|
||||
)
|
||||
|
||||
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
||||
import click
|
||||
|
||||
from .. import cardano, messages, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
|
||||
|
||||
@ -62,9 +62,9 @@ def cli() -> None:
|
||||
@click.option("-i", "--include-network-id", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.option("-T", "--tag-cbor-sets", is_flag=True)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
file: TextIO,
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
protocol_magic: int,
|
||||
@ -123,9 +123,8 @@ def sign_tx(
|
||||
for p in transaction["additional_witness_requests"]
|
||||
]
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
sign_tx_response = cardano.sign_tx(
|
||||
client,
|
||||
session,
|
||||
signing_mode,
|
||||
inputs,
|
||||
outputs,
|
||||
@ -209,9 +208,9 @@ def sign_tx(
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
address_type: messages.CardanoAddressType,
|
||||
staking_address: str,
|
||||
@ -262,9 +261,8 @@ def get_address(
|
||||
script_staking_hash_bytes,
|
||||
)
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_address(
|
||||
client,
|
||||
session,
|
||||
address_parameters,
|
||||
protocol_magic,
|
||||
network_id,
|
||||
@ -283,18 +281,17 @@ def get_address(
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
show_display: bool,
|
||||
) -> messages.CardanoPublicKey:
|
||||
"""Get Cardano public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_public_key(
|
||||
client, address_n, derivation_type=derivation_type, show_display=show_display
|
||||
session, address_n, derivation_type=derivation_type, show_display=show_display
|
||||
)
|
||||
|
||||
|
||||
@ -312,9 +309,9 @@ def get_public_key(
|
||||
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
file: TextIO,
|
||||
display_format: messages.CardanoNativeScriptHashDisplayFormat,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
@ -323,7 +320,6 @@ def get_native_script_hash(
|
||||
native_script_json = json.load(file)
|
||||
native_script = cardano.parse_native_script(native_script_json)
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_native_script_hash(
|
||||
client, native_script, display_format, derivation_type=derivation_type
|
||||
session, native_script, display_format, derivation_type=derivation_type
|
||||
)
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
|
||||
import click
|
||||
|
||||
from .. import misc, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
PROMPT_TYPE = ChoiceType(
|
||||
@ -42,10 +42,10 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("size", type=int)
|
||||
@with_client
|
||||
def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||
@with_session(empty_passphrase=True)
|
||||
def get_entropy(session: "Session", size: int) -> str:
|
||||
"""Get random bytes from device."""
|
||||
return misc.get_entropy(client, size).hex()
|
||||
return misc.get_entropy(session, size).hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
@with_session(empty_passphrase=True)
|
||||
def encrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
key: str,
|
||||
value: str,
|
||||
@ -75,7 +75,7 @@ def encrypt_keyvalue(
|
||||
ask_on_encrypt, ask_on_decrypt = prompt
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.encrypt_keyvalue(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
key,
|
||||
value.encode(),
|
||||
@ -91,9 +91,9 @@ def encrypt_keyvalue(
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
@with_session(empty_passphrase=True)
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
key: str,
|
||||
value: str,
|
||||
@ -112,7 +112,7 @@ def decrypt_keyvalue(
|
||||
ask_on_encrypt, ask_on_decrypt = prompt
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.decrypt_keyvalue(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
key,
|
||||
bytes.fromhex(value),
|
||||
|
@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
|
||||
|
||||
import click
|
||||
|
||||
from .. import mapping, messages, protobuf
|
||||
from ..client import TrezorClient
|
||||
from ..debuglink import TrezorClientDebugLink
|
||||
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
|
||||
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
|
||||
from ..debuglink import record_screen
|
||||
from . import with_client
|
||||
from ..transport.session import Session
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import TrezorConnection
|
||||
@ -35,51 +34,51 @@ def cli() -> None:
|
||||
"""Miscellaneous debug features."""
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("message_name_or_type")
|
||||
@click.argument("hex_data")
|
||||
@click.pass_obj
|
||||
def send_bytes(
|
||||
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
|
||||
) -> None:
|
||||
"""Send raw bytes to Trezor.
|
||||
# @cli.command()
|
||||
# @click.argument("message_name_or_type")
|
||||
# @click.argument("hex_data")
|
||||
# @click.pass_obj
|
||||
# def send_bytes(
|
||||
# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
|
||||
# ) -> None:
|
||||
# """Send raw bytes to Trezor.
|
||||
|
||||
Message type and message data must be specified separately, due to how message
|
||||
chunking works on the transport level. Message length is calculated and sent
|
||||
automatically, and it is currently impossible to explicitly specify invalid length.
|
||||
# Message type and message data must be specified separately, due to how message
|
||||
# chunking works on the transport level. Message length is calculated and sent
|
||||
# automatically, and it is currently impossible to explicitly specify invalid length.
|
||||
|
||||
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
|
||||
in which case the value of that enum is used.
|
||||
"""
|
||||
if message_name_or_type.isdigit():
|
||||
message_type = int(message_name_or_type)
|
||||
else:
|
||||
message_type = getattr(messages.MessageType, message_name_or_type)
|
||||
# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
|
||||
# in which case the value of that enum is used.
|
||||
# """
|
||||
# if message_name_or_type.isdigit():
|
||||
# message_type = int(message_name_or_type)
|
||||
# else:
|
||||
# message_type = getattr(messages.MessageType, message_name_or_type)
|
||||
|
||||
if not isinstance(message_type, int):
|
||||
raise click.ClickException("Invalid message type.")
|
||||
# if not isinstance(message_type, int):
|
||||
# raise click.ClickException("Invalid message type.")
|
||||
|
||||
try:
|
||||
message_data = bytes.fromhex(hex_data)
|
||||
except Exception as e:
|
||||
raise click.ClickException("Invalid hex data.") from e
|
||||
# try:
|
||||
# message_data = bytes.fromhex(hex_data)
|
||||
# except Exception as e:
|
||||
# raise click.ClickException("Invalid hex data.") from e
|
||||
|
||||
transport = obj.get_transport()
|
||||
transport.begin_session()
|
||||
transport.write(message_type, message_data)
|
||||
# transport = obj.get_transport()
|
||||
# transport.deprecated_begin_session()
|
||||
# transport.write(message_type, message_data)
|
||||
|
||||
response_type, response_data = transport.read()
|
||||
transport.end_session()
|
||||
# response_type, response_data = transport.read()
|
||||
# transport.deprecated_end_session()
|
||||
|
||||
click.echo(f"Response type: {response_type}")
|
||||
click.echo(f"Response data: {response_data.hex()}")
|
||||
# click.echo(f"Response type: {response_type}")
|
||||
# click.echo(f"Response data: {response_data.hex()}")
|
||||
|
||||
try:
|
||||
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
click.echo("Parsed message:")
|
||||
click.echo(protobuf.format_message(msg))
|
||||
except Exception as e:
|
||||
click.echo(f"Could not parse response: {e}")
|
||||
# try:
|
||||
# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
# click.echo("Parsed message:")
|
||||
# click.echo(protobuf.format_message(msg))
|
||||
# except Exception as e:
|
||||
# click.echo(f"Could not parse response: {e}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -106,17 +105,17 @@ def record_screen_from_connection(
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def prodtest_t1(client: "TrezorClient") -> str:
|
||||
@with_session(management=True)
|
||||
def prodtest_t1(session: "Session") -> str:
|
||||
"""Perform a prodtest on Model One.
|
||||
|
||||
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
|
||||
"""
|
||||
return debuglink_prodtest_t1(client)
|
||||
return debuglink_prodtest_t1(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def optiga_set_sec_max(client: "TrezorClient") -> str:
|
||||
@with_session(management=True)
|
||||
def optiga_set_sec_max(session: "Session") -> str:
|
||||
"""Set Optiga's security event counter to maximum."""
|
||||
return debuglink_optiga_set_sec_max(client)
|
||||
return debuglink_optiga_set_sec_max(session)
|
||||
|
@ -24,11 +24,11 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import debuglink, device, exceptions, messages, ui
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..protobuf import MessageType
|
||||
from ..transport.session import Session
|
||||
from . import TrezorConnection
|
||||
|
||||
RECOVERY_DEVICE_INPUT_METHOD = {
|
||||
@ -64,17 +64,18 @@ def cli() -> None:
|
||||
help="Wipe device in bootloader mode. This also erases the firmware.",
|
||||
is_flag=True,
|
||||
)
|
||||
@with_client
|
||||
def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||
@with_session(management=True)
|
||||
def wipe(session: "Session", bootloader: bool) -> str:
|
||||
"""Reset device to factory defaults and remove all private data."""
|
||||
features = session.features
|
||||
if bootloader:
|
||||
if not client.features.bootloader_mode:
|
||||
if not features.bootloader_mode:
|
||||
click.echo("Please switch your device to bootloader mode.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo("Wiping user data and firmware!")
|
||||
else:
|
||||
if client.features.bootloader_mode:
|
||||
if features.bootloader_mode:
|
||||
click.echo(
|
||||
"Your device is in bootloader mode. This operation would also erase firmware."
|
||||
)
|
||||
@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||
click.echo("Wiping user data!")
|
||||
|
||||
try:
|
||||
return device.wipe(client)
|
||||
return device.wipe(
|
||||
session
|
||||
) # TODO decide where the wipe should happen - management or regular session
|
||||
except exceptions.TrezorFailure as e:
|
||||
click.echo("Action failed: {} {}".format(*e.args))
|
||||
sys.exit(3)
|
||||
@ -103,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
||||
@click.option("-a", "--academic", is_flag=True)
|
||||
@click.option("-b", "--needs-backup", is_flag=True)
|
||||
@click.option("-n", "--no-backup", is_flag=True)
|
||||
@with_client
|
||||
@with_session(management=True)
|
||||
def load(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
mnemonic: t.Sequence[str],
|
||||
pin: str,
|
||||
passphrase_protection: bool,
|
||||
@ -136,7 +139,7 @@ def load(
|
||||
|
||||
try:
|
||||
return debuglink.load_device(
|
||||
client,
|
||||
session,
|
||||
mnemonic=list(mnemonic),
|
||||
pin=pin,
|
||||
passphrase_protection=passphrase_protection,
|
||||
@ -171,9 +174,9 @@ def load(
|
||||
)
|
||||
@click.option("-d", "--dry-run", is_flag=True)
|
||||
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
|
||||
@with_client
|
||||
@with_session(management=True)
|
||||
def recover(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
words: str,
|
||||
expand: bool,
|
||||
pin_protection: bool,
|
||||
@ -201,7 +204,7 @@ def recover(
|
||||
type = messages.RecoveryType.UnlockRepeatedBackup
|
||||
|
||||
return device.recover(
|
||||
client,
|
||||
session,
|
||||
word_count=int(words),
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -222,9 +225,9 @@ def recover(
|
||||
@click.option("-s", "--skip-backup", is_flag=True)
|
||||
@click.option("-n", "--no-backup", is_flag=True)
|
||||
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
|
||||
@with_client
|
||||
@with_session(management=True)
|
||||
def setup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
strength: int | None,
|
||||
passphrase_protection: bool,
|
||||
pin_protection: bool,
|
||||
@ -241,7 +244,7 @@ def setup(
|
||||
BT = messages.BackupType
|
||||
|
||||
if backup_type is None:
|
||||
if client.version >= (2, 7, 1):
|
||||
if session.version >= (2, 7, 1):
|
||||
# SLIP39 extendable was introduced in 2.7.1
|
||||
backup_type = BT.Slip39_Single_Extendable
|
||||
else:
|
||||
@ -251,10 +254,10 @@ def setup(
|
||||
if (
|
||||
backup_type
|
||||
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
|
||||
and messages.Capability.Shamir not in client.features.capabilities
|
||||
and messages.Capability.Shamir not in session.features.capabilities
|
||||
) or (
|
||||
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
|
||||
and messages.Capability.ShamirGroups not in client.features.capabilities
|
||||
and messages.Capability.ShamirGroups not in session.features.capabilities
|
||||
):
|
||||
click.echo(
|
||||
"WARNING: Your Trezor device does not indicate support for the requested\n"
|
||||
@ -262,7 +265,7 @@ def setup(
|
||||
)
|
||||
|
||||
return device.reset(
|
||||
client,
|
||||
session,
|
||||
strength=strength,
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -277,23 +280,21 @@ def setup(
|
||||
@cli.command()
|
||||
@click.option("-t", "--group-threshold", type=int)
|
||||
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
|
||||
@with_client
|
||||
@with_session(management=True)
|
||||
def backup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
group_threshold: int | None = None,
|
||||
groups: t.Sequence[tuple[int, int]] = (),
|
||||
) -> str:
|
||||
"""Perform device seed backup."""
|
||||
|
||||
return device.backup(client, group_threshold, groups)
|
||||
return device.backup(session, group_threshold, groups)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
||||
@with_client
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
) -> str:
|
||||
@with_session(management=True)
|
||||
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str:
|
||||
"""Secure the device with SD card protection.
|
||||
|
||||
When SD card protection is enabled, a randomly generated secret is stored
|
||||
@ -307,9 +308,9 @@ def sd_protect(
|
||||
off - Remove SD card secret protection.
|
||||
refresh - Replace the current SD card secret with a new one.
|
||||
"""
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
raise click.ClickException("Trezor One does not support SD card protection.")
|
||||
return device.sd_protect(client, operation)
|
||||
return device.sd_protect(session, operation)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -319,24 +320,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str:
|
||||
|
||||
Currently only supported on Trezor Model One.
|
||||
"""
|
||||
# avoid using @with_client because it closes the session afterwards,
|
||||
# avoid using @with_management_session because it closes the session afterwards,
|
||||
# which triggers double prompt on device
|
||||
with obj.client_context() as client:
|
||||
return device.reboot_to_bootloader(client)
|
||||
return device.reboot_to_bootloader(client.get_management_session())
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def tutorial(client: "TrezorClient") -> str:
|
||||
@with_session(management=True)
|
||||
def tutorial(session: "Session") -> str:
|
||||
"""Show on-device tutorial."""
|
||||
return device.show_device_tutorial(client)
|
||||
return device.show_device_tutorial(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def unlock_bootloader(client: "TrezorClient") -> str:
|
||||
@with_session(management=True)
|
||||
def unlock_bootloader(session: "Session") -> str:
|
||||
"""Unlocks bootloader. Irreversible."""
|
||||
return device.unlock_bootloader(client)
|
||||
return device.unlock_bootloader(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -347,11 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
|
||||
type=int,
|
||||
help="Dialog expiry in seconds.",
|
||||
)
|
||||
@with_client
|
||||
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str:
|
||||
@with_session(management=True)
|
||||
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str:
|
||||
"""Show a "Do not disconnect" dialog."""
|
||||
if enable is False:
|
||||
return device.set_busy(client, None)
|
||||
return device.set_busy(session, None)
|
||||
|
||||
if expiry is None:
|
||||
raise click.ClickException("Missing option '-e' / '--expiry'.")
|
||||
@ -361,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
|
||||
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
|
||||
)
|
||||
|
||||
return device.set_busy(client, expiry * 1000)
|
||||
return device.set_busy(session, expiry * 1000)
|
||||
|
||||
|
||||
PUBKEY_WHITELIST_URL_TEMPLATE = (
|
||||
@ -381,9 +382,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
|
||||
is_flag=True,
|
||||
help="Do not check intermediate certificates against the whitelist.",
|
||||
)
|
||||
@with_client
|
||||
@with_session(management=True)
|
||||
def authenticate(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
hex_challenge: str | None,
|
||||
root: t.BinaryIO | None,
|
||||
raw: bool | None,
|
||||
@ -408,7 +409,7 @@ def authenticate(
|
||||
challenge = bytes.fromhex(hex_challenge)
|
||||
|
||||
if raw:
|
||||
msg = device.authenticate(client, challenge)
|
||||
msg = device.authenticate(session, challenge)
|
||||
|
||||
click.echo(f"Challenge: {hex_challenge}")
|
||||
click.echo(f"Signature of challenge: {msg.signature.hex()}")
|
||||
@ -456,14 +457,14 @@ def authenticate(
|
||||
else:
|
||||
whitelist_json = requests.get(
|
||||
PUBKEY_WHITELIST_URL_TEMPLATE.format(
|
||||
model=client.model.internal_name.lower()
|
||||
model=session.model.internal_name.lower()
|
||||
)
|
||||
).json()
|
||||
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
|
||||
|
||||
try:
|
||||
authentication.authenticate_device(
|
||||
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
||||
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
||||
)
|
||||
except authentication.DeviceNotAuthentic:
|
||||
click.echo("Device is not authentic.")
|
||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import eos, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import messages
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
|
||||
|
||||
@ -37,11 +37,11 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Eos public key in base58 encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
res = eos.get_public_key(client, address_n, show_display)
|
||||
res = eos.get_public_key(session, address_n, show_display)
|
||||
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
|
||||
|
||||
|
||||
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> "messages.EosSignedTx":
|
||||
"""Sign EOS transaction."""
|
||||
tx_json = json.load(file)
|
||||
|
||||
address_n = tools.parse_path(address)
|
||||
return eos.sign_tx(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
tx_json["transaction"],
|
||||
tx_json["chain_id"],
|
||||
|
@ -26,14 +26,14 @@ import click
|
||||
|
||||
from .. import _rlp, definitions, ethereum, tools
|
||||
from ..messages import EthereumDefinitions
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import web3
|
||||
from eth_typing import ChecksumAddress # noqa: I900
|
||||
from web3.types import Wei
|
||||
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
|
||||
|
||||
@ -268,24 +268,24 @@ def cli(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Ethereum address in hex encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
return ethereum.get_address(client, address_n, show_display, network, chunkify)
|
||||
return ethereum.get_address(session, address_n, show_display, network, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
|
||||
@with_session
|
||||
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
|
||||
"""Get Ethereum public node of given path."""
|
||||
address_n = tools.parse_path(address)
|
||||
result = ethereum.get_public_node(client, address_n, show_display=show_display)
|
||||
result = ethereum.get_public_node(session, address_n, show_display=show_display)
|
||||
return {
|
||||
"node": {
|
||||
"depth": result.node.depth,
|
||||
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("to_address")
|
||||
@click.argument("amount", callback=_amount_to_int)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
chain_id: int,
|
||||
address: str,
|
||||
amount: int,
|
||||
@ -400,7 +400,7 @@ def sign_tx(
|
||||
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
|
||||
address_n = tools.parse_path(address)
|
||||
from_address = ethereum.get_address(
|
||||
client, address_n, encoded_network=encoded_network
|
||||
session, address_n, encoded_network=encoded_network
|
||||
)
|
||||
|
||||
if token:
|
||||
@ -446,7 +446,7 @@ def sign_tx(
|
||||
assert max_gas_fee is not None
|
||||
assert max_priority_fee is not None
|
||||
sig = ethereum.sign_tx_eip1559(
|
||||
client,
|
||||
session,
|
||||
n=address_n,
|
||||
nonce=nonce,
|
||||
gas_limit=gas_limit,
|
||||
@ -465,7 +465,7 @@ def sign_tx(
|
||||
gas_price = _get_web3().eth.gas_price
|
||||
assert gas_price is not None
|
||||
sig = ethereum.sign_tx(
|
||||
client,
|
||||
session,
|
||||
n=address_n,
|
||||
tx_type=tx_type,
|
||||
nonce=nonce,
|
||||
@ -526,14 +526,14 @@ def sign_tx(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_message(
|
||||
client: "TrezorClient", address: str, message: str, chunkify: bool
|
||||
session: "Session", address: str, message: str, chunkify: bool
|
||||
) -> Dict[str, str]:
|
||||
"""Sign message with Ethereum address."""
|
||||
address_n = tools.parse_path(address)
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
|
||||
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
|
||||
output = {
|
||||
"message": message,
|
||||
"address": ret.address,
|
||||
@ -550,9 +550,9 @@ def sign_message(
|
||||
help="Be compatible with Metamask's signTypedData_v4 implementation",
|
||||
)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_typed_data(
|
||||
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
|
||||
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
|
||||
) -> Dict[str, str]:
|
||||
"""Sign typed data (EIP-712) with Ethereum address.
|
||||
|
||||
@ -565,7 +565,7 @@ def sign_typed_data(
|
||||
defs = EthereumDefinitions(encoded_network=network)
|
||||
data = json.loads(file.read())
|
||||
ret = ethereum.sign_typed_data(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
data,
|
||||
metamask_v4_compat=metamask_v4_compat,
|
||||
@ -583,9 +583,9 @@ def sign_typed_data(
|
||||
@click.argument("address")
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
signature: str,
|
||||
message: str,
|
||||
@ -594,7 +594,7 @@ def verify_message(
|
||||
"""Verify message signed with Ethereum address."""
|
||||
signature_bytes = ethereum.decode_hex(signature)
|
||||
return ethereum.verify_message(
|
||||
client, address, signature_bytes, message, chunkify=chunkify
|
||||
session, address, signature_bytes, message, chunkify=chunkify
|
||||
)
|
||||
|
||||
|
||||
@ -602,9 +602,9 @@ def verify_message(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.argument("domain_hash_hex")
|
||||
@click.argument("message_hash_hex")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_typed_data_hash(
|
||||
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
|
||||
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Sign hash of typed data (EIP-712) with Ethereum address.
|
||||
@ -618,7 +618,7 @@ def sign_typed_data_hash(
|
||||
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
ret = ethereum.sign_typed_data_hash(
|
||||
client, address_n, domain_hash, message_hash, network
|
||||
session, address_n, domain_hash, message_hash, network
|
||||
)
|
||||
output = {
|
||||
"domain_hash": domain_hash_hex,
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
|
||||
import click
|
||||
|
||||
from .. import fido
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
||||
|
||||
@ -40,10 +40,10 @@ def credentials() -> None:
|
||||
|
||||
|
||||
@credentials.command(name="list")
|
||||
@with_client
|
||||
def credentials_list(client: "TrezorClient") -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def credentials_list(session: "Session") -> None:
|
||||
"""List all resident credentials on the device."""
|
||||
creds = fido.list_credentials(client)
|
||||
creds = fido.list_credentials(session)
|
||||
for cred in creds:
|
||||
click.echo("")
|
||||
click.echo(f"WebAuthn credential at index {cred.index}:")
|
||||
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
|
||||
|
||||
@credentials.command(name="add")
|
||||
@click.argument("hex_credential_id")
|
||||
@with_client
|
||||
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
|
||||
@with_session(empty_passphrase=True)
|
||||
def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
||||
"""Add the credential with the given ID as a resident credential.
|
||||
|
||||
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
||||
"""
|
||||
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
|
||||
return fido.add_credential(session, bytes.fromhex(hex_credential_id))
|
||||
|
||||
|
||||
@credentials.command(name="remove")
|
||||
@click.option(
|
||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||
)
|
||||
@with_client
|
||||
def credentials_remove(client: "TrezorClient", index: int) -> str:
|
||||
@with_session(empty_passphrase=True)
|
||||
def credentials_remove(session: "Session", index: int) -> str:
|
||||
"""Remove the resident credential at the given index."""
|
||||
return fido.remove_credential(client, index)
|
||||
return fido.remove_credential(session, index)
|
||||
|
||||
|
||||
#
|
||||
@ -110,19 +110,19 @@ def counter() -> None:
|
||||
|
||||
@counter.command(name="set")
|
||||
@click.argument("counter", type=int)
|
||||
@with_client
|
||||
def counter_set(client: "TrezorClient", counter: int) -> str:
|
||||
@with_session(empty_passphrase=True)
|
||||
def counter_set(session: "Session", counter: int) -> str:
|
||||
"""Set FIDO/U2F counter value."""
|
||||
return fido.set_counter(client, counter)
|
||||
return fido.set_counter(session, counter)
|
||||
|
||||
|
||||
@counter.command(name="get-next")
|
||||
@with_client
|
||||
def counter_get_next(client: "TrezorClient") -> int:
|
||||
@with_session(empty_passphrase=True)
|
||||
def counter_get_next(session: "Session") -> int:
|
||||
"""Get-and-increase value of FIDO/U2F counter.
|
||||
|
||||
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
||||
is returned and atomically increased. This command performs the same operation
|
||||
and returns the counter value.
|
||||
"""
|
||||
return fido.get_next_counter(client)
|
||||
return fido.get_next_counter(session)
|
||||
|
@ -37,10 +37,11 @@ import requests
|
||||
from .. import device, exceptions, firmware, messages, models
|
||||
from ..firmware import models as fw_models
|
||||
from ..models import TrezorModel
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
from . import TrezorConnection
|
||||
|
||||
MODEL_CHOICE = ChoiceType(
|
||||
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
|
||||
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
||||
1.8.0 because that installs the appropriate bootloader.
|
||||
"""
|
||||
f = client.features
|
||||
version = (f.major_version, f.minor_version, f.patch_version)
|
||||
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
|
||||
features = client.features
|
||||
version = client.version
|
||||
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
|
||||
return bootloader_onev2
|
||||
|
||||
|
||||
@ -306,25 +307,26 @@ def find_best_firmware_version(
|
||||
If the specified version is not found, prints the closest available version
|
||||
(higher than the specified one, if existing).
|
||||
"""
|
||||
features = client.features
|
||||
model = client.model
|
||||
|
||||
if bitcoin_only is None:
|
||||
bitcoin_only = _should_use_bitcoin_only(client.features)
|
||||
bitcoin_only = _should_use_bitcoin_only(features)
|
||||
|
||||
def version_str(version: Iterable[int]) -> str:
|
||||
return ".".join(map(str, version))
|
||||
|
||||
f = client.features
|
||||
|
||||
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
|
||||
releases = get_all_firmware_releases(model, bitcoin_only, beta)
|
||||
highest_version = releases[0]["version"]
|
||||
|
||||
if version:
|
||||
want_version = [int(x) for x in version.split(".")]
|
||||
if len(want_version) != 3:
|
||||
click.echo("Please use the 'X.Y.Z' version format.")
|
||||
if want_version[0] != f.major_version:
|
||||
if want_version[0] != features.major_version:
|
||||
click.echo(
|
||||
f"Warning: Trezor {client.model.name} firmware version should be "
|
||||
f"{f.major_version}.X.Y (requested: {version})"
|
||||
f"Warning: Trezor {model.name} firmware version should be "
|
||||
f"{features.major_version}.X.Y (requested: {version})"
|
||||
)
|
||||
else:
|
||||
want_version = highest_version
|
||||
@ -359,8 +361,8 @@ def find_best_firmware_version(
|
||||
# to the newer one, in that case update to the minimal
|
||||
# compatible version first
|
||||
# Choosing the version key to compare based on (not) being in BL mode
|
||||
client_version = [f.major_version, f.minor_version, f.patch_version]
|
||||
if f.bootloader_mode:
|
||||
client_version = client.version
|
||||
if features.bootloader_mode:
|
||||
key_to_compare = "min_bootloader_version"
|
||||
else:
|
||||
key_to_compare = "min_firmware_version"
|
||||
@ -447,11 +449,11 @@ def extract_embedded_fw(
|
||||
|
||||
|
||||
def upload_firmware_into_device(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
firmware_data: bytes,
|
||||
) -> None:
|
||||
"""Perform the final act of loading the firmware into Trezor."""
|
||||
f = client.features
|
||||
f = session.features
|
||||
try:
|
||||
if f.major_version == 1 and f.firmware_present is not False:
|
||||
# Trezor One does not send ButtonRequest
|
||||
@ -461,7 +463,7 @@ def upload_firmware_into_device(
|
||||
with click.progressbar(
|
||||
label="Uploading", length=len(firmware_data), show_eta=False
|
||||
) as bar:
|
||||
firmware.update(client, firmware_data, bar.update)
|
||||
firmware.update(session, firmware_data, bar.update)
|
||||
except exceptions.Cancelled:
|
||||
click.echo("Update aborted on device.")
|
||||
except exceptions.TrezorException as e:
|
||||
@ -654,6 +656,7 @@ def update(
|
||||
against data.trezor.io information, if available.
|
||||
"""
|
||||
with obj.client_context() as client:
|
||||
management_session = client.get_management_session()
|
||||
if sum(bool(x) for x in (filename, url, version)) > 1:
|
||||
click.echo("You can use only one of: filename, url, version.")
|
||||
sys.exit(1)
|
||||
@ -709,7 +712,7 @@ def update(
|
||||
if _is_strict_update(client, firmware_data):
|
||||
header_size = _get_firmware_header_size(firmware_data)
|
||||
device.reboot_to_bootloader(
|
||||
client,
|
||||
management_session,
|
||||
boot_command=messages.BootCommand.INSTALL_UPGRADE,
|
||||
firmware_header=firmware_data[:header_size],
|
||||
language_data=language_data,
|
||||
@ -719,7 +722,7 @@ def update(
|
||||
click.echo(
|
||||
"WARNING: Seamless installation not possible, language data will not be uploaded."
|
||||
)
|
||||
device.reboot_to_bootloader(client)
|
||||
device.reboot_to_bootloader(management_session)
|
||||
|
||||
click.echo("Waiting for bootloader...")
|
||||
while True:
|
||||
@ -735,13 +738,15 @@ def update(
|
||||
click.echo("Please switch your device to bootloader mode.")
|
||||
sys.exit(1)
|
||||
|
||||
upload_firmware_into_device(client=client, firmware_data=firmware_data)
|
||||
upload_firmware_into_device(
|
||||
session=client.get_management_session(), firmware_data=firmware_data
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("hex_challenge", required=False)
|
||||
@with_client
|
||||
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
|
||||
@with_session(management=True)
|
||||
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
|
||||
"""Get a hash of the installed firmware combined with the optional challenge."""
|
||||
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
|
||||
return firmware.get_hash(client, challenge).hex()
|
||||
return firmware.get_hash(session, challenge).hex()
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
|
||||
import click
|
||||
|
||||
from .. import messages, monero, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
|
||||
|
||||
@ -42,9 +42,9 @@ def cli() -> None:
|
||||
default=messages.MoneroNetworkType.MAINNET,
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
network_type: messages.MoneroNetworkType,
|
||||
@ -52,7 +52,7 @@ def get_address(
|
||||
) -> bytes:
|
||||
"""Get Monero address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return monero.get_address(client, address_n, show_display, network_type, chunkify)
|
||||
return monero.get_address(session, address_n, show_display, network_type, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -63,13 +63,13 @@ def get_address(
|
||||
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
|
||||
default=messages.MoneroNetworkType.MAINNET,
|
||||
)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_watch_key(
|
||||
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
|
||||
session: "Session", address: str, network_type: messages.MoneroNetworkType
|
||||
) -> Dict[str, str]:
|
||||
"""Get Monero watch key for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
res = monero.get_watch_key(client, address_n, network_type)
|
||||
res = monero.get_watch_key(session, address_n, network_type)
|
||||
# TODO: could be made required in MoneroWatchKey
|
||||
assert res.address is not None
|
||||
assert res.watch_key is not None
|
||||
|
@ -21,10 +21,10 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import nem, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
|
||||
|
||||
@ -39,9 +39,9 @@ def cli() -> None:
|
||||
@click.option("-N", "--network", type=int, default=0x68)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
network: int,
|
||||
show_display: bool,
|
||||
@ -49,7 +49,7 @@ def get_address(
|
||||
) -> str:
|
||||
"""Get NEM address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return nem.get_address(client, address_n, network, show_display, chunkify)
|
||||
return nem.get_address(session, address_n, network, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -58,9 +58,9 @@ def get_address(
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
file: TextIO,
|
||||
broadcast: Optional[str],
|
||||
@ -71,7 +71,7 @@ def sign_tx(
|
||||
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
||||
"""
|
||||
address_n = tools.parse_path(address)
|
||||
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
||||
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||
|
||||
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}
|
||||
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import ripple, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
|
||||
|
||||
@ -37,13 +37,13 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Ripple address"""
|
||||
address_n = tools.parse_path(address)
|
||||
return ripple.get_address(client, address_n, show_display, chunkify)
|
||||
return ripple.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -51,13 +51,13 @@ def get_address(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
|
||||
@with_session
|
||||
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
|
||||
"""Sign Ripple transaction"""
|
||||
address_n = tools.parse_path(address)
|
||||
msg = ripple.create_sign_tx_msg(json.load(file))
|
||||
|
||||
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
|
||||
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
|
||||
click.echo("Signature:")
|
||||
click.echo(result.signature.hex())
|
||||
click.echo()
|
||||
|
@ -24,10 +24,11 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import device, messages, toif
|
||||
from . import AliasedGroup, ChoiceType, with_client
|
||||
from ..transport.session import Session
|
||||
from . import AliasedGroup, ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
pass
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
@ -180,18 +181,18 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||
@with_client
|
||||
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
|
||||
@with_session(management=True)
|
||||
def pin(session: "Session", enable: Optional[bool], remove: bool) -> str:
|
||||
"""Set, change or remove PIN."""
|
||||
# Remove argument is there for backwards compatibility
|
||||
return device.change_pin(client, remove=_should_remove(enable, remove))
|
||||
return device.change_pin(session, remove=_should_remove(enable, remove))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||
@with_client
|
||||
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
|
||||
@with_session(management=True)
|
||||
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str:
|
||||
"""Set or remove the wipe code.
|
||||
|
||||
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
||||
@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
|
||||
removed and the device will be reset to factory defaults.
|
||||
"""
|
||||
# Remove argument is there for backwards compatibility
|
||||
return device.change_wipe_code(client, remove=_should_remove(enable, remove))
|
||||
return device.change_wipe_code(session, remove=_should_remove(enable, remove))
|
||||
|
||||
|
||||
@cli.command()
|
||||
# keep the deprecated -l/--label option, make it do nothing
|
||||
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.argument("label")
|
||||
@with_client
|
||||
def label(client: "TrezorClient", label: str) -> str:
|
||||
@with_session(management=True)
|
||||
def label(session: "Session", label: str) -> str:
|
||||
"""Set new device label."""
|
||||
return device.apply_settings(client, label=label)
|
||||
return device.apply_settings(session, label=label)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def brightness(client: "TrezorClient") -> str:
|
||||
@with_session(management=True)
|
||||
def brightness(session: "Session") -> str:
|
||||
"""Set display brightness."""
|
||||
return device.set_brightness(client)
|
||||
return device.set_brightness(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
|
||||
@with_session(management=True)
|
||||
def haptic_feedback(session: "Session", enable: bool) -> str:
|
||||
"""Enable or disable haptic feedback."""
|
||||
return device.apply_settings(client, haptic_feedback=enable)
|
||||
return device.apply_settings(session, haptic_feedback=enable)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
|
||||
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
|
||||
)
|
||||
@click.option("-d/-D", "--display/--no-display", default=None)
|
||||
@with_client
|
||||
@with_session(management=True)
|
||||
def language(
|
||||
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
|
||||
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
|
||||
) -> str:
|
||||
"""Set new language with translations."""
|
||||
if remove != (path_or_url is None):
|
||||
@ -260,29 +261,29 @@ def language(
|
||||
f"Failed to load translations from {path_or_url}"
|
||||
) from None
|
||||
return device.change_language(
|
||||
client, language_data=language_data, show_display=display
|
||||
session, language_data=language_data, show_display=display
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("rotation", type=ChoiceType(ROTATION))
|
||||
@with_client
|
||||
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str:
|
||||
@with_session(management=True)
|
||||
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str:
|
||||
"""Set display rotation.
|
||||
|
||||
Configure display rotation for Trezor Model T. The options are
|
||||
north, east, south or west.
|
||||
"""
|
||||
return device.apply_settings(client, display_rotation=rotation)
|
||||
return device.apply_settings(session, display_rotation=rotation)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("delay", type=str)
|
||||
@with_client
|
||||
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
||||
@with_session(management=True)
|
||||
def auto_lock_delay(session: "Session", delay: str) -> str:
|
||||
"""Set auto-lock delay (in seconds)."""
|
||||
|
||||
if not client.features.pin_protection:
|
||||
if not session.features.pin_protection:
|
||||
raise click.ClickException("Set up a PIN first")
|
||||
|
||||
value, unit = delay[:-1], delay[-1:]
|
||||
@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
||||
seconds = float(value) * units[unit]
|
||||
else:
|
||||
seconds = float(delay) # assume seconds if no unit is specified
|
||||
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
|
||||
return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("flags")
|
||||
@with_client
|
||||
def flags(client: "TrezorClient", flags: str) -> str:
|
||||
@with_session(management=True)
|
||||
def flags(session: "Session", flags: str) -> str:
|
||||
"""Set device flags."""
|
||||
if flags.lower().startswith("0b"):
|
||||
flags_int = int(flags, 2)
|
||||
@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
|
||||
flags_int = int(flags, 16)
|
||||
else:
|
||||
flags_int = int(flags)
|
||||
return device.apply_flags(client, flags=flags_int)
|
||||
return device.apply_flags(session, flags=flags_int)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str:
|
||||
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
||||
)
|
||||
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
|
||||
@with_client
|
||||
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
@with_session(management=True)
|
||||
def homescreen(session: "Session", filename: str, quality: int) -> str:
|
||||
"""Set new homescreen.
|
||||
|
||||
To revert to default homescreen, use 'trezorctl set homescreen default'
|
||||
@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
if not path.exists() or not path.is_file():
|
||||
raise click.ClickException("Cannot open file")
|
||||
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
img = image_to_t1(path)
|
||||
else:
|
||||
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
||||
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
||||
width = (
|
||||
client.features.homescreen_width
|
||||
if client.features.homescreen_width is not None
|
||||
session.features.homescreen_width
|
||||
if session.features.homescreen_width is not None
|
||||
else 240
|
||||
)
|
||||
height = (
|
||||
client.features.homescreen_height
|
||||
if client.features.homescreen_height is not None
|
||||
session.features.homescreen_height
|
||||
if session.features.homescreen_height is not None
|
||||
else 240
|
||||
)
|
||||
img = image_to_jpeg(path, width, height, quality)
|
||||
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
||||
width = client.features.homescreen_width
|
||||
height = client.features.homescreen_height
|
||||
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
||||
width = session.features.homescreen_width
|
||||
height = session.features.homescreen_height
|
||||
if width is None or height is None:
|
||||
raise click.ClickException("Device did not report homescreen size.")
|
||||
img = image_to_toif(path, width, height, True)
|
||||
elif (
|
||||
client.features.homescreen_format == messages.HomescreenFormat.Toif
|
||||
or client.features.homescreen_format is None
|
||||
session.features.homescreen_format == messages.HomescreenFormat.Toif
|
||||
or session.features.homescreen_format is None
|
||||
):
|
||||
width = (
|
||||
client.features.homescreen_width
|
||||
if client.features.homescreen_width is not None
|
||||
session.features.homescreen_width
|
||||
if session.features.homescreen_width is not None
|
||||
else 144
|
||||
)
|
||||
height = (
|
||||
client.features.homescreen_height
|
||||
if client.features.homescreen_height is not None
|
||||
session.features.homescreen_height
|
||||
if session.features.homescreen_height is not None
|
||||
else 144
|
||||
)
|
||||
img = image_to_toif(path, width, height, False)
|
||||
@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
"Unknown image format requested by the device."
|
||||
)
|
||||
|
||||
return device.apply_settings(client, homescreen=img)
|
||||
return device.apply_settings(session, homescreen=img)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
||||
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
|
||||
)
|
||||
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
||||
@with_client
|
||||
@with_session(management=True)
|
||||
def safety_checks(
|
||||
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
|
||||
session: "Session", always: bool, level: messages.SafetyCheckLevel
|
||||
) -> str:
|
||||
"""Set safety check level.
|
||||
|
||||
@ -392,18 +393,18 @@ def safety_checks(
|
||||
"""
|
||||
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
|
||||
level = messages.SafetyCheckLevel.PromptAlways
|
||||
return device.apply_settings(client, safety_checks=level)
|
||||
return device.apply_settings(session, safety_checks=level)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def experimental_features(client: "TrezorClient", enable: bool) -> str:
|
||||
@with_session(management=True)
|
||||
def experimental_features(session: "Session", enable: bool) -> str:
|
||||
"""Enable or disable experimental message types.
|
||||
|
||||
This is a developer feature. Use with caution.
|
||||
"""
|
||||
return device.apply_settings(client, experimental_features=enable)
|
||||
return device.apply_settings(session, experimental_features=enable)
|
||||
|
||||
|
||||
#
|
||||
@ -426,25 +427,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
|
||||
|
||||
@passphrase.command(name="on")
|
||||
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
||||
@with_client
|
||||
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
|
||||
@with_session(management=True)
|
||||
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str:
|
||||
"""Enable passphrase."""
|
||||
if client.features.passphrase_protection is not True:
|
||||
if session.features.passphrase_protection is not True:
|
||||
use_passphrase = True
|
||||
else:
|
||||
use_passphrase = None
|
||||
return device.apply_settings(
|
||||
client,
|
||||
session,
|
||||
use_passphrase=use_passphrase,
|
||||
passphrase_always_on_device=force_on_device,
|
||||
)
|
||||
|
||||
|
||||
@passphrase.command(name="off")
|
||||
@with_client
|
||||
def passphrase_off(client: "TrezorClient") -> str:
|
||||
@with_session(management=True)
|
||||
def passphrase_off(session: "Session") -> str:
|
||||
"""Disable passphrase."""
|
||||
return device.apply_settings(client, use_passphrase=False)
|
||||
return device.apply_settings(session, use_passphrase=False)
|
||||
|
||||
|
||||
# Registering the aliases for backwards compatibility
|
||||
@ -457,10 +458,10 @@ passphrase.aliases = {
|
||||
|
||||
@passphrase.command(name="hide")
|
||||
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str:
|
||||
@with_session(management=True)
|
||||
def hide_passphrase_from_host(session: "Session", hide: bool) -> str:
|
||||
"""Enable or disable hiding passphrase coming from host.
|
||||
|
||||
This is a developer feature. Use with caution.
|
||||
"""
|
||||
return device.apply_settings(client, hide_passphrase_from_host=hide)
|
||||
return device.apply_settings(session, hide_passphrase_from_host=hide)
|
||||
|
@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
||||
import click
|
||||
|
||||
from .. import messages, solana, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
|
||||
DEFAULT_PATH = "m/44h/501h/0h/0h"
|
||||
@ -21,40 +21,40 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
) -> messages.SolanaPublicKey:
|
||||
"""Get Solana public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return solana.get_public_key(client, address_n, show_display)
|
||||
return solana.get_public_key(session, address_n, show_display)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
chunkify: bool,
|
||||
) -> messages.SolanaAddress:
|
||||
"""Get Solana address."""
|
||||
address_n = tools.parse_path(address)
|
||||
return solana.get_address(client, address_n, show_display, chunkify)
|
||||
return solana.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("serialized_tx", type=str)
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-a", "--additional-information-file", type=click.File("r"))
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
serialized_tx: str,
|
||||
additional_information_file: Optional[TextIO],
|
||||
@ -78,7 +78,7 @@ def sign_tx(
|
||||
)
|
||||
|
||||
return solana.sign_tx(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
bytes.fromhex(serialized_tx),
|
||||
additional_information,
|
||||
|
@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
|
||||
import click
|
||||
|
||||
from .. import stellar, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
try:
|
||||
from stellar_sdk import (
|
||||
@ -52,13 +52,13 @@ def cli() -> None:
|
||||
)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Stellar public address."""
|
||||
address_n = tools.parse_path(address)
|
||||
return stellar.get_address(client, address_n, show_display, chunkify)
|
||||
return stellar.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -77,9 +77,9 @@ def get_address(
|
||||
help="Network passphrase (blank for public network).",
|
||||
)
|
||||
@click.argument("b64envelope")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
|
||||
session: "Session", b64envelope: str, address: str, network_passphrase: str
|
||||
) -> bytes:
|
||||
"""Sign a base64-encoded transaction envelope.
|
||||
|
||||
@ -109,6 +109,6 @@ def sign_transaction(
|
||||
|
||||
address_n = tools.parse_path(address)
|
||||
tx, operations = stellar.from_envelope(envelope)
|
||||
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
|
||||
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
|
||||
|
||||
return base64.b64encode(resp.signature)
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import messages, protobuf, tezos, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
|
||||
|
||||
@ -37,23 +37,23 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Tezos address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return tezos.get_address(client, address_n, show_display, chunkify)
|
||||
return tezos.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Tezos public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return tezos.get_public_key(client, address_n, show_display)
|
||||
return tezos.get_public_key(session, address_n, show_display)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> messages.TezosSignedTx:
|
||||
"""Sign Tezos transaction."""
|
||||
address_n = tools.parse_path(address)
|
||||
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
|
||||
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify)
|
||||
return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)
|
||||
|
@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
|
||||
|
||||
import click
|
||||
|
||||
from .. import __version__, log, messages, protobuf, ui
|
||||
from ..client import TrezorClient
|
||||
from .. import __version__, log, messages, protobuf
|
||||
from ..client import ProtocolVersion, TrezorClient
|
||||
from ..transport import DeviceIsBusy, enumerate_devices
|
||||
from ..transport.session import Session
|
||||
from ..transport.thp import channel_database
|
||||
from ..transport.thp.channel_database import get_channel_db
|
||||
from ..transport.udp import UdpTransport
|
||||
from . import (
|
||||
AliasedGroup,
|
||||
@ -50,6 +53,7 @@ from . import (
|
||||
stellar,
|
||||
tezos,
|
||||
with_client,
|
||||
with_session,
|
||||
)
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None:
|
||||
"--record",
|
||||
help="Record screen changes into a specified directory.",
|
||||
)
|
||||
@click.option(
|
||||
"-n",
|
||||
"--no-store",
|
||||
is_flag=True,
|
||||
help="Do not store channels data between commands.",
|
||||
default=False,
|
||||
)
|
||||
@click.version_option(version=__version__)
|
||||
@click.pass_context
|
||||
def cli_main(
|
||||
@ -204,9 +215,10 @@ def cli_main(
|
||||
script: bool,
|
||||
session_id: Optional[str],
|
||||
record: Optional[str],
|
||||
no_store: bool,
|
||||
) -> None:
|
||||
configure_logging(verbose)
|
||||
|
||||
channel_database.set_channel_database(should_not_store=no_store)
|
||||
bytes_session_id: Optional[bytes] = None
|
||||
if session_id is not None:
|
||||
try:
|
||||
@ -214,6 +226,7 @@ def cli_main(
|
||||
except ValueError:
|
||||
raise click.ClickException(f"Not a valid session id: {session_id}")
|
||||
|
||||
# ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
|
||||
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
|
||||
|
||||
# Optionally record the screen into a specified directory.
|
||||
@ -285,18 +298,23 @@ def format_device_name(features: messages.Features) -> str:
|
||||
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||
"""List connected Trezor devices."""
|
||||
if no_resolve:
|
||||
return enumerate_devices()
|
||||
for d in enumerate_devices():
|
||||
print(d.get_path())
|
||||
return
|
||||
|
||||
from . import get_client
|
||||
|
||||
for transport in enumerate_devices():
|
||||
try:
|
||||
client = TrezorClient(transport, ui=ui.ClickUI())
|
||||
client = get_client(transport)
|
||||
description = format_device_name(client.features)
|
||||
client.end_session()
|
||||
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
get_channel_db().save_channel(client.protocol)
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
except Exception:
|
||||
description = "Failed to read details"
|
||||
click.echo(f"{transport} - {description}")
|
||||
except Exception as e:
|
||||
description = "Failed to read details " + str(type(e))
|
||||
click.echo(f"{transport.get_path()} - {description}")
|
||||
return None
|
||||
|
||||
|
||||
@ -314,15 +332,19 @@ def version() -> str:
|
||||
@cli.command()
|
||||
@click.argument("message")
|
||||
@click.option("-b", "--button-protection", is_flag=True)
|
||||
@with_client
|
||||
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
|
||||
@with_session(empty_passphrase=True)
|
||||
def ping(session: "Session", message: str, button_protection: bool) -> str:
|
||||
"""Send ping message."""
|
||||
return client.ping(message, button_protection=button_protection)
|
||||
|
||||
# TODO return short-circuit from old client for old Trezors
|
||||
return session.ping(message, button_protection)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_obj
|
||||
def get_session(obj: TrezorConnection) -> str:
|
||||
def get_session(
|
||||
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
|
||||
) -> str:
|
||||
"""Get a session ID for subsequent commands.
|
||||
|
||||
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
|
||||
@ -336,23 +358,44 @@ def get_session(obj: TrezorConnection) -> str:
|
||||
obj.session_id = None
|
||||
|
||||
with obj.client_context() as client:
|
||||
|
||||
if client.features.model == "1" and client.version < (1, 9, 0):
|
||||
raise click.ClickException(
|
||||
"Upgrade your firmware to enable session support."
|
||||
)
|
||||
|
||||
client.ensure_unlocked()
|
||||
if client.session_id is None:
|
||||
# client.ensure_unlocked()
|
||||
session = client.get_session(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano
|
||||
)
|
||||
if session.id is None:
|
||||
raise click.ClickException("Passphrase not enabled or firmware too old.")
|
||||
else:
|
||||
return client.session_id.hex()
|
||||
return session.id.hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def clear_session(client: "TrezorClient") -> None:
|
||||
@with_session(must_resume=True, empty_passphrase=True)
|
||||
def clear_session(session: "Session") -> None:
|
||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||
return client.clear_session()
|
||||
if session is None:
|
||||
click.echo("Cannot clear session as it was not properly resumed.")
|
||||
return
|
||||
session.call(messages.LockDevice())
|
||||
session.end()
|
||||
# TODO different behaviour than main, not sure if ok
|
||||
|
||||
|
||||
@cli.command()
|
||||
def delete_channels() -> None:
|
||||
"""
|
||||
Delete cached channels.
|
||||
|
||||
Do not use together with the `-n` (`--no-store`) flag,
|
||||
as the JSON database will not be deleted in that case.
|
||||
"""
|
||||
get_channel_db().clear_stored_channels()
|
||||
click.echo("Deleted stored channels")
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -21,47 +21,44 @@ import logging
|
||||
import re
|
||||
import textwrap
|
||||
import time
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from enum import Enum, IntEnum, auto
|
||||
from itertools import zip_longest
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import mapping, messages, models, protobuf
|
||||
from .client import TrezorClient
|
||||
from .exceptions import TrezorFailure
|
||||
from . import btc, mapping, messages, models, protobuf
|
||||
from .client import (
|
||||
MAX_PASSPHRASE_LENGTH,
|
||||
MAX_PIN_LENGTH,
|
||||
PASSPHRASE_ON_DEVICE,
|
||||
TrezorClient,
|
||||
)
|
||||
from .exceptions import Cancelled, PinException, TrezorFailure
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import DebugWaitType
|
||||
from .tools import expect
|
||||
from .messages import Capability, DebugWaitType
|
||||
from .tools import expect, parse_path
|
||||
from .transport.session import Session, SessionV1
|
||||
from .transport.thp.protocol_v1 import ProtocolV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from .messages import PinMatrixRequestType
|
||||
from .transport import Transport
|
||||
|
||||
ExpectedMessage = Union[
|
||||
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
|
||||
ExpectedMessage = t.Union[
|
||||
protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
|
||||
]
|
||||
|
||||
AnyDict = Dict[str, Any]
|
||||
AnyDict = t.Dict[str, t.Any]
|
||||
|
||||
class InputFunc(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hold_ms: int | None = None,
|
||||
@ -70,6 +67,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
||||
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -104,11 +102,13 @@ class UnstructuredJSONReader:
|
||||
except json.JSONDecodeError:
|
||||
self.dict = {}
|
||||
|
||||
def top_level_value(self, key: str) -> Any:
|
||||
def top_level_value(self, key: str) -> t.Any:
|
||||
return self.dict.get(key)
|
||||
|
||||
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
|
||||
def recursively_find(data: Any) -> Iterator[Any]:
|
||||
def find_objects_with_key_and_value(
|
||||
self, key: str, value: t.Any
|
||||
) -> list["AnyDict"]:
|
||||
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||
if isinstance(data, dict):
|
||||
if data.get(key) == value:
|
||||
yield data
|
||||
@ -121,7 +121,7 @@ class UnstructuredJSONReader:
|
||||
return list(recursively_find(self.dict))
|
||||
|
||||
def find_unique_object_with_key_and_value(
|
||||
self, key: str, value: Any
|
||||
self, key: str, value: t.Any
|
||||
) -> AnyDict | None:
|
||||
objects = self.find_objects_with_key_and_value(key, value)
|
||||
if not objects:
|
||||
@ -129,8 +129,10 @@ class UnstructuredJSONReader:
|
||||
assert len(objects) == 1
|
||||
return objects[0]
|
||||
|
||||
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
|
||||
def recursively_find(data: Any) -> Iterator[Any]:
|
||||
def find_values_by_key(
|
||||
self, key: str, only_type: type | None = None
|
||||
) -> list[t.Any]:
|
||||
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||
if isinstance(data, dict):
|
||||
if key in data:
|
||||
yield data[key]
|
||||
@ -148,8 +150,8 @@ class UnstructuredJSONReader:
|
||||
return values
|
||||
|
||||
def find_unique_value_by_key(
|
||||
self, key: str, default: Any, only_type: type | None = None
|
||||
) -> Any:
|
||||
self, key: str, default: t.Any, only_type: type | None = None
|
||||
) -> t.Any:
|
||||
values = self.find_values_by_key(key, only_type=only_type)
|
||||
if not values:
|
||||
return default
|
||||
@ -160,7 +162,7 @@ class UnstructuredJSONReader:
|
||||
class LayoutContent(UnstructuredJSONReader):
|
||||
"""Contains helper functions to extract specific parts of the layout."""
|
||||
|
||||
def __init__(self, json_tokens: Sequence[str]) -> None:
|
||||
def __init__(self, json_tokens: t.Sequence[str]) -> None:
|
||||
json_str = "".join(json_tokens)
|
||||
super().__init__(json_str)
|
||||
|
||||
@ -422,11 +424,13 @@ def _make_input_func(
|
||||
|
||||
|
||||
class DebugLink:
|
||||
|
||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||
self.transport = transport
|
||||
self.allow_interactions = auto_interact
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
|
||||
self.protocol = ProtocolV1(self.transport, self.mapping)
|
||||
# To be set by TrezorClientDebugLink (is not known during creation time)
|
||||
self.model: models.TrezorModel | None = None
|
||||
self.version: tuple[int, int, int] = (0, 0, 0)
|
||||
@ -479,10 +483,16 @@ class DebugLink:
|
||||
self.screen_text_file = file_path
|
||||
|
||||
def open(self) -> None:
|
||||
self.transport.begin_session()
|
||||
self.transport.open()
|
||||
# raise NotImplementedError
|
||||
# TODO is this needed?
|
||||
# self.transport.deprecated_begin_session()
|
||||
|
||||
def close(self) -> None:
|
||||
self.transport.end_session()
|
||||
pass
|
||||
# raise NotImplementedError
|
||||
# TODO is this needed?
|
||||
# self.transport.deprecated_end_session()
|
||||
|
||||
def _write(self, msg: protobuf.MessageType) -> None:
|
||||
if self.waiting_for_layout_change:
|
||||
@ -499,15 +509,10 @@ class DebugLink:
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self.transport.write(msg_type, msg_bytes)
|
||||
self.protocol.write(msg)
|
||||
|
||||
def _read(self) -> protobuf.MessageType:
|
||||
ret_type, ret_bytes = self.transport.read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(ret_type, ret_bytes)
|
||||
msg = self.protocol.read()
|
||||
|
||||
# Collapse tokens to make log use less lines.
|
||||
msg_for_log = msg
|
||||
@ -521,18 +526,27 @@ class DebugLink:
|
||||
)
|
||||
return msg
|
||||
|
||||
def _call(self, msg: protobuf.MessageType) -> Any:
|
||||
def _call(self, msg: protobuf.MessageType) -> t.Any:
|
||||
self._write(msg)
|
||||
return self._read()
|
||||
|
||||
def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkState:
|
||||
def state(
|
||||
self,
|
||||
wait_type: DebugWaitType | None = None,
|
||||
thp_channel_id: bytes | None = None,
|
||||
) -> messages.DebugLinkState:
|
||||
if wait_type is None:
|
||||
wait_type = (
|
||||
DebugWaitType.CURRENT_LAYOUT
|
||||
if self.has_global_layout
|
||||
else DebugWaitType.IMMEDIATE
|
||||
)
|
||||
result = self._call(messages.DebugLinkGetState(wait_layout=wait_type))
|
||||
result = self._call(
|
||||
messages.DebugLinkGetState(
|
||||
wait_layout=wait_type,
|
||||
thp_channel_id=thp_channel_id,
|
||||
)
|
||||
)
|
||||
while not isinstance(result, (messages.Failure, messages.DebugLinkState)):
|
||||
result = self._read()
|
||||
if isinstance(result, messages.Failure):
|
||||
@ -544,7 +558,7 @@ class DebugLink:
|
||||
|
||||
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
|
||||
# Next layout change will be caused by external event
|
||||
# (e.g. device being auto-locked or as a result of device_handler.run(xxx))
|
||||
# (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
|
||||
# and not by our debug actions/decisions.
|
||||
# Resetting the debug state so we wait for the next layout change
|
||||
# (and do not return the current state).
|
||||
@ -560,7 +574,7 @@ class DebugLink:
|
||||
return LayoutContent(obj.tokens)
|
||||
|
||||
@contextmanager
|
||||
def wait_for_layout_change(self) -> Iterator[LayoutContent]:
|
||||
def wait_for_layout_change(self) -> t.Iterator[LayoutContent]:
|
||||
# set up a dummy layout content object to be yielded
|
||||
layout_content = LayoutContent(
|
||||
["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("]
|
||||
@ -622,7 +636,7 @@ class DebugLink:
|
||||
|
||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||
|
||||
def read_recovery_word(self) -> Tuple[str | None, int | None]:
|
||||
def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
|
||||
state = self.state()
|
||||
return (state.recovery_fake_word, state.recovery_word_pos)
|
||||
|
||||
@ -700,7 +714,7 @@ class DebugLink:
|
||||
|
||||
def click(
|
||||
self,
|
||||
click: Tuple[int, int],
|
||||
click: t.Tuple[int, int],
|
||||
hold_ms: int | None = None,
|
||||
wait: bool | None = None,
|
||||
) -> LayoutContent:
|
||||
@ -862,10 +876,10 @@ class DebugUI:
|
||||
self.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.pins: Iterator[str] | None = None
|
||||
self.pins: t.Iterator[str] | None = None
|
||||
self.passphrase = ""
|
||||
self.input_flow: Union[
|
||||
Generator[None, messages.ButtonRequest, None], object, None
|
||||
self.input_flow: t.Union[
|
||||
t.Generator[None, messages.ButtonRequest, None], object, None
|
||||
] = None
|
||||
|
||||
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
|
||||
@ -896,7 +910,7 @@ class DebugUI:
|
||||
raise AssertionError("input flow ended prematurely")
|
||||
else:
|
||||
try:
|
||||
assert isinstance(self.input_flow, Generator)
|
||||
assert isinstance(self.input_flow, t.Generator)
|
||||
self.input_flow.send(br)
|
||||
except StopIteration:
|
||||
self.input_flow = self.INPUT_FLOW_DONE
|
||||
@ -918,12 +932,15 @@ class DebugUI:
|
||||
|
||||
|
||||
class MessageFilter:
|
||||
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
|
||||
|
||||
def __init__(
|
||||
self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
|
||||
) -> None:
|
||||
self.message_type = message_type
|
||||
self.fields: Dict[str, Any] = {}
|
||||
self.fields: t.Dict[str, t.Any] = {}
|
||||
self.update_fields(**fields)
|
||||
|
||||
def update_fields(self, **fields: Any) -> "MessageFilter":
|
||||
def update_fields(self, **fields: t.Any) -> "MessageFilter":
|
||||
for name, value in fields.items():
|
||||
try:
|
||||
self.fields[name] = self.from_message_or_type(value)
|
||||
@ -971,7 +988,7 @@ class MessageFilter:
|
||||
return True
|
||||
|
||||
def to_string(self, maxwidth: int = 80) -> str:
|
||||
fields: list[Tuple[str, str]] = []
|
||||
fields: list[t.Tuple[str, str]] = []
|
||||
for field in self.message_type.FIELDS.values():
|
||||
if field.name not in self.fields:
|
||||
continue
|
||||
@ -1001,7 +1018,8 @@ class MessageFilter:
|
||||
|
||||
|
||||
class MessageFilterGenerator:
|
||||
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
|
||||
|
||||
def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
|
||||
message_type = getattr(messages, key)
|
||||
return MessageFilter(message_type).update_fields
|
||||
|
||||
@ -1009,6 +1027,245 @@ class MessageFilterGenerator:
|
||||
message_filters = MessageFilterGenerator()
|
||||
|
||||
|
||||
class SessionDebugWrapper(Session):
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
self.reset_debug_features()
|
||||
if isinstance(session, SessionDebugWrapper):
|
||||
raise Exception("Cannot wrap already wrapped session!")
|
||||
|
||||
@property
|
||||
def protocol_version(self) -> int:
|
||||
return self.client.protocol_version
|
||||
|
||||
@property
|
||||
def client(self) -> TrezorClientDebugLink:
|
||||
assert isinstance(self._session.client, TrezorClientDebugLink)
|
||||
return self._session.client
|
||||
|
||||
@property
|
||||
def id(self) -> bytes:
|
||||
return self._session.id
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
print("writing message:", msg.__class__.__name__)
|
||||
self._session._write(self._filter_message(msg))
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
resp = self._filter_message(self._session._read())
|
||||
print("reading message:", resp.__class__.__name__)
|
||||
if self.actual_responses is not None:
|
||||
self.actual_responses.append(resp)
|
||||
return resp
|
||||
|
||||
def set_expected_responses(
|
||||
self,
|
||||
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
|
||||
) -> None:
|
||||
"""Set a sequence of expected responses to session calls.
|
||||
|
||||
Within a given with-block, the list of received responses from device must
|
||||
match the list of expected responses, otherwise an ``AssertionError`` is raised.
|
||||
|
||||
If an expected response is given a field value other than ``None``, that field value
|
||||
must exactly match the received field value. If a given field is ``None``
|
||||
(or unspecified) in the expected response, the received field value is not
|
||||
checked.
|
||||
|
||||
Each expected response can also be a tuple ``(bool, message)``. In that case, the
|
||||
expected response is only evaluated if the first field is ``True``.
|
||||
This is useful for differentiating sequences between Trezor models:
|
||||
|
||||
>>> trezor_one = session.features.model == "1"
|
||||
>>> session.set_expected_responses([
|
||||
>>> messages.ButtonRequest(code=ConfirmOutput),
|
||||
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
|
||||
>>> messages.Success(),
|
||||
>>> ])
|
||||
"""
|
||||
if not self.in_with_statement:
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
|
||||
# make sure all items are (bool, message) tuples
|
||||
expected_with_validity = (
|
||||
e if isinstance(e, tuple) else (True, e) for e in expected
|
||||
)
|
||||
|
||||
# only apply those items that are (True, message)
|
||||
self.expected_responses = [
|
||||
MessageFilter.from_message_or_type(expected)
|
||||
for valid, expected in expected_with_validity
|
||||
if valid
|
||||
]
|
||||
self.actual_responses = []
|
||||
|
||||
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||
"""Lock the device.
|
||||
|
||||
If the device does not have a PIN configured, this will do nothing.
|
||||
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
||||
before further actions.
|
||||
|
||||
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
||||
the device will not prompt for it after unlocking.
|
||||
|
||||
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
||||
passphrase cache, use `clear_session()`.
|
||||
"""
|
||||
# TODO update the documentation above
|
||||
# Private argument _refresh_features can be used internally to avoid
|
||||
# refreshing in cases where we will refresh soon anyway. This is used
|
||||
# in TrezorClient.clear_session()
|
||||
self.call(messages.LockDevice())
|
||||
if _refresh_features:
|
||||
self.refresh_features()
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._write(messages.Cancel())
|
||||
|
||||
def ensure_unlocked(self) -> None:
|
||||
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||
self.refresh_features()
|
||||
|
||||
def set_filter(
|
||||
self,
|
||||
message_type: t.Type[protobuf.MessageType],
|
||||
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
"""Configure a filter function for a specified message type.
|
||||
|
||||
The `callback` must be a function that accepts a protobuf message, and returns
|
||||
a (possibly modified) protobuf message of the same type. Whenever a message
|
||||
is sent or received that matches `message_type`, `callback` is invoked on the
|
||||
message and its result is substituted for the original.
|
||||
|
||||
Useful for test scenarios with an active malicious actor on the wire.
|
||||
"""
|
||||
if not self.in_with_statement:
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
|
||||
self.filters[message_type] = callback
|
||||
|
||||
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
|
||||
message_type = msg.__class__
|
||||
callback = self.filters.get(message_type)
|
||||
if callable(callback):
|
||||
return callback(deepcopy(msg))
|
||||
else:
|
||||
return msg
|
||||
|
||||
def reset_debug_features(self) -> None:
|
||||
"""Prepare the debugging session for a new testcase.
|
||||
|
||||
Clears all debugging state that might have been modified by a testcase.
|
||||
"""
|
||||
self.in_with_statement = False
|
||||
self.expected_responses: list[MessageFilter] | None = None
|
||||
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||
self.filters: t.Dict[
|
||||
t.Type[protobuf.MessageType],
|
||||
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
] = {}
|
||||
self.button_callback = self.client.button_callback
|
||||
self.pin_callback = self.client.pin_callback
|
||||
self.passphrase_callback = self._session.passphrase_callback
|
||||
self.passphrase = self._session.passphrase
|
||||
|
||||
def __enter__(self) -> "SessionDebugWrapper":
|
||||
# For usage in with/expected_responses
|
||||
if self.in_with_statement:
|
||||
raise RuntimeError("Do not nest!")
|
||||
self.in_with_statement = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
# copy expected/actual responses before clearing them
|
||||
expected_responses = self.expected_responses
|
||||
actual_responses = self.actual_responses
|
||||
|
||||
# grab a copy of the inputflow generator to raise an exception through it
|
||||
if isinstance(self.client.ui, DebugUI):
|
||||
input_flow = self.client.ui.input_flow
|
||||
else:
|
||||
input_flow = None
|
||||
|
||||
self.reset_debug_features()
|
||||
|
||||
if exc_type is None:
|
||||
# If no other exception was raised, evaluate missed responses
|
||||
# (raises AssertionError on mismatch)
|
||||
self._verify_responses(expected_responses, actual_responses)
|
||||
if isinstance(input_flow, t.Generator):
|
||||
# Ensure that the input flow is exhausted
|
||||
try:
|
||||
input_flow.throw(
|
||||
AssertionError("input flow continues past end of test")
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
elif isinstance(input_flow, t.Generator):
|
||||
# Propagate the exception through the input flow, so that we see in
|
||||
# traceback where it is stuck.
|
||||
input_flow.throw(exc_type, value, traceback)
|
||||
|
||||
@classmethod
|
||||
def _verify_responses(
|
||||
cls,
|
||||
expected: list[MessageFilter] | None,
|
||||
actual: list[protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
if expected is None and actual is None:
|
||||
return
|
||||
|
||||
assert expected is not None
|
||||
assert actual is not None
|
||||
|
||||
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
||||
if exp is None:
|
||||
output = cls._expectation_lines(expected, i)
|
||||
output.append("No more messages were expected, but we got:")
|
||||
for resp in actual[i:]:
|
||||
output.append(
|
||||
textwrap.indent(protobuf.format_message(resp), " ")
|
||||
)
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
if act is None:
|
||||
output = cls._expectation_lines(expected, i)
|
||||
output.append("This and the following message was not received.")
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
if not exp.match(act):
|
||||
output = cls._expectation_lines(expected, i)
|
||||
output.append("Actually received:")
|
||||
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
@staticmethod
|
||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||
output: list[str] = []
|
||||
output.append("Expected responses:")
|
||||
if start_at > 0:
|
||||
output.append(f" (...{start_at} previous responses omitted)")
|
||||
for i in range(start_at, stop_at):
|
||||
exp = expected[i]
|
||||
prefix = " " if i != current else ">>> "
|
||||
output.append(textwrap.indent(exp.to_string(), prefix))
|
||||
if stop_at < len(expected):
|
||||
omitted = len(expected) - stop_at
|
||||
output.append(f" (...{omitted} following responses omitted)")
|
||||
|
||||
output.append("")
|
||||
return output
|
||||
|
||||
|
||||
class TrezorClientDebugLink(TrezorClient):
|
||||
# This class implements automatic responses
|
||||
# and other functionality for unit tests
|
||||
@ -1034,54 +1291,165 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
raise
|
||||
|
||||
# set transport explicitly so that sync_responses can work
|
||||
super().__init__(transport)
|
||||
|
||||
self.transport = transport
|
||||
self.ui: DebugUI = DebugUI(self.debug)
|
||||
|
||||
self.reset_debug_features()
|
||||
self.reset_debug_features(new_management_session=True)
|
||||
self.sync_responses()
|
||||
super().__init__(transport, ui=self.ui)
|
||||
|
||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
||||
# and know the supported debug capabilities
|
||||
self.debug.model = self.model
|
||||
self.debug.version = self.version
|
||||
self.passphrase: str | None = None
|
||||
|
||||
@property
|
||||
def layout_type(self) -> LayoutType:
|
||||
return self.debug.layout_type
|
||||
|
||||
def reset_debug_features(self) -> None:
|
||||
"""Prepare the debugging client for a new testcase.
|
||||
def get_new_client(self) -> TrezorClientDebugLink:
|
||||
return TrezorClientDebugLink(self.transport, self.debug.allow_interactions)
|
||||
|
||||
def reset_debug_features(self, new_management_session: bool = False) -> None:
|
||||
"""
|
||||
Prepare the debugging client for a new testcase.
|
||||
|
||||
Clears all debugging state that might have been modified by a testcase.
|
||||
"""
|
||||
self.ui: DebugUI = DebugUI(self.debug)
|
||||
# self.pin_callback = self.ui.debug_callback_button
|
||||
self.in_with_statement = False
|
||||
self.expected_responses: list[MessageFilter] | None = None
|
||||
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||
self.filters: dict[
|
||||
type[protobuf.MessageType],
|
||||
Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
self.filters: t.Dict[
|
||||
t.Type[protobuf.MessageType],
|
||||
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
] = {}
|
||||
if new_management_session:
|
||||
self._management_session = self.get_management_session(new_session=True)
|
||||
|
||||
@property
|
||||
def button_callback(self):
|
||||
|
||||
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
session._write(messages.ButtonAck())
|
||||
self.ui.button_request(msg)
|
||||
return session._read()
|
||||
|
||||
return _callback_button
|
||||
|
||||
@property
|
||||
def pin_callback(self):
|
||||
|
||||
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
|
||||
try:
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
except Cancelled:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if any(d not in "123456789" for d in pin) or not (
|
||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||
):
|
||||
session.call_raw(messages.Cancel())
|
||||
raise ValueError("Invalid PIN provided")
|
||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
return _callback_pin
|
||||
|
||||
@property
|
||||
def passphrase_callback(self):
|
||||
def _callback_passphrase(
|
||||
session: Session, msg: messages.PassphraseRequest
|
||||
) -> t.Any:
|
||||
available_on_device = (
|
||||
Capability.PassphraseEntry in session.features.capabilities
|
||||
)
|
||||
|
||||
def send_passphrase(
|
||||
passphrase: str | None = None, on_device: bool | None = None
|
||||
) -> t.Any:
|
||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||
resp = session.call_raw(msg)
|
||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||
# session.session_id = resp.state
|
||||
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||
return resp
|
||||
|
||||
# short-circuit old style entry
|
||||
if msg._on_device is True:
|
||||
return send_passphrase(None, None)
|
||||
|
||||
try:
|
||||
if session.passphrase is None and isinstance(session, SessionV1):
|
||||
passphrase = self.ui.get_passphrase(
|
||||
available_on_device=available_on_device
|
||||
)
|
||||
else:
|
||||
passphrase = session.passphrase
|
||||
except Cancelled:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||
if not available_on_device:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise RuntimeError("Device is not capable of entering passphrase")
|
||||
else:
|
||||
return send_passphrase(on_device=True)
|
||||
|
||||
# else process host-entered passphrase
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise ValueError("Passphrase too long")
|
||||
|
||||
return send_passphrase(passphrase, on_device=False)
|
||||
|
||||
return _callback_passphrase
|
||||
|
||||
def ensure_open(self) -> None:
|
||||
"""Only open session if there isn't already an open one."""
|
||||
if self.session_counter == 0:
|
||||
self.open()
|
||||
# if self.session_counter == 0:
|
||||
# self.open()
|
||||
# TODO check if is this needed
|
||||
|
||||
def open(self) -> None:
|
||||
super().open()
|
||||
if self.session_counter == 1:
|
||||
self.debug.open()
|
||||
pass
|
||||
# TODO is this needed?
|
||||
# self.debug.open()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.session_counter == 1:
|
||||
self.debug.close()
|
||||
super().close()
|
||||
pass
|
||||
# TODO is this needed?
|
||||
# self.debug.close()
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object | None = "",
|
||||
derive_cardano: bool = False,
|
||||
) -> Session:
|
||||
if isinstance(passphrase, str):
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
return super().get_session(passphrase, derive_cardano)
|
||||
|
||||
def set_filter(
|
||||
self,
|
||||
message_type: type[protobuf.MessageType],
|
||||
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
message_type: t.Type[protobuf.MessageType],
|
||||
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
"""Configure a filter function for a specified message type.
|
||||
|
||||
@ -1106,7 +1474,8 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
return msg
|
||||
|
||||
def set_input_flow(
|
||||
self, input_flow: Generator[None, messages.ButtonRequest | None, None]
|
||||
self,
|
||||
input_flow: t.Generator[None, messages.ButtonRequest | None, None],
|
||||
) -> None:
|
||||
"""Configure a sequence of input events for the current with-block.
|
||||
|
||||
@ -1140,6 +1509,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
if not hasattr(input_flow, "send"):
|
||||
raise RuntimeError("input_flow should be a generator function")
|
||||
self.ui.input_flow = input_flow
|
||||
assert input_flow is not None
|
||||
input_flow.send(None) # start the generator
|
||||
|
||||
def watch_layout(self, watch: bool = True) -> None:
|
||||
@ -1162,7 +1532,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
self.in_with_statement = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
|
||||
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
# copy expected/actual responses before clearing them
|
||||
@ -1175,20 +1545,21 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
else:
|
||||
input_flow = None
|
||||
|
||||
self.reset_debug_features()
|
||||
self.reset_debug_features(new_management_session=False)
|
||||
|
||||
if exc_type is None:
|
||||
# If no other exception was raised, evaluate missed responses
|
||||
# (raises AssertionError on mismatch)
|
||||
self._verify_responses(expected_responses, actual_responses)
|
||||
|
||||
elif isinstance(input_flow, Generator):
|
||||
elif isinstance(input_flow, t.Generator):
|
||||
# Propagate the exception through the input flow, so that we see in
|
||||
# traceback where it is stuck.
|
||||
input_flow.throw(exc_type, value, traceback)
|
||||
|
||||
def set_expected_responses(
|
||||
self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
|
||||
self,
|
||||
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
|
||||
) -> None:
|
||||
"""Set a sequence of expected responses to client calls.
|
||||
|
||||
@ -1227,7 +1598,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
]
|
||||
self.actual_responses = []
|
||||
|
||||
def use_pin_sequence(self, pins: Iterable[str]) -> None:
|
||||
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
|
||||
"""Respond to PIN prompts from device with the provided PINs.
|
||||
The sequence must be at least as long as the expected number of PIN prompts.
|
||||
"""
|
||||
@ -1235,6 +1606,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
def use_passphrase(self, passphrase: str) -> None:
|
||||
"""Respond to passphrase prompts from device with the provided passphrase."""
|
||||
self.passphrase = passphrase
|
||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||
|
||||
def use_mnemonic(self, mnemonic: str) -> None:
|
||||
@ -1244,15 +1616,14 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
def _raw_read(self) -> protobuf.MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
resp = super()._raw_read()
|
||||
resp = self.get_management_session()._read()
|
||||
resp = self._filter_message(resp)
|
||||
if self.actual_responses is not None:
|
||||
self.actual_responses.append(resp)
|
||||
return resp
|
||||
|
||||
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
||||
return super()._raw_write(self._filter_message(msg))
|
||||
return self.get_management_session()._write(self._filter_message(msg))
|
||||
|
||||
@staticmethod
|
||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||
@ -1322,23 +1693,25 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
||||
# prompt, which is in TINY mode and does not respond to `Ping`.
|
||||
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
||||
self.transport.begin_session()
|
||||
# TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
||||
self.transport.open()
|
||||
try:
|
||||
self.transport.write(*cancel_msg)
|
||||
|
||||
# self.protocol.write(messages.Cancel())
|
||||
message = "SYNC" + secrets.token_hex(8)
|
||||
ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message))
|
||||
self.transport.write(*ping_msg)
|
||||
self.get_management_session()._write(messages.Ping(message=message))
|
||||
resp = None
|
||||
while resp != messages.Success(message=message):
|
||||
msg_id, msg_bytes = self.transport.read()
|
||||
try:
|
||||
resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes)
|
||||
resp = self.get_management_session()._read()
|
||||
|
||||
raise Exception
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
self.transport.end_session()
|
||||
pass # TODO fix
|
||||
# self.transport.end_session(self.session_id or b"")
|
||||
|
||||
def mnemonic_callback(self, _) -> str:
|
||||
word, pos = self.debug.read_recovery_word()
|
||||
@ -1352,8 +1725,8 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def load_device(
|
||||
client: "TrezorClient",
|
||||
mnemonic: Union[str, Iterable[str]],
|
||||
session: "Session",
|
||||
mnemonic: str | t.Iterable[str],
|
||||
pin: str | None,
|
||||
passphrase_protection: bool,
|
||||
label: str | None,
|
||||
@ -1366,12 +1739,12 @@ def load_device(
|
||||
|
||||
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
||||
|
||||
if client.features.initialized:
|
||||
if session.features.initialized:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call device.wipe() and try again."
|
||||
)
|
||||
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.LoadDevice(
|
||||
mnemonics=mnemonics,
|
||||
pin=pin,
|
||||
@ -1382,7 +1755,7 @@ def load_device(
|
||||
no_backup=no_backup,
|
||||
)
|
||||
)
|
||||
client.init_device()
|
||||
session.refresh_features()
|
||||
return resp
|
||||
|
||||
|
||||
@ -1391,11 +1764,11 @@ load_device_by_mnemonic = load_device
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
|
||||
if client.features.bootloader_mode is not True:
|
||||
def prodtest_t1(session: "Session") -> protobuf.MessageType:
|
||||
if session.features.bootloader_mode is not True:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.ProdTestT1(
|
||||
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
|
||||
)
|
||||
@ -1404,8 +1777,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
|
||||
|
||||
def record_screen(
|
||||
debug_client: "TrezorClientDebugLink",
|
||||
directory: Union[str, None],
|
||||
report_func: Union[Callable[[str], None], None] = None,
|
||||
directory: str | None,
|
||||
report_func: t.Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
"""Record screen changes into a specified directory.
|
||||
|
||||
@ -1451,5 +1824,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType:
|
||||
return client.call(messages.DebugLinkOptigaSetSecMax())
|
||||
def optiga_set_sec_max(session: "Session") -> protobuf.MessageType:
|
||||
return session.call(messages.DebugLinkOptigaSetSecMax())
|
||||
|
@ -23,20 +23,19 @@ from typing import TYPE_CHECKING, Callable, Iterable, Optional
|
||||
|
||||
from . import messages
|
||||
from .exceptions import Cancelled, TrezorException
|
||||
from .tools import Address, expect, session
|
||||
from .tools import Address, expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def apply_settings(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
label: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
use_passphrase: Optional[bool] = None,
|
||||
@ -67,13 +66,13 @@ def apply_settings(
|
||||
haptic_feedback=haptic_feedback,
|
||||
)
|
||||
|
||||
out = client.call(settings)
|
||||
client.refresh_features()
|
||||
out = session.call(settings)
|
||||
session.refresh_features()
|
||||
return out
|
||||
|
||||
|
||||
def _send_language_data(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
request: "messages.TranslationDataRequest",
|
||||
language_data: bytes,
|
||||
) -> "MessageType":
|
||||
@ -83,76 +82,70 @@ def _send_language_data(
|
||||
data_length = response.data_length
|
||||
data_offset = response.data_offset
|
||||
chunk = language_data[data_offset : data_offset + data_length]
|
||||
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def change_language(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
language_data: bytes,
|
||||
show_display: bool | None = None,
|
||||
) -> "MessageType":
|
||||
data_length = len(language_data)
|
||||
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
if data_length > 0:
|
||||
assert isinstance(response, messages.TranslationDataRequest)
|
||||
response = _send_language_data(client, response, language_data)
|
||||
response = _send_language_data(session, response, language_data)
|
||||
assert isinstance(response, messages.Success)
|
||||
client.refresh_features() # changing the language in features
|
||||
session.refresh_features() # changing the language in features
|
||||
return response
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
|
||||
out = client.call(messages.ApplyFlags(flags=flags))
|
||||
client.refresh_features()
|
||||
def apply_flags(session: "Session", flags: int) -> "MessageType":
|
||||
out = session.call(messages.ApplyFlags(flags=flags))
|
||||
session.refresh_features()
|
||||
return out
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
||||
ret = client.call(messages.ChangePin(remove=remove))
|
||||
client.refresh_features()
|
||||
def change_pin(session: "Session", remove: bool = False) -> "MessageType":
|
||||
ret = session.call(messages.ChangePin(remove=remove))
|
||||
session.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
||||
ret = client.call(messages.ChangeWipeCode(remove=remove))
|
||||
client.refresh_features()
|
||||
def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType":
|
||||
ret = session.call(messages.ChangeWipeCode(remove=remove))
|
||||
session.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
session: "Session", operation: messages.SdProtectOperationType
|
||||
) -> "MessageType":
|
||||
ret = client.call(messages.SdProtect(operation=operation))
|
||||
client.refresh_features()
|
||||
ret = session.call(messages.SdProtect(operation=operation))
|
||||
session.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def wipe(client: "TrezorClient") -> "MessageType":
|
||||
ret = client.call(messages.WipeDevice())
|
||||
if not client.features.bootloader_mode:
|
||||
client.init_device()
|
||||
def wipe(session: "Session") -> "MessageType":
|
||||
|
||||
ret = session.call(messages.WipeDevice())
|
||||
# if not session.features.bootloader_mode:
|
||||
# session.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@session
|
||||
def recover(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
word_count: int = 24,
|
||||
passphrase_protection: bool = False,
|
||||
pin_protection: bool = True,
|
||||
@ -188,13 +181,13 @@ def recover(
|
||||
if type is None:
|
||||
type = messages.RecoveryType.NormalRecovery
|
||||
|
||||
if client.features.model == "1" and input_callback is None:
|
||||
if session.features.model == "1" and input_callback is None:
|
||||
raise RuntimeError("Input callback required for Trezor One")
|
||||
|
||||
if word_count not in (12, 18, 24):
|
||||
raise ValueError("Invalid word count. Use 12/18/24")
|
||||
|
||||
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||
if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||
raise RuntimeError(
|
||||
"Device already initialized. Call device.wipe() and try again."
|
||||
)
|
||||
@ -216,24 +209,23 @@ def recover(
|
||||
msg.label = label
|
||||
msg.u2f_counter = u2f_counter
|
||||
|
||||
res = client.call(msg)
|
||||
res = session.call(msg)
|
||||
|
||||
while isinstance(res, messages.WordRequest):
|
||||
try:
|
||||
assert input_callback is not None
|
||||
inp = input_callback(res.type)
|
||||
res = client.call(messages.WordAck(word=inp))
|
||||
res = session.call(messages.WordAck(word=inp))
|
||||
except Cancelled:
|
||||
res = client.call(messages.Cancel())
|
||||
res = session.call(messages.Cancel())
|
||||
|
||||
client.init_device()
|
||||
session.refresh_features()
|
||||
return res
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def reset(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
display_random: bool = False,
|
||||
strength: Optional[int] = None,
|
||||
passphrase_protection: bool = False,
|
||||
@ -257,13 +249,13 @@ def reset(
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if client.features.initialized:
|
||||
if session.features.initialized:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call wipe_device() and try again."
|
||||
)
|
||||
|
||||
if strength is None:
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
strength = 256
|
||||
else:
|
||||
strength = 128
|
||||
@ -280,25 +272,24 @@ def reset(
|
||||
backup_type=backup_type,
|
||||
)
|
||||
|
||||
resp = client.call(msg)
|
||||
resp = session.call(msg)
|
||||
if not isinstance(resp, messages.EntropyRequest):
|
||||
raise RuntimeError("Invalid response, expected EntropyRequest")
|
||||
|
||||
external_entropy = os.urandom(32)
|
||||
# LOG.debug("Computer generated entropy: " + external_entropy.hex())
|
||||
ret = client.call(messages.EntropyAck(entropy=external_entropy))
|
||||
client.init_device()
|
||||
ret = session.call(messages.EntropyAck(entropy=external_entropy))
|
||||
session.refresh_features() # TODO is necessary?
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def backup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
group_threshold: Optional[int] = None,
|
||||
groups: Iterable[tuple[int, int]] = (),
|
||||
) -> "MessageType":
|
||||
ret = client.call(
|
||||
ret = session.call(
|
||||
messages.BackupDevice(
|
||||
group_threshold=group_threshold,
|
||||
groups=[
|
||||
@ -307,37 +298,36 @@ def backup(
|
||||
],
|
||||
)
|
||||
)
|
||||
client.refresh_features()
|
||||
session.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def cancel_authorization(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.CancelAuthorization())
|
||||
def cancel_authorization(session: "Session") -> "MessageType":
|
||||
return session.call(messages.CancelAuthorization())
|
||||
|
||||
|
||||
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
|
||||
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType":
|
||||
resp = client.call(messages.UnlockPath(address_n=n))
|
||||
def unlock_path(session: "Session", n: "Address") -> "MessageType":
|
||||
resp = session.call(messages.UnlockPath(address_n=n))
|
||||
|
||||
# Cancel the UnlockPath workflow now that we have the authentication code.
|
||||
try:
|
||||
client.call(messages.Cancel())
|
||||
session.call(messages.Cancel())
|
||||
except Cancelled:
|
||||
return resp
|
||||
else:
|
||||
raise TrezorException("Unexpected response in UnlockPath flow")
|
||||
|
||||
|
||||
@session
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def reboot_to_bootloader(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
|
||||
firmware_header: Optional[bytes] = None,
|
||||
language_data: bytes = b"",
|
||||
) -> "MessageType":
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
messages.RebootToBootloader(
|
||||
boot_command=boot_command,
|
||||
firmware_header=firmware_header,
|
||||
@ -345,42 +335,37 @@ def reboot_to_bootloader(
|
||||
)
|
||||
)
|
||||
if isinstance(response, messages.TranslationDataRequest):
|
||||
response = _send_language_data(client, response, language_data)
|
||||
response = _send_language_data(session, response, language_data)
|
||||
return response
|
||||
|
||||
|
||||
@session
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def show_device_tutorial(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.ShowDeviceTutorial())
|
||||
|
||||
|
||||
@session
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def unlock_bootloader(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.UnlockBootloader())
|
||||
def show_device_tutorial(session: "Session") -> "MessageType":
|
||||
return session.call(messages.ShowDeviceTutorial())
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
@session
|
||||
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType":
|
||||
def unlock_bootloader(session: "Session") -> "MessageType":
|
||||
return session.call(messages.UnlockBootloader())
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType":
|
||||
"""Sets or clears the busy state of the device.
|
||||
|
||||
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
|
||||
Setting `expiry_ms=None` clears the busy state.
|
||||
"""
|
||||
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms))
|
||||
client.refresh_features()
|
||||
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms))
|
||||
session.refresh_features()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(messages.AuthenticityProof)
|
||||
def authenticate(client: "TrezorClient", challenge: bytes):
|
||||
return client.call(messages.AuthenticateDevice(challenge=challenge))
|
||||
def authenticate(session: "Session", challenge: bytes):
|
||||
return session.call(messages.AuthenticateDevice(challenge=challenge))
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def set_brightness(
|
||||
client: "TrezorClient", value: Optional[int] = None
|
||||
) -> "MessageType":
|
||||
return client.call(messages.SetBrightness(value=value))
|
||||
def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType":
|
||||
return session.call(messages.SetBrightness(value=value))
|
||||
|
@ -18,12 +18,12 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import b58decode, expect, session
|
||||
from .tools import b58decode, expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def name_to_number(name: str) -> int:
|
||||
@ -321,17 +321,16 @@ def parse_transaction_json(
|
||||
|
||||
@expect(messages.EosPublicKey)
|
||||
def get_public_key(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
session: "Session", n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
messages.EosGetPublicKey(address_n=n, show_display=show_display)
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: "Address",
|
||||
transaction: dict,
|
||||
chain_id: str,
|
||||
@ -347,11 +346,11 @@ def sign_tx(
|
||||
chunkify=chunkify,
|
||||
)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
|
||||
try:
|
||||
while isinstance(response, messages.EosTxActionRequest):
|
||||
response = client.call(actions.pop(0))
|
||||
response = session.call(actions.pop(0))
|
||||
except IndexError:
|
||||
# pop from empty list
|
||||
raise exceptions.TrezorException(
|
||||
|
@ -18,12 +18,12 @@ import re
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
||||
|
||||
from . import definitions, exceptions, messages
|
||||
from .tools import expect, prepare_message_bytes, session, unharden
|
||||
from .tools import expect, prepare_message_bytes, unharden
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def int_to_big_endian(value: int) -> bytes:
|
||||
@ -163,13 +163,13 @@ def network_from_address_n(
|
||||
|
||||
@expect(messages.EthereumAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
encoded_network: Optional[bytes] = None,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumGetAddress(
|
||||
address_n=n,
|
||||
show_display=show_display,
|
||||
@ -181,16 +181,15 @@ def get_address(
|
||||
|
||||
@expect(messages.EthereumPublicKey)
|
||||
def get_public_node(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
session: "Session", n: "Address", show_display: bool = False
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
nonce: int,
|
||||
gas_price: int,
|
||||
@ -226,13 +225,13 @@ def sign_tx(
|
||||
data, chunk = data[1024:], data[:1024]
|
||||
msg.data_initial_chunk = chunk
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
@ -247,9 +246,8 @@ def sign_tx(
|
||||
return response.signature_v, response.signature_r, response.signature_s
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx_eip1559(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
*,
|
||||
nonce: int,
|
||||
@ -282,13 +280,13 @@ def sign_tx_eip1559(
|
||||
chunkify=chunkify,
|
||||
)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
@ -299,13 +297,13 @@ def sign_tx_eip1559(
|
||||
|
||||
@expect(messages.EthereumMessageSignature)
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
encoded_network: Optional[bytes] = None,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumSignMessage(
|
||||
address_n=n,
|
||||
message=prepare_message_bytes(message),
|
||||
@ -317,7 +315,7 @@ def sign_message(
|
||||
|
||||
@expect(messages.EthereumTypedDataSignature)
|
||||
def sign_typed_data(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
data: Dict[str, Any],
|
||||
*,
|
||||
@ -333,7 +331,7 @@ def sign_typed_data(
|
||||
metamask_v4_compat=metamask_v4_compat,
|
||||
definitions=definitions,
|
||||
)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
# Sending all the types
|
||||
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
||||
@ -349,7 +347,7 @@ def sign_typed_data(
|
||||
members.append(struct_member)
|
||||
|
||||
request = messages.EthereumTypedDataStructAck(members=members)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
# Sending the whole message that should be signed
|
||||
while isinstance(response, messages.EthereumTypedDataValueRequest):
|
||||
@ -362,7 +360,7 @@ def sign_typed_data(
|
||||
member_typename = data["primaryType"]
|
||||
member_data = data["message"]
|
||||
else:
|
||||
client.cancel()
|
||||
# TODO session.cancel()
|
||||
raise exceptions.TrezorException("Root index can only be 0 or 1")
|
||||
|
||||
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
|
||||
@ -385,20 +383,20 @@ def sign_typed_data(
|
||||
encoded_data = encode_data(member_data, member_typename)
|
||||
|
||||
request = messages.EthereumTypedDataValueAck(value=encoded_data)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
signature: bytes,
|
||||
message: AnyStr,
|
||||
chunkify: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.EthereumVerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
@ -413,13 +411,13 @@ def verify_message(
|
||||
|
||||
@expect(messages.EthereumTypedDataSignature)
|
||||
def sign_typed_data_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
domain_hash: bytes,
|
||||
message_hash: Optional[bytes],
|
||||
encoded_network: Optional[bytes] = None,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumSignTypedHash(
|
||||
address_n=n,
|
||||
domain_separator_hash=domain_hash,
|
||||
|
@ -20,8 +20,8 @@ from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(
|
||||
@ -29,27 +29,27 @@ if TYPE_CHECKING:
|
||||
field="credentials",
|
||||
ret_type=List[messages.WebAuthnCredential],
|
||||
)
|
||||
def list_credentials(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.WebAuthnListResidentCredentials())
|
||||
def list_credentials(session: "Session") -> "MessageType":
|
||||
return session.call(messages.WebAuthnListResidentCredentials())
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
|
||||
return client.call(
|
||||
def add_credential(session: "Session", credential_id: bytes) -> "MessageType":
|
||||
return session.call(
|
||||
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
|
||||
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
|
||||
def remove_credential(session: "Session", index: int) -> "MessageType":
|
||||
return session.call(messages.WebAuthnRemoveResidentCredential(index=index))
|
||||
|
||||
|
||||
@expect(messages.Success, field="message", ret_type=str)
|
||||
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
|
||||
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
|
||||
def set_counter(session: "Session", u2f_counter: int) -> "MessageType":
|
||||
return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
|
||||
|
||||
|
||||
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
|
||||
def get_next_counter(client: "TrezorClient") -> "MessageType":
|
||||
return client.call(messages.GetNextU2FCounter())
|
||||
def get_next_counter(session: "Session") -> "MessageType":
|
||||
return session.call(messages.GetNextU2FCounter())
|
||||
|
@ -20,7 +20,7 @@ from hashlib import blake2s
|
||||
from typing_extensions import Protocol, TypeGuard
|
||||
|
||||
from .. import messages
|
||||
from ..tools import expect, session
|
||||
from ..tools import expect
|
||||
from .core import VendorFirmware
|
||||
from .legacy import LegacyFirmware, LegacyV2Firmware
|
||||
|
||||
@ -38,7 +38,7 @@ if True:
|
||||
from .vendor import * # noqa: F401, F403
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
T = t.TypeVar("T", bound="FirmwareType")
|
||||
|
||||
@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@session
|
||||
def update(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
data: bytes,
|
||||
progress_update: t.Callable[[int], t.Any] = lambda _: None,
|
||||
):
|
||||
if client.features.bootloader_mode is False:
|
||||
if session.features.bootloader_mode is False:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
||||
resp = client.call(messages.FirmwareErase(length=len(data)))
|
||||
resp = session.call(messages.FirmwareErase(length=len(data)))
|
||||
|
||||
# TREZORv1 method
|
||||
if isinstance(resp, messages.Success):
|
||||
resp = client.call(messages.FirmwareUpload(payload=data))
|
||||
resp = session.call(messages.FirmwareUpload(payload=data))
|
||||
progress_update(len(data))
|
||||
if isinstance(resp, messages.Success):
|
||||
return
|
||||
@ -97,7 +96,7 @@ def update(
|
||||
length = resp.length
|
||||
payload = data[resp.offset : resp.offset + length]
|
||||
digest = blake2s(payload).digest()
|
||||
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||
progress_update(length)
|
||||
|
||||
if isinstance(resp, messages.Success):
|
||||
@ -107,5 +106,5 @@ def update(
|
||||
|
||||
|
||||
@expect(messages.FirmwareHash, field="hash", ret_type=bytes)
|
||||
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]):
|
||||
return client.call(messages.GetFirmwareHash(challenge=challenge))
|
||||
def get_hash(session: "Session", challenge: t.Optional[bytes]):
|
||||
return session.call(messages.GetFirmwareHash(challenge=challenge))
|
||||
|
@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
@ -25,6 +26,7 @@ from typing_extensions import Self
|
||||
from . import messages, protobuf
|
||||
|
||||
T = TypeVar("T")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtobufMapping:
|
||||
@ -63,11 +65,21 @@ class ProtobufMapping:
|
||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||
if wire_type is None:
|
||||
raise ValueError("Cannot encode class without wire type")
|
||||
|
||||
LOG.debug("encoding wire type %d", wire_type)
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return wire_type, buf.getvalue()
|
||||
|
||||
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
||||
"""Serialize a Python protobuf class.
|
||||
|
||||
Returns the byte representation of the protobuf message.
|
||||
"""
|
||||
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return buf.getvalue()
|
||||
|
||||
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
||||
"""Deserialize a protobuf message into a Python class."""
|
||||
cls = self.type_to_class[msg_wire_type]
|
||||
@ -83,7 +95,9 @@ class ProtobufMapping:
|
||||
mapping = cls()
|
||||
|
||||
message_types = getattr(module, "MessageType")
|
||||
for entry in message_types:
|
||||
thp_message_types = getattr(module, "ThpMessageType")
|
||||
|
||||
for entry in (*message_types, *thp_message_types):
|
||||
msg_class = getattr(module, entry.name, None)
|
||||
if msg_class is None:
|
||||
raise ValueError(
|
||||
|
316
python/src/trezorlib/messages.py
generated
316
python/src/trezorlib/messages.py
generated
@ -43,6 +43,8 @@ class FailureType(IntEnum):
|
||||
PinMismatch = 12
|
||||
WipeCodeMismatch = 13
|
||||
InvalidSession = 14
|
||||
ThpUnallocatedSession = 15
|
||||
InvalidProtocol = 16
|
||||
FirmwareError = 99
|
||||
|
||||
|
||||
@ -400,6 +402,34 @@ class TezosBallotType(IntEnum):
|
||||
Pass = 2
|
||||
|
||||
|
||||
class ThpMessageType(IntEnum):
|
||||
ThpCreateNewSession = 1000
|
||||
ThpNewSession = 1001
|
||||
ThpStartPairingRequest = 1008
|
||||
ThpPairingPreparationsFinished = 1009
|
||||
ThpCredentialRequest = 1010
|
||||
ThpCredentialResponse = 1011
|
||||
ThpEndRequest = 1012
|
||||
ThpEndResponse = 1013
|
||||
ThpCodeEntryCommitment = 1016
|
||||
ThpCodeEntryChallenge = 1017
|
||||
ThpCodeEntryCpaceHost = 1018
|
||||
ThpCodeEntryCpaceTrezor = 1019
|
||||
ThpCodeEntryTag = 1020
|
||||
ThpCodeEntrySecret = 1021
|
||||
ThpQrCodeTag = 1024
|
||||
ThpQrCodeSecret = 1025
|
||||
ThpNfcUnidirectionalTag = 1032
|
||||
ThpNfcUnidirectionalSecret = 1033
|
||||
|
||||
|
||||
class ThpPairingMethod(IntEnum):
|
||||
NoMethod = 1
|
||||
CodeEntry = 2
|
||||
QrCode = 3
|
||||
NFC_Unidirectional = 4
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Initialize = 0
|
||||
Ping = 1
|
||||
@ -4100,6 +4130,7 @@ class DebugLinkGetState(protobuf.MessageType):
|
||||
1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None),
|
||||
3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE),
|
||||
4: protobuf.Field("thp_channel_id", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -4108,10 +4139,12 @@ class DebugLinkGetState(protobuf.MessageType):
|
||||
wait_word_list: Optional["bool"] = None,
|
||||
wait_word_pos: Optional["bool"] = None,
|
||||
wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE,
|
||||
thp_channel_id: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.wait_word_list = wait_word_list
|
||||
self.wait_word_pos = wait_word_pos
|
||||
self.wait_layout = wait_layout
|
||||
self.thp_channel_id = thp_channel_id
|
||||
|
||||
|
||||
class DebugLinkState(protobuf.MessageType):
|
||||
@ -4130,6 +4163,9 @@ class DebugLinkState(protobuf.MessageType):
|
||||
11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None),
|
||||
12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None),
|
||||
13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None),
|
||||
14: protobuf.Field("thp_pairing_code_entry_code", "uint32", repeated=False, required=False, default=None),
|
||||
15: protobuf.Field("thp_pairing_code_qr_code", "bytes", repeated=False, required=False, default=None),
|
||||
16: protobuf.Field("thp_pairing_code_nfc_unidirectional", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@ -4148,6 +4184,9 @@ class DebugLinkState(protobuf.MessageType):
|
||||
recovery_word_pos: Optional["int"] = None,
|
||||
reset_word_pos: Optional["int"] = None,
|
||||
mnemonic_type: Optional["BackupType"] = None,
|
||||
thp_pairing_code_entry_code: Optional["int"] = None,
|
||||
thp_pairing_code_qr_code: Optional["bytes"] = None,
|
||||
thp_pairing_code_nfc_unidirectional: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.tokens: Sequence["str"] = tokens if tokens is not None else []
|
||||
self.layout = layout
|
||||
@ -4162,6 +4201,9 @@ class DebugLinkState(protobuf.MessageType):
|
||||
self.recovery_word_pos = recovery_word_pos
|
||||
self.reset_word_pos = reset_word_pos
|
||||
self.mnemonic_type = mnemonic_type
|
||||
self.thp_pairing_code_entry_code = thp_pairing_code_entry_code
|
||||
self.thp_pairing_code_qr_code = thp_pairing_code_qr_code
|
||||
self.thp_pairing_code_nfc_unidirectional = thp_pairing_code_nfc_unidirectional
|
||||
|
||||
|
||||
class DebugLinkStop(protobuf.MessageType):
|
||||
@ -7824,6 +7866,280 @@ class TezosManagerTransfer(protobuf.MessageType):
|
||||
self.amount = amount
|
||||
|
||||
|
||||
class ThpDeviceProperties(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = None
|
||||
FIELDS = {
|
||||
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
|
||||
3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None),
|
||||
4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None),
|
||||
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
|
||||
internal_model: Optional["str"] = None,
|
||||
model_variant: Optional["int"] = None,
|
||||
bootloader_mode: Optional["bool"] = None,
|
||||
protocol_version: Optional["int"] = None,
|
||||
) -> None:
|
||||
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
|
||||
self.internal_model = internal_model
|
||||
self.model_variant = model_variant
|
||||
self.bootloader_mode = bootloader_mode
|
||||
self.protocol_version = protocol_version
|
||||
|
||||
|
||||
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = None
|
||||
FIELDS = {
|
||||
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
|
||||
host_pairing_credential: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
|
||||
self.host_pairing_credential = host_pairing_credential
|
||||
|
||||
|
||||
class ThpCreateNewSession(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1000
|
||||
FIELDS = {
|
||||
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
|
||||
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
passphrase: Optional["str"] = None,
|
||||
on_device: Optional["bool"] = None,
|
||||
derive_cardano: Optional["bool"] = None,
|
||||
) -> None:
|
||||
self.passphrase = passphrase
|
||||
self.on_device = on_device
|
||||
self.derive_cardano = derive_cardano
|
||||
|
||||
|
||||
class ThpNewSession(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1001
|
||||
FIELDS = {
|
||||
1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
new_session_id: Optional["int"] = None,
|
||||
) -> None:
|
||||
self.new_session_id = new_session_id
|
||||
|
||||
|
||||
class ThpStartPairingRequest(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1008
|
||||
FIELDS = {
|
||||
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_name: Optional["str"] = None,
|
||||
) -> None:
|
||||
self.host_name = host_name
|
||||
|
||||
|
||||
class ThpPairingPreparationsFinished(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1009
|
||||
|
||||
|
||||
class ThpCodeEntryCommitment(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1016
|
||||
FIELDS = {
|
||||
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
commitment: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.commitment = commitment
|
||||
|
||||
|
||||
class ThpCodeEntryChallenge(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1017
|
||||
FIELDS = {
|
||||
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
challenge: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.challenge = challenge
|
||||
|
||||
|
||||
class ThpCodeEntryCpaceHost(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1018
|
||||
FIELDS = {
|
||||
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cpace_host_public_key: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.cpace_host_public_key = cpace_host_public_key
|
||||
|
||||
|
||||
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1019
|
||||
FIELDS = {
|
||||
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cpace_trezor_public_key: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.cpace_trezor_public_key = cpace_trezor_public_key
|
||||
|
||||
|
||||
class ThpCodeEntryTag(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1020
|
||||
FIELDS = {
|
||||
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class ThpCodeEntrySecret(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1021
|
||||
FIELDS = {
|
||||
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.secret = secret
|
||||
|
||||
|
||||
class ThpQrCodeTag(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1024
|
||||
FIELDS = {
|
||||
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class ThpQrCodeSecret(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1025
|
||||
FIELDS = {
|
||||
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.secret = secret
|
||||
|
||||
|
||||
class ThpNfcUnidirectionalTag(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1032
|
||||
FIELDS = {
|
||||
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1033
|
||||
FIELDS = {
|
||||
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.secret = secret
|
||||
|
||||
|
||||
class ThpCredentialRequest(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1010
|
||||
FIELDS = {
|
||||
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_static_pubkey: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.host_static_pubkey = host_static_pubkey
|
||||
|
||||
|
||||
class ThpCredentialResponse(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1011
|
||||
FIELDS = {
|
||||
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
trezor_static_pubkey: Optional["bytes"] = None,
|
||||
credential: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.trezor_static_pubkey = trezor_static_pubkey
|
||||
self.credential = credential
|
||||
|
||||
|
||||
class ThpEndRequest(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1012
|
||||
|
||||
|
||||
class ThpEndResponse(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1013
|
||||
|
||||
|
||||
class ThpCredentialMetadata(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = None
|
||||
FIELDS = {
|
||||
|
@ -20,25 +20,25 @@ from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(messages.Entropy, field="entropy", ret_type=bytes)
|
||||
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
|
||||
return client.call(messages.GetEntropy(size=size))
|
||||
def get_entropy(session: "Session", size: int) -> "MessageType":
|
||||
return session.call(messages.GetEntropy(size=size))
|
||||
|
||||
|
||||
@expect(messages.SignedIdentity)
|
||||
def sign_identity(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
identity: messages.IdentityType,
|
||||
challenge_hidden: bytes,
|
||||
challenge_visual: str,
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SignIdentity(
|
||||
identity=identity,
|
||||
challenge_hidden=challenge_hidden,
|
||||
@ -50,12 +50,12 @@ def sign_identity(
|
||||
|
||||
@expect(messages.ECDHSessionKey)
|
||||
def get_ecdh_session_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
identity: messages.IdentityType,
|
||||
peer_public_key: bytes,
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetECDHSessionKey(
|
||||
identity=identity,
|
||||
peer_public_key=peer_public_key,
|
||||
@ -66,7 +66,7 @@ def get_ecdh_session_key(
|
||||
|
||||
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||
def encrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
@ -74,7 +74,7 @@ def encrypt_keyvalue(
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
key=key,
|
||||
@ -89,7 +89,7 @@ def encrypt_keyvalue(
|
||||
|
||||
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
@ -97,7 +97,7 @@ def decrypt_keyvalue(
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
key=key,
|
||||
@ -111,5 +111,5 @@ def decrypt_keyvalue(
|
||||
|
||||
|
||||
@expect(messages.Nonce, field="nonce", ret_type=bytes)
|
||||
def get_nonce(client: "TrezorClient"):
|
||||
return client.call(messages.GetNonce())
|
||||
def get_nonce(session: "Session"):
|
||||
return session.call(messages.GetNonce())
|
||||
|
@ -20,9 +20,9 @@ from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
# MAINNET = 0
|
||||
@ -33,13 +33,13 @@ if TYPE_CHECKING:
|
||||
|
||||
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.MoneroGetAddress(
|
||||
address_n=n,
|
||||
show_display=show_display,
|
||||
@ -51,10 +51,10 @@ def get_address(
|
||||
|
||||
@expect(messages.MoneroWatchKey)
|
||||
def get_watch_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
|
||||
)
|
||||
|
@ -21,9 +21,9 @@ from . import exceptions, messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||
@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
|
||||
|
||||
@expect(messages.NEMAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
network: int,
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.NEMGetAddress(
|
||||
address_n=n, network=network, show_display=show_display, chunkify=chunkify
|
||||
)
|
||||
@ -213,7 +213,7 @@ def get_address(
|
||||
|
||||
@expect(messages.NEMSignedTx)
|
||||
def sign_tx(
|
||||
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
|
||||
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
|
||||
) -> "MessageType":
|
||||
try:
|
||||
msg = create_sign_tx(transaction, chunkify=chunkify)
|
||||
@ -222,4 +222,4 @@ def sign_tx(
|
||||
|
||||
assert msg.transaction is not None
|
||||
msg.transaction.address_n = n
|
||||
return client.call(msg)
|
||||
return session.call(msg)
|
||||
|
@ -21,9 +21,9 @@ from .protobuf import dict_to_proto
|
||||
from .tools import dict_from_camelcase, expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
||||
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||
@ -31,12 +31,12 @@ REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||
|
||||
@expect(messages.RippleAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.RippleGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
)
|
||||
@ -45,14 +45,14 @@ def get_address(
|
||||
|
||||
@expect(messages.RippleSignedTx)
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
msg: messages.RippleSignTx,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
msg.address_n = address_n
|
||||
msg.chunkify = chunkify
|
||||
return client.call(msg)
|
||||
return session.call(msg)
|
||||
|
||||
|
||||
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
||||
|
@ -4,29 +4,29 @@ from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(messages.SolanaPublicKey)
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
show_display: bool,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
|
||||
)
|
||||
|
||||
|
||||
@expect(messages.SolanaAddress)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
show_display: bool,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaGetAddress(
|
||||
address_n=address_n,
|
||||
show_display=show_display,
|
||||
@ -37,12 +37,12 @@ def get_address(
|
||||
|
||||
@expect(messages.SolanaTxSignature)
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
serialized_tx: bytes,
|
||||
additional_info: Optional[messages.SolanaTxAdditionalInfo],
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaSignTx(
|
||||
address_n=address_n,
|
||||
serialized_tx=serialized_tx,
|
||||
|
@ -21,9 +21,9 @@ from . import exceptions, messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
StellarMessageType = Union[
|
||||
messages.StellarAccountMergeOp,
|
||||
@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
|
||||
|
||||
@expect(messages.StellarAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.StellarGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
)
|
||||
@ -338,7 +338,7 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
tx: messages.StellarSignTx,
|
||||
operations: List["StellarMessageType"],
|
||||
address_n: "Address",
|
||||
@ -354,10 +354,10 @@ def sign_tx(
|
||||
# 3. Receive a StellarTxOpRequest message
|
||||
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
||||
# 5. The final message received will be StellarSignedTx which is returned from this method
|
||||
resp = client.call(tx)
|
||||
resp = session.call(tx)
|
||||
try:
|
||||
while isinstance(resp, messages.StellarTxOpRequest):
|
||||
resp = client.call(operations.pop(0))
|
||||
resp = session.call(operations.pop(0))
|
||||
except IndexError:
|
||||
# pop from empty list
|
||||
raise exceptions.TrezorException(
|
||||
|
@ -20,19 +20,19 @@ from . import messages
|
||||
from .tools import expect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .protobuf import MessageType
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
@expect(messages.TezosAddress, field="address", ret_type=str)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.TezosGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
)
|
||||
@ -41,12 +41,12 @@ def get_address(
|
||||
|
||||
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.TezosGetPublicKey(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
)
|
||||
@ -55,11 +55,11 @@ def get_public_key(
|
||||
|
||||
@expect(messages.TezosSignedTx)
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
sign_tx_msg: messages.TezosSignTx,
|
||||
chunkify: bool = False,
|
||||
) -> "MessageType":
|
||||
sign_tx_msg.address_n = address_n
|
||||
sign_tx_msg.chunkify = chunkify
|
||||
return client.call(sign_tx_msg)
|
||||
return session.call(sign_tx_msg)
|
||||
|
@ -40,7 +40,7 @@ if TYPE_CHECKING:
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from . import client
|
||||
from .protobuf import MessageType
|
||||
@ -284,23 +284,6 @@ def expect(
|
||||
return decorator
|
||||
|
||||
|
||||
def session(
|
||||
f: "Callable[Concatenate[TrezorClient, P], R]",
|
||||
) -> "Callable[Concatenate[TrezorClient, P], R]":
|
||||
# Decorator wraps a BaseClient method
|
||||
# with session activation / deactivation
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
client.open()
|
||||
try:
|
||||
return f(client, *args, **kwargs)
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
return wrapped_f
|
||||
|
||||
|
||||
# de-camelcasifier
|
||||
# https://stackoverflow.com/a/1176023/222189
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -14,24 +14,18 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
import typing as t
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="Transport")
|
||||
T = t.TypeVar("T", bound="Transport")
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
||||
""".strip()
|
||||
|
||||
|
||||
MessagePayload = Tuple[int, bytes]
|
||||
MessagePayload = t.Tuple[int, bytes]
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException):
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Raw connection to a Trezor device.
|
||||
|
||||
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
|
||||
or USB-HID connection, or UDP socket of listening emulator(s).
|
||||
It can also enumerate devices available over this communication link, and return
|
||||
them as instances.
|
||||
|
||||
Transport instance is a thing that:
|
||||
- can be identified and requested by a string URI-like path
|
||||
- can open and close sessions, which enclose related operations
|
||||
- can read and write protobuf messages
|
||||
|
||||
You need to implement a new Transport subclass if you invent a new way to connect
|
||||
a Trezor device to a computer.
|
||||
"""
|
||||
|
||||
PATH_PREFIX: str
|
||||
ENABLED = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.get_path()
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
|
||||
) -> t.Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
|
||||
if device.get_path() == path:
|
||||
return device
|
||||
|
||||
if prefix_search and device.get_path().startswith(path):
|
||||
return device
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
def get_path(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def begin_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def find_debug(self: "T") -> "T":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["T"]:
|
||||
def open(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
or device.get_path() == path
|
||||
or (prefix_search and device.get_path().startswith(path))
|
||||
):
|
||||
return device
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
CHUNK_SIZE: t.ClassVar[int]
|
||||
|
||||
|
||||
def all_transports() -> Iterable[Type["Transport"]]:
|
||||
def all_transports() -> t.Iterable[t.Type["Transport"]]:
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[Type["Transport"], ...] = (
|
||||
transports: t.Tuple[t.Type["Transport"], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: Optional[Iterable["TrezorModel"]] = None,
|
||||
) -> Sequence["Transport"]:
|
||||
devices: List["Transport"] = []
|
||||
models: t.Iterable["TrezorModel"] | None = None,
|
||||
) -> t.Sequence["Transport"]:
|
||||
devices: t.List["Transport"] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
@ -145,9 +121,7 @@ def enumerate_devices(
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(
|
||||
path: Optional[str] = None, prefix_search: bool = False
|
||||
) -> "Transport":
|
||||
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
|
||||
if path is None:
|
||||
try:
|
||||
return next(iter(enumerate_devices()))
|
||||
|
@ -1,6 +1,6 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -14,24 +14,30 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
||||
import typing as t
|
||||
|
||||
import requests
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
PROTOCOL_VERSION_1 = 1
|
||||
PROTOCOL_VERSION_2 = 2
|
||||
|
||||
TREZORD_HOST = "http://127.0.0.1:21325"
|
||||
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
|
||||
|
||||
TREZORD_VERSION_MODERN = (2, 0, 25)
|
||||
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
|
||||
|
||||
CONNECTION = requests.Session()
|
||||
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
||||
@ -45,7 +51,7 @@ class BridgeException(TransportException):
|
||||
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
||||
|
||||
|
||||
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
||||
def call_bridge(path: str, data: str | None = None) -> requests.Response:
|
||||
url = TREZORD_HOST + "/" + path
|
||||
r = CONNECTION.post(url, data=data)
|
||||
if r.status_code != 200:
|
||||
@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
||||
return r
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
def get_bridge_version() -> t.Tuple[int, ...]:
|
||||
config = call_bridge("configure").json()
|
||||
version_tuple = tuple(map(int, config["version"].split(".")))
|
||||
return version_tuple < TREZORD_VERSION_MODERN
|
||||
return tuple(map(int, config["version"].split(".")))
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||
|
||||
|
||||
def supports_protocolV2() -> bool:
|
||||
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
|
||||
|
||||
|
||||
def detect_protocol_version(transport: "BridgeTransport") -> int:
|
||||
from .. import mapping, messages
|
||||
from ..messages import FailureType
|
||||
|
||||
protocol_version = PROTOCOL_VERSION_1
|
||||
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
||||
transport.deprecated_begin_session()
|
||||
transport.deprecated_write(request_type, request_data)
|
||||
|
||||
response_type, response_data = transport.deprecated_read()
|
||||
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
transport.deprecated_begin_session()
|
||||
if isinstance(response, messages.Failure):
|
||||
if response.code == FailureType.InvalidProtocol:
|
||||
LOG.debug("Protocol V2 detected")
|
||||
protocol_version = PROTOCOL_VERSION_2
|
||||
|
||||
return protocol_version
|
||||
|
||||
|
||||
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
||||
is_valid = (
|
||||
supports_protocolV2()
|
||||
or detect_protocol_version(transport) == PROTOCOL_VERSION_1
|
||||
)
|
||||
if not is_valid:
|
||||
LOG.warning("Detected unsupported Bridge transport!")
|
||||
return is_valid
|
||||
|
||||
|
||||
def filter_invalid_bridge_transports(
|
||||
transports: t.Iterable["BridgeTransport"],
|
||||
) -> t.Sequence["BridgeTransport"]:
|
||||
"""Filters out invalid bridge transports. Keeps only valid ones."""
|
||||
return [t for t in transports if _is_transport_valid(t)]
|
||||
|
||||
|
||||
class BridgeHandle:
|
||||
@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle):
|
||||
class BridgeHandleLegacy(BridgeHandle):
|
||||
def __init__(self, transport: "BridgeTransport") -> None:
|
||||
super().__init__(transport)
|
||||
self.request: Optional[str] = None
|
||||
self.request: str | None = None
|
||||
|
||||
def write_buf(self, buf: bytes) -> None:
|
||||
if self.request is not None:
|
||||
@ -112,13 +162,12 @@ class BridgeTransport(Transport):
|
||||
ENABLED: bool = True
|
||||
|
||||
def __init__(
|
||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
||||
self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
|
||||
) -> None:
|
||||
if legacy and debug:
|
||||
raise TransportException("Debugging not supported on legacy Bridge")
|
||||
|
||||
self.device = device
|
||||
self.session: Optional[str] = None
|
||||
self.session: str | None = device["session"]
|
||||
self.debug = debug
|
||||
self.legacy = legacy
|
||||
|
||||
@ -135,7 +184,7 @@ class BridgeTransport(Transport):
|
||||
raise TransportException("Debug device not available")
|
||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
||||
|
||||
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
||||
def _call(self, action: str, data: str | None = None) -> requests.Response:
|
||||
session = self.session or "null"
|
||||
uri = action + "/" + str(session)
|
||||
if self.debug:
|
||||
@ -144,17 +193,20 @@ class BridgeTransport(Transport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["BridgeTransport"]:
|
||||
cls, _models: t.Iterable["TrezorModel"] | None = None
|
||||
) -> t.Iterable["BridgeTransport"]:
|
||||
try:
|
||||
legacy = is_legacy_bridge()
|
||||
return [
|
||||
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json()
|
||||
return filter_invalid_bridge_transports(
|
||||
[
|
||||
BridgeTransport(dev, legacy)
|
||||
for dev in call_bridge("enumerate").json()
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def begin_session(self) -> None:
|
||||
def deprecated_begin_session(self) -> None:
|
||||
try:
|
||||
data = self._call("acquire/" + self.device["path"])
|
||||
except BridgeException as e:
|
||||
@ -163,18 +215,32 @@ class BridgeTransport(Transport):
|
||||
raise
|
||||
self.session = data.json()["session"]
|
||||
|
||||
def end_session(self) -> None:
|
||||
def deprecated_end_session(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
self._call("release")
|
||||
self.session = None
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
def deprecated_write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
self.handle.write_buf(header + message_data)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
def deprecated_read(self) -> MessagePayload:
|
||||
data = self.handle.read_buf()
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
return msg_type, data[headerlen : headerlen + datalen]
|
||||
|
||||
def open(self) -> None:
|
||||
pass
|
||||
# TODO self.handle.open()
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
# TODO self.handle.close()
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :)
|
||||
self.handle.write_buf(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes: # TODO check if it works :)
|
||||
return self.handle.read_buf()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -14,15 +14,16 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
import typing as t
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -35,23 +36,61 @@ except Exception as e:
|
||||
HID_IMPORTED = False
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
HidDevice = t.Dict[str, t.Any]
|
||||
HidDeviceHandle = t.Any
|
||||
|
||||
|
||||
class HidHandle:
|
||||
def __init__(
|
||||
self, path: bytes, serial: str, probe_hid_version: bool = False
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.serial = serial
|
||||
class HidTransport(Transport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
|
||||
self.device = device
|
||||
self.device_path = device["path"]
|
||||
self.device_serial_number = device["serial_number"]
|
||||
self.handle: HidDeviceHandle = None
|
||||
self.hid_version = None if probe_hid_version else 2
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
|
||||
) -> t.Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: t.List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
self.handle.open_path(self.device_path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (UDEV_RULES_STR,)
|
||||
@ -62,11 +101,11 @@ class HidHandle:
|
||||
# and we wouldn't even know.
|
||||
# So we check that the serial matches what we expect.
|
||||
serial = self.handle.get_serial_number_string()
|
||||
if serial != self.serial:
|
||||
if serial != self.device_serial_number:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
raise TransportException(
|
||||
f"Unexpected device {serial} on path {self.path.decode()}"
|
||||
f"Unexpected device {serial} on path {self.device_path.decode()}"
|
||||
)
|
||||
|
||||
self.handle.set_nonblocking(True)
|
||||
@ -77,7 +116,7 @@ class HidHandle:
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
# reload serial, because device.wipe() can reset it
|
||||
self.serial = self.handle.get_serial_number_string()
|
||||
self.device_serial_number = self.handle.get_serial_number_string()
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
@ -115,53 +154,6 @@ class HidHandle:
|
||||
raise TransportException("Unknown HID version")
|
||||
|
||||
|
||||
class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice) -> None:
|
||||
self.device = device
|
||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self.handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
||||
) -> Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||
|
||||
|
@ -1,165 +0,0 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from . import MessagePayload, Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
V2_FIRST_CHUNK = 0x01
|
||||
V2_NEXT_CHUNK = 0x02
|
||||
V2_BEGIN_SESSION = 0x03
|
||||
V2_END_SESSION = 0x04
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Handle(StructuralType):
|
||||
"""PEP 544 structural type for Handle functionality.
|
||||
(called a "Protocol" in the proposed PEP, name which is impractical here)
|
||||
|
||||
Handle is a "physical" layer for a protocol.
|
||||
It can open/close a connection and read/write bare data in 64-byte chunks.
|
||||
|
||||
Functionally we gain nothing from making this an (abstract) base class for handle
|
||||
implementations, so this definition is for type hinting purposes only. You can,
|
||||
but don't have to, inherit from it.
|
||||
"""
|
||||
|
||||
def open(self) -> None: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def read_chunk(self) -> bytes: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
|
||||
class Protocol:
|
||||
"""Wire protocol that can communicate with a Trezor device, given a Handle.
|
||||
|
||||
A Protocol implements the part of the Transport API that relates to communicating
|
||||
logical messages over a physical layer. It is a thing that can:
|
||||
- open and close sessions,
|
||||
- send and receive protobuf messages,
|
||||
given the ability to:
|
||||
- open and close physical connections,
|
||||
- and send and receive binary chunks.
|
||||
|
||||
For now, the class also handles session counting and opening the underlying Handle.
|
||||
This will probably be removed in the future.
|
||||
|
||||
We will need a new Protocol class if we change the way a Trezor device encapsulates
|
||||
its messages.
|
||||
"""
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.handle = handle
|
||||
self.session_counter = 0
|
||||
|
||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||
def begin_session(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.handle.open()
|
||||
self.session_counter += 1
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolBasedTransport(Transport):
|
||||
"""Transport that implements its communications through a Protocol.
|
||||
|
||||
Intended as a base class for implementations that proxy their communication
|
||||
operations to a Protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: Protocol) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
self.protocol.write(message_type, message_data)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
return self.protocol.read()
|
||||
|
||||
def begin_session(self) -> None:
|
||||
self.protocol.begin_session()
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.protocol.end_session()
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
210
python/src/trezorlib/transport/session.py
Normal file
210
python/src/trezorlib/transport/session.py
Normal file
@ -0,0 +1,210 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from .. import exceptions, messages, models
|
||||
from .thp.protocol_v1 import ProtocolV1
|
||||
from .thp.protocol_v2 import ProtocolV2
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Session:
|
||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
|
||||
def __init__(
|
||||
self, client: TrezorClient, id: bytes, passphrase: str | object | None = None
|
||||
) -> None:
|
||||
self.client = client
|
||||
self._id = id
|
||||
self.passphrase = passphrase
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool
|
||||
) -> Session:
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, msg: t.Any) -> t.Any:
|
||||
# TODO self.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
|
||||
while True:
|
||||
if isinstance(resp, messages.PinMatrixRequest):
|
||||
if self.pin_callback is None:
|
||||
raise Exception # TODO
|
||||
resp = self.pin_callback(self, resp)
|
||||
elif isinstance(resp, messages.PassphraseRequest):
|
||||
if self.passphrase_callback is None:
|
||||
raise Exception # TODO
|
||||
resp = self.passphrase_callback(self, resp)
|
||||
elif isinstance(resp, messages.ButtonRequest):
|
||||
if self.button_callback is None:
|
||||
raise Exception # TODO
|
||||
resp = self.button_callback(self, resp)
|
||||
elif isinstance(resp, messages.Failure):
|
||||
if resp.code == messages.FailureType.ActionCancelled:
|
||||
raise exceptions.Cancelled
|
||||
raise exceptions.TrezorFailure(resp)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def call_raw(self, msg: t.Any) -> t.Any:
|
||||
self._write(msg)
|
||||
return self._read()
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def refresh_features(self) -> None:
|
||||
self.client.refresh_features()
|
||||
|
||||
def end(self) -> t.Any:
|
||||
return self.call(messages.EndSession())
|
||||
|
||||
def ping(self, message: str, button_protection: bool | None = None) -> str:
|
||||
resp: messages.Success = self.call(
|
||||
messages.Ping(message=message, button_protection=button_protection)
|
||||
)
|
||||
return resp.message or ""
|
||||
|
||||
@property
|
||||
def features(self) -> messages.Features:
|
||||
return self.client.features
|
||||
|
||||
@property
|
||||
def model(self) -> models.TrezorModel:
|
||||
return self.client.model
|
||||
|
||||
@property
|
||||
def version(self) -> t.Tuple[int, int, int]:
|
||||
return self.client.version
|
||||
|
||||
@property
|
||||
def id(self) -> bytes:
|
||||
return self._id
|
||||
|
||||
@id.setter
|
||||
def id(self, value: bytes) -> None:
|
||||
if not isinstance(value, bytes):
|
||||
raise ValueError("id must be of type bytes")
|
||||
self._id = value
|
||||
|
||||
|
||||
class SessionV1(Session):
|
||||
derive_cardano: bool | None = False
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
client: TrezorClient,
|
||||
passphrase: str | object = "",
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, id=session_id or b"")
|
||||
|
||||
session._init_callbacks()
|
||||
session.passphrase = passphrase
|
||||
session.derive_cardano = derive_cardano
|
||||
session.init_session(session.derive_cardano)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, session_id)
|
||||
session.init_session()
|
||||
return session
|
||||
|
||||
def _init_callbacks(self) -> None:
|
||||
self.button_callback = self.client.button_callback
|
||||
if self.button_callback is None:
|
||||
self.button_callback = _callback_button
|
||||
self.pin_callback = self.client.pin_callback
|
||||
self.passphrase_callback = self.client.passphrase_callback
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
self.client.protocol.write(msg)
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
return self.client.protocol.read()
|
||||
|
||||
def init_session(self, derive_cardano: bool | None = None):
|
||||
if self.id == b"":
|
||||
session_id = None
|
||||
else:
|
||||
session_id = self.id
|
||||
resp: messages.Features = self.call_raw(
|
||||
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
||||
)
|
||||
if isinstance(self.passphrase, str):
|
||||
self.passphrase_callback = _send_passphrase
|
||||
self._id = resp.session_id
|
||||
|
||||
|
||||
def _send_passphrase(session: Session, resp: t.Any) -> None:
|
||||
assert isinstance(session.passphrase, str)
|
||||
return session.call(messages.PassphraseAck(passphrase=session.passphrase))
|
||||
|
||||
|
||||
def _callback_button(session: Session, msg: t.Any) -> t.Any:
|
||||
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
|
||||
return session.call(messages.ButtonAck())
|
||||
|
||||
|
||||
class SessionV2(Session):
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
|
||||
) -> SessionV2:
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
session = cls(client, b"\x00")
|
||||
new_session: messages.ThpNewSession = session.call(
|
||||
messages.ThpCreateNewSession(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano
|
||||
)
|
||||
)
|
||||
assert new_session.new_session_id is not None
|
||||
session_id = new_session.new_session_id
|
||||
session.update_id_and_sid(session_id.to_bytes(1, "big"))
|
||||
return session
|
||||
|
||||
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||
super().__init__(client, id)
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
|
||||
self.pin_callback = client.pin_callback
|
||||
self.button_callback = client.button_callback
|
||||
if self.button_callback is None:
|
||||
self.button_callback = _callback_button
|
||||
self.channel: ProtocolV2 = client.protocol.get_channel()
|
||||
self.update_id_and_sid(id)
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
LOG.debug("writing message %s", type(msg))
|
||||
self.channel.write(self.sid, msg)
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
msg = self.channel.read(self.sid)
|
||||
LOG.debug("reading message %s", type(msg))
|
||||
return msg
|
||||
|
||||
def update_id_and_sid(self, id: bytes) -> None:
|
||||
self._id = id
|
||||
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid
|
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
@ -0,0 +1,102 @@
|
||||
# from storage.cache_thp import ChannelCache
|
||||
# from trezor import log
|
||||
# from trezor.wire.thp import ThpError
|
||||
|
||||
|
||||
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
|
||||
# """
|
||||
# Checks if:
|
||||
# - an ACK message is expected
|
||||
# - the received ACK message acknowledges correct sequence number (bit)
|
||||
# """
|
||||
# if not _is_ack_expected(cache):
|
||||
# return False
|
||||
|
||||
# if not _has_ack_correct_sync_bit(cache, ack_bit):
|
||||
# return False
|
||||
|
||||
# return True
|
||||
|
||||
|
||||
# def _is_ack_expected(cache: ChannelCache) -> bool:
|
||||
# is_expected: bool = not is_sending_allowed(cache)
|
||||
# if __debug__ and not is_expected:
|
||||
# log.debug(__name__, "Received unexpected ACK message")
|
||||
# return is_expected
|
||||
|
||||
|
||||
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
|
||||
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
|
||||
# if __debug__ and not is_correct:
|
||||
# log.debug(__name__, "Received ACK message with wrong ack bit")
|
||||
# return is_correct
|
||||
|
||||
|
||||
# def is_sending_allowed(cache: ChannelCache) -> bool:
|
||||
# """
|
||||
# Checks whether sending a message in the provided channel is allowed.
|
||||
|
||||
# Note: Sending a message in a channel before receipt of ACK message for the previously
|
||||
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
|
||||
# """
|
||||
# return bool(cache.sync >> 7)
|
||||
|
||||
|
||||
# def get_send_seq_bit(cache: ChannelCache) -> int:
|
||||
# """
|
||||
# Returns the sequential number (bit) of the next message to be sent
|
||||
# in the provided channel.
|
||||
# """
|
||||
# return (cache.sync & 0x20) >> 5
|
||||
|
||||
|
||||
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
|
||||
# """
|
||||
# Returns the (expected) sequential number (bit) of the next message
|
||||
# to be received in the provided channel.
|
||||
# """
|
||||
# return (cache.sync & 0x40) >> 6
|
||||
|
||||
|
||||
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
|
||||
# """
|
||||
# Set the flag whether sending a message in this channel is allowed or not.
|
||||
# """
|
||||
# cache.sync &= 0x7F
|
||||
# if sending_allowed:
|
||||
# cache.sync |= 0x80
|
||||
|
||||
|
||||
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||
# """
|
||||
# Set the expected sequential number (bit) of the next message to be received
|
||||
# in the provided channel
|
||||
# """
|
||||
# if __debug__:
|
||||
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
|
||||
# if seq_bit not in (0, 1):
|
||||
# raise ThpError("Unexpected receive sync bit")
|
||||
|
||||
# # set second bit to "seq_bit" value
|
||||
# cache.sync &= 0xBF
|
||||
# if seq_bit:
|
||||
# cache.sync |= 0x40
|
||||
|
||||
|
||||
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||
# if seq_bit not in (0, 1):
|
||||
# raise ThpError("Unexpected send seq bit")
|
||||
# if __debug__:
|
||||
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
|
||||
# # set third bit to "seq_bit" value
|
||||
# cache.sync &= 0xDF
|
||||
# if seq_bit:
|
||||
# cache.sync |= 0x20
|
||||
|
||||
|
||||
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
|
||||
# """
|
||||
# Set the sequential bit of the "next message to be send" to the opposite value,
|
||||
# i.e. 1 -> 0 and 0 -> 1
|
||||
# """
|
||||
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))
|
40
python/src/trezorlib/transport/thp/channel_data.py
Normal file
40
python/src/trezorlib/transport/thp/channel_data.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from binascii import hexlify
|
||||
|
||||
|
||||
class ChannelData:
|
||||
def __init__(
|
||||
self,
|
||||
protocol_version: int,
|
||||
transport_path: str,
|
||||
channel_id: int,
|
||||
key_request: bytes,
|
||||
key_response: bytes,
|
||||
nonce_request: int,
|
||||
nonce_response: int,
|
||||
sync_bit_send: int,
|
||||
sync_bit_receive: int,
|
||||
) -> None:
|
||||
self.protocol_version: int = protocol_version
|
||||
self.transport_path: str = transport_path
|
||||
self.channel_id: int = channel_id
|
||||
self.key_request: str = hexlify(key_request).decode()
|
||||
self.key_response: str = hexlify(key_response).decode()
|
||||
self.nonce_request: int = nonce_request
|
||||
self.nonce_response: int = nonce_response
|
||||
self.sync_bit_receive: int = sync_bit_receive
|
||||
self.sync_bit_send: int = sync_bit_send
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"protocol_version": self.protocol_version,
|
||||
"transport_path": self.transport_path,
|
||||
"channel_id": self.channel_id,
|
||||
"key_request": self.key_request,
|
||||
"key_response": self.key_response,
|
||||
"nonce_request": self.nonce_request,
|
||||
"nonce_response": self.nonce_response,
|
||||
"sync_bit_send": self.sync_bit_send,
|
||||
"sync_bit_receive": self.sync_bit_receive,
|
||||
}
|
146
python/src/trezorlib/transport/thp/channel_database.py
Normal file
146
python/src/trezorlib/transport/thp/channel_database.py
Normal file
@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ..thp.channel_data import ChannelData
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
db: "ChannelDatabase | None" = None
|
||||
|
||||
|
||||
def get_channel_db() -> ChannelDatabase:
|
||||
if db is None:
|
||||
set_channel_database(should_not_store=True)
|
||||
assert db is not None
|
||||
return db
|
||||
|
||||
|
||||
class ChannelDatabase:
|
||||
|
||||
def load_stored_channels(self) -> t.List[ChannelData]: ...
|
||||
def clear_stored_channels(self) -> None: ...
|
||||
def read_all_channels(self) -> t.List: ...
|
||||
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
|
||||
def save_channel(self, new_channel: ProtocolAndChannel): ...
|
||||
def remove_channel(self, transport_path: str) -> None: ...
|
||||
|
||||
|
||||
class DummyChannelDatabase(ChannelDatabase):
|
||||
|
||||
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||
return []
|
||||
|
||||
def clear_stored_channels(self) -> None:
|
||||
pass
|
||||
|
||||
def read_all_channels(self) -> t.List:
|
||||
return []
|
||||
|
||||
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||
return
|
||||
|
||||
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||
pass
|
||||
|
||||
def remove_channel(self, transport_path: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class JsonChannelDatabase(ChannelDatabase):
|
||||
def __init__(self, data_path: str) -> None:
|
||||
self.data_path = data_path
|
||||
super().__init__()
|
||||
|
||||
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||
dicts = self.read_all_channels()
|
||||
return [dict_to_channel_data(d) for d in dicts]
|
||||
|
||||
def clear_stored_channels(self) -> None:
|
||||
LOG.debug("Clearing contents of %s", self.data_path)
|
||||
with open(self.data_path, "w") as f:
|
||||
json.dump([], f)
|
||||
try:
|
||||
os.remove(self.data_path)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
|
||||
|
||||
def read_all_channels(self) -> t.List:
|
||||
ensure_file_exists(self.data_path)
|
||||
with open(self.data_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||
LOG.debug("saving all channels")
|
||||
with open(self.data_path, "w") as f:
|
||||
json.dump(channels, f, indent=4)
|
||||
|
||||
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||
|
||||
LOG.debug("save channel")
|
||||
channels = self.read_all_channels()
|
||||
transport_path = new_channel.transport.get_path()
|
||||
|
||||
# If the channel is found in database: replace the old entry by the new
|
||||
for i, channel in enumerate(channels):
|
||||
if channel["transport_path"] == transport_path:
|
||||
LOG.debug("Modified channel entry for %s", transport_path)
|
||||
channels[i] = new_channel.get_channel_data().to_dict()
|
||||
self.save_all_channels(channels)
|
||||
return
|
||||
|
||||
# Channel was not found: add a new channel entry
|
||||
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||
channels.append(new_channel.get_channel_data().to_dict())
|
||||
self.save_all_channels(channels)
|
||||
|
||||
def remove_channel(self, transport_path: str) -> None:
|
||||
LOG.debug(
|
||||
"Removing channel with path %s from the channel database.",
|
||||
transport_path,
|
||||
)
|
||||
channels = self.read_all_channels()
|
||||
remaining_channels = [
|
||||
ch for ch in channels if ch["transport_path"] != transport_path
|
||||
]
|
||||
self.save_all_channels(remaining_channels)
|
||||
|
||||
|
||||
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||
return ChannelData(
|
||||
protocol_version=dict["protocol_version"],
|
||||
transport_path=dict["transport_path"],
|
||||
channel_id=dict["channel_id"],
|
||||
key_request=bytes.fromhex(dict["key_request"]),
|
||||
key_response=bytes.fromhex(dict["key_response"]),
|
||||
nonce_request=dict["nonce_request"],
|
||||
nonce_response=dict["nonce_response"],
|
||||
sync_bit_send=dict["sync_bit_send"],
|
||||
sync_bit_receive=dict["sync_bit_receive"],
|
||||
)
|
||||
|
||||
|
||||
def ensure_file_exists(file_path: str) -> None:
|
||||
LOG.debug("checking if file %s exists", file_path)
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
LOG.debug("File %s does not exist. Creating a new one.", file_path)
|
||||
with open(file_path, "w") as f:
|
||||
json.dump([], f)
|
||||
|
||||
|
||||
def set_channel_database(should_not_store: bool):
|
||||
global db
|
||||
if should_not_store:
|
||||
db = DummyChannelDatabase()
|
||||
else:
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
APP_NAME = "@trezor" # TODO
|
||||
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
|
||||
|
||||
db = JsonChannelDatabase(DATA_PATH)
|
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
@ -0,0 +1,19 @@
|
||||
import zlib
|
||||
|
||||
CHECKSUM_LENGTH = 4
|
||||
|
||||
|
||||
def compute(data: bytes) -> bytes:
|
||||
"""
|
||||
Returns a CRC-32 checksum of the provided `data`.
|
||||
"""
|
||||
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
|
||||
|
||||
|
||||
def is_valid(checksum: bytes, data: bytes) -> bool:
|
||||
"""
|
||||
Checks whether the CRC-32 checksum of the `data` is the same
|
||||
as the checksum provided in `checksum`.
|
||||
"""
|
||||
data_checksum = compute(data)
|
||||
return checksum == data_checksum
|
59
python/src/trezorlib/transport/thp/control_byte.py
Normal file
59
python/src/trezorlib/transport/thp/control_byte.py
Normal file
@ -0,0 +1,59 @@
|
||||
CODEC_V1 = 0x3F
|
||||
CONTINUATION_PACKET = 0x80
|
||||
HANDSHAKE_INIT_REQ = 0x00
|
||||
HANDSHAKE_INIT_RES = 0x01
|
||||
HANDSHAKE_COMP_REQ = 0x02
|
||||
HANDSHAKE_COMP_RES = 0x03
|
||||
ENCRYPTED_TRANSPORT = 0x04
|
||||
|
||||
CONTINUATION_PACKET_MASK = 0x80
|
||||
ACK_MASK = 0xF7
|
||||
DATA_MASK = 0xE7
|
||||
|
||||
ACK_MESSAGE = 0x20
|
||||
_ERROR = 0x42
|
||||
CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x41
|
||||
|
||||
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
TREZOR_STATE_PAIRED = b"\x01"
|
||||
|
||||
|
||||
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
|
||||
if seq_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if seq_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise Exception("Unexpected sequence bit")
|
||||
|
||||
|
||||
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
|
||||
if ack_bit == 0:
|
||||
return ctrl_byte & 0xF7
|
||||
if ack_bit == 1:
|
||||
return ctrl_byte | 0x08
|
||||
raise Exception("Unexpected acknowledgement bit")
|
||||
|
||||
|
||||
def get_seq_bit(ctrl_byte: int) -> int:
|
||||
return (ctrl_byte & 0x10) >> 4
|
||||
|
||||
|
||||
def is_ack(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||
|
||||
|
||||
def is_continuation(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
|
||||
|
||||
|
||||
def is_encrypted_transport(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||
|
||||
|
||||
def is_handshake_init_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
|
||||
|
||||
|
||||
def is_handshake_comp_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ
|
116
python/src/trezorlib/transport/thp/curve25519.py
Normal file
116
python/src/trezorlib/transport/thp/curve25519.py
Normal file
@ -0,0 +1,116 @@
|
||||
from typing import Tuple
|
||||
|
||||
p = 2**255 - 19
|
||||
J = 486662
|
||||
|
||||
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
|
||||
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
|
||||
a24 = 121666 # (J + 2) // 4
|
||||
|
||||
|
||||
def decode_scalar(scalar: bytes) -> int:
|
||||
# decodeScalar25519 from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
|
||||
if len(scalar) != 32:
|
||||
raise ValueError("Invalid length of scalar")
|
||||
|
||||
array = bytearray(scalar)
|
||||
array[0] &= 248
|
||||
array[31] &= 127
|
||||
array[31] |= 64
|
||||
|
||||
return int.from_bytes(array, "little")
|
||||
|
||||
|
||||
def decode_coordinate(coordinate: bytes) -> int:
|
||||
# decodeUCoordinate from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
if len(coordinate) != 32:
|
||||
raise ValueError("Invalid length of coordinate")
|
||||
|
||||
array = bytearray(coordinate)
|
||||
array[-1] &= 0x7F
|
||||
return int.from_bytes(array, "little") % p
|
||||
|
||||
|
||||
def encode_coordinate(coordinate: int) -> bytes:
|
||||
# encodeUCoordinate from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
return coordinate.to_bytes(32, "little")
|
||||
|
||||
|
||||
def get_private_key(secret: bytes) -> bytes:
|
||||
return decode_scalar(secret).to_bytes(32, "little")
|
||||
|
||||
|
||||
def get_public_key(private_key: bytes) -> bytes:
|
||||
base_point = int.to_bytes(9, 32, "little")
|
||||
return multiply(private_key, base_point)
|
||||
|
||||
|
||||
def multiply(private_scalar: bytes, public_point: bytes):
|
||||
# X25519 from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
|
||||
def ladder_operation(
|
||||
x1: int, x2: int, z2: int, x3: int, z3: int
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
|
||||
# (x4, z4) = 2 * (x2, z2)
|
||||
# (x5, z5) = (x2, z2) + (x3, z3)
|
||||
# where (x1, 1) = (x3, z3) - (x2, z2)
|
||||
|
||||
a = (x2 + z2) % p
|
||||
aa = (a * a) % p
|
||||
b = (x2 - z2) % p
|
||||
bb = (b * b) % p
|
||||
e = (aa - bb) % p
|
||||
c = (x3 + z3) % p
|
||||
d = (x3 - z3) % p
|
||||
da = (d * a) % p
|
||||
cb = (c * b) % p
|
||||
t0 = (da + cb) % p
|
||||
x5 = (t0 * t0) % p
|
||||
t1 = (da - cb) % p
|
||||
t2 = (t1 * t1) % p
|
||||
z5 = (x1 * t2) % p
|
||||
x4 = (aa * bb) % p
|
||||
t3 = (a24 * e) % p
|
||||
t4 = (bb + t3) % p
|
||||
z4 = (e * t4) % p
|
||||
|
||||
return x4, z4, x5, z5
|
||||
|
||||
def conditional_swap(first: int, second: int, condition: int):
|
||||
# Returns (second, first) if condition is true and (first, second) otherwise
|
||||
# Must be implemented in a way that it is constant time
|
||||
true_mask = -condition
|
||||
false_mask = ~true_mask
|
||||
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
|
||||
first & true_mask
|
||||
)
|
||||
|
||||
k = decode_scalar(private_scalar)
|
||||
u = decode_coordinate(public_point)
|
||||
|
||||
x_1 = u
|
||||
x_2 = 1
|
||||
z_2 = 0
|
||||
x_3 = u
|
||||
z_3 = 1
|
||||
swap = 0
|
||||
|
||||
for i in reversed(range(256)):
|
||||
bit = (k >> i) & 1
|
||||
swap = bit ^ swap
|
||||
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||
swap = bit
|
||||
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
|
||||
|
||||
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||
|
||||
x = pow(z_2, p - 2, p) * x_2 % p
|
||||
return encode_coordinate(x)
|
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
@ -0,0 +1,82 @@
|
||||
import struct
|
||||
|
||||
CODEC_V1 = 0x3F
|
||||
CONTINUATION_PACKET = 0x80
|
||||
HANDSHAKE_INIT_REQ = 0x00
|
||||
HANDSHAKE_INIT_RES = 0x01
|
||||
HANDSHAKE_COMP_REQ = 0x02
|
||||
HANDSHAKE_COMP_RES = 0x03
|
||||
ENCRYPTED_TRANSPORT = 0x04
|
||||
|
||||
CONTINUATION_PACKET_MASK = 0x80
|
||||
ACK_MASK = 0xF7
|
||||
DATA_MASK = 0xE7
|
||||
|
||||
ACK_MESSAGE = 0x20
|
||||
_ERROR = 0x42
|
||||
CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x41
|
||||
|
||||
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
TREZOR_STATE_PAIRED = b"\x01"
|
||||
|
||||
BROADCAST_CHANNEL_ID = 0xFFFF
|
||||
|
||||
|
||||
class MessageHeader:
|
||||
format_str_init = ">BHH"
|
||||
format_str_cont = ">BH"
|
||||
|
||||
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.data_length = length
|
||||
|
||||
def to_bytes_init(self) -> bytes:
|
||||
return struct.pack(
|
||||
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
|
||||
)
|
||||
|
||||
def to_bytes_cont(self) -> bytes:
|
||||
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
|
||||
|
||||
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||
struct.pack_into(
|
||||
self.format_str_init,
|
||||
buffer,
|
||||
buffer_offset,
|
||||
self.ctrl_byte,
|
||||
self.cid,
|
||||
self.data_length,
|
||||
)
|
||||
|
||||
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||
struct.pack_into(
|
||||
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
|
||||
)
|
||||
|
||||
def is_ack(self) -> bool:
|
||||
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||
|
||||
def is_channel_allocation_response(self):
|
||||
return (
|
||||
self.cid == BROADCAST_CHANNEL_ID
|
||||
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
|
||||
)
|
||||
|
||||
def is_handshake_init_response(self) -> bool:
|
||||
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
|
||||
|
||||
def is_handshake_comp_response(self) -> bool:
|
||||
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
|
||||
|
||||
def is_encrypted_transport(self) -> bool:
|
||||
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||
|
||||
@classmethod
|
||||
def get_error_header(cls, cid: int, length: int):
|
||||
return cls(_ERROR, cid, length)
|
||||
|
||||
@classmethod
|
||||
def get_channel_allocation_request_header(cls, length: int):
|
||||
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)
|
32
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
32
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
from ..thp.channel_data import ChannelData
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolAndChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_data: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
self.channel_keys = channel_data
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
raise NotImplementedError
|
||||
|
||||
def update_features(self) -> None:
|
||||
raise NotImplementedError
|
97
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
97
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...log import DUMP_BYTES
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV1(ProtocolAndChannel):
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
_features: messages.Features | None = None
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
self.write(messages.GetFeatures())
|
||||
resp = self.read()
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = resp
|
||||
|
||||
def read(self) -> t.Any:
|
||||
msg_type, msg_bytes = self._read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
self.transport.close()
|
||||
return msg
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: chunk_size - 1]
|
||||
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||
self.transport.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def _read(self) -> t.Tuple[int, bytes]:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> t.Tuple[int, int, bytes]:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
404
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
404
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
@ -0,0 +1,404 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
from binascii import hexlify
|
||||
from enum import IntEnum
|
||||
|
||||
import click
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
from ..thp import checksum, curve25519, thp_io
|
||||
from ..thp.channel_data import ChannelData
|
||||
from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.message_header import MessageHeader
|
||||
from . import control_byte
|
||||
from .channel_database import ChannelDatabase, get_channel_db
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
MANAGEMENT_SESSION_ID: int = 0
|
||||
|
||||
|
||||
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
|
||||
hash = hashlib.sha256(val_1)
|
||||
hash.update(val_2)
|
||||
return hash.digest()
|
||||
|
||||
|
||||
def _hkdf(chaining_key: bytes, input: bytes):
|
||||
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
|
||||
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
|
||||
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
|
||||
ctx_output_2.update(b"\x02")
|
||||
output_2 = ctx_output_2.digest()
|
||||
return (output_1, output_2)
|
||||
|
||||
|
||||
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||
if not nonce <= 0xFFFFFFFFFFFFFFFF:
|
||||
raise ValueError("Nonce overflow, terminate the channel")
|
||||
return bytes(4) + nonce.to_bytes(8, "big")
|
||||
|
||||
|
||||
class ProtocolV2(ProtocolAndChannel):
|
||||
channel_id: int
|
||||
channel_database: ChannelDatabase
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
|
||||
_has_valid_channel: bool = False
|
||||
_features: messages.Features | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_data: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.channel_database: ChannelDatabase = get_channel_db()
|
||||
super().__init__(transport, mapping, channel_data)
|
||||
if channel_data is not None:
|
||||
self.channel_id = channel_data.channel_id
|
||||
self.key_request = bytes.fromhex(channel_data.key_request)
|
||||
self.key_response = bytes.fromhex(channel_data.key_response)
|
||||
self.nonce_request = channel_data.nonce_request
|
||||
self.nonce_response = channel_data.nonce_response
|
||||
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||
self.sync_bit_send = channel_data.sync_bit_send
|
||||
self._has_valid_channel = True
|
||||
|
||||
def get_channel(self) -> ProtocolV2:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
return self
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
return ChannelData(
|
||||
protocol_version=2,
|
||||
transport_path=self.transport.get_path(),
|
||||
channel_id=self.channel_id,
|
||||
key_request=self.key_request,
|
||||
key_response=self.key_response,
|
||||
nonce_request=self.nonce_request,
|
||||
nonce_response=self.nonce_response,
|
||||
sync_bit_receive=self.sync_bit_receive,
|
||||
sync_bit_send=self.sync_bit_send,
|
||||
)
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||
if sid != session_id:
|
||||
raise Exception("Received messsage on a different session.")
|
||||
self.channel_database.save_channel(self)
|
||||
return self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
def write(self, session_id: int, msg: t.Any) -> None:
|
||||
msg_type, msg_data = self.mapping.encode(msg)
|
||||
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||
self.channel_database.save_channel(self)
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
message = messages.GetFeatures()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
self.session_id: int = 0
|
||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
features = self.mapping.decode(msg_type, msg_data)
|
||||
if not isinstance(features, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = features
|
||||
|
||||
def _establish_new_channel(self) -> None:
|
||||
self.sync_bit_send = 0
|
||||
self.sync_bit_receive = 0
|
||||
# Send channel allocation request
|
||||
# Note that [:8] on the following line is required when tests use
|
||||
# WITH_MOCK_URANDOM. Without [:8] such tests will (almost always) fail.
|
||||
channel_id_request_nonce = os.urandom(8)[:8]
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
MessageHeader.get_channel_allocation_request_header(12),
|
||||
channel_id_request_nonce,
|
||||
)
|
||||
|
||||
# Read channel allocation response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not self._is_valid_channel_allocation_response(
|
||||
header, payload, channel_id_request_nonce
|
||||
):
|
||||
# TODO raise exception here, I guess
|
||||
raise Exception("Invalid channel allocation response.")
|
||||
|
||||
self.channel_id = int.from_bytes(payload[8:10], "big")
|
||||
self.device_properties = payload[10:]
|
||||
|
||||
# Send handshake init request
|
||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||
# Note that [:32] on the following line is required when tests use
|
||||
# WITH_MOCK_URANDOM. Without [:32] such tests will (almost always) fail.
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)[:32])
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||
)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
# Read handshake init response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
self._send_ack_0()
|
||||
|
||||
if not header.is_handshake_init_response():
|
||||
click.echo(
|
||||
"Received message is not a valid handshake init response message",
|
||||
err=True,
|
||||
)
|
||||
|
||||
trezor_ephemeral_pubkey = payload[:32]
|
||||
encrypted_trezor_static_pubkey = payload[32:80]
|
||||
noise_tag = payload[80:96]
|
||||
|
||||
# TODO check noise tag
|
||||
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||
|
||||
# Prepare and send handshake completion request
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||
ck, k = _hkdf(
|
||||
PROTOCOL_NAME,
|
||||
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
try:
|
||||
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||
IV_1, encrypted_trezor_static_pubkey, h
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
f"Exception of type{type(e)}", err=True
|
||||
) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||
)
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||
h = _sha256_of_two(h, tag_of_empty_string)
|
||||
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
|
||||
|
||||
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
|
||||
aes_ctx = AESGCM(k)
|
||||
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
|
||||
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
|
||||
)
|
||||
msg_data = self.mapping.encode_without_wire_type(
|
||||
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||
pairing_methods=[
|
||||
messages.ThpPairingMethod.NoMethod,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||
h = _sha256_of_two(h, encrypted_payload)
|
||||
ha_completion_req_header = MessageHeader(
|
||||
0x12,
|
||||
self.channel_id,
|
||||
len(encrypted_host_static_pubkey)
|
||||
+ len(encrypted_payload)
|
||||
+ CHECKSUM_LENGTH,
|
||||
)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
ha_completion_req_header,
|
||||
encrypted_host_static_pubkey + encrypted_payload,
|
||||
)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, _ = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
click.echo(
|
||||
"Received message is not a valid handshake completion response",
|
||||
err=True,
|
||||
)
|
||||
self._send_ack_1()
|
||||
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
self.nonce_request = 0
|
||||
self.nonce_response = 1
|
||||
|
||||
# Send StartPairingReqest message
|
||||
message = messages.ThpStartPairingRequest()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
|
||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
# Read
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
maaa = self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
assert isinstance(maaa, messages.ThpEndResponse)
|
||||
self._has_valid_channel = True
|
||||
|
||||
def _send_ack_0(self):
|
||||
LOG.debug("sending ack 0")
|
||||
header = MessageHeader(0x20, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _send_ack_1(self):
|
||||
LOG.debug("sending ack 1")
|
||||
header = MessageHeader(0x28, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _encrypt_and_write(
|
||||
self,
|
||||
session_id: int,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
ctrl_byte: int | None = None,
|
||||
) -> None:
|
||||
assert self.key_request is not None
|
||||
aes_ctx = AESGCM(self.key_request)
|
||||
|
||||
if ctrl_byte is None:
|
||||
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
|
||||
self.sync_bit_send = 1 - self.sync_bit_send
|
||||
|
||||
sid = session_id.to_bytes(1, "big")
|
||||
msg_type = message_type.to_bytes(2, "big")
|
||||
data = sid + msg_type + message_data
|
||||
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
self.nonce_request += 1
|
||||
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||
header = MessageHeader(
|
||||
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, header, encrypted_message
|
||||
)
|
||||
|
||||
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
|
||||
header, raw_payload = self._read_until_valid_crc_check()
|
||||
if control_byte.is_ack(header.ctrl_byte):
|
||||
return self.read_and_decrypt()
|
||||
if not header.is_encrypted_transport():
|
||||
click.echo(
|
||||
"Trying to decrypt not encrypted message!"
|
||||
+ hexlify(header.to_bytes_init() + raw_payload).decode(),
|
||||
err=True,
|
||||
)
|
||||
|
||||
if not control_byte.is_ack(header.ctrl_byte):
|
||||
LOG.debug(
|
||||
"--> Get sequence bit %d %s %s",
|
||||
control_byte.get_seq_bit(header.ctrl_byte),
|
||||
"from control byte",
|
||||
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
|
||||
)
|
||||
if control_byte.get_seq_bit(header.ctrl_byte):
|
||||
self._send_ack_1()
|
||||
else:
|
||||
self._send_ack_0()
|
||||
aes_ctx = AESGCM(self.key_response)
|
||||
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||
self.nonce_response += 1
|
||||
|
||||
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||
session_id = message[0]
|
||||
message_type = message[1:3]
|
||||
message_data = message[3:]
|
||||
return (
|
||||
session_id,
|
||||
int.from_bytes(message_type, "big"),
|
||||
message_data,
|
||||
)
|
||||
|
||||
def _read_until_valid_crc_check(
|
||||
self,
|
||||
) -> t.Tuple[MessageHeader, bytes]:
|
||||
is_valid = False
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
while not is_valid:
|
||||
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||
if not is_valid:
|
||||
click.echo(
|
||||
"Received a message with an invalid checksum:"
|
||||
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
|
||||
err=True,
|
||||
)
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
|
||||
return header, payload
|
||||
|
||||
def _is_valid_channel_allocation_response(
|
||||
self, header: MessageHeader, payload: bytes, original_nonce: bytes
|
||||
) -> bool:
|
||||
if not header.is_channel_allocation_response():
|
||||
click.echo(
|
||||
"Received message is not a channel allocation response", err=True
|
||||
)
|
||||
return False
|
||||
if len(payload) < 10:
|
||||
click.echo("Invalid channel allocation response payload", err=True)
|
||||
return False
|
||||
if payload[:8] != original_nonce:
|
||||
click.echo(
|
||||
"Invalid channel allocation response payload (nonce mismatch)", err=True
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
class ControlByteType(IntEnum):
|
||||
CHANNEL_ALLOCATION_RES = 1
|
||||
HANDSHAKE_INIT_RES = 2
|
||||
HANDSHAKE_COMP_RES = 3
|
||||
ACK = 4
|
||||
ENCRYPTED_TRANSPORT = 5
|
93
python/src/trezorlib/transport/thp/thp_io.py
Normal file
93
python/src/trezorlib/transport/thp/thp_io.py
Normal file
@ -0,0 +1,93 @@
|
||||
import struct
|
||||
from typing import Tuple
|
||||
|
||||
from .. import Transport
|
||||
from ..thp import checksum
|
||||
from .message_header import MessageHeader
|
||||
|
||||
INIT_HEADER_LENGTH = 5
|
||||
CONT_HEADER_LENGTH = 3
|
||||
MAX_PAYLOAD_LEN = 60000
|
||||
MESSAGE_TYPE_LENGTH = 2
|
||||
|
||||
CONTINUATION_PACKET = 0x80
|
||||
|
||||
|
||||
def write_payload_to_wire_and_add_checksum(
|
||||
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||
):
|
||||
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
|
||||
data = transport_payload + chksum
|
||||
write_payload_to_wire(transport, header, data)
|
||||
|
||||
|
||||
def write_payload_to_wire(
|
||||
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||
):
|
||||
transport.open()
|
||||
buffer = bytearray(transport_payload)
|
||||
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
|
||||
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
|
||||
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
|
||||
while buffer:
|
||||
chunk = (
|
||||
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
|
||||
)
|
||||
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
|
||||
|
||||
|
||||
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
|
||||
"""
|
||||
Reads from the given wire transport.
|
||||
|
||||
Returns `Tuple[MessageHeader, bytes, bytes]`:
|
||||
1. `header` (`MessageHeader`): Header of the message.
|
||||
2. `data` (`bytes`): Contents of the message (if any).
|
||||
3. `checksum` (`bytes`): crc32 checksum of the header + data.
|
||||
|
||||
"""
|
||||
buffer = bytearray()
|
||||
|
||||
# Read header with first part of message data
|
||||
header, first_chunk = read_first(transport)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < header.data_length:
|
||||
buffer.extend(read_next(transport, header.cid))
|
||||
|
||||
data_len = header.data_length - checksum.CHECKSUM_LENGTH
|
||||
msg_data = buffer[:data_len]
|
||||
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
|
||||
|
||||
return (header, msg_data, chksum)
|
||||
|
||||
|
||||
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
|
||||
chunk = transport.read_chunk()
|
||||
try:
|
||||
ctrl_byte, cid, data_length = struct.unpack(
|
||||
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[INIT_HEADER_LENGTH:]
|
||||
return MessageHeader(ctrl_byte, cid, data_length), data
|
||||
|
||||
|
||||
def read_next(transport: Transport, cid: int) -> bytes:
|
||||
chunk = transport.read_chunk()
|
||||
ctrl_byte, read_cid = struct.unpack(
|
||||
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||
)
|
||||
if ctrl_byte != CONTINUATION_PACKET:
|
||||
raise RuntimeError("Continuation packet with incorrect control byte")
|
||||
if read_cid != cid:
|
||||
raise RuntimeError("Continuation packet for different channel")
|
||||
|
||||
return chunk[CONT_HEADER_LENGTH:]
|
@ -1,6 +1,6 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -14,14 +14,15 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable, Optional
|
||||
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(ProtocolBasedTransport):
|
||||
class UdpTransport(Transport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(self, device: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device = (host, port)
|
||||
self.socket: Optional[socket.socket] = None
|
||||
self.device: Tuple[str, int] = (host, port)
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
self.socket: socket.socket | None = None
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
if d._ping():
|
||||
if d.ping():
|
||||
return d
|
||||
else:
|
||||
raise TransportException(
|
||||
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
cls, _models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["UdpTransport"]:
|
||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||
try:
|
||||
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
else:
|
||||
raise TransportException(f"No UDP device at {path}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self._ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def _ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
while True:
|
||||
try:
|
||||
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return bytearray(chunk)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self.ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
@ -1,6 +1,6 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -14,16 +14,17 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, List
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZORS, TrezorModel
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
|
||||
WEBUSB_CHUNK_SIZE = 64
|
||||
|
||||
|
||||
class WebUsbHandle:
|
||||
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
|
||||
class WebUsbTransport(Transport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
return devices
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
@ -64,6 +121,8 @@ class WebUsbHandle:
|
||||
self.handle.claimInterface(self.interface)
|
||||
except usb1.USBErrorAccess as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
except usb1.USBErrorBusy as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
@ -75,6 +134,8 @@ class WebUsbHandle:
|
||||
self.handle = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.handle is None:
|
||||
self.open()
|
||||
assert self.handle is not None
|
||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
@ -97,6 +158,8 @@ class WebUsbHandle:
|
||||
return
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
if self.handle is None:
|
||||
self.open()
|
||||
assert self.handle is not None
|
||||
endpoint = 0x80 | self.endpoint
|
||||
while True:
|
||||
@ -117,70 +180,6 @@ class WebUsbHandle:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return chunk
|
||||
|
||||
|
||||
class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
handle: Optional[WebUsbHandle] = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
if handle is None:
|
||||
handle = WebUsbHandle(device, debug)
|
||||
|
||||
self.device = device
|
||||
self.handle = handle
|
||||
self.debug = debug
|
||||
|
||||
super().__init__(protocol=ProtocolV1(handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
except usb1.USBErrorPipe:
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
|
Loading…
Reference in New Issue
Block a user