mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-07 21:22:41 +00:00
feat(python): implement session based trezorlib
This commit is contained in:
parent
7f5764b7d4
commit
fbff05a89f
@ -7,7 +7,7 @@ import typing as t
|
|||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
|
|
||||||
from . import device
|
from . import device
|
||||||
from .client import TrezorClient
|
from .transport.session import Session
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cryptography_version = metadata.version("cryptography")
|
cryptography_version = metadata.version("cryptography")
|
||||||
@ -361,7 +361,7 @@ def verify_authentication_response(
|
|||||||
|
|
||||||
|
|
||||||
def authenticate_device(
|
def authenticate_device(
|
||||||
client: TrezorClient,
|
session: Session,
|
||||||
challenge: bytes | None = None,
|
challenge: bytes | None = None,
|
||||||
*,
|
*,
|
||||||
whitelist: t.Collection[bytes] | None = None,
|
whitelist: t.Collection[bytes] | None = None,
|
||||||
@ -371,7 +371,7 @@ def authenticate_device(
|
|||||||
if challenge is None:
|
if challenge is None:
|
||||||
challenge = secrets.token_bytes(16)
|
challenge = secrets.token_bytes(16)
|
||||||
|
|
||||||
resp = device.authenticate(client, challenge)
|
resp = device.authenticate(session, challenge)
|
||||||
|
|
||||||
return verify_authentication_response(
|
return verify_authentication_response(
|
||||||
challenge,
|
challenge,
|
||||||
|
@ -19,16 +19,16 @@ from typing import TYPE_CHECKING
|
|||||||
from . import messages
|
from . import messages
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def list_names(
|
def list_names(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
) -> messages.BenchmarkNames:
|
) -> messages.BenchmarkNames:
|
||||||
return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
|
return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
|
||||||
|
|
||||||
|
|
||||||
def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult:
|
def run(session: "Session", name: str) -> messages.BenchmarkResult:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
|
messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
|
||||||
)
|
)
|
||||||
|
@ -18,20 +18,19 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .protobuf import dict_to_proto
|
from .protobuf import dict_to_proto
|
||||||
from .tools import session
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.BinanceGetAddress(
|
messages.BinanceGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
),
|
),
|
||||||
@ -40,17 +39,16 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
session: "Session", address_n: "Address", show_display: bool = False
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
|
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
|
||||||
expect=messages.BinancePublicKey,
|
expect=messages.BinancePublicKey,
|
||||||
).public_key
|
).public_key
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
|
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||||
) -> messages.BinanceSignedTx:
|
) -> messages.BinanceSignedTx:
|
||||||
msg = tx_json["msgs"][0]
|
msg = tx_json["msgs"][0]
|
||||||
tx_msg = tx_json.copy()
|
tx_msg = tx_json.copy()
|
||||||
@ -59,7 +57,7 @@ def sign_tx(
|
|||||||
tx_msg["chunkify"] = chunkify
|
tx_msg["chunkify"] = chunkify
|
||||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
||||||
|
|
||||||
client.call(envelope, expect=messages.BinanceTxRequest)
|
session.call(envelope, expect=messages.BinanceTxRequest)
|
||||||
|
|
||||||
if "refid" in msg:
|
if "refid" in msg:
|
||||||
msg = dict_to_proto(messages.BinanceCancelMsg, msg)
|
msg = dict_to_proto(messages.BinanceCancelMsg, msg)
|
||||||
@ -70,4 +68,4 @@ def sign_tx(
|
|||||||
else:
|
else:
|
||||||
raise ValueError("can not determine msg type")
|
raise ValueError("can not determine msg type")
|
||||||
|
|
||||||
return client.call(msg, expect=messages.BinanceSignedTx)
|
return session.call(msg, expect=messages.BinanceSignedTx)
|
||||||
|
@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
|
|||||||
from typing_extensions import Protocol, TypedDict
|
from typing_extensions import Protocol, TypedDict
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import _return_success, prepare_message_bytes, session
|
from .tools import _return_success, prepare_message_bytes
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
class ScriptSig(TypedDict):
|
class ScriptSig(TypedDict):
|
||||||
asm: str
|
asm: str
|
||||||
@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
|||||||
|
|
||||||
|
|
||||||
def get_public_node(
|
def get_public_node(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
ecdsa_curve_name: Optional[str] = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
@ -116,12 +116,12 @@ def get_public_node(
|
|||||||
unlock_path_mac: Optional[bytes] = None,
|
unlock_path_mac: Optional[bytes] = None,
|
||||||
) -> messages.PublicKey:
|
) -> messages.PublicKey:
|
||||||
if unlock_path:
|
if unlock_path:
|
||||||
client.call(
|
session.call(
|
||||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||||
expect=messages.UnlockedPathRequest,
|
expect=messages.UnlockedPathRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetPublicKey(
|
messages.GetPublicKey(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
ecdsa_curve_name=ecdsa_curve_name,
|
ecdsa_curve_name=ecdsa_curve_name,
|
||||||
@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_authenticated_address(
|
def get_authenticated_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
@ -151,12 +151,12 @@ def get_authenticated_address(
|
|||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> messages.Address:
|
) -> messages.Address:
|
||||||
if unlock_path:
|
if unlock_path:
|
||||||
client.call(
|
session.call(
|
||||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||||
expect=messages.UnlockedPathRequest,
|
expect=messages.UnlockedPathRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetAddress(
|
messages.GetAddress(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
@ -171,13 +171,13 @@ def get_authenticated_address(
|
|||||||
|
|
||||||
|
|
||||||
def get_ownership_id(
|
def get_ownership_id(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetOwnershipId(
|
messages.GetOwnershipId(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
@ -188,8 +188,9 @@ def get_ownership_id(
|
|||||||
).ownership_id
|
).ownership_id
|
||||||
|
|
||||||
|
|
||||||
|
# TODO this is used by tests only
|
||||||
def get_ownership_proof(
|
def get_ownership_proof(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||||
@ -200,9 +201,9 @@ def get_ownership_proof(
|
|||||||
preauthorized: bool = False,
|
preauthorized: bool = False,
|
||||||
) -> Tuple[bytes, bytes]:
|
) -> Tuple[bytes, bytes]:
|
||||||
if preauthorized:
|
if preauthorized:
|
||||||
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||||
|
|
||||||
res = client.call(
|
res = session.call(
|
||||||
messages.GetOwnershipProof(
|
messages.GetOwnershipProof(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
@ -219,7 +220,7 @@ def get_ownership_proof(
|
|||||||
|
|
||||||
|
|
||||||
def sign_message(
|
def sign_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
message: AnyStr,
|
message: AnyStr,
|
||||||
@ -227,7 +228,7 @@ def sign_message(
|
|||||||
no_script_type: bool = False,
|
no_script_type: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> messages.MessageSignature:
|
) -> messages.MessageSignature:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SignMessage(
|
messages.SignMessage(
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
address_n=n,
|
address_n=n,
|
||||||
@ -241,7 +242,7 @@ def sign_message(
|
|||||||
|
|
||||||
|
|
||||||
def verify_message(
|
def verify_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
address: str,
|
address: str,
|
||||||
signature: bytes,
|
signature: bytes,
|
||||||
@ -249,7 +250,7 @@ def verify_message(
|
|||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
client.call(
|
session.call(
|
||||||
messages.VerifyMessage(
|
messages.VerifyMessage(
|
||||||
address=address,
|
address=address,
|
||||||
signature=signature,
|
signature=signature,
|
||||||
@ -264,9 +265,9 @@ def verify_message(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@session
|
# @session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
inputs: Sequence[messages.TxInputType],
|
inputs: Sequence[messages.TxInputType],
|
||||||
outputs: Sequence[messages.TxOutputType],
|
outputs: Sequence[messages.TxOutputType],
|
||||||
@ -314,14 +315,14 @@ def sign_tx(
|
|||||||
setattr(signtx, name, value)
|
setattr(signtx, name, value)
|
||||||
|
|
||||||
if unlock_path:
|
if unlock_path:
|
||||||
client.call(
|
session.call(
|
||||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||||
expect=messages.UnlockedPathRequest,
|
expect=messages.UnlockedPathRequest,
|
||||||
)
|
)
|
||||||
elif preauthorized:
|
elif preauthorized:
|
||||||
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||||
|
|
||||||
res = client.call(signtx, expect=messages.TxRequest)
|
res = session.call(signtx, expect=messages.TxRequest)
|
||||||
|
|
||||||
# Prepare structure for signatures
|
# Prepare structure for signatures
|
||||||
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||||
@ -380,7 +381,7 @@ def sign_tx(
|
|||||||
if res.request_type == R.TXPAYMENTREQ:
|
if res.request_type == R.TXPAYMENTREQ:
|
||||||
assert res.details.request_index is not None
|
assert res.details.request_index is not None
|
||||||
msg = payment_reqs[res.details.request_index]
|
msg = payment_reqs[res.details.request_index]
|
||||||
res = client.call(msg, expect=messages.TxRequest)
|
res = session.call(msg, expect=messages.TxRequest)
|
||||||
else:
|
else:
|
||||||
msg = messages.TransactionType()
|
msg = messages.TransactionType()
|
||||||
if res.request_type == R.TXMETA:
|
if res.request_type == R.TXMETA:
|
||||||
@ -410,7 +411,7 @@ def sign_tx(
|
|||||||
f"Unknown request type - {res.request_type}."
|
f"Unknown request type - {res.request_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
|
res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
|
||||||
|
|
||||||
for i, sig in zip(inputs, signatures):
|
for i, sig in zip(inputs, signatures):
|
||||||
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
|
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
|
||||||
@ -420,7 +421,7 @@ def sign_tx(
|
|||||||
|
|
||||||
|
|
||||||
def authorize_coinjoin(
|
def authorize_coinjoin(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coordinator: str,
|
coordinator: str,
|
||||||
max_rounds: int,
|
max_rounds: int,
|
||||||
max_coordinator_fee_rate: int,
|
max_coordinator_fee_rate: int,
|
||||||
@ -429,7 +430,7 @@ def authorize_coinjoin(
|
|||||||
coin_name: str,
|
coin_name: str,
|
||||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.AuthorizeCoinJoin(
|
messages.AuthorizeCoinJoin(
|
||||||
coordinator=coordinator,
|
coordinator=coordinator,
|
||||||
max_rounds=max_rounds,
|
max_rounds=max_rounds,
|
||||||
|
@ -35,7 +35,7 @@ from . import messages as m
|
|||||||
from . import tools
|
from . import tools
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
from .transport.session import Session
|
||||||
|
|
||||||
PROTOCOL_MAGICS = {
|
PROTOCOL_MAGICS = {
|
||||||
"mainnet": 764824073,
|
"mainnet": 764824073,
|
||||||
@ -818,7 +818,7 @@ def _get_collateral_inputs_items(
|
|||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_parameters: m.CardanoAddressParametersType,
|
address_parameters: m.CardanoAddressParametersType,
|
||||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||||
network_id: int = NETWORK_IDS["mainnet"],
|
network_id: int = NETWORK_IDS["mainnet"],
|
||||||
@ -826,7 +826,7 @@ def get_address(
|
|||||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
m.CardanoGetAddress(
|
m.CardanoGetAddress(
|
||||||
address_parameters=address_parameters,
|
address_parameters=address_parameters,
|
||||||
protocol_magic=protocol_magic,
|
protocol_magic=protocol_magic,
|
||||||
@ -840,12 +840,12 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
) -> m.CardanoPublicKey:
|
) -> m.CardanoPublicKey:
|
||||||
return client.call(
|
return session.call(
|
||||||
m.CardanoGetPublicKey(
|
m.CardanoGetPublicKey(
|
||||||
address_n=address_n,
|
address_n=address_n,
|
||||||
derivation_type=derivation_type,
|
derivation_type=derivation_type,
|
||||||
@ -856,12 +856,12 @@ def get_public_key(
|
|||||||
|
|
||||||
|
|
||||||
def get_native_script_hash(
|
def get_native_script_hash(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
native_script: m.CardanoNativeScript,
|
native_script: m.CardanoNativeScript,
|
||||||
display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
|
display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||||
) -> m.CardanoNativeScriptHash:
|
) -> m.CardanoNativeScriptHash:
|
||||||
return client.call(
|
return session.call(
|
||||||
m.CardanoGetNativeScriptHash(
|
m.CardanoGetNativeScriptHash(
|
||||||
script=native_script,
|
script=native_script,
|
||||||
display_format=display_format,
|
display_format=display_format,
|
||||||
@ -872,7 +872,7 @@ def get_native_script_hash(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
signing_mode: m.CardanoTxSigningMode,
|
signing_mode: m.CardanoTxSigningMode,
|
||||||
inputs: List[InputWithPath],
|
inputs: List[InputWithPath],
|
||||||
outputs: List[OutputWithData],
|
outputs: List[OutputWithData],
|
||||||
@ -907,7 +907,7 @@ def sign_tx(
|
|||||||
signing_mode,
|
signing_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(
|
response = session.call(
|
||||||
m.CardanoSignTxInit(
|
m.CardanoSignTxInit(
|
||||||
signing_mode=signing_mode,
|
signing_mode=signing_mode,
|
||||||
inputs_count=len(inputs),
|
inputs_count=len(inputs),
|
||||||
@ -942,12 +942,12 @@ def sign_tx(
|
|||||||
_get_certificates_items(certificates),
|
_get_certificates_items(certificates),
|
||||||
withdrawals,
|
withdrawals,
|
||||||
):
|
):
|
||||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||||
|
|
||||||
sign_tx_response: Dict[str, Any] = {}
|
sign_tx_response: Dict[str, Any] = {}
|
||||||
|
|
||||||
if auxiliary_data is not None:
|
if auxiliary_data is not None:
|
||||||
auxiliary_data_supplement = client.call(
|
auxiliary_data_supplement = session.call(
|
||||||
auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
|
auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
@ -958,25 +958,25 @@ def sign_tx(
|
|||||||
auxiliary_data_supplement.__dict__
|
auxiliary_data_supplement.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
|
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
|
||||||
|
|
||||||
for tx_item in chain(
|
for tx_item in chain(
|
||||||
_get_mint_items(mint),
|
_get_mint_items(mint),
|
||||||
_get_collateral_inputs_items(collateral_inputs),
|
_get_collateral_inputs_items(collateral_inputs),
|
||||||
required_signers,
|
required_signers,
|
||||||
):
|
):
|
||||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||||
|
|
||||||
if collateral_return is not None:
|
if collateral_return is not None:
|
||||||
for tx_item in _get_output_items(collateral_return):
|
for tx_item in _get_output_items(collateral_return):
|
||||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||||
|
|
||||||
for reference_input in reference_inputs:
|
for reference_input in reference_inputs:
|
||||||
response = client.call(reference_input, expect=m.CardanoTxItemAck)
|
response = session.call(reference_input, expect=m.CardanoTxItemAck)
|
||||||
|
|
||||||
sign_tx_response["witnesses"] = []
|
sign_tx_response["witnesses"] = []
|
||||||
for witness_request in witness_requests:
|
for witness_request in witness_requests:
|
||||||
response = client.call(witness_request, expect=m.CardanoTxWitnessResponse)
|
response = session.call(witness_request, expect=m.CardanoTxWitnessResponse)
|
||||||
sign_tx_response["witnesses"].append(
|
sign_tx_response["witnesses"].append(
|
||||||
{
|
{
|
||||||
"type": response.type,
|
"type": response.type,
|
||||||
@ -986,9 +986,9 @@ def sign_tx(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
|
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
|
||||||
sign_tx_response["tx_hash"] = response.tx_hash
|
sign_tx_response["tx_hash"] = response.tx_hash
|
||||||
|
|
||||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
|
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
|
||||||
|
|
||||||
return sign_tx_response
|
return sign_tx_response
|
||||||
|
@ -13,28 +13,24 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import typing as t
|
||||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
from enum import IntEnum
|
||||||
|
|
||||||
from mnemonic import Mnemonic
|
from . import mapping, messages, models
|
||||||
|
from .mapping import ProtobufMapping
|
||||||
|
from .tools import parse_path
|
||||||
|
from .transport import Transport, get_transport
|
||||||
|
from .transport.thp.channel_data import ChannelData
|
||||||
|
from .transport.thp.protocol_and_channel import ProtocolAndChannel
|
||||||
|
from .transport.thp.protocol_v1 import ProtocolV1
|
||||||
|
from .transport.thp.protocol_v2 import ProtocolV2
|
||||||
|
|
||||||
from . import exceptions, mapping, messages, models
|
if t.TYPE_CHECKING:
|
||||||
from .log import DUMP_BYTES
|
from .transport.session import Session
|
||||||
from .messages import Capability
|
|
||||||
from .protobuf import MessageType
|
|
||||||
from .tools import parse_path, session
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .transport import Transport
|
|
||||||
from .ui import TrezorClientUI
|
|
||||||
|
|
||||||
UI = TypeVar("UI", bound="TrezorClientUI")
|
|
||||||
MT = TypeVar("MT", bound=MessageType)
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -51,8 +47,205 @@ Or visit https://suite.trezor.io/
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolVersion(IntEnum):
|
||||||
|
UNKNOWN = 0x00
|
||||||
|
PROTOCOL_V1 = 0x01 # Codec
|
||||||
|
PROTOCOL_V2 = 0x02 # THP
|
||||||
|
|
||||||
|
|
||||||
|
class TrezorClient:
|
||||||
|
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
|
||||||
|
_seedless_session: Session | None = None
|
||||||
|
_features: messages.Features | None = None
|
||||||
|
_protocol_version: int
|
||||||
|
_setup_pin: str | None = None # Should by used only by conftest
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport: Transport,
|
||||||
|
protobuf_mapping: ProtobufMapping | None = None,
|
||||||
|
protocol: ProtocolAndChannel | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._is_invalidated: bool = False
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
if protobuf_mapping is None:
|
||||||
|
self.mapping = mapping.DEFAULT_MAPPING
|
||||||
|
else:
|
||||||
|
self.mapping = protobuf_mapping
|
||||||
|
if protocol is None:
|
||||||
|
self.protocol = self._get_protocol()
|
||||||
|
else:
|
||||||
|
self.protocol = protocol
|
||||||
|
self.protocol.mapping = self.mapping
|
||||||
|
|
||||||
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
|
self._protocol_version = ProtocolVersion.PROTOCOL_V1
|
||||||
|
elif isinstance(self.protocol, ProtocolV2):
|
||||||
|
self._protocol_version = ProtocolVersion.PROTOCOL_V2
|
||||||
|
else:
|
||||||
|
self._protocol_version = ProtocolVersion.UNKNOWN
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resume(
|
||||||
|
cls,
|
||||||
|
transport: Transport,
|
||||||
|
channel_data: ChannelData,
|
||||||
|
protobuf_mapping: ProtobufMapping | None = None,
|
||||||
|
) -> TrezorClient:
|
||||||
|
if protobuf_mapping is None:
|
||||||
|
protobuf_mapping = mapping.DEFAULT_MAPPING
|
||||||
|
protocol_v1 = ProtocolV1(transport, protobuf_mapping)
|
||||||
|
if channel_data.protocol_version_major == 2:
|
||||||
|
try:
|
||||||
|
protocol_v1.write(messages.Ping(message="Sanity check - to resume"))
|
||||||
|
except Exception as e:
|
||||||
|
print(type(e))
|
||||||
|
response = protocol_v1.read()
|
||||||
|
if (
|
||||||
|
isinstance(response, messages.Failure)
|
||||||
|
and response.code == messages.FailureType.InvalidProtocol
|
||||||
|
):
|
||||||
|
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
|
||||||
|
protocol.write(0, messages.Ping())
|
||||||
|
response = protocol.read(0)
|
||||||
|
if not isinstance(response, messages.Success):
|
||||||
|
LOG.debug("Failed to resume ProtocolV2")
|
||||||
|
raise Exception("Failed to resume ProtocolV2")
|
||||||
|
LOG.debug("Protocol V2 detected - can be resumed")
|
||||||
|
else:
|
||||||
|
LOG.debug("Failed to resume ProtocolV2")
|
||||||
|
raise Exception("Failed to resume ProtocolV2")
|
||||||
|
else:
|
||||||
|
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
|
||||||
|
return TrezorClient(transport, protobuf_mapping, protocol)
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
passphrase: str | object | None = None,
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
session_id: int = 0,
|
||||||
|
) -> Session:
|
||||||
|
"""
|
||||||
|
Returns initialized session (with derived seed).
|
||||||
|
|
||||||
|
Will fail if the device is not initialized
|
||||||
|
"""
|
||||||
|
from .transport.session import SessionV1, SessionV2
|
||||||
|
|
||||||
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
|
if passphrase is None:
|
||||||
|
passphrase = ""
|
||||||
|
return SessionV1.new(self, passphrase, derive_cardano)
|
||||||
|
if isinstance(self.protocol, ProtocolV2):
|
||||||
|
assert isinstance(passphrase, str) or passphrase is None
|
||||||
|
return SessionV2.new(self, passphrase, derive_cardano, session_id)
|
||||||
|
raise NotImplementedError # TODO
|
||||||
|
|
||||||
|
def resume_session(self, session: Session):
|
||||||
|
"""
|
||||||
|
Note: this function potentially modifies the input session.
|
||||||
|
"""
|
||||||
|
from .debuglink import SessionDebugWrapper
|
||||||
|
from .transport.session import SessionV1, SessionV2
|
||||||
|
|
||||||
|
if isinstance(session, SessionDebugWrapper):
|
||||||
|
session = session._session
|
||||||
|
|
||||||
|
if isinstance(session, SessionV2):
|
||||||
|
return session
|
||||||
|
elif isinstance(session, SessionV1):
|
||||||
|
session.init_session()
|
||||||
|
return session
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_seedless_session(self, new_session: bool = False) -> Session:
|
||||||
|
from .transport.session import SessionV1, SessionV2
|
||||||
|
|
||||||
|
if not new_session and self._seedless_session is not None:
|
||||||
|
return self._seedless_session
|
||||||
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
|
self._seedless_session = SessionV1.new(
|
||||||
|
client=self,
|
||||||
|
passphrase="",
|
||||||
|
derive_cardano=False,
|
||||||
|
)
|
||||||
|
elif isinstance(self.protocol, ProtocolV2):
|
||||||
|
self._seedless_session = SessionV2(client=self, id=b"\x00")
|
||||||
|
assert self._seedless_session is not None
|
||||||
|
return self._seedless_session
|
||||||
|
|
||||||
|
def invalidate(self) -> None:
|
||||||
|
self._is_invalidated = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> messages.Features:
|
||||||
|
if self._features is None:
|
||||||
|
self._features = self.protocol.get_features()
|
||||||
|
assert self._features is not None
|
||||||
|
return self._features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol_version(self) -> int:
|
||||||
|
return self._protocol_version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> models.TrezorModel:
|
||||||
|
f = self.features
|
||||||
|
model = models.by_name(f.model or "1")
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Unsupported Trezor model"
|
||||||
|
f" (internal_model: {f.internal_model}, model: {f.model})"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> tuple[int, int, int]:
|
||||||
|
f = self.features
|
||||||
|
ver = (
|
||||||
|
f.major_version,
|
||||||
|
f.minor_version,
|
||||||
|
f.patch_version,
|
||||||
|
)
|
||||||
|
return ver
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_invalidated(self) -> bool:
|
||||||
|
return self._is_invalidated
|
||||||
|
|
||||||
|
def refresh_features(self) -> None:
|
||||||
|
self.protocol.update_features()
|
||||||
|
self._features = self.protocol.get_features()
|
||||||
|
|
||||||
|
def _get_protocol(self) -> ProtocolAndChannel:
|
||||||
|
self.transport.open()
|
||||||
|
|
||||||
|
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
|
||||||
|
|
||||||
|
protocol.write(messages.Initialize())
|
||||||
|
|
||||||
|
response = protocol.read()
|
||||||
|
self.transport.close()
|
||||||
|
if isinstance(response, messages.Failure):
|
||||||
|
if response.code == messages.FailureType.InvalidProtocol:
|
||||||
|
LOG.debug("Protocol V2 detected")
|
||||||
|
protocol = ProtocolV2(self.transport, self.mapping)
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
|
||||||
def get_default_client(
|
def get_default_client(
|
||||||
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
|
path: t.Optional[str] = None,
|
||||||
|
**kwargs: t.Any,
|
||||||
) -> "TrezorClient":
|
) -> "TrezorClient":
|
||||||
"""Get a client for a connected Trezor device.
|
"""Get a client for a connected Trezor device.
|
||||||
|
|
||||||
@ -62,436 +255,10 @@ def get_default_client(
|
|||||||
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
||||||
If no UI is supplied, instantiates the default CLI UI.
|
If no UI is supplied, instantiates the default CLI UI.
|
||||||
"""
|
"""
|
||||||
from .transport import get_transport
|
|
||||||
from .ui import ClickUI
|
|
||||||
|
|
||||||
if path is None:
|
if path is None:
|
||||||
path = os.getenv("TREZOR_PATH")
|
path = os.getenv("TREZOR_PATH")
|
||||||
|
|
||||||
transport = get_transport(path, prefix_search=True)
|
transport = get_transport(path, prefix_search=True)
|
||||||
if ui is None:
|
|
||||||
ui = ClickUI()
|
|
||||||
|
|
||||||
return TrezorClient(transport, ui, **kwargs)
|
return TrezorClient(transport, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TrezorClient(Generic[UI]):
|
|
||||||
"""Trezor client, a connection to a Trezor device.
|
|
||||||
|
|
||||||
This class allows you to manage connection state, send and receive protobuf
|
|
||||||
messages, handle user interactions, and perform some generic tasks
|
|
||||||
(send a cancel message, initialize or clear a session, ping the device).
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: models.TrezorModel
|
|
||||||
transport: "Transport"
|
|
||||||
session_id: Optional[bytes]
|
|
||||||
ui: UI
|
|
||||||
features: messages.Features
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
transport: "Transport",
|
|
||||||
ui: UI,
|
|
||||||
session_id: Optional[bytes] = None,
|
|
||||||
derive_cardano: Optional[bool] = None,
|
|
||||||
model: Optional[models.TrezorModel] = None,
|
|
||||||
_init_device: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""Create a TrezorClient instance.
|
|
||||||
|
|
||||||
You have to provide a `transport`, i.e., a raw connection to the device. You can
|
|
||||||
use `trezorlib.transport.get_transport` to find one.
|
|
||||||
|
|
||||||
You have to provide a UI implementation for the three kinds of interaction:
|
|
||||||
- button request (notify the user that their interaction is needed)
|
|
||||||
- PIN request (on T1, ask the user to input numbers for a PIN matrix)
|
|
||||||
- passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for
|
|
||||||
details.
|
|
||||||
|
|
||||||
You can supply a `session_id` you might have saved in the previous session. If
|
|
||||||
you do, the user might not need to enter their passphrase again.
|
|
||||||
|
|
||||||
You can provide Trezor model information. If not provided, it is detected from
|
|
||||||
the model name reported at initialization time.
|
|
||||||
|
|
||||||
By default, the instance will open a connection to the Trezor device, send an
|
|
||||||
`Initialize` message, set up the `features` field from the response, and connect
|
|
||||||
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
|
|
||||||
this means that `client.features` is unset. Use `client.init_device()` or
|
|
||||||
`client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break.
|
|
||||||
Only use this if you are _sure_ that you know what you are doing. This feature
|
|
||||||
might be removed at any time.
|
|
||||||
"""
|
|
||||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
|
||||||
# Here, self.model could be set to None. Unless _init_device is False, it will
|
|
||||||
# get correctly reconfigured as part of the init_device flow.
|
|
||||||
self.model = model # type: ignore ["None" is incompatible with "TrezorModel"]
|
|
||||||
if self.model:
|
|
||||||
self.mapping = self.model.default_mapping
|
|
||||||
else:
|
|
||||||
self.mapping = mapping.DEFAULT_MAPPING
|
|
||||||
self.transport = transport
|
|
||||||
self.ui = ui
|
|
||||||
self.session_counter = 0
|
|
||||||
self.session_id = session_id
|
|
||||||
if _init_device:
|
|
||||||
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
|
||||||
|
|
||||||
def open(self) -> None:
|
|
||||||
if self.session_counter == 0:
|
|
||||||
self.transport.begin_session()
|
|
||||||
self.session_counter += 1
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
self.session_counter = max(self.session_counter - 1, 0)
|
|
||||||
if self.session_counter == 0:
|
|
||||||
# TODO call EndSession here?
|
|
||||||
self.transport.end_session()
|
|
||||||
|
|
||||||
def cancel(self) -> None:
|
|
||||||
self._raw_write(messages.Cancel())
|
|
||||||
|
|
||||||
def call_raw(self, msg: MessageType) -> MessageType:
|
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
||||||
self._raw_write(msg)
|
|
||||||
return self._raw_read()
|
|
||||||
|
|
||||||
def _raw_write(self, msg: MessageType) -> None:
|
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
||||||
LOG.debug(
|
|
||||||
f"sending message: {msg.__class__.__name__}",
|
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
|
||||||
LOG.log(
|
|
||||||
DUMP_BYTES,
|
|
||||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
|
||||||
)
|
|
||||||
self.transport.write(msg_type, msg_bytes)
|
|
||||||
|
|
||||||
def _raw_read(self) -> MessageType:
|
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
||||||
msg_type, msg_bytes = self.transport.read()
|
|
||||||
LOG.log(
|
|
||||||
DUMP_BYTES,
|
|
||||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
|
||||||
)
|
|
||||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
|
||||||
LOG.debug(
|
|
||||||
f"received message: {msg.__class__.__name__}",
|
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
|
|
||||||
try:
|
|
||||||
pin = self.ui.get_pin(msg.type)
|
|
||||||
except exceptions.Cancelled:
|
|
||||||
self.call_raw(messages.Cancel())
|
|
||||||
raise
|
|
||||||
|
|
||||||
if any(d not in "123456789" for d in pin) or not (
|
|
||||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
|
||||||
):
|
|
||||||
self.call_raw(messages.Cancel())
|
|
||||||
raise ValueError("Invalid PIN provided")
|
|
||||||
|
|
||||||
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
|
|
||||||
if isinstance(resp, messages.Failure) and resp.code in (
|
|
||||||
messages.FailureType.PinInvalid,
|
|
||||||
messages.FailureType.PinCancelled,
|
|
||||||
messages.FailureType.PinExpected,
|
|
||||||
):
|
|
||||||
raise exceptions.PinException(resp.code, resp.message)
|
|
||||||
else:
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
|
|
||||||
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
|
||||||
|
|
||||||
def send_passphrase(
|
|
||||||
passphrase: Optional[str] = None, on_device: Optional[bool] = None
|
|
||||||
) -> MessageType:
|
|
||||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
|
||||||
resp = self.call_raw(msg)
|
|
||||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
|
||||||
self.session_id = resp.state
|
|
||||||
resp = self.call_raw(messages.Deprecated_PassphraseStateAck())
|
|
||||||
return resp
|
|
||||||
|
|
||||||
# short-circuit old style entry
|
|
||||||
if msg._on_device is True:
|
|
||||||
return send_passphrase(None, None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
passphrase = self.ui.get_passphrase(available_on_device=available_on_device)
|
|
||||||
except exceptions.Cancelled:
|
|
||||||
self.call_raw(messages.Cancel())
|
|
||||||
raise
|
|
||||||
|
|
||||||
if passphrase is PASSPHRASE_ON_DEVICE:
|
|
||||||
if not available_on_device:
|
|
||||||
self.call_raw(messages.Cancel())
|
|
||||||
raise RuntimeError("Device is not capable of entering passphrase")
|
|
||||||
else:
|
|
||||||
return send_passphrase(on_device=True)
|
|
||||||
|
|
||||||
# else process host-entered passphrase
|
|
||||||
if not isinstance(passphrase, str):
|
|
||||||
raise RuntimeError("Passphrase must be a str")
|
|
||||||
passphrase = Mnemonic.normalize_string(passphrase)
|
|
||||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
|
||||||
self.call_raw(messages.Cancel())
|
|
||||||
raise ValueError("Passphrase too long")
|
|
||||||
|
|
||||||
return send_passphrase(passphrase, on_device=False)
|
|
||||||
|
|
||||||
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
|
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
||||||
# do this raw - send ButtonAck first, notify UI later
|
|
||||||
self._raw_write(messages.ButtonAck())
|
|
||||||
self.ui.button_request(msg)
|
|
||||||
return self._raw_read()
|
|
||||||
|
|
||||||
@session
|
|
||||||
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
|
|
||||||
self.check_firmware_version()
|
|
||||||
resp = self.call_raw(msg)
|
|
||||||
while True:
|
|
||||||
if isinstance(resp, messages.PinMatrixRequest):
|
|
||||||
resp = self._callback_pin(resp)
|
|
||||||
elif isinstance(resp, messages.PassphraseRequest):
|
|
||||||
resp = self._callback_passphrase(resp)
|
|
||||||
elif isinstance(resp, messages.ButtonRequest):
|
|
||||||
resp = self._callback_button(resp)
|
|
||||||
elif isinstance(resp, messages.Failure):
|
|
||||||
if resp.code == messages.FailureType.ActionCancelled:
|
|
||||||
raise exceptions.Cancelled
|
|
||||||
raise exceptions.TrezorFailure(resp)
|
|
||||||
elif not isinstance(resp, expect):
|
|
||||||
raise exceptions.UnexpectedMessageError(expect, resp)
|
|
||||||
else:
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def _refresh_features(self, features: messages.Features) -> None:
|
|
||||||
"""Update internal fields based on passed-in Features message."""
|
|
||||||
|
|
||||||
if not self.model:
|
|
||||||
# Trezor Model One bootloader 1.8.0 or older does not send model name
|
|
||||||
model = models.by_internal_name(features.internal_model)
|
|
||||||
if model is None:
|
|
||||||
model = models.by_name(features.model or "1")
|
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Unsupported Trezor model"
|
|
||||||
f" (internal_model: {features.internal_model}, model: {features.model})"
|
|
||||||
)
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
if features.vendor not in self.model.vendors:
|
|
||||||
raise RuntimeError("Unsupported device")
|
|
||||||
|
|
||||||
self.features = features
|
|
||||||
self.version = (
|
|
||||||
self.features.major_version,
|
|
||||||
self.features.minor_version,
|
|
||||||
self.features.patch_version,
|
|
||||||
)
|
|
||||||
self.check_firmware_version(warn_only=True)
|
|
||||||
if self.features.session_id is not None:
|
|
||||||
self.session_id = self.features.session_id
|
|
||||||
self.features.session_id = None
|
|
||||||
|
|
||||||
@session
|
|
||||||
def refresh_features(self) -> messages.Features:
|
|
||||||
"""Reload features from the device.
|
|
||||||
|
|
||||||
Should be called after changing settings or performing operations that affect
|
|
||||||
device state.
|
|
||||||
"""
|
|
||||||
resp = self.call_raw(messages.GetFeatures())
|
|
||||||
if not isinstance(resp, messages.Features):
|
|
||||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
|
||||||
self._refresh_features(resp)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
@session
|
|
||||||
def init_device(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
session_id: Optional[bytes] = None,
|
|
||||||
new_session: bool = False,
|
|
||||||
derive_cardano: Optional[bool] = None,
|
|
||||||
) -> Optional[bytes]:
|
|
||||||
"""Initialize the device and return a session ID.
|
|
||||||
|
|
||||||
You can optionally specify a session ID. If the session still exists on the
|
|
||||||
device, the same session ID will be returned and the session is resumed.
|
|
||||||
Otherwise a different session ID is returned.
|
|
||||||
|
|
||||||
Specify `new_session=True` to open a fresh session. Since firmware version
|
|
||||||
1.9.0/2.3.0, the previous session will remain cached on the device, and can be
|
|
||||||
resumed by calling `init_device` again with the appropriate session ID.
|
|
||||||
|
|
||||||
If neither `new_session` nor `session_id` is specified, the current session ID
|
|
||||||
will be reused. If no session ID was cached, a new session ID will be allocated
|
|
||||||
and returned.
|
|
||||||
|
|
||||||
# Version notes:
|
|
||||||
|
|
||||||
Trezor One older than 1.9.0 does not have session management. Optional arguments
|
|
||||||
have no effect and the function returns None
|
|
||||||
|
|
||||||
Trezor T older than 2.3.0 does not have session cache. Requesting a new session
|
|
||||||
will overwrite the old one. In addition, this function will always return None.
|
|
||||||
A valid session_id can be obtained from the `session_id` attribute, but only
|
|
||||||
after a passphrase-protected call is performed. You can use the following code:
|
|
||||||
|
|
||||||
>>> client.init_device()
|
|
||||||
>>> client.ensure_unlocked()
|
|
||||||
>>> valid_session_id = client.session_id
|
|
||||||
"""
|
|
||||||
if new_session:
|
|
||||||
self.session_id = None
|
|
||||||
elif session_id is not None:
|
|
||||||
self.session_id = session_id
|
|
||||||
|
|
||||||
resp = self.call_raw(
|
|
||||||
messages.Initialize(
|
|
||||||
session_id=self.session_id,
|
|
||||||
derive_cardano=derive_cardano,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if isinstance(resp, messages.Failure):
|
|
||||||
# can happen if `derive_cardano` does not match the current session
|
|
||||||
raise exceptions.TrezorFailure(resp)
|
|
||||||
if not isinstance(resp, messages.Features):
|
|
||||||
raise exceptions.TrezorException("Unexpected response to Initialize")
|
|
||||||
|
|
||||||
if self.session_id is not None and resp.session_id == self.session_id:
|
|
||||||
LOG.info("Successfully resumed session")
|
|
||||||
elif session_id is not None:
|
|
||||||
LOG.info("Failed to resume session")
|
|
||||||
|
|
||||||
# TT < 2.3.0 compatibility:
|
|
||||||
# _refresh_features will clear out the session_id field. We want this function
|
|
||||||
# to return its value, so that callers can rely on it being either a valid
|
|
||||||
# session_id, or None if we can't do that.
|
|
||||||
# Older TT FW does not report session_id in Features and self.session_id might
|
|
||||||
# be invalid because TT will not allocate a session_id until a passphrase
|
|
||||||
# exchange happens.
|
|
||||||
reported_session_id = resp.session_id
|
|
||||||
self._refresh_features(resp)
|
|
||||||
return reported_session_id
|
|
||||||
|
|
||||||
def is_outdated(self) -> bool:
|
|
||||||
if self.features.bootloader_mode:
|
|
||||||
return False
|
|
||||||
return self.version < self.model.minimum_version
|
|
||||||
|
|
||||||
def check_firmware_version(self, warn_only: bool = False) -> None:
|
|
||||||
if self.is_outdated():
|
|
||||||
if warn_only:
|
|
||||||
warnings.warn("Firmware is out of date", stacklevel=2)
|
|
||||||
else:
|
|
||||||
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
|
||||||
|
|
||||||
def ping(self, msg: str, button_protection: bool = False) -> str:
|
|
||||||
# We would like ping to work on any valid TrezorClient instance, but
|
|
||||||
# due to the protection modes, we need to go through self.call, and that will
|
|
||||||
# raise an exception if the firmware is too old.
|
|
||||||
# So we short-circuit the simplest variant of ping with call_raw.
|
|
||||||
if not button_protection:
|
|
||||||
# XXX this should be: `with self:`
|
|
||||||
try:
|
|
||||||
self.open()
|
|
||||||
resp = self.call_raw(messages.Ping(message=msg))
|
|
||||||
if isinstance(resp, messages.ButtonRequest):
|
|
||||||
# device is PIN-locked.
|
|
||||||
# respond and hope for the best
|
|
||||||
resp = self._callback_button(resp)
|
|
||||||
resp = messages.Success.ensure_isinstance(resp)
|
|
||||||
assert resp.message is not None
|
|
||||||
return resp.message
|
|
||||||
finally:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
resp = self.call(
|
|
||||||
messages.Ping(message=msg, button_protection=button_protection),
|
|
||||||
expect=messages.Success,
|
|
||||||
)
|
|
||||||
assert resp.message is not None
|
|
||||||
return resp.message
|
|
||||||
|
|
||||||
def get_device_id(self) -> Optional[str]:
|
|
||||||
return self.features.device_id
|
|
||||||
|
|
||||||
@session
|
|
||||||
def lock(self, *, _refresh_features: bool = True) -> None:
|
|
||||||
"""Lock the device.
|
|
||||||
|
|
||||||
If the device does not have a PIN configured, this will do nothing.
|
|
||||||
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
|
||||||
before further actions.
|
|
||||||
|
|
||||||
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
|
||||||
the device will not prompt for it after unlocking.
|
|
||||||
|
|
||||||
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
|
||||||
passphrase cache, use `clear_session()`.
|
|
||||||
"""
|
|
||||||
# Private argument _refresh_features can be used internally to avoid
|
|
||||||
# refreshing in cases where we will refresh soon anyway. This is used
|
|
||||||
# in TrezorClient.clear_session()
|
|
||||||
self.call(messages.LockDevice())
|
|
||||||
if _refresh_features:
|
|
||||||
self.refresh_features()
|
|
||||||
|
|
||||||
@session
|
|
||||||
def ensure_unlocked(self) -> None:
|
|
||||||
"""Ensure the device is unlocked and a passphrase is cached.
|
|
||||||
|
|
||||||
If the device is locked, this will prompt for PIN. If passphrase is enabled
|
|
||||||
and no passphrase is cached for the current session, the device will also
|
|
||||||
prompt for passphrase.
|
|
||||||
|
|
||||||
After calling this method, further actions on the device will not prompt for
|
|
||||||
PIN or passphrase until the device is locked or the session becomes invalid.
|
|
||||||
"""
|
|
||||||
from .btc import get_address
|
|
||||||
|
|
||||||
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
|
||||||
self.refresh_features()
|
|
||||||
|
|
||||||
def end_session(self) -> None:
|
|
||||||
"""Close the current session and clear cached passphrase.
|
|
||||||
|
|
||||||
The session will become invalid until `init_device()` is called again.
|
|
||||||
If passphrase is enabled, further actions will prompt for it again.
|
|
||||||
|
|
||||||
This is a no-op in bootloader mode, as it does not support session management.
|
|
||||||
"""
|
|
||||||
# since: 2.3.4, 1.9.4
|
|
||||||
try:
|
|
||||||
if not self.features.bootloader_mode:
|
|
||||||
self.call(messages.EndSession())
|
|
||||||
except exceptions.TrezorFailure:
|
|
||||||
# A failure most likely means that the FW version does not support
|
|
||||||
# the EndSession call. We ignore the failure and clear the local session_id.
|
|
||||||
# The client-side end result is identical.
|
|
||||||
pass
|
|
||||||
self.session_id = None
|
|
||||||
|
|
||||||
@session
|
|
||||||
def clear_session(self) -> None:
|
|
||||||
"""Lock the device and present a fresh session.
|
|
||||||
|
|
||||||
The current session will be invalidated and a new one will be started. If the
|
|
||||||
device has PIN enabled, it will become locked.
|
|
||||||
|
|
||||||
Equivalent to calling `lock()`, `end_session()` and `init_device()`.
|
|
||||||
"""
|
|
||||||
self.lock(_refresh_features=False)
|
|
||||||
self.end_session()
|
|
||||||
self.init_device(new_session=True)
|
|
||||||
|
@ -21,55 +21,55 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
|
import typing as t
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
from . import mapping, messages, models, protobuf
|
from . import btc, mapping, messages, models, protobuf
|
||||||
from .client import TrezorClient
|
from .client import (
|
||||||
from .exceptions import TrezorFailure, UnexpectedMessageError
|
MAX_PASSPHRASE_LENGTH,
|
||||||
|
MAX_PIN_LENGTH,
|
||||||
|
PASSPHRASE_ON_DEVICE,
|
||||||
|
TrezorClient,
|
||||||
|
)
|
||||||
|
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
|
||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import DebugWaitType
|
from .messages import Capability, DebugWaitType
|
||||||
|
from .protobuf import MessageType
|
||||||
|
from .tools import parse_path
|
||||||
|
from .transport.session import Session, SessionV1
|
||||||
|
from .transport.thp.protocol_v1 import ProtocolV1
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from .messages import PinMatrixRequestType
|
from .messages import PinMatrixRequestType
|
||||||
from .transport import Transport
|
from .transport import Transport
|
||||||
|
|
||||||
ExpectedMessage = Union[
|
ExpectedMessage = t.Union[
|
||||||
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
|
protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
|
||||||
]
|
]
|
||||||
|
|
||||||
AnyDict = Dict[str, Any]
|
AnyDict = t.Dict[str, t.Any]
|
||||||
|
|
||||||
class InputFunc(Protocol):
|
class InputFunc(Protocol):
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
hold_ms: int | None = None,
|
hold_ms: int | None = None,
|
||||||
) -> "None": ...
|
) -> "None": ...
|
||||||
|
|
||||||
InputFlowType = Generator[None, messages.ButtonRequest, None]
|
InputFlowType = t.Generator[None, messages.ButtonRequest, None]
|
||||||
|
|
||||||
|
|
||||||
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
||||||
|
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -107,11 +107,11 @@ class UnstructuredJSONReader:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
self.dict = {}
|
self.dict = {}
|
||||||
|
|
||||||
def top_level_value(self, key: str) -> Any:
|
def top_level_value(self, key: str) -> t.Any:
|
||||||
return self.dict.get(key)
|
return self.dict.get(key)
|
||||||
|
|
||||||
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
|
def find_objects_with_key_and_value(self, key: str, value: t.Any) -> list[AnyDict]:
|
||||||
def recursively_find(data: Any) -> Iterator[Any]:
|
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
if data.get(key) == value:
|
if data.get(key) == value:
|
||||||
yield data
|
yield data
|
||||||
@ -124,7 +124,7 @@ class UnstructuredJSONReader:
|
|||||||
return list(recursively_find(self.dict))
|
return list(recursively_find(self.dict))
|
||||||
|
|
||||||
def find_unique_object_with_key_and_value(
|
def find_unique_object_with_key_and_value(
|
||||||
self, key: str, value: Any
|
self, key: str, value: t.Any
|
||||||
) -> AnyDict | None:
|
) -> AnyDict | None:
|
||||||
objects = self.find_objects_with_key_and_value(key, value)
|
objects = self.find_objects_with_key_and_value(key, value)
|
||||||
if not objects:
|
if not objects:
|
||||||
@ -132,8 +132,10 @@ class UnstructuredJSONReader:
|
|||||||
assert len(objects) == 1
|
assert len(objects) == 1
|
||||||
return objects[0]
|
return objects[0]
|
||||||
|
|
||||||
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
|
def find_values_by_key(
|
||||||
def recursively_find(data: Any) -> Iterator[Any]:
|
self, key: str, only_type: type | None = None
|
||||||
|
) -> list[t.Any]:
|
||||||
|
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
if key in data:
|
if key in data:
|
||||||
yield data[key]
|
yield data[key]
|
||||||
@ -151,8 +153,8 @@ class UnstructuredJSONReader:
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def find_unique_value_by_key(
|
def find_unique_value_by_key(
|
||||||
self, key: str, default: Any, only_type: type | None = None
|
self, key: str, default: t.Any, only_type: type | None = None
|
||||||
) -> Any:
|
) -> t.Any:
|
||||||
values = self.find_values_by_key(key, only_type=only_type)
|
values = self.find_values_by_key(key, only_type=only_type)
|
||||||
if not values:
|
if not values:
|
||||||
return default
|
return default
|
||||||
@ -163,7 +165,7 @@ class UnstructuredJSONReader:
|
|||||||
class LayoutContent(UnstructuredJSONReader):
|
class LayoutContent(UnstructuredJSONReader):
|
||||||
"""Contains helper functions to extract specific parts of the layout."""
|
"""Contains helper functions to extract specific parts of the layout."""
|
||||||
|
|
||||||
def __init__(self, json_tokens: Sequence[str]) -> None:
|
def __init__(self, json_tokens: t.Sequence[str]) -> None:
|
||||||
json_str = "".join(json_tokens)
|
json_str = "".join(json_tokens)
|
||||||
super().__init__(json_str)
|
super().__init__(json_str)
|
||||||
|
|
||||||
@ -429,6 +431,7 @@ class DebugLink:
|
|||||||
self.allow_interactions = auto_interact
|
self.allow_interactions = auto_interact
|
||||||
self.mapping = mapping.DEFAULT_MAPPING
|
self.mapping = mapping.DEFAULT_MAPPING
|
||||||
|
|
||||||
|
self.protocol = ProtocolV1(self.transport, self.mapping)
|
||||||
# To be set by TrezorClientDebugLink (is not known during creation time)
|
# To be set by TrezorClientDebugLink (is not known during creation time)
|
||||||
self.model: models.TrezorModel | None = None
|
self.model: models.TrezorModel | None = None
|
||||||
self.version: tuple[int, int, int] = (0, 0, 0)
|
self.version: tuple[int, int, int] = (0, 0, 0)
|
||||||
@ -471,10 +474,16 @@ class DebugLink:
|
|||||||
return LayoutType.from_model(self.model)
|
return LayoutType.from_model(self.model)
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.transport.begin_session()
|
self.transport.open()
|
||||||
|
# raise NotImplementedError
|
||||||
|
# TODO is this needed?
|
||||||
|
# self.transport.deprecated_begin_session()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.transport.end_session()
|
pass
|
||||||
|
# raise NotImplementedError
|
||||||
|
# TODO is this needed?
|
||||||
|
# self.transport.deprecated_end_session()
|
||||||
|
|
||||||
def _write(self, msg: protobuf.MessageType) -> None:
|
def _write(self, msg: protobuf.MessageType) -> None:
|
||||||
if self.waiting_for_layout_change:
|
if self.waiting_for_layout_change:
|
||||||
@ -491,15 +500,10 @@ class DebugLink:
|
|||||||
DUMP_BYTES,
|
DUMP_BYTES,
|
||||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
)
|
)
|
||||||
self.transport.write(msg_type, msg_bytes)
|
self.protocol.write(msg)
|
||||||
|
|
||||||
def _read(self) -> protobuf.MessageType:
|
def _read(self) -> protobuf.MessageType:
|
||||||
ret_type, ret_bytes = self.transport.read()
|
msg = self.protocol.read()
|
||||||
LOG.log(
|
|
||||||
DUMP_BYTES,
|
|
||||||
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
|
|
||||||
)
|
|
||||||
msg = self.mapping.decode(ret_type, ret_bytes)
|
|
||||||
|
|
||||||
# Collapse tokens to make log use less lines.
|
# Collapse tokens to make log use less lines.
|
||||||
msg_for_log = msg
|
msg_for_log = msg
|
||||||
@ -513,7 +517,7 @@ class DebugLink:
|
|||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _call(self, msg: protobuf.MessageType) -> Any:
|
def _call(self, msg: protobuf.MessageType) -> t.Any:
|
||||||
self._write(msg)
|
self._write(msg)
|
||||||
return self._read()
|
return self._read()
|
||||||
|
|
||||||
@ -531,6 +535,25 @@ class DebugLink:
|
|||||||
raise TrezorFailure(result)
|
raise TrezorFailure(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def pairing_info(
|
||||||
|
self,
|
||||||
|
thp_channel_id: bytes | None = None,
|
||||||
|
handshake_hash: bytes | None = None,
|
||||||
|
nfc_secret_host: bytes | None = None,
|
||||||
|
) -> messages.DebugLinkPairingInfo:
|
||||||
|
result = self._call(
|
||||||
|
messages.DebugLinkGetPairingInfo(
|
||||||
|
channel_id=thp_channel_id,
|
||||||
|
handshake_hash=handshake_hash,
|
||||||
|
nfc_secret_host=nfc_secret_host,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)):
|
||||||
|
result = self._read()
|
||||||
|
if isinstance(result, messages.Failure):
|
||||||
|
raise TrezorFailure(result)
|
||||||
|
return result
|
||||||
|
|
||||||
def read_layout(self, wait: bool | None = None) -> LayoutContent:
|
def read_layout(self, wait: bool | None = None) -> LayoutContent:
|
||||||
"""
|
"""
|
||||||
Force waiting for the layout by setting `wait=True`. Force not waiting by
|
Force waiting for the layout by setting `wait=True`. Force not waiting by
|
||||||
@ -547,7 +570,7 @@ class DebugLink:
|
|||||||
|
|
||||||
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
|
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
|
||||||
# Next layout change will be caused by external event
|
# Next layout change will be caused by external event
|
||||||
# (e.g. device being auto-locked or as a result of device_handler.run(xxx))
|
# (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
|
||||||
# and not by our debug actions/decisions.
|
# and not by our debug actions/decisions.
|
||||||
# Resetting the debug state so we wait for the next layout change
|
# Resetting the debug state so we wait for the next layout change
|
||||||
# (and do not return the current state).
|
# (and do not return the current state).
|
||||||
@ -562,7 +585,7 @@ class DebugLink:
|
|||||||
return LayoutContent(obj.tokens)
|
return LayoutContent(obj.tokens)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def wait_for_layout_change(self) -> Iterator[None]:
|
def wait_for_layout_change(self) -> t.Iterator[None]:
|
||||||
# make sure some current layout is up by issuing a dummy GetState
|
# make sure some current layout is up by issuing a dummy GetState
|
||||||
self.state()
|
self.state()
|
||||||
|
|
||||||
@ -615,7 +638,7 @@ class DebugLink:
|
|||||||
|
|
||||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||||
|
|
||||||
def read_recovery_word(self) -> Tuple[str | None, int | None]:
|
def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
|
||||||
state = self.state()
|
state = self.state()
|
||||||
return (state.recovery_fake_word, state.recovery_word_pos)
|
return (state.recovery_fake_word, state.recovery_word_pos)
|
||||||
|
|
||||||
@ -671,7 +694,7 @@ class DebugLink:
|
|||||||
"""Send text input to the device. See `_decision` for more details."""
|
"""Send text input to the device. See `_decision` for more details."""
|
||||||
self._decision(messages.DebugLinkDecision(input=word))
|
self._decision(messages.DebugLinkDecision(input=word))
|
||||||
|
|
||||||
def click(self, click: Tuple[int, int], hold_ms: int | None = None) -> None:
|
def click(self, click: t.Tuple[int, int], hold_ms: int | None = None) -> None:
|
||||||
"""Send a click to the device. See `_decision` for more details."""
|
"""Send a click to the device. See `_decision` for more details."""
|
||||||
x, y = click
|
x, y = click
|
||||||
self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms))
|
self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms))
|
||||||
@ -794,10 +817,10 @@ class DebugUI:
|
|||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.pins: Iterator[str] | None = None
|
self.pins: t.Iterator[str] | None = None
|
||||||
self.passphrase = ""
|
self.passphrase = ""
|
||||||
self.input_flow: Union[
|
self.input_flow: t.Union[
|
||||||
Generator[None, messages.ButtonRequest, None], object, None
|
t.Generator[None, messages.ButtonRequest, None], object, None
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
|
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
|
||||||
@ -829,7 +852,7 @@ class DebugUI:
|
|||||||
raise AssertionError("input flow ended prematurely")
|
raise AssertionError("input flow ended prematurely")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
assert isinstance(self.input_flow, Generator)
|
assert isinstance(self.input_flow, t.Generator)
|
||||||
self.input_flow.send(br)
|
self.input_flow.send(br)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.input_flow = self.INPUT_FLOW_DONE
|
self.input_flow = self.INPUT_FLOW_DONE
|
||||||
@ -851,12 +874,15 @@ class DebugUI:
|
|||||||
|
|
||||||
|
|
||||||
class MessageFilter:
|
class MessageFilter:
|
||||||
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
|
|
||||||
|
def __init__(
|
||||||
|
self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
|
||||||
|
) -> None:
|
||||||
self.message_type = message_type
|
self.message_type = message_type
|
||||||
self.fields: Dict[str, Any] = {}
|
self.fields: t.Dict[str, t.Any] = {}
|
||||||
self.update_fields(**fields)
|
self.update_fields(**fields)
|
||||||
|
|
||||||
def update_fields(self, **fields: Any) -> "MessageFilter":
|
def update_fields(self, **fields: t.Any) -> "MessageFilter":
|
||||||
for name, value in fields.items():
|
for name, value in fields.items():
|
||||||
try:
|
try:
|
||||||
self.fields[name] = self.from_message_or_type(value)
|
self.fields[name] = self.from_message_or_type(value)
|
||||||
@ -904,7 +930,7 @@ class MessageFilter:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def to_string(self, maxwidth: int = 80) -> str:
|
def to_string(self, maxwidth: int = 80) -> str:
|
||||||
fields: list[Tuple[str, str]] = []
|
fields: list[t.Tuple[str, str]] = []
|
||||||
for field in self.message_type.FIELDS.values():
|
for field in self.message_type.FIELDS.values():
|
||||||
if field.name not in self.fields:
|
if field.name not in self.fields:
|
||||||
continue
|
continue
|
||||||
@ -934,7 +960,7 @@ class MessageFilter:
|
|||||||
|
|
||||||
|
|
||||||
class MessageFilterGenerator:
|
class MessageFilterGenerator:
|
||||||
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
|
def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
|
||||||
message_type = getattr(messages, key)
|
message_type = getattr(messages, key)
|
||||||
return MessageFilter(message_type).update_fields
|
return MessageFilter(message_type).update_fields
|
||||||
|
|
||||||
@ -942,6 +968,245 @@ class MessageFilterGenerator:
|
|||||||
message_filters = MessageFilterGenerator()
|
message_filters = MessageFilterGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionDebugWrapper(Session):
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self._session = session
|
||||||
|
self.reset_debug_features()
|
||||||
|
if isinstance(session, SessionDebugWrapper):
|
||||||
|
raise Exception("Cannot wrap already wrapped session!")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol_version(self) -> int:
|
||||||
|
return self.client.protocol_version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> TrezorClientDebugLink:
|
||||||
|
assert isinstance(self._session.client, TrezorClientDebugLink)
|
||||||
|
return self._session.client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> bytes:
|
||||||
|
return self._session.id
|
||||||
|
|
||||||
|
def _write(self, msg: t.Any) -> None:
|
||||||
|
print("writing message:", msg.__class__.__name__)
|
||||||
|
self._session._write(self._filter_message(msg))
|
||||||
|
|
||||||
|
def _read(self) -> t.Any:
|
||||||
|
resp = self._filter_message(self._session._read())
|
||||||
|
print("reading message:", resp.__class__.__name__)
|
||||||
|
if self.actual_responses is not None:
|
||||||
|
self.actual_responses.append(resp)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def set_expected_responses(
|
||||||
|
self,
|
||||||
|
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
|
||||||
|
) -> None:
|
||||||
|
"""Set a sequence of expected responses to session calls.
|
||||||
|
|
||||||
|
Within a given with-block, the list of received responses from device must
|
||||||
|
match the list of expected responses, otherwise an ``AssertionError`` is raised.
|
||||||
|
|
||||||
|
If an expected response is given a field value other than ``None``, that field value
|
||||||
|
must exactly match the received field value. If a given field is ``None``
|
||||||
|
(or unspecified) in the expected response, the received field value is not
|
||||||
|
checked.
|
||||||
|
|
||||||
|
Each expected response can also be a tuple ``(bool, message)``. In that case, the
|
||||||
|
expected response is only evaluated if the first field is ``True``.
|
||||||
|
This is useful for differentiating sequences between Trezor models:
|
||||||
|
|
||||||
|
>>> trezor_one = session.features.model == "1"
|
||||||
|
>>> session.set_expected_responses([
|
||||||
|
>>> messages.ButtonRequest(code=ConfirmOutput),
|
||||||
|
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
|
||||||
|
>>> messages.Success(),
|
||||||
|
>>> ])
|
||||||
|
"""
|
||||||
|
if not self.in_with_statement:
|
||||||
|
raise RuntimeError("Must be called inside 'with' statement")
|
||||||
|
|
||||||
|
# make sure all items are (bool, message) tuples
|
||||||
|
expected_with_validity = (
|
||||||
|
e if isinstance(e, tuple) else (True, e) for e in expected
|
||||||
|
)
|
||||||
|
|
||||||
|
# only apply those items that are (True, message)
|
||||||
|
self.expected_responses = [
|
||||||
|
MessageFilter.from_message_or_type(expected)
|
||||||
|
for valid, expected in expected_with_validity
|
||||||
|
if valid
|
||||||
|
]
|
||||||
|
self.actual_responses = []
|
||||||
|
|
||||||
|
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||||
|
"""Lock the device.
|
||||||
|
|
||||||
|
If the device does not have a PIN configured, this will do nothing.
|
||||||
|
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
||||||
|
before further actions.
|
||||||
|
|
||||||
|
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
||||||
|
the device will not prompt for it after unlocking.
|
||||||
|
|
||||||
|
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
||||||
|
passphrase cache, use `clear_session()`.
|
||||||
|
"""
|
||||||
|
# TODO update the documentation above
|
||||||
|
# Private argument _refresh_features can be used internally to avoid
|
||||||
|
# refreshing in cases where we will refresh soon anyway. This is used
|
||||||
|
# in TrezorClient.clear_session()
|
||||||
|
self.call(messages.LockDevice())
|
||||||
|
if _refresh_features:
|
||||||
|
self.refresh_features()
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
self._write(messages.Cancel())
|
||||||
|
|
||||||
|
def ensure_unlocked(self) -> None:
|
||||||
|
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||||
|
self.refresh_features()
|
||||||
|
|
||||||
|
def set_filter(
|
||||||
|
self,
|
||||||
|
message_type: t.Type[protobuf.MessageType],
|
||||||
|
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure a filter function for a specified message type.
|
||||||
|
|
||||||
|
The `callback` must be a function that accepts a protobuf message, and returns
|
||||||
|
a (possibly modified) protobuf message of the same type. Whenever a message
|
||||||
|
is sent or received that matches `message_type`, `callback` is invoked on the
|
||||||
|
message and its result is substituted for the original.
|
||||||
|
|
||||||
|
Useful for test scenarios with an active malicious actor on the wire.
|
||||||
|
"""
|
||||||
|
if not self.in_with_statement:
|
||||||
|
raise RuntimeError("Must be called inside 'with' statement")
|
||||||
|
|
||||||
|
self.filters[message_type] = callback
|
||||||
|
|
||||||
|
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
|
||||||
|
message_type = msg.__class__
|
||||||
|
callback = self.filters.get(message_type)
|
||||||
|
if callable(callback):
|
||||||
|
return callback(deepcopy(msg))
|
||||||
|
else:
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def reset_debug_features(self) -> None:
|
||||||
|
"""Prepare the debugging session for a new testcase.
|
||||||
|
|
||||||
|
Clears all debugging state that might have been modified by a testcase.
|
||||||
|
"""
|
||||||
|
self.in_with_statement = False
|
||||||
|
self.expected_responses: list[MessageFilter] | None = None
|
||||||
|
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||||
|
self.filters: t.Dict[
|
||||||
|
t.Type[protobuf.MessageType],
|
||||||
|
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
|
] = {}
|
||||||
|
self.button_callback = self.client.button_callback
|
||||||
|
self.pin_callback = self.client.pin_callback
|
||||||
|
self.passphrase_callback = self._session.passphrase_callback
|
||||||
|
self.passphrase = self._session.passphrase
|
||||||
|
|
||||||
|
def __enter__(self) -> "SessionDebugWrapper":
|
||||||
|
# For usage in with/expected_responses
|
||||||
|
if self.in_with_statement:
|
||||||
|
raise RuntimeError("Do not nest!")
|
||||||
|
self.in_with_statement = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
|
# copy expected/actual responses before clearing them
|
||||||
|
expected_responses = self.expected_responses
|
||||||
|
actual_responses = self.actual_responses
|
||||||
|
|
||||||
|
# grab a copy of the inputflow generator to raise an exception through it
|
||||||
|
if isinstance(self.client.ui, DebugUI):
|
||||||
|
input_flow = self.client.ui.input_flow
|
||||||
|
else:
|
||||||
|
input_flow = None
|
||||||
|
|
||||||
|
self.reset_debug_features()
|
||||||
|
|
||||||
|
if exc_type is None:
|
||||||
|
# If no other exception was raised, evaluate missed responses
|
||||||
|
# (raises AssertionError on mismatch)
|
||||||
|
self._verify_responses(expected_responses, actual_responses)
|
||||||
|
if isinstance(input_flow, t.Generator):
|
||||||
|
# Ensure that the input flow is exhausted
|
||||||
|
try:
|
||||||
|
input_flow.throw(
|
||||||
|
AssertionError("input flow continues past end of test")
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif isinstance(input_flow, t.Generator):
|
||||||
|
# Propagate the exception through the input flow, so that we see in
|
||||||
|
# traceback where it is stuck.
|
||||||
|
input_flow.throw(exc_type, value, traceback)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _verify_responses(
|
||||||
|
cls,
|
||||||
|
expected: list[MessageFilter] | None,
|
||||||
|
actual: list[protobuf.MessageType] | None,
|
||||||
|
) -> None:
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
|
if expected is None and actual is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert expected is not None
|
||||||
|
assert actual is not None
|
||||||
|
|
||||||
|
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
||||||
|
if exp is None:
|
||||||
|
output = cls._expectation_lines(expected, i)
|
||||||
|
output.append("No more messages were expected, but we got:")
|
||||||
|
for resp in actual[i:]:
|
||||||
|
output.append(
|
||||||
|
textwrap.indent(protobuf.format_message(resp), " ")
|
||||||
|
)
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
|
if act is None:
|
||||||
|
output = cls._expectation_lines(expected, i)
|
||||||
|
output.append("This and the following message was not received.")
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
|
if not exp.match(act):
|
||||||
|
output = cls._expectation_lines(expected, i)
|
||||||
|
output.append("Actually received:")
|
||||||
|
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||||
|
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||||
|
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||||
|
output: list[str] = []
|
||||||
|
output.append("Expected responses:")
|
||||||
|
if start_at > 0:
|
||||||
|
output.append(f" (...{start_at} previous responses omitted)")
|
||||||
|
for i in range(start_at, stop_at):
|
||||||
|
exp = expected[i]
|
||||||
|
prefix = " " if i != current else ">>> "
|
||||||
|
output.append(textwrap.indent(exp.to_string(), prefix))
|
||||||
|
if stop_at < len(expected):
|
||||||
|
omitted = len(expected) - stop_at
|
||||||
|
output.append(f" (...{omitted} following responses omitted)")
|
||||||
|
|
||||||
|
output.append("")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TrezorClientDebugLink(TrezorClient):
|
class TrezorClientDebugLink(TrezorClient):
|
||||||
# This class implements automatic responses
|
# This class implements automatic responses
|
||||||
# and other functionality for unit tests
|
# and other functionality for unit tests
|
||||||
@ -967,23 +1232,34 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# set transport explicitly so that sync_responses can work
|
# set transport explicitly so that sync_responses can work
|
||||||
self.transport = transport
|
super().__init__(transport)
|
||||||
|
|
||||||
self.reset_debug_features()
|
self.transport = transport
|
||||||
|
self.ui: DebugUI = DebugUI(self.debug)
|
||||||
|
|
||||||
|
self.reset_debug_features(new_seedless_session=True)
|
||||||
self.sync_responses()
|
self.sync_responses()
|
||||||
super().__init__(transport, ui=self.ui)
|
|
||||||
|
|
||||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
# So that we can choose right screenshotting logic (T1 vs TT)
|
||||||
# and know the supported debug capabilities
|
# and know the supported debug capabilities
|
||||||
self.debug.model = self.model
|
self.debug.model = self.model
|
||||||
self.debug.version = self.version
|
self.debug.version = self.version
|
||||||
|
self.passphrase: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layout_type(self) -> LayoutType:
|
def layout_type(self) -> LayoutType:
|
||||||
return self.debug.layout_type
|
return self.debug.layout_type
|
||||||
|
|
||||||
def reset_debug_features(self) -> None:
|
def get_new_client(self) -> TrezorClientDebugLink:
|
||||||
"""Prepare the debugging client for a new testcase.
|
new_client = TrezorClientDebugLink(
|
||||||
|
self.transport, self.debug.allow_interactions
|
||||||
|
)
|
||||||
|
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
|
||||||
|
return new_client
|
||||||
|
|
||||||
|
def reset_debug_features(self, new_seedless_session: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Prepare the debugging client for a new testcase.
|
||||||
|
|
||||||
Clears all debugging state that might have been modified by a testcase.
|
Clears all debugging state that might have been modified by a testcase.
|
||||||
"""
|
"""
|
||||||
@ -991,30 +1267,139 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.in_with_statement = False
|
self.in_with_statement = False
|
||||||
self.expected_responses: list[MessageFilter] | None = None
|
self.expected_responses: list[MessageFilter] | None = None
|
||||||
self.actual_responses: list[protobuf.MessageType] | None = None
|
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||||
self.filters: dict[
|
self.filters: t.Dict[
|
||||||
type[protobuf.MessageType],
|
t.Type[protobuf.MessageType],
|
||||||
Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
] = {}
|
] = {}
|
||||||
|
if new_seedless_session:
|
||||||
|
self._seedless_session = self.get_seedless_session(new_session=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def button_callback(self):
|
||||||
|
|
||||||
|
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
# do this raw - send ButtonAck first, notify UI later
|
||||||
|
session._write(messages.ButtonAck())
|
||||||
|
self.ui.button_request(msg)
|
||||||
|
return session._read()
|
||||||
|
|
||||||
|
return _callback_button
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pin_callback(self):
|
||||||
|
|
||||||
|
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
|
||||||
|
try:
|
||||||
|
pin = self.ui.get_pin(msg.type)
|
||||||
|
except Cancelled:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise
|
||||||
|
|
||||||
|
if any(d not in "123456789" for d in pin) or not (
|
||||||
|
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||||
|
):
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise ValueError("Invalid PIN provided")
|
||||||
|
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||||
|
if isinstance(resp, messages.Failure) and resp.code in (
|
||||||
|
messages.FailureType.PinInvalid,
|
||||||
|
messages.FailureType.PinCancelled,
|
||||||
|
messages.FailureType.PinExpected,
|
||||||
|
):
|
||||||
|
raise PinException(resp.code, resp.message)
|
||||||
|
else:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
return _callback_pin
|
||||||
|
|
||||||
|
@property
|
||||||
|
def passphrase_callback(self):
|
||||||
|
def _callback_passphrase(
|
||||||
|
session: Session, msg: messages.PassphraseRequest
|
||||||
|
) -> t.Any:
|
||||||
|
available_on_device = (
|
||||||
|
Capability.PassphraseEntry in session.features.capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_passphrase(
|
||||||
|
passphrase: str | None = None, on_device: bool | None = None
|
||||||
|
) -> MessageType:
|
||||||
|
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||||
|
resp = session.call_raw(msg)
|
||||||
|
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||||
|
# session.session_id = resp.state
|
||||||
|
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# short-circuit old style entry
|
||||||
|
if msg._on_device is True:
|
||||||
|
return send_passphrase(None, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if session.passphrase is None and isinstance(session, SessionV1):
|
||||||
|
passphrase = self.ui.get_passphrase(
|
||||||
|
available_on_device=available_on_device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
passphrase = session.passphrase
|
||||||
|
except Cancelled:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise
|
||||||
|
|
||||||
|
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||||
|
if not available_on_device:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise RuntimeError("Device is not capable of entering passphrase")
|
||||||
|
else:
|
||||||
|
return send_passphrase(on_device=True)
|
||||||
|
|
||||||
|
# else process host-entered passphrase
|
||||||
|
if not isinstance(passphrase, str):
|
||||||
|
raise RuntimeError("Passphrase must be a str")
|
||||||
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise ValueError("Passphrase too long")
|
||||||
|
|
||||||
|
return send_passphrase(passphrase, on_device=False)
|
||||||
|
|
||||||
|
return _callback_passphrase
|
||||||
|
|
||||||
def ensure_open(self) -> None:
|
def ensure_open(self) -> None:
|
||||||
"""Only open session if there isn't already an open one."""
|
"""Only open session if there isn't already an open one."""
|
||||||
if self.session_counter == 0:
|
# if self.session_counter == 0:
|
||||||
self.open()
|
# self.open()
|
||||||
|
# TODO check if is this needed
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
super().open()
|
pass
|
||||||
if self.session_counter == 1:
|
# TODO is this needed?
|
||||||
self.debug.open()
|
# self.debug.open()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self.session_counter == 1:
|
pass
|
||||||
self.debug.close()
|
# TODO is this needed?
|
||||||
super().close()
|
# self.debug.close()
|
||||||
|
|
||||||
|
def lock(self) -> None:
|
||||||
|
s = SessionDebugWrapper(self.get_seedless_session())
|
||||||
|
s.lock()
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
passphrase: str | object | None = "",
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
session_id: int = 0,
|
||||||
|
) -> Session:
|
||||||
|
if isinstance(passphrase, str):
|
||||||
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
return super().get_session(passphrase, derive_cardano, session_id)
|
||||||
|
|
||||||
def set_filter(
|
def set_filter(
|
||||||
self,
|
self,
|
||||||
message_type: type[protobuf.MessageType],
|
message_type: t.Type[protobuf.MessageType],
|
||||||
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Configure a filter function for a specified message type.
|
"""Configure a filter function for a specified message type.
|
||||||
|
|
||||||
@ -1039,7 +1424,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
def set_input_flow(
|
def set_input_flow(
|
||||||
self, input_flow: InputFlowType | Callable[[], InputFlowType]
|
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Configure a sequence of input events for the current with-block.
|
"""Configure a sequence of input events for the current with-block.
|
||||||
|
|
||||||
@ -1095,7 +1480,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.in_with_statement = True
|
self.in_with_statement = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
|
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
# copy expected/actual responses before clearing them
|
# copy expected/actual responses before clearing them
|
||||||
@ -1108,21 +1493,23 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
else:
|
else:
|
||||||
input_flow = None
|
input_flow = None
|
||||||
|
|
||||||
self.reset_debug_features()
|
self.reset_debug_features(new_seedless_session=False)
|
||||||
|
|
||||||
if exc_type is None:
|
if exc_type is None:
|
||||||
# If no other exception was raised, evaluate missed responses
|
# If no other exception was raised, evaluate missed responses
|
||||||
# (raises AssertionError on mismatch)
|
# (raises AssertionError on mismatch)
|
||||||
self._verify_responses(expected_responses, actual_responses)
|
self._verify_responses(expected_responses, actual_responses)
|
||||||
|
|
||||||
elif isinstance(input_flow, Generator):
|
elif isinstance(input_flow, t.Generator):
|
||||||
# Propagate the exception through the input flow, so that we see in
|
# Propagate the exception through the input flow, so that we see in
|
||||||
# traceback where it is stuck.
|
# traceback where it is stuck.
|
||||||
input_flow.throw(exc_type, value, traceback)
|
input_flow.throw(exc_type, value, traceback)
|
||||||
|
|
||||||
def set_expected_responses(
|
def set_expected_responses(
|
||||||
self,
|
self,
|
||||||
expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]],
|
expected: t.Sequence[
|
||||||
|
t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]]
|
||||||
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set a sequence of expected responses to client calls.
|
"""Set a sequence of expected responses to client calls.
|
||||||
|
|
||||||
@ -1161,7 +1548,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
]
|
]
|
||||||
self.actual_responses = []
|
self.actual_responses = []
|
||||||
|
|
||||||
def use_pin_sequence(self, pins: Iterable[str]) -> None:
|
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
|
||||||
"""Respond to PIN prompts from device with the provided PINs.
|
"""Respond to PIN prompts from device with the provided PINs.
|
||||||
The sequence must be at least as long as the expected number of PIN prompts.
|
The sequence must be at least as long as the expected number of PIN prompts.
|
||||||
"""
|
"""
|
||||||
@ -1169,6 +1556,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
def use_passphrase(self, passphrase: str) -> None:
|
def use_passphrase(self, passphrase: str) -> None:
|
||||||
"""Respond to passphrase prompts from device with the provided passphrase."""
|
"""Respond to passphrase prompts from device with the provided passphrase."""
|
||||||
|
self.passphrase = passphrase
|
||||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
|
||||||
def use_mnemonic(self, mnemonic: str) -> None:
|
def use_mnemonic(self, mnemonic: str) -> None:
|
||||||
@ -1178,15 +1566,14 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
def _raw_read(self) -> protobuf.MessageType:
|
def _raw_read(self) -> protobuf.MessageType:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
resp = self.get_seedless_session()._read()
|
||||||
resp = super()._raw_read()
|
|
||||||
resp = self._filter_message(resp)
|
resp = self._filter_message(resp)
|
||||||
if self.actual_responses is not None:
|
if self.actual_responses is not None:
|
||||||
self.actual_responses.append(resp)
|
self.actual_responses.append(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
||||||
return super()._raw_write(self._filter_message(msg))
|
return self.get_seedless_session()._write(self._filter_message(msg))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||||
@ -1256,23 +1643,25 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
||||||
# prompt, which is in TINY mode and does not respond to `Ping`.
|
# prompt, which is in TINY mode and does not respond to `Ping`.
|
||||||
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
# TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
||||||
self.transport.begin_session()
|
self.transport.open()
|
||||||
try:
|
try:
|
||||||
self.transport.write(*cancel_msg)
|
# self.protocol.write(messages.Cancel())
|
||||||
|
|
||||||
message = "SYNC" + secrets.token_hex(8)
|
message = "SYNC" + secrets.token_hex(8)
|
||||||
ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message))
|
self.get_seedless_session()._write(messages.Ping(message=message))
|
||||||
self.transport.write(*ping_msg)
|
|
||||||
resp = None
|
resp = None
|
||||||
while resp != messages.Success(message=message):
|
while resp != messages.Success(message=message):
|
||||||
msg_id, msg_bytes = self.transport.read()
|
|
||||||
try:
|
try:
|
||||||
resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes)
|
resp = self.get_seedless_session()._read()
|
||||||
|
|
||||||
|
raise Exception
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.transport.end_session()
|
pass # TODO fix
|
||||||
|
# self.transport.end_session(self.session_id or b"")
|
||||||
|
|
||||||
def mnemonic_callback(self, _) -> str:
|
def mnemonic_callback(self, _) -> str:
|
||||||
word, pos = self.debug.read_recovery_word()
|
word, pos = self.debug.read_recovery_word()
|
||||||
@ -1285,8 +1674,8 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
|
|
||||||
def load_device(
|
def load_device(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
mnemonic: Union[str, Iterable[str]],
|
mnemonic: str | t.Iterable[str],
|
||||||
pin: str | None,
|
pin: str | None,
|
||||||
passphrase_protection: bool,
|
passphrase_protection: bool,
|
||||||
label: str | None,
|
label: str | None,
|
||||||
@ -1299,12 +1688,12 @@ def load_device(
|
|||||||
|
|
||||||
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
||||||
|
|
||||||
if client.features.initialized:
|
if session.features.initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Device is initialized already. Call device.wipe() and try again."
|
"Device is initialized already. Call device.wipe() and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
client.call(
|
session.call(
|
||||||
messages.LoadDevice(
|
messages.LoadDevice(
|
||||||
mnemonics=mnemonics,
|
mnemonics=mnemonics,
|
||||||
pin=pin,
|
pin=pin,
|
||||||
@ -1316,18 +1705,18 @@ def load_device(
|
|||||||
),
|
),
|
||||||
expect=messages.Success,
|
expect=messages.Success,
|
||||||
)
|
)
|
||||||
client.init_device()
|
session.refresh_features()
|
||||||
|
|
||||||
|
|
||||||
# keep the old name for compatibility
|
# keep the old name for compatibility
|
||||||
load_device_by_mnemonic = load_device
|
load_device_by_mnemonic = load_device
|
||||||
|
|
||||||
|
|
||||||
def prodtest_t1(client: "TrezorClient") -> None:
|
def prodtest_t1(session: "Session") -> None:
|
||||||
if client.features.bootloader_mode is not True:
|
if session.features.bootloader_mode is not True:
|
||||||
raise RuntimeError("Device must be in bootloader mode")
|
raise RuntimeError("Device must be in bootloader mode")
|
||||||
|
|
||||||
client.call(
|
session.call(
|
||||||
messages.ProdTestT1(
|
messages.ProdTestT1(
|
||||||
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
|
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
|
||||||
),
|
),
|
||||||
@ -1337,8 +1726,8 @@ def prodtest_t1(client: "TrezorClient") -> None:
|
|||||||
|
|
||||||
def record_screen(
|
def record_screen(
|
||||||
debug_client: "TrezorClientDebugLink",
|
debug_client: "TrezorClientDebugLink",
|
||||||
directory: Union[str, None],
|
directory: str | None,
|
||||||
report_func: Union[Callable[[str], None], None] = None,
|
report_func: t.Callable[[str], None] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record screen changes into a specified directory.
|
"""Record screen changes into a specified directory.
|
||||||
|
|
||||||
@ -1383,5 +1772,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
|
|||||||
return debug_client.features.fw_vendor == "EMULATOR"
|
return debug_client.features.fw_vendor == "EMULATOR"
|
||||||
|
|
||||||
|
|
||||||
def optiga_set_sec_max(client: "TrezorClient") -> None:
|
def optiga_set_sec_max(session: "Session") -> None:
|
||||||
client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)
|
session.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)
|
||||||
|
@ -28,16 +28,10 @@ from slip10 import SLIP10
|
|||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .exceptions import Cancelled, TrezorException
|
from .exceptions import Cancelled, TrezorException
|
||||||
from .tools import (
|
from .tools import Address, _deprecation_retval_helper, _return_success, parse_path
|
||||||
Address,
|
|
||||||
_deprecation_retval_helper,
|
|
||||||
_return_success,
|
|
||||||
parse_path,
|
|
||||||
session,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||||
@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
|
|||||||
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
|
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def apply_settings(
|
def apply_settings(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
label: Optional[str] = None,
|
label: Optional[str] = None,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
use_passphrase: Optional[bool] = None,
|
use_passphrase: Optional[bool] = None,
|
||||||
@ -79,13 +72,13 @@ def apply_settings(
|
|||||||
haptic_feedback=haptic_feedback,
|
haptic_feedback=haptic_feedback,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = client.call(settings, expect=messages.Success)
|
out = session.call(settings, expect=messages.Success)
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return _return_success(out)
|
return _return_success(out)
|
||||||
|
|
||||||
|
|
||||||
def _send_language_data(
|
def _send_language_data(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
request: "messages.TranslationDataRequest",
|
request: "messages.TranslationDataRequest",
|
||||||
language_data: bytes,
|
language_data: bytes,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -95,69 +88,63 @@ def _send_language_data(
|
|||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
data_offset = response.data_offset
|
data_offset = response.data_offset
|
||||||
chunk = language_data[data_offset : data_offset + data_length]
|
chunk = language_data[data_offset : data_offset + data_length]
|
||||||
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
|
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def change_language(
|
def change_language(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
language_data: bytes,
|
language_data: bytes,
|
||||||
show_display: bool | None = None,
|
show_display: bool | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
data_length = len(language_data)
|
data_length = len(language_data)
|
||||||
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
|
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
if data_length > 0:
|
if data_length > 0:
|
||||||
response = messages.TranslationDataRequest.ensure_isinstance(response)
|
response = messages.TranslationDataRequest.ensure_isinstance(response)
|
||||||
_send_language_data(client, response, language_data)
|
_send_language_data(session, response, language_data)
|
||||||
else:
|
else:
|
||||||
messages.Success.ensure_isinstance(response)
|
messages.Success.ensure_isinstance(response)
|
||||||
client.refresh_features() # changing the language in features
|
session.refresh_features() # changing the language in features
|
||||||
return _return_success(messages.Success(message="Language changed."))
|
return _return_success(messages.Success(message="Language changed."))
|
||||||
|
|
||||||
|
|
||||||
@session
|
def apply_flags(session: "Session", flags: int) -> str | None:
|
||||||
def apply_flags(client: "TrezorClient", flags: int) -> str | None:
|
out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
|
||||||
out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
|
session.refresh_features()
|
||||||
client.refresh_features()
|
|
||||||
return _return_success(out)
|
return _return_success(out)
|
||||||
|
|
||||||
|
|
||||||
@session
|
def change_pin(session: "Session", remove: bool = False) -> str | None:
|
||||||
def change_pin(client: "TrezorClient", remove: bool = False) -> str | None:
|
ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success)
|
||||||
ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success)
|
session.refresh_features()
|
||||||
client.refresh_features()
|
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
@session
|
def change_wipe_code(session: "Session", remove: bool = False) -> str | None:
|
||||||
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None:
|
ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
|
||||||
ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
|
session.refresh_features()
|
||||||
client.refresh_features()
|
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sd_protect(
|
def sd_protect(
|
||||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
session: "Session", operation: messages.SdProtectOperationType
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success)
|
ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success)
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
@session
|
def wipe(session: "Session") -> str | None:
|
||||||
def wipe(client: "TrezorClient") -> str | None:
|
ret = session.call(messages.WipeDevice(), expect=messages.Success)
|
||||||
ret = client.call(messages.WipeDevice(), expect=messages.Success)
|
session.invalidate()
|
||||||
if not client.features.bootloader_mode:
|
# if not session.features.bootloader_mode:
|
||||||
client.init_device()
|
# session.refresh_features()
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def recover(
|
def recover(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
word_count: int = 24,
|
word_count: int = 24,
|
||||||
passphrase_protection: bool = False,
|
passphrase_protection: bool = False,
|
||||||
pin_protection: bool = True,
|
pin_protection: bool = True,
|
||||||
@ -193,13 +180,13 @@ def recover(
|
|||||||
if type is None:
|
if type is None:
|
||||||
type = messages.RecoveryType.NormalRecovery
|
type = messages.RecoveryType.NormalRecovery
|
||||||
|
|
||||||
if client.features.model == "1" and input_callback is None:
|
if session.features.model == "1" and input_callback is None:
|
||||||
raise RuntimeError("Input callback required for Trezor One")
|
raise RuntimeError("Input callback required for Trezor One")
|
||||||
|
|
||||||
if word_count not in (12, 18, 24):
|
if word_count not in (12, 18, 24):
|
||||||
raise ValueError("Invalid word count. Use 12/18/24")
|
raise ValueError("Invalid word count. Use 12/18/24")
|
||||||
|
|
||||||
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Device already initialized. Call device.wipe() and try again."
|
"Device already initialized. Call device.wipe() and try again."
|
||||||
)
|
)
|
||||||
@ -221,20 +208,20 @@ def recover(
|
|||||||
msg.label = label
|
msg.label = label
|
||||||
msg.u2f_counter = u2f_counter
|
msg.u2f_counter = u2f_counter
|
||||||
|
|
||||||
res = client.call(msg)
|
res = session.call(msg)
|
||||||
|
|
||||||
while isinstance(res, messages.WordRequest):
|
while isinstance(res, messages.WordRequest):
|
||||||
try:
|
try:
|
||||||
assert input_callback is not None
|
assert input_callback is not None
|
||||||
inp = input_callback(res.type)
|
inp = input_callback(res.type)
|
||||||
res = client.call(messages.WordAck(word=inp))
|
res = session.call(messages.WordAck(word=inp))
|
||||||
except Cancelled:
|
except Cancelled:
|
||||||
res = client.call(messages.Cancel())
|
res = session.call(messages.Cancel())
|
||||||
|
|
||||||
# check that the result is a Success
|
# check that the result is a Success
|
||||||
res = messages.Success.ensure_isinstance(res)
|
res = messages.Success.ensure_isinstance(res)
|
||||||
# reinitialize the device
|
# reinitialize the device
|
||||||
client.init_device()
|
session.refresh_features()
|
||||||
|
|
||||||
return _deprecation_retval_helper(res)
|
return _deprecation_retval_helper(res)
|
||||||
|
|
||||||
@ -280,7 +267,7 @@ def _seed_from_entropy(
|
|||||||
|
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
display_random: bool = False,
|
display_random: bool = False,
|
||||||
strength: Optional[int] = None,
|
strength: Optional[int] = None,
|
||||||
passphrase_protection: bool = False,
|
passphrase_protection: bool = False,
|
||||||
@ -313,7 +300,7 @@ def reset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
client,
|
session,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
passphrase_protection=passphrase_protection,
|
passphrase_protection=passphrase_protection,
|
||||||
pin_protection=pin_protection,
|
pin_protection=pin_protection,
|
||||||
@ -331,9 +318,8 @@ def _get_external_entropy() -> bytes:
|
|||||||
return secrets.token_bytes(32)
|
return secrets.token_bytes(32)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def setup(
|
def setup(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
*,
|
*,
|
||||||
strength: Optional[int] = None,
|
strength: Optional[int] = None,
|
||||||
passphrase_protection: bool = True,
|
passphrase_protection: bool = True,
|
||||||
@ -388,19 +374,19 @@ def setup(
|
|||||||
check.
|
check.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if client.features.initialized:
|
if session.features.initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Device is initialized already. Call wipe_device() and try again."
|
"Device is initialized already. Call wipe_device() and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
if strength is None:
|
if strength is None:
|
||||||
if client.features.model == "1":
|
if session.features.model == "1":
|
||||||
strength = 256
|
strength = 256
|
||||||
else:
|
else:
|
||||||
strength = 128
|
strength = 128
|
||||||
|
|
||||||
if backup_type is None:
|
if backup_type is None:
|
||||||
if client.version < SLIP39_EXTENDABLE_MIN_VERSION:
|
if session.version < SLIP39_EXTENDABLE_MIN_VERSION:
|
||||||
# includes Trezor One 1.x.x
|
# includes Trezor One 1.x.x
|
||||||
backup_type = messages.BackupType.Bip39
|
backup_type = messages.BackupType.Bip39
|
||||||
else:
|
else:
|
||||||
@ -411,7 +397,7 @@ def setup(
|
|||||||
paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")]
|
paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")]
|
||||||
|
|
||||||
if entropy_check_count is None:
|
if entropy_check_count is None:
|
||||||
if client.version < ENTROPY_CHECK_MIN_VERSION:
|
if session.version < ENTROPY_CHECK_MIN_VERSION:
|
||||||
# includes Trezor One 1.x.x
|
# includes Trezor One 1.x.x
|
||||||
entropy_check_count = 0
|
entropy_check_count = 0
|
||||||
else:
|
else:
|
||||||
@ -431,18 +417,18 @@ def setup(
|
|||||||
)
|
)
|
||||||
if entropy_check_count > 0:
|
if entropy_check_count > 0:
|
||||||
xpubs = _reset_with_entropycheck(
|
xpubs = _reset_with_entropycheck(
|
||||||
client, msg, entropy_check_count, paths, _get_entropy
|
session, msg, entropy_check_count, paths, _get_entropy
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_reset_no_entropycheck(client, msg, _get_entropy)
|
_reset_no_entropycheck(session, msg, _get_entropy)
|
||||||
xpubs = []
|
xpubs = []
|
||||||
|
|
||||||
client.init_device()
|
session.refresh_features()
|
||||||
return xpubs
|
return xpubs
|
||||||
|
|
||||||
|
|
||||||
def _reset_no_entropycheck(
|
def _reset_no_entropycheck(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
msg: messages.ResetDevice,
|
msg: messages.ResetDevice,
|
||||||
get_entropy: Callable[[], bytes],
|
get_entropy: Callable[[], bytes],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -454,12 +440,12 @@ def _reset_no_entropycheck(
|
|||||||
<< Success
|
<< Success
|
||||||
"""
|
"""
|
||||||
assert msg.entropy_check is False
|
assert msg.entropy_check is False
|
||||||
client.call(msg, expect=messages.EntropyRequest)
|
session.call(msg, expect=messages.EntropyRequest)
|
||||||
client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
|
session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
|
||||||
|
|
||||||
|
|
||||||
def _reset_with_entropycheck(
|
def _reset_with_entropycheck(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
reset_msg: messages.ResetDevice,
|
reset_msg: messages.ResetDevice,
|
||||||
entropy_check_count: int,
|
entropy_check_count: int,
|
||||||
paths: Iterable[Address],
|
paths: Iterable[Address],
|
||||||
@ -495,7 +481,7 @@ def _reset_with_entropycheck(
|
|||||||
def get_xpubs() -> list[tuple[Address, str]]:
|
def get_xpubs() -> list[tuple[Address, str]]:
|
||||||
xpubs = []
|
xpubs = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.GetPublicKey(address_n=path), expect=messages.PublicKey
|
messages.GetPublicKey(address_n=path), expect=messages.PublicKey
|
||||||
)
|
)
|
||||||
xpubs.append((path, resp.xpub))
|
xpubs.append((path, resp.xpub))
|
||||||
@ -524,13 +510,13 @@ def _reset_with_entropycheck(
|
|||||||
raise TrezorException("Invalid XPUB in entropy check")
|
raise TrezorException("Invalid XPUB in entropy check")
|
||||||
|
|
||||||
xpubs = []
|
xpubs = []
|
||||||
resp = client.call(reset_msg, expect=messages.EntropyRequest)
|
resp = session.call(reset_msg, expect=messages.EntropyRequest)
|
||||||
entropy_commitment = resp.entropy_commitment
|
entropy_commitment = resp.entropy_commitment
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# provide external entropy for this round
|
# provide external entropy for this round
|
||||||
external_entropy = get_entropy()
|
external_entropy = get_entropy()
|
||||||
client.call(
|
session.call(
|
||||||
messages.EntropyAck(entropy=external_entropy),
|
messages.EntropyAck(entropy=external_entropy),
|
||||||
expect=messages.EntropyCheckReady,
|
expect=messages.EntropyCheckReady,
|
||||||
)
|
)
|
||||||
@ -540,7 +526,7 @@ def _reset_with_entropycheck(
|
|||||||
|
|
||||||
if entropy_check_count <= 0:
|
if entropy_check_count <= 0:
|
||||||
# last round, wait for a Success and exit the loop
|
# last round, wait for a Success and exit the loop
|
||||||
client.call(
|
session.call(
|
||||||
messages.EntropyCheckContinue(finish=True),
|
messages.EntropyCheckContinue(finish=True),
|
||||||
expect=messages.Success,
|
expect=messages.Success,
|
||||||
)
|
)
|
||||||
@ -549,7 +535,7 @@ def _reset_with_entropycheck(
|
|||||||
entropy_check_count -= 1
|
entropy_check_count -= 1
|
||||||
|
|
||||||
# Next round starts.
|
# Next round starts.
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.EntropyCheckContinue(finish=False),
|
messages.EntropyCheckContinue(finish=False),
|
||||||
expect=messages.EntropyRequest,
|
expect=messages.EntropyRequest,
|
||||||
)
|
)
|
||||||
@ -570,13 +556,12 @@ def _reset_with_entropycheck(
|
|||||||
return xpubs
|
return xpubs
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def backup(
|
def backup(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
group_threshold: Optional[int] = None,
|
group_threshold: Optional[int] = None,
|
||||||
groups: Iterable[tuple[int, int]] = (),
|
groups: Iterable[tuple[int, int]] = (),
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
ret = client.call(
|
ret = session.call(
|
||||||
messages.BackupDevice(
|
messages.BackupDevice(
|
||||||
group_threshold=group_threshold,
|
group_threshold=group_threshold,
|
||||||
groups=[
|
groups=[
|
||||||
@ -586,37 +571,36 @@ def backup(
|
|||||||
),
|
),
|
||||||
expect=messages.Success,
|
expect=messages.Success,
|
||||||
)
|
)
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
def cancel_authorization(client: "TrezorClient") -> str | None:
|
def cancel_authorization(session: "Session") -> str | None:
|
||||||
ret = client.call(messages.CancelAuthorization(), expect=messages.Success)
|
ret = session.call(messages.CancelAuthorization(), expect=messages.Success)
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
def unlock_path(client: "TrezorClient", n: "Address") -> bytes:
|
def unlock_path(session: "Session", n: "Address") -> bytes:
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
|
messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cancel the UnlockPath workflow now that we have the authentication code.
|
# Cancel the UnlockPath workflow now that we have the authentication code.
|
||||||
try:
|
try:
|
||||||
client.call(messages.Cancel())
|
session.call(messages.Cancel())
|
||||||
except Cancelled:
|
except Cancelled:
|
||||||
return resp.mac
|
return resp.mac
|
||||||
else:
|
else:
|
||||||
raise TrezorException("Unexpected response in UnlockPath flow")
|
raise TrezorException("Unexpected response in UnlockPath flow")
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def reboot_to_bootloader(
|
def reboot_to_bootloader(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
|
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
|
||||||
firmware_header: Optional[bytes] = None,
|
firmware_header: Optional[bytes] = None,
|
||||||
language_data: bytes = b"",
|
language_data: bytes = b"",
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
response = client.call(
|
response = session.call(
|
||||||
messages.RebootToBootloader(
|
messages.RebootToBootloader(
|
||||||
boot_command=boot_command,
|
boot_command=boot_command,
|
||||||
firmware_header=firmware_header,
|
firmware_header=firmware_header,
|
||||||
@ -624,43 +608,38 @@ def reboot_to_bootloader(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if isinstance(response, messages.TranslationDataRequest):
|
if isinstance(response, messages.TranslationDataRequest):
|
||||||
response = _send_language_data(client, response, language_data)
|
response = _send_language_data(session, response, language_data)
|
||||||
return _return_success(messages.Success(message=""))
|
return _return_success(messages.Success(message=""))
|
||||||
|
|
||||||
|
|
||||||
@session
|
def show_device_tutorial(session: "Session") -> str | None:
|
||||||
def show_device_tutorial(client: "TrezorClient") -> str | None:
|
ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success)
|
||||||
ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
|
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
@session
|
def unlock_bootloader(session: "Session") -> str | None:
|
||||||
def unlock_bootloader(client: "TrezorClient") -> str | None:
|
ret = session.call(messages.UnlockBootloader(), expect=messages.Success)
|
||||||
ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
|
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
@session
|
def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None:
|
||||||
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
|
|
||||||
"""Sets or clears the busy state of the device.
|
"""Sets or clears the busy state of the device.
|
||||||
|
|
||||||
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
|
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
|
||||||
Setting `expiry_ms=None` clears the busy state.
|
Setting `expiry_ms=None` clears the busy state.
|
||||||
"""
|
"""
|
||||||
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
|
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
def authenticate(
|
def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof:
|
||||||
client: "TrezorClient", challenge: bytes
|
return session.call(
|
||||||
) -> messages.AuthenticityProof:
|
|
||||||
return client.call(
|
|
||||||
messages.AuthenticateDevice(challenge=challenge),
|
messages.AuthenticateDevice(challenge=challenge),
|
||||||
expect=messages.AuthenticityProof,
|
expect=messages.AuthenticityProof,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None:
|
def set_brightness(session: "Session", value: Optional[int] = None) -> str | None:
|
||||||
ret = client.call(messages.SetBrightness(value=value), expect=messages.Success)
|
ret = session.call(messages.SetBrightness(value=value), expect=messages.Success)
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
@ -18,11 +18,11 @@ from datetime import datetime
|
|||||||
from typing import TYPE_CHECKING, List, Tuple
|
from typing import TYPE_CHECKING, List, Tuple
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import b58decode, session
|
from .tools import b58decode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def name_to_number(name: str) -> int:
|
def name_to_number(name: str) -> int:
|
||||||
@ -319,17 +319,16 @@ def parse_transaction_json(
|
|||||||
|
|
||||||
|
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
session: "Session", n: "Address", show_display: bool = False
|
||||||
) -> messages.EosPublicKey:
|
) -> messages.EosPublicKey:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EosGetPublicKey(address_n=n, show_display=show_display),
|
messages.EosGetPublicKey(address_n=n, show_display=show_display),
|
||||||
expect=messages.EosPublicKey,
|
expect=messages.EosPublicKey,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: "Address",
|
address: "Address",
|
||||||
transaction: dict,
|
transaction: dict,
|
||||||
chain_id: str,
|
chain_id: str,
|
||||||
@ -345,11 +344,11 @@ def sign_tx(
|
|||||||
chunkify=chunkify,
|
chunkify=chunkify,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while isinstance(response, messages.EosTxActionRequest):
|
while isinstance(response, messages.EosTxActionRequest):
|
||||||
response = client.call(actions.pop(0))
|
response = session.call(actions.pop(0))
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# pop from empty list
|
# pop from empty list
|
||||||
raise exceptions.TrezorException(
|
raise exceptions.TrezorException(
|
||||||
|
@ -18,11 +18,11 @@ import re
|
|||||||
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from . import definitions, exceptions, messages
|
from . import definitions, exceptions, messages
|
||||||
from .tools import prepare_message_bytes, session, unharden
|
from .tools import prepare_message_bytes, unharden
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def int_to_big_endian(value: int) -> bytes:
|
def int_to_big_endian(value: int) -> bytes:
|
||||||
@ -161,13 +161,13 @@ def network_from_address_n(
|
|||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
encoded_network: Optional[bytes] = None,
|
encoded_network: Optional[bytes] = None,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.EthereumGetAddress(
|
messages.EthereumGetAddress(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
@ -181,17 +181,16 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def get_public_node(
|
def get_public_node(
|
||||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
session: "Session", n: "Address", show_display: bool = False
|
||||||
) -> messages.EthereumPublicKey:
|
) -> messages.EthereumPublicKey:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
|
messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
|
||||||
expect=messages.EthereumPublicKey,
|
expect=messages.EthereumPublicKey,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
nonce: int,
|
nonce: int,
|
||||||
gas_price: int,
|
gas_price: int,
|
||||||
@ -227,13 +226,13 @@ def sign_tx(
|
|||||||
data, chunk = data[1024:], data[:1024]
|
data, chunk = data[1024:], data[:1024]
|
||||||
msg.data_initial_chunk = chunk
|
msg.data_initial_chunk = chunk
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
while response.data_length is not None:
|
while response.data_length is not None:
|
||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
data, chunk = data[data_length:], data[:data_length]
|
data, chunk = data[data_length:], data[:data_length]
|
||||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
assert response.signature_v is not None
|
assert response.signature_v is not None
|
||||||
@ -248,9 +247,8 @@ def sign_tx(
|
|||||||
return response.signature_v, response.signature_r, response.signature_s
|
return response.signature_v, response.signature_r, response.signature_s
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx_eip1559(
|
def sign_tx_eip1559(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
*,
|
*,
|
||||||
nonce: int,
|
nonce: int,
|
||||||
@ -283,13 +281,13 @@ def sign_tx_eip1559(
|
|||||||
chunkify=chunkify,
|
chunkify=chunkify,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
while response.data_length is not None:
|
while response.data_length is not None:
|
||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
data, chunk = data[data_length:], data[:data_length]
|
data, chunk = data[data_length:], data[:data_length]
|
||||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
assert response.signature_v is not None
|
assert response.signature_v is not None
|
||||||
@ -299,13 +297,13 @@ def sign_tx_eip1559(
|
|||||||
|
|
||||||
|
|
||||||
def sign_message(
|
def sign_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
message: AnyStr,
|
message: AnyStr,
|
||||||
encoded_network: Optional[bytes] = None,
|
encoded_network: Optional[bytes] = None,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> messages.EthereumMessageSignature:
|
) -> messages.EthereumMessageSignature:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EthereumSignMessage(
|
messages.EthereumSignMessage(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
message=prepare_message_bytes(message),
|
message=prepare_message_bytes(message),
|
||||||
@ -317,7 +315,7 @@ def sign_message(
|
|||||||
|
|
||||||
|
|
||||||
def sign_typed_data(
|
def sign_typed_data(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
*,
|
*,
|
||||||
@ -333,7 +331,7 @@ def sign_typed_data(
|
|||||||
metamask_v4_compat=metamask_v4_compat,
|
metamask_v4_compat=metamask_v4_compat,
|
||||||
definitions=definitions,
|
definitions=definitions,
|
||||||
)
|
)
|
||||||
response = client.call(request)
|
response = session.call(request)
|
||||||
|
|
||||||
# Sending all the types
|
# Sending all the types
|
||||||
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
||||||
@ -349,7 +347,7 @@ def sign_typed_data(
|
|||||||
members.append(struct_member)
|
members.append(struct_member)
|
||||||
|
|
||||||
request = messages.EthereumTypedDataStructAck(members=members)
|
request = messages.EthereumTypedDataStructAck(members=members)
|
||||||
response = client.call(request)
|
response = session.call(request)
|
||||||
|
|
||||||
# Sending the whole message that should be signed
|
# Sending the whole message that should be signed
|
||||||
while isinstance(response, messages.EthereumTypedDataValueRequest):
|
while isinstance(response, messages.EthereumTypedDataValueRequest):
|
||||||
@ -362,7 +360,7 @@ def sign_typed_data(
|
|||||||
member_typename = data["primaryType"]
|
member_typename = data["primaryType"]
|
||||||
member_data = data["message"]
|
member_data = data["message"]
|
||||||
else:
|
else:
|
||||||
client.cancel()
|
# TODO session.cancel()
|
||||||
raise exceptions.TrezorException("Root index can only be 0 or 1")
|
raise exceptions.TrezorException("Root index can only be 0 or 1")
|
||||||
|
|
||||||
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
|
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
|
||||||
@ -385,20 +383,20 @@ def sign_typed_data(
|
|||||||
encoded_data = encode_data(member_data, member_typename)
|
encoded_data = encode_data(member_data, member_typename)
|
||||||
|
|
||||||
request = messages.EthereumTypedDataValueAck(value=encoded_data)
|
request = messages.EthereumTypedDataValueAck(value=encoded_data)
|
||||||
response = client.call(request)
|
response = session.call(request)
|
||||||
|
|
||||||
return messages.EthereumTypedDataSignature.ensure_isinstance(response)
|
return messages.EthereumTypedDataSignature.ensure_isinstance(response)
|
||||||
|
|
||||||
|
|
||||||
def verify_message(
|
def verify_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
signature: bytes,
|
signature: bytes,
|
||||||
message: AnyStr,
|
message: AnyStr,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
client.call(
|
session.call(
|
||||||
messages.EthereumVerifyMessage(
|
messages.EthereumVerifyMessage(
|
||||||
address=address,
|
address=address,
|
||||||
signature=signature,
|
signature=signature,
|
||||||
@ -413,13 +411,13 @@ def verify_message(
|
|||||||
|
|
||||||
|
|
||||||
def sign_typed_data_hash(
|
def sign_typed_data_hash(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
domain_hash: bytes,
|
domain_hash: bytes,
|
||||||
message_hash: Optional[bytes],
|
message_hash: Optional[bytes],
|
||||||
encoded_network: Optional[bytes] = None,
|
encoded_network: Optional[bytes] = None,
|
||||||
) -> messages.EthereumTypedDataSignature:
|
) -> messages.EthereumTypedDataSignature:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EthereumSignTypedHash(
|
messages.EthereumSignTypedHash(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
domain_separator_hash=domain_hash,
|
domain_separator_hash=domain_hash,
|
||||||
|
@ -65,3 +65,7 @@ class UnexpectedMessageError(TrezorException):
|
|||||||
self.expected = expected
|
self.expected = expected
|
||||||
self.actual = actual
|
self.actual = actual
|
||||||
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")
|
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceLockedException(TrezorException):
|
||||||
|
pass
|
||||||
|
@ -22,37 +22,37 @@ from . import messages
|
|||||||
from .tools import _return_success
|
from .tools import _return_success
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]:
|
def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
|
messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
|
||||||
).credentials
|
).credentials
|
||||||
|
|
||||||
|
|
||||||
def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None:
|
def add_credential(session: "Session", credential_id: bytes) -> str | None:
|
||||||
ret = client.call(
|
ret = session.call(
|
||||||
messages.WebAuthnAddResidentCredential(credential_id=credential_id),
|
messages.WebAuthnAddResidentCredential(credential_id=credential_id),
|
||||||
expect=messages.Success,
|
expect=messages.Success,
|
||||||
)
|
)
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
def remove_credential(client: "TrezorClient", index: int) -> str | None:
|
def remove_credential(session: "Session", index: int) -> str | None:
|
||||||
ret = client.call(
|
ret = session.call(
|
||||||
messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
|
messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
|
||||||
)
|
)
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None:
|
def set_counter(session: "Session", u2f_counter: int) -> str | None:
|
||||||
ret = client.call(
|
ret = session.call(
|
||||||
messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
|
messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
|
||||||
)
|
)
|
||||||
return _return_success(ret)
|
return _return_success(ret)
|
||||||
|
|
||||||
|
|
||||||
def get_next_counter(client: "TrezorClient") -> int:
|
def get_next_counter(session: "Session") -> int:
|
||||||
ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
|
ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
|
||||||
return ret.u2f_counter
|
return ret.u2f_counter
|
||||||
|
@ -20,7 +20,6 @@ from hashlib import blake2s
|
|||||||
from typing_extensions import Protocol, TypeGuard
|
from typing_extensions import Protocol, TypeGuard
|
||||||
|
|
||||||
from .. import messages
|
from .. import messages
|
||||||
from ..tools import session
|
|
||||||
from .core import VendorFirmware
|
from .core import VendorFirmware
|
||||||
from .legacy import LegacyFirmware, LegacyV2Firmware
|
from .legacy import LegacyFirmware, LegacyV2Firmware
|
||||||
|
|
||||||
@ -38,7 +37,7 @@ if True:
|
|||||||
from .vendor import * # noqa: F401, F403
|
from .vendor import * # noqa: F401, F403
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
T = t.TypeVar("T", bound="FirmwareType")
|
T = t.TypeVar("T", bound="FirmwareType")
|
||||||
|
|
||||||
@ -72,20 +71,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
|
|||||||
# ====== Client functions ====== #
|
# ====== Client functions ====== #
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def update(
|
def update(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
data: bytes,
|
data: bytes,
|
||||||
progress_update: t.Callable[[int], t.Any] = lambda _: None,
|
progress_update: t.Callable[[int], t.Any] = lambda _: None,
|
||||||
):
|
):
|
||||||
if client.features.bootloader_mode is False:
|
if session.features.bootloader_mode is False:
|
||||||
raise RuntimeError("Device must be in bootloader mode")
|
raise RuntimeError("Device must be in bootloader mode")
|
||||||
|
|
||||||
resp = client.call(messages.FirmwareErase(length=len(data)))
|
resp = session.call(messages.FirmwareErase(length=len(data)))
|
||||||
|
|
||||||
# TREZORv1 method
|
# TREZORv1 method
|
||||||
if isinstance(resp, messages.Success):
|
if isinstance(resp, messages.Success):
|
||||||
resp = client.call(messages.FirmwareUpload(payload=data))
|
resp = session.call(messages.FirmwareUpload(payload=data))
|
||||||
progress_update(len(data))
|
progress_update(len(data))
|
||||||
if isinstance(resp, messages.Success):
|
if isinstance(resp, messages.Success):
|
||||||
return
|
return
|
||||||
@ -97,7 +95,7 @@ def update(
|
|||||||
length = resp.length
|
length = resp.length
|
||||||
payload = data[resp.offset : resp.offset + length]
|
payload = data[resp.offset : resp.offset + length]
|
||||||
digest = blake2s(payload).digest()
|
digest = blake2s(payload).digest()
|
||||||
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||||
progress_update(length)
|
progress_update(length)
|
||||||
|
|
||||||
if isinstance(resp, messages.Success):
|
if isinstance(resp, messages.Success):
|
||||||
@ -106,7 +104,7 @@ def update(
|
|||||||
raise RuntimeError(f"Unexpected message {resp}")
|
raise RuntimeError(f"Unexpected message {resp}")
|
||||||
|
|
||||||
|
|
||||||
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes:
|
def get_hash(session: "Session", challenge: t.Optional[bytes]) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
|
messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
|
||||||
).hash
|
).hash
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ from typing_extensions import Self
|
|||||||
from . import messages, protobuf
|
from . import messages, protobuf
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProtobufMapping:
|
class ProtobufMapping:
|
||||||
@ -63,11 +65,21 @@ class ProtobufMapping:
|
|||||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||||
if wire_type is None:
|
if wire_type is None:
|
||||||
raise ValueError("Cannot encode class without wire type")
|
raise ValueError("Cannot encode class without wire type")
|
||||||
|
LOG.debug("encoding wire type %d", wire_type)
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
protobuf.dump_message(buf, msg)
|
protobuf.dump_message(buf, msg)
|
||||||
return wire_type, buf.getvalue()
|
return wire_type, buf.getvalue()
|
||||||
|
|
||||||
|
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
||||||
|
"""Serialize a Python protobuf class.
|
||||||
|
|
||||||
|
Returns the byte representation of the protobuf message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
protobuf.dump_message(buf, msg)
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
||||||
"""Deserialize a protobuf message into a Python class."""
|
"""Deserialize a protobuf message into a Python class."""
|
||||||
cls = self.type_to_class[msg_wire_type]
|
cls = self.type_to_class[msg_wire_type]
|
||||||
@ -83,7 +95,9 @@ class ProtobufMapping:
|
|||||||
mapping = cls()
|
mapping = cls()
|
||||||
|
|
||||||
message_types = getattr(module, "MessageType")
|
message_types = getattr(module, "MessageType")
|
||||||
for entry in message_types:
|
thp_message_types = getattr(module, "ThpMessageType")
|
||||||
|
|
||||||
|
for entry in (*message_types, *thp_message_types):
|
||||||
msg_class = getattr(module, entry.name, None)
|
msg_class = getattr(module, entry.name, None)
|
||||||
if msg_class is None:
|
if msg_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
352
python/src/trezorlib/messages.py
generated
352
python/src/trezorlib/messages.py
generated
@ -43,6 +43,10 @@ class FailureType(IntEnum):
|
|||||||
PinMismatch = 12
|
PinMismatch = 12
|
||||||
WipeCodeMismatch = 13
|
WipeCodeMismatch = 13
|
||||||
InvalidSession = 14
|
InvalidSession = 14
|
||||||
|
ThpUnallocatedSession = 15
|
||||||
|
InvalidProtocol = 16
|
||||||
|
BufferError = 17
|
||||||
|
DeviceIsBusy = 18
|
||||||
FirmwareError = 99
|
FirmwareError = 99
|
||||||
|
|
||||||
|
|
||||||
@ -400,6 +404,34 @@ class TezosBallotType(IntEnum):
|
|||||||
Pass = 2
|
Pass = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ThpMessageType(IntEnum):
|
||||||
|
ThpCreateNewSession = 1000
|
||||||
|
ThpPairingRequest = 1006
|
||||||
|
ThpPairingRequestApproved = 1007
|
||||||
|
ThpSelectMethod = 1008
|
||||||
|
ThpPairingPreparationsFinished = 1009
|
||||||
|
ThpCredentialRequest = 1010
|
||||||
|
ThpCredentialResponse = 1011
|
||||||
|
ThpEndRequest = 1012
|
||||||
|
ThpEndResponse = 1013
|
||||||
|
ThpCodeEntryCommitment = 1016
|
||||||
|
ThpCodeEntryChallenge = 1017
|
||||||
|
ThpCodeEntryCpaceTrezor = 1018
|
||||||
|
ThpCodeEntryCpaceHostTag = 1019
|
||||||
|
ThpCodeEntrySecret = 1020
|
||||||
|
ThpQrCodeTag = 1024
|
||||||
|
ThpQrCodeSecret = 1025
|
||||||
|
ThpNfcTagHost = 1032
|
||||||
|
ThpNfcTagTrezor = 1033
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingMethod(IntEnum):
|
||||||
|
SkipPairing = 1
|
||||||
|
CodeEntry = 2
|
||||||
|
QrCode = 3
|
||||||
|
NFC = 4
|
||||||
|
|
||||||
|
|
||||||
class MessageType(IntEnum):
|
class MessageType(IntEnum):
|
||||||
Initialize = 0
|
Initialize = 0
|
||||||
Ping = 1
|
Ping = 1
|
||||||
@ -500,6 +532,8 @@ class MessageType(IntEnum):
|
|||||||
DebugLinkWatchLayout = 9006
|
DebugLinkWatchLayout = 9006
|
||||||
DebugLinkResetDebugEvents = 9007
|
DebugLinkResetDebugEvents = 9007
|
||||||
DebugLinkOptigaSetSecMax = 9008
|
DebugLinkOptigaSetSecMax = 9008
|
||||||
|
DebugLinkGetPairingInfo = 9009
|
||||||
|
DebugLinkPairingInfo = 9010
|
||||||
EthereumGetPublicKey = 450
|
EthereumGetPublicKey = 450
|
||||||
EthereumPublicKey = 451
|
EthereumPublicKey = 451
|
||||||
EthereumGetAddress = 56
|
EthereumGetAddress = 56
|
||||||
@ -4203,6 +4237,52 @@ class DebugLinkState(protobuf.MessageType):
|
|||||||
self.mnemonic_type = mnemonic_type
|
self.mnemonic_type = mnemonic_type
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLinkGetPairingInfo(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 9009
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: Optional["bytes"] = None,
|
||||||
|
handshake_hash: Optional["bytes"] = None,
|
||||||
|
nfc_secret_host: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.handshake_hash = handshake_hash
|
||||||
|
self.nfc_secret_host = nfc_secret_host
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLinkPairingInfo(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 9010
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None),
|
||||||
|
4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None),
|
||||||
|
5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: Optional["bytes"] = None,
|
||||||
|
handshake_hash: Optional["bytes"] = None,
|
||||||
|
code_entry_code: Optional["int"] = None,
|
||||||
|
code_qr_code: Optional["bytes"] = None,
|
||||||
|
nfc_secret_trezor: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.handshake_hash = handshake_hash
|
||||||
|
self.code_entry_code = code_entry_code
|
||||||
|
self.code_qr_code = code_qr_code
|
||||||
|
self.nfc_secret_trezor = nfc_secret_trezor
|
||||||
|
|
||||||
|
|
||||||
class DebugLinkStop(protobuf.MessageType):
|
class DebugLinkStop(protobuf.MessageType):
|
||||||
MESSAGE_WIRE_TYPE = 103
|
MESSAGE_WIRE_TYPE = 103
|
||||||
|
|
||||||
@ -7863,8 +7943,68 @@ class TezosManagerTransfer(protobuf.MessageType):
|
|||||||
self.amount = amount
|
self.amount = amount
|
||||||
|
|
||||||
|
|
||||||
class ThpCredentialMetadata(protobuf.MessageType):
|
class ThpDeviceProperties(protobuf.MessageType):
|
||||||
MESSAGE_WIRE_TYPE = None
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None),
|
||||||
|
4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None),
|
||||||
|
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
|
||||||
|
internal_model: Optional["str"] = None,
|
||||||
|
model_variant: Optional["int"] = None,
|
||||||
|
protocol_version_major: Optional["int"] = None,
|
||||||
|
protocol_version_minor: Optional["int"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
|
||||||
|
self.internal_model = internal_model
|
||||||
|
self.model_variant = model_variant
|
||||||
|
self.protocol_version_major = protocol_version_major
|
||||||
|
self.protocol_version_minor = protocol_version_minor
|
||||||
|
|
||||||
|
|
||||||
|
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_pairing_credential: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_pairing_credential = host_pairing_credential
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCreateNewSession(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1000
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
passphrase: Optional["str"] = None,
|
||||||
|
on_device: Optional["bool"] = None,
|
||||||
|
derive_cardano: Optional["bool"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.passphrase = passphrase
|
||||||
|
self.on_device = on_device
|
||||||
|
self.derive_cardano = derive_cardano
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1006
|
||||||
FIELDS = {
|
FIELDS = {
|
||||||
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||||
}
|
}
|
||||||
@ -7877,6 +8017,216 @@ class ThpCredentialMetadata(protobuf.MessageType):
|
|||||||
self.host_name = host_name
|
self.host_name = host_name
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingRequestApproved(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1007
|
||||||
|
|
||||||
|
|
||||||
|
class ThpSelectMethod(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1008
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
selected_pairing_method: Optional["ThpPairingMethod"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.selected_pairing_method = selected_pairing_method
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingPreparationsFinished(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1009
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCommitment(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1016
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
commitment: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.commitment = commitment
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryChallenge(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1017
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
challenge: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.challenge = challenge
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1018
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cpace_trezor_public_key: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.cpace_trezor_public_key = cpace_trezor_public_key
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCpaceHostTag(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1019
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cpace_host_public_key: Optional["bytes"] = None,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.cpace_host_public_key = cpace_host_public_key
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntrySecret(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1020
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
secret: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
|
||||||
|
class ThpQrCodeTag(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1024
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpQrCodeSecret(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1025
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
secret: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
|
||||||
|
class ThpNfcTagHost(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1032
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpNfcTagTrezor(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1033
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1010
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_static_pubkey: Optional["bytes"] = None,
|
||||||
|
autoconnect: Optional["bool"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_static_pubkey = host_static_pubkey
|
||||||
|
self.autoconnect = autoconnect
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialResponse(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1011
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
trezor_static_pubkey: Optional["bytes"] = None,
|
||||||
|
credential: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.trezor_static_pubkey = trezor_static_pubkey
|
||||||
|
self.credential = credential
|
||||||
|
|
||||||
|
|
||||||
|
class ThpEndRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1012
|
||||||
|
|
||||||
|
|
||||||
|
class ThpEndResponse(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1013
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialMetadata(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_name: Optional["str"] = None,
|
||||||
|
autoconnect: Optional["bool"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_name = host_name
|
||||||
|
self.autoconnect = autoconnect
|
||||||
|
|
||||||
|
|
||||||
class ThpPairingCredential(protobuf.MessageType):
|
class ThpPairingCredential(protobuf.MessageType):
|
||||||
MESSAGE_WIRE_TYPE = None
|
MESSAGE_WIRE_TYPE = None
|
||||||
FIELDS = {
|
FIELDS = {
|
||||||
|
@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from . import messages
|
from . import messages
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def get_entropy(client: "TrezorClient", size: int) -> bytes:
|
def get_entropy(session: "Session", size: int) -> bytes:
|
||||||
return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
|
return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
|
||||||
|
|
||||||
|
|
||||||
def sign_identity(
|
def sign_identity(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
identity: messages.IdentityType,
|
identity: messages.IdentityType,
|
||||||
challenge_hidden: bytes,
|
challenge_hidden: bytes,
|
||||||
challenge_visual: str,
|
challenge_visual: str,
|
||||||
ecdsa_curve_name: Optional[str] = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
) -> messages.SignedIdentity:
|
) -> messages.SignedIdentity:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SignIdentity(
|
messages.SignIdentity(
|
||||||
identity=identity,
|
identity=identity,
|
||||||
challenge_hidden=challenge_hidden,
|
challenge_hidden=challenge_hidden,
|
||||||
@ -46,12 +46,12 @@ def sign_identity(
|
|||||||
|
|
||||||
|
|
||||||
def get_ecdh_session_key(
|
def get_ecdh_session_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
identity: messages.IdentityType,
|
identity: messages.IdentityType,
|
||||||
peer_public_key: bytes,
|
peer_public_key: bytes,
|
||||||
ecdsa_curve_name: Optional[str] = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
) -> messages.ECDHSessionKey:
|
) -> messages.ECDHSessionKey:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetECDHSessionKey(
|
messages.GetECDHSessionKey(
|
||||||
identity=identity,
|
identity=identity,
|
||||||
peer_public_key=peer_public_key,
|
peer_public_key=peer_public_key,
|
||||||
@ -62,7 +62,7 @@ def get_ecdh_session_key(
|
|||||||
|
|
||||||
|
|
||||||
def encrypt_keyvalue(
|
def encrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
key: str,
|
key: str,
|
||||||
value: bytes,
|
value: bytes,
|
||||||
@ -70,7 +70,7 @@ def encrypt_keyvalue(
|
|||||||
ask_on_decrypt: bool = True,
|
ask_on_decrypt: bool = True,
|
||||||
iv: bytes = b"",
|
iv: bytes = b"",
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.CipherKeyValue(
|
messages.CipherKeyValue(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
key=key,
|
key=key,
|
||||||
@ -85,7 +85,7 @@ def encrypt_keyvalue(
|
|||||||
|
|
||||||
|
|
||||||
def decrypt_keyvalue(
|
def decrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
key: str,
|
key: str,
|
||||||
value: bytes,
|
value: bytes,
|
||||||
@ -93,7 +93,7 @@ def decrypt_keyvalue(
|
|||||||
ask_on_decrypt: bool = True,
|
ask_on_decrypt: bool = True,
|
||||||
iv: bytes = b"",
|
iv: bytes = b"",
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.CipherKeyValue(
|
messages.CipherKeyValue(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
key=key,
|
key=key,
|
||||||
@ -107,5 +107,5 @@ def decrypt_keyvalue(
|
|||||||
).value
|
).value
|
||||||
|
|
||||||
|
|
||||||
def get_nonce(client: "TrezorClient") -> bytes:
|
def get_nonce(session: "Session") -> bytes:
|
||||||
return client.call(messages.GetNonce(), expect=messages.Nonce).nonce
|
return session.call(messages.GetNonce(), expect=messages.Nonce).nonce
|
||||||
|
@ -19,8 +19,8 @@ from typing import TYPE_CHECKING
|
|||||||
from . import messages
|
from . import messages
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
# MAINNET = 0
|
# MAINNET = 0
|
||||||
@ -30,13 +30,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.MoneroGetAddress(
|
messages.MoneroGetAddress(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
@ -48,11 +48,11 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def get_watch_key(
|
def get_watch_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||||
) -> messages.MoneroWatchKey:
|
) -> messages.MoneroWatchKey:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
|
messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
|
||||||
expect=messages.MoneroWatchKey,
|
expect=messages.MoneroWatchKey,
|
||||||
)
|
)
|
||||||
|
@ -20,8 +20,8 @@ from typing import TYPE_CHECKING
|
|||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||||
@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
|
|||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
network: int,
|
network: int,
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.NEMGetAddress(
|
messages.NEMGetAddress(
|
||||||
address_n=n, network=network, show_display=show_display, chunkify=chunkify
|
address_n=n, network=network, show_display=show_display, chunkify=chunkify
|
||||||
),
|
),
|
||||||
@ -210,7 +210,7 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
|
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
|
||||||
) -> messages.NEMSignedTx:
|
) -> messages.NEMSignedTx:
|
||||||
try:
|
try:
|
||||||
msg = create_sign_tx(transaction, chunkify=chunkify)
|
msg = create_sign_tx(transaction, chunkify=chunkify)
|
||||||
@ -219,4 +219,4 @@ def sign_tx(
|
|||||||
|
|
||||||
assert msg.transaction is not None
|
assert msg.transaction is not None
|
||||||
msg.transaction.address_n = n
|
msg.transaction.address_n = n
|
||||||
return client.call(msg, expect=messages.NEMSignedTx)
|
return session.call(msg, expect=messages.NEMSignedTx)
|
||||||
|
@ -21,20 +21,20 @@ from .protobuf import dict_to_proto
|
|||||||
from .tools import dict_from_camelcase
|
from .tools import dict_from_camelcase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
||||||
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.RippleGetAddress(
|
messages.RippleGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
),
|
),
|
||||||
@ -43,14 +43,14 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
msg: messages.RippleSignTx,
|
msg: messages.RippleSignTx,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> messages.RippleSignedTx:
|
) -> messages.RippleSignedTx:
|
||||||
msg.address_n = address_n
|
msg.address_n = address_n
|
||||||
msg.chunkify = chunkify
|
msg.chunkify = chunkify
|
||||||
return client.call(msg, expect=messages.RippleSignedTx)
|
return session.call(msg, expect=messages.RippleSignedTx)
|
||||||
|
|
||||||
|
|
||||||
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
||||||
|
@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
from . import messages
|
from . import messages
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
|
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
|
||||||
expect=messages.SolanaPublicKey,
|
expect=messages.SolanaPublicKey,
|
||||||
).public_key
|
).public_key
|
||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SolanaGetAddress(
|
messages.SolanaGetAddress(
|
||||||
address_n=address_n,
|
address_n=address_n,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
@ -34,12 +34,12 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
serialized_tx: bytes,
|
serialized_tx: bytes,
|
||||||
additional_info: Optional[messages.SolanaTxAdditionalInfo],
|
additional_info: Optional[messages.SolanaTxAdditionalInfo],
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SolanaSignTx(
|
messages.SolanaSignTx(
|
||||||
address_n=address_n,
|
address_n=address_n,
|
||||||
serialized_tx=serialized_tx,
|
serialized_tx=serialized_tx,
|
||||||
|
@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union
|
|||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
StellarMessageType = Union[
|
StellarMessageType = Union[
|
||||||
messages.StellarAccountMergeOp,
|
messages.StellarAccountMergeOp,
|
||||||
@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
|
|||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.StellarGetAddress(
|
messages.StellarGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
),
|
),
|
||||||
@ -336,7 +336,7 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
tx: messages.StellarSignTx,
|
tx: messages.StellarSignTx,
|
||||||
operations: List["StellarMessageType"],
|
operations: List["StellarMessageType"],
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
@ -352,10 +352,10 @@ def sign_tx(
|
|||||||
# 3. Receive a StellarTxOpRequest message
|
# 3. Receive a StellarTxOpRequest message
|
||||||
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
||||||
# 5. The final message received will be StellarSignedTx which is returned from this method
|
# 5. The final message received will be StellarSignedTx which is returned from this method
|
||||||
resp = client.call(tx)
|
resp = session.call(tx)
|
||||||
try:
|
try:
|
||||||
while isinstance(resp, messages.StellarTxOpRequest):
|
while isinstance(resp, messages.StellarTxOpRequest):
|
||||||
resp = client.call(operations.pop(0))
|
resp = session.call(operations.pop(0))
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# pop from empty list
|
# pop from empty list
|
||||||
raise exceptions.TrezorException(
|
raise exceptions.TrezorException(
|
||||||
|
@ -19,17 +19,17 @@ from typing import TYPE_CHECKING
|
|||||||
from . import messages
|
from . import messages
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.TezosGetAddress(
|
messages.TezosGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
),
|
),
|
||||||
@ -38,12 +38,12 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.TezosGetPublicKey(
|
messages.TezosGetPublicKey(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
),
|
),
|
||||||
@ -52,11 +52,11 @@ def get_public_key(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
sign_tx_msg: messages.TezosSignTx,
|
sign_tx_msg: messages.TezosSignTx,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> messages.TezosSignedTx:
|
) -> messages.TezosSignedTx:
|
||||||
sign_tx_msg.address_n = address_n
|
sign_tx_msg.address_n = address_n
|
||||||
sign_tx_msg.chunkify = chunkify
|
sign_tx_msg.chunkify = chunkify
|
||||||
return client.call(sign_tx_msg, expect=messages.TezosSignedTx)
|
return session.call(sign_tx_msg, expect=messages.TezosSignedTx)
|
||||||
|
@ -45,7 +45,7 @@ if TYPE_CHECKING:
|
|||||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from . import client
|
from . import client
|
||||||
from .messages import Success
|
from .messages import Success
|
||||||
@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None:
|
|||||||
return _deprecation_retval_helper(msg.message, stacklevel=1)
|
return _deprecation_retval_helper(msg.message, stacklevel=1)
|
||||||
|
|
||||||
|
|
||||||
def session(
|
|
||||||
f: "Callable[Concatenate[TrezorClient, P], R]",
|
|
||||||
) -> "Callable[Concatenate[TrezorClient, P], R]":
|
|
||||||
# Decorator wraps a BaseClient method
|
|
||||||
# with session activation / deactivation
|
|
||||||
@functools.wraps(f)
|
|
||||||
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
|
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
||||||
client.open()
|
|
||||||
try:
|
|
||||||
return f(client, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
return wrapped_f
|
|
||||||
|
|
||||||
|
|
||||||
# de-camelcasifier
|
# de-camelcasifier
|
||||||
# https://stackoverflow.com/a/1176023/222189
|
# https://stackoverflow.com/a/1176023/222189
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user