mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-07 21:22:41 +00:00
feat(python): implement session based trezorlib
This commit is contained in:
parent
7f5764b7d4
commit
fbff05a89f
@ -7,7 +7,7 @@ import typing as t
|
||||
from importlib import metadata
|
||||
|
||||
from . import device
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
try:
|
||||
cryptography_version = metadata.version("cryptography")
|
||||
@ -361,7 +361,7 @@ def verify_authentication_response(
|
||||
|
||||
|
||||
def authenticate_device(
|
||||
client: TrezorClient,
|
||||
session: Session,
|
||||
challenge: bytes | None = None,
|
||||
*,
|
||||
whitelist: t.Collection[bytes] | None = None,
|
||||
@ -371,7 +371,7 @@ def authenticate_device(
|
||||
if challenge is None:
|
||||
challenge = secrets.token_bytes(16)
|
||||
|
||||
resp = device.authenticate(client, challenge)
|
||||
resp = device.authenticate(session, challenge)
|
||||
|
||||
return verify_authentication_response(
|
||||
challenge,
|
||||
|
@ -19,16 +19,16 @@ from typing import TYPE_CHECKING
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def list_names(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
) -> messages.BenchmarkNames:
|
||||
return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
|
||||
return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
|
||||
|
||||
|
||||
def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult:
|
||||
return client.call(
|
||||
def run(session: "Session", name: str) -> messages.BenchmarkResult:
|
||||
return session.call(
|
||||
messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
|
||||
)
|
||||
|
@ -18,20 +18,19 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .protobuf import dict_to_proto
|
||||
from .tools import session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -40,17 +39,16 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
session: "Session", address_n: "Address", show_display: bool = False
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
|
||||
expect=messages.BinancePublicKey,
|
||||
).public_key
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
) -> messages.BinanceSignedTx:
|
||||
msg = tx_json["msgs"][0]
|
||||
tx_msg = tx_json.copy()
|
||||
@ -59,7 +57,7 @@ def sign_tx(
|
||||
tx_msg["chunkify"] = chunkify
|
||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
||||
|
||||
client.call(envelope, expect=messages.BinanceTxRequest)
|
||||
session.call(envelope, expect=messages.BinanceTxRequest)
|
||||
|
||||
if "refid" in msg:
|
||||
msg = dict_to_proto(messages.BinanceCancelMsg, msg)
|
||||
@ -70,4 +68,4 @@ def sign_tx(
|
||||
else:
|
||||
raise ValueError("can not determine msg type")
|
||||
|
||||
return client.call(msg, expect=messages.BinanceSignedTx)
|
||||
return session.call(msg, expect=messages.BinanceSignedTx)
|
||||
|
@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
|
||||
from typing_extensions import Protocol, TypedDict
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import _return_success, prepare_message_bytes, session
|
||||
from .tools import _return_success, prepare_message_bytes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
class ScriptSig(TypedDict):
|
||||
asm: str
|
||||
@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
||||
|
||||
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
show_display: bool = False,
|
||||
@ -116,12 +116,12 @@ def get_public_node(
|
||||
unlock_path_mac: Optional[bytes] = None,
|
||||
) -> messages.PublicKey:
|
||||
if unlock_path:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||
expect=messages.UnlockedPathRequest,
|
||||
)
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetPublicKey(
|
||||
address_n=n,
|
||||
ecdsa_curve_name=ecdsa_curve_name,
|
||||
@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str:
|
||||
|
||||
|
||||
def get_authenticated_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
@ -151,12 +151,12 @@ def get_authenticated_address(
|
||||
chunkify: bool = False,
|
||||
) -> messages.Address:
|
||||
if unlock_path:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||
expect=messages.UnlockedPathRequest,
|
||||
)
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetAddress(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -171,13 +171,13 @@ def get_authenticated_address(
|
||||
|
||||
|
||||
def get_ownership_id(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetOwnershipId(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -188,8 +188,9 @@ def get_ownership_id(
|
||||
).ownership_id
|
||||
|
||||
|
||||
# TODO this is used by tests only
|
||||
def get_ownership_proof(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
@ -200,9 +201,9 @@ def get_ownership_proof(
|
||||
preauthorized: bool = False,
|
||||
) -> Tuple[bytes, bytes]:
|
||||
if preauthorized:
|
||||
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.GetOwnershipProof(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -219,7 +220,7 @@ def get_ownership_proof(
|
||||
|
||||
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
@ -227,7 +228,7 @@ def sign_message(
|
||||
no_script_type: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> messages.MessageSignature:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SignMessage(
|
||||
coin_name=coin_name,
|
||||
address_n=n,
|
||||
@ -241,7 +242,7 @@ def sign_message(
|
||||
|
||||
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
address: str,
|
||||
signature: bytes,
|
||||
@ -249,7 +250,7 @@ def verify_message(
|
||||
chunkify: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.VerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
@ -264,9 +265,9 @@ def verify_message(
|
||||
return False
|
||||
|
||||
|
||||
@session
|
||||
# @session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
inputs: Sequence[messages.TxInputType],
|
||||
outputs: Sequence[messages.TxOutputType],
|
||||
@ -314,14 +315,14 @@ def sign_tx(
|
||||
setattr(signtx, name, value)
|
||||
|
||||
if unlock_path:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||
expect=messages.UnlockedPathRequest,
|
||||
)
|
||||
elif preauthorized:
|
||||
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
|
||||
res = client.call(signtx, expect=messages.TxRequest)
|
||||
res = session.call(signtx, expect=messages.TxRequest)
|
||||
|
||||
# Prepare structure for signatures
|
||||
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||
@ -380,7 +381,7 @@ def sign_tx(
|
||||
if res.request_type == R.TXPAYMENTREQ:
|
||||
assert res.details.request_index is not None
|
||||
msg = payment_reqs[res.details.request_index]
|
||||
res = client.call(msg, expect=messages.TxRequest)
|
||||
res = session.call(msg, expect=messages.TxRequest)
|
||||
else:
|
||||
msg = messages.TransactionType()
|
||||
if res.request_type == R.TXMETA:
|
||||
@ -410,7 +411,7 @@ def sign_tx(
|
||||
f"Unknown request type - {res.request_type}."
|
||||
)
|
||||
|
||||
res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
|
||||
res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
|
||||
|
||||
for i, sig in zip(inputs, signatures):
|
||||
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
|
||||
@ -420,7 +421,7 @@ def sign_tx(
|
||||
|
||||
|
||||
def authorize_coinjoin(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coordinator: str,
|
||||
max_rounds: int,
|
||||
max_coordinator_fee_rate: int,
|
||||
@ -429,7 +430,7 @@ def authorize_coinjoin(
|
||||
coin_name: str,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> str | None:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.AuthorizeCoinJoin(
|
||||
coordinator=coordinator,
|
||||
max_rounds=max_rounds,
|
||||
|
@ -35,7 +35,7 @@ from . import messages as m
|
||||
from . import tools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
PROTOCOL_MAGICS = {
|
||||
"mainnet": 764824073,
|
||||
@ -818,7 +818,7 @@ def _get_collateral_inputs_items(
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_parameters: m.CardanoAddressParametersType,
|
||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||
network_id: int = NETWORK_IDS["mainnet"],
|
||||
@ -826,7 +826,7 @@ def get_address(
|
||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
m.CardanoGetAddress(
|
||||
address_parameters=address_parameters,
|
||||
protocol_magic=protocol_magic,
|
||||
@ -840,12 +840,12 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||
show_display: bool = False,
|
||||
) -> m.CardanoPublicKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
m.CardanoGetPublicKey(
|
||||
address_n=address_n,
|
||||
derivation_type=derivation_type,
|
||||
@ -856,12 +856,12 @@ def get_public_key(
|
||||
|
||||
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
native_script: m.CardanoNativeScript,
|
||||
display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||
) -> m.CardanoNativeScriptHash:
|
||||
return client.call(
|
||||
return session.call(
|
||||
m.CardanoGetNativeScriptHash(
|
||||
script=native_script,
|
||||
display_format=display_format,
|
||||
@ -872,7 +872,7 @@ def get_native_script_hash(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
signing_mode: m.CardanoTxSigningMode,
|
||||
inputs: List[InputWithPath],
|
||||
outputs: List[OutputWithData],
|
||||
@ -907,7 +907,7 @@ def sign_tx(
|
||||
signing_mode,
|
||||
)
|
||||
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
m.CardanoSignTxInit(
|
||||
signing_mode=signing_mode,
|
||||
inputs_count=len(inputs),
|
||||
@ -942,12 +942,12 @@ def sign_tx(
|
||||
_get_certificates_items(certificates),
|
||||
withdrawals,
|
||||
):
|
||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
|
||||
sign_tx_response: Dict[str, Any] = {}
|
||||
|
||||
if auxiliary_data is not None:
|
||||
auxiliary_data_supplement = client.call(
|
||||
auxiliary_data_supplement = session.call(
|
||||
auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
|
||||
)
|
||||
if (
|
||||
@ -958,25 +958,25 @@ def sign_tx(
|
||||
auxiliary_data_supplement.__dict__
|
||||
)
|
||||
|
||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
|
||||
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
|
||||
|
||||
for tx_item in chain(
|
||||
_get_mint_items(mint),
|
||||
_get_collateral_inputs_items(collateral_inputs),
|
||||
required_signers,
|
||||
):
|
||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
|
||||
if collateral_return is not None:
|
||||
for tx_item in _get_output_items(collateral_return):
|
||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
|
||||
for reference_input in reference_inputs:
|
||||
response = client.call(reference_input, expect=m.CardanoTxItemAck)
|
||||
response = session.call(reference_input, expect=m.CardanoTxItemAck)
|
||||
|
||||
sign_tx_response["witnesses"] = []
|
||||
for witness_request in witness_requests:
|
||||
response = client.call(witness_request, expect=m.CardanoTxWitnessResponse)
|
||||
response = session.call(witness_request, expect=m.CardanoTxWitnessResponse)
|
||||
sign_tx_response["witnesses"].append(
|
||||
{
|
||||
"type": response.type,
|
||||
@ -986,9 +986,9 @@ def sign_tx(
|
||||
}
|
||||
)
|
||||
|
||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
|
||||
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
|
||||
sign_tx_response["tx_hash"] = response.tx_hash
|
||||
|
||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
|
||||
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
|
||||
|
||||
return sign_tx_response
|
||||
|
@ -13,28 +13,24 @@
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
from . import mapping, messages, models
|
||||
from .mapping import ProtobufMapping
|
||||
from .tools import parse_path
|
||||
from .transport import Transport, get_transport
|
||||
from .transport.thp.channel_data import ChannelData
|
||||
from .transport.thp.protocol_and_channel import ProtocolAndChannel
|
||||
from .transport.thp.protocol_v1 import ProtocolV1
|
||||
from .transport.thp.protocol_v2 import ProtocolV2
|
||||
|
||||
from . import exceptions, mapping, messages, models
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability
|
||||
from .protobuf import MessageType
|
||||
from .tools import parse_path, session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .transport import Transport
|
||||
from .ui import TrezorClientUI
|
||||
|
||||
UI = TypeVar("UI", bound="TrezorClientUI")
|
||||
MT = TypeVar("MT", bound=MessageType)
|
||||
if t.TYPE_CHECKING:
|
||||
from .transport.session import Session
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -51,8 +47,205 @@ Or visit https://suite.trezor.io/
|
||||
""".strip()
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolVersion(IntEnum):
|
||||
UNKNOWN = 0x00
|
||||
PROTOCOL_V1 = 0x01 # Codec
|
||||
PROTOCOL_V2 = 0x02 # THP
|
||||
|
||||
|
||||
class TrezorClient:
|
||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
|
||||
_seedless_session: Session | None = None
|
||||
_features: messages.Features | None = None
|
||||
_protocol_version: int
|
||||
_setup_pin: str | None = None # Should by used only by conftest
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
protobuf_mapping: ProtobufMapping | None = None,
|
||||
protocol: ProtocolAndChannel | None = None,
|
||||
) -> None:
|
||||
self._is_invalidated: bool = False
|
||||
self.transport = transport
|
||||
|
||||
if protobuf_mapping is None:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
else:
|
||||
self.mapping = protobuf_mapping
|
||||
if protocol is None:
|
||||
self.protocol = self._get_protocol()
|
||||
else:
|
||||
self.protocol = protocol
|
||||
self.protocol.mapping = self.mapping
|
||||
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
self._protocol_version = ProtocolVersion.PROTOCOL_V1
|
||||
elif isinstance(self.protocol, ProtocolV2):
|
||||
self._protocol_version = ProtocolVersion.PROTOCOL_V2
|
||||
else:
|
||||
self._protocol_version = ProtocolVersion.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def resume(
|
||||
cls,
|
||||
transport: Transport,
|
||||
channel_data: ChannelData,
|
||||
protobuf_mapping: ProtobufMapping | None = None,
|
||||
) -> TrezorClient:
|
||||
if protobuf_mapping is None:
|
||||
protobuf_mapping = mapping.DEFAULT_MAPPING
|
||||
protocol_v1 = ProtocolV1(transport, protobuf_mapping)
|
||||
if channel_data.protocol_version_major == 2:
|
||||
try:
|
||||
protocol_v1.write(messages.Ping(message="Sanity check - to resume"))
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
response = protocol_v1.read()
|
||||
if (
|
||||
isinstance(response, messages.Failure)
|
||||
and response.code == messages.FailureType.InvalidProtocol
|
||||
):
|
||||
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
|
||||
protocol.write(0, messages.Ping())
|
||||
response = protocol.read(0)
|
||||
if not isinstance(response, messages.Success):
|
||||
LOG.debug("Failed to resume ProtocolV2")
|
||||
raise Exception("Failed to resume ProtocolV2")
|
||||
LOG.debug("Protocol V2 detected - can be resumed")
|
||||
else:
|
||||
LOG.debug("Failed to resume ProtocolV2")
|
||||
raise Exception("Failed to resume ProtocolV2")
|
||||
else:
|
||||
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
|
||||
return TrezorClient(transport, protobuf_mapping, protocol)
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object | None = None,
|
||||
derive_cardano: bool = False,
|
||||
session_id: int = 0,
|
||||
) -> Session:
|
||||
"""
|
||||
Returns initialized session (with derived seed).
|
||||
|
||||
Will fail if the device is not initialized
|
||||
"""
|
||||
from .transport.session import SessionV1, SessionV2
|
||||
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
if passphrase is None:
|
||||
passphrase = ""
|
||||
return SessionV1.new(self, passphrase, derive_cardano)
|
||||
if isinstance(self.protocol, ProtocolV2):
|
||||
assert isinstance(passphrase, str) or passphrase is None
|
||||
return SessionV2.new(self, passphrase, derive_cardano, session_id)
|
||||
raise NotImplementedError # TODO
|
||||
|
||||
def resume_session(self, session: Session):
|
||||
"""
|
||||
Note: this function potentially modifies the input session.
|
||||
"""
|
||||
from .debuglink import SessionDebugWrapper
|
||||
from .transport.session import SessionV1, SessionV2
|
||||
|
||||
if isinstance(session, SessionDebugWrapper):
|
||||
session = session._session
|
||||
|
||||
if isinstance(session, SessionV2):
|
||||
return session
|
||||
elif isinstance(session, SessionV1):
|
||||
session.init_session()
|
||||
return session
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_seedless_session(self, new_session: bool = False) -> Session:
|
||||
from .transport.session import SessionV1, SessionV2
|
||||
|
||||
if not new_session and self._seedless_session is not None:
|
||||
return self._seedless_session
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
self._seedless_session = SessionV1.new(
|
||||
client=self,
|
||||
passphrase="",
|
||||
derive_cardano=False,
|
||||
)
|
||||
elif isinstance(self.protocol, ProtocolV2):
|
||||
self._seedless_session = SessionV2(client=self, id=b"\x00")
|
||||
assert self._seedless_session is not None
|
||||
return self._seedless_session
|
||||
|
||||
def invalidate(self) -> None:
|
||||
self._is_invalidated = True
|
||||
|
||||
@property
|
||||
def features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self._features = self.protocol.get_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
@property
|
||||
def protocol_version(self) -> int:
|
||||
return self._protocol_version
|
||||
|
||||
@property
|
||||
def model(self) -> models.TrezorModel:
|
||||
f = self.features
|
||||
model = models.by_name(f.model or "1")
|
||||
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
"Unsupported Trezor model"
|
||||
f" (internal_model: {f.internal_model}, model: {f.model})"
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def version(self) -> tuple[int, int, int]:
|
||||
f = self.features
|
||||
ver = (
|
||||
f.major_version,
|
||||
f.minor_version,
|
||||
f.patch_version,
|
||||
)
|
||||
return ver
|
||||
|
||||
@property
|
||||
def is_invalidated(self) -> bool:
|
||||
return self._is_invalidated
|
||||
|
||||
def refresh_features(self) -> None:
|
||||
self.protocol.update_features()
|
||||
self._features = self.protocol.get_features()
|
||||
|
||||
def _get_protocol(self) -> ProtocolAndChannel:
|
||||
self.transport.open()
|
||||
|
||||
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
|
||||
|
||||
protocol.write(messages.Initialize())
|
||||
|
||||
response = protocol.read()
|
||||
self.transport.close()
|
||||
if isinstance(response, messages.Failure):
|
||||
if response.code == messages.FailureType.InvalidProtocol:
|
||||
LOG.debug("Protocol V2 detected")
|
||||
protocol = ProtocolV2(self.transport, self.mapping)
|
||||
return protocol
|
||||
|
||||
|
||||
def get_default_client(
|
||||
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
|
||||
path: t.Optional[str] = None,
|
||||
**kwargs: t.Any,
|
||||
) -> "TrezorClient":
|
||||
"""Get a client for a connected Trezor device.
|
||||
|
||||
@ -62,436 +255,10 @@ def get_default_client(
|
||||
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
||||
If no UI is supplied, instantiates the default CLI UI.
|
||||
"""
|
||||
from .transport import get_transport
|
||||
from .ui import ClickUI
|
||||
|
||||
if path is None:
|
||||
path = os.getenv("TREZOR_PATH")
|
||||
|
||||
transport = get_transport(path, prefix_search=True)
|
||||
if ui is None:
|
||||
ui = ClickUI()
|
||||
|
||||
return TrezorClient(transport, ui, **kwargs)
|
||||
|
||||
|
||||
class TrezorClient(Generic[UI]):
|
||||
"""Trezor client, a connection to a Trezor device.
|
||||
|
||||
This class allows you to manage connection state, send and receive protobuf
|
||||
messages, handle user interactions, and perform some generic tasks
|
||||
(send a cancel message, initialize or clear a session, ping the device).
|
||||
"""
|
||||
|
||||
model: models.TrezorModel
|
||||
transport: "Transport"
|
||||
session_id: Optional[bytes]
|
||||
ui: UI
|
||||
features: messages.Features
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: "Transport",
|
||||
ui: UI,
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
model: Optional[models.TrezorModel] = None,
|
||||
_init_device: bool = True,
|
||||
) -> None:
|
||||
"""Create a TrezorClient instance.
|
||||
|
||||
You have to provide a `transport`, i.e., a raw connection to the device. You can
|
||||
use `trezorlib.transport.get_transport` to find one.
|
||||
|
||||
You have to provide a UI implementation for the three kinds of interaction:
|
||||
- button request (notify the user that their interaction is needed)
|
||||
- PIN request (on T1, ask the user to input numbers for a PIN matrix)
|
||||
- passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for
|
||||
details.
|
||||
|
||||
You can supply a `session_id` you might have saved in the previous session. If
|
||||
you do, the user might not need to enter their passphrase again.
|
||||
|
||||
You can provide Trezor model information. If not provided, it is detected from
|
||||
the model name reported at initialization time.
|
||||
|
||||
By default, the instance will open a connection to the Trezor device, send an
|
||||
`Initialize` message, set up the `features` field from the response, and connect
|
||||
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
|
||||
this means that `client.features` is unset. Use `client.init_device()` or
|
||||
`client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break.
|
||||
Only use this if you are _sure_ that you know what you are doing. This feature
|
||||
might be removed at any time.
|
||||
"""
|
||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
||||
# Here, self.model could be set to None. Unless _init_device is False, it will
|
||||
# get correctly reconfigured as part of the init_device flow.
|
||||
self.model = model # type: ignore ["None" is incompatible with "TrezorModel"]
|
||||
if self.model:
|
||||
self.mapping = self.model.default_mapping
|
||||
else:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
self.transport = transport
|
||||
self.ui = ui
|
||||
self.session_counter = 0
|
||||
self.session_id = session_id
|
||||
if _init_device:
|
||||
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
||||
|
||||
def open(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.transport.begin_session()
|
||||
self.session_counter += 1
|
||||
|
||||
def close(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
# TODO call EndSession here?
|
||||
self.transport.end_session()
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._raw_write(messages.Cancel())
|
||||
|
||||
def call_raw(self, msg: MessageType) -> MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
self._raw_write(msg)
|
||||
return self._raw_read()
|
||||
|
||||
def _raw_write(self, msg: MessageType) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self.transport.write(msg_type, msg_bytes)
|
||||
|
||||
def _raw_read(self) -> MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
msg_type, msg_bytes = self.transport.read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
|
||||
try:
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
except exceptions.Cancelled:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if any(d not in "123456789" for d in pin) or not (
|
||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||
):
|
||||
self.call_raw(messages.Cancel())
|
||||
raise ValueError("Invalid PIN provided")
|
||||
|
||||
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise exceptions.PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
|
||||
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
||||
|
||||
def send_passphrase(
|
||||
passphrase: Optional[str] = None, on_device: Optional[bool] = None
|
||||
) -> MessageType:
|
||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||
resp = self.call_raw(msg)
|
||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||
self.session_id = resp.state
|
||||
resp = self.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||
return resp
|
||||
|
||||
# short-circuit old style entry
|
||||
if msg._on_device is True:
|
||||
return send_passphrase(None, None)
|
||||
|
||||
try:
|
||||
passphrase = self.ui.get_passphrase(available_on_device=available_on_device)
|
||||
except exceptions.Cancelled:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||
if not available_on_device:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise RuntimeError("Device is not capable of entering passphrase")
|
||||
else:
|
||||
return send_passphrase(on_device=True)
|
||||
|
||||
# else process host-entered passphrase
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise ValueError("Passphrase too long")
|
||||
|
||||
return send_passphrase(passphrase, on_device=False)
|
||||
|
||||
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
self._raw_write(messages.ButtonAck())
|
||||
self.ui.button_request(msg)
|
||||
return self._raw_read()
|
||||
|
||||
@session
|
||||
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
|
||||
self.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
while True:
|
||||
if isinstance(resp, messages.PinMatrixRequest):
|
||||
resp = self._callback_pin(resp)
|
||||
elif isinstance(resp, messages.PassphraseRequest):
|
||||
resp = self._callback_passphrase(resp)
|
||||
elif isinstance(resp, messages.ButtonRequest):
|
||||
resp = self._callback_button(resp)
|
||||
elif isinstance(resp, messages.Failure):
|
||||
if resp.code == messages.FailureType.ActionCancelled:
|
||||
raise exceptions.Cancelled
|
||||
raise exceptions.TrezorFailure(resp)
|
||||
elif not isinstance(resp, expect):
|
||||
raise exceptions.UnexpectedMessageError(expect, resp)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def _refresh_features(self, features: messages.Features) -> None:
|
||||
"""Update internal fields based on passed-in Features message."""
|
||||
|
||||
if not self.model:
|
||||
# Trezor Model One bootloader 1.8.0 or older does not send model name
|
||||
model = models.by_internal_name(features.internal_model)
|
||||
if model is None:
|
||||
model = models.by_name(features.model or "1")
|
||||
if model is None:
|
||||
raise RuntimeError(
|
||||
"Unsupported Trezor model"
|
||||
f" (internal_model: {features.internal_model}, model: {features.model})"
|
||||
)
|
||||
self.model = model
|
||||
|
||||
if features.vendor not in self.model.vendors:
|
||||
raise RuntimeError("Unsupported device")
|
||||
|
||||
self.features = features
|
||||
self.version = (
|
||||
self.features.major_version,
|
||||
self.features.minor_version,
|
||||
self.features.patch_version,
|
||||
)
|
||||
self.check_firmware_version(warn_only=True)
|
||||
if self.features.session_id is not None:
|
||||
self.session_id = self.features.session_id
|
||||
self.features.session_id = None
|
||||
|
||||
@session
|
||||
def refresh_features(self) -> messages.Features:
|
||||
"""Reload features from the device.
|
||||
|
||||
Should be called after changing settings or performing operations that affect
|
||||
device state.
|
||||
"""
|
||||
resp = self.call_raw(messages.GetFeatures())
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._refresh_features(resp)
|
||||
return resp
|
||||
|
||||
@session
|
||||
def init_device(
|
||||
self,
|
||||
*,
|
||||
session_id: Optional[bytes] = None,
|
||||
new_session: bool = False,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
) -> Optional[bytes]:
|
||||
"""Initialize the device and return a session ID.
|
||||
|
||||
You can optionally specify a session ID. If the session still exists on the
|
||||
device, the same session ID will be returned and the session is resumed.
|
||||
Otherwise a different session ID is returned.
|
||||
|
||||
Specify `new_session=True` to open a fresh session. Since firmware version
|
||||
1.9.0/2.3.0, the previous session will remain cached on the device, and can be
|
||||
resumed by calling `init_device` again with the appropriate session ID.
|
||||
|
||||
If neither `new_session` nor `session_id` is specified, the current session ID
|
||||
will be reused. If no session ID was cached, a new session ID will be allocated
|
||||
and returned.
|
||||
|
||||
# Version notes:
|
||||
|
||||
Trezor One older than 1.9.0 does not have session management. Optional arguments
|
||||
have no effect and the function returns None
|
||||
|
||||
Trezor T older than 2.3.0 does not have session cache. Requesting a new session
|
||||
will overwrite the old one. In addition, this function will always return None.
|
||||
A valid session_id can be obtained from the `session_id` attribute, but only
|
||||
after a passphrase-protected call is performed. You can use the following code:
|
||||
|
||||
>>> client.init_device()
|
||||
>>> client.ensure_unlocked()
|
||||
>>> valid_session_id = client.session_id
|
||||
"""
|
||||
if new_session:
|
||||
self.session_id = None
|
||||
elif session_id is not None:
|
||||
self.session_id = session_id
|
||||
|
||||
resp = self.call_raw(
|
||||
messages.Initialize(
|
||||
session_id=self.session_id,
|
||||
derive_cardano=derive_cardano,
|
||||
)
|
||||
)
|
||||
if isinstance(resp, messages.Failure):
|
||||
# can happen if `derive_cardano` does not match the current session
|
||||
raise exceptions.TrezorFailure(resp)
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to Initialize")
|
||||
|
||||
if self.session_id is not None and resp.session_id == self.session_id:
|
||||
LOG.info("Successfully resumed session")
|
||||
elif session_id is not None:
|
||||
LOG.info("Failed to resume session")
|
||||
|
||||
# TT < 2.3.0 compatibility:
|
||||
# _refresh_features will clear out the session_id field. We want this function
|
||||
# to return its value, so that callers can rely on it being either a valid
|
||||
# session_id, or None if we can't do that.
|
||||
# Older TT FW does not report session_id in Features and self.session_id might
|
||||
# be invalid because TT will not allocate a session_id until a passphrase
|
||||
# exchange happens.
|
||||
reported_session_id = resp.session_id
|
||||
self._refresh_features(resp)
|
||||
return reported_session_id
|
||||
|
||||
def is_outdated(self) -> bool:
|
||||
if self.features.bootloader_mode:
|
||||
return False
|
||||
return self.version < self.model.minimum_version
|
||||
|
||||
def check_firmware_version(self, warn_only: bool = False) -> None:
|
||||
if self.is_outdated():
|
||||
if warn_only:
|
||||
warnings.warn("Firmware is out of date", stacklevel=2)
|
||||
else:
|
||||
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
||||
|
||||
def ping(self, msg: str, button_protection: bool = False) -> str:
|
||||
# We would like ping to work on any valid TrezorClient instance, but
|
||||
# due to the protection modes, we need to go through self.call, and that will
|
||||
# raise an exception if the firmware is too old.
|
||||
# So we short-circuit the simplest variant of ping with call_raw.
|
||||
if not button_protection:
|
||||
# XXX this should be: `with self:`
|
||||
try:
|
||||
self.open()
|
||||
resp = self.call_raw(messages.Ping(message=msg))
|
||||
if isinstance(resp, messages.ButtonRequest):
|
||||
# device is PIN-locked.
|
||||
# respond and hope for the best
|
||||
resp = self._callback_button(resp)
|
||||
resp = messages.Success.ensure_isinstance(resp)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
resp = self.call(
|
||||
messages.Ping(message=msg, button_protection=button_protection),
|
||||
expect=messages.Success,
|
||||
)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
|
||||
def get_device_id(self) -> Optional[str]:
|
||||
return self.features.device_id
|
||||
|
||||
@session
|
||||
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||
"""Lock the device.
|
||||
|
||||
If the device does not have a PIN configured, this will do nothing.
|
||||
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
||||
before further actions.
|
||||
|
||||
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
||||
the device will not prompt for it after unlocking.
|
||||
|
||||
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
||||
passphrase cache, use `clear_session()`.
|
||||
"""
|
||||
# Private argument _refresh_features can be used internally to avoid
|
||||
# refreshing in cases where we will refresh soon anyway. This is used
|
||||
# in TrezorClient.clear_session()
|
||||
self.call(messages.LockDevice())
|
||||
if _refresh_features:
|
||||
self.refresh_features()
|
||||
|
||||
@session
|
||||
def ensure_unlocked(self) -> None:
|
||||
"""Ensure the device is unlocked and a passphrase is cached.
|
||||
|
||||
If the device is locked, this will prompt for PIN. If passphrase is enabled
|
||||
and no passphrase is cached for the current session, the device will also
|
||||
prompt for passphrase.
|
||||
|
||||
After calling this method, further actions on the device will not prompt for
|
||||
PIN or passphrase until the device is locked or the session becomes invalid.
|
||||
"""
|
||||
from .btc import get_address
|
||||
|
||||
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||
self.refresh_features()
|
||||
|
||||
def end_session(self) -> None:
|
||||
"""Close the current session and clear cached passphrase.
|
||||
|
||||
The session will become invalid until `init_device()` is called again.
|
||||
If passphrase is enabled, further actions will prompt for it again.
|
||||
|
||||
This is a no-op in bootloader mode, as it does not support session management.
|
||||
"""
|
||||
# since: 2.3.4, 1.9.4
|
||||
try:
|
||||
if not self.features.bootloader_mode:
|
||||
self.call(messages.EndSession())
|
||||
except exceptions.TrezorFailure:
|
||||
# A failure most likely means that the FW version does not support
|
||||
# the EndSession call. We ignore the failure and clear the local session_id.
|
||||
# The client-side end result is identical.
|
||||
pass
|
||||
self.session_id = None
|
||||
|
||||
@session
|
||||
def clear_session(self) -> None:
|
||||
"""Lock the device and present a fresh session.
|
||||
|
||||
The current session will be invalidated and a new one will be started. If the
|
||||
device has PIN enabled, it will become locked.
|
||||
|
||||
Equivalent to calling `lock()`, `end_session()` and `init_device()`.
|
||||
"""
|
||||
self.lock(_refresh_features=False)
|
||||
self.end_session()
|
||||
self.init_device(new_session=True)
|
||||
return TrezorClient(transport, **kwargs)
|
||||
|
@ -21,55 +21,55 @@ import logging
|
||||
import re
|
||||
import textwrap
|
||||
import time
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from enum import Enum, IntEnum, auto
|
||||
from itertools import zip_longest
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import mapping, messages, models, protobuf
|
||||
from .client import TrezorClient
|
||||
from .exceptions import TrezorFailure, UnexpectedMessageError
|
||||
from . import btc, mapping, messages, models, protobuf
|
||||
from .client import (
|
||||
MAX_PASSPHRASE_LENGTH,
|
||||
MAX_PIN_LENGTH,
|
||||
PASSPHRASE_ON_DEVICE,
|
||||
TrezorClient,
|
||||
)
|
||||
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import DebugWaitType
|
||||
from .messages import Capability, DebugWaitType
|
||||
from .protobuf import MessageType
|
||||
from .tools import parse_path
|
||||
from .transport.session import Session, SessionV1
|
||||
from .transport.thp.protocol_v1 import ProtocolV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from .messages import PinMatrixRequestType
|
||||
from .transport import Transport
|
||||
|
||||
ExpectedMessage = Union[
|
||||
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
|
||||
ExpectedMessage = t.Union[
|
||||
protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
|
||||
]
|
||||
|
||||
AnyDict = Dict[str, Any]
|
||||
AnyDict = t.Dict[str, t.Any]
|
||||
|
||||
class InputFunc(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hold_ms: int | None = None,
|
||||
) -> "None": ...
|
||||
|
||||
InputFlowType = Generator[None, messages.ButtonRequest, None]
|
||||
InputFlowType = t.Generator[None, messages.ButtonRequest, None]
|
||||
|
||||
|
||||
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
||||
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -107,11 +107,11 @@ class UnstructuredJSONReader:
|
||||
except json.JSONDecodeError:
|
||||
self.dict = {}
|
||||
|
||||
def top_level_value(self, key: str) -> Any:
|
||||
def top_level_value(self, key: str) -> t.Any:
|
||||
return self.dict.get(key)
|
||||
|
||||
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
|
||||
def recursively_find(data: Any) -> Iterator[Any]:
|
||||
def find_objects_with_key_and_value(self, key: str, value: t.Any) -> list[AnyDict]:
|
||||
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||
if isinstance(data, dict):
|
||||
if data.get(key) == value:
|
||||
yield data
|
||||
@ -124,7 +124,7 @@ class UnstructuredJSONReader:
|
||||
return list(recursively_find(self.dict))
|
||||
|
||||
def find_unique_object_with_key_and_value(
|
||||
self, key: str, value: Any
|
||||
self, key: str, value: t.Any
|
||||
) -> AnyDict | None:
|
||||
objects = self.find_objects_with_key_and_value(key, value)
|
||||
if not objects:
|
||||
@ -132,8 +132,10 @@ class UnstructuredJSONReader:
|
||||
assert len(objects) == 1
|
||||
return objects[0]
|
||||
|
||||
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
|
||||
def recursively_find(data: Any) -> Iterator[Any]:
|
||||
def find_values_by_key(
|
||||
self, key: str, only_type: type | None = None
|
||||
) -> list[t.Any]:
|
||||
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||
if isinstance(data, dict):
|
||||
if key in data:
|
||||
yield data[key]
|
||||
@ -151,8 +153,8 @@ class UnstructuredJSONReader:
|
||||
return values
|
||||
|
||||
def find_unique_value_by_key(
|
||||
self, key: str, default: Any, only_type: type | None = None
|
||||
) -> Any:
|
||||
self, key: str, default: t.Any, only_type: type | None = None
|
||||
) -> t.Any:
|
||||
values = self.find_values_by_key(key, only_type=only_type)
|
||||
if not values:
|
||||
return default
|
||||
@ -163,7 +165,7 @@ class UnstructuredJSONReader:
|
||||
class LayoutContent(UnstructuredJSONReader):
|
||||
"""Contains helper functions to extract specific parts of the layout."""
|
||||
|
||||
def __init__(self, json_tokens: Sequence[str]) -> None:
|
||||
def __init__(self, json_tokens: t.Sequence[str]) -> None:
|
||||
json_str = "".join(json_tokens)
|
||||
super().__init__(json_str)
|
||||
|
||||
@ -429,6 +431,7 @@ class DebugLink:
|
||||
self.allow_interactions = auto_interact
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
|
||||
self.protocol = ProtocolV1(self.transport, self.mapping)
|
||||
# To be set by TrezorClientDebugLink (is not known during creation time)
|
||||
self.model: models.TrezorModel | None = None
|
||||
self.version: tuple[int, int, int] = (0, 0, 0)
|
||||
@ -471,10 +474,16 @@ class DebugLink:
|
||||
return LayoutType.from_model(self.model)
|
||||
|
||||
def open(self) -> None:
|
||||
self.transport.begin_session()
|
||||
self.transport.open()
|
||||
# raise NotImplementedError
|
||||
# TODO is this needed?
|
||||
# self.transport.deprecated_begin_session()
|
||||
|
||||
def close(self) -> None:
|
||||
self.transport.end_session()
|
||||
pass
|
||||
# raise NotImplementedError
|
||||
# TODO is this needed?
|
||||
# self.transport.deprecated_end_session()
|
||||
|
||||
def _write(self, msg: protobuf.MessageType) -> None:
|
||||
if self.waiting_for_layout_change:
|
||||
@ -491,15 +500,10 @@ class DebugLink:
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self.transport.write(msg_type, msg_bytes)
|
||||
self.protocol.write(msg)
|
||||
|
||||
def _read(self) -> protobuf.MessageType:
|
||||
ret_type, ret_bytes = self.transport.read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(ret_type, ret_bytes)
|
||||
msg = self.protocol.read()
|
||||
|
||||
# Collapse tokens to make log use less lines.
|
||||
msg_for_log = msg
|
||||
@ -513,7 +517,7 @@ class DebugLink:
|
||||
)
|
||||
return msg
|
||||
|
||||
def _call(self, msg: protobuf.MessageType) -> Any:
|
||||
def _call(self, msg: protobuf.MessageType) -> t.Any:
|
||||
self._write(msg)
|
||||
return self._read()
|
||||
|
||||
@ -531,6 +535,25 @@ class DebugLink:
|
||||
raise TrezorFailure(result)
|
||||
return result
|
||||
|
||||
def pairing_info(
|
||||
self,
|
||||
thp_channel_id: bytes | None = None,
|
||||
handshake_hash: bytes | None = None,
|
||||
nfc_secret_host: bytes | None = None,
|
||||
) -> messages.DebugLinkPairingInfo:
|
||||
result = self._call(
|
||||
messages.DebugLinkGetPairingInfo(
|
||||
channel_id=thp_channel_id,
|
||||
handshake_hash=handshake_hash,
|
||||
nfc_secret_host=nfc_secret_host,
|
||||
)
|
||||
)
|
||||
while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)):
|
||||
result = self._read()
|
||||
if isinstance(result, messages.Failure):
|
||||
raise TrezorFailure(result)
|
||||
return result
|
||||
|
||||
def read_layout(self, wait: bool | None = None) -> LayoutContent:
|
||||
"""
|
||||
Force waiting for the layout by setting `wait=True`. Force not waiting by
|
||||
@ -547,7 +570,7 @@ class DebugLink:
|
||||
|
||||
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
|
||||
# Next layout change will be caused by external event
|
||||
# (e.g. device being auto-locked or as a result of device_handler.run(xxx))
|
||||
# (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
|
||||
# and not by our debug actions/decisions.
|
||||
# Resetting the debug state so we wait for the next layout change
|
||||
# (and do not return the current state).
|
||||
@ -562,7 +585,7 @@ class DebugLink:
|
||||
return LayoutContent(obj.tokens)
|
||||
|
||||
@contextmanager
|
||||
def wait_for_layout_change(self) -> Iterator[None]:
|
||||
def wait_for_layout_change(self) -> t.Iterator[None]:
|
||||
# make sure some current layout is up by issuing a dummy GetState
|
||||
self.state()
|
||||
|
||||
@ -615,7 +638,7 @@ class DebugLink:
|
||||
|
||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||
|
||||
def read_recovery_word(self) -> Tuple[str | None, int | None]:
|
||||
def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
|
||||
state = self.state()
|
||||
return (state.recovery_fake_word, state.recovery_word_pos)
|
||||
|
||||
@ -671,7 +694,7 @@ class DebugLink:
|
||||
"""Send text input to the device. See `_decision` for more details."""
|
||||
self._decision(messages.DebugLinkDecision(input=word))
|
||||
|
||||
def click(self, click: Tuple[int, int], hold_ms: int | None = None) -> None:
|
||||
def click(self, click: t.Tuple[int, int], hold_ms: int | None = None) -> None:
|
||||
"""Send a click to the device. See `_decision` for more details."""
|
||||
x, y = click
|
||||
self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms))
|
||||
@ -794,10 +817,10 @@ class DebugUI:
|
||||
self.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.pins: Iterator[str] | None = None
|
||||
self.pins: t.Iterator[str] | None = None
|
||||
self.passphrase = ""
|
||||
self.input_flow: Union[
|
||||
Generator[None, messages.ButtonRequest, None], object, None
|
||||
self.input_flow: t.Union[
|
||||
t.Generator[None, messages.ButtonRequest, None], object, None
|
||||
] = None
|
||||
|
||||
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
|
||||
@ -829,7 +852,7 @@ class DebugUI:
|
||||
raise AssertionError("input flow ended prematurely")
|
||||
else:
|
||||
try:
|
||||
assert isinstance(self.input_flow, Generator)
|
||||
assert isinstance(self.input_flow, t.Generator)
|
||||
self.input_flow.send(br)
|
||||
except StopIteration:
|
||||
self.input_flow = self.INPUT_FLOW_DONE
|
||||
@ -851,12 +874,15 @@ class DebugUI:
|
||||
|
||||
|
||||
class MessageFilter:
|
||||
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
|
||||
|
||||
def __init__(
|
||||
self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
|
||||
) -> None:
|
||||
self.message_type = message_type
|
||||
self.fields: Dict[str, Any] = {}
|
||||
self.fields: t.Dict[str, t.Any] = {}
|
||||
self.update_fields(**fields)
|
||||
|
||||
def update_fields(self, **fields: Any) -> "MessageFilter":
|
||||
def update_fields(self, **fields: t.Any) -> "MessageFilter":
|
||||
for name, value in fields.items():
|
||||
try:
|
||||
self.fields[name] = self.from_message_or_type(value)
|
||||
@ -904,7 +930,7 @@ class MessageFilter:
|
||||
return True
|
||||
|
||||
def to_string(self, maxwidth: int = 80) -> str:
|
||||
fields: list[Tuple[str, str]] = []
|
||||
fields: list[t.Tuple[str, str]] = []
|
||||
for field in self.message_type.FIELDS.values():
|
||||
if field.name not in self.fields:
|
||||
continue
|
||||
@ -934,7 +960,7 @@ class MessageFilter:
|
||||
|
||||
|
||||
class MessageFilterGenerator:
|
||||
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
|
||||
def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
|
||||
message_type = getattr(messages, key)
|
||||
return MessageFilter(message_type).update_fields
|
||||
|
||||
@ -942,6 +968,245 @@ class MessageFilterGenerator:
|
||||
message_filters = MessageFilterGenerator()
|
||||
|
||||
|
||||
class SessionDebugWrapper(Session):
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
self.reset_debug_features()
|
||||
if isinstance(session, SessionDebugWrapper):
|
||||
raise Exception("Cannot wrap already wrapped session!")
|
||||
|
||||
@property
|
||||
def protocol_version(self) -> int:
|
||||
return self.client.protocol_version
|
||||
|
||||
@property
|
||||
def client(self) -> TrezorClientDebugLink:
|
||||
assert isinstance(self._session.client, TrezorClientDebugLink)
|
||||
return self._session.client
|
||||
|
||||
@property
|
||||
def id(self) -> bytes:
|
||||
return self._session.id
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
print("writing message:", msg.__class__.__name__)
|
||||
self._session._write(self._filter_message(msg))
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
resp = self._filter_message(self._session._read())
|
||||
print("reading message:", resp.__class__.__name__)
|
||||
if self.actual_responses is not None:
|
||||
self.actual_responses.append(resp)
|
||||
return resp
|
||||
|
||||
def set_expected_responses(
|
||||
self,
|
||||
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
|
||||
) -> None:
|
||||
"""Set a sequence of expected responses to session calls.
|
||||
|
||||
Within a given with-block, the list of received responses from device must
|
||||
match the list of expected responses, otherwise an ``AssertionError`` is raised.
|
||||
|
||||
If an expected response is given a field value other than ``None``, that field value
|
||||
must exactly match the received field value. If a given field is ``None``
|
||||
(or unspecified) in the expected response, the received field value is not
|
||||
checked.
|
||||
|
||||
Each expected response can also be a tuple ``(bool, message)``. In that case, the
|
||||
expected response is only evaluated if the first field is ``True``.
|
||||
This is useful for differentiating sequences between Trezor models:
|
||||
|
||||
>>> trezor_one = session.features.model == "1"
|
||||
>>> session.set_expected_responses([
|
||||
>>> messages.ButtonRequest(code=ConfirmOutput),
|
||||
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
|
||||
>>> messages.Success(),
|
||||
>>> ])
|
||||
"""
|
||||
if not self.in_with_statement:
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
|
||||
# make sure all items are (bool, message) tuples
|
||||
expected_with_validity = (
|
||||
e if isinstance(e, tuple) else (True, e) for e in expected
|
||||
)
|
||||
|
||||
# only apply those items that are (True, message)
|
||||
self.expected_responses = [
|
||||
MessageFilter.from_message_or_type(expected)
|
||||
for valid, expected in expected_with_validity
|
||||
if valid
|
||||
]
|
||||
self.actual_responses = []
|
||||
|
||||
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||
"""Lock the device.
|
||||
|
||||
If the device does not have a PIN configured, this will do nothing.
|
||||
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
||||
before further actions.
|
||||
|
||||
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
||||
the device will not prompt for it after unlocking.
|
||||
|
||||
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
||||
passphrase cache, use `clear_session()`.
|
||||
"""
|
||||
# TODO update the documentation above
|
||||
# Private argument _refresh_features can be used internally to avoid
|
||||
# refreshing in cases where we will refresh soon anyway. This is used
|
||||
# in TrezorClient.clear_session()
|
||||
self.call(messages.LockDevice())
|
||||
if _refresh_features:
|
||||
self.refresh_features()
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._write(messages.Cancel())
|
||||
|
||||
def ensure_unlocked(self) -> None:
|
||||
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||
self.refresh_features()
|
||||
|
||||
def set_filter(
|
||||
self,
|
||||
message_type: t.Type[protobuf.MessageType],
|
||||
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
"""Configure a filter function for a specified message type.
|
||||
|
||||
The `callback` must be a function that accepts a protobuf message, and returns
|
||||
a (possibly modified) protobuf message of the same type. Whenever a message
|
||||
is sent or received that matches `message_type`, `callback` is invoked on the
|
||||
message and its result is substituted for the original.
|
||||
|
||||
Useful for test scenarios with an active malicious actor on the wire.
|
||||
"""
|
||||
if not self.in_with_statement:
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
|
||||
self.filters[message_type] = callback
|
||||
|
||||
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
|
||||
message_type = msg.__class__
|
||||
callback = self.filters.get(message_type)
|
||||
if callable(callback):
|
||||
return callback(deepcopy(msg))
|
||||
else:
|
||||
return msg
|
||||
|
||||
def reset_debug_features(self) -> None:
|
||||
"""Prepare the debugging session for a new testcase.
|
||||
|
||||
Clears all debugging state that might have been modified by a testcase.
|
||||
"""
|
||||
self.in_with_statement = False
|
||||
self.expected_responses: list[MessageFilter] | None = None
|
||||
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||
self.filters: t.Dict[
|
||||
t.Type[protobuf.MessageType],
|
||||
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
] = {}
|
||||
self.button_callback = self.client.button_callback
|
||||
self.pin_callback = self.client.pin_callback
|
||||
self.passphrase_callback = self._session.passphrase_callback
|
||||
self.passphrase = self._session.passphrase
|
||||
|
||||
def __enter__(self) -> "SessionDebugWrapper":
|
||||
# For usage in with/expected_responses
|
||||
if self.in_with_statement:
|
||||
raise RuntimeError("Do not nest!")
|
||||
self.in_with_statement = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
# copy expected/actual responses before clearing them
|
||||
expected_responses = self.expected_responses
|
||||
actual_responses = self.actual_responses
|
||||
|
||||
# grab a copy of the inputflow generator to raise an exception through it
|
||||
if isinstance(self.client.ui, DebugUI):
|
||||
input_flow = self.client.ui.input_flow
|
||||
else:
|
||||
input_flow = None
|
||||
|
||||
self.reset_debug_features()
|
||||
|
||||
if exc_type is None:
|
||||
# If no other exception was raised, evaluate missed responses
|
||||
# (raises AssertionError on mismatch)
|
||||
self._verify_responses(expected_responses, actual_responses)
|
||||
if isinstance(input_flow, t.Generator):
|
||||
# Ensure that the input flow is exhausted
|
||||
try:
|
||||
input_flow.throw(
|
||||
AssertionError("input flow continues past end of test")
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
elif isinstance(input_flow, t.Generator):
|
||||
# Propagate the exception through the input flow, so that we see in
|
||||
# traceback where it is stuck.
|
||||
input_flow.throw(exc_type, value, traceback)
|
||||
|
||||
@classmethod
|
||||
def _verify_responses(
|
||||
cls,
|
||||
expected: list[MessageFilter] | None,
|
||||
actual: list[protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
if expected is None and actual is None:
|
||||
return
|
||||
|
||||
assert expected is not None
|
||||
assert actual is not None
|
||||
|
||||
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
||||
if exp is None:
|
||||
output = cls._expectation_lines(expected, i)
|
||||
output.append("No more messages were expected, but we got:")
|
||||
for resp in actual[i:]:
|
||||
output.append(
|
||||
textwrap.indent(protobuf.format_message(resp), " ")
|
||||
)
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
if act is None:
|
||||
output = cls._expectation_lines(expected, i)
|
||||
output.append("This and the following message was not received.")
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
if not exp.match(act):
|
||||
output = cls._expectation_lines(expected, i)
|
||||
output.append("Actually received:")
|
||||
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
@staticmethod
|
||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||
output: list[str] = []
|
||||
output.append("Expected responses:")
|
||||
if start_at > 0:
|
||||
output.append(f" (...{start_at} previous responses omitted)")
|
||||
for i in range(start_at, stop_at):
|
||||
exp = expected[i]
|
||||
prefix = " " if i != current else ">>> "
|
||||
output.append(textwrap.indent(exp.to_string(), prefix))
|
||||
if stop_at < len(expected):
|
||||
omitted = len(expected) - stop_at
|
||||
output.append(f" (...{omitted} following responses omitted)")
|
||||
|
||||
output.append("")
|
||||
return output
|
||||
|
||||
|
||||
class TrezorClientDebugLink(TrezorClient):
|
||||
# This class implements automatic responses
|
||||
# and other functionality for unit tests
|
||||
@ -967,23 +1232,34 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
raise
|
||||
|
||||
# set transport explicitly so that sync_responses can work
|
||||
self.transport = transport
|
||||
super().__init__(transport)
|
||||
|
||||
self.reset_debug_features()
|
||||
self.transport = transport
|
||||
self.ui: DebugUI = DebugUI(self.debug)
|
||||
|
||||
self.reset_debug_features(new_seedless_session=True)
|
||||
self.sync_responses()
|
||||
super().__init__(transport, ui=self.ui)
|
||||
|
||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
||||
# and know the supported debug capabilities
|
||||
self.debug.model = self.model
|
||||
self.debug.version = self.version
|
||||
self.passphrase: str | None = None
|
||||
|
||||
@property
|
||||
def layout_type(self) -> LayoutType:
|
||||
return self.debug.layout_type
|
||||
|
||||
def reset_debug_features(self) -> None:
|
||||
"""Prepare the debugging client for a new testcase.
|
||||
def get_new_client(self) -> TrezorClientDebugLink:
|
||||
new_client = TrezorClientDebugLink(
|
||||
self.transport, self.debug.allow_interactions
|
||||
)
|
||||
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
|
||||
return new_client
|
||||
|
||||
def reset_debug_features(self, new_seedless_session: bool = False) -> None:
|
||||
"""
|
||||
Prepare the debugging client for a new testcase.
|
||||
|
||||
Clears all debugging state that might have been modified by a testcase.
|
||||
"""
|
||||
@ -991,30 +1267,139 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
self.in_with_statement = False
|
||||
self.expected_responses: list[MessageFilter] | None = None
|
||||
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||
self.filters: dict[
|
||||
type[protobuf.MessageType],
|
||||
Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
self.filters: t.Dict[
|
||||
t.Type[protobuf.MessageType],
|
||||
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
] = {}
|
||||
if new_seedless_session:
|
||||
self._seedless_session = self.get_seedless_session(new_session=True)
|
||||
|
||||
@property
|
||||
def button_callback(self):
|
||||
|
||||
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
session._write(messages.ButtonAck())
|
||||
self.ui.button_request(msg)
|
||||
return session._read()
|
||||
|
||||
return _callback_button
|
||||
|
||||
@property
|
||||
def pin_callback(self):
|
||||
|
||||
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
|
||||
try:
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
except Cancelled:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if any(d not in "123456789" for d in pin) or not (
|
||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||
):
|
||||
session.call_raw(messages.Cancel())
|
||||
raise ValueError("Invalid PIN provided")
|
||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
return _callback_pin
|
||||
|
||||
@property
|
||||
def passphrase_callback(self):
|
||||
def _callback_passphrase(
|
||||
session: Session, msg: messages.PassphraseRequest
|
||||
) -> t.Any:
|
||||
available_on_device = (
|
||||
Capability.PassphraseEntry in session.features.capabilities
|
||||
)
|
||||
|
||||
def send_passphrase(
|
||||
passphrase: str | None = None, on_device: bool | None = None
|
||||
) -> MessageType:
|
||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||
resp = session.call_raw(msg)
|
||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||
# session.session_id = resp.state
|
||||
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||
return resp
|
||||
|
||||
# short-circuit old style entry
|
||||
if msg._on_device is True:
|
||||
return send_passphrase(None, None)
|
||||
|
||||
try:
|
||||
if session.passphrase is None and isinstance(session, SessionV1):
|
||||
passphrase = self.ui.get_passphrase(
|
||||
available_on_device=available_on_device
|
||||
)
|
||||
else:
|
||||
passphrase = session.passphrase
|
||||
except Cancelled:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||
if not available_on_device:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise RuntimeError("Device is not capable of entering passphrase")
|
||||
else:
|
||||
return send_passphrase(on_device=True)
|
||||
|
||||
# else process host-entered passphrase
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise ValueError("Passphrase too long")
|
||||
|
||||
return send_passphrase(passphrase, on_device=False)
|
||||
|
||||
return _callback_passphrase
|
||||
|
||||
def ensure_open(self) -> None:
|
||||
"""Only open session if there isn't already an open one."""
|
||||
if self.session_counter == 0:
|
||||
self.open()
|
||||
# if self.session_counter == 0:
|
||||
# self.open()
|
||||
# TODO check if is this needed
|
||||
|
||||
def open(self) -> None:
|
||||
super().open()
|
||||
if self.session_counter == 1:
|
||||
self.debug.open()
|
||||
pass
|
||||
# TODO is this needed?
|
||||
# self.debug.open()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.session_counter == 1:
|
||||
self.debug.close()
|
||||
super().close()
|
||||
pass
|
||||
# TODO is this needed?
|
||||
# self.debug.close()
|
||||
|
||||
def lock(self) -> None:
|
||||
s = SessionDebugWrapper(self.get_seedless_session())
|
||||
s.lock()
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object | None = "",
|
||||
derive_cardano: bool = False,
|
||||
session_id: int = 0,
|
||||
) -> Session:
|
||||
if isinstance(passphrase, str):
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
return super().get_session(passphrase, derive_cardano, session_id)
|
||||
|
||||
def set_filter(
|
||||
self,
|
||||
message_type: type[protobuf.MessageType],
|
||||
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
message_type: t.Type[protobuf.MessageType],
|
||||
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
"""Configure a filter function for a specified message type.
|
||||
|
||||
@ -1039,7 +1424,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
return msg
|
||||
|
||||
def set_input_flow(
|
||||
self, input_flow: InputFlowType | Callable[[], InputFlowType]
|
||||
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
|
||||
) -> None:
|
||||
"""Configure a sequence of input events for the current with-block.
|
||||
|
||||
@ -1095,7 +1480,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
self.in_with_statement = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
|
||||
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
# copy expected/actual responses before clearing them
|
||||
@ -1108,21 +1493,23 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
else:
|
||||
input_flow = None
|
||||
|
||||
self.reset_debug_features()
|
||||
self.reset_debug_features(new_seedless_session=False)
|
||||
|
||||
if exc_type is None:
|
||||
# If no other exception was raised, evaluate missed responses
|
||||
# (raises AssertionError on mismatch)
|
||||
self._verify_responses(expected_responses, actual_responses)
|
||||
|
||||
elif isinstance(input_flow, Generator):
|
||||
elif isinstance(input_flow, t.Generator):
|
||||
# Propagate the exception through the input flow, so that we see in
|
||||
# traceback where it is stuck.
|
||||
input_flow.throw(exc_type, value, traceback)
|
||||
|
||||
def set_expected_responses(
|
||||
self,
|
||||
expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]],
|
||||
expected: t.Sequence[
|
||||
t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]]
|
||||
],
|
||||
) -> None:
|
||||
"""Set a sequence of expected responses to client calls.
|
||||
|
||||
@ -1161,7 +1548,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
]
|
||||
self.actual_responses = []
|
||||
|
||||
def use_pin_sequence(self, pins: Iterable[str]) -> None:
|
||||
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
|
||||
"""Respond to PIN prompts from device with the provided PINs.
|
||||
The sequence must be at least as long as the expected number of PIN prompts.
|
||||
"""
|
||||
@ -1169,6 +1556,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
def use_passphrase(self, passphrase: str) -> None:
|
||||
"""Respond to passphrase prompts from device with the provided passphrase."""
|
||||
self.passphrase = passphrase
|
||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||
|
||||
def use_mnemonic(self, mnemonic: str) -> None:
|
||||
@ -1178,15 +1566,14 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
def _raw_read(self) -> protobuf.MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
resp = super()._raw_read()
|
||||
resp = self.get_seedless_session()._read()
|
||||
resp = self._filter_message(resp)
|
||||
if self.actual_responses is not None:
|
||||
self.actual_responses.append(resp)
|
||||
return resp
|
||||
|
||||
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
||||
return super()._raw_write(self._filter_message(msg))
|
||||
return self.get_seedless_session()._write(self._filter_message(msg))
|
||||
|
||||
@staticmethod
|
||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||
@ -1256,23 +1643,25 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
||||
# prompt, which is in TINY mode and does not respond to `Ping`.
|
||||
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
||||
self.transport.begin_session()
|
||||
# TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
||||
self.transport.open()
|
||||
try:
|
||||
self.transport.write(*cancel_msg)
|
||||
|
||||
# self.protocol.write(messages.Cancel())
|
||||
message = "SYNC" + secrets.token_hex(8)
|
||||
ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message))
|
||||
self.transport.write(*ping_msg)
|
||||
self.get_seedless_session()._write(messages.Ping(message=message))
|
||||
resp = None
|
||||
while resp != messages.Success(message=message):
|
||||
msg_id, msg_bytes = self.transport.read()
|
||||
try:
|
||||
resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes)
|
||||
resp = self.get_seedless_session()._read()
|
||||
|
||||
raise Exception
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
finally:
|
||||
self.transport.end_session()
|
||||
pass # TODO fix
|
||||
# self.transport.end_session(self.session_id or b"")
|
||||
|
||||
def mnemonic_callback(self, _) -> str:
|
||||
word, pos = self.debug.read_recovery_word()
|
||||
@ -1285,8 +1674,8 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
|
||||
def load_device(
|
||||
client: "TrezorClient",
|
||||
mnemonic: Union[str, Iterable[str]],
|
||||
session: "Session",
|
||||
mnemonic: str | t.Iterable[str],
|
||||
pin: str | None,
|
||||
passphrase_protection: bool,
|
||||
label: str | None,
|
||||
@ -1299,12 +1688,12 @@ def load_device(
|
||||
|
||||
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
||||
|
||||
if client.features.initialized:
|
||||
if session.features.initialized:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call device.wipe() and try again."
|
||||
)
|
||||
|
||||
client.call(
|
||||
session.call(
|
||||
messages.LoadDevice(
|
||||
mnemonics=mnemonics,
|
||||
pin=pin,
|
||||
@ -1316,18 +1705,18 @@ def load_device(
|
||||
),
|
||||
expect=messages.Success,
|
||||
)
|
||||
client.init_device()
|
||||
session.refresh_features()
|
||||
|
||||
|
||||
# keep the old name for compatibility
|
||||
load_device_by_mnemonic = load_device
|
||||
|
||||
|
||||
def prodtest_t1(client: "TrezorClient") -> None:
|
||||
if client.features.bootloader_mode is not True:
|
||||
def prodtest_t1(session: "Session") -> None:
|
||||
if session.features.bootloader_mode is not True:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
||||
client.call(
|
||||
session.call(
|
||||
messages.ProdTestT1(
|
||||
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
|
||||
),
|
||||
@ -1337,8 +1726,8 @@ def prodtest_t1(client: "TrezorClient") -> None:
|
||||
|
||||
def record_screen(
|
||||
debug_client: "TrezorClientDebugLink",
|
||||
directory: Union[str, None],
|
||||
report_func: Union[Callable[[str], None], None] = None,
|
||||
directory: str | None,
|
||||
report_func: t.Callable[[str], None] | None = None,
|
||||
) -> None:
|
||||
"""Record screen changes into a specified directory.
|
||||
|
||||
@ -1383,5 +1772,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
|
||||
return debug_client.features.fw_vendor == "EMULATOR"
|
||||
|
||||
|
||||
def optiga_set_sec_max(client: "TrezorClient") -> None:
|
||||
client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)
|
||||
def optiga_set_sec_max(session: "Session") -> None:
|
||||
session.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)
|
||||
|
@ -28,16 +28,10 @@ from slip10 import SLIP10
|
||||
|
||||
from . import messages
|
||||
from .exceptions import Cancelled, TrezorException
|
||||
from .tools import (
|
||||
Address,
|
||||
_deprecation_retval_helper,
|
||||
_return_success,
|
||||
parse_path,
|
||||
session,
|
||||
)
|
||||
from .tools import Address, _deprecation_retval_helper, _return_success, parse_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||
@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
|
||||
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
|
||||
|
||||
|
||||
@session
|
||||
def apply_settings(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
label: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
use_passphrase: Optional[bool] = None,
|
||||
@ -79,13 +72,13 @@ def apply_settings(
|
||||
haptic_feedback=haptic_feedback,
|
||||
)
|
||||
|
||||
out = client.call(settings, expect=messages.Success)
|
||||
client.refresh_features()
|
||||
out = session.call(settings, expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(out)
|
||||
|
||||
|
||||
def _send_language_data(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
request: "messages.TranslationDataRequest",
|
||||
language_data: bytes,
|
||||
) -> None:
|
||||
@ -95,69 +88,63 @@ def _send_language_data(
|
||||
data_length = response.data_length
|
||||
data_offset = response.data_offset
|
||||
chunk = language_data[data_offset : data_offset + data_length]
|
||||
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||
|
||||
|
||||
@session
|
||||
def change_language(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
language_data: bytes,
|
||||
show_display: bool | None = None,
|
||||
) -> str | None:
|
||||
data_length = len(language_data)
|
||||
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
if data_length > 0:
|
||||
response = messages.TranslationDataRequest.ensure_isinstance(response)
|
||||
_send_language_data(client, response, language_data)
|
||||
_send_language_data(session, response, language_data)
|
||||
else:
|
||||
messages.Success.ensure_isinstance(response)
|
||||
client.refresh_features() # changing the language in features
|
||||
session.refresh_features() # changing the language in features
|
||||
return _return_success(messages.Success(message="Language changed."))
|
||||
|
||||
|
||||
@session
|
||||
def apply_flags(client: "TrezorClient", flags: int) -> str | None:
|
||||
out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
def apply_flags(session: "Session", flags: int) -> str | None:
|
||||
out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(out)
|
||||
|
||||
|
||||
@session
|
||||
def change_pin(client: "TrezorClient", remove: bool = False) -> str | None:
|
||||
ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
def change_pin(session: "Session", remove: bool = False) -> str | None:
|
||||
ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None:
|
||||
ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
def change_wipe_code(session: "Session", remove: bool = False) -> str | None:
|
||||
ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
session: "Session", operation: messages.SdProtectOperationType
|
||||
) -> str | None:
|
||||
ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def wipe(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.WipeDevice(), expect=messages.Success)
|
||||
if not client.features.bootloader_mode:
|
||||
client.init_device()
|
||||
def wipe(session: "Session") -> str | None:
|
||||
ret = session.call(messages.WipeDevice(), expect=messages.Success)
|
||||
session.invalidate()
|
||||
# if not session.features.bootloader_mode:
|
||||
# session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def recover(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
word_count: int = 24,
|
||||
passphrase_protection: bool = False,
|
||||
pin_protection: bool = True,
|
||||
@ -193,13 +180,13 @@ def recover(
|
||||
if type is None:
|
||||
type = messages.RecoveryType.NormalRecovery
|
||||
|
||||
if client.features.model == "1" and input_callback is None:
|
||||
if session.features.model == "1" and input_callback is None:
|
||||
raise RuntimeError("Input callback required for Trezor One")
|
||||
|
||||
if word_count not in (12, 18, 24):
|
||||
raise ValueError("Invalid word count. Use 12/18/24")
|
||||
|
||||
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||
if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||
raise RuntimeError(
|
||||
"Device already initialized. Call device.wipe() and try again."
|
||||
)
|
||||
@ -221,20 +208,20 @@ def recover(
|
||||
msg.label = label
|
||||
msg.u2f_counter = u2f_counter
|
||||
|
||||
res = client.call(msg)
|
||||
res = session.call(msg)
|
||||
|
||||
while isinstance(res, messages.WordRequest):
|
||||
try:
|
||||
assert input_callback is not None
|
||||
inp = input_callback(res.type)
|
||||
res = client.call(messages.WordAck(word=inp))
|
||||
res = session.call(messages.WordAck(word=inp))
|
||||
except Cancelled:
|
||||
res = client.call(messages.Cancel())
|
||||
res = session.call(messages.Cancel())
|
||||
|
||||
# check that the result is a Success
|
||||
res = messages.Success.ensure_isinstance(res)
|
||||
# reinitialize the device
|
||||
client.init_device()
|
||||
session.refresh_features()
|
||||
|
||||
return _deprecation_retval_helper(res)
|
||||
|
||||
@ -280,7 +267,7 @@ def _seed_from_entropy(
|
||||
|
||||
|
||||
def reset(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
display_random: bool = False,
|
||||
strength: Optional[int] = None,
|
||||
passphrase_protection: bool = False,
|
||||
@ -313,7 +300,7 @@ def reset(
|
||||
)
|
||||
|
||||
setup(
|
||||
client,
|
||||
session,
|
||||
strength=strength,
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -331,9 +318,8 @@ def _get_external_entropy() -> bytes:
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
|
||||
@session
|
||||
def setup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
*,
|
||||
strength: Optional[int] = None,
|
||||
passphrase_protection: bool = True,
|
||||
@ -388,19 +374,19 @@ def setup(
|
||||
check.
|
||||
"""
|
||||
|
||||
if client.features.initialized:
|
||||
if session.features.initialized:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call wipe_device() and try again."
|
||||
)
|
||||
|
||||
if strength is None:
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
strength = 256
|
||||
else:
|
||||
strength = 128
|
||||
|
||||
if backup_type is None:
|
||||
if client.version < SLIP39_EXTENDABLE_MIN_VERSION:
|
||||
if session.version < SLIP39_EXTENDABLE_MIN_VERSION:
|
||||
# includes Trezor One 1.x.x
|
||||
backup_type = messages.BackupType.Bip39
|
||||
else:
|
||||
@ -411,7 +397,7 @@ def setup(
|
||||
paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")]
|
||||
|
||||
if entropy_check_count is None:
|
||||
if client.version < ENTROPY_CHECK_MIN_VERSION:
|
||||
if session.version < ENTROPY_CHECK_MIN_VERSION:
|
||||
# includes Trezor One 1.x.x
|
||||
entropy_check_count = 0
|
||||
else:
|
||||
@ -431,18 +417,18 @@ def setup(
|
||||
)
|
||||
if entropy_check_count > 0:
|
||||
xpubs = _reset_with_entropycheck(
|
||||
client, msg, entropy_check_count, paths, _get_entropy
|
||||
session, msg, entropy_check_count, paths, _get_entropy
|
||||
)
|
||||
else:
|
||||
_reset_no_entropycheck(client, msg, _get_entropy)
|
||||
_reset_no_entropycheck(session, msg, _get_entropy)
|
||||
xpubs = []
|
||||
|
||||
client.init_device()
|
||||
session.refresh_features()
|
||||
return xpubs
|
||||
|
||||
|
||||
def _reset_no_entropycheck(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
msg: messages.ResetDevice,
|
||||
get_entropy: Callable[[], bytes],
|
||||
) -> None:
|
||||
@ -454,12 +440,12 @@ def _reset_no_entropycheck(
|
||||
<< Success
|
||||
"""
|
||||
assert msg.entropy_check is False
|
||||
client.call(msg, expect=messages.EntropyRequest)
|
||||
client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
|
||||
session.call(msg, expect=messages.EntropyRequest)
|
||||
session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
|
||||
|
||||
|
||||
def _reset_with_entropycheck(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
reset_msg: messages.ResetDevice,
|
||||
entropy_check_count: int,
|
||||
paths: Iterable[Address],
|
||||
@ -495,7 +481,7 @@ def _reset_with_entropycheck(
|
||||
def get_xpubs() -> list[tuple[Address, str]]:
|
||||
xpubs = []
|
||||
for path in paths:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.GetPublicKey(address_n=path), expect=messages.PublicKey
|
||||
)
|
||||
xpubs.append((path, resp.xpub))
|
||||
@ -524,13 +510,13 @@ def _reset_with_entropycheck(
|
||||
raise TrezorException("Invalid XPUB in entropy check")
|
||||
|
||||
xpubs = []
|
||||
resp = client.call(reset_msg, expect=messages.EntropyRequest)
|
||||
resp = session.call(reset_msg, expect=messages.EntropyRequest)
|
||||
entropy_commitment = resp.entropy_commitment
|
||||
|
||||
while True:
|
||||
# provide external entropy for this round
|
||||
external_entropy = get_entropy()
|
||||
client.call(
|
||||
session.call(
|
||||
messages.EntropyAck(entropy=external_entropy),
|
||||
expect=messages.EntropyCheckReady,
|
||||
)
|
||||
@ -540,7 +526,7 @@ def _reset_with_entropycheck(
|
||||
|
||||
if entropy_check_count <= 0:
|
||||
# last round, wait for a Success and exit the loop
|
||||
client.call(
|
||||
session.call(
|
||||
messages.EntropyCheckContinue(finish=True),
|
||||
expect=messages.Success,
|
||||
)
|
||||
@ -549,7 +535,7 @@ def _reset_with_entropycheck(
|
||||
entropy_check_count -= 1
|
||||
|
||||
# Next round starts.
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.EntropyCheckContinue(finish=False),
|
||||
expect=messages.EntropyRequest,
|
||||
)
|
||||
@ -570,13 +556,12 @@ def _reset_with_entropycheck(
|
||||
return xpubs
|
||||
|
||||
|
||||
@session
|
||||
def backup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
group_threshold: Optional[int] = None,
|
||||
groups: Iterable[tuple[int, int]] = (),
|
||||
) -> str | None:
|
||||
ret = client.call(
|
||||
ret = session.call(
|
||||
messages.BackupDevice(
|
||||
group_threshold=group_threshold,
|
||||
groups=[
|
||||
@ -586,37 +571,36 @@ def backup(
|
||||
),
|
||||
expect=messages.Success,
|
||||
)
|
||||
client.refresh_features()
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def cancel_authorization(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.CancelAuthorization(), expect=messages.Success)
|
||||
def cancel_authorization(session: "Session") -> str | None:
|
||||
ret = session.call(messages.CancelAuthorization(), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def unlock_path(client: "TrezorClient", n: "Address") -> bytes:
|
||||
resp = client.call(
|
||||
def unlock_path(session: "Session", n: "Address") -> bytes:
|
||||
resp = session.call(
|
||||
messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
|
||||
)
|
||||
|
||||
# Cancel the UnlockPath workflow now that we have the authentication code.
|
||||
try:
|
||||
client.call(messages.Cancel())
|
||||
session.call(messages.Cancel())
|
||||
except Cancelled:
|
||||
return resp.mac
|
||||
else:
|
||||
raise TrezorException("Unexpected response in UnlockPath flow")
|
||||
|
||||
|
||||
@session
|
||||
def reboot_to_bootloader(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
|
||||
firmware_header: Optional[bytes] = None,
|
||||
language_data: bytes = b"",
|
||||
) -> str | None:
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
messages.RebootToBootloader(
|
||||
boot_command=boot_command,
|
||||
firmware_header=firmware_header,
|
||||
@ -624,43 +608,38 @@ def reboot_to_bootloader(
|
||||
)
|
||||
)
|
||||
if isinstance(response, messages.TranslationDataRequest):
|
||||
response = _send_language_data(client, response, language_data)
|
||||
response = _send_language_data(session, response, language_data)
|
||||
return _return_success(messages.Success(message=""))
|
||||
|
||||
|
||||
@session
|
||||
def show_device_tutorial(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
|
||||
def show_device_tutorial(session: "Session") -> str | None:
|
||||
ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def unlock_bootloader(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
|
||||
def unlock_bootloader(session: "Session") -> str | None:
|
||||
ret = session.call(messages.UnlockBootloader(), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
|
||||
def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None:
|
||||
"""Sets or clears the busy state of the device.
|
||||
|
||||
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
|
||||
Setting `expiry_ms=None` clears the busy state.
|
||||
"""
|
||||
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def authenticate(
|
||||
client: "TrezorClient", challenge: bytes
|
||||
) -> messages.AuthenticityProof:
|
||||
return client.call(
|
||||
def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof:
|
||||
return session.call(
|
||||
messages.AuthenticateDevice(challenge=challenge),
|
||||
expect=messages.AuthenticityProof,
|
||||
)
|
||||
|
||||
|
||||
def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None:
|
||||
ret = client.call(messages.SetBrightness(value=value), expect=messages.Success)
|
||||
def set_brightness(session: "Session", value: Optional[int] = None) -> str | None:
|
||||
ret = session.call(messages.SetBrightness(value=value), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
@ -18,11 +18,11 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import b58decode, session
|
||||
from .tools import b58decode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def name_to_number(name: str) -> int:
|
||||
@ -319,17 +319,16 @@ def parse_transaction_json(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
session: "Session", n: "Address", show_display: bool = False
|
||||
) -> messages.EosPublicKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EosGetPublicKey(address_n=n, show_display=show_display),
|
||||
expect=messages.EosPublicKey,
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: "Address",
|
||||
transaction: dict,
|
||||
chain_id: str,
|
||||
@ -345,11 +344,11 @@ def sign_tx(
|
||||
chunkify=chunkify,
|
||||
)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
|
||||
try:
|
||||
while isinstance(response, messages.EosTxActionRequest):
|
||||
response = client.call(actions.pop(0))
|
||||
response = session.call(actions.pop(0))
|
||||
except IndexError:
|
||||
# pop from empty list
|
||||
raise exceptions.TrezorException(
|
||||
|
@ -18,11 +18,11 @@ import re
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
||||
|
||||
from . import definitions, exceptions, messages
|
||||
from .tools import prepare_message_bytes, session, unharden
|
||||
from .tools import prepare_message_bytes, unharden
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def int_to_big_endian(value: int) -> bytes:
|
||||
@ -161,13 +161,13 @@ def network_from_address_n(
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
encoded_network: Optional[bytes] = None,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.EthereumGetAddress(
|
||||
address_n=n,
|
||||
show_display=show_display,
|
||||
@ -181,17 +181,16 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_node(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
session: "Session", n: "Address", show_display: bool = False
|
||||
) -> messages.EthereumPublicKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
|
||||
expect=messages.EthereumPublicKey,
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
nonce: int,
|
||||
gas_price: int,
|
||||
@ -227,13 +226,13 @@ def sign_tx(
|
||||
data, chunk = data[1024:], data[:1024]
|
||||
msg.data_initial_chunk = chunk
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
@ -248,9 +247,8 @@ def sign_tx(
|
||||
return response.signature_v, response.signature_r, response.signature_s
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx_eip1559(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
*,
|
||||
nonce: int,
|
||||
@ -283,13 +281,13 @@ def sign_tx_eip1559(
|
||||
chunkify=chunkify,
|
||||
)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
@ -299,13 +297,13 @@ def sign_tx_eip1559(
|
||||
|
||||
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
encoded_network: Optional[bytes] = None,
|
||||
chunkify: bool = False,
|
||||
) -> messages.EthereumMessageSignature:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumSignMessage(
|
||||
address_n=n,
|
||||
message=prepare_message_bytes(message),
|
||||
@ -317,7 +315,7 @@ def sign_message(
|
||||
|
||||
|
||||
def sign_typed_data(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
data: Dict[str, Any],
|
||||
*,
|
||||
@ -333,7 +331,7 @@ def sign_typed_data(
|
||||
metamask_v4_compat=metamask_v4_compat,
|
||||
definitions=definitions,
|
||||
)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
# Sending all the types
|
||||
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
||||
@ -349,7 +347,7 @@ def sign_typed_data(
|
||||
members.append(struct_member)
|
||||
|
||||
request = messages.EthereumTypedDataStructAck(members=members)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
# Sending the whole message that should be signed
|
||||
while isinstance(response, messages.EthereumTypedDataValueRequest):
|
||||
@ -362,7 +360,7 @@ def sign_typed_data(
|
||||
member_typename = data["primaryType"]
|
||||
member_data = data["message"]
|
||||
else:
|
||||
client.cancel()
|
||||
# TODO session.cancel()
|
||||
raise exceptions.TrezorException("Root index can only be 0 or 1")
|
||||
|
||||
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
|
||||
@ -385,20 +383,20 @@ def sign_typed_data(
|
||||
encoded_data = encode_data(member_data, member_typename)
|
||||
|
||||
request = messages.EthereumTypedDataValueAck(value=encoded_data)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
return messages.EthereumTypedDataSignature.ensure_isinstance(response)
|
||||
|
||||
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
signature: bytes,
|
||||
message: AnyStr,
|
||||
chunkify: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.EthereumVerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
@ -413,13 +411,13 @@ def verify_message(
|
||||
|
||||
|
||||
def sign_typed_data_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
domain_hash: bytes,
|
||||
message_hash: Optional[bytes],
|
||||
encoded_network: Optional[bytes] = None,
|
||||
) -> messages.EthereumTypedDataSignature:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumSignTypedHash(
|
||||
address_n=n,
|
||||
domain_separator_hash=domain_hash,
|
||||
|
@ -65,3 +65,7 @@ class UnexpectedMessageError(TrezorException):
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")
|
||||
|
||||
|
||||
class DeviceLockedException(TrezorException):
|
||||
pass
|
||||
|
@ -22,37 +22,37 @@ from . import messages
|
||||
from .tools import _return_success
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]:
|
||||
return client.call(
|
||||
def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]:
|
||||
return session.call(
|
||||
messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
|
||||
).credentials
|
||||
|
||||
|
||||
def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None:
|
||||
ret = client.call(
|
||||
def add_credential(session: "Session", credential_id: bytes) -> str | None:
|
||||
ret = session.call(
|
||||
messages.WebAuthnAddResidentCredential(credential_id=credential_id),
|
||||
expect=messages.Success,
|
||||
)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def remove_credential(client: "TrezorClient", index: int) -> str | None:
|
||||
ret = client.call(
|
||||
def remove_credential(session: "Session", index: int) -> str | None:
|
||||
ret = session.call(
|
||||
messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
|
||||
)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None:
|
||||
ret = client.call(
|
||||
def set_counter(session: "Session", u2f_counter: int) -> str | None:
|
||||
ret = session.call(
|
||||
messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
|
||||
)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def get_next_counter(client: "TrezorClient") -> int:
|
||||
ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
|
||||
def get_next_counter(session: "Session") -> int:
|
||||
ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
|
||||
return ret.u2f_counter
|
||||
|
@ -20,7 +20,6 @@ from hashlib import blake2s
|
||||
from typing_extensions import Protocol, TypeGuard
|
||||
|
||||
from .. import messages
|
||||
from ..tools import session
|
||||
from .core import VendorFirmware
|
||||
from .legacy import LegacyFirmware, LegacyV2Firmware
|
||||
|
||||
@ -38,7 +37,7 @@ if True:
|
||||
from .vendor import * # noqa: F401, F403
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
T = t.TypeVar("T", bound="FirmwareType")
|
||||
|
||||
@ -72,20 +71,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@session
|
||||
def update(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
data: bytes,
|
||||
progress_update: t.Callable[[int], t.Any] = lambda _: None,
|
||||
):
|
||||
if client.features.bootloader_mode is False:
|
||||
if session.features.bootloader_mode is False:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
||||
resp = client.call(messages.FirmwareErase(length=len(data)))
|
||||
resp = session.call(messages.FirmwareErase(length=len(data)))
|
||||
|
||||
# TREZORv1 method
|
||||
if isinstance(resp, messages.Success):
|
||||
resp = client.call(messages.FirmwareUpload(payload=data))
|
||||
resp = session.call(messages.FirmwareUpload(payload=data))
|
||||
progress_update(len(data))
|
||||
if isinstance(resp, messages.Success):
|
||||
return
|
||||
@ -97,7 +95,7 @@ def update(
|
||||
length = resp.length
|
||||
payload = data[resp.offset : resp.offset + length]
|
||||
digest = blake2s(payload).digest()
|
||||
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||
progress_update(length)
|
||||
|
||||
if isinstance(resp, messages.Success):
|
||||
@ -106,7 +104,7 @@ def update(
|
||||
raise RuntimeError(f"Unexpected message {resp}")
|
||||
|
||||
|
||||
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes:
|
||||
return client.call(
|
||||
def get_hash(session: "Session", challenge: t.Optional[bytes]) -> bytes:
|
||||
return session.call(
|
||||
messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
|
||||
).hash
|
||||
|
@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
@ -25,6 +26,7 @@ from typing_extensions import Self
|
||||
from . import messages, protobuf
|
||||
|
||||
T = TypeVar("T")
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtobufMapping:
|
||||
@ -63,11 +65,21 @@ class ProtobufMapping:
|
||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||
if wire_type is None:
|
||||
raise ValueError("Cannot encode class without wire type")
|
||||
|
||||
LOG.debug("encoding wire type %d", wire_type)
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return wire_type, buf.getvalue()
|
||||
|
||||
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
||||
"""Serialize a Python protobuf class.
|
||||
|
||||
Returns the byte representation of the protobuf message.
|
||||
"""
|
||||
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return buf.getvalue()
|
||||
|
||||
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
||||
"""Deserialize a protobuf message into a Python class."""
|
||||
cls = self.type_to_class[msg_wire_type]
|
||||
@ -83,7 +95,9 @@ class ProtobufMapping:
|
||||
mapping = cls()
|
||||
|
||||
message_types = getattr(module, "MessageType")
|
||||
for entry in message_types:
|
||||
thp_message_types = getattr(module, "ThpMessageType")
|
||||
|
||||
for entry in (*message_types, *thp_message_types):
|
||||
msg_class = getattr(module, entry.name, None)
|
||||
if msg_class is None:
|
||||
raise ValueError(
|
||||
|
352
python/src/trezorlib/messages.py
generated
352
python/src/trezorlib/messages.py
generated
@ -43,6 +43,10 @@ class FailureType(IntEnum):
|
||||
PinMismatch = 12
|
||||
WipeCodeMismatch = 13
|
||||
InvalidSession = 14
|
||||
ThpUnallocatedSession = 15
|
||||
InvalidProtocol = 16
|
||||
BufferError = 17
|
||||
DeviceIsBusy = 18
|
||||
FirmwareError = 99
|
||||
|
||||
|
||||
@ -400,6 +404,34 @@ class TezosBallotType(IntEnum):
|
||||
Pass = 2
|
||||
|
||||
|
||||
class ThpMessageType(IntEnum):
|
||||
ThpCreateNewSession = 1000
|
||||
ThpPairingRequest = 1006
|
||||
ThpPairingRequestApproved = 1007
|
||||
ThpSelectMethod = 1008
|
||||
ThpPairingPreparationsFinished = 1009
|
||||
ThpCredentialRequest = 1010
|
||||
ThpCredentialResponse = 1011
|
||||
ThpEndRequest = 1012
|
||||
ThpEndResponse = 1013
|
||||
ThpCodeEntryCommitment = 1016
|
||||
ThpCodeEntryChallenge = 1017
|
||||
ThpCodeEntryCpaceTrezor = 1018
|
||||
ThpCodeEntryCpaceHostTag = 1019
|
||||
ThpCodeEntrySecret = 1020
|
||||
ThpQrCodeTag = 1024
|
||||
ThpQrCodeSecret = 1025
|
||||
ThpNfcTagHost = 1032
|
||||
ThpNfcTagTrezor = 1033
|
||||
|
||||
|
||||
class ThpPairingMethod(IntEnum):
|
||||
SkipPairing = 1
|
||||
CodeEntry = 2
|
||||
QrCode = 3
|
||||
NFC = 4
|
||||
|
||||
|
||||
class MessageType(IntEnum):
|
||||
Initialize = 0
|
||||
Ping = 1
|
||||
@ -500,6 +532,8 @@ class MessageType(IntEnum):
|
||||
DebugLinkWatchLayout = 9006
|
||||
DebugLinkResetDebugEvents = 9007
|
||||
DebugLinkOptigaSetSecMax = 9008
|
||||
DebugLinkGetPairingInfo = 9009
|
||||
DebugLinkPairingInfo = 9010
|
||||
EthereumGetPublicKey = 450
|
||||
EthereumPublicKey = 451
|
||||
EthereumGetAddress = 56
|
||||
@ -4203,6 +4237,52 @@ class DebugLinkState(protobuf.MessageType):
|
||||
self.mnemonic_type = mnemonic_type
|
||||
|
||||
|
||||
class DebugLinkGetPairingInfo(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 9009
|
||||
FIELDS = {
|
||||
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
|
||||
3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: Optional["bytes"] = None,
|
||||
handshake_hash: Optional["bytes"] = None,
|
||||
nfc_secret_host: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.handshake_hash = handshake_hash
|
||||
self.nfc_secret_host = nfc_secret_host
|
||||
|
||||
|
||||
class DebugLinkPairingInfo(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 9010
|
||||
FIELDS = {
|
||||
1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None),
|
||||
3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None),
|
||||
4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None),
|
||||
5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: Optional["bytes"] = None,
|
||||
handshake_hash: Optional["bytes"] = None,
|
||||
code_entry_code: Optional["int"] = None,
|
||||
code_qr_code: Optional["bytes"] = None,
|
||||
nfc_secret_trezor: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.handshake_hash = handshake_hash
|
||||
self.code_entry_code = code_entry_code
|
||||
self.code_qr_code = code_qr_code
|
||||
self.nfc_secret_trezor = nfc_secret_trezor
|
||||
|
||||
|
||||
class DebugLinkStop(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 103
|
||||
|
||||
@ -7863,8 +7943,68 @@ class TezosManagerTransfer(protobuf.MessageType):
|
||||
self.amount = amount
|
||||
|
||||
|
||||
class ThpCredentialMetadata(protobuf.MessageType):
|
||||
class ThpDeviceProperties(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = None
|
||||
FIELDS = {
|
||||
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
|
||||
3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None),
|
||||
4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None),
|
||||
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
|
||||
internal_model: Optional["str"] = None,
|
||||
model_variant: Optional["int"] = None,
|
||||
protocol_version_major: Optional["int"] = None,
|
||||
protocol_version_minor: Optional["int"] = None,
|
||||
) -> None:
|
||||
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
|
||||
self.internal_model = internal_model
|
||||
self.model_variant = model_variant
|
||||
self.protocol_version_major = protocol_version_major
|
||||
self.protocol_version_minor = protocol_version_minor
|
||||
|
||||
|
||||
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = None
|
||||
FIELDS = {
|
||||
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_pairing_credential: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.host_pairing_credential = host_pairing_credential
|
||||
|
||||
|
||||
class ThpCreateNewSession(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1000
|
||||
FIELDS = {
|
||||
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
|
||||
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
passphrase: Optional["str"] = None,
|
||||
on_device: Optional["bool"] = None,
|
||||
derive_cardano: Optional["bool"] = None,
|
||||
) -> None:
|
||||
self.passphrase = passphrase
|
||||
self.on_device = on_device
|
||||
self.derive_cardano = derive_cardano
|
||||
|
||||
|
||||
class ThpPairingRequest(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1006
|
||||
FIELDS = {
|
||||
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||
}
|
||||
@ -7877,6 +8017,216 @@ class ThpCredentialMetadata(protobuf.MessageType):
|
||||
self.host_name = host_name
|
||||
|
||||
|
||||
class ThpPairingRequestApproved(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1007
|
||||
|
||||
|
||||
class ThpSelectMethod(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1008
|
||||
FIELDS = {
|
||||
1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
selected_pairing_method: Optional["ThpPairingMethod"] = None,
|
||||
) -> None:
|
||||
self.selected_pairing_method = selected_pairing_method
|
||||
|
||||
|
||||
class ThpPairingPreparationsFinished(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1009
|
||||
|
||||
|
||||
class ThpCodeEntryCommitment(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1016
|
||||
FIELDS = {
|
||||
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
commitment: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.commitment = commitment
|
||||
|
||||
|
||||
class ThpCodeEntryChallenge(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1017
|
||||
FIELDS = {
|
||||
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
challenge: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.challenge = challenge
|
||||
|
||||
|
||||
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1018
|
||||
FIELDS = {
|
||||
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cpace_trezor_public_key: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.cpace_trezor_public_key = cpace_trezor_public_key
|
||||
|
||||
|
||||
class ThpCodeEntryCpaceHostTag(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1019
|
||||
FIELDS = {
|
||||
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cpace_host_public_key: Optional["bytes"] = None,
|
||||
tag: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.cpace_host_public_key = cpace_host_public_key
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class ThpCodeEntrySecret(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1020
|
||||
FIELDS = {
|
||||
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.secret = secret
|
||||
|
||||
|
||||
class ThpQrCodeTag(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1024
|
||||
FIELDS = {
|
||||
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class ThpQrCodeSecret(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1025
|
||||
FIELDS = {
|
||||
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.secret = secret
|
||||
|
||||
|
||||
class ThpNfcTagHost(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1032
|
||||
FIELDS = {
|
||||
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class ThpNfcTagTrezor(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1033
|
||||
FIELDS = {
|
||||
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tag: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class ThpCredentialRequest(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1010
|
||||
FIELDS = {
|
||||
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_static_pubkey: Optional["bytes"] = None,
|
||||
autoconnect: Optional["bool"] = None,
|
||||
) -> None:
|
||||
self.host_static_pubkey = host_static_pubkey
|
||||
self.autoconnect = autoconnect
|
||||
|
||||
|
||||
class ThpCredentialResponse(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1011
|
||||
FIELDS = {
|
||||
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
trezor_static_pubkey: Optional["bytes"] = None,
|
||||
credential: Optional["bytes"] = None,
|
||||
) -> None:
|
||||
self.trezor_static_pubkey = trezor_static_pubkey
|
||||
self.credential = credential
|
||||
|
||||
|
||||
class ThpEndRequest(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1012
|
||||
|
||||
|
||||
class ThpEndResponse(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = 1013
|
||||
|
||||
|
||||
class ThpCredentialMetadata(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = None
|
||||
FIELDS = {
|
||||
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||
2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host_name: Optional["str"] = None,
|
||||
autoconnect: Optional["bool"] = None,
|
||||
) -> None:
|
||||
self.host_name = host_name
|
||||
self.autoconnect = autoconnect
|
||||
|
||||
|
||||
class ThpPairingCredential(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = None
|
||||
FIELDS = {
|
||||
|
@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_entropy(client: "TrezorClient", size: int) -> bytes:
|
||||
return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
|
||||
def get_entropy(session: "Session", size: int) -> bytes:
|
||||
return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
|
||||
|
||||
|
||||
def sign_identity(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
identity: messages.IdentityType,
|
||||
challenge_hidden: bytes,
|
||||
challenge_visual: str,
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> messages.SignedIdentity:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SignIdentity(
|
||||
identity=identity,
|
||||
challenge_hidden=challenge_hidden,
|
||||
@ -46,12 +46,12 @@ def sign_identity(
|
||||
|
||||
|
||||
def get_ecdh_session_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
identity: messages.IdentityType,
|
||||
peer_public_key: bytes,
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> messages.ECDHSessionKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetECDHSessionKey(
|
||||
identity=identity,
|
||||
peer_public_key=peer_public_key,
|
||||
@ -62,7 +62,7 @@ def get_ecdh_session_key(
|
||||
|
||||
|
||||
def encrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
@ -70,7 +70,7 @@ def encrypt_keyvalue(
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
key=key,
|
||||
@ -85,7 +85,7 @@ def encrypt_keyvalue(
|
||||
|
||||
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
@ -93,7 +93,7 @@ def decrypt_keyvalue(
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
key=key,
|
||||
@ -107,5 +107,5 @@ def decrypt_keyvalue(
|
||||
).value
|
||||
|
||||
|
||||
def get_nonce(client: "TrezorClient") -> bytes:
|
||||
return client.call(messages.GetNonce(), expect=messages.Nonce).nonce
|
||||
def get_nonce(session: "Session") -> bytes:
|
||||
return session.call(messages.GetNonce(), expect=messages.Nonce).nonce
|
||||
|
@ -19,8 +19,8 @@ from typing import TYPE_CHECKING
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
# MAINNET = 0
|
||||
@ -30,13 +30,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||
chunkify: bool = False,
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.MoneroGetAddress(
|
||||
address_n=n,
|
||||
show_display=show_display,
|
||||
@ -48,11 +48,11 @@ def get_address(
|
||||
|
||||
|
||||
def get_watch_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||
) -> messages.MoneroWatchKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
|
||||
expect=messages.MoneroWatchKey,
|
||||
)
|
||||
|
@ -20,8 +20,8 @@ from typing import TYPE_CHECKING
|
||||
from . import exceptions, messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||
@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
network: int,
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.NEMGetAddress(
|
||||
address_n=n, network=network, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -210,7 +210,7 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
|
||||
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
|
||||
) -> messages.NEMSignedTx:
|
||||
try:
|
||||
msg = create_sign_tx(transaction, chunkify=chunkify)
|
||||
@ -219,4 +219,4 @@ def sign_tx(
|
||||
|
||||
assert msg.transaction is not None
|
||||
msg.transaction.address_n = n
|
||||
return client.call(msg, expect=messages.NEMSignedTx)
|
||||
return session.call(msg, expect=messages.NEMSignedTx)
|
||||
|
@ -21,20 +21,20 @@ from .protobuf import dict_to_proto
|
||||
from .tools import dict_from_camelcase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
||||
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.RippleGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -43,14 +43,14 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
msg: messages.RippleSignTx,
|
||||
chunkify: bool = False,
|
||||
) -> messages.RippleSignedTx:
|
||||
msg.address_n = address_n
|
||||
msg.chunkify = chunkify
|
||||
return client.call(msg, expect=messages.RippleSignedTx)
|
||||
return session.call(msg, expect=messages.RippleSignedTx)
|
||||
|
||||
|
||||
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
||||
|
@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
show_display: bool,
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
|
||||
expect=messages.SolanaPublicKey,
|
||||
).public_key
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
show_display: bool,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaGetAddress(
|
||||
address_n=address_n,
|
||||
show_display=show_display,
|
||||
@ -34,12 +34,12 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
serialized_tx: bytes,
|
||||
additional_info: Optional[messages.SolanaTxAdditionalInfo],
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaSignTx(
|
||||
address_n=address_n,
|
||||
serialized_tx=serialized_tx,
|
||||
|
@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from . import exceptions, messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
StellarMessageType = Union[
|
||||
messages.StellarAccountMergeOp,
|
||||
@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.StellarGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -336,7 +336,7 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
tx: messages.StellarSignTx,
|
||||
operations: List["StellarMessageType"],
|
||||
address_n: "Address",
|
||||
@ -352,10 +352,10 @@ def sign_tx(
|
||||
# 3. Receive a StellarTxOpRequest message
|
||||
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
||||
# 5. The final message received will be StellarSignedTx which is returned from this method
|
||||
resp = client.call(tx)
|
||||
resp = session.call(tx)
|
||||
try:
|
||||
while isinstance(resp, messages.StellarTxOpRequest):
|
||||
resp = client.call(operations.pop(0))
|
||||
resp = session.call(operations.pop(0))
|
||||
except IndexError:
|
||||
# pop from empty list
|
||||
raise exceptions.TrezorException(
|
||||
|
@ -19,17 +19,17 @@ from typing import TYPE_CHECKING
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.TezosGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -38,12 +38,12 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.TezosGetPublicKey(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -52,11 +52,11 @@ def get_public_key(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
sign_tx_msg: messages.TezosSignTx,
|
||||
chunkify: bool = False,
|
||||
) -> messages.TezosSignedTx:
|
||||
sign_tx_msg.address_n = address_n
|
||||
sign_tx_msg.chunkify = chunkify
|
||||
return client.call(sign_tx_msg, expect=messages.TezosSignedTx)
|
||||
return session.call(sign_tx_msg, expect=messages.TezosSignedTx)
|
||||
|
@ -45,7 +45,7 @@ if TYPE_CHECKING:
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from . import client
|
||||
from .messages import Success
|
||||
@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None:
|
||||
return _deprecation_retval_helper(msg.message, stacklevel=1)
|
||||
|
||||
|
||||
def session(
|
||||
f: "Callable[Concatenate[TrezorClient, P], R]",
|
||||
) -> "Callable[Concatenate[TrezorClient, P], R]":
|
||||
# Decorator wraps a BaseClient method
|
||||
# with session activation / deactivation
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
client.open()
|
||||
try:
|
||||
return f(client, *args, **kwargs)
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
return wrapped_f
|
||||
|
||||
|
||||
# de-camelcasifier
|
||||
# https://stackoverflow.com/a/1176023/222189
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user