mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 15:30:55 +00:00
feat(python): implement session based trezorlib
[no changelog]
This commit is contained in:
parent
4a18f67f8f
commit
6b1fc71ce3
@ -95,6 +95,15 @@ class Emulator:
|
|||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
|
@client.setter
|
||||||
|
def client(self, new_client: TrezorClientDebugLink) -> None:
|
||||||
|
"""Setter for the client property to update _client."""
|
||||||
|
if not isinstance(new_client, TrezorClientDebugLink):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected a TrezorClientDebugLink, got {type(new_client).__name__}."
|
||||||
|
)
|
||||||
|
self._client = new_client
|
||||||
|
|
||||||
def make_args(self) -> List[str]:
|
def make_args(self) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -112,7 +121,7 @@ class Emulator:
|
|||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if transport._ping():
|
if transport.ping():
|
||||||
break
|
break
|
||||||
if self.process.poll() is not None:
|
if self.process.poll() is not None:
|
||||||
raise RuntimeError("Emulator process died")
|
raise RuntimeError("Emulator process died")
|
||||||
|
@ -7,7 +7,7 @@ import typing as t
|
|||||||
from importlib import metadata
|
from importlib import metadata
|
||||||
|
|
||||||
from . import device
|
from . import device
|
||||||
from .client import TrezorClient
|
from .transport.session import Session
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cryptography_version = metadata.version("cryptography")
|
cryptography_version = metadata.version("cryptography")
|
||||||
@ -361,7 +361,7 @@ def verify_authentication_response(
|
|||||||
|
|
||||||
|
|
||||||
def authenticate_device(
|
def authenticate_device(
|
||||||
client: TrezorClient,
|
session: Session,
|
||||||
challenge: bytes | None = None,
|
challenge: bytes | None = None,
|
||||||
*,
|
*,
|
||||||
whitelist: t.Collection[bytes] | None = None,
|
whitelist: t.Collection[bytes] | None = None,
|
||||||
@ -371,7 +371,7 @@ def authenticate_device(
|
|||||||
if challenge is None:
|
if challenge is None:
|
||||||
challenge = secrets.token_bytes(16)
|
challenge = secrets.token_bytes(16)
|
||||||
|
|
||||||
resp = device.authenticate(client, challenge)
|
resp = device.authenticate(session, challenge)
|
||||||
|
|
||||||
return verify_authentication_response(
|
return verify_authentication_response(
|
||||||
challenge,
|
challenge,
|
||||||
|
@ -20,17 +20,17 @@ from . import messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.BenchmarkNames)
|
@expect(messages.BenchmarkNames)
|
||||||
def list_names(
|
def list_names(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(messages.BenchmarkListNames())
|
return session.call(messages.BenchmarkListNames())
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.BenchmarkResult)
|
@expect(messages.BenchmarkResult)
|
||||||
def run(client: "TrezorClient", name: str) -> "MessageType":
|
def run(session: "Session", name: str) -> "MessageType":
|
||||||
return client.call(messages.BenchmarkRun(name=name))
|
return session.call(messages.BenchmarkRun(name=name))
|
||||||
|
@ -18,22 +18,22 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .protobuf import dict_to_proto
|
from .protobuf import dict_to_proto
|
||||||
from .tools import expect, session
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.BinanceAddress, field="address", ret_type=str)
|
@expect(messages.BinanceAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.BinanceGetAddress(
|
messages.BinanceGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
)
|
)
|
||||||
@ -42,16 +42,15 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
|
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
session: "Session", address_n: "Address", show_display: bool = False
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
|
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
|
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||||
) -> messages.BinanceSignedTx:
|
) -> messages.BinanceSignedTx:
|
||||||
msg = tx_json["msgs"][0]
|
msg = tx_json["msgs"][0]
|
||||||
tx_msg = tx_json.copy()
|
tx_msg = tx_json.copy()
|
||||||
@ -60,7 +59,7 @@ def sign_tx(
|
|||||||
tx_msg["chunkify"] = chunkify
|
tx_msg["chunkify"] = chunkify
|
||||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
||||||
|
|
||||||
response = client.call(envelope)
|
response = session.call(envelope)
|
||||||
|
|
||||||
if not isinstance(response, messages.BinanceTxRequest):
|
if not isinstance(response, messages.BinanceTxRequest):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -77,7 +76,7 @@ def sign_tx(
|
|||||||
else:
|
else:
|
||||||
raise ValueError("can not determine msg type")
|
raise ValueError("can not determine msg type")
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
|
|
||||||
if not isinstance(response, messages.BinanceSignedTx):
|
if not isinstance(response, messages.BinanceSignedTx):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@ -23,12 +22,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
|
|||||||
from typing_extensions import Protocol, TypedDict
|
from typing_extensions import Protocol, TypedDict
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import expect, prepare_message_bytes, session
|
from .tools import expect, prepare_message_bytes
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
class ScriptSig(TypedDict):
|
class ScriptSig(TypedDict):
|
||||||
asm: str
|
asm: str
|
||||||
@ -105,7 +104,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
|||||||
|
|
||||||
@expect(messages.PublicKey)
|
@expect(messages.PublicKey)
|
||||||
def get_public_node(
|
def get_public_node(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
ecdsa_curve_name: Optional[str] = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
@ -116,13 +115,13 @@ def get_public_node(
|
|||||||
unlock_path_mac: Optional[bytes] = None,
|
unlock_path_mac: Optional[bytes] = None,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
if unlock_path:
|
if unlock_path:
|
||||||
res = client.call(
|
res = session.call(
|
||||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||||
)
|
)
|
||||||
if not isinstance(res, messages.UnlockedPathRequest):
|
if not isinstance(res, messages.UnlockedPathRequest):
|
||||||
raise exceptions.TrezorException("Unexpected message")
|
raise exceptions.TrezorException("Unexpected message")
|
||||||
|
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetPublicKey(
|
messages.GetPublicKey(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
ecdsa_curve_name=ecdsa_curve_name,
|
ecdsa_curve_name=ecdsa_curve_name,
|
||||||
@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any):
|
|||||||
|
|
||||||
@expect(messages.Address)
|
@expect(messages.Address)
|
||||||
def get_authenticated_address(
|
def get_authenticated_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
@ -153,13 +152,13 @@ def get_authenticated_address(
|
|||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
if unlock_path:
|
if unlock_path:
|
||||||
res = client.call(
|
res = session.call(
|
||||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||||
)
|
)
|
||||||
if not isinstance(res, messages.UnlockedPathRequest):
|
if not isinstance(res, messages.UnlockedPathRequest):
|
||||||
raise exceptions.TrezorException("Unexpected message")
|
raise exceptions.TrezorException("Unexpected message")
|
||||||
|
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetAddress(
|
messages.GetAddress(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
@ -172,15 +171,16 @@ def get_authenticated_address(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO this is used by tests only
|
||||||
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
|
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
|
||||||
def get_ownership_id(
|
def get_ownership_id(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetOwnershipId(
|
messages.GetOwnershipId(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
@ -190,8 +190,9 @@ def get_ownership_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO this is used by tests only
|
||||||
def get_ownership_proof(
|
def get_ownership_proof(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||||
@ -202,11 +203,11 @@ def get_ownership_proof(
|
|||||||
preauthorized: bool = False,
|
preauthorized: bool = False,
|
||||||
) -> Tuple[bytes, bytes]:
|
) -> Tuple[bytes, bytes]:
|
||||||
if preauthorized:
|
if preauthorized:
|
||||||
res = client.call(messages.DoPreauthorized())
|
res = session.call(messages.DoPreauthorized())
|
||||||
if not isinstance(res, messages.PreauthorizedRequest):
|
if not isinstance(res, messages.PreauthorizedRequest):
|
||||||
raise exceptions.TrezorException("Unexpected message")
|
raise exceptions.TrezorException("Unexpected message")
|
||||||
|
|
||||||
res = client.call(
|
res = session.call(
|
||||||
messages.GetOwnershipProof(
|
messages.GetOwnershipProof(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
@ -226,7 +227,7 @@ def get_ownership_proof(
|
|||||||
|
|
||||||
@expect(messages.MessageSignature)
|
@expect(messages.MessageSignature)
|
||||||
def sign_message(
|
def sign_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
n: "Address",
|
n: "Address",
|
||||||
message: AnyStr,
|
message: AnyStr,
|
||||||
@ -234,7 +235,7 @@ def sign_message(
|
|||||||
no_script_type: bool = False,
|
no_script_type: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SignMessage(
|
messages.SignMessage(
|
||||||
coin_name=coin_name,
|
coin_name=coin_name,
|
||||||
address_n=n,
|
address_n=n,
|
||||||
@ -247,7 +248,7 @@ def sign_message(
|
|||||||
|
|
||||||
|
|
||||||
def verify_message(
|
def verify_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
address: str,
|
address: str,
|
||||||
signature: bytes,
|
signature: bytes,
|
||||||
@ -255,7 +256,7 @@ def verify_message(
|
|||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.VerifyMessage(
|
messages.VerifyMessage(
|
||||||
address=address,
|
address=address,
|
||||||
signature=signature,
|
signature=signature,
|
||||||
@ -269,9 +270,9 @@ def verify_message(
|
|||||||
return isinstance(resp, messages.Success)
|
return isinstance(resp, messages.Success)
|
||||||
|
|
||||||
|
|
||||||
@session
|
# @session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin_name: str,
|
coin_name: str,
|
||||||
inputs: Sequence[messages.TxInputType],
|
inputs: Sequence[messages.TxInputType],
|
||||||
outputs: Sequence[messages.TxOutputType],
|
outputs: Sequence[messages.TxOutputType],
|
||||||
@ -319,17 +320,17 @@ def sign_tx(
|
|||||||
setattr(signtx, name, value)
|
setattr(signtx, name, value)
|
||||||
|
|
||||||
if unlock_path:
|
if unlock_path:
|
||||||
res = client.call(
|
res = session.call(
|
||||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
|
||||||
)
|
)
|
||||||
if not isinstance(res, messages.UnlockedPathRequest):
|
if not isinstance(res, messages.UnlockedPathRequest):
|
||||||
raise exceptions.TrezorException("Unexpected message")
|
raise exceptions.TrezorException("Unexpected message")
|
||||||
elif preauthorized:
|
elif preauthorized:
|
||||||
res = client.call(messages.DoPreauthorized())
|
res = session.call(messages.DoPreauthorized())
|
||||||
if not isinstance(res, messages.PreauthorizedRequest):
|
if not isinstance(res, messages.PreauthorizedRequest):
|
||||||
raise exceptions.TrezorException("Unexpected message")
|
raise exceptions.TrezorException("Unexpected message")
|
||||||
|
|
||||||
res = client.call(signtx)
|
res = session.call(signtx)
|
||||||
|
|
||||||
# Prepare structure for signatures
|
# Prepare structure for signatures
|
||||||
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||||
@ -388,7 +389,7 @@ def sign_tx(
|
|||||||
if res.request_type == R.TXPAYMENTREQ:
|
if res.request_type == R.TXPAYMENTREQ:
|
||||||
assert res.details.request_index is not None
|
assert res.details.request_index is not None
|
||||||
msg = payment_reqs[res.details.request_index]
|
msg = payment_reqs[res.details.request_index]
|
||||||
res = client.call(msg)
|
res = session.call(msg)
|
||||||
else:
|
else:
|
||||||
msg = messages.TransactionType()
|
msg = messages.TransactionType()
|
||||||
if res.request_type == R.TXMETA:
|
if res.request_type == R.TXMETA:
|
||||||
@ -418,7 +419,7 @@ def sign_tx(
|
|||||||
f"Unknown request type - {res.request_type}."
|
f"Unknown request type - {res.request_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
res = client.call(messages.TxAck(tx=msg))
|
res = session.call(messages.TxAck(tx=msg))
|
||||||
|
|
||||||
if not isinstance(res, messages.TxRequest):
|
if not isinstance(res, messages.TxRequest):
|
||||||
raise exceptions.TrezorException("Unexpected message")
|
raise exceptions.TrezorException("Unexpected message")
|
||||||
@ -432,7 +433,7 @@ def sign_tx(
|
|||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def authorize_coinjoin(
|
def authorize_coinjoin(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coordinator: str,
|
coordinator: str,
|
||||||
max_rounds: int,
|
max_rounds: int,
|
||||||
max_coordinator_fee_rate: int,
|
max_coordinator_fee_rate: int,
|
||||||
@ -441,7 +442,7 @@ def authorize_coinjoin(
|
|||||||
coin_name: str,
|
coin_name: str,
|
||||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.AuthorizeCoinJoin(
|
messages.AuthorizeCoinJoin(
|
||||||
coordinator=coordinator,
|
coordinator=coordinator,
|
||||||
max_rounds=max_rounds,
|
max_rounds=max_rounds,
|
||||||
|
@ -35,8 +35,8 @@ from . import exceptions, messages, tools
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
PROTOCOL_MAGICS = {
|
PROTOCOL_MAGICS = {
|
||||||
"mainnet": 764824073,
|
"mainnet": 764824073,
|
||||||
@ -825,7 +825,7 @@ def _get_collateral_inputs_items(
|
|||||||
|
|
||||||
@expect(messages.CardanoAddress, field="address", ret_type=str)
|
@expect(messages.CardanoAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_parameters: messages.CardanoAddressParametersType,
|
address_parameters: messages.CardanoAddressParametersType,
|
||||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||||
network_id: int = NETWORK_IDS["mainnet"],
|
network_id: int = NETWORK_IDS["mainnet"],
|
||||||
@ -833,7 +833,7 @@ def get_address(
|
|||||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.CardanoGetAddress(
|
messages.CardanoGetAddress(
|
||||||
address_parameters=address_parameters,
|
address_parameters=address_parameters,
|
||||||
protocol_magic=protocol_magic,
|
protocol_magic=protocol_magic,
|
||||||
@ -847,12 +847,12 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.CardanoPublicKey)
|
@expect(messages.CardanoPublicKey)
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.CardanoGetPublicKey(
|
messages.CardanoGetPublicKey(
|
||||||
address_n=address_n,
|
address_n=address_n,
|
||||||
derivation_type=derivation_type,
|
derivation_type=derivation_type,
|
||||||
@ -863,12 +863,12 @@ def get_public_key(
|
|||||||
|
|
||||||
@expect(messages.CardanoNativeScriptHash)
|
@expect(messages.CardanoNativeScriptHash)
|
||||||
def get_native_script_hash(
|
def get_native_script_hash(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
native_script: messages.CardanoNativeScript,
|
native_script: messages.CardanoNativeScript,
|
||||||
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
|
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||||
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.CardanoGetNativeScriptHash(
|
messages.CardanoGetNativeScriptHash(
|
||||||
script=native_script,
|
script=native_script,
|
||||||
display_format=display_format,
|
display_format=display_format,
|
||||||
@ -878,7 +878,7 @@ def get_native_script_hash(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
signing_mode: messages.CardanoTxSigningMode,
|
signing_mode: messages.CardanoTxSigningMode,
|
||||||
inputs: List[InputWithPath],
|
inputs: List[InputWithPath],
|
||||||
outputs: List[OutputWithData],
|
outputs: List[OutputWithData],
|
||||||
@ -915,7 +915,7 @@ def sign_tx(
|
|||||||
signing_mode,
|
signing_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(
|
response = session.call(
|
||||||
messages.CardanoSignTxInit(
|
messages.CardanoSignTxInit(
|
||||||
signing_mode=signing_mode,
|
signing_mode=signing_mode,
|
||||||
inputs_count=len(inputs),
|
inputs_count=len(inputs),
|
||||||
@ -951,14 +951,14 @@ def sign_tx(
|
|||||||
_get_certificates_items(certificates),
|
_get_certificates_items(certificates),
|
||||||
withdrawals,
|
withdrawals,
|
||||||
):
|
):
|
||||||
response = client.call(tx_item)
|
response = session.call(tx_item)
|
||||||
if not isinstance(response, messages.CardanoTxItemAck):
|
if not isinstance(response, messages.CardanoTxItemAck):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
|
|
||||||
sign_tx_response: Dict[str, Any] = {}
|
sign_tx_response: Dict[str, Any] = {}
|
||||||
|
|
||||||
if auxiliary_data is not None:
|
if auxiliary_data is not None:
|
||||||
auxiliary_data_supplement = client.call(auxiliary_data)
|
auxiliary_data_supplement = session.call(auxiliary_data)
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
|
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
|
||||||
):
|
):
|
||||||
@ -971,7 +971,7 @@ def sign_tx(
|
|||||||
auxiliary_data_supplement.__dict__
|
auxiliary_data_supplement.__dict__
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(messages.CardanoTxHostAck())
|
response = session.call(messages.CardanoTxHostAck())
|
||||||
if not isinstance(response, messages.CardanoTxItemAck):
|
if not isinstance(response, messages.CardanoTxItemAck):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
|
|
||||||
@ -980,24 +980,24 @@ def sign_tx(
|
|||||||
_get_collateral_inputs_items(collateral_inputs),
|
_get_collateral_inputs_items(collateral_inputs),
|
||||||
required_signers,
|
required_signers,
|
||||||
):
|
):
|
||||||
response = client.call(tx_item)
|
response = session.call(tx_item)
|
||||||
if not isinstance(response, messages.CardanoTxItemAck):
|
if not isinstance(response, messages.CardanoTxItemAck):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
|
|
||||||
if collateral_return is not None:
|
if collateral_return is not None:
|
||||||
for tx_item in _get_output_items(collateral_return):
|
for tx_item in _get_output_items(collateral_return):
|
||||||
response = client.call(tx_item)
|
response = session.call(tx_item)
|
||||||
if not isinstance(response, messages.CardanoTxItemAck):
|
if not isinstance(response, messages.CardanoTxItemAck):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
|
|
||||||
for reference_input in reference_inputs:
|
for reference_input in reference_inputs:
|
||||||
response = client.call(reference_input)
|
response = session.call(reference_input)
|
||||||
if not isinstance(response, messages.CardanoTxItemAck):
|
if not isinstance(response, messages.CardanoTxItemAck):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
|
|
||||||
sign_tx_response["witnesses"] = []
|
sign_tx_response["witnesses"] = []
|
||||||
for witness_request in witness_requests:
|
for witness_request in witness_requests:
|
||||||
response = client.call(witness_request)
|
response = session.call(witness_request)
|
||||||
if not isinstance(response, messages.CardanoTxWitnessResponse):
|
if not isinstance(response, messages.CardanoTxWitnessResponse):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
sign_tx_response["witnesses"].append(
|
sign_tx_response["witnesses"].append(
|
||||||
@ -1009,12 +1009,12 @@ def sign_tx(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(messages.CardanoTxHostAck())
|
response = session.call(messages.CardanoTxHostAck())
|
||||||
if not isinstance(response, messages.CardanoTxBodyHash):
|
if not isinstance(response, messages.CardanoTxBodyHash):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
sign_tx_response["tx_hash"] = response.tx_hash
|
sign_tx_response["tx_hash"] = response.tx_hash
|
||||||
|
|
||||||
response = client.call(messages.CardanoTxHostAck())
|
response = session.call(messages.CardanoTxHostAck())
|
||||||
if not isinstance(response, messages.CardanoSignTxFinished):
|
if not isinstance(response, messages.CardanoSignTxFinished):
|
||||||
raise UNEXPECTED_RESPONSE_ERROR
|
raise UNEXPECTED_RESPONSE_ERROR
|
||||||
|
|
||||||
|
@ -14,33 +14,42 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import typing as t
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import exceptions, transport
|
from .. import exceptions, transport, ui
|
||||||
from ..client import TrezorClient
|
from ..client import ProtocolVersion, TrezorClient
|
||||||
from ..ui import ClickUI, ScriptUI
|
from ..messages import Capability
|
||||||
|
from ..transport import Transport
|
||||||
|
from ..transport.session import Session, SessionV1, SessionV2
|
||||||
|
from ..transport.thp.channel_database import get_channel_db
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
# Needed to enforce a return value from decorators
|
# Needed to enforce a return value from decorators
|
||||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
from ..transport import Transport
|
|
||||||
from ..ui import TrezorClientUI
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = t.TypeVar("R")
|
||||||
|
FuncWithSession = t.Callable[Concatenate[Session, P], R]
|
||||||
|
|
||||||
|
|
||||||
class ChoiceType(click.Choice):
|
class ChoiceType(click.Choice):
|
||||||
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
|
|
||||||
|
def __init__(
|
||||||
|
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
|
||||||
|
) -> None:
|
||||||
super().__init__(list(typemap.keys()))
|
super().__init__(list(typemap.keys()))
|
||||||
self.case_sensitive = case_sensitive
|
self.case_sensitive = case_sensitive
|
||||||
if case_sensitive:
|
if case_sensitive:
|
||||||
@ -48,7 +57,7 @@ class ChoiceType(click.Choice):
|
|||||||
else:
|
else:
|
||||||
self.typemap = {k.lower(): v for k, v in typemap.items()}
|
self.typemap = {k.lower(): v for k, v in typemap.items()}
|
||||||
|
|
||||||
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
|
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
|
||||||
if value in self.typemap.values():
|
if value in self.typemap.values():
|
||||||
return value
|
return value
|
||||||
value = super().convert(value, param, ctx)
|
value = super().convert(value, param, ctx)
|
||||||
@ -57,11 +66,69 @@ class ChoiceType(click.Choice):
|
|||||||
return self.typemap[value]
|
return self.typemap[value]
|
||||||
|
|
||||||
|
|
||||||
|
def get_passphrase(
|
||||||
|
passphrase_on_host: bool, available_on_device: bool
|
||||||
|
) -> t.Union[str, object]:
|
||||||
|
if available_on_device and not passphrase_on_host:
|
||||||
|
return ui.PASSPHRASE_ON_DEVICE
|
||||||
|
|
||||||
|
env_passphrase = os.getenv("PASSPHRASE")
|
||||||
|
if env_passphrase is not None:
|
||||||
|
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||||
|
return env_passphrase
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
passphrase = ui.prompt(
|
||||||
|
"Passphrase required",
|
||||||
|
hide_input=True,
|
||||||
|
default="",
|
||||||
|
show_default=False,
|
||||||
|
)
|
||||||
|
# In case user sees the input on the screen, we do not need confirmation
|
||||||
|
if not ui.CAN_HANDLE_HIDDEN_INPUT:
|
||||||
|
return passphrase
|
||||||
|
second = ui.prompt(
|
||||||
|
"Confirm your passphrase",
|
||||||
|
hide_input=True,
|
||||||
|
default="",
|
||||||
|
show_default=False,
|
||||||
|
)
|
||||||
|
if passphrase == second:
|
||||||
|
return passphrase
|
||||||
|
else:
|
||||||
|
ui.echo("Passphrase did not match. Please try again.")
|
||||||
|
except click.Abort:
|
||||||
|
raise exceptions.Cancelled from None
|
||||||
|
|
||||||
|
|
||||||
|
def get_client(transport: Transport) -> TrezorClient:
|
||||||
|
stored_channels = get_channel_db().load_stored_channels()
|
||||||
|
stored_transport_paths = [ch.transport_path for ch in stored_channels]
|
||||||
|
path = transport.get_path()
|
||||||
|
if path in stored_transport_paths:
|
||||||
|
stored_channel_with_correct_transport_path = next(
|
||||||
|
ch for ch in stored_channels if ch.transport_path == path
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
client = TrezorClient.resume(
|
||||||
|
transport, stored_channel_with_correct_transport_path
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.debug("Failed to resume a channel. Replacing by a new one.")
|
||||||
|
get_channel_db().remove_channel(path)
|
||||||
|
client = TrezorClient(transport)
|
||||||
|
else:
|
||||||
|
client = TrezorClient(transport)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
class TrezorConnection:
|
class TrezorConnection:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
session_id: Optional[bytes],
|
session_id: bytes | None,
|
||||||
passphrase_on_host: bool,
|
passphrase_on_host: bool,
|
||||||
script: bool,
|
script: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -70,6 +137,54 @@ class TrezorConnection:
|
|||||||
self.passphrase_on_host = passphrase_on_host
|
self.passphrase_on_host = passphrase_on_host
|
||||||
self.script = script
|
self.script = script
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
empty_passphrase: bool = False,
|
||||||
|
must_resume: bool = False,
|
||||||
|
) -> Session:
|
||||||
|
client = self.get_client()
|
||||||
|
if must_resume and self.session_id is None:
|
||||||
|
click.echo("Failed to resume session - no session id provided")
|
||||||
|
raise RuntimeError("Failed to resume session - no session id provided")
|
||||||
|
|
||||||
|
# Try resume session from id
|
||||||
|
if self.session_id is not None:
|
||||||
|
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
||||||
|
session = SessionV1.resume_from_id(
|
||||||
|
client=client, session_id=self.session_id
|
||||||
|
)
|
||||||
|
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||||
|
session = SessionV2(client, self.session_id)
|
||||||
|
# TODO fix resumption on THP
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported client protocol", client.protocol_version)
|
||||||
|
if must_resume:
|
||||||
|
if session.id != self.session_id or session.id is None:
|
||||||
|
click.echo("Failed to resume session")
|
||||||
|
RuntimeError("Failed to resume session - no session id provided")
|
||||||
|
return session
|
||||||
|
|
||||||
|
features = client.protocol.get_features()
|
||||||
|
|
||||||
|
passphrase_enabled = True # TODO what to do here?
|
||||||
|
|
||||||
|
if not passphrase_enabled:
|
||||||
|
return client.get_session(derive_cardano=derive_cardano)
|
||||||
|
|
||||||
|
if empty_passphrase:
|
||||||
|
passphrase = ""
|
||||||
|
else:
|
||||||
|
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||||
|
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||||
|
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
||||||
|
if not isinstance(passphrase, str):
|
||||||
|
raise RuntimeError("Passphrase must be a str")
|
||||||
|
session = client.get_session(
|
||||||
|
passphrase=passphrase, derive_cardano=derive_cardano
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
def get_transport(self) -> "Transport":
|
def get_transport(self) -> "Transport":
|
||||||
try:
|
try:
|
||||||
# look for transport without prefix search
|
# look for transport without prefix search
|
||||||
@ -82,19 +197,13 @@ class TrezorConnection:
|
|||||||
# if this fails, we want the exception to bubble up to the caller
|
# if this fails, we want the exception to bubble up to the caller
|
||||||
return transport.get_transport(self.path, prefix_search=True)
|
return transport.get_transport(self.path, prefix_search=True)
|
||||||
|
|
||||||
def get_ui(self) -> "TrezorClientUI":
|
|
||||||
if self.script:
|
|
||||||
# It is alright to return just the class object instead of instance,
|
|
||||||
# as the ScriptUI class object itself is the implementation of TrezorClientUI
|
|
||||||
# (ScriptUI is just a set of staticmethods)
|
|
||||||
return ScriptUI
|
|
||||||
else:
|
|
||||||
return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
|
||||||
|
|
||||||
def get_client(self) -> TrezorClient:
|
def get_client(self) -> TrezorClient:
|
||||||
transport = self.get_transport()
|
return get_client(self.get_transport())
|
||||||
ui = self.get_ui()
|
|
||||||
return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
def get_management_session(self) -> Session:
|
||||||
|
client = self.get_client()
|
||||||
|
management_session = client.get_management_session()
|
||||||
|
return management_session
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def client_context(self):
|
def client_context(self):
|
||||||
@ -128,7 +237,57 @@ class TrezorConnection:
|
|||||||
# other exceptions may cause a traceback
|
# other exceptions may cause a traceback
|
||||||
|
|
||||||
|
|
||||||
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
|
def with_session(
|
||||||
|
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
|
||||||
|
*,
|
||||||
|
empty_passphrase: bool = False,
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
management: bool = False,
|
||||||
|
must_resume: bool = False,
|
||||||
|
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
|
||||||
|
"""Provides a Click command with parameter `session=obj.get_session(...)` or
|
||||||
|
`session=obj.get_management_session()` based on the parameters provided.
|
||||||
|
|
||||||
|
If default parameters are ok, this decorator can be used without parentheses.
|
||||||
|
|
||||||
|
TODO: handle resumption of sessions and their (potential) closure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
func: FuncWithSession,
|
||||||
|
) -> "t.Callable[P, R]":
|
||||||
|
|
||||||
|
@click.pass_obj
|
||||||
|
@functools.wraps(func)
|
||||||
|
def function_with_session(
|
||||||
|
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||||
|
) -> "R":
|
||||||
|
if management:
|
||||||
|
session = obj.get_management_session()
|
||||||
|
else:
|
||||||
|
session = obj.get_session(
|
||||||
|
derive_cardano=derive_cardano,
|
||||||
|
empty_passphrase=empty_passphrase,
|
||||||
|
must_resume=must_resume,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return func(session, *args, **kwargs)
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
# TODO try end session if not resumed
|
||||||
|
|
||||||
|
return function_with_session
|
||||||
|
|
||||||
|
# If the decorator @get_session is used without parentheses
|
||||||
|
if func and callable(func):
|
||||||
|
return decorator(func) # type: ignore [Function return type]
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def with_client(
|
||||||
|
func: "t.Callable[Concatenate[TrezorClient, P], R]",
|
||||||
|
) -> "t.Callable[P, R]":
|
||||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||||
|
|
||||||
Sessions are handled transparently. The user is warned when session did not resume
|
Sessions are handled transparently. The user is warned when session did not resume
|
||||||
@ -142,23 +301,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
|
|||||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||||
) -> "R":
|
) -> "R":
|
||||||
with obj.client_context() as client:
|
with obj.client_context() as client:
|
||||||
session_was_resumed = obj.session_id == client.session_id
|
# session_was_resumed = obj.session_id == client.session_id
|
||||||
if not session_was_resumed and obj.session_id is not None:
|
# if not session_was_resumed and obj.session_id is not None:
|
||||||
# tried to resume but failed
|
# # tried to resume but failed
|
||||||
click.echo("Warning: failed to resume session.", err=True)
|
# click.echo("Warning: failed to resume session.", err=True)
|
||||||
|
click.echo(
|
||||||
|
"Warning: resume session detection is not implemented yet!", err=True
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return func(client, *args, **kwargs)
|
return func(client, *args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
if not session_was_resumed:
|
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||||
try:
|
get_channel_db().save_channel(client.protocol)
|
||||||
client.end_session()
|
# if not session_was_resumed:
|
||||||
except Exception:
|
# try:
|
||||||
pass
|
# client.end_session()
|
||||||
|
# except Exception:
|
||||||
|
# pass
|
||||||
|
|
||||||
return trezorctl_command_with_client
|
return trezorctl_command_with_client
|
||||||
|
|
||||||
|
|
||||||
|
# def with_client(
|
||||||
|
# func: "t.Callable[Concatenate[TrezorClient, P], R]",
|
||||||
|
# ) -> "t.Callable[P, R]":
|
||||||
|
# """Wrap a Click command in `with obj.client_context() as client`.
|
||||||
|
|
||||||
|
# Sessions are handled transparently. The user is warned when session did not resume
|
||||||
|
# cleanly. The session is closed after the command completes - unless the session
|
||||||
|
# was resumed, in which case it should remain open.
|
||||||
|
# """
|
||||||
|
|
||||||
|
# @click.pass_obj
|
||||||
|
# @functools.wraps(func)
|
||||||
|
# def trezorctl_command_with_client(
|
||||||
|
# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||||
|
# ) -> "R":
|
||||||
|
# with obj.client_context() as client:
|
||||||
|
# session_was_resumed = obj.session_id == client.session_id
|
||||||
|
# if not session_was_resumed and obj.session_id is not None:
|
||||||
|
# # tried to resume but failed
|
||||||
|
# click.echo("Warning: failed to resume session.", err=True)
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# return func(client, *args, **kwargs)
|
||||||
|
# finally:
|
||||||
|
# if not session_was_resumed:
|
||||||
|
# try:
|
||||||
|
# client.end_session()
|
||||||
|
# except Exception:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
# # the return type of @click.pass_obj is improperly specified and pyright doesn't
|
||||||
|
# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
|
||||||
|
# return trezorctl_command_with_client
|
||||||
|
|
||||||
|
|
||||||
class AliasedGroup(click.Group):
|
class AliasedGroup(click.Group):
|
||||||
"""Command group that handles aliases and Click 6.x compatibility.
|
"""Command group that handles aliases and Click 6.x compatibility.
|
||||||
|
|
||||||
@ -188,14 +386,14 @@ class AliasedGroup(click.Group):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
aliases: Optional[Dict[str, click.Command]] = None,
|
aliases: t.Dict[str, click.Command] | None = None,
|
||||||
*args: Any,
|
*args: t.Any,
|
||||||
**kwargs: Any,
|
**kwargs: t.Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.aliases = aliases or {}
|
self.aliases = aliases or {}
|
||||||
|
|
||||||
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
|
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||||
cmd_name = cmd_name.replace("_", "-")
|
cmd_name = cmd_name.replace("_", "-")
|
||||||
# try to look up the real name
|
# try to look up the real name
|
||||||
cmd = super().get_command(ctx, cmd_name)
|
cmd = super().get_command(ctx, cmd_name)
|
||||||
|
@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import benchmark
|
from .. import benchmark
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def list_names_patern(
|
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
|
||||||
client: "TrezorClient", pattern: Optional[str] = None
|
names = list(benchmark.list_names(session).names)
|
||||||
) -> List[str]:
|
|
||||||
names = list(benchmark.list_names(client).names)
|
|
||||||
if pattern is None:
|
if pattern is None:
|
||||||
return names
|
return names
|
||||||
return [name for name in names if fnmatch(name, pattern)]
|
return [name for name in names if fnmatch(name, pattern)]
|
||||||
@ -43,10 +41,10 @@ def cli() -> None:
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("pattern", required=False)
|
@click.argument("pattern", required=False)
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
||||||
"""List names of all supported benchmarks"""
|
"""List names of all supported benchmarks"""
|
||||||
names = list_names_patern(client, pattern)
|
names = list_names_patern(session, pattern)
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
click.echo("No benchmark satisfies the pattern.")
|
click.echo("No benchmark satisfies the pattern.")
|
||||||
else:
|
else:
|
||||||
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("pattern", required=False)
|
@click.argument("pattern", required=False)
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
|
def run(session: "Session", pattern: Optional[str]) -> None:
|
||||||
"""Run benchmark"""
|
"""Run benchmark"""
|
||||||
names = list_names_patern(client, pattern)
|
names = list_names_patern(session, pattern)
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
click.echo("No benchmark satisfies the pattern.")
|
click.echo("No benchmark satisfies the pattern.")
|
||||||
else:
|
else:
|
||||||
for name in names:
|
for name in names:
|
||||||
result = benchmark.run(client, name)
|
result = benchmark.run(session, name)
|
||||||
click.echo(f"{name}: {result.value} {result.unit}")
|
click.echo(f"{name}: {result.value} {result.unit}")
|
||||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import binance, tools
|
from .. import binance, tools
|
||||||
from . import with_client
|
from ..transport.session import Session
|
||||||
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .. import messages
|
from .. import messages
|
||||||
from ..client import TrezorClient
|
|
||||||
|
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
|
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
|
||||||
@ -39,23 +39,23 @@ def cli() -> None:
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get Binance address for specified path."""
|
"""Get Binance address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return binance.get_address(client, address_n, show_display, chunkify)
|
return binance.get_address(session, address_n, show_display, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||||
"""Get Binance public key."""
|
"""Get Binance public key."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return binance.get_public_key(client, address_n, show_display).hex()
|
return binance.get_public_key(session, address_n, show_display).hex()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||||
) -> "messages.BinanceSignedTx":
|
) -> "messages.BinanceSignedTx":
|
||||||
"""Sign Binance transaction.
|
"""Sign Binance transaction.
|
||||||
|
|
||||||
Transaction must be provided as a JSON file.
|
Transaction must be provided as a JSON file.
|
||||||
"""
|
"""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#
|
#
|
||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
@ -22,10 +23,10 @@ import click
|
|||||||
import construct as c
|
import construct as c
|
||||||
|
|
||||||
from .. import btc, messages, protobuf, tools
|
from .. import btc, messages, protobuf, tools
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PURPOSE_BIP44 = 44
|
PURPOSE_BIP44 = 44
|
||||||
PURPOSE_BIP48 = 48
|
PURPOSE_BIP48 = 48
|
||||||
@ -174,15 +175,15 @@ def cli() -> None:
|
|||||||
help="Sort pubkeys lexicographically using BIP-67",
|
help="Sort pubkeys lexicographically using BIP-67",
|
||||||
)
|
)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin: str,
|
coin: str,
|
||||||
address: str,
|
address: str,
|
||||||
script_type: Optional[messages.InputScriptType],
|
script_type: messages.InputScriptType | None,
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
multisig_xpub: List[str],
|
multisig_xpub: List[str],
|
||||||
multisig_threshold: Optional[int],
|
multisig_threshold: int | None,
|
||||||
multisig_suffix_length: int,
|
multisig_suffix_length: int,
|
||||||
multisig_sort_pubkeys: bool,
|
multisig_sort_pubkeys: bool,
|
||||||
chunkify: bool,
|
chunkify: bool,
|
||||||
@ -235,7 +236,7 @@ def get_address(
|
|||||||
multisig = None
|
multisig = None
|
||||||
|
|
||||||
return btc.get_address(
|
return btc.get_address(
|
||||||
client,
|
session,
|
||||||
coin,
|
coin,
|
||||||
address_n,
|
address_n,
|
||||||
show_display,
|
show_display,
|
||||||
@ -252,9 +253,9 @@ def get_address(
|
|||||||
@click.option("-e", "--curve")
|
@click.option("-e", "--curve")
|
||||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_public_node(
|
def get_public_node(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin: str,
|
coin: str,
|
||||||
address: str,
|
address: str,
|
||||||
curve: Optional[str],
|
curve: Optional[str],
|
||||||
@ -266,7 +267,7 @@ def get_public_node(
|
|||||||
if script_type is None:
|
if script_type is None:
|
||||||
script_type = guess_script_type_from_path(address_n)
|
script_type = guess_script_type_from_path(address_n)
|
||||||
result = btc.get_public_node(
|
result = btc.get_public_node(
|
||||||
client,
|
session,
|
||||||
address_n,
|
address_n,
|
||||||
ecdsa_curve_name=curve,
|
ecdsa_curve_name=curve,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _get_descriptor(
|
def _get_descriptor(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin: Optional[str],
|
coin: Optional[str],
|
||||||
account: int,
|
account: int,
|
||||||
purpose: Optional[int],
|
purpose: Optional[int],
|
||||||
@ -326,7 +327,7 @@ def _get_descriptor(
|
|||||||
|
|
||||||
n = tools.parse_path(path)
|
n = tools.parse_path(path)
|
||||||
pub = btc.get_public_node(
|
pub = btc.get_public_node(
|
||||||
client,
|
session,
|
||||||
n,
|
n,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
coin_name=coin,
|
coin_name=coin,
|
||||||
@ -363,9 +364,9 @@ def _get_descriptor(
|
|||||||
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
|
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
|
||||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_descriptor(
|
def get_descriptor(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin: Optional[str],
|
coin: Optional[str],
|
||||||
account: int,
|
account: int,
|
||||||
account_type: Optional[int],
|
account_type: Optional[int],
|
||||||
@ -375,7 +376,7 @@ def get_descriptor(
|
|||||||
"""Get descriptor of given account."""
|
"""Get descriptor of given account."""
|
||||||
try:
|
try:
|
||||||
return _get_descriptor(
|
return _get_descriptor(
|
||||||
client, coin, account, account_type, script_type, show_display
|
session, coin, account, account_type, script_type, show_display
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise click.ClickException(str(e))
|
raise click.ClickException(str(e))
|
||||||
@ -390,8 +391,8 @@ def get_descriptor(
|
|||||||
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@click.argument("json_file", type=click.File())
|
@click.argument("json_file", type=click.File())
|
||||||
@with_client
|
@with_session
|
||||||
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
|
||||||
"""Sign transaction.
|
"""Sign transaction.
|
||||||
|
|
||||||
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
||||||
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, serialized_tx = btc.sign_tx(
|
_, serialized_tx = btc.sign_tx(
|
||||||
client,
|
session,
|
||||||
coin,
|
coin,
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
|||||||
)
|
)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_session
|
||||||
def sign_message(
|
def sign_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin: str,
|
coin: str,
|
||||||
address: str,
|
address: str,
|
||||||
message: str,
|
message: str,
|
||||||
@ -462,7 +463,7 @@ def sign_message(
|
|||||||
if script_type is None:
|
if script_type is None:
|
||||||
script_type = guess_script_type_from_path(address_n)
|
script_type = guess_script_type_from_path(address_n)
|
||||||
res = btc.sign_message(
|
res = btc.sign_message(
|
||||||
client,
|
session,
|
||||||
coin,
|
coin,
|
||||||
address_n,
|
address_n,
|
||||||
message,
|
message,
|
||||||
@ -483,9 +484,9 @@ def sign_message(
|
|||||||
@click.argument("address")
|
@click.argument("address")
|
||||||
@click.argument("signature")
|
@click.argument("signature")
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_session
|
||||||
def verify_message(
|
def verify_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
coin: str,
|
coin: str,
|
||||||
address: str,
|
address: str,
|
||||||
signature: str,
|
signature: str,
|
||||||
@ -495,7 +496,7 @@ def verify_message(
|
|||||||
"""Verify message."""
|
"""Verify message."""
|
||||||
signature_bytes = base64.b64decode(signature)
|
signature_bytes = base64.b64decode(signature)
|
||||||
return btc.verify_message(
|
return btc.verify_message(
|
||||||
client, coin, address, signature_bytes, message, chunkify=chunkify
|
session, coin, address, signature_bytes, message, chunkify=chunkify
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import cardano, messages, tools
|
from .. import cardano, messages, tools
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
|
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
|
||||||
|
|
||||||
@ -62,9 +62,9 @@ def cli() -> None:
|
|||||||
@click.option("-i", "--include-network-id", is_flag=True)
|
@click.option("-i", "--include-network-id", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@click.option("-T", "--tag-cbor-sets", is_flag=True)
|
@click.option("-T", "--tag-cbor-sets", is_flag=True)
|
||||||
@with_client
|
@with_session(derive_cardano=True)
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
file: TextIO,
|
file: TextIO,
|
||||||
signing_mode: messages.CardanoTxSigningMode,
|
signing_mode: messages.CardanoTxSigningMode,
|
||||||
protocol_magic: int,
|
protocol_magic: int,
|
||||||
@ -123,9 +123,8 @@ def sign_tx(
|
|||||||
for p in transaction["additional_witness_requests"]
|
for p in transaction["additional_witness_requests"]
|
||||||
]
|
]
|
||||||
|
|
||||||
client.init_device(derive_cardano=True)
|
|
||||||
sign_tx_response = cardano.sign_tx(
|
sign_tx_response = cardano.sign_tx(
|
||||||
client,
|
session,
|
||||||
signing_mode,
|
signing_mode,
|
||||||
inputs,
|
inputs,
|
||||||
outputs,
|
outputs,
|
||||||
@ -209,9 +208,9 @@ def sign_tx(
|
|||||||
default=messages.CardanoDerivationType.ICARUS,
|
default=messages.CardanoDerivationType.ICARUS,
|
||||||
)
|
)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session(derive_cardano=True)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
address_type: messages.CardanoAddressType,
|
address_type: messages.CardanoAddressType,
|
||||||
staking_address: str,
|
staking_address: str,
|
||||||
@ -262,9 +261,8 @@ def get_address(
|
|||||||
script_staking_hash_bytes,
|
script_staking_hash_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
client.init_device(derive_cardano=True)
|
|
||||||
return cardano.get_address(
|
return cardano.get_address(
|
||||||
client,
|
session,
|
||||||
address_parameters,
|
address_parameters,
|
||||||
protocol_magic,
|
protocol_magic,
|
||||||
network_id,
|
network_id,
|
||||||
@ -283,18 +281,17 @@ def get_address(
|
|||||||
default=messages.CardanoDerivationType.ICARUS,
|
default=messages.CardanoDerivationType.ICARUS,
|
||||||
)
|
)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session(derive_cardano=True)
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
derivation_type: messages.CardanoDerivationType,
|
derivation_type: messages.CardanoDerivationType,
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
) -> messages.CardanoPublicKey:
|
) -> messages.CardanoPublicKey:
|
||||||
"""Get Cardano public key."""
|
"""Get Cardano public key."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
client.init_device(derive_cardano=True)
|
|
||||||
return cardano.get_public_key(
|
return cardano.get_public_key(
|
||||||
client, address_n, derivation_type=derivation_type, show_display=show_display
|
session, address_n, derivation_type=derivation_type, show_display=show_display
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -312,9 +309,9 @@ def get_public_key(
|
|||||||
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
|
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
|
||||||
default=messages.CardanoDerivationType.ICARUS,
|
default=messages.CardanoDerivationType.ICARUS,
|
||||||
)
|
)
|
||||||
@with_client
|
@with_session(derive_cardano=True)
|
||||||
def get_native_script_hash(
|
def get_native_script_hash(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
file: TextIO,
|
file: TextIO,
|
||||||
display_format: messages.CardanoNativeScriptHashDisplayFormat,
|
display_format: messages.CardanoNativeScriptHashDisplayFormat,
|
||||||
derivation_type: messages.CardanoDerivationType,
|
derivation_type: messages.CardanoDerivationType,
|
||||||
@ -323,7 +320,6 @@ def get_native_script_hash(
|
|||||||
native_script_json = json.load(file)
|
native_script_json = json.load(file)
|
||||||
native_script = cardano.parse_native_script(native_script_json)
|
native_script = cardano.parse_native_script(native_script_json)
|
||||||
|
|
||||||
client.init_device(derive_cardano=True)
|
|
||||||
return cardano.get_native_script_hash(
|
return cardano.get_native_script_hash(
|
||||||
client, native_script, display_format, derivation_type=derivation_type
|
session, native_script, display_format, derivation_type=derivation_type
|
||||||
)
|
)
|
||||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import misc, tools
|
from .. import misc, tools
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
PROMPT_TYPE = ChoiceType(
|
PROMPT_TYPE = ChoiceType(
|
||||||
@ -42,10 +42,10 @@ def cli() -> None:
|
|||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("size", type=int)
|
@click.argument("size", type=int)
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def get_entropy(client: "TrezorClient", size: int) -> str:
|
def get_entropy(session: "Session", size: int) -> str:
|
||||||
"""Get random bytes from device."""
|
"""Get random bytes from device."""
|
||||||
return misc.get_entropy(client, size).hex()
|
return misc.get_entropy(session, size).hex()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
|
|||||||
)
|
)
|
||||||
@click.argument("key")
|
@click.argument("key")
|
||||||
@click.argument("value")
|
@click.argument("value")
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def encrypt_keyvalue(
|
def encrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
key: str,
|
key: str,
|
||||||
value: str,
|
value: str,
|
||||||
@ -75,7 +75,7 @@ def encrypt_keyvalue(
|
|||||||
ask_on_encrypt, ask_on_decrypt = prompt
|
ask_on_encrypt, ask_on_decrypt = prompt
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return misc.encrypt_keyvalue(
|
return misc.encrypt_keyvalue(
|
||||||
client,
|
session,
|
||||||
address_n,
|
address_n,
|
||||||
key,
|
key,
|
||||||
value.encode(),
|
value.encode(),
|
||||||
@ -91,9 +91,9 @@ def encrypt_keyvalue(
|
|||||||
)
|
)
|
||||||
@click.argument("key")
|
@click.argument("key")
|
||||||
@click.argument("value")
|
@click.argument("value")
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def decrypt_keyvalue(
|
def decrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
key: str,
|
key: str,
|
||||||
value: str,
|
value: str,
|
||||||
@ -112,7 +112,7 @@ def decrypt_keyvalue(
|
|||||||
ask_on_encrypt, ask_on_decrypt = prompt
|
ask_on_encrypt, ask_on_decrypt = prompt
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return misc.decrypt_keyvalue(
|
return misc.decrypt_keyvalue(
|
||||||
client,
|
session,
|
||||||
address_n,
|
address_n,
|
||||||
key,
|
key,
|
||||||
bytes.fromhex(value),
|
bytes.fromhex(value),
|
||||||
|
@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import mapping, messages, protobuf
|
|
||||||
from ..client import TrezorClient
|
|
||||||
from ..debuglink import TrezorClientDebugLink
|
from ..debuglink import TrezorClientDebugLink
|
||||||
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
|
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
|
||||||
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
|
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
|
||||||
from ..debuglink import record_screen
|
from ..debuglink import record_screen
|
||||||
from . import with_client
|
from ..transport.session import Session
|
||||||
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from . import TrezorConnection
|
from . import TrezorConnection
|
||||||
@ -35,51 +34,51 @@ def cli() -> None:
|
|||||||
"""Miscellaneous debug features."""
|
"""Miscellaneous debug features."""
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
# @cli.command()
|
||||||
@click.argument("message_name_or_type")
|
# @click.argument("message_name_or_type")
|
||||||
@click.argument("hex_data")
|
# @click.argument("hex_data")
|
||||||
@click.pass_obj
|
# @click.pass_obj
|
||||||
def send_bytes(
|
# def send_bytes(
|
||||||
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
|
# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str
|
||||||
) -> None:
|
# ) -> None:
|
||||||
"""Send raw bytes to Trezor.
|
# """Send raw bytes to Trezor.
|
||||||
|
|
||||||
Message type and message data must be specified separately, due to how message
|
# Message type and message data must be specified separately, due to how message
|
||||||
chunking works on the transport level. Message length is calculated and sent
|
# chunking works on the transport level. Message length is calculated and sent
|
||||||
automatically, and it is currently impossible to explicitly specify invalid length.
|
# automatically, and it is currently impossible to explicitly specify invalid length.
|
||||||
|
|
||||||
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
|
# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
|
||||||
in which case the value of that enum is used.
|
# in which case the value of that enum is used.
|
||||||
"""
|
# """
|
||||||
if message_name_or_type.isdigit():
|
# if message_name_or_type.isdigit():
|
||||||
message_type = int(message_name_or_type)
|
# message_type = int(message_name_or_type)
|
||||||
else:
|
# else:
|
||||||
message_type = getattr(messages.MessageType, message_name_or_type)
|
# message_type = getattr(messages.MessageType, message_name_or_type)
|
||||||
|
|
||||||
if not isinstance(message_type, int):
|
# if not isinstance(message_type, int):
|
||||||
raise click.ClickException("Invalid message type.")
|
# raise click.ClickException("Invalid message type.")
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
message_data = bytes.fromhex(hex_data)
|
# message_data = bytes.fromhex(hex_data)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
raise click.ClickException("Invalid hex data.") from e
|
# raise click.ClickException("Invalid hex data.") from e
|
||||||
|
|
||||||
transport = obj.get_transport()
|
# transport = obj.get_transport()
|
||||||
transport.begin_session()
|
# transport.deprecated_begin_session()
|
||||||
transport.write(message_type, message_data)
|
# transport.write(message_type, message_data)
|
||||||
|
|
||||||
response_type, response_data = transport.read()
|
# response_type, response_data = transport.read()
|
||||||
transport.end_session()
|
# transport.deprecated_end_session()
|
||||||
|
|
||||||
click.echo(f"Response type: {response_type}")
|
# click.echo(f"Response type: {response_type}")
|
||||||
click.echo(f"Response data: {response_data.hex()}")
|
# click.echo(f"Response data: {response_data.hex()}")
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||||
click.echo("Parsed message:")
|
# click.echo("Parsed message:")
|
||||||
click.echo(protobuf.format_message(msg))
|
# click.echo(protobuf.format_message(msg))
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
click.echo(f"Could not parse response: {e}")
|
# click.echo(f"Could not parse response: {e}")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -106,17 +105,17 @@ def record_screen_from_connection(
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def prodtest_t1(client: "TrezorClient") -> str:
|
def prodtest_t1(session: "Session") -> str:
|
||||||
"""Perform a prodtest on Model One.
|
"""Perform a prodtest on Model One.
|
||||||
|
|
||||||
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
|
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
|
||||||
"""
|
"""
|
||||||
return debuglink_prodtest_t1(client)
|
return debuglink_prodtest_t1(session)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def optiga_set_sec_max(client: "TrezorClient") -> str:
|
def optiga_set_sec_max(session: "Session") -> str:
|
||||||
"""Set Optiga's security event counter to maximum."""
|
"""Set Optiga's security event counter to maximum."""
|
||||||
return debuglink_optiga_set_sec_max(client)
|
return debuglink_optiga_set_sec_max(session)
|
||||||
|
@ -25,11 +25,11 @@ import requests
|
|||||||
|
|
||||||
from .. import debuglink, device, exceptions, messages, ui
|
from .. import debuglink, device, exceptions, messages, ui
|
||||||
from ..tools import format_path
|
from ..tools import format_path
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_session
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
|
||||||
from ..protobuf import MessageType
|
from ..protobuf import MessageType
|
||||||
|
from ..transport.session import Session
|
||||||
from . import TrezorConnection
|
from . import TrezorConnection
|
||||||
|
|
||||||
RECOVERY_DEVICE_INPUT_METHOD = {
|
RECOVERY_DEVICE_INPUT_METHOD = {
|
||||||
@ -65,17 +65,18 @@ def cli() -> None:
|
|||||||
help="Wipe device in bootloader mode. This also erases the firmware.",
|
help="Wipe device in bootloader mode. This also erases the firmware.",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
)
|
)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
def wipe(session: "Session", bootloader: bool) -> str:
|
||||||
"""Reset device to factory defaults and remove all private data."""
|
"""Reset device to factory defaults and remove all private data."""
|
||||||
|
features = session.features
|
||||||
if bootloader:
|
if bootloader:
|
||||||
if not client.features.bootloader_mode:
|
if not features.bootloader_mode:
|
||||||
click.echo("Please switch your device to bootloader mode.")
|
click.echo("Please switch your device to bootloader mode.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
click.echo("Wiping user data and firmware!")
|
click.echo("Wiping user data and firmware!")
|
||||||
else:
|
else:
|
||||||
if client.features.bootloader_mode:
|
if features.bootloader_mode:
|
||||||
click.echo(
|
click.echo(
|
||||||
"Your device is in bootloader mode. This operation would also erase firmware."
|
"Your device is in bootloader mode. This operation would also erase firmware."
|
||||||
)
|
)
|
||||||
@ -88,7 +89,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
|||||||
click.echo("Wiping user data!")
|
click.echo("Wiping user data!")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return device.wipe(client)
|
return device.wipe(
|
||||||
|
session
|
||||||
|
) # TODO decide where the wipe should happen - management or regular session
|
||||||
except exceptions.TrezorFailure as e:
|
except exceptions.TrezorFailure as e:
|
||||||
click.echo("Action failed: {} {}".format(*e.args))
|
click.echo("Action failed: {} {}".format(*e.args))
|
||||||
sys.exit(3)
|
sys.exit(3)
|
||||||
@ -104,9 +107,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
|
|||||||
@click.option("-a", "--academic", is_flag=True)
|
@click.option("-a", "--academic", is_flag=True)
|
||||||
@click.option("-b", "--needs-backup", is_flag=True)
|
@click.option("-b", "--needs-backup", is_flag=True)
|
||||||
@click.option("-n", "--no-backup", is_flag=True)
|
@click.option("-n", "--no-backup", is_flag=True)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def load(
|
def load(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
mnemonic: t.Sequence[str],
|
mnemonic: t.Sequence[str],
|
||||||
pin: str,
|
pin: str,
|
||||||
passphrase_protection: bool,
|
passphrase_protection: bool,
|
||||||
@ -137,7 +140,7 @@ def load(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return debuglink.load_device(
|
return debuglink.load_device(
|
||||||
client,
|
session,
|
||||||
mnemonic=list(mnemonic),
|
mnemonic=list(mnemonic),
|
||||||
pin=pin,
|
pin=pin,
|
||||||
passphrase_protection=passphrase_protection,
|
passphrase_protection=passphrase_protection,
|
||||||
@ -172,9 +175,9 @@ def load(
|
|||||||
)
|
)
|
||||||
@click.option("-d", "--dry-run", is_flag=True)
|
@click.option("-d", "--dry-run", is_flag=True)
|
||||||
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
|
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def recover(
|
def recover(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
words: str,
|
words: str,
|
||||||
expand: bool,
|
expand: bool,
|
||||||
pin_protection: bool,
|
pin_protection: bool,
|
||||||
@ -202,7 +205,7 @@ def recover(
|
|||||||
type = messages.RecoveryType.UnlockRepeatedBackup
|
type = messages.RecoveryType.UnlockRepeatedBackup
|
||||||
|
|
||||||
return device.recover(
|
return device.recover(
|
||||||
client,
|
session,
|
||||||
word_count=int(words),
|
word_count=int(words),
|
||||||
passphrase_protection=passphrase_protection,
|
passphrase_protection=passphrase_protection,
|
||||||
pin_protection=pin_protection,
|
pin_protection=pin_protection,
|
||||||
@ -224,9 +227,9 @@ def recover(
|
|||||||
@click.option("-n", "--no-backup", is_flag=True)
|
@click.option("-n", "--no-backup", is_flag=True)
|
||||||
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
|
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
|
||||||
@click.option("-e", "--entropy-check-count", type=click.IntRange(0))
|
@click.option("-e", "--entropy-check-count", type=click.IntRange(0))
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def setup(
|
def setup(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
strength: int | None,
|
strength: int | None,
|
||||||
passphrase_protection: bool,
|
passphrase_protection: bool,
|
||||||
pin_protection: bool,
|
pin_protection: bool,
|
||||||
@ -244,7 +247,7 @@ def setup(
|
|||||||
BT = messages.BackupType
|
BT = messages.BackupType
|
||||||
|
|
||||||
if backup_type is None:
|
if backup_type is None:
|
||||||
if client.version >= (2, 7, 1):
|
if session.version >= (2, 7, 1):
|
||||||
# SLIP39 extendable was introduced in 2.7.1
|
# SLIP39 extendable was introduced in 2.7.1
|
||||||
backup_type = BT.Slip39_Single_Extendable
|
backup_type = BT.Slip39_Single_Extendable
|
||||||
else:
|
else:
|
||||||
@ -254,10 +257,10 @@ def setup(
|
|||||||
if (
|
if (
|
||||||
backup_type
|
backup_type
|
||||||
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
|
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
|
||||||
and messages.Capability.Shamir not in client.features.capabilities
|
and messages.Capability.Shamir not in session.features.capabilities
|
||||||
) or (
|
) or (
|
||||||
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
|
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
|
||||||
and messages.Capability.ShamirGroups not in client.features.capabilities
|
and messages.Capability.ShamirGroups not in session.features.capabilities
|
||||||
):
|
):
|
||||||
click.echo(
|
click.echo(
|
||||||
"WARNING: Your Trezor device does not indicate support for the requested\n"
|
"WARNING: Your Trezor device does not indicate support for the requested\n"
|
||||||
@ -265,7 +268,7 @@ def setup(
|
|||||||
)
|
)
|
||||||
|
|
||||||
resp, path_xpubs = device.reset_entropy_check(
|
resp, path_xpubs = device.reset_entropy_check(
|
||||||
client,
|
session,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
passphrase_protection=passphrase_protection,
|
passphrase_protection=passphrase_protection,
|
||||||
pin_protection=pin_protection,
|
pin_protection=pin_protection,
|
||||||
@ -289,23 +292,21 @@ def setup(
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-t", "--group-threshold", type=int)
|
@click.option("-t", "--group-threshold", type=int)
|
||||||
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
|
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def backup(
|
def backup(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
group_threshold: int | None = None,
|
group_threshold: int | None = None,
|
||||||
groups: t.Sequence[tuple[int, int]] = (),
|
groups: t.Sequence[tuple[int, int]] = (),
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Perform device seed backup."""
|
"""Perform device seed backup."""
|
||||||
|
|
||||||
return device.backup(client, group_threshold, groups)
|
return device.backup(session, group_threshold, groups)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def sd_protect(
|
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str:
|
||||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
|
||||||
) -> str:
|
|
||||||
"""Secure the device with SD card protection.
|
"""Secure the device with SD card protection.
|
||||||
|
|
||||||
When SD card protection is enabled, a randomly generated secret is stored
|
When SD card protection is enabled, a randomly generated secret is stored
|
||||||
@ -319,9 +320,9 @@ def sd_protect(
|
|||||||
off - Remove SD card secret protection.
|
off - Remove SD card secret protection.
|
||||||
refresh - Replace the current SD card secret with a new one.
|
refresh - Replace the current SD card secret with a new one.
|
||||||
"""
|
"""
|
||||||
if client.features.model == "1":
|
if session.features.model == "1":
|
||||||
raise click.ClickException("Trezor One does not support SD card protection.")
|
raise click.ClickException("Trezor One does not support SD card protection.")
|
||||||
return device.sd_protect(client, operation)
|
return device.sd_protect(session, operation)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -331,24 +332,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str:
|
|||||||
|
|
||||||
Currently only supported on Trezor Model One.
|
Currently only supported on Trezor Model One.
|
||||||
"""
|
"""
|
||||||
# avoid using @with_client because it closes the session afterwards,
|
# avoid using @with_management_session because it closes the session afterwards,
|
||||||
# which triggers double prompt on device
|
# which triggers double prompt on device
|
||||||
with obj.client_context() as client:
|
with obj.client_context() as client:
|
||||||
return device.reboot_to_bootloader(client)
|
return device.reboot_to_bootloader(client.get_management_session())
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def tutorial(client: "TrezorClient") -> str:
|
def tutorial(session: "Session") -> str:
|
||||||
"""Show on-device tutorial."""
|
"""Show on-device tutorial."""
|
||||||
return device.show_device_tutorial(client)
|
return device.show_device_tutorial(session)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def unlock_bootloader(client: "TrezorClient") -> str:
|
def unlock_bootloader(session: "Session") -> str:
|
||||||
"""Unlocks bootloader. Irreversible."""
|
"""Unlocks bootloader. Irreversible."""
|
||||||
return device.unlock_bootloader(client)
|
return device.unlock_bootloader(session)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -359,11 +360,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
|
|||||||
type=int,
|
type=int,
|
||||||
help="Dialog expiry in seconds.",
|
help="Dialog expiry in seconds.",
|
||||||
)
|
)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str:
|
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str:
|
||||||
"""Show a "Do not disconnect" dialog."""
|
"""Show a "Do not disconnect" dialog."""
|
||||||
if enable is False:
|
if enable is False:
|
||||||
return device.set_busy(client, None)
|
return device.set_busy(session, None)
|
||||||
|
|
||||||
if expiry is None:
|
if expiry is None:
|
||||||
raise click.ClickException("Missing option '-e' / '--expiry'.")
|
raise click.ClickException("Missing option '-e' / '--expiry'.")
|
||||||
@ -373,7 +374,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
|
|||||||
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
|
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
|
||||||
)
|
)
|
||||||
|
|
||||||
return device.set_busy(client, expiry * 1000)
|
return device.set_busy(session, expiry * 1000)
|
||||||
|
|
||||||
|
|
||||||
PUBKEY_WHITELIST_URL_TEMPLATE = (
|
PUBKEY_WHITELIST_URL_TEMPLATE = (
|
||||||
@ -393,9 +394,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
|
|||||||
is_flag=True,
|
is_flag=True,
|
||||||
help="Do not check intermediate certificates against the whitelist.",
|
help="Do not check intermediate certificates against the whitelist.",
|
||||||
)
|
)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def authenticate(
|
def authenticate(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
hex_challenge: str | None,
|
hex_challenge: str | None,
|
||||||
root: t.BinaryIO | None,
|
root: t.BinaryIO | None,
|
||||||
raw: bool | None,
|
raw: bool | None,
|
||||||
@ -420,7 +421,7 @@ def authenticate(
|
|||||||
challenge = bytes.fromhex(hex_challenge)
|
challenge = bytes.fromhex(hex_challenge)
|
||||||
|
|
||||||
if raw:
|
if raw:
|
||||||
msg = device.authenticate(client, challenge)
|
msg = device.authenticate(session, challenge)
|
||||||
|
|
||||||
click.echo(f"Challenge: {hex_challenge}")
|
click.echo(f"Challenge: {hex_challenge}")
|
||||||
click.echo(f"Signature of challenge: {msg.signature.hex()}")
|
click.echo(f"Signature of challenge: {msg.signature.hex()}")
|
||||||
@ -468,14 +469,14 @@ def authenticate(
|
|||||||
else:
|
else:
|
||||||
whitelist_json = requests.get(
|
whitelist_json = requests.get(
|
||||||
PUBKEY_WHITELIST_URL_TEMPLATE.format(
|
PUBKEY_WHITELIST_URL_TEMPLATE.format(
|
||||||
model=client.model.internal_name.lower()
|
model=session.model.internal_name.lower()
|
||||||
)
|
)
|
||||||
).json()
|
).json()
|
||||||
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
|
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
authentication.authenticate_device(
|
authentication.authenticate_device(
|
||||||
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
||||||
)
|
)
|
||||||
except authentication.DeviceNotAuthentic:
|
except authentication.DeviceNotAuthentic:
|
||||||
click.echo("Device is not authentic.")
|
click.echo("Device is not authentic.")
|
||||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import eos, tools
|
from .. import eos, tools
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .. import messages
|
from .. import messages
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
|
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
|
||||||
|
|
||||||
@ -37,11 +37,11 @@ def cli() -> None:
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||||
"""Get Eos public key in base58 encoding."""
|
"""Get Eos public key in base58 encoding."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
res = eos.get_public_key(client, address_n, show_display)
|
res = eos.get_public_key(session, address_n, show_display)
|
||||||
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
|
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
|
||||||
|
|
||||||
|
|
||||||
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def sign_transaction(
|
def sign_transaction(
|
||||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||||
) -> "messages.EosSignedTx":
|
) -> "messages.EosSignedTx":
|
||||||
"""Sign EOS transaction."""
|
"""Sign EOS transaction."""
|
||||||
tx_json = json.load(file)
|
tx_json = json.load(file)
|
||||||
|
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return eos.sign_tx(
|
return eos.sign_tx(
|
||||||
client,
|
session,
|
||||||
address_n,
|
address_n,
|
||||||
tx_json["transaction"],
|
tx_json["transaction"],
|
||||||
tx_json["chain_id"],
|
tx_json["chain_id"],
|
||||||
|
@ -26,14 +26,14 @@ import click
|
|||||||
|
|
||||||
from .. import _rlp, definitions, ethereum, tools
|
from .. import _rlp, definitions, ethereum, tools
|
||||||
from ..messages import EthereumDefinitions
|
from ..messages import EthereumDefinitions
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import web3
|
import web3
|
||||||
from eth_typing import ChecksumAddress # noqa: I900
|
from eth_typing import ChecksumAddress # noqa: I900
|
||||||
from web3.types import Wei
|
from web3.types import Wei
|
||||||
|
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
|
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
|
||||||
|
|
||||||
@ -268,24 +268,24 @@ def cli(
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get Ethereum address in hex encoding."""
|
"""Get Ethereum address in hex encoding."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||||
return ethereum.get_address(client, address_n, show_display, network, chunkify)
|
return ethereum.get_address(session, address_n, show_display, network, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
|
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
|
||||||
"""Get Ethereum public node of given path."""
|
"""Get Ethereum public node of given path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
result = ethereum.get_public_node(client, address_n, show_display=show_display)
|
result = ethereum.get_public_node(session, address_n, show_display=show_display)
|
||||||
return {
|
return {
|
||||||
"node": {
|
"node": {
|
||||||
"depth": result.node.depth,
|
"depth": result.node.depth,
|
||||||
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
|
|||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@click.argument("to_address")
|
@click.argument("to_address")
|
||||||
@click.argument("amount", callback=_amount_to_int)
|
@click.argument("amount", callback=_amount_to_int)
|
||||||
@with_client
|
@with_session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
chain_id: int,
|
chain_id: int,
|
||||||
address: str,
|
address: str,
|
||||||
amount: int,
|
amount: int,
|
||||||
@ -400,7 +400,7 @@ def sign_tx(
|
|||||||
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
|
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
from_address = ethereum.get_address(
|
from_address = ethereum.get_address(
|
||||||
client, address_n, encoded_network=encoded_network
|
session, address_n, encoded_network=encoded_network
|
||||||
)
|
)
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
@ -446,7 +446,7 @@ def sign_tx(
|
|||||||
assert max_gas_fee is not None
|
assert max_gas_fee is not None
|
||||||
assert max_priority_fee is not None
|
assert max_priority_fee is not None
|
||||||
sig = ethereum.sign_tx_eip1559(
|
sig = ethereum.sign_tx_eip1559(
|
||||||
client,
|
session,
|
||||||
n=address_n,
|
n=address_n,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
gas_limit=gas_limit,
|
gas_limit=gas_limit,
|
||||||
@ -465,7 +465,7 @@ def sign_tx(
|
|||||||
gas_price = _get_web3().eth.gas_price
|
gas_price = _get_web3().eth.gas_price
|
||||||
assert gas_price is not None
|
assert gas_price is not None
|
||||||
sig = ethereum.sign_tx(
|
sig = ethereum.sign_tx(
|
||||||
client,
|
session,
|
||||||
n=address_n,
|
n=address_n,
|
||||||
tx_type=tx_type,
|
tx_type=tx_type,
|
||||||
nonce=nonce,
|
nonce=nonce,
|
||||||
@ -526,14 +526,14 @@ def sign_tx(
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_session
|
||||||
def sign_message(
|
def sign_message(
|
||||||
client: "TrezorClient", address: str, message: str, chunkify: bool
|
session: "Session", address: str, message: str, chunkify: bool
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Sign message with Ethereum address."""
|
"""Sign message with Ethereum address."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||||
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
|
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
|
||||||
output = {
|
output = {
|
||||||
"message": message,
|
"message": message,
|
||||||
"address": ret.address,
|
"address": ret.address,
|
||||||
@ -550,9 +550,9 @@ def sign_message(
|
|||||||
help="Be compatible with Metamask's signTypedData_v4 implementation",
|
help="Be compatible with Metamask's signTypedData_v4 implementation",
|
||||||
)
|
)
|
||||||
@click.argument("file", type=click.File("r"))
|
@click.argument("file", type=click.File("r"))
|
||||||
@with_client
|
@with_session
|
||||||
def sign_typed_data(
|
def sign_typed_data(
|
||||||
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
|
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Sign typed data (EIP-712) with Ethereum address.
|
"""Sign typed data (EIP-712) with Ethereum address.
|
||||||
|
|
||||||
@ -565,7 +565,7 @@ def sign_typed_data(
|
|||||||
defs = EthereumDefinitions(encoded_network=network)
|
defs = EthereumDefinitions(encoded_network=network)
|
||||||
data = json.loads(file.read())
|
data = json.loads(file.read())
|
||||||
ret = ethereum.sign_typed_data(
|
ret = ethereum.sign_typed_data(
|
||||||
client,
|
session,
|
||||||
address_n,
|
address_n,
|
||||||
data,
|
data,
|
||||||
metamask_v4_compat=metamask_v4_compat,
|
metamask_v4_compat=metamask_v4_compat,
|
||||||
@ -583,9 +583,9 @@ def sign_typed_data(
|
|||||||
@click.argument("address")
|
@click.argument("address")
|
||||||
@click.argument("signature")
|
@click.argument("signature")
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@with_client
|
@with_session
|
||||||
def verify_message(
|
def verify_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
signature: str,
|
signature: str,
|
||||||
message: str,
|
message: str,
|
||||||
@ -594,7 +594,7 @@ def verify_message(
|
|||||||
"""Verify message signed with Ethereum address."""
|
"""Verify message signed with Ethereum address."""
|
||||||
signature_bytes = ethereum.decode_hex(signature)
|
signature_bytes = ethereum.decode_hex(signature)
|
||||||
return ethereum.verify_message(
|
return ethereum.verify_message(
|
||||||
client, address, signature_bytes, message, chunkify=chunkify
|
session, address, signature_bytes, message, chunkify=chunkify
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -602,9 +602,9 @@ def verify_message(
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.argument("domain_hash_hex")
|
@click.argument("domain_hash_hex")
|
||||||
@click.argument("message_hash_hex")
|
@click.argument("message_hash_hex")
|
||||||
@with_client
|
@with_session
|
||||||
def sign_typed_data_hash(
|
def sign_typed_data_hash(
|
||||||
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
|
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Sign hash of typed data (EIP-712) with Ethereum address.
|
Sign hash of typed data (EIP-712) with Ethereum address.
|
||||||
@ -618,7 +618,7 @@ def sign_typed_data_hash(
|
|||||||
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
|
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
|
||||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||||
ret = ethereum.sign_typed_data_hash(
|
ret = ethereum.sign_typed_data_hash(
|
||||||
client, address_n, domain_hash, message_hash, network
|
session, address_n, domain_hash, message_hash, network
|
||||||
)
|
)
|
||||||
output = {
|
output = {
|
||||||
"domain_hash": domain_hash_hex,
|
"domain_hash": domain_hash_hex,
|
||||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import fido
|
from .. import fido
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
||||||
|
|
||||||
@ -40,10 +40,10 @@ def credentials() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@credentials.command(name="list")
|
@credentials.command(name="list")
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def credentials_list(client: "TrezorClient") -> None:
|
def credentials_list(session: "Session") -> None:
|
||||||
"""List all resident credentials on the device."""
|
"""List all resident credentials on the device."""
|
||||||
creds = fido.list_credentials(client)
|
creds = fido.list_credentials(session)
|
||||||
for cred in creds:
|
for cred in creds:
|
||||||
click.echo("")
|
click.echo("")
|
||||||
click.echo(f"WebAuthn credential at index {cred.index}:")
|
click.echo(f"WebAuthn credential at index {cred.index}:")
|
||||||
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
|
|||||||
|
|
||||||
@credentials.command(name="add")
|
@credentials.command(name="add")
|
||||||
@click.argument("hex_credential_id")
|
@click.argument("hex_credential_id")
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
|
def credentials_add(session: "Session", hex_credential_id: str) -> str:
|
||||||
"""Add the credential with the given ID as a resident credential.
|
"""Add the credential with the given ID as a resident credential.
|
||||||
|
|
||||||
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
||||||
"""
|
"""
|
||||||
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
|
return fido.add_credential(session, bytes.fromhex(hex_credential_id))
|
||||||
|
|
||||||
|
|
||||||
@credentials.command(name="remove")
|
@credentials.command(name="remove")
|
||||||
@click.option(
|
@click.option(
|
||||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||||
)
|
)
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def credentials_remove(client: "TrezorClient", index: int) -> str:
|
def credentials_remove(session: "Session", index: int) -> str:
|
||||||
"""Remove the resident credential at the given index."""
|
"""Remove the resident credential at the given index."""
|
||||||
return fido.remove_credential(client, index)
|
return fido.remove_credential(session, index)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -110,19 +110,19 @@ def counter() -> None:
|
|||||||
|
|
||||||
@counter.command(name="set")
|
@counter.command(name="set")
|
||||||
@click.argument("counter", type=int)
|
@click.argument("counter", type=int)
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def counter_set(client: "TrezorClient", counter: int) -> str:
|
def counter_set(session: "Session", counter: int) -> str:
|
||||||
"""Set FIDO/U2F counter value."""
|
"""Set FIDO/U2F counter value."""
|
||||||
return fido.set_counter(client, counter)
|
return fido.set_counter(session, counter)
|
||||||
|
|
||||||
|
|
||||||
@counter.command(name="get-next")
|
@counter.command(name="get-next")
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def counter_get_next(client: "TrezorClient") -> int:
|
def counter_get_next(session: "Session") -> int:
|
||||||
"""Get-and-increase value of FIDO/U2F counter.
|
"""Get-and-increase value of FIDO/U2F counter.
|
||||||
|
|
||||||
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
||||||
is returned and atomically increased. This command performs the same operation
|
is returned and atomically increased. This command performs the same operation
|
||||||
and returns the counter value.
|
and returns the counter value.
|
||||||
"""
|
"""
|
||||||
return fido.get_next_counter(client)
|
return fido.get_next_counter(session)
|
||||||
|
@ -37,10 +37,11 @@ import requests
|
|||||||
from .. import device, exceptions, firmware, messages, models
|
from .. import device, exceptions, firmware, messages, models
|
||||||
from ..firmware import models as fw_models
|
from ..firmware import models as fw_models
|
||||||
from ..models import TrezorModel
|
from ..models import TrezorModel
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..client import TrezorClient
|
||||||
|
from ..transport.session import Session
|
||||||
from . import TrezorConnection
|
from . import TrezorConnection
|
||||||
|
|
||||||
MODEL_CHOICE = ChoiceType(
|
MODEL_CHOICE = ChoiceType(
|
||||||
@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
|
|||||||
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
||||||
1.8.0 because that installs the appropriate bootloader.
|
1.8.0 because that installs the appropriate bootloader.
|
||||||
"""
|
"""
|
||||||
f = client.features
|
features = client.features
|
||||||
version = (f.major_version, f.minor_version, f.patch_version)
|
version = client.version
|
||||||
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
|
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
|
||||||
return bootloader_onev2
|
return bootloader_onev2
|
||||||
|
|
||||||
|
|
||||||
@ -306,25 +307,26 @@ def find_best_firmware_version(
|
|||||||
If the specified version is not found, prints the closest available version
|
If the specified version is not found, prints the closest available version
|
||||||
(higher than the specified one, if existing).
|
(higher than the specified one, if existing).
|
||||||
"""
|
"""
|
||||||
|
features = client.features
|
||||||
|
model = client.model
|
||||||
|
|
||||||
if bitcoin_only is None:
|
if bitcoin_only is None:
|
||||||
bitcoin_only = _should_use_bitcoin_only(client.features)
|
bitcoin_only = _should_use_bitcoin_only(features)
|
||||||
|
|
||||||
def version_str(version: Iterable[int]) -> str:
|
def version_str(version: Iterable[int]) -> str:
|
||||||
return ".".join(map(str, version))
|
return ".".join(map(str, version))
|
||||||
|
|
||||||
f = client.features
|
releases = get_all_firmware_releases(model, bitcoin_only, beta)
|
||||||
|
|
||||||
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
|
|
||||||
highest_version = releases[0]["version"]
|
highest_version = releases[0]["version"]
|
||||||
|
|
||||||
if version:
|
if version:
|
||||||
want_version = [int(x) for x in version.split(".")]
|
want_version = [int(x) for x in version.split(".")]
|
||||||
if len(want_version) != 3:
|
if len(want_version) != 3:
|
||||||
click.echo("Please use the 'X.Y.Z' version format.")
|
click.echo("Please use the 'X.Y.Z' version format.")
|
||||||
if want_version[0] != f.major_version:
|
if want_version[0] != features.major_version:
|
||||||
click.echo(
|
click.echo(
|
||||||
f"Warning: Trezor {client.model.name} firmware version should be "
|
f"Warning: Trezor {model.name} firmware version should be "
|
||||||
f"{f.major_version}.X.Y (requested: {version})"
|
f"{features.major_version}.X.Y (requested: {version})"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
want_version = highest_version
|
want_version = highest_version
|
||||||
@ -359,8 +361,8 @@ def find_best_firmware_version(
|
|||||||
# to the newer one, in that case update to the minimal
|
# to the newer one, in that case update to the minimal
|
||||||
# compatible version first
|
# compatible version first
|
||||||
# Choosing the version key to compare based on (not) being in BL mode
|
# Choosing the version key to compare based on (not) being in BL mode
|
||||||
client_version = [f.major_version, f.minor_version, f.patch_version]
|
client_version = client.version
|
||||||
if f.bootloader_mode:
|
if features.bootloader_mode:
|
||||||
key_to_compare = "min_bootloader_version"
|
key_to_compare = "min_bootloader_version"
|
||||||
else:
|
else:
|
||||||
key_to_compare = "min_firmware_version"
|
key_to_compare = "min_firmware_version"
|
||||||
@ -447,11 +449,11 @@ def extract_embedded_fw(
|
|||||||
|
|
||||||
|
|
||||||
def upload_firmware_into_device(
|
def upload_firmware_into_device(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
firmware_data: bytes,
|
firmware_data: bytes,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Perform the final act of loading the firmware into Trezor."""
|
"""Perform the final act of loading the firmware into Trezor."""
|
||||||
f = client.features
|
f = session.features
|
||||||
try:
|
try:
|
||||||
if f.major_version == 1 and f.firmware_present is not False:
|
if f.major_version == 1 and f.firmware_present is not False:
|
||||||
# Trezor One does not send ButtonRequest
|
# Trezor One does not send ButtonRequest
|
||||||
@ -461,7 +463,7 @@ def upload_firmware_into_device(
|
|||||||
with click.progressbar(
|
with click.progressbar(
|
||||||
label="Uploading", length=len(firmware_data), show_eta=False
|
label="Uploading", length=len(firmware_data), show_eta=False
|
||||||
) as bar:
|
) as bar:
|
||||||
firmware.update(client, firmware_data, bar.update)
|
firmware.update(session, firmware_data, bar.update)
|
||||||
except exceptions.Cancelled:
|
except exceptions.Cancelled:
|
||||||
click.echo("Update aborted on device.")
|
click.echo("Update aborted on device.")
|
||||||
except exceptions.TrezorException as e:
|
except exceptions.TrezorException as e:
|
||||||
@ -654,6 +656,7 @@ def update(
|
|||||||
against data.trezor.io information, if available.
|
against data.trezor.io information, if available.
|
||||||
"""
|
"""
|
||||||
with obj.client_context() as client:
|
with obj.client_context() as client:
|
||||||
|
management_session = client.get_management_session()
|
||||||
if sum(bool(x) for x in (filename, url, version)) > 1:
|
if sum(bool(x) for x in (filename, url, version)) > 1:
|
||||||
click.echo("You can use only one of: filename, url, version.")
|
click.echo("You can use only one of: filename, url, version.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@ -709,7 +712,7 @@ def update(
|
|||||||
if _is_strict_update(client, firmware_data):
|
if _is_strict_update(client, firmware_data):
|
||||||
header_size = _get_firmware_header_size(firmware_data)
|
header_size = _get_firmware_header_size(firmware_data)
|
||||||
device.reboot_to_bootloader(
|
device.reboot_to_bootloader(
|
||||||
client,
|
management_session,
|
||||||
boot_command=messages.BootCommand.INSTALL_UPGRADE,
|
boot_command=messages.BootCommand.INSTALL_UPGRADE,
|
||||||
firmware_header=firmware_data[:header_size],
|
firmware_header=firmware_data[:header_size],
|
||||||
language_data=language_data,
|
language_data=language_data,
|
||||||
@ -719,7 +722,7 @@ def update(
|
|||||||
click.echo(
|
click.echo(
|
||||||
"WARNING: Seamless installation not possible, language data will not be uploaded."
|
"WARNING: Seamless installation not possible, language data will not be uploaded."
|
||||||
)
|
)
|
||||||
device.reboot_to_bootloader(client)
|
device.reboot_to_bootloader(management_session)
|
||||||
|
|
||||||
click.echo("Waiting for bootloader...")
|
click.echo("Waiting for bootloader...")
|
||||||
while True:
|
while True:
|
||||||
@ -735,13 +738,15 @@ def update(
|
|||||||
click.echo("Please switch your device to bootloader mode.")
|
click.echo("Please switch your device to bootloader mode.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
upload_firmware_into_device(client=client, firmware_data=firmware_data)
|
upload_firmware_into_device(
|
||||||
|
session=client.get_management_session(), firmware_data=firmware_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("hex_challenge", required=False)
|
@click.argument("hex_challenge", required=False)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
|
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
|
||||||
"""Get a hash of the installed firmware combined with the optional challenge."""
|
"""Get a hash of the installed firmware combined with the optional challenge."""
|
||||||
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
|
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
|
||||||
return firmware.get_hash(client, challenge).hex()
|
return firmware.get_hash(session, challenge).hex()
|
||||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import messages, monero, tools
|
from .. import messages, monero, tools
|
||||||
from . import ChoiceType, with_client
|
from . import ChoiceType, with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
|
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
|
||||||
|
|
||||||
@ -42,9 +42,9 @@ def cli() -> None:
|
|||||||
default=messages.MoneroNetworkType.MAINNET,
|
default=messages.MoneroNetworkType.MAINNET,
|
||||||
)
|
)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
network_type: messages.MoneroNetworkType,
|
network_type: messages.MoneroNetworkType,
|
||||||
@ -52,7 +52,7 @@ def get_address(
|
|||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Get Monero address for specified path."""
|
"""Get Monero address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return monero.get_address(client, address_n, show_display, network_type, chunkify)
|
return monero.get_address(session, address_n, show_display, network_type, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -63,13 +63,13 @@ def get_address(
|
|||||||
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
|
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
|
||||||
default=messages.MoneroNetworkType.MAINNET,
|
default=messages.MoneroNetworkType.MAINNET,
|
||||||
)
|
)
|
||||||
@with_client
|
@with_session
|
||||||
def get_watch_key(
|
def get_watch_key(
|
||||||
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
|
session: "Session", address: str, network_type: messages.MoneroNetworkType
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""Get Monero watch key for specified path."""
|
"""Get Monero watch key for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
res = monero.get_watch_key(client, address_n, network_type)
|
res = monero.get_watch_key(session, address_n, network_type)
|
||||||
# TODO: could be made required in MoneroWatchKey
|
# TODO: could be made required in MoneroWatchKey
|
||||||
assert res.address is not None
|
assert res.address is not None
|
||||||
assert res.watch_key is not None
|
assert res.watch_key is not None
|
||||||
|
@ -21,10 +21,10 @@ import click
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from .. import nem, tools
|
from .. import nem, tools
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
|
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
|
||||||
|
|
||||||
@ -39,9 +39,9 @@ def cli() -> None:
|
|||||||
@click.option("-N", "--network", type=int, default=0x68)
|
@click.option("-N", "--network", type=int, default=0x68)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
network: int,
|
network: int,
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
@ -49,7 +49,7 @@ def get_address(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Get NEM address for specified path."""
|
"""Get NEM address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return nem.get_address(client, address_n, network, show_display, chunkify)
|
return nem.get_address(session, address_n, network, show_display, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -58,9 +58,9 @@ def get_address(
|
|||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
file: TextIO,
|
file: TextIO,
|
||||||
broadcast: Optional[str],
|
broadcast: Optional[str],
|
||||||
@ -71,7 +71,7 @@ def sign_tx(
|
|||||||
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
||||||
"""
|
"""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||||
|
|
||||||
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}
|
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}
|
||||||
|
|
||||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import ripple, tools
|
from .. import ripple, tools
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
|
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
|
||||||
|
|
||||||
@ -37,13 +37,13 @@ def cli() -> None:
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get Ripple address"""
|
"""Get Ripple address"""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return ripple.get_address(client, address_n, show_display, chunkify)
|
return ripple.get_address(session, address_n, show_display, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -51,13 +51,13 @@ def get_address(
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
|
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
|
||||||
"""Sign Ripple transaction"""
|
"""Sign Ripple transaction"""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
msg = ripple.create_sign_tx_msg(json.load(file))
|
msg = ripple.create_sign_tx_msg(json.load(file))
|
||||||
|
|
||||||
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
|
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
|
||||||
click.echo("Signature:")
|
click.echo("Signature:")
|
||||||
click.echo(result.signature.hex())
|
click.echo(result.signature.hex())
|
||||||
click.echo()
|
click.echo()
|
||||||
|
@ -24,10 +24,11 @@ import click
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from .. import device, messages, toif
|
from .. import device, messages, toif
|
||||||
from . import AliasedGroup, ChoiceType, with_client
|
from ..transport.session import Session
|
||||||
|
from . import AliasedGroup, ChoiceType, with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -180,18 +181,18 @@ def cli() -> None:
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
|
def pin(session: "Session", enable: Optional[bool], remove: bool) -> str:
|
||||||
"""Set, change or remove PIN."""
|
"""Set, change or remove PIN."""
|
||||||
# Remove argument is there for backwards compatibility
|
# Remove argument is there for backwards compatibility
|
||||||
return device.change_pin(client, remove=_should_remove(enable, remove))
|
return device.change_pin(session, remove=_should_remove(enable, remove))
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
|
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str:
|
||||||
"""Set or remove the wipe code.
|
"""Set or remove the wipe code.
|
||||||
|
|
||||||
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
||||||
@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
|
|||||||
removed and the device will be reset to factory defaults.
|
removed and the device will be reset to factory defaults.
|
||||||
"""
|
"""
|
||||||
# Remove argument is there for backwards compatibility
|
# Remove argument is there for backwards compatibility
|
||||||
return device.change_wipe_code(client, remove=_should_remove(enable, remove))
|
return device.change_wipe_code(session, remove=_should_remove(enable, remove))
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
# keep the deprecated -l/--label option, make it do nothing
|
# keep the deprecated -l/--label option, make it do nothing
|
||||||
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.argument("label")
|
@click.argument("label")
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def label(client: "TrezorClient", label: str) -> str:
|
def label(session: "Session", label: str) -> str:
|
||||||
"""Set new device label."""
|
"""Set new device label."""
|
||||||
return device.apply_settings(client, label=label)
|
return device.apply_settings(session, label=label)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def brightness(client: "TrezorClient") -> str:
|
def brightness(session: "Session") -> str:
|
||||||
"""Set display brightness."""
|
"""Set display brightness."""
|
||||||
return device.set_brightness(client)
|
return device.set_brightness(session)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
|
def haptic_feedback(session: "Session", enable: bool) -> str:
|
||||||
"""Enable or disable haptic feedback."""
|
"""Enable or disable haptic feedback."""
|
||||||
return device.apply_settings(client, haptic_feedback=enable)
|
return device.apply_settings(session, haptic_feedback=enable)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
|
|||||||
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
|
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
|
||||||
)
|
)
|
||||||
@click.option("-d/-D", "--display/--no-display", default=None)
|
@click.option("-d/-D", "--display/--no-display", default=None)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def language(
|
def language(
|
||||||
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
|
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Set new language with translations."""
|
"""Set new language with translations."""
|
||||||
if remove != (path_or_url is None):
|
if remove != (path_or_url is None):
|
||||||
@ -260,29 +261,29 @@ def language(
|
|||||||
f"Failed to load translations from {path_or_url}"
|
f"Failed to load translations from {path_or_url}"
|
||||||
) from None
|
) from None
|
||||||
return device.change_language(
|
return device.change_language(
|
||||||
client, language_data=language_data, show_display=display
|
session, language_data=language_data, show_display=display
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("rotation", type=ChoiceType(ROTATION))
|
@click.argument("rotation", type=ChoiceType(ROTATION))
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str:
|
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str:
|
||||||
"""Set display rotation.
|
"""Set display rotation.
|
||||||
|
|
||||||
Configure display rotation for Trezor Model T. The options are
|
Configure display rotation for Trezor Model T. The options are
|
||||||
north, east, south or west.
|
north, east, south or west.
|
||||||
"""
|
"""
|
||||||
return device.apply_settings(client, display_rotation=rotation)
|
return device.apply_settings(session, display_rotation=rotation)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("delay", type=str)
|
@click.argument("delay", type=str)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
def auto_lock_delay(session: "Session", delay: str) -> str:
|
||||||
"""Set auto-lock delay (in seconds)."""
|
"""Set auto-lock delay (in seconds)."""
|
||||||
|
|
||||||
if not client.features.pin_protection:
|
if not session.features.pin_protection:
|
||||||
raise click.ClickException("Set up a PIN first")
|
raise click.ClickException("Set up a PIN first")
|
||||||
|
|
||||||
value, unit = delay[:-1], delay[-1:]
|
value, unit = delay[:-1], delay[-1:]
|
||||||
@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
|
|||||||
seconds = float(value) * units[unit]
|
seconds = float(value) * units[unit]
|
||||||
else:
|
else:
|
||||||
seconds = float(delay) # assume seconds if no unit is specified
|
seconds = float(delay) # assume seconds if no unit is specified
|
||||||
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
|
return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("flags")
|
@click.argument("flags")
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def flags(client: "TrezorClient", flags: str) -> str:
|
def flags(session: "Session", flags: str) -> str:
|
||||||
"""Set device flags."""
|
"""Set device flags."""
|
||||||
if flags.lower().startswith("0b"):
|
if flags.lower().startswith("0b"):
|
||||||
flags_int = int(flags, 2)
|
flags_int = int(flags, 2)
|
||||||
@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
|
|||||||
flags_int = int(flags, 16)
|
flags_int = int(flags, 16)
|
||||||
else:
|
else:
|
||||||
flags_int = int(flags)
|
flags_int = int(flags)
|
||||||
return device.apply_flags(client, flags=flags_int)
|
return device.apply_flags(session, flags=flags_int)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str:
|
|||||||
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
||||||
)
|
)
|
||||||
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
|
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
def homescreen(session: "Session", filename: str, quality: int) -> str:
|
||||||
"""Set new homescreen.
|
"""Set new homescreen.
|
||||||
|
|
||||||
To revert to default homescreen, use 'trezorctl set homescreen default'
|
To revert to default homescreen, use 'trezorctl set homescreen default'
|
||||||
@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
|||||||
if not path.exists() or not path.is_file():
|
if not path.exists() or not path.is_file():
|
||||||
raise click.ClickException("Cannot open file")
|
raise click.ClickException("Cannot open file")
|
||||||
|
|
||||||
if client.features.model == "1":
|
if session.features.model == "1":
|
||||||
img = image_to_t1(path)
|
img = image_to_t1(path)
|
||||||
else:
|
else:
|
||||||
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
||||||
width = (
|
width = (
|
||||||
client.features.homescreen_width
|
session.features.homescreen_width
|
||||||
if client.features.homescreen_width is not None
|
if session.features.homescreen_width is not None
|
||||||
else 240
|
else 240
|
||||||
)
|
)
|
||||||
height = (
|
height = (
|
||||||
client.features.homescreen_height
|
session.features.homescreen_height
|
||||||
if client.features.homescreen_height is not None
|
if session.features.homescreen_height is not None
|
||||||
else 240
|
else 240
|
||||||
)
|
)
|
||||||
img = image_to_jpeg(path, width, height, quality)
|
img = image_to_jpeg(path, width, height, quality)
|
||||||
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
||||||
width = client.features.homescreen_width
|
width = session.features.homescreen_width
|
||||||
height = client.features.homescreen_height
|
height = session.features.homescreen_height
|
||||||
if width is None or height is None:
|
if width is None or height is None:
|
||||||
raise click.ClickException("Device did not report homescreen size.")
|
raise click.ClickException("Device did not report homescreen size.")
|
||||||
img = image_to_toif(path, width, height, True)
|
img = image_to_toif(path, width, height, True)
|
||||||
elif (
|
elif (
|
||||||
client.features.homescreen_format == messages.HomescreenFormat.Toif
|
session.features.homescreen_format == messages.HomescreenFormat.Toif
|
||||||
or client.features.homescreen_format is None
|
or session.features.homescreen_format is None
|
||||||
):
|
):
|
||||||
width = (
|
width = (
|
||||||
client.features.homescreen_width
|
session.features.homescreen_width
|
||||||
if client.features.homescreen_width is not None
|
if session.features.homescreen_width is not None
|
||||||
else 144
|
else 144
|
||||||
)
|
)
|
||||||
height = (
|
height = (
|
||||||
client.features.homescreen_height
|
session.features.homescreen_height
|
||||||
if client.features.homescreen_height is not None
|
if session.features.homescreen_height is not None
|
||||||
else 144
|
else 144
|
||||||
)
|
)
|
||||||
img = image_to_toif(path, width, height, False)
|
img = image_to_toif(path, width, height, False)
|
||||||
@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
|||||||
"Unknown image format requested by the device."
|
"Unknown image format requested by the device."
|
||||||
)
|
)
|
||||||
|
|
||||||
return device.apply_settings(client, homescreen=img)
|
return device.apply_settings(session, homescreen=img)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
|
|||||||
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
|
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
|
||||||
)
|
)
|
||||||
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def safety_checks(
|
def safety_checks(
|
||||||
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
|
session: "Session", always: bool, level: messages.SafetyCheckLevel
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Set safety check level.
|
"""Set safety check level.
|
||||||
|
|
||||||
@ -392,18 +393,18 @@ def safety_checks(
|
|||||||
"""
|
"""
|
||||||
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
|
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
|
||||||
level = messages.SafetyCheckLevel.PromptAlways
|
level = messages.SafetyCheckLevel.PromptAlways
|
||||||
return device.apply_settings(client, safety_checks=level)
|
return device.apply_settings(session, safety_checks=level)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def experimental_features(client: "TrezorClient", enable: bool) -> str:
|
def experimental_features(session: "Session", enable: bool) -> str:
|
||||||
"""Enable or disable experimental message types.
|
"""Enable or disable experimental message types.
|
||||||
|
|
||||||
This is a developer feature. Use with caution.
|
This is a developer feature. Use with caution.
|
||||||
"""
|
"""
|
||||||
return device.apply_settings(client, experimental_features=enable)
|
return device.apply_settings(session, experimental_features=enable)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -426,25 +427,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
|
|||||||
|
|
||||||
@passphrase.command(name="on")
|
@passphrase.command(name="on")
|
||||||
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
|
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str:
|
||||||
"""Enable passphrase."""
|
"""Enable passphrase."""
|
||||||
if client.features.passphrase_protection is not True:
|
if session.features.passphrase_protection is not True:
|
||||||
use_passphrase = True
|
use_passphrase = True
|
||||||
else:
|
else:
|
||||||
use_passphrase = None
|
use_passphrase = None
|
||||||
return device.apply_settings(
|
return device.apply_settings(
|
||||||
client,
|
session,
|
||||||
use_passphrase=use_passphrase,
|
use_passphrase=use_passphrase,
|
||||||
passphrase_always_on_device=force_on_device,
|
passphrase_always_on_device=force_on_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@passphrase.command(name="off")
|
@passphrase.command(name="off")
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def passphrase_off(client: "TrezorClient") -> str:
|
def passphrase_off(session: "Session") -> str:
|
||||||
"""Disable passphrase."""
|
"""Disable passphrase."""
|
||||||
return device.apply_settings(client, use_passphrase=False)
|
return device.apply_settings(session, use_passphrase=False)
|
||||||
|
|
||||||
|
|
||||||
# Registering the aliases for backwards compatibility
|
# Registering the aliases for backwards compatibility
|
||||||
@ -457,10 +458,10 @@ passphrase.aliases = {
|
|||||||
|
|
||||||
@passphrase.command(name="hide")
|
@passphrase.command(name="hide")
|
||||||
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
|
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
|
||||||
@with_client
|
@with_session(management=True)
|
||||||
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str:
|
def hide_passphrase_from_host(session: "Session", hide: bool) -> str:
|
||||||
"""Enable or disable hiding passphrase coming from host.
|
"""Enable or disable hiding passphrase coming from host.
|
||||||
|
|
||||||
This is a developer feature. Use with caution.
|
This is a developer feature. Use with caution.
|
||||||
"""
|
"""
|
||||||
return device.apply_settings(client, hide_passphrase_from_host=hide)
|
return device.apply_settings(session, hide_passphrase_from_host=hide)
|
||||||
|
@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import messages, solana, tools
|
from .. import messages, solana, tools
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
|
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
|
||||||
DEFAULT_PATH = "m/44h/501h/0h/0h"
|
DEFAULT_PATH = "m/44h/501h/0h/0h"
|
||||||
@ -21,40 +21,40 @@ def cli() -> None:
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
) -> messages.SolanaPublicKey:
|
) -> messages.SolanaPublicKey:
|
||||||
"""Get Solana public key."""
|
"""Get Solana public key."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return solana.get_public_key(client, address_n, show_display)
|
return solana.get_public_key(session, address_n, show_display)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
chunkify: bool,
|
chunkify: bool,
|
||||||
) -> messages.SolanaAddress:
|
) -> messages.SolanaAddress:
|
||||||
"""Get Solana address."""
|
"""Get Solana address."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return solana.get_address(client, address_n, show_display, chunkify)
|
return solana.get_address(session, address_n, show_display, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("serialized_tx", type=str)
|
@click.argument("serialized_tx", type=str)
|
||||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||||
@click.option("-a", "--additional-information-file", type=click.File("r"))
|
@click.option("-a", "--additional-information-file", type=click.File("r"))
|
||||||
@with_client
|
@with_session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
serialized_tx: str,
|
serialized_tx: str,
|
||||||
additional_information_file: Optional[TextIO],
|
additional_information_file: Optional[TextIO],
|
||||||
@ -78,7 +78,7 @@ def sign_tx(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return solana.sign_tx(
|
return solana.sign_tx(
|
||||||
client,
|
session,
|
||||||
address_n,
|
address_n,
|
||||||
bytes.fromhex(serialized_tx),
|
bytes.fromhex(serialized_tx),
|
||||||
additional_information,
|
additional_information,
|
||||||
|
@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import stellar, tools
|
from .. import stellar, tools
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from stellar_sdk import (
|
from stellar_sdk import (
|
||||||
@ -52,13 +52,13 @@ def cli() -> None:
|
|||||||
)
|
)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get Stellar public address."""
|
"""Get Stellar public address."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return stellar.get_address(client, address_n, show_display, chunkify)
|
return stellar.get_address(session, address_n, show_display, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -77,9 +77,9 @@ def get_address(
|
|||||||
help="Network passphrase (blank for public network).",
|
help="Network passphrase (blank for public network).",
|
||||||
)
|
)
|
||||||
@click.argument("b64envelope")
|
@click.argument("b64envelope")
|
||||||
@with_client
|
@with_session
|
||||||
def sign_transaction(
|
def sign_transaction(
|
||||||
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
|
session: "Session", b64envelope: str, address: str, network_passphrase: str
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Sign a base64-encoded transaction envelope.
|
"""Sign a base64-encoded transaction envelope.
|
||||||
|
|
||||||
@ -109,6 +109,6 @@ def sign_transaction(
|
|||||||
|
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
tx, operations = stellar.from_envelope(envelope)
|
tx, operations = stellar.from_envelope(envelope)
|
||||||
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
|
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
|
||||||
|
|
||||||
return base64.b64encode(resp.signature)
|
return base64.b64encode(resp.signature)
|
||||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import messages, protobuf, tezos, tools
|
from .. import messages, protobuf, tezos, tools
|
||||||
from . import with_client
|
from . import with_session
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
|
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
|
||||||
|
|
||||||
@ -37,23 +37,23 @@ def cli() -> None:
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Get Tezos address for specified path."""
|
"""Get Tezos address for specified path."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return tezos.get_address(client, address_n, show_display, chunkify)
|
return tezos.get_address(session, address_n, show_display, chunkify)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-d", "--show-display", is_flag=True)
|
@click.option("-d", "--show-display", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||||
"""Get Tezos public key."""
|
"""Get Tezos public key."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
return tezos.get_public_key(client, address_n, show_display)
|
return tezos.get_public_key(session, address_n, show_display)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
|||||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||||
@click.option("-C", "--chunkify", is_flag=True)
|
@click.option("-C", "--chunkify", is_flag=True)
|
||||||
@with_client
|
@with_session
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||||
) -> messages.TezosSignedTx:
|
) -> messages.TezosSignedTx:
|
||||||
"""Sign Tezos transaction."""
|
"""Sign Tezos transaction."""
|
||||||
address_n = tools.parse_path(address)
|
address_n = tools.parse_path(address)
|
||||||
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
|
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
|
||||||
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify)
|
return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)
|
||||||
|
@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from .. import __version__, log, messages, protobuf, ui
|
from .. import __version__, log, messages, protobuf
|
||||||
from ..client import TrezorClient
|
from ..client import ProtocolVersion, TrezorClient
|
||||||
from ..transport import DeviceIsBusy, enumerate_devices
|
from ..transport import DeviceIsBusy, enumerate_devices
|
||||||
|
from ..transport.session import Session
|
||||||
|
from ..transport.thp import channel_database
|
||||||
|
from ..transport.thp.channel_database import get_channel_db
|
||||||
from ..transport.udp import UdpTransport
|
from ..transport.udp import UdpTransport
|
||||||
from . import (
|
from . import (
|
||||||
AliasedGroup,
|
AliasedGroup,
|
||||||
@ -50,6 +53,7 @@ from . import (
|
|||||||
stellar,
|
stellar,
|
||||||
tezos,
|
tezos,
|
||||||
with_client,
|
with_client,
|
||||||
|
with_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable)
|
F = TypeVar("F", bound=Callable)
|
||||||
@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None:
|
|||||||
"--record",
|
"--record",
|
||||||
help="Record screen changes into a specified directory.",
|
help="Record screen changes into a specified directory.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"-n",
|
||||||
|
"--no-store",
|
||||||
|
is_flag=True,
|
||||||
|
help="Do not store channels data between commands.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
@click.version_option(version=__version__)
|
@click.version_option(version=__version__)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_main(
|
def cli_main(
|
||||||
@ -204,9 +215,10 @@ def cli_main(
|
|||||||
script: bool,
|
script: bool,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
record: Optional[str],
|
record: Optional[str],
|
||||||
|
no_store: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
configure_logging(verbose)
|
configure_logging(verbose)
|
||||||
|
channel_database.set_channel_database(should_not_store=no_store)
|
||||||
bytes_session_id: Optional[bytes] = None
|
bytes_session_id: Optional[bytes] = None
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
try:
|
try:
|
||||||
@ -214,6 +226,7 @@ def cli_main(
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
raise click.ClickException(f"Not a valid session id: {session_id}")
|
raise click.ClickException(f"Not a valid session id: {session_id}")
|
||||||
|
|
||||||
|
# ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
|
||||||
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
|
ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script)
|
||||||
|
|
||||||
# Optionally record the screen into a specified directory.
|
# Optionally record the screen into a specified directory.
|
||||||
@ -285,18 +298,23 @@ def format_device_name(features: messages.Features) -> str:
|
|||||||
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||||
"""List connected Trezor devices."""
|
"""List connected Trezor devices."""
|
||||||
if no_resolve:
|
if no_resolve:
|
||||||
return enumerate_devices()
|
for d in enumerate_devices():
|
||||||
|
print(d.get_path())
|
||||||
|
return
|
||||||
|
|
||||||
|
from . import get_client
|
||||||
|
|
||||||
for transport in enumerate_devices():
|
for transport in enumerate_devices():
|
||||||
try:
|
try:
|
||||||
client = TrezorClient(transport, ui=ui.ClickUI())
|
client = get_client(transport)
|
||||||
description = format_device_name(client.features)
|
description = format_device_name(client.features)
|
||||||
client.end_session()
|
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||||
|
get_channel_db().save_channel(client.protocol)
|
||||||
except DeviceIsBusy:
|
except DeviceIsBusy:
|
||||||
description = "Device is in use by another process"
|
description = "Device is in use by another process"
|
||||||
except Exception:
|
except Exception as e:
|
||||||
description = "Failed to read details"
|
description = "Failed to read details " + str(type(e))
|
||||||
click.echo(f"{transport} - {description}")
|
click.echo(f"{transport.get_path()} - {description}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -314,15 +332,19 @@ def version() -> str:
|
|||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("message")
|
@click.argument("message")
|
||||||
@click.option("-b", "--button-protection", is_flag=True)
|
@click.option("-b", "--button-protection", is_flag=True)
|
||||||
@with_client
|
@with_session(empty_passphrase=True)
|
||||||
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
|
def ping(session: "Session", message: str, button_protection: bool) -> str:
|
||||||
"""Send ping message."""
|
"""Send ping message."""
|
||||||
return client.ping(message, button_protection=button_protection)
|
|
||||||
|
# TODO return short-circuit from old client for old Trezors
|
||||||
|
return session.ping(message, button_protection)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.pass_obj
|
@click.pass_obj
|
||||||
def get_session(obj: TrezorConnection) -> str:
|
def get_session(
|
||||||
|
obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False
|
||||||
|
) -> str:
|
||||||
"""Get a session ID for subsequent commands.
|
"""Get a session ID for subsequent commands.
|
||||||
|
|
||||||
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
|
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
|
||||||
@ -336,23 +358,44 @@ def get_session(obj: TrezorConnection) -> str:
|
|||||||
obj.session_id = None
|
obj.session_id = None
|
||||||
|
|
||||||
with obj.client_context() as client:
|
with obj.client_context() as client:
|
||||||
|
|
||||||
if client.features.model == "1" and client.version < (1, 9, 0):
|
if client.features.model == "1" and client.version < (1, 9, 0):
|
||||||
raise click.ClickException(
|
raise click.ClickException(
|
||||||
"Upgrade your firmware to enable session support."
|
"Upgrade your firmware to enable session support."
|
||||||
)
|
)
|
||||||
|
|
||||||
client.ensure_unlocked()
|
# client.ensure_unlocked()
|
||||||
if client.session_id is None:
|
session = client.get_session(
|
||||||
|
passphrase=passphrase, derive_cardano=derive_cardano
|
||||||
|
)
|
||||||
|
if session.id is None:
|
||||||
raise click.ClickException("Passphrase not enabled or firmware too old.")
|
raise click.ClickException("Passphrase not enabled or firmware too old.")
|
||||||
else:
|
else:
|
||||||
return client.session_id.hex()
|
return session.id.hex()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@with_client
|
@with_session(must_resume=True, empty_passphrase=True)
|
||||||
def clear_session(client: "TrezorClient") -> None:
|
def clear_session(session: "Session") -> None:
|
||||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||||
return client.clear_session()
|
if session is None:
|
||||||
|
click.echo("Cannot clear session as it was not properly resumed.")
|
||||||
|
return
|
||||||
|
session.call(messages.LockDevice())
|
||||||
|
session.end()
|
||||||
|
# TODO different behaviour than main, not sure if ok
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
def delete_channels() -> None:
|
||||||
|
"""
|
||||||
|
Delete cached channels.
|
||||||
|
|
||||||
|
Do not use together with the `-n` (`--no-store`) flag,
|
||||||
|
as the JSON database will not be deleted in that case.
|
||||||
|
"""
|
||||||
|
get_channel_db().clear_stored_channels()
|
||||||
|
click.echo("Deleted stored channels")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
|||||||
# This file is part of the Trezor project.
|
# This file is part of the Trezor project.
|
||||||
#
|
#
|
||||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||||
#
|
#
|
||||||
# This library is free software: you can redistribute it and/or modify
|
# This library is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
@ -21,47 +21,44 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
|
import typing as t
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum, IntEnum, auto
|
from enum import Enum, IntEnum, auto
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
from . import mapping, messages, models, protobuf
|
from . import btc, mapping, messages, models, protobuf
|
||||||
from .client import TrezorClient
|
from .client import (
|
||||||
from .exceptions import TrezorFailure
|
MAX_PASSPHRASE_LENGTH,
|
||||||
|
MAX_PIN_LENGTH,
|
||||||
|
PASSPHRASE_ON_DEVICE,
|
||||||
|
TrezorClient,
|
||||||
|
)
|
||||||
|
from .exceptions import Cancelled, PinException, TrezorFailure
|
||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import DebugWaitType
|
from .messages import Capability, DebugWaitType
|
||||||
from .tools import expect
|
from .tools import expect, parse_path
|
||||||
|
from .transport.session import Session, SessionV1
|
||||||
|
from .transport.thp.protocol_v1 import ProtocolV1
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from .messages import PinMatrixRequestType
|
from .messages import PinMatrixRequestType
|
||||||
from .transport import Transport
|
from .transport import Transport
|
||||||
|
|
||||||
ExpectedMessage = Union[
|
ExpectedMessage = t.Union[
|
||||||
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
|
protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
|
||||||
]
|
]
|
||||||
|
|
||||||
AnyDict = Dict[str, Any]
|
AnyDict = t.Dict[str, t.Any]
|
||||||
|
|
||||||
class InputFunc(Protocol):
|
class InputFunc(Protocol):
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
hold_ms: int | None = None,
|
hold_ms: int | None = None,
|
||||||
@ -70,6 +67,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
EXPECTED_RESPONSES_CONTEXT_LINES = 3
|
||||||
|
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -104,11 +102,13 @@ class UnstructuredJSONReader:
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
self.dict = {}
|
self.dict = {}
|
||||||
|
|
||||||
def top_level_value(self, key: str) -> Any:
|
def top_level_value(self, key: str) -> t.Any:
|
||||||
return self.dict.get(key)
|
return self.dict.get(key)
|
||||||
|
|
||||||
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
|
def find_objects_with_key_and_value(
|
||||||
def recursively_find(data: Any) -> Iterator[Any]:
|
self, key: str, value: t.Any
|
||||||
|
) -> list["AnyDict"]:
|
||||||
|
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
if data.get(key) == value:
|
if data.get(key) == value:
|
||||||
yield data
|
yield data
|
||||||
@ -121,7 +121,7 @@ class UnstructuredJSONReader:
|
|||||||
return list(recursively_find(self.dict))
|
return list(recursively_find(self.dict))
|
||||||
|
|
||||||
def find_unique_object_with_key_and_value(
|
def find_unique_object_with_key_and_value(
|
||||||
self, key: str, value: Any
|
self, key: str, value: t.Any
|
||||||
) -> AnyDict | None:
|
) -> AnyDict | None:
|
||||||
objects = self.find_objects_with_key_and_value(key, value)
|
objects = self.find_objects_with_key_and_value(key, value)
|
||||||
if not objects:
|
if not objects:
|
||||||
@ -129,8 +129,10 @@ class UnstructuredJSONReader:
|
|||||||
assert len(objects) == 1
|
assert len(objects) == 1
|
||||||
return objects[0]
|
return objects[0]
|
||||||
|
|
||||||
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
|
def find_values_by_key(
|
||||||
def recursively_find(data: Any) -> Iterator[Any]:
|
self, key: str, only_type: type | None = None
|
||||||
|
) -> list[t.Any]:
|
||||||
|
def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
if key in data:
|
if key in data:
|
||||||
yield data[key]
|
yield data[key]
|
||||||
@ -148,8 +150,8 @@ class UnstructuredJSONReader:
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
def find_unique_value_by_key(
|
def find_unique_value_by_key(
|
||||||
self, key: str, default: Any, only_type: type | None = None
|
self, key: str, default: t.Any, only_type: type | None = None
|
||||||
) -> Any:
|
) -> t.Any:
|
||||||
values = self.find_values_by_key(key, only_type=only_type)
|
values = self.find_values_by_key(key, only_type=only_type)
|
||||||
if not values:
|
if not values:
|
||||||
return default
|
return default
|
||||||
@ -160,7 +162,7 @@ class UnstructuredJSONReader:
|
|||||||
class LayoutContent(UnstructuredJSONReader):
|
class LayoutContent(UnstructuredJSONReader):
|
||||||
"""Contains helper functions to extract specific parts of the layout."""
|
"""Contains helper functions to extract specific parts of the layout."""
|
||||||
|
|
||||||
def __init__(self, json_tokens: Sequence[str]) -> None:
|
def __init__(self, json_tokens: t.Sequence[str]) -> None:
|
||||||
json_str = "".join(json_tokens)
|
json_str = "".join(json_tokens)
|
||||||
super().__init__(json_str)
|
super().__init__(json_str)
|
||||||
|
|
||||||
@ -422,11 +424,13 @@ def _make_input_func(
|
|||||||
|
|
||||||
|
|
||||||
class DebugLink:
|
class DebugLink:
|
||||||
|
|
||||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.allow_interactions = auto_interact
|
self.allow_interactions = auto_interact
|
||||||
self.mapping = mapping.DEFAULT_MAPPING
|
self.mapping = mapping.DEFAULT_MAPPING
|
||||||
|
|
||||||
|
self.protocol = ProtocolV1(self.transport, self.mapping)
|
||||||
# To be set by TrezorClientDebugLink (is not known during creation time)
|
# To be set by TrezorClientDebugLink (is not known during creation time)
|
||||||
self.model: models.TrezorModel | None = None
|
self.model: models.TrezorModel | None = None
|
||||||
self.version: tuple[int, int, int] = (0, 0, 0)
|
self.version: tuple[int, int, int] = (0, 0, 0)
|
||||||
@ -479,10 +483,16 @@ class DebugLink:
|
|||||||
self.screen_text_file = file_path
|
self.screen_text_file = file_path
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.transport.begin_session()
|
self.transport.open()
|
||||||
|
# raise NotImplementedError
|
||||||
|
# TODO is this needed?
|
||||||
|
# self.transport.deprecated_begin_session()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.transport.end_session()
|
pass
|
||||||
|
# raise NotImplementedError
|
||||||
|
# TODO is this needed?
|
||||||
|
# self.transport.deprecated_end_session()
|
||||||
|
|
||||||
def _write(self, msg: protobuf.MessageType) -> None:
|
def _write(self, msg: protobuf.MessageType) -> None:
|
||||||
if self.waiting_for_layout_change:
|
if self.waiting_for_layout_change:
|
||||||
@ -499,15 +509,10 @@ class DebugLink:
|
|||||||
DUMP_BYTES,
|
DUMP_BYTES,
|
||||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
)
|
)
|
||||||
self.transport.write(msg_type, msg_bytes)
|
self.protocol.write(msg)
|
||||||
|
|
||||||
def _read(self) -> protobuf.MessageType:
|
def _read(self) -> protobuf.MessageType:
|
||||||
ret_type, ret_bytes = self.transport.read()
|
msg = self.protocol.read()
|
||||||
LOG.log(
|
|
||||||
DUMP_BYTES,
|
|
||||||
f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
|
|
||||||
)
|
|
||||||
msg = self.mapping.decode(ret_type, ret_bytes)
|
|
||||||
|
|
||||||
# Collapse tokens to make log use less lines.
|
# Collapse tokens to make log use less lines.
|
||||||
msg_for_log = msg
|
msg_for_log = msg
|
||||||
@ -521,18 +526,27 @@ class DebugLink:
|
|||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _call(self, msg: protobuf.MessageType) -> Any:
|
def _call(self, msg: protobuf.MessageType) -> t.Any:
|
||||||
self._write(msg)
|
self._write(msg)
|
||||||
return self._read()
|
return self._read()
|
||||||
|
|
||||||
def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkState:
|
def state(
|
||||||
|
self,
|
||||||
|
wait_type: DebugWaitType | None = None,
|
||||||
|
thp_channel_id: bytes | None = None,
|
||||||
|
) -> messages.DebugLinkState:
|
||||||
if wait_type is None:
|
if wait_type is None:
|
||||||
wait_type = (
|
wait_type = (
|
||||||
DebugWaitType.CURRENT_LAYOUT
|
DebugWaitType.CURRENT_LAYOUT
|
||||||
if self.has_global_layout
|
if self.has_global_layout
|
||||||
else DebugWaitType.IMMEDIATE
|
else DebugWaitType.IMMEDIATE
|
||||||
)
|
)
|
||||||
result = self._call(messages.DebugLinkGetState(wait_layout=wait_type))
|
result = self._call(
|
||||||
|
messages.DebugLinkGetState(
|
||||||
|
wait_layout=wait_type,
|
||||||
|
thp_channel_id=thp_channel_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
while not isinstance(result, (messages.Failure, messages.DebugLinkState)):
|
while not isinstance(result, (messages.Failure, messages.DebugLinkState)):
|
||||||
result = self._read()
|
result = self._read()
|
||||||
if isinstance(result, messages.Failure):
|
if isinstance(result, messages.Failure):
|
||||||
@ -544,7 +558,7 @@ class DebugLink:
|
|||||||
|
|
||||||
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
|
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
|
||||||
# Next layout change will be caused by external event
|
# Next layout change will be caused by external event
|
||||||
# (e.g. device being auto-locked or as a result of device_handler.run(xxx))
|
# (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
|
||||||
# and not by our debug actions/decisions.
|
# and not by our debug actions/decisions.
|
||||||
# Resetting the debug state so we wait for the next layout change
|
# Resetting the debug state so we wait for the next layout change
|
||||||
# (and do not return the current state).
|
# (and do not return the current state).
|
||||||
@ -560,7 +574,7 @@ class DebugLink:
|
|||||||
return LayoutContent(obj.tokens)
|
return LayoutContent(obj.tokens)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def wait_for_layout_change(self) -> Iterator[LayoutContent]:
|
def wait_for_layout_change(self) -> t.Iterator[LayoutContent]:
|
||||||
# set up a dummy layout content object to be yielded
|
# set up a dummy layout content object to be yielded
|
||||||
layout_content = LayoutContent(
|
layout_content = LayoutContent(
|
||||||
["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("]
|
["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("]
|
||||||
@ -622,7 +636,7 @@ class DebugLink:
|
|||||||
|
|
||||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||||
|
|
||||||
def read_recovery_word(self) -> Tuple[str | None, int | None]:
|
def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
|
||||||
state = self.state()
|
state = self.state()
|
||||||
return (state.recovery_fake_word, state.recovery_word_pos)
|
return (state.recovery_fake_word, state.recovery_word_pos)
|
||||||
|
|
||||||
@ -700,7 +714,7 @@ class DebugLink:
|
|||||||
|
|
||||||
def click(
|
def click(
|
||||||
self,
|
self,
|
||||||
click: Tuple[int, int],
|
click: t.Tuple[int, int],
|
||||||
hold_ms: int | None = None,
|
hold_ms: int | None = None,
|
||||||
wait: bool | None = None,
|
wait: bool | None = None,
|
||||||
) -> LayoutContent:
|
) -> LayoutContent:
|
||||||
@ -862,10 +876,10 @@ class DebugUI:
|
|||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.pins: Iterator[str] | None = None
|
self.pins: t.Iterator[str] | None = None
|
||||||
self.passphrase = ""
|
self.passphrase = ""
|
||||||
self.input_flow: Union[
|
self.input_flow: t.Union[
|
||||||
Generator[None, messages.ButtonRequest, None], object, None
|
t.Generator[None, messages.ButtonRequest, None], object, None
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
|
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
|
||||||
@ -896,7 +910,7 @@ class DebugUI:
|
|||||||
raise AssertionError("input flow ended prematurely")
|
raise AssertionError("input flow ended prematurely")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
assert isinstance(self.input_flow, Generator)
|
assert isinstance(self.input_flow, t.Generator)
|
||||||
self.input_flow.send(br)
|
self.input_flow.send(br)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.input_flow = self.INPUT_FLOW_DONE
|
self.input_flow = self.INPUT_FLOW_DONE
|
||||||
@ -918,12 +932,15 @@ class DebugUI:
|
|||||||
|
|
||||||
|
|
||||||
class MessageFilter:
|
class MessageFilter:
|
||||||
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
|
|
||||||
|
def __init__(
|
||||||
|
self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
|
||||||
|
) -> None:
|
||||||
self.message_type = message_type
|
self.message_type = message_type
|
||||||
self.fields: Dict[str, Any] = {}
|
self.fields: t.Dict[str, t.Any] = {}
|
||||||
self.update_fields(**fields)
|
self.update_fields(**fields)
|
||||||
|
|
||||||
def update_fields(self, **fields: Any) -> "MessageFilter":
|
def update_fields(self, **fields: t.Any) -> "MessageFilter":
|
||||||
for name, value in fields.items():
|
for name, value in fields.items():
|
||||||
try:
|
try:
|
||||||
self.fields[name] = self.from_message_or_type(value)
|
self.fields[name] = self.from_message_or_type(value)
|
||||||
@ -971,7 +988,7 @@ class MessageFilter:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def to_string(self, maxwidth: int = 80) -> str:
|
def to_string(self, maxwidth: int = 80) -> str:
|
||||||
fields: list[Tuple[str, str]] = []
|
fields: list[t.Tuple[str, str]] = []
|
||||||
for field in self.message_type.FIELDS.values():
|
for field in self.message_type.FIELDS.values():
|
||||||
if field.name not in self.fields:
|
if field.name not in self.fields:
|
||||||
continue
|
continue
|
||||||
@ -1001,7 +1018,8 @@ class MessageFilter:
|
|||||||
|
|
||||||
|
|
||||||
class MessageFilterGenerator:
|
class MessageFilterGenerator:
|
||||||
def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
|
|
||||||
|
def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
|
||||||
message_type = getattr(messages, key)
|
message_type = getattr(messages, key)
|
||||||
return MessageFilter(message_type).update_fields
|
return MessageFilter(message_type).update_fields
|
||||||
|
|
||||||
@ -1009,6 +1027,245 @@ class MessageFilterGenerator:
|
|||||||
message_filters = MessageFilterGenerator()
|
message_filters = MessageFilterGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionDebugWrapper(Session):
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self._session = session
|
||||||
|
self.reset_debug_features()
|
||||||
|
if isinstance(session, SessionDebugWrapper):
|
||||||
|
raise Exception("Cannot wrap already wrapped session!")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def protocol_version(self) -> int:
|
||||||
|
return self.client.protocol_version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> TrezorClientDebugLink:
|
||||||
|
assert isinstance(self._session.client, TrezorClientDebugLink)
|
||||||
|
return self._session.client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> bytes:
|
||||||
|
return self._session.id
|
||||||
|
|
||||||
|
def _write(self, msg: t.Any) -> None:
|
||||||
|
print("writing message:", msg.__class__.__name__)
|
||||||
|
self._session._write(self._filter_message(msg))
|
||||||
|
|
||||||
|
def _read(self) -> t.Any:
|
||||||
|
resp = self._filter_message(self._session._read())
|
||||||
|
print("reading message:", resp.__class__.__name__)
|
||||||
|
if self.actual_responses is not None:
|
||||||
|
self.actual_responses.append(resp)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def set_expected_responses(
|
||||||
|
self,
|
||||||
|
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
|
||||||
|
) -> None:
|
||||||
|
"""Set a sequence of expected responses to session calls.
|
||||||
|
|
||||||
|
Within a given with-block, the list of received responses from device must
|
||||||
|
match the list of expected responses, otherwise an ``AssertionError`` is raised.
|
||||||
|
|
||||||
|
If an expected response is given a field value other than ``None``, that field value
|
||||||
|
must exactly match the received field value. If a given field is ``None``
|
||||||
|
(or unspecified) in the expected response, the received field value is not
|
||||||
|
checked.
|
||||||
|
|
||||||
|
Each expected response can also be a tuple ``(bool, message)``. In that case, the
|
||||||
|
expected response is only evaluated if the first field is ``True``.
|
||||||
|
This is useful for differentiating sequences between Trezor models:
|
||||||
|
|
||||||
|
>>> trezor_one = session.features.model == "1"
|
||||||
|
>>> session.set_expected_responses([
|
||||||
|
>>> messages.ButtonRequest(code=ConfirmOutput),
|
||||||
|
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
|
||||||
|
>>> messages.Success(),
|
||||||
|
>>> ])
|
||||||
|
"""
|
||||||
|
if not self.in_with_statement:
|
||||||
|
raise RuntimeError("Must be called inside 'with' statement")
|
||||||
|
|
||||||
|
# make sure all items are (bool, message) tuples
|
||||||
|
expected_with_validity = (
|
||||||
|
e if isinstance(e, tuple) else (True, e) for e in expected
|
||||||
|
)
|
||||||
|
|
||||||
|
# only apply those items that are (True, message)
|
||||||
|
self.expected_responses = [
|
||||||
|
MessageFilter.from_message_or_type(expected)
|
||||||
|
for valid, expected in expected_with_validity
|
||||||
|
if valid
|
||||||
|
]
|
||||||
|
self.actual_responses = []
|
||||||
|
|
||||||
|
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||||
|
"""Lock the device.
|
||||||
|
|
||||||
|
If the device does not have a PIN configured, this will do nothing.
|
||||||
|
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
||||||
|
before further actions.
|
||||||
|
|
||||||
|
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
||||||
|
the device will not prompt for it after unlocking.
|
||||||
|
|
||||||
|
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
||||||
|
passphrase cache, use `clear_session()`.
|
||||||
|
"""
|
||||||
|
# TODO update the documentation above
|
||||||
|
# Private argument _refresh_features can be used internally to avoid
|
||||||
|
# refreshing in cases where we will refresh soon anyway. This is used
|
||||||
|
# in TrezorClient.clear_session()
|
||||||
|
self.call(messages.LockDevice())
|
||||||
|
if _refresh_features:
|
||||||
|
self.refresh_features()
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
self._write(messages.Cancel())
|
||||||
|
|
||||||
|
def ensure_unlocked(self) -> None:
|
||||||
|
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||||
|
self.refresh_features()
|
||||||
|
|
||||||
|
def set_filter(
|
||||||
|
self,
|
||||||
|
message_type: t.Type[protobuf.MessageType],
|
||||||
|
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Configure a filter function for a specified message type.
|
||||||
|
|
||||||
|
The `callback` must be a function that accepts a protobuf message, and returns
|
||||||
|
a (possibly modified) protobuf message of the same type. Whenever a message
|
||||||
|
is sent or received that matches `message_type`, `callback` is invoked on the
|
||||||
|
message and its result is substituted for the original.
|
||||||
|
|
||||||
|
Useful for test scenarios with an active malicious actor on the wire.
|
||||||
|
"""
|
||||||
|
if not self.in_with_statement:
|
||||||
|
raise RuntimeError("Must be called inside 'with' statement")
|
||||||
|
|
||||||
|
self.filters[message_type] = callback
|
||||||
|
|
||||||
|
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
|
||||||
|
message_type = msg.__class__
|
||||||
|
callback = self.filters.get(message_type)
|
||||||
|
if callable(callback):
|
||||||
|
return callback(deepcopy(msg))
|
||||||
|
else:
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def reset_debug_features(self) -> None:
|
||||||
|
"""Prepare the debugging session for a new testcase.
|
||||||
|
|
||||||
|
Clears all debugging state that might have been modified by a testcase.
|
||||||
|
"""
|
||||||
|
self.in_with_statement = False
|
||||||
|
self.expected_responses: list[MessageFilter] | None = None
|
||||||
|
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||||
|
self.filters: t.Dict[
|
||||||
|
t.Type[protobuf.MessageType],
|
||||||
|
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
|
] = {}
|
||||||
|
self.button_callback = self.client.button_callback
|
||||||
|
self.pin_callback = self.client.pin_callback
|
||||||
|
self.passphrase_callback = self._session.passphrase_callback
|
||||||
|
self.passphrase = self._session.passphrase
|
||||||
|
|
||||||
|
def __enter__(self) -> "SessionDebugWrapper":
|
||||||
|
# For usage in with/expected_responses
|
||||||
|
if self.in_with_statement:
|
||||||
|
raise RuntimeError("Do not nest!")
|
||||||
|
self.in_with_statement = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
|
# copy expected/actual responses before clearing them
|
||||||
|
expected_responses = self.expected_responses
|
||||||
|
actual_responses = self.actual_responses
|
||||||
|
|
||||||
|
# grab a copy of the inputflow generator to raise an exception through it
|
||||||
|
if isinstance(self.client.ui, DebugUI):
|
||||||
|
input_flow = self.client.ui.input_flow
|
||||||
|
else:
|
||||||
|
input_flow = None
|
||||||
|
|
||||||
|
self.reset_debug_features()
|
||||||
|
|
||||||
|
if exc_type is None:
|
||||||
|
# If no other exception was raised, evaluate missed responses
|
||||||
|
# (raises AssertionError on mismatch)
|
||||||
|
self._verify_responses(expected_responses, actual_responses)
|
||||||
|
if isinstance(input_flow, t.Generator):
|
||||||
|
# Ensure that the input flow is exhausted
|
||||||
|
try:
|
||||||
|
input_flow.throw(
|
||||||
|
AssertionError("input flow continues past end of test")
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif isinstance(input_flow, t.Generator):
|
||||||
|
# Propagate the exception through the input flow, so that we see in
|
||||||
|
# traceback where it is stuck.
|
||||||
|
input_flow.throw(exc_type, value, traceback)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _verify_responses(
|
||||||
|
cls,
|
||||||
|
expected: list[MessageFilter] | None,
|
||||||
|
actual: list[protobuf.MessageType] | None,
|
||||||
|
) -> None:
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
|
if expected is None and actual is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert expected is not None
|
||||||
|
assert actual is not None
|
||||||
|
|
||||||
|
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
||||||
|
if exp is None:
|
||||||
|
output = cls._expectation_lines(expected, i)
|
||||||
|
output.append("No more messages were expected, but we got:")
|
||||||
|
for resp in actual[i:]:
|
||||||
|
output.append(
|
||||||
|
textwrap.indent(protobuf.format_message(resp), " ")
|
||||||
|
)
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
|
if act is None:
|
||||||
|
output = cls._expectation_lines(expected, i)
|
||||||
|
output.append("This and the following message was not received.")
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
|
if not exp.match(act):
|
||||||
|
output = cls._expectation_lines(expected, i)
|
||||||
|
output.append("Actually received:")
|
||||||
|
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||||
|
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||||
|
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||||
|
output: list[str] = []
|
||||||
|
output.append("Expected responses:")
|
||||||
|
if start_at > 0:
|
||||||
|
output.append(f" (...{start_at} previous responses omitted)")
|
||||||
|
for i in range(start_at, stop_at):
|
||||||
|
exp = expected[i]
|
||||||
|
prefix = " " if i != current else ">>> "
|
||||||
|
output.append(textwrap.indent(exp.to_string(), prefix))
|
||||||
|
if stop_at < len(expected):
|
||||||
|
omitted = len(expected) - stop_at
|
||||||
|
output.append(f" (...{omitted} following responses omitted)")
|
||||||
|
|
||||||
|
output.append("")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TrezorClientDebugLink(TrezorClient):
|
class TrezorClientDebugLink(TrezorClient):
|
||||||
# This class implements automatic responses
|
# This class implements automatic responses
|
||||||
# and other functionality for unit tests
|
# and other functionality for unit tests
|
||||||
@ -1034,54 +1291,165 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# set transport explicitly so that sync_responses can work
|
# set transport explicitly so that sync_responses can work
|
||||||
|
super().__init__(transport)
|
||||||
|
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
self.ui: DebugUI = DebugUI(self.debug)
|
||||||
|
|
||||||
self.reset_debug_features()
|
self.reset_debug_features(new_management_session=True)
|
||||||
self.sync_responses()
|
self.sync_responses()
|
||||||
super().__init__(transport, ui=self.ui)
|
|
||||||
|
|
||||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
# So that we can choose right screenshotting logic (T1 vs TT)
|
||||||
# and know the supported debug capabilities
|
# and know the supported debug capabilities
|
||||||
self.debug.model = self.model
|
self.debug.model = self.model
|
||||||
self.debug.version = self.version
|
self.debug.version = self.version
|
||||||
|
self.passphrase: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layout_type(self) -> LayoutType:
|
def layout_type(self) -> LayoutType:
|
||||||
return self.debug.layout_type
|
return self.debug.layout_type
|
||||||
|
|
||||||
def reset_debug_features(self) -> None:
|
def get_new_client(self) -> TrezorClientDebugLink:
|
||||||
"""Prepare the debugging client for a new testcase.
|
return TrezorClientDebugLink(self.transport, self.debug.allow_interactions)
|
||||||
|
|
||||||
|
def reset_debug_features(self, new_management_session: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Prepare the debugging client for a new testcase.
|
||||||
|
|
||||||
Clears all debugging state that might have been modified by a testcase.
|
Clears all debugging state that might have been modified by a testcase.
|
||||||
"""
|
"""
|
||||||
self.ui: DebugUI = DebugUI(self.debug)
|
self.ui: DebugUI = DebugUI(self.debug)
|
||||||
|
# self.pin_callback = self.ui.debug_callback_button
|
||||||
self.in_with_statement = False
|
self.in_with_statement = False
|
||||||
self.expected_responses: list[MessageFilter] | None = None
|
self.expected_responses: list[MessageFilter] | None = None
|
||||||
self.actual_responses: list[protobuf.MessageType] | None = None
|
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||||
self.filters: dict[
|
self.filters: t.Dict[
|
||||||
type[protobuf.MessageType],
|
t.Type[protobuf.MessageType],
|
||||||
Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
] = {}
|
] = {}
|
||||||
|
if new_management_session:
|
||||||
|
self._management_session = self.get_management_session(new_session=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def button_callback(self):
|
||||||
|
|
||||||
|
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
# do this raw - send ButtonAck first, notify UI later
|
||||||
|
session._write(messages.ButtonAck())
|
||||||
|
self.ui.button_request(msg)
|
||||||
|
return session._read()
|
||||||
|
|
||||||
|
return _callback_button
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pin_callback(self):
|
||||||
|
|
||||||
|
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
|
||||||
|
try:
|
||||||
|
pin = self.ui.get_pin(msg.type)
|
||||||
|
except Cancelled:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise
|
||||||
|
|
||||||
|
if any(d not in "123456789" for d in pin) or not (
|
||||||
|
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||||
|
):
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise ValueError("Invalid PIN provided")
|
||||||
|
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||||
|
if isinstance(resp, messages.Failure) and resp.code in (
|
||||||
|
messages.FailureType.PinInvalid,
|
||||||
|
messages.FailureType.PinCancelled,
|
||||||
|
messages.FailureType.PinExpected,
|
||||||
|
):
|
||||||
|
raise PinException(resp.code, resp.message)
|
||||||
|
else:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
return _callback_pin
|
||||||
|
|
||||||
|
@property
|
||||||
|
def passphrase_callback(self):
|
||||||
|
def _callback_passphrase(
|
||||||
|
session: Session, msg: messages.PassphraseRequest
|
||||||
|
) -> t.Any:
|
||||||
|
available_on_device = (
|
||||||
|
Capability.PassphraseEntry in session.features.capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_passphrase(
|
||||||
|
passphrase: str | None = None, on_device: bool | None = None
|
||||||
|
) -> t.Any:
|
||||||
|
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||||
|
resp = session.call_raw(msg)
|
||||||
|
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||||
|
# session.session_id = resp.state
|
||||||
|
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# short-circuit old style entry
|
||||||
|
if msg._on_device is True:
|
||||||
|
return send_passphrase(None, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if session.passphrase is None and isinstance(session, SessionV1):
|
||||||
|
passphrase = self.ui.get_passphrase(
|
||||||
|
available_on_device=available_on_device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
passphrase = session.passphrase
|
||||||
|
except Cancelled:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise
|
||||||
|
|
||||||
|
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||||
|
if not available_on_device:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise RuntimeError("Device is not capable of entering passphrase")
|
||||||
|
else:
|
||||||
|
return send_passphrase(on_device=True)
|
||||||
|
|
||||||
|
# else process host-entered passphrase
|
||||||
|
if not isinstance(passphrase, str):
|
||||||
|
raise RuntimeError("Passphrase must be a str")
|
||||||
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise ValueError("Passphrase too long")
|
||||||
|
|
||||||
|
return send_passphrase(passphrase, on_device=False)
|
||||||
|
|
||||||
|
return _callback_passphrase
|
||||||
|
|
||||||
def ensure_open(self) -> None:
|
def ensure_open(self) -> None:
|
||||||
"""Only open session if there isn't already an open one."""
|
"""Only open session if there isn't already an open one."""
|
||||||
if self.session_counter == 0:
|
# if self.session_counter == 0:
|
||||||
self.open()
|
# self.open()
|
||||||
|
# TODO check if is this needed
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
super().open()
|
pass
|
||||||
if self.session_counter == 1:
|
# TODO is this needed?
|
||||||
self.debug.open()
|
# self.debug.open()
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self.session_counter == 1:
|
pass
|
||||||
self.debug.close()
|
# TODO is this needed?
|
||||||
super().close()
|
# self.debug.close()
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
passphrase: str | object | None = "",
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
) -> Session:
|
||||||
|
if isinstance(passphrase, str):
|
||||||
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
return super().get_session(passphrase, derive_cardano)
|
||||||
|
|
||||||
def set_filter(
|
def set_filter(
|
||||||
self,
|
self,
|
||||||
message_type: type[protobuf.MessageType],
|
message_type: t.Type[protobuf.MessageType],
|
||||||
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Configure a filter function for a specified message type.
|
"""Configure a filter function for a specified message type.
|
||||||
|
|
||||||
@ -1106,7 +1474,8 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
def set_input_flow(
|
def set_input_flow(
|
||||||
self, input_flow: Generator[None, messages.ButtonRequest | None, None]
|
self,
|
||||||
|
input_flow: t.Generator[None, messages.ButtonRequest | None, None],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Configure a sequence of input events for the current with-block.
|
"""Configure a sequence of input events for the current with-block.
|
||||||
|
|
||||||
@ -1140,6 +1509,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
if not hasattr(input_flow, "send"):
|
if not hasattr(input_flow, "send"):
|
||||||
raise RuntimeError("input_flow should be a generator function")
|
raise RuntimeError("input_flow should be a generator function")
|
||||||
self.ui.input_flow = input_flow
|
self.ui.input_flow = input_flow
|
||||||
|
assert input_flow is not None
|
||||||
input_flow.send(None) # start the generator
|
input_flow.send(None) # start the generator
|
||||||
|
|
||||||
def watch_layout(self, watch: bool = True) -> None:
|
def watch_layout(self, watch: bool = True) -> None:
|
||||||
@ -1162,7 +1532,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.in_with_statement = True
|
self.in_with_statement = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
|
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
# copy expected/actual responses before clearing them
|
# copy expected/actual responses before clearing them
|
||||||
@ -1175,20 +1545,21 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
else:
|
else:
|
||||||
input_flow = None
|
input_flow = None
|
||||||
|
|
||||||
self.reset_debug_features()
|
self.reset_debug_features(new_management_session=False)
|
||||||
|
|
||||||
if exc_type is None:
|
if exc_type is None:
|
||||||
# If no other exception was raised, evaluate missed responses
|
# If no other exception was raised, evaluate missed responses
|
||||||
# (raises AssertionError on mismatch)
|
# (raises AssertionError on mismatch)
|
||||||
self._verify_responses(expected_responses, actual_responses)
|
self._verify_responses(expected_responses, actual_responses)
|
||||||
|
|
||||||
elif isinstance(input_flow, Generator):
|
elif isinstance(input_flow, t.Generator):
|
||||||
# Propagate the exception through the input flow, so that we see in
|
# Propagate the exception through the input flow, so that we see in
|
||||||
# traceback where it is stuck.
|
# traceback where it is stuck.
|
||||||
input_flow.throw(exc_type, value, traceback)
|
input_flow.throw(exc_type, value, traceback)
|
||||||
|
|
||||||
def set_expected_responses(
|
def set_expected_responses(
|
||||||
self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
|
self,
|
||||||
|
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set a sequence of expected responses to client calls.
|
"""Set a sequence of expected responses to client calls.
|
||||||
|
|
||||||
@ -1227,7 +1598,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
]
|
]
|
||||||
self.actual_responses = []
|
self.actual_responses = []
|
||||||
|
|
||||||
def use_pin_sequence(self, pins: Iterable[str]) -> None:
|
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
|
||||||
"""Respond to PIN prompts from device with the provided PINs.
|
"""Respond to PIN prompts from device with the provided PINs.
|
||||||
The sequence must be at least as long as the expected number of PIN prompts.
|
The sequence must be at least as long as the expected number of PIN prompts.
|
||||||
"""
|
"""
|
||||||
@ -1235,6 +1606,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
def use_passphrase(self, passphrase: str) -> None:
|
def use_passphrase(self, passphrase: str) -> None:
|
||||||
"""Respond to passphrase prompts from device with the provided passphrase."""
|
"""Respond to passphrase prompts from device with the provided passphrase."""
|
||||||
|
self.passphrase = passphrase
|
||||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
|
||||||
def use_mnemonic(self, mnemonic: str) -> None:
|
def use_mnemonic(self, mnemonic: str) -> None:
|
||||||
@ -1244,15 +1616,14 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
def _raw_read(self) -> protobuf.MessageType:
|
def _raw_read(self) -> protobuf.MessageType:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
resp = self.get_management_session()._read()
|
||||||
resp = super()._raw_read()
|
|
||||||
resp = self._filter_message(resp)
|
resp = self._filter_message(resp)
|
||||||
if self.actual_responses is not None:
|
if self.actual_responses is not None:
|
||||||
self.actual_responses.append(resp)
|
self.actual_responses.append(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
||||||
return super()._raw_write(self._filter_message(msg))
|
return self.get_management_session()._write(self._filter_message(msg))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||||
@ -1322,23 +1693,25 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
|
||||||
# prompt, which is in TINY mode and does not respond to `Ping`.
|
# prompt, which is in TINY mode and does not respond to `Ping`.
|
||||||
cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
# TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
|
||||||
self.transport.begin_session()
|
self.transport.open()
|
||||||
try:
|
try:
|
||||||
self.transport.write(*cancel_msg)
|
# self.protocol.write(messages.Cancel())
|
||||||
|
|
||||||
message = "SYNC" + secrets.token_hex(8)
|
message = "SYNC" + secrets.token_hex(8)
|
||||||
ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message))
|
self.get_management_session()._write(messages.Ping(message=message))
|
||||||
self.transport.write(*ping_msg)
|
|
||||||
resp = None
|
resp = None
|
||||||
while resp != messages.Success(message=message):
|
while resp != messages.Success(message=message):
|
||||||
msg_id, msg_bytes = self.transport.read()
|
|
||||||
try:
|
try:
|
||||||
resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes)
|
resp = self.get_management_session()._read()
|
||||||
|
|
||||||
|
raise Exception
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.transport.end_session()
|
pass # TODO fix
|
||||||
|
# self.transport.end_session(self.session_id or b"")
|
||||||
|
|
||||||
def mnemonic_callback(self, _) -> str:
|
def mnemonic_callback(self, _) -> str:
|
||||||
word, pos = self.debug.read_recovery_word()
|
word, pos = self.debug.read_recovery_word()
|
||||||
@ -1352,8 +1725,8 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def load_device(
|
def load_device(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
mnemonic: Union[str, Iterable[str]],
|
mnemonic: str | t.Iterable[str],
|
||||||
pin: str | None,
|
pin: str | None,
|
||||||
passphrase_protection: bool,
|
passphrase_protection: bool,
|
||||||
label: str | None,
|
label: str | None,
|
||||||
@ -1366,12 +1739,12 @@ def load_device(
|
|||||||
|
|
||||||
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
|
||||||
|
|
||||||
if client.features.initialized:
|
if session.features.initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Device is initialized already. Call device.wipe() and try again."
|
"Device is initialized already. Call device.wipe() and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.LoadDevice(
|
messages.LoadDevice(
|
||||||
mnemonics=mnemonics,
|
mnemonics=mnemonics,
|
||||||
pin=pin,
|
pin=pin,
|
||||||
@ -1382,7 +1755,7 @@ def load_device(
|
|||||||
no_backup=no_backup,
|
no_backup=no_backup,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
client.init_device()
|
session.refresh_features()
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@ -1391,11 +1764,11 @@ load_device_by_mnemonic = load_device
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
|
def prodtest_t1(session: "Session") -> protobuf.MessageType:
|
||||||
if client.features.bootloader_mode is not True:
|
if session.features.bootloader_mode is not True:
|
||||||
raise RuntimeError("Device must be in bootloader mode")
|
raise RuntimeError("Device must be in bootloader mode")
|
||||||
|
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.ProdTestT1(
|
messages.ProdTestT1(
|
||||||
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
|
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
|
||||||
)
|
)
|
||||||
@ -1404,8 +1777,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
|
|||||||
|
|
||||||
def record_screen(
|
def record_screen(
|
||||||
debug_client: "TrezorClientDebugLink",
|
debug_client: "TrezorClientDebugLink",
|
||||||
directory: Union[str, None],
|
directory: str | None,
|
||||||
report_func: Union[Callable[[str], None], None] = None,
|
report_func: t.Callable[[str], None] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record screen changes into a specified directory.
|
"""Record screen changes into a specified directory.
|
||||||
|
|
||||||
@ -1451,5 +1824,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType:
|
def optiga_set_sec_max(session: "Session") -> protobuf.MessageType:
|
||||||
return client.call(messages.DebugLinkOptigaSetSecMax())
|
return session.call(messages.DebugLinkOptigaSetSecMax())
|
||||||
|
@ -27,20 +27,19 @@ from slip10 import SLIP10
|
|||||||
|
|
||||||
from . import messages
|
from . import messages
|
||||||
from .exceptions import Cancelled, TrezorException
|
from .exceptions import Cancelled, TrezorException
|
||||||
from .tools import Address, expect, parse_path, session
|
from .tools import Address, expect, parse_path
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
|
||||||
def apply_settings(
|
def apply_settings(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
label: Optional[str] = None,
|
label: Optional[str] = None,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
use_passphrase: Optional[bool] = None,
|
use_passphrase: Optional[bool] = None,
|
||||||
@ -71,13 +70,13 @@ def apply_settings(
|
|||||||
haptic_feedback=haptic_feedback,
|
haptic_feedback=haptic_feedback,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = client.call(settings)
|
out = session.call(settings)
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _send_language_data(
|
def _send_language_data(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
request: "messages.TranslationDataRequest",
|
request: "messages.TranslationDataRequest",
|
||||||
language_data: bytes,
|
language_data: bytes,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
@ -87,76 +86,70 @@ def _send_language_data(
|
|||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
data_offset = response.data_offset
|
data_offset = response.data_offset
|
||||||
chunk = language_data[data_offset : data_offset + data_length]
|
chunk = language_data[data_offset : data_offset + data_length]
|
||||||
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
|
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
|
||||||
def change_language(
|
def change_language(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
language_data: bytes,
|
language_data: bytes,
|
||||||
show_display: bool | None = None,
|
show_display: bool | None = None,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
data_length = len(language_data)
|
data_length = len(language_data)
|
||||||
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
|
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
if data_length > 0:
|
if data_length > 0:
|
||||||
assert isinstance(response, messages.TranslationDataRequest)
|
assert isinstance(response, messages.TranslationDataRequest)
|
||||||
response = _send_language_data(client, response, language_data)
|
response = _send_language_data(session, response, language_data)
|
||||||
assert isinstance(response, messages.Success)
|
assert isinstance(response, messages.Success)
|
||||||
client.refresh_features() # changing the language in features
|
session.refresh_features() # changing the language in features
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
def apply_flags(session: "Session", flags: int) -> "MessageType":
|
||||||
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
|
out = session.call(messages.ApplyFlags(flags=flags))
|
||||||
out = client.call(messages.ApplyFlags(flags=flags))
|
session.refresh_features()
|
||||||
client.refresh_features()
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
def change_pin(session: "Session", remove: bool = False) -> "MessageType":
|
||||||
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
ret = session.call(messages.ChangePin(remove=remove))
|
||||||
ret = client.call(messages.ChangePin(remove=remove))
|
session.refresh_features()
|
||||||
client.refresh_features()
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType":
|
||||||
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
|
ret = session.call(messages.ChangeWipeCode(remove=remove))
|
||||||
ret = client.call(messages.ChangeWipeCode(remove=remove))
|
session.refresh_features()
|
||||||
client.refresh_features()
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
|
||||||
def sd_protect(
|
def sd_protect(
|
||||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
session: "Session", operation: messages.SdProtectOperationType
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
ret = client.call(messages.SdProtect(operation=operation))
|
ret = session.call(messages.SdProtect(operation=operation))
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
def wipe(session: "Session") -> "MessageType":
|
||||||
def wipe(client: "TrezorClient") -> "MessageType":
|
|
||||||
ret = client.call(messages.WipeDevice())
|
ret = session.call(messages.WipeDevice())
|
||||||
if not client.features.bootloader_mode:
|
# if not session.features.bootloader_mode:
|
||||||
client.init_device()
|
# session.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def recover(
|
def recover(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
word_count: int = 24,
|
word_count: int = 24,
|
||||||
passphrase_protection: bool = False,
|
passphrase_protection: bool = False,
|
||||||
pin_protection: bool = True,
|
pin_protection: bool = True,
|
||||||
@ -192,13 +185,13 @@ def recover(
|
|||||||
if type is None:
|
if type is None:
|
||||||
type = messages.RecoveryType.NormalRecovery
|
type = messages.RecoveryType.NormalRecovery
|
||||||
|
|
||||||
if client.features.model == "1" and input_callback is None:
|
if session.features.model == "1" and input_callback is None:
|
||||||
raise RuntimeError("Input callback required for Trezor One")
|
raise RuntimeError("Input callback required for Trezor One")
|
||||||
|
|
||||||
if word_count not in (12, 18, 24):
|
if word_count not in (12, 18, 24):
|
||||||
raise ValueError("Invalid word count. Use 12/18/24")
|
raise ValueError("Invalid word count. Use 12/18/24")
|
||||||
|
|
||||||
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Device already initialized. Call device.wipe() and try again."
|
"Device already initialized. Call device.wipe() and try again."
|
||||||
)
|
)
|
||||||
@ -220,17 +213,17 @@ def recover(
|
|||||||
msg.label = label
|
msg.label = label
|
||||||
msg.u2f_counter = u2f_counter
|
msg.u2f_counter = u2f_counter
|
||||||
|
|
||||||
res = client.call(msg)
|
res = session.call(msg)
|
||||||
|
|
||||||
while isinstance(res, messages.WordRequest):
|
while isinstance(res, messages.WordRequest):
|
||||||
try:
|
try:
|
||||||
assert input_callback is not None
|
assert input_callback is not None
|
||||||
inp = input_callback(res.type)
|
inp = input_callback(res.type)
|
||||||
res = client.call(messages.WordAck(word=inp))
|
res = session.call(messages.WordAck(word=inp))
|
||||||
except Cancelled:
|
except Cancelled:
|
||||||
res = client.call(messages.Cancel())
|
res = session.call(messages.Cancel())
|
||||||
|
|
||||||
client.init_device()
|
session.refresh_features()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -279,9 +272,8 @@ def reset(*args: Any, **kwargs: Any) -> "MessageType":
|
|||||||
return reset_entropy_check(*args, **kwargs)[0]
|
return reset_entropy_check(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def reset_entropy_check(
|
def reset_entropy_check(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
display_random: bool = False,
|
display_random: bool = False,
|
||||||
strength: Optional[int] = None,
|
strength: Optional[int] = None,
|
||||||
passphrase_protection: bool = False,
|
passphrase_protection: bool = False,
|
||||||
@ -307,13 +299,13 @@ def reset_entropy_check(
|
|||||||
DeprecationWarning,
|
DeprecationWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
if client.features.initialized:
|
if session.features.initialized:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Device is initialized already. Call wipe_device() and try again."
|
"Device is initialized already. Call wipe_device() and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
if strength is None:
|
if strength is None:
|
||||||
if client.features.model == "1":
|
if session.features.model == "1":
|
||||||
strength = 256
|
strength = 256
|
||||||
else:
|
else:
|
||||||
strength = 128
|
strength = 128
|
||||||
@ -335,7 +327,7 @@ def reset_entropy_check(
|
|||||||
entropy_check=entropy_check_count is not None,
|
entropy_check=entropy_check_count is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = client.call(msg)
|
resp = session.call(msg)
|
||||||
if not isinstance(resp, messages.EntropyRequest):
|
if not isinstance(resp, messages.EntropyRequest):
|
||||||
raise RuntimeError("Invalid response, expected EntropyRequest")
|
raise RuntimeError("Invalid response, expected EntropyRequest")
|
||||||
|
|
||||||
@ -344,7 +336,7 @@ def reset_entropy_check(
|
|||||||
|
|
||||||
external_entropy = os.urandom(32)
|
external_entropy = os.urandom(32)
|
||||||
entropy_commitment = resp.entropy_commitment
|
entropy_commitment = resp.entropy_commitment
|
||||||
resp = client.call(messages.EntropyAck(entropy=external_entropy))
|
resp = session.call(messages.EntropyAck(entropy=external_entropy))
|
||||||
|
|
||||||
if entropy_check_count is None:
|
if entropy_check_count is None:
|
||||||
break
|
break
|
||||||
@ -353,18 +345,18 @@ def reset_entropy_check(
|
|||||||
return resp, []
|
return resp, []
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
resp = client.call(messages.GetPublicKey(address_n=path))
|
resp = session.call(messages.GetPublicKey(address_n=path))
|
||||||
if not isinstance(resp, messages.PublicKey):
|
if not isinstance(resp, messages.PublicKey):
|
||||||
return resp, []
|
return resp, []
|
||||||
xpubs.append(resp.xpub)
|
xpubs.append(resp.xpub)
|
||||||
|
|
||||||
if entropy_check_count <= 0:
|
if entropy_check_count <= 0:
|
||||||
resp = client.call(messages.EntropyCheckContinue(finish=True))
|
resp = session.call(messages.EntropyCheckContinue(finish=True))
|
||||||
break
|
break
|
||||||
|
|
||||||
entropy_check_count -= 1
|
entropy_check_count -= 1
|
||||||
|
|
||||||
resp = client.call(messages.EntropyCheckContinue(finish=False))
|
resp = session.call(messages.EntropyCheckContinue(finish=False))
|
||||||
if not isinstance(resp, messages.EntropyRequest):
|
if not isinstance(resp, messages.EntropyRequest):
|
||||||
raise RuntimeError("Invalid response, expected EntropyRequest")
|
raise RuntimeError("Invalid response, expected EntropyRequest")
|
||||||
|
|
||||||
@ -385,18 +377,17 @@ def reset_entropy_check(
|
|||||||
if slip10.get_xpub_from_path(path) != xpub:
|
if slip10.get_xpub_from_path(path) != xpub:
|
||||||
raise RuntimeError("Invalid XPUB in entropy check")
|
raise RuntimeError("Invalid XPUB in entropy check")
|
||||||
|
|
||||||
client.init_device()
|
session.refresh_features()
|
||||||
return resp, zip(paths, xpubs)
|
return resp, zip(paths, xpubs)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
|
||||||
def backup(
|
def backup(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
group_threshold: Optional[int] = None,
|
group_threshold: Optional[int] = None,
|
||||||
groups: Iterable[tuple[int, int]] = (),
|
groups: Iterable[tuple[int, int]] = (),
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
ret = client.call(
|
ret = session.call(
|
||||||
messages.BackupDevice(
|
messages.BackupDevice(
|
||||||
group_threshold=group_threshold,
|
group_threshold=group_threshold,
|
||||||
groups=[
|
groups=[
|
||||||
@ -405,37 +396,36 @@ def backup(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def cancel_authorization(client: "TrezorClient") -> "MessageType":
|
def cancel_authorization(session: "Session") -> "MessageType":
|
||||||
return client.call(messages.CancelAuthorization())
|
return session.call(messages.CancelAuthorization())
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
|
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
|
||||||
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType":
|
def unlock_path(session: "Session", n: "Address") -> "MessageType":
|
||||||
resp = client.call(messages.UnlockPath(address_n=n))
|
resp = session.call(messages.UnlockPath(address_n=n))
|
||||||
|
|
||||||
# Cancel the UnlockPath workflow now that we have the authentication code.
|
# Cancel the UnlockPath workflow now that we have the authentication code.
|
||||||
try:
|
try:
|
||||||
client.call(messages.Cancel())
|
session.call(messages.Cancel())
|
||||||
except Cancelled:
|
except Cancelled:
|
||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
raise TrezorException("Unexpected response in UnlockPath flow")
|
raise TrezorException("Unexpected response in UnlockPath flow")
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def reboot_to_bootloader(
|
def reboot_to_bootloader(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
|
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
|
||||||
firmware_header: Optional[bytes] = None,
|
firmware_header: Optional[bytes] = None,
|
||||||
language_data: bytes = b"",
|
language_data: bytes = b"",
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
response = client.call(
|
response = session.call(
|
||||||
messages.RebootToBootloader(
|
messages.RebootToBootloader(
|
||||||
boot_command=boot_command,
|
boot_command=boot_command,
|
||||||
firmware_header=firmware_header,
|
firmware_header=firmware_header,
|
||||||
@ -443,42 +433,37 @@ def reboot_to_bootloader(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
if isinstance(response, messages.TranslationDataRequest):
|
if isinstance(response, messages.TranslationDataRequest):
|
||||||
response = _send_language_data(client, response, language_data)
|
response = _send_language_data(session, response, language_data)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def show_device_tutorial(client: "TrezorClient") -> "MessageType":
|
def show_device_tutorial(session: "Session") -> "MessageType":
|
||||||
return client.call(messages.ShowDeviceTutorial())
|
return session.call(messages.ShowDeviceTutorial())
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
|
||||||
def unlock_bootloader(client: "TrezorClient") -> "MessageType":
|
|
||||||
return client.call(messages.UnlockBootloader())
|
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
@session
|
def unlock_bootloader(session: "Session") -> "MessageType":
|
||||||
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType":
|
return session.call(messages.UnlockBootloader())
|
||||||
|
|
||||||
|
|
||||||
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
|
def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType":
|
||||||
"""Sets or clears the busy state of the device.
|
"""Sets or clears the busy state of the device.
|
||||||
|
|
||||||
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
|
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
|
||||||
Setting `expiry_ms=None` clears the busy state.
|
Setting `expiry_ms=None` clears the busy state.
|
||||||
"""
|
"""
|
||||||
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms))
|
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms))
|
||||||
client.refresh_features()
|
session.refresh_features()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.AuthenticityProof)
|
@expect(messages.AuthenticityProof)
|
||||||
def authenticate(client: "TrezorClient", challenge: bytes):
|
def authenticate(session: "Session", challenge: bytes):
|
||||||
return client.call(messages.AuthenticateDevice(challenge=challenge))
|
return session.call(messages.AuthenticateDevice(challenge=challenge))
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def set_brightness(
|
def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType":
|
||||||
client: "TrezorClient", value: Optional[int] = None
|
return session.call(messages.SetBrightness(value=value))
|
||||||
) -> "MessageType":
|
|
||||||
return client.call(messages.SetBrightness(value=value))
|
|
||||||
|
@ -18,12 +18,12 @@ from datetime import datetime
|
|||||||
from typing import TYPE_CHECKING, List, Tuple
|
from typing import TYPE_CHECKING, List, Tuple
|
||||||
|
|
||||||
from . import exceptions, messages
|
from . import exceptions, messages
|
||||||
from .tools import b58decode, expect, session
|
from .tools import b58decode, expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def name_to_number(name: str) -> int:
|
def name_to_number(name: str) -> int:
|
||||||
@ -321,17 +321,16 @@ def parse_transaction_json(
|
|||||||
|
|
||||||
@expect(messages.EosPublicKey)
|
@expect(messages.EosPublicKey)
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
session: "Session", n: "Address", show_display: bool = False
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
response = client.call(
|
response = session.call(
|
||||||
messages.EosGetPublicKey(address_n=n, show_display=show_display)
|
messages.EosGetPublicKey(address_n=n, show_display=show_display)
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: "Address",
|
address: "Address",
|
||||||
transaction: dict,
|
transaction: dict,
|
||||||
chain_id: str,
|
chain_id: str,
|
||||||
@ -347,11 +346,11 @@ def sign_tx(
|
|||||||
chunkify=chunkify,
|
chunkify=chunkify,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while isinstance(response, messages.EosTxActionRequest):
|
while isinstance(response, messages.EosTxActionRequest):
|
||||||
response = client.call(actions.pop(0))
|
response = session.call(actions.pop(0))
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# pop from empty list
|
# pop from empty list
|
||||||
raise exceptions.TrezorException(
|
raise exceptions.TrezorException(
|
||||||
|
@ -18,12 +18,12 @@ import re
|
|||||||
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from . import definitions, exceptions, messages
|
from . import definitions, exceptions, messages
|
||||||
from .tools import expect, prepare_message_bytes, session, unharden
|
from .tools import expect, prepare_message_bytes, unharden
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
def int_to_big_endian(value: int) -> bytes:
|
def int_to_big_endian(value: int) -> bytes:
|
||||||
@ -163,13 +163,13 @@ def network_from_address_n(
|
|||||||
|
|
||||||
@expect(messages.EthereumAddress, field="address", ret_type=str)
|
@expect(messages.EthereumAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
encoded_network: Optional[bytes] = None,
|
encoded_network: Optional[bytes] = None,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EthereumGetAddress(
|
messages.EthereumGetAddress(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
@ -181,16 +181,15 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.EthereumPublicKey)
|
@expect(messages.EthereumPublicKey)
|
||||||
def get_public_node(
|
def get_public_node(
|
||||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
session: "Session", n: "Address", show_display: bool = False
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
|
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
nonce: int,
|
nonce: int,
|
||||||
gas_price: int,
|
gas_price: int,
|
||||||
@ -226,13 +225,13 @@ def sign_tx(
|
|||||||
data, chunk = data[1024:], data[:1024]
|
data, chunk = data[1024:], data[:1024]
|
||||||
msg.data_initial_chunk = chunk
|
msg.data_initial_chunk = chunk
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
while response.data_length is not None:
|
while response.data_length is not None:
|
||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
data, chunk = data[data_length:], data[:data_length]
|
data, chunk = data[data_length:], data[:data_length]
|
||||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
assert response.signature_v is not None
|
assert response.signature_v is not None
|
||||||
@ -247,9 +246,8 @@ def sign_tx(
|
|||||||
return response.signature_v, response.signature_r, response.signature_s
|
return response.signature_v, response.signature_r, response.signature_s
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def sign_tx_eip1559(
|
def sign_tx_eip1559(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
*,
|
*,
|
||||||
nonce: int,
|
nonce: int,
|
||||||
@ -282,13 +280,13 @@ def sign_tx_eip1559(
|
|||||||
chunkify=chunkify,
|
chunkify=chunkify,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.call(msg)
|
response = session.call(msg)
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
while response.data_length is not None:
|
while response.data_length is not None:
|
||||||
data_length = response.data_length
|
data_length = response.data_length
|
||||||
data, chunk = data[data_length:], data[:data_length]
|
data, chunk = data[data_length:], data[:data_length]
|
||||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||||
assert isinstance(response, messages.EthereumTxRequest)
|
assert isinstance(response, messages.EthereumTxRequest)
|
||||||
|
|
||||||
assert response.signature_v is not None
|
assert response.signature_v is not None
|
||||||
@ -299,13 +297,13 @@ def sign_tx_eip1559(
|
|||||||
|
|
||||||
@expect(messages.EthereumMessageSignature)
|
@expect(messages.EthereumMessageSignature)
|
||||||
def sign_message(
|
def sign_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
message: AnyStr,
|
message: AnyStr,
|
||||||
encoded_network: Optional[bytes] = None,
|
encoded_network: Optional[bytes] = None,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EthereumSignMessage(
|
messages.EthereumSignMessage(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
message=prepare_message_bytes(message),
|
message=prepare_message_bytes(message),
|
||||||
@ -317,7 +315,7 @@ def sign_message(
|
|||||||
|
|
||||||
@expect(messages.EthereumTypedDataSignature)
|
@expect(messages.EthereumTypedDataSignature)
|
||||||
def sign_typed_data(
|
def sign_typed_data(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
data: Dict[str, Any],
|
data: Dict[str, Any],
|
||||||
*,
|
*,
|
||||||
@ -333,7 +331,7 @@ def sign_typed_data(
|
|||||||
metamask_v4_compat=metamask_v4_compat,
|
metamask_v4_compat=metamask_v4_compat,
|
||||||
definitions=definitions,
|
definitions=definitions,
|
||||||
)
|
)
|
||||||
response = client.call(request)
|
response = session.call(request)
|
||||||
|
|
||||||
# Sending all the types
|
# Sending all the types
|
||||||
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
||||||
@ -349,7 +347,7 @@ def sign_typed_data(
|
|||||||
members.append(struct_member)
|
members.append(struct_member)
|
||||||
|
|
||||||
request = messages.EthereumTypedDataStructAck(members=members)
|
request = messages.EthereumTypedDataStructAck(members=members)
|
||||||
response = client.call(request)
|
response = session.call(request)
|
||||||
|
|
||||||
# Sending the whole message that should be signed
|
# Sending the whole message that should be signed
|
||||||
while isinstance(response, messages.EthereumTypedDataValueRequest):
|
while isinstance(response, messages.EthereumTypedDataValueRequest):
|
||||||
@ -362,7 +360,7 @@ def sign_typed_data(
|
|||||||
member_typename = data["primaryType"]
|
member_typename = data["primaryType"]
|
||||||
member_data = data["message"]
|
member_data = data["message"]
|
||||||
else:
|
else:
|
||||||
client.cancel()
|
# TODO session.cancel()
|
||||||
raise exceptions.TrezorException("Root index can only be 0 or 1")
|
raise exceptions.TrezorException("Root index can only be 0 or 1")
|
||||||
|
|
||||||
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
|
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
|
||||||
@ -385,20 +383,20 @@ def sign_typed_data(
|
|||||||
encoded_data = encode_data(member_data, member_typename)
|
encoded_data = encode_data(member_data, member_typename)
|
||||||
|
|
||||||
request = messages.EthereumTypedDataValueAck(value=encoded_data)
|
request = messages.EthereumTypedDataValueAck(value=encoded_data)
|
||||||
response = client.call(request)
|
response = session.call(request)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def verify_message(
|
def verify_message(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address: str,
|
address: str,
|
||||||
signature: bytes,
|
signature: bytes,
|
||||||
message: AnyStr,
|
message: AnyStr,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
resp = client.call(
|
resp = session.call(
|
||||||
messages.EthereumVerifyMessage(
|
messages.EthereumVerifyMessage(
|
||||||
address=address,
|
address=address,
|
||||||
signature=signature,
|
signature=signature,
|
||||||
@ -413,13 +411,13 @@ def verify_message(
|
|||||||
|
|
||||||
@expect(messages.EthereumTypedDataSignature)
|
@expect(messages.EthereumTypedDataSignature)
|
||||||
def sign_typed_data_hash(
|
def sign_typed_data_hash(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
domain_hash: bytes,
|
domain_hash: bytes,
|
||||||
message_hash: Optional[bytes],
|
message_hash: Optional[bytes],
|
||||||
encoded_network: Optional[bytes] = None,
|
encoded_network: Optional[bytes] = None,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.EthereumSignTypedHash(
|
messages.EthereumSignTypedHash(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
domain_separator_hash=domain_hash,
|
domain_separator_hash=domain_hash,
|
||||||
|
@ -20,8 +20,8 @@ from . import messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
@expect(
|
@expect(
|
||||||
@ -29,27 +29,27 @@ if TYPE_CHECKING:
|
|||||||
field="credentials",
|
field="credentials",
|
||||||
ret_type=List[messages.WebAuthnCredential],
|
ret_type=List[messages.WebAuthnCredential],
|
||||||
)
|
)
|
||||||
def list_credentials(client: "TrezorClient") -> "MessageType":
|
def list_credentials(session: "Session") -> "MessageType":
|
||||||
return client.call(messages.WebAuthnListResidentCredentials())
|
return session.call(messages.WebAuthnListResidentCredentials())
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
|
def add_credential(session: "Session", credential_id: bytes) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
|
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
|
def remove_credential(session: "Session", index: int) -> "MessageType":
|
||||||
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
|
return session.call(messages.WebAuthnRemoveResidentCredential(index=index))
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Success, field="message", ret_type=str)
|
@expect(messages.Success, field="message", ret_type=str)
|
||||||
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
|
def set_counter(session: "Session", u2f_counter: int) -> "MessageType":
|
||||||
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
|
return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
|
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
|
||||||
def get_next_counter(client: "TrezorClient") -> "MessageType":
|
def get_next_counter(session: "Session") -> "MessageType":
|
||||||
return client.call(messages.GetNextU2FCounter())
|
return session.call(messages.GetNextU2FCounter())
|
||||||
|
@ -20,7 +20,7 @@ from hashlib import blake2s
|
|||||||
from typing_extensions import Protocol, TypeGuard
|
from typing_extensions import Protocol, TypeGuard
|
||||||
|
|
||||||
from .. import messages
|
from .. import messages
|
||||||
from ..tools import expect, session
|
from ..tools import expect
|
||||||
from .core import VendorFirmware
|
from .core import VendorFirmware
|
||||||
from .legacy import LegacyFirmware, LegacyV2Firmware
|
from .legacy import LegacyFirmware, LegacyV2Firmware
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ if True:
|
|||||||
from .vendor import * # noqa: F401, F403
|
from .vendor import * # noqa: F401, F403
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from ..client import TrezorClient
|
from ..transport.session import Session
|
||||||
|
|
||||||
T = t.TypeVar("T", bound="FirmwareType")
|
T = t.TypeVar("T", bound="FirmwareType")
|
||||||
|
|
||||||
@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]:
|
|||||||
# ====== Client functions ====== #
|
# ====== Client functions ====== #
|
||||||
|
|
||||||
|
|
||||||
@session
|
|
||||||
def update(
|
def update(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
data: bytes,
|
data: bytes,
|
||||||
progress_update: t.Callable[[int], t.Any] = lambda _: None,
|
progress_update: t.Callable[[int], t.Any] = lambda _: None,
|
||||||
):
|
):
|
||||||
if client.features.bootloader_mode is False:
|
if session.features.bootloader_mode is False:
|
||||||
raise RuntimeError("Device must be in bootloader mode")
|
raise RuntimeError("Device must be in bootloader mode")
|
||||||
|
|
||||||
resp = client.call(messages.FirmwareErase(length=len(data)))
|
resp = session.call(messages.FirmwareErase(length=len(data)))
|
||||||
|
|
||||||
# TREZORv1 method
|
# TREZORv1 method
|
||||||
if isinstance(resp, messages.Success):
|
if isinstance(resp, messages.Success):
|
||||||
resp = client.call(messages.FirmwareUpload(payload=data))
|
resp = session.call(messages.FirmwareUpload(payload=data))
|
||||||
progress_update(len(data))
|
progress_update(len(data))
|
||||||
if isinstance(resp, messages.Success):
|
if isinstance(resp, messages.Success):
|
||||||
return
|
return
|
||||||
@ -97,7 +96,7 @@ def update(
|
|||||||
length = resp.length
|
length = resp.length
|
||||||
payload = data[resp.offset : resp.offset + length]
|
payload = data[resp.offset : resp.offset + length]
|
||||||
digest = blake2s(payload).digest()
|
digest = blake2s(payload).digest()
|
||||||
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||||
progress_update(length)
|
progress_update(length)
|
||||||
|
|
||||||
if isinstance(resp, messages.Success):
|
if isinstance(resp, messages.Success):
|
||||||
@ -107,5 +106,5 @@ def update(
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.FirmwareHash, field="hash", ret_type=bytes)
|
@expect(messages.FirmwareHash, field="hash", ret_type=bytes)
|
||||||
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]):
|
def get_hash(session: "Session", challenge: t.Optional[bytes]):
|
||||||
return client.call(messages.GetFirmwareHash(challenge=challenge))
|
return session.call(messages.GetFirmwareHash(challenge=challenge))
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import logging
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ from typing_extensions import Self
|
|||||||
from . import messages, protobuf
|
from . import messages, protobuf
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ProtobufMapping:
|
class ProtobufMapping:
|
||||||
@ -63,11 +65,21 @@ class ProtobufMapping:
|
|||||||
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE)
|
||||||
if wire_type is None:
|
if wire_type is None:
|
||||||
raise ValueError("Cannot encode class without wire type")
|
raise ValueError("Cannot encode class without wire type")
|
||||||
|
LOG.debug("encoding wire type %d", wire_type)
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
protobuf.dump_message(buf, msg)
|
protobuf.dump_message(buf, msg)
|
||||||
return wire_type, buf.getvalue()
|
return wire_type, buf.getvalue()
|
||||||
|
|
||||||
|
def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes:
|
||||||
|
"""Serialize a Python protobuf class.
|
||||||
|
|
||||||
|
Returns the byte representation of the protobuf message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
protobuf.dump_message(buf, msg)
|
||||||
|
return buf.getvalue()
|
||||||
|
|
||||||
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType:
|
||||||
"""Deserialize a protobuf message into a Python class."""
|
"""Deserialize a protobuf message into a Python class."""
|
||||||
cls = self.type_to_class[msg_wire_type]
|
cls = self.type_to_class[msg_wire_type]
|
||||||
@ -83,7 +95,9 @@ class ProtobufMapping:
|
|||||||
mapping = cls()
|
mapping = cls()
|
||||||
|
|
||||||
message_types = getattr(module, "MessageType")
|
message_types = getattr(module, "MessageType")
|
||||||
for entry in message_types:
|
thp_message_types = getattr(module, "ThpMessageType")
|
||||||
|
|
||||||
|
for entry in (*message_types, *thp_message_types):
|
||||||
msg_class = getattr(module, entry.name, None)
|
msg_class = getattr(module, entry.name, None)
|
||||||
if msg_class is None:
|
if msg_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
316
python/src/trezorlib/messages.py
generated
316
python/src/trezorlib/messages.py
generated
@ -43,6 +43,8 @@ class FailureType(IntEnum):
|
|||||||
PinMismatch = 12
|
PinMismatch = 12
|
||||||
WipeCodeMismatch = 13
|
WipeCodeMismatch = 13
|
||||||
InvalidSession = 14
|
InvalidSession = 14
|
||||||
|
ThpUnallocatedSession = 15
|
||||||
|
InvalidProtocol = 16
|
||||||
FirmwareError = 99
|
FirmwareError = 99
|
||||||
|
|
||||||
|
|
||||||
@ -400,6 +402,34 @@ class TezosBallotType(IntEnum):
|
|||||||
Pass = 2
|
Pass = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ThpMessageType(IntEnum):
|
||||||
|
ThpCreateNewSession = 1000
|
||||||
|
ThpNewSession = 1001
|
||||||
|
ThpStartPairingRequest = 1008
|
||||||
|
ThpPairingPreparationsFinished = 1009
|
||||||
|
ThpCredentialRequest = 1010
|
||||||
|
ThpCredentialResponse = 1011
|
||||||
|
ThpEndRequest = 1012
|
||||||
|
ThpEndResponse = 1013
|
||||||
|
ThpCodeEntryCommitment = 1016
|
||||||
|
ThpCodeEntryChallenge = 1017
|
||||||
|
ThpCodeEntryCpaceHost = 1018
|
||||||
|
ThpCodeEntryCpaceTrezor = 1019
|
||||||
|
ThpCodeEntryTag = 1020
|
||||||
|
ThpCodeEntrySecret = 1021
|
||||||
|
ThpQrCodeTag = 1024
|
||||||
|
ThpQrCodeSecret = 1025
|
||||||
|
ThpNfcUnidirectionalTag = 1032
|
||||||
|
ThpNfcUnidirectionalSecret = 1033
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingMethod(IntEnum):
|
||||||
|
NoMethod = 1
|
||||||
|
CodeEntry = 2
|
||||||
|
QrCode = 3
|
||||||
|
NFC_Unidirectional = 4
|
||||||
|
|
||||||
|
|
||||||
class MessageType(IntEnum):
|
class MessageType(IntEnum):
|
||||||
Initialize = 0
|
Initialize = 0
|
||||||
Ping = 1
|
Ping = 1
|
||||||
@ -4136,6 +4166,7 @@ class DebugLinkGetState(protobuf.MessageType):
|
|||||||
1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None),
|
1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None),
|
||||||
2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None),
|
2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None),
|
||||||
3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE),
|
3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE),
|
||||||
|
4: protobuf.Field("thp_channel_id", "bytes", repeated=False, required=False, default=None),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -4144,10 +4175,12 @@ class DebugLinkGetState(protobuf.MessageType):
|
|||||||
wait_word_list: Optional["bool"] = None,
|
wait_word_list: Optional["bool"] = None,
|
||||||
wait_word_pos: Optional["bool"] = None,
|
wait_word_pos: Optional["bool"] = None,
|
||||||
wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE,
|
wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE,
|
||||||
|
thp_channel_id: Optional["bytes"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.wait_word_list = wait_word_list
|
self.wait_word_list = wait_word_list
|
||||||
self.wait_word_pos = wait_word_pos
|
self.wait_word_pos = wait_word_pos
|
||||||
self.wait_layout = wait_layout
|
self.wait_layout = wait_layout
|
||||||
|
self.thp_channel_id = thp_channel_id
|
||||||
|
|
||||||
|
|
||||||
class DebugLinkState(protobuf.MessageType):
|
class DebugLinkState(protobuf.MessageType):
|
||||||
@ -4166,6 +4199,9 @@ class DebugLinkState(protobuf.MessageType):
|
|||||||
11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None),
|
11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None),
|
||||||
12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None),
|
12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None),
|
||||||
13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None),
|
13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None),
|
||||||
|
14: protobuf.Field("thp_pairing_code_entry_code", "uint32", repeated=False, required=False, default=None),
|
||||||
|
15: protobuf.Field("thp_pairing_code_qr_code", "bytes", repeated=False, required=False, default=None),
|
||||||
|
16: protobuf.Field("thp_pairing_code_nfc_unidirectional", "bytes", repeated=False, required=False, default=None),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -4184,6 +4220,9 @@ class DebugLinkState(protobuf.MessageType):
|
|||||||
recovery_word_pos: Optional["int"] = None,
|
recovery_word_pos: Optional["int"] = None,
|
||||||
reset_word_pos: Optional["int"] = None,
|
reset_word_pos: Optional["int"] = None,
|
||||||
mnemonic_type: Optional["BackupType"] = None,
|
mnemonic_type: Optional["BackupType"] = None,
|
||||||
|
thp_pairing_code_entry_code: Optional["int"] = None,
|
||||||
|
thp_pairing_code_qr_code: Optional["bytes"] = None,
|
||||||
|
thp_pairing_code_nfc_unidirectional: Optional["bytes"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.tokens: Sequence["str"] = tokens if tokens is not None else []
|
self.tokens: Sequence["str"] = tokens if tokens is not None else []
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
@ -4198,6 +4237,9 @@ class DebugLinkState(protobuf.MessageType):
|
|||||||
self.recovery_word_pos = recovery_word_pos
|
self.recovery_word_pos = recovery_word_pos
|
||||||
self.reset_word_pos = reset_word_pos
|
self.reset_word_pos = reset_word_pos
|
||||||
self.mnemonic_type = mnemonic_type
|
self.mnemonic_type = mnemonic_type
|
||||||
|
self.thp_pairing_code_entry_code = thp_pairing_code_entry_code
|
||||||
|
self.thp_pairing_code_qr_code = thp_pairing_code_qr_code
|
||||||
|
self.thp_pairing_code_nfc_unidirectional = thp_pairing_code_nfc_unidirectional
|
||||||
|
|
||||||
|
|
||||||
class DebugLinkStop(protobuf.MessageType):
|
class DebugLinkStop(protobuf.MessageType):
|
||||||
@ -7860,6 +7902,280 @@ class TezosManagerTransfer(protobuf.MessageType):
|
|||||||
self.amount = amount
|
self.amount = amount
|
||||||
|
|
||||||
|
|
||||||
|
class ThpDeviceProperties(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None),
|
||||||
|
4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None),
|
||||||
|
5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
|
||||||
|
internal_model: Optional["str"] = None,
|
||||||
|
model_variant: Optional["int"] = None,
|
||||||
|
bootloader_mode: Optional["bool"] = None,
|
||||||
|
protocol_version: Optional["int"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
|
||||||
|
self.internal_model = internal_model
|
||||||
|
self.model_variant = model_variant
|
||||||
|
self.bootloader_mode = bootloader_mode
|
||||||
|
self.protocol_version = protocol_version
|
||||||
|
|
||||||
|
|
||||||
|
class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = None
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None,
|
||||||
|
host_pairing_credential: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else []
|
||||||
|
self.host_pairing_credential = host_pairing_credential
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCreateNewSession(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1000
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None),
|
||||||
|
3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
passphrase: Optional["str"] = None,
|
||||||
|
on_device: Optional["bool"] = None,
|
||||||
|
derive_cardano: Optional["bool"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.passphrase = passphrase
|
||||||
|
self.on_device = on_device
|
||||||
|
self.derive_cardano = derive_cardano
|
||||||
|
|
||||||
|
|
||||||
|
class ThpNewSession(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1001
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
new_session_id: Optional["int"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.new_session_id = new_session_id
|
||||||
|
|
||||||
|
|
||||||
|
class ThpStartPairingRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1008
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_name: Optional["str"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_name = host_name
|
||||||
|
|
||||||
|
|
||||||
|
class ThpPairingPreparationsFinished(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1009
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCommitment(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1016
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
commitment: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.commitment = commitment
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryChallenge(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1017
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
challenge: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.challenge = challenge
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCpaceHost(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1018
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cpace_host_public_key: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.cpace_host_public_key = cpace_host_public_key
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1019
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cpace_trezor_public_key: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.cpace_trezor_public_key = cpace_trezor_public_key
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntryTag(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1020
|
||||||
|
FIELDS = {
|
||||||
|
2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCodeEntrySecret(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1021
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
secret: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
|
||||||
|
class ThpQrCodeTag(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1024
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpQrCodeSecret(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1025
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
secret: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
|
||||||
|
class ThpNfcUnidirectionalTag(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1032
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tag: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
|
class ThpNfcUnidirectionalSecret(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1033
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
secret: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1010
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
host_static_pubkey: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.host_static_pubkey = host_static_pubkey
|
||||||
|
|
||||||
|
|
||||||
|
class ThpCredentialResponse(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1011
|
||||||
|
FIELDS = {
|
||||||
|
1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None),
|
||||||
|
2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
trezor_static_pubkey: Optional["bytes"] = None,
|
||||||
|
credential: Optional["bytes"] = None,
|
||||||
|
) -> None:
|
||||||
|
self.trezor_static_pubkey = trezor_static_pubkey
|
||||||
|
self.credential = credential
|
||||||
|
|
||||||
|
|
||||||
|
class ThpEndRequest(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1012
|
||||||
|
|
||||||
|
|
||||||
|
class ThpEndResponse(protobuf.MessageType):
|
||||||
|
MESSAGE_WIRE_TYPE = 1013
|
||||||
|
|
||||||
|
|
||||||
class ThpCredentialMetadata(protobuf.MessageType):
|
class ThpCredentialMetadata(protobuf.MessageType):
|
||||||
MESSAGE_WIRE_TYPE = None
|
MESSAGE_WIRE_TYPE = None
|
||||||
FIELDS = {
|
FIELDS = {
|
||||||
|
@ -20,25 +20,25 @@ from . import messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.Entropy, field="entropy", ret_type=bytes)
|
@expect(messages.Entropy, field="entropy", ret_type=bytes)
|
||||||
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
|
def get_entropy(session: "Session", size: int) -> "MessageType":
|
||||||
return client.call(messages.GetEntropy(size=size))
|
return session.call(messages.GetEntropy(size=size))
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.SignedIdentity)
|
@expect(messages.SignedIdentity)
|
||||||
def sign_identity(
|
def sign_identity(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
identity: messages.IdentityType,
|
identity: messages.IdentityType,
|
||||||
challenge_hidden: bytes,
|
challenge_hidden: bytes,
|
||||||
challenge_visual: str,
|
challenge_visual: str,
|
||||||
ecdsa_curve_name: Optional[str] = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SignIdentity(
|
messages.SignIdentity(
|
||||||
identity=identity,
|
identity=identity,
|
||||||
challenge_hidden=challenge_hidden,
|
challenge_hidden=challenge_hidden,
|
||||||
@ -50,12 +50,12 @@ def sign_identity(
|
|||||||
|
|
||||||
@expect(messages.ECDHSessionKey)
|
@expect(messages.ECDHSessionKey)
|
||||||
def get_ecdh_session_key(
|
def get_ecdh_session_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
identity: messages.IdentityType,
|
identity: messages.IdentityType,
|
||||||
peer_public_key: bytes,
|
peer_public_key: bytes,
|
||||||
ecdsa_curve_name: Optional[str] = None,
|
ecdsa_curve_name: Optional[str] = None,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.GetECDHSessionKey(
|
messages.GetECDHSessionKey(
|
||||||
identity=identity,
|
identity=identity,
|
||||||
peer_public_key=peer_public_key,
|
peer_public_key=peer_public_key,
|
||||||
@ -66,7 +66,7 @@ def get_ecdh_session_key(
|
|||||||
|
|
||||||
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||||
def encrypt_keyvalue(
|
def encrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
key: str,
|
key: str,
|
||||||
value: bytes,
|
value: bytes,
|
||||||
@ -74,7 +74,7 @@ def encrypt_keyvalue(
|
|||||||
ask_on_decrypt: bool = True,
|
ask_on_decrypt: bool = True,
|
||||||
iv: bytes = b"",
|
iv: bytes = b"",
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.CipherKeyValue(
|
messages.CipherKeyValue(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
key=key,
|
key=key,
|
||||||
@ -89,7 +89,7 @@ def encrypt_keyvalue(
|
|||||||
|
|
||||||
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
|
||||||
def decrypt_keyvalue(
|
def decrypt_keyvalue(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
key: str,
|
key: str,
|
||||||
value: bytes,
|
value: bytes,
|
||||||
@ -97,7 +97,7 @@ def decrypt_keyvalue(
|
|||||||
ask_on_decrypt: bool = True,
|
ask_on_decrypt: bool = True,
|
||||||
iv: bytes = b"",
|
iv: bytes = b"",
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.CipherKeyValue(
|
messages.CipherKeyValue(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
key=key,
|
key=key,
|
||||||
@ -111,5 +111,5 @@ def decrypt_keyvalue(
|
|||||||
|
|
||||||
|
|
||||||
@expect(messages.Nonce, field="nonce", ret_type=bytes)
|
@expect(messages.Nonce, field="nonce", ret_type=bytes)
|
||||||
def get_nonce(client: "TrezorClient"):
|
def get_nonce(session: "Session"):
|
||||||
return client.call(messages.GetNonce())
|
return session.call(messages.GetNonce())
|
||||||
|
@ -20,9 +20,9 @@ from . import messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
# MAINNET = 0
|
# MAINNET = 0
|
||||||
@ -33,13 +33,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
|
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.MoneroGetAddress(
|
messages.MoneroGetAddress(
|
||||||
address_n=n,
|
address_n=n,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
@ -51,10 +51,10 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.MoneroWatchKey)
|
@expect(messages.MoneroWatchKey)
|
||||||
def get_watch_key(
|
def get_watch_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
|
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
|
||||||
)
|
)
|
||||||
|
@ -21,9 +21,9 @@ from . import exceptions, messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||||
@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
|
|||||||
|
|
||||||
@expect(messages.NEMAddress, field="address", ret_type=str)
|
@expect(messages.NEMAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
n: "Address",
|
n: "Address",
|
||||||
network: int,
|
network: int,
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.NEMGetAddress(
|
messages.NEMGetAddress(
|
||||||
address_n=n, network=network, show_display=show_display, chunkify=chunkify
|
address_n=n, network=network, show_display=show_display, chunkify=chunkify
|
||||||
)
|
)
|
||||||
@ -213,7 +213,7 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.NEMSignedTx)
|
@expect(messages.NEMSignedTx)
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
|
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
try:
|
try:
|
||||||
msg = create_sign_tx(transaction, chunkify=chunkify)
|
msg = create_sign_tx(transaction, chunkify=chunkify)
|
||||||
@ -222,4 +222,4 @@ def sign_tx(
|
|||||||
|
|
||||||
assert msg.transaction is not None
|
assert msg.transaction is not None
|
||||||
msg.transaction.address_n = n
|
msg.transaction.address_n = n
|
||||||
return client.call(msg)
|
return session.call(msg)
|
||||||
|
@ -21,9 +21,9 @@ from .protobuf import dict_to_proto
|
|||||||
from .tools import dict_from_camelcase, expect
|
from .tools import dict_from_camelcase, expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
||||||
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||||
@ -31,12 +31,12 @@ REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
|||||||
|
|
||||||
@expect(messages.RippleAddress, field="address", ret_type=str)
|
@expect(messages.RippleAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.RippleGetAddress(
|
messages.RippleGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
)
|
)
|
||||||
@ -45,14 +45,14 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.RippleSignedTx)
|
@expect(messages.RippleSignedTx)
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
msg: messages.RippleSignTx,
|
msg: messages.RippleSignTx,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
msg.address_n = address_n
|
msg.address_n = address_n
|
||||||
msg.chunkify = chunkify
|
msg.chunkify = chunkify
|
||||||
return client.call(msg)
|
return session.call(msg)
|
||||||
|
|
||||||
|
|
||||||
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
||||||
|
@ -4,29 +4,29 @@ from . import messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.SolanaPublicKey)
|
@expect(messages.SolanaPublicKey)
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
|
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.SolanaAddress)
|
@expect(messages.SolanaAddress)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
show_display: bool,
|
show_display: bool,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SolanaGetAddress(
|
messages.SolanaGetAddress(
|
||||||
address_n=address_n,
|
address_n=address_n,
|
||||||
show_display=show_display,
|
show_display=show_display,
|
||||||
@ -37,12 +37,12 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.SolanaTxSignature)
|
@expect(messages.SolanaTxSignature)
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: List[int],
|
address_n: List[int],
|
||||||
serialized_tx: bytes,
|
serialized_tx: bytes,
|
||||||
additional_info: Optional[messages.SolanaTxAdditionalInfo],
|
additional_info: Optional[messages.SolanaTxAdditionalInfo],
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.SolanaSignTx(
|
messages.SolanaSignTx(
|
||||||
address_n=address_n,
|
address_n=address_n,
|
||||||
serialized_tx=serialized_tx,
|
serialized_tx=serialized_tx,
|
||||||
|
@ -21,9 +21,9 @@ from . import exceptions, messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
StellarMessageType = Union[
|
StellarMessageType = Union[
|
||||||
messages.StellarAccountMergeOp,
|
messages.StellarAccountMergeOp,
|
||||||
@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
|
|||||||
|
|
||||||
@expect(messages.StellarAddress, field="address", ret_type=str)
|
@expect(messages.StellarAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.StellarGetAddress(
|
messages.StellarGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
)
|
)
|
||||||
@ -338,7 +338,7 @@ def get_address(
|
|||||||
|
|
||||||
|
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
tx: messages.StellarSignTx,
|
tx: messages.StellarSignTx,
|
||||||
operations: List["StellarMessageType"],
|
operations: List["StellarMessageType"],
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
@ -354,10 +354,10 @@ def sign_tx(
|
|||||||
# 3. Receive a StellarTxOpRequest message
|
# 3. Receive a StellarTxOpRequest message
|
||||||
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
||||||
# 5. The final message received will be StellarSignedTx which is returned from this method
|
# 5. The final message received will be StellarSignedTx which is returned from this method
|
||||||
resp = client.call(tx)
|
resp = session.call(tx)
|
||||||
try:
|
try:
|
||||||
while isinstance(resp, messages.StellarTxOpRequest):
|
while isinstance(resp, messages.StellarTxOpRequest):
|
||||||
resp = client.call(operations.pop(0))
|
resp = session.call(operations.pop(0))
|
||||||
except IndexError:
|
except IndexError:
|
||||||
# pop from empty list
|
# pop from empty list
|
||||||
raise exceptions.TrezorException(
|
raise exceptions.TrezorException(
|
||||||
|
@ -20,19 +20,19 @@ from . import messages
|
|||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .client import TrezorClient
|
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import Address
|
from .tools import Address
|
||||||
|
from .transport.session import Session
|
||||||
|
|
||||||
|
|
||||||
@expect(messages.TezosAddress, field="address", ret_type=str)
|
@expect(messages.TezosAddress, field="address", ret_type=str)
|
||||||
def get_address(
|
def get_address(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.TezosGetAddress(
|
messages.TezosGetAddress(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
)
|
)
|
||||||
@ -41,12 +41,12 @@ def get_address(
|
|||||||
|
|
||||||
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
|
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
|
||||||
def get_public_key(
|
def get_public_key(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
show_display: bool = False,
|
show_display: bool = False,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
return client.call(
|
return session.call(
|
||||||
messages.TezosGetPublicKey(
|
messages.TezosGetPublicKey(
|
||||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||||
)
|
)
|
||||||
@ -55,11 +55,11 @@ def get_public_key(
|
|||||||
|
|
||||||
@expect(messages.TezosSignedTx)
|
@expect(messages.TezosSignedTx)
|
||||||
def sign_tx(
|
def sign_tx(
|
||||||
client: "TrezorClient",
|
session: "Session",
|
||||||
address_n: "Address",
|
address_n: "Address",
|
||||||
sign_tx_msg: messages.TezosSignTx,
|
sign_tx_msg: messages.TezosSignTx,
|
||||||
chunkify: bool = False,
|
chunkify: bool = False,
|
||||||
) -> "MessageType":
|
) -> "MessageType":
|
||||||
sign_tx_msg.address_n = address_n
|
sign_tx_msg.address_n = address_n
|
||||||
sign_tx_msg.chunkify = chunkify
|
sign_tx_msg.chunkify = chunkify
|
||||||
return client.call(sign_tx_msg)
|
return session.call(sign_tx_msg)
|
||||||
|
@ -40,7 +40,7 @@ if TYPE_CHECKING:
|
|||||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from . import client
|
from . import client
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
@ -301,23 +301,6 @@ def expect(
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def session(
|
|
||||||
f: "Callable[Concatenate[TrezorClient, P], R]",
|
|
||||||
) -> "Callable[Concatenate[TrezorClient, P], R]":
|
|
||||||
# Decorator wraps a BaseClient method
|
|
||||||
# with session activation / deactivation
|
|
||||||
@functools.wraps(f)
|
|
||||||
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
|
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
||||||
client.open()
|
|
||||||
try:
|
|
||||||
return f(client, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
return wrapped_f
|
|
||||||
|
|
||||||
|
|
||||||
# de-camelcasifier
|
# de-camelcasifier
|
||||||
# https://stackoverflow.com/a/1176023/222189
|
# https://stackoverflow.com/a/1176023/222189
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# This file is part of the Trezor project.
|
# This file is part of the Trezor project.
|
||||||
#
|
#
|
||||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||||
#
|
#
|
||||||
# This library is free software: you can redistribute it and/or modify
|
# This library is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
@ -14,24 +14,18 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
import typing as t
|
||||||
TYPE_CHECKING,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..exceptions import TrezorException
|
from ..exceptions import TrezorException
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from ..models import TrezorModel
|
from ..models import TrezorModel
|
||||||
|
|
||||||
T = TypeVar("T", bound="Transport")
|
T = t.TypeVar("T", bound="Transport")
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
MessagePayload = Tuple[int, bytes]
|
MessagePayload = t.Tuple[int, bytes]
|
||||||
|
|
||||||
|
|
||||||
class TransportException(TrezorException):
|
class TransportException(TrezorException):
|
||||||
@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException):
|
|||||||
|
|
||||||
|
|
||||||
class Transport:
|
class Transport:
|
||||||
"""Raw connection to a Trezor device.
|
|
||||||
|
|
||||||
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
|
|
||||||
or USB-HID connection, or UDP socket of listening emulator(s).
|
|
||||||
It can also enumerate devices available over this communication link, and return
|
|
||||||
them as instances.
|
|
||||||
|
|
||||||
Transport instance is a thing that:
|
|
||||||
- can be identified and requested by a string URI-like path
|
|
||||||
- can open and close sessions, which enclose related operations
|
|
||||||
- can read and write protobuf messages
|
|
||||||
|
|
||||||
You need to implement a new Transport subclass if you invent a new way to connect
|
|
||||||
a Trezor device to a computer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATH_PREFIX: str
|
PATH_PREFIX: str
|
||||||
ENABLED = False
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
@classmethod
|
||||||
return self.get_path()
|
def enumerate(
|
||||||
|
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
|
||||||
|
) -> t.Iterable["T"]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||||
|
for device in cls.enumerate():
|
||||||
|
|
||||||
|
if device.get_path() == path:
|
||||||
|
return device
|
||||||
|
|
||||||
|
if prefix_search and device.get_path().startswith(path):
|
||||||
|
return device
|
||||||
|
|
||||||
|
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||||
|
|
||||||
def get_path(self) -> str:
|
def get_path(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def begin_session(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def end_session(self) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def find_debug(self: "T") -> "T":
|
def find_debug(self: "T") -> "T":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
def open(self) -> None:
|
||||||
def enumerate(
|
|
||||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
|
||||||
) -> Iterable["T"]:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
def close(self) -> None:
|
||||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
raise NotImplementedError
|
||||||
for device in cls.enumerate():
|
|
||||||
if (
|
|
||||||
path is None
|
|
||||||
or device.get_path() == path
|
|
||||||
or (prefix_search and device.get_path().startswith(path))
|
|
||||||
):
|
|
||||||
return device
|
|
||||||
|
|
||||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
def write_chunk(self, chunk: bytes) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def read_chunk(self) -> bytes:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
CHUNK_SIZE: t.ClassVar[int]
|
||||||
|
|
||||||
|
|
||||||
def all_transports() -> Iterable[Type["Transport"]]:
|
def all_transports() -> t.Iterable[t.Type["Transport"]]:
|
||||||
from .bridge import BridgeTransport
|
from .bridge import BridgeTransport
|
||||||
from .hid import HidTransport
|
from .hid import HidTransport
|
||||||
from .udp import UdpTransport
|
from .udp import UdpTransport
|
||||||
from .webusb import WebUsbTransport
|
from .webusb import WebUsbTransport
|
||||||
|
|
||||||
transports: Tuple[Type["Transport"], ...] = (
|
transports: t.Tuple[t.Type["Transport"], ...] = (
|
||||||
BridgeTransport,
|
BridgeTransport,
|
||||||
HidTransport,
|
HidTransport,
|
||||||
UdpTransport,
|
UdpTransport,
|
||||||
@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
|||||||
|
|
||||||
|
|
||||||
def enumerate_devices(
|
def enumerate_devices(
|
||||||
models: Optional[Iterable["TrezorModel"]] = None,
|
models: t.Iterable["TrezorModel"] | None = None,
|
||||||
) -> Sequence["Transport"]:
|
) -> t.Sequence["Transport"]:
|
||||||
devices: List["Transport"] = []
|
devices: t.List["Transport"] = []
|
||||||
for transport in all_transports():
|
for transport in all_transports():
|
||||||
name = transport.__name__
|
name = transport.__name__
|
||||||
try:
|
try:
|
||||||
@ -145,9 +121,7 @@ def enumerate_devices(
|
|||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def get_transport(
|
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
|
||||||
path: Optional[str] = None, prefix_search: bool = False
|
|
||||||
) -> "Transport":
|
|
||||||
if path is None:
|
if path is None:
|
||||||
try:
|
try:
|
||||||
return next(iter(enumerate_devices()))
|
return next(iter(enumerate_devices()))
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# This file is part of the Trezor project.
|
# This file is part of the Trezor project.
|
||||||
#
|
#
|
||||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||||
#
|
#
|
||||||
# This library is free software: you can redistribute it and/or modify
|
# This library is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
@ -14,24 +14,30 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
import typing as t
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from ..models import TrezorModel
|
from ..models import TrezorModel
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PROTOCOL_VERSION_1 = 1
|
||||||
|
PROTOCOL_VERSION_2 = 2
|
||||||
|
|
||||||
TREZORD_HOST = "http://127.0.0.1:21325"
|
TREZORD_HOST = "http://127.0.0.1:21325"
|
||||||
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
|
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
|
||||||
|
|
||||||
TREZORD_VERSION_MODERN = (2, 0, 25)
|
TREZORD_VERSION_MODERN = (2, 0, 25)
|
||||||
|
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
|
||||||
|
|
||||||
CONNECTION = requests.Session()
|
CONNECTION = requests.Session()
|
||||||
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
||||||
@ -45,7 +51,7 @@ class BridgeException(TransportException):
|
|||||||
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
||||||
|
|
||||||
|
|
||||||
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
def call_bridge(path: str, data: str | None = None) -> requests.Response:
|
||||||
url = TREZORD_HOST + "/" + path
|
url = TREZORD_HOST + "/" + path
|
||||||
r = CONNECTION.post(url, data=data)
|
r = CONNECTION.post(url, data=data)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def is_legacy_bridge() -> bool:
|
def get_bridge_version() -> t.Tuple[int, ...]:
|
||||||
config = call_bridge("configure").json()
|
config = call_bridge("configure").json()
|
||||||
version_tuple = tuple(map(int, config["version"].split(".")))
|
return tuple(map(int, config["version"].split(".")))
|
||||||
return version_tuple < TREZORD_VERSION_MODERN
|
|
||||||
|
|
||||||
|
def is_legacy_bridge() -> bool:
|
||||||
|
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||||
|
|
||||||
|
|
||||||
|
def supports_protocolV2() -> bool:
|
||||||
|
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
|
||||||
|
|
||||||
|
|
||||||
|
def detect_protocol_version(transport: "BridgeTransport") -> int:
|
||||||
|
from .. import mapping, messages
|
||||||
|
from ..messages import FailureType
|
||||||
|
|
||||||
|
protocol_version = PROTOCOL_VERSION_1
|
||||||
|
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
||||||
|
transport.deprecated_begin_session()
|
||||||
|
transport.deprecated_write(request_type, request_data)
|
||||||
|
|
||||||
|
response_type, response_data = transport.deprecated_read()
|
||||||
|
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||||
|
transport.deprecated_begin_session()
|
||||||
|
if isinstance(response, messages.Failure):
|
||||||
|
if response.code == FailureType.InvalidProtocol:
|
||||||
|
LOG.debug("Protocol V2 detected")
|
||||||
|
protocol_version = PROTOCOL_VERSION_2
|
||||||
|
|
||||||
|
return protocol_version
|
||||||
|
|
||||||
|
|
||||||
|
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
||||||
|
is_valid = (
|
||||||
|
supports_protocolV2()
|
||||||
|
or detect_protocol_version(transport) == PROTOCOL_VERSION_1
|
||||||
|
)
|
||||||
|
if not is_valid:
|
||||||
|
LOG.warning("Detected unsupported Bridge transport!")
|
||||||
|
return is_valid
|
||||||
|
|
||||||
|
|
||||||
|
def filter_invalid_bridge_transports(
|
||||||
|
transports: t.Iterable["BridgeTransport"],
|
||||||
|
) -> t.Sequence["BridgeTransport"]:
|
||||||
|
"""Filters out invalid bridge transports. Keeps only valid ones."""
|
||||||
|
return [t for t in transports if _is_transport_valid(t)]
|
||||||
|
|
||||||
|
|
||||||
class BridgeHandle:
|
class BridgeHandle:
|
||||||
@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle):
|
|||||||
class BridgeHandleLegacy(BridgeHandle):
|
class BridgeHandleLegacy(BridgeHandle):
|
||||||
def __init__(self, transport: "BridgeTransport") -> None:
|
def __init__(self, transport: "BridgeTransport") -> None:
|
||||||
super().__init__(transport)
|
super().__init__(transport)
|
||||||
self.request: Optional[str] = None
|
self.request: str | None = None
|
||||||
|
|
||||||
def write_buf(self, buf: bytes) -> None:
|
def write_buf(self, buf: bytes) -> None:
|
||||||
if self.request is not None:
|
if self.request is not None:
|
||||||
@ -112,13 +162,12 @@ class BridgeTransport(Transport):
|
|||||||
ENABLED: bool = True
|
ENABLED: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
if legacy and debug:
|
if legacy and debug:
|
||||||
raise TransportException("Debugging not supported on legacy Bridge")
|
raise TransportException("Debugging not supported on legacy Bridge")
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.session: Optional[str] = None
|
self.session: str | None = device["session"]
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.legacy = legacy
|
self.legacy = legacy
|
||||||
|
|
||||||
@ -135,7 +184,7 @@ class BridgeTransport(Transport):
|
|||||||
raise TransportException("Debug device not available")
|
raise TransportException("Debug device not available")
|
||||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
return BridgeTransport(self.device, self.legacy, debug=True)
|
||||||
|
|
||||||
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
def _call(self, action: str, data: str | None = None) -> requests.Response:
|
||||||
session = self.session or "null"
|
session = self.session or "null"
|
||||||
uri = action + "/" + str(session)
|
uri = action + "/" + str(session)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
@ -144,17 +193,20 @@ class BridgeTransport(Transport):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(
|
def enumerate(
|
||||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
cls, _models: t.Iterable["TrezorModel"] | None = None
|
||||||
) -> Iterable["BridgeTransport"]:
|
) -> t.Iterable["BridgeTransport"]:
|
||||||
try:
|
try:
|
||||||
legacy = is_legacy_bridge()
|
legacy = is_legacy_bridge()
|
||||||
return [
|
return filter_invalid_bridge_transports(
|
||||||
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json()
|
[
|
||||||
]
|
BridgeTransport(dev, legacy)
|
||||||
|
for dev in call_bridge("enumerate").json()
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def begin_session(self) -> None:
|
def deprecated_begin_session(self) -> None:
|
||||||
try:
|
try:
|
||||||
data = self._call("acquire/" + self.device["path"])
|
data = self._call("acquire/" + self.device["path"])
|
||||||
except BridgeException as e:
|
except BridgeException as e:
|
||||||
@ -163,18 +215,32 @@ class BridgeTransport(Transport):
|
|||||||
raise
|
raise
|
||||||
self.session = data.json()["session"]
|
self.session = data.json()["session"]
|
||||||
|
|
||||||
def end_session(self) -> None:
|
def deprecated_end_session(self) -> None:
|
||||||
if not self.session:
|
if not self.session:
|
||||||
return
|
return
|
||||||
self._call("release")
|
self._call("release")
|
||||||
self.session = None
|
self.session = None
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def deprecated_write(self, message_type: int, message_data: bytes) -> None:
|
||||||
header = struct.pack(">HL", message_type, len(message_data))
|
header = struct.pack(">HL", message_type, len(message_data))
|
||||||
self.handle.write_buf(header + message_data)
|
self.handle.write_buf(header + message_data)
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def deprecated_read(self) -> MessagePayload:
|
||||||
data = self.handle.read_buf()
|
data = self.handle.read_buf()
|
||||||
headerlen = struct.calcsize(">HL")
|
headerlen = struct.calcsize(">HL")
|
||||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||||
return msg_type, data[headerlen : headerlen + datalen]
|
return msg_type, data[headerlen : headerlen + datalen]
|
||||||
|
|
||||||
|
def open(self) -> None:
|
||||||
|
pass
|
||||||
|
# TODO self.handle.open()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
pass
|
||||||
|
# TODO self.handle.close()
|
||||||
|
|
||||||
|
def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :)
|
||||||
|
self.handle.write_buf(chunk)
|
||||||
|
|
||||||
|
def read_chunk(self) -> bytes: # TODO check if it works :)
|
||||||
|
return self.handle.read_buf()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# This file is part of the Trezor project.
|
# This file is part of the Trezor project.
|
||||||
#
|
#
|
||||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||||
#
|
#
|
||||||
# This library is free software: you can redistribute it and/or modify
|
# This library is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
@ -14,15 +14,16 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
import typing as t
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from ..models import TREZOR_ONE, TrezorModel
|
from ..models import TREZOR_ONE, TrezorModel
|
||||||
from . import UDEV_RULES_STR, TransportException
|
from . import UDEV_RULES_STR, Transport, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -35,23 +36,61 @@ except Exception as e:
|
|||||||
HID_IMPORTED = False
|
HID_IMPORTED = False
|
||||||
|
|
||||||
|
|
||||||
HidDevice = Dict[str, Any]
|
HidDevice = t.Dict[str, t.Any]
|
||||||
HidDeviceHandle = Any
|
HidDeviceHandle = t.Any
|
||||||
|
|
||||||
|
|
||||||
class HidHandle:
|
class HidTransport(Transport):
|
||||||
def __init__(
|
"""
|
||||||
self, path: bytes, serial: str, probe_hid_version: bool = False
|
HidTransport implements transport over USB HID interface.
|
||||||
) -> None:
|
"""
|
||||||
self.path = path
|
|
||||||
self.serial = serial
|
PATH_PREFIX = "hid"
|
||||||
|
ENABLED = HID_IMPORTED
|
||||||
|
|
||||||
|
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.device_path = device["path"]
|
||||||
|
self.device_serial_number = device["serial_number"]
|
||||||
self.handle: HidDeviceHandle = None
|
self.handle: HidDeviceHandle = None
|
||||||
self.hid_version = None if probe_hid_version else 2
|
self.hid_version = None if probe_hid_version else 2
|
||||||
|
|
||||||
|
def get_path(self) -> str:
|
||||||
|
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enumerate(
|
||||||
|
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
|
||||||
|
) -> t.Iterable["HidTransport"]:
|
||||||
|
if models is None:
|
||||||
|
models = {TREZOR_ONE}
|
||||||
|
usb_ids = [id for model in models for id in model.usb_ids]
|
||||||
|
|
||||||
|
devices: t.List["HidTransport"] = []
|
||||||
|
for dev in hid.enumerate(0, 0):
|
||||||
|
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||||
|
if usb_id not in usb_ids:
|
||||||
|
continue
|
||||||
|
if debug:
|
||||||
|
if not is_debuglink(dev):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if not is_wirelink(dev):
|
||||||
|
continue
|
||||||
|
devices.append(HidTransport(dev))
|
||||||
|
return devices
|
||||||
|
|
||||||
|
def find_debug(self) -> "HidTransport":
|
||||||
|
# For v1 protocol, find debug USB interface for the same serial number
|
||||||
|
for debug in HidTransport.enumerate(debug=True):
|
||||||
|
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||||
|
return debug
|
||||||
|
raise TransportException("Debug HID device not found")
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.handle = hid.device()
|
self.handle = hid.device()
|
||||||
try:
|
try:
|
||||||
self.handle.open_path(self.path)
|
self.handle.open_path(self.device_path)
|
||||||
except (IOError, OSError) as e:
|
except (IOError, OSError) as e:
|
||||||
if sys.platform.startswith("linux"):
|
if sys.platform.startswith("linux"):
|
||||||
e.args = e.args + (UDEV_RULES_STR,)
|
e.args = e.args + (UDEV_RULES_STR,)
|
||||||
@ -62,11 +101,11 @@ class HidHandle:
|
|||||||
# and we wouldn't even know.
|
# and we wouldn't even know.
|
||||||
# So we check that the serial matches what we expect.
|
# So we check that the serial matches what we expect.
|
||||||
serial = self.handle.get_serial_number_string()
|
serial = self.handle.get_serial_number_string()
|
||||||
if serial != self.serial:
|
if serial != self.device_serial_number:
|
||||||
self.handle.close()
|
self.handle.close()
|
||||||
self.handle = None
|
self.handle = None
|
||||||
raise TransportException(
|
raise TransportException(
|
||||||
f"Unexpected device {serial} on path {self.path.decode()}"
|
f"Unexpected device {serial} on path {self.device_path.decode()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.handle.set_nonblocking(True)
|
self.handle.set_nonblocking(True)
|
||||||
@ -77,7 +116,7 @@ class HidHandle:
|
|||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self.handle is not None:
|
if self.handle is not None:
|
||||||
# reload serial, because device.wipe() can reset it
|
# reload serial, because device.wipe() can reset it
|
||||||
self.serial = self.handle.get_serial_number_string()
|
self.device_serial_number = self.handle.get_serial_number_string()
|
||||||
self.handle.close()
|
self.handle.close()
|
||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
@ -115,53 +154,6 @@ class HidHandle:
|
|||||||
raise TransportException("Unknown HID version")
|
raise TransportException("Unknown HID version")
|
||||||
|
|
||||||
|
|
||||||
class HidTransport(ProtocolBasedTransport):
|
|
||||||
"""
|
|
||||||
HidTransport implements transport over USB HID interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATH_PREFIX = "hid"
|
|
||||||
ENABLED = HID_IMPORTED
|
|
||||||
|
|
||||||
def __init__(self, device: HidDevice) -> None:
|
|
||||||
self.device = device
|
|
||||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
|
||||||
|
|
||||||
super().__init__(protocol=ProtocolV1(self.handle))
|
|
||||||
|
|
||||||
def get_path(self) -> str:
|
|
||||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def enumerate(
|
|
||||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
|
||||||
) -> Iterable["HidTransport"]:
|
|
||||||
if models is None:
|
|
||||||
models = {TREZOR_ONE}
|
|
||||||
usb_ids = [id for model in models for id in model.usb_ids]
|
|
||||||
|
|
||||||
devices: List["HidTransport"] = []
|
|
||||||
for dev in hid.enumerate(0, 0):
|
|
||||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
|
||||||
if usb_id not in usb_ids:
|
|
||||||
continue
|
|
||||||
if debug:
|
|
||||||
if not is_debuglink(dev):
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if not is_wirelink(dev):
|
|
||||||
continue
|
|
||||||
devices.append(HidTransport(dev))
|
|
||||||
return devices
|
|
||||||
|
|
||||||
def find_debug(self) -> "HidTransport":
|
|
||||||
# For v1 protocol, find debug USB interface for the same serial number
|
|
||||||
for debug in HidTransport.enumerate(debug=True):
|
|
||||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
|
||||||
return debug
|
|
||||||
raise TransportException("Debug HID device not found")
|
|
||||||
|
|
||||||
|
|
||||||
def is_wirelink(dev: HidDevice) -> bool:
|
def is_wirelink(dev: HidDevice) -> bool:
|
||||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||||
|
|
||||||
|
@ -1,165 +0,0 @@
|
|||||||
# This file is part of the Trezor project.
|
|
||||||
#
|
|
||||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
|
||||||
#
|
|
||||||
# This library is free software: you can redistribute it and/or modify
|
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
|
||||||
# as published by the Free Software Foundation.
|
|
||||||
#
|
|
||||||
# This library is distributed in the hope that it will be useful,
|
|
||||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
# GNU Lesser General Public License for more details.
|
|
||||||
#
|
|
||||||
# You should have received a copy of the License along with this library.
|
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import struct
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from typing_extensions import Protocol as StructuralType
|
|
||||||
|
|
||||||
from . import MessagePayload, Transport
|
|
||||||
|
|
||||||
REPLEN = 64
|
|
||||||
|
|
||||||
V2_FIRST_CHUNK = 0x01
|
|
||||||
V2_NEXT_CHUNK = 0x02
|
|
||||||
V2_BEGIN_SESSION = 0x03
|
|
||||||
V2_END_SESSION = 0x04
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Handle(StructuralType):
|
|
||||||
"""PEP 544 structural type for Handle functionality.
|
|
||||||
(called a "Protocol" in the proposed PEP, name which is impractical here)
|
|
||||||
|
|
||||||
Handle is a "physical" layer for a protocol.
|
|
||||||
It can open/close a connection and read/write bare data in 64-byte chunks.
|
|
||||||
|
|
||||||
Functionally we gain nothing from making this an (abstract) base class for handle
|
|
||||||
implementations, so this definition is for type hinting purposes only. You can,
|
|
||||||
but don't have to, inherit from it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def open(self) -> None: ...
|
|
||||||
|
|
||||||
def close(self) -> None: ...
|
|
||||||
|
|
||||||
def read_chunk(self) -> bytes: ...
|
|
||||||
|
|
||||||
def write_chunk(self, chunk: bytes) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class Protocol:
|
|
||||||
"""Wire protocol that can communicate with a Trezor device, given a Handle.
|
|
||||||
|
|
||||||
A Protocol implements the part of the Transport API that relates to communicating
|
|
||||||
logical messages over a physical layer. It is a thing that can:
|
|
||||||
- open and close sessions,
|
|
||||||
- send and receive protobuf messages,
|
|
||||||
given the ability to:
|
|
||||||
- open and close physical connections,
|
|
||||||
- and send and receive binary chunks.
|
|
||||||
|
|
||||||
For now, the class also handles session counting and opening the underlying Handle.
|
|
||||||
This will probably be removed in the future.
|
|
||||||
|
|
||||||
We will need a new Protocol class if we change the way a Trezor device encapsulates
|
|
||||||
its messages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, handle: Handle) -> None:
|
|
||||||
self.handle = handle
|
|
||||||
self.session_counter = 0
|
|
||||||
|
|
||||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
|
||||||
def begin_session(self) -> None:
|
|
||||||
if self.session_counter == 0:
|
|
||||||
self.handle.open()
|
|
||||||
self.session_counter += 1
|
|
||||||
|
|
||||||
def end_session(self) -> None:
|
|
||||||
self.session_counter = max(self.session_counter - 1, 0)
|
|
||||||
if self.session_counter == 0:
|
|
||||||
self.handle.close()
|
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolBasedTransport(Transport):
|
|
||||||
"""Transport that implements its communications through a Protocol.
|
|
||||||
|
|
||||||
Intended as a base class for implementations that proxy their communication
|
|
||||||
operations to a Protocol.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, protocol: Protocol) -> None:
|
|
||||||
self.protocol = protocol
|
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
|
||||||
self.protocol.write(message_type, message_data)
|
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
|
||||||
return self.protocol.read()
|
|
||||||
|
|
||||||
def begin_session(self) -> None:
|
|
||||||
self.protocol.begin_session()
|
|
||||||
|
|
||||||
def end_session(self) -> None:
|
|
||||||
self.protocol.end_session()
|
|
||||||
|
|
||||||
|
|
||||||
class ProtocolV1(Protocol):
|
|
||||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
|
||||||
Does not understand sessions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
HEADER_LEN = struct.calcsize(">HL")
|
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
|
||||||
header = struct.pack(">HL", message_type, len(message_data))
|
|
||||||
buffer = bytearray(b"##" + header + message_data)
|
|
||||||
|
|
||||||
while buffer:
|
|
||||||
# Report ID, data padded to 63 bytes
|
|
||||||
chunk = b"?" + buffer[: REPLEN - 1]
|
|
||||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
|
||||||
self.handle.write_chunk(chunk)
|
|
||||||
buffer = buffer[63:]
|
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
|
||||||
buffer = bytearray()
|
|
||||||
# Read header with first part of message data
|
|
||||||
msg_type, datalen, first_chunk = self.read_first()
|
|
||||||
buffer.extend(first_chunk)
|
|
||||||
|
|
||||||
# Read the rest of the message
|
|
||||||
while len(buffer) < datalen:
|
|
||||||
buffer.extend(self.read_next())
|
|
||||||
|
|
||||||
return msg_type, buffer[:datalen]
|
|
||||||
|
|
||||||
def read_first(self) -> Tuple[int, int, bytes]:
|
|
||||||
chunk = self.handle.read_chunk()
|
|
||||||
if chunk[:3] != b"?##":
|
|
||||||
raise RuntimeError("Unexpected magic characters")
|
|
||||||
try:
|
|
||||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
|
||||||
except Exception:
|
|
||||||
raise RuntimeError("Cannot parse header")
|
|
||||||
|
|
||||||
data = chunk[3 + self.HEADER_LEN :]
|
|
||||||
return msg_type, datalen, data
|
|
||||||
|
|
||||||
def read_next(self) -> bytes:
|
|
||||||
chunk = self.handle.read_chunk()
|
|
||||||
if chunk[:1] != b"?":
|
|
||||||
raise RuntimeError("Unexpected magic characters")
|
|
||||||
return chunk[1:]
|
|
210
python/src/trezorlib/transport/session.py
Normal file
210
python/src/trezorlib/transport/session.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from .. import exceptions, messages, models
|
||||||
|
from .thp.protocol_v1 import ProtocolV1
|
||||||
|
from .thp.protocol_v2 import ProtocolV2
|
||||||
|
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
from ..client import TrezorClient
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Session:
|
||||||
|
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, client: TrezorClient, id: bytes, passphrase: str | object | None = None
|
||||||
|
) -> None:
|
||||||
|
self.client = client
|
||||||
|
self._id = id
|
||||||
|
self.passphrase = passphrase
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(
|
||||||
|
cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool
|
||||||
|
) -> Session:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def call(self, msg: t.Any) -> t.Any:
|
||||||
|
# TODO self.check_firmware_version()
|
||||||
|
resp = self.call_raw(msg)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if isinstance(resp, messages.PinMatrixRequest):
|
||||||
|
if self.pin_callback is None:
|
||||||
|
raise Exception # TODO
|
||||||
|
resp = self.pin_callback(self, resp)
|
||||||
|
elif isinstance(resp, messages.PassphraseRequest):
|
||||||
|
if self.passphrase_callback is None:
|
||||||
|
raise Exception # TODO
|
||||||
|
resp = self.passphrase_callback(self, resp)
|
||||||
|
elif isinstance(resp, messages.ButtonRequest):
|
||||||
|
if self.button_callback is None:
|
||||||
|
raise Exception # TODO
|
||||||
|
resp = self.button_callback(self, resp)
|
||||||
|
elif isinstance(resp, messages.Failure):
|
||||||
|
if resp.code == messages.FailureType.ActionCancelled:
|
||||||
|
raise exceptions.Cancelled
|
||||||
|
raise exceptions.TrezorFailure(resp)
|
||||||
|
else:
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def call_raw(self, msg: t.Any) -> t.Any:
|
||||||
|
self._write(msg)
|
||||||
|
return self._read()
|
||||||
|
|
||||||
|
def _write(self, msg: t.Any) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _read(self) -> t.Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def refresh_features(self) -> None:
|
||||||
|
self.client.refresh_features()
|
||||||
|
|
||||||
|
def end(self) -> t.Any:
|
||||||
|
return self.call(messages.EndSession())
|
||||||
|
|
||||||
|
def ping(self, message: str, button_protection: bool | None = None) -> str:
|
||||||
|
resp: messages.Success = self.call(
|
||||||
|
messages.Ping(message=message, button_protection=button_protection)
|
||||||
|
)
|
||||||
|
return resp.message or ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> messages.Features:
|
||||||
|
return self.client.features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> models.TrezorModel:
|
||||||
|
return self.client.model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def version(self) -> t.Tuple[int, int, int]:
|
||||||
|
return self.client.version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> bytes:
|
||||||
|
return self._id
|
||||||
|
|
||||||
|
@id.setter
|
||||||
|
def id(self, value: bytes) -> None:
|
||||||
|
if not isinstance(value, bytes):
|
||||||
|
raise ValueError("id must be of type bytes")
|
||||||
|
self._id = value
|
||||||
|
|
||||||
|
|
||||||
|
class SessionV1(Session):
|
||||||
|
derive_cardano: bool | None = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(
|
||||||
|
cls,
|
||||||
|
client: TrezorClient,
|
||||||
|
passphrase: str | object = "",
|
||||||
|
derive_cardano: bool = False,
|
||||||
|
session_id: bytes | None = None,
|
||||||
|
) -> SessionV1:
|
||||||
|
assert isinstance(client.protocol, ProtocolV1)
|
||||||
|
session = SessionV1(client, id=session_id or b"")
|
||||||
|
|
||||||
|
session._init_callbacks()
|
||||||
|
session.passphrase = passphrase
|
||||||
|
session.derive_cardano = derive_cardano
|
||||||
|
session.init_session(session.derive_cardano)
|
||||||
|
return session
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
|
||||||
|
assert isinstance(client.protocol, ProtocolV1)
|
||||||
|
session = SessionV1(client, session_id)
|
||||||
|
session.init_session()
|
||||||
|
return session
|
||||||
|
|
||||||
|
def _init_callbacks(self) -> None:
|
||||||
|
self.button_callback = self.client.button_callback
|
||||||
|
if self.button_callback is None:
|
||||||
|
self.button_callback = _callback_button
|
||||||
|
self.pin_callback = self.client.pin_callback
|
||||||
|
self.passphrase_callback = self.client.passphrase_callback
|
||||||
|
|
||||||
|
def _write(self, msg: t.Any) -> None:
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
assert isinstance(self.client.protocol, ProtocolV1)
|
||||||
|
self.client.protocol.write(msg)
|
||||||
|
|
||||||
|
def _read(self) -> t.Any:
|
||||||
|
if t.TYPE_CHECKING:
|
||||||
|
assert isinstance(self.client.protocol, ProtocolV1)
|
||||||
|
return self.client.protocol.read()
|
||||||
|
|
||||||
|
def init_session(self, derive_cardano: bool | None = None):
|
||||||
|
if self.id == b"":
|
||||||
|
session_id = None
|
||||||
|
else:
|
||||||
|
session_id = self.id
|
||||||
|
resp: messages.Features = self.call_raw(
|
||||||
|
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
||||||
|
)
|
||||||
|
if isinstance(self.passphrase, str):
|
||||||
|
self.passphrase_callback = _send_passphrase
|
||||||
|
self._id = resp.session_id
|
||||||
|
|
||||||
|
|
||||||
|
def _send_passphrase(session: Session, resp: t.Any) -> None:
|
||||||
|
assert isinstance(session.passphrase, str)
|
||||||
|
return session.call(messages.PassphraseAck(passphrase=session.passphrase))
|
||||||
|
|
||||||
|
|
||||||
|
def _callback_button(session: Session, msg: t.Any) -> t.Any:
|
||||||
|
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
|
||||||
|
return session.call(messages.ButtonAck())
|
||||||
|
|
||||||
|
|
||||||
|
class SessionV2(Session):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new(
|
||||||
|
cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool
|
||||||
|
) -> SessionV2:
|
||||||
|
assert isinstance(client.protocol, ProtocolV2)
|
||||||
|
session = cls(client, b"\x00")
|
||||||
|
new_session: messages.ThpNewSession = session.call(
|
||||||
|
messages.ThpCreateNewSession(
|
||||||
|
passphrase=passphrase, derive_cardano=derive_cardano
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert new_session.new_session_id is not None
|
||||||
|
session_id = new_session.new_session_id
|
||||||
|
session.update_id_and_sid(session_id.to_bytes(1, "big"))
|
||||||
|
return session
|
||||||
|
|
||||||
|
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||||
|
super().__init__(client, id)
|
||||||
|
assert isinstance(client.protocol, ProtocolV2)
|
||||||
|
|
||||||
|
self.pin_callback = client.pin_callback
|
||||||
|
self.button_callback = client.button_callback
|
||||||
|
if self.button_callback is None:
|
||||||
|
self.button_callback = _callback_button
|
||||||
|
self.channel: ProtocolV2 = client.protocol.get_channel()
|
||||||
|
self.update_id_and_sid(id)
|
||||||
|
|
||||||
|
def _write(self, msg: t.Any) -> None:
|
||||||
|
LOG.debug("writing message %s", type(msg))
|
||||||
|
self.channel.write(self.sid, msg)
|
||||||
|
|
||||||
|
def _read(self) -> t.Any:
|
||||||
|
msg = self.channel.read(self.sid)
|
||||||
|
LOG.debug("reading message %s", type(msg))
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def update_id_and_sid(self, id: bytes) -> None:
|
||||||
|
self._id = id
|
||||||
|
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid
|
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# from storage.cache_thp import ChannelCache
|
||||||
|
# from trezor import log
|
||||||
|
# from trezor.wire.thp import ThpError
|
||||||
|
|
||||||
|
|
||||||
|
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
|
||||||
|
# """
|
||||||
|
# Checks if:
|
||||||
|
# - an ACK message is expected
|
||||||
|
# - the received ACK message acknowledges correct sequence number (bit)
|
||||||
|
# """
|
||||||
|
# if not _is_ack_expected(cache):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# if not _has_ack_correct_sync_bit(cache, ack_bit):
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# return True
|
||||||
|
|
||||||
|
|
||||||
|
# def _is_ack_expected(cache: ChannelCache) -> bool:
|
||||||
|
# is_expected: bool = not is_sending_allowed(cache)
|
||||||
|
# if __debug__ and not is_expected:
|
||||||
|
# log.debug(__name__, "Received unexpected ACK message")
|
||||||
|
# return is_expected
|
||||||
|
|
||||||
|
|
||||||
|
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
|
||||||
|
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
|
||||||
|
# if __debug__ and not is_correct:
|
||||||
|
# log.debug(__name__, "Received ACK message with wrong ack bit")
|
||||||
|
# return is_correct
|
||||||
|
|
||||||
|
|
||||||
|
# def is_sending_allowed(cache: ChannelCache) -> bool:
|
||||||
|
# """
|
||||||
|
# Checks whether sending a message in the provided channel is allowed.
|
||||||
|
|
||||||
|
# Note: Sending a message in a channel before receipt of ACK message for the previously
|
||||||
|
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
|
||||||
|
# """
|
||||||
|
# return bool(cache.sync >> 7)
|
||||||
|
|
||||||
|
|
||||||
|
# def get_send_seq_bit(cache: ChannelCache) -> int:
|
||||||
|
# """
|
||||||
|
# Returns the sequential number (bit) of the next message to be sent
|
||||||
|
# in the provided channel.
|
||||||
|
# """
|
||||||
|
# return (cache.sync & 0x20) >> 5
|
||||||
|
|
||||||
|
|
||||||
|
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
|
||||||
|
# """
|
||||||
|
# Returns the (expected) sequential number (bit) of the next message
|
||||||
|
# to be received in the provided channel.
|
||||||
|
# """
|
||||||
|
# return (cache.sync & 0x40) >> 6
|
||||||
|
|
||||||
|
|
||||||
|
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
|
||||||
|
# """
|
||||||
|
# Set the flag whether sending a message in this channel is allowed or not.
|
||||||
|
# """
|
||||||
|
# cache.sync &= 0x7F
|
||||||
|
# if sending_allowed:
|
||||||
|
# cache.sync |= 0x80
|
||||||
|
|
||||||
|
|
||||||
|
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||||
|
# """
|
||||||
|
# Set the expected sequential number (bit) of the next message to be received
|
||||||
|
# in the provided channel
|
||||||
|
# """
|
||||||
|
# if __debug__:
|
||||||
|
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
|
||||||
|
# if seq_bit not in (0, 1):
|
||||||
|
# raise ThpError("Unexpected receive sync bit")
|
||||||
|
|
||||||
|
# # set second bit to "seq_bit" value
|
||||||
|
# cache.sync &= 0xBF
|
||||||
|
# if seq_bit:
|
||||||
|
# cache.sync |= 0x40
|
||||||
|
|
||||||
|
|
||||||
|
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||||
|
# if seq_bit not in (0, 1):
|
||||||
|
# raise ThpError("Unexpected send seq bit")
|
||||||
|
# if __debug__:
|
||||||
|
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
|
||||||
|
# # set third bit to "seq_bit" value
|
||||||
|
# cache.sync &= 0xDF
|
||||||
|
# if seq_bit:
|
||||||
|
# cache.sync |= 0x20
|
||||||
|
|
||||||
|
|
||||||
|
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
|
||||||
|
# """
|
||||||
|
# Set the sequential bit of the "next message to be send" to the opposite value,
|
||||||
|
# i.e. 1 -> 0 and 0 -> 1
|
||||||
|
# """
|
||||||
|
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))
|
40
python/src/trezorlib/transport/thp/channel_data.py
Normal file
40
python/src/trezorlib/transport/thp/channel_data.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from binascii import hexlify
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelData:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
protocol_version: int,
|
||||||
|
transport_path: str,
|
||||||
|
channel_id: int,
|
||||||
|
key_request: bytes,
|
||||||
|
key_response: bytes,
|
||||||
|
nonce_request: int,
|
||||||
|
nonce_response: int,
|
||||||
|
sync_bit_send: int,
|
||||||
|
sync_bit_receive: int,
|
||||||
|
) -> None:
|
||||||
|
self.protocol_version: int = protocol_version
|
||||||
|
self.transport_path: str = transport_path
|
||||||
|
self.channel_id: int = channel_id
|
||||||
|
self.key_request: str = hexlify(key_request).decode()
|
||||||
|
self.key_response: str = hexlify(key_response).decode()
|
||||||
|
self.nonce_request: int = nonce_request
|
||||||
|
self.nonce_response: int = nonce_response
|
||||||
|
self.sync_bit_receive: int = sync_bit_receive
|
||||||
|
self.sync_bit_send: int = sync_bit_send
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"protocol_version": self.protocol_version,
|
||||||
|
"transport_path": self.transport_path,
|
||||||
|
"channel_id": self.channel_id,
|
||||||
|
"key_request": self.key_request,
|
||||||
|
"key_response": self.key_response,
|
||||||
|
"nonce_request": self.nonce_request,
|
||||||
|
"nonce_response": self.nonce_response,
|
||||||
|
"sync_bit_send": self.sync_bit_send,
|
||||||
|
"sync_bit_receive": self.sync_bit_receive,
|
||||||
|
}
|
146
python/src/trezorlib/transport/thp/channel_database.py
Normal file
146
python/src/trezorlib/transport/thp/channel_database.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from ..thp.channel_data import ChannelData
|
||||||
|
from .protocol_and_channel import ProtocolAndChannel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
db: "ChannelDatabase | None" = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_channel_db() -> ChannelDatabase:
|
||||||
|
if db is None:
|
||||||
|
set_channel_database(should_not_store=True)
|
||||||
|
assert db is not None
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDatabase:
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]: ...
|
||||||
|
def clear_stored_channels(self) -> None: ...
|
||||||
|
def read_all_channels(self) -> t.List: ...
|
||||||
|
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
|
||||||
|
def save_channel(self, new_channel: ProtocolAndChannel): ...
|
||||||
|
def remove_channel(self, transport_path: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChannelDatabase(ChannelDatabase):
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def read_all_channels(self) -> t.List:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JsonChannelDatabase(ChannelDatabase):
|
||||||
|
def __init__(self, data_path: str) -> None:
|
||||||
|
self.data_path = data_path
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||||
|
dicts = self.read_all_channels()
|
||||||
|
return [dict_to_channel_data(d) for d in dicts]
|
||||||
|
|
||||||
|
def clear_stored_channels(self) -> None:
|
||||||
|
LOG.debug("Clearing contents of %s", self.data_path)
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
try:
|
||||||
|
os.remove(self.data_path)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
|
||||||
|
|
||||||
|
def read_all_channels(self) -> t.List:
|
||||||
|
ensure_file_exists(self.data_path)
|
||||||
|
with open(self.data_path, "r") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||||
|
LOG.debug("saving all channels")
|
||||||
|
with open(self.data_path, "w") as f:
|
||||||
|
json.dump(channels, f, indent=4)
|
||||||
|
|
||||||
|
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||||
|
|
||||||
|
LOG.debug("save channel")
|
||||||
|
channels = self.read_all_channels()
|
||||||
|
transport_path = new_channel.transport.get_path()
|
||||||
|
|
||||||
|
# If the channel is found in database: replace the old entry by the new
|
||||||
|
for i, channel in enumerate(channels):
|
||||||
|
if channel["transport_path"] == transport_path:
|
||||||
|
LOG.debug("Modified channel entry for %s", transport_path)
|
||||||
|
channels[i] = new_channel.get_channel_data().to_dict()
|
||||||
|
self.save_all_channels(channels)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Channel was not found: add a new channel entry
|
||||||
|
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||||
|
channels.append(new_channel.get_channel_data().to_dict())
|
||||||
|
self.save_all_channels(channels)
|
||||||
|
|
||||||
|
def remove_channel(self, transport_path: str) -> None:
|
||||||
|
LOG.debug(
|
||||||
|
"Removing channel with path %s from the channel database.",
|
||||||
|
transport_path,
|
||||||
|
)
|
||||||
|
channels = self.read_all_channels()
|
||||||
|
remaining_channels = [
|
||||||
|
ch for ch in channels if ch["transport_path"] != transport_path
|
||||||
|
]
|
||||||
|
self.save_all_channels(remaining_channels)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||||
|
return ChannelData(
|
||||||
|
protocol_version=dict["protocol_version"],
|
||||||
|
transport_path=dict["transport_path"],
|
||||||
|
channel_id=dict["channel_id"],
|
||||||
|
key_request=bytes.fromhex(dict["key_request"]),
|
||||||
|
key_response=bytes.fromhex(dict["key_response"]),
|
||||||
|
nonce_request=dict["nonce_request"],
|
||||||
|
nonce_response=dict["nonce_response"],
|
||||||
|
sync_bit_send=dict["sync_bit_send"],
|
||||||
|
sync_bit_receive=dict["sync_bit_receive"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_file_exists(file_path: str) -> None:
|
||||||
|
LOG.debug("checking if file %s exists", file_path)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
LOG.debug("File %s does not exist. Creating a new one.", file_path)
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump([], f)
|
||||||
|
|
||||||
|
|
||||||
|
def set_channel_database(should_not_store: bool):
|
||||||
|
global db
|
||||||
|
if should_not_store:
|
||||||
|
db = DummyChannelDatabase()
|
||||||
|
else:
|
||||||
|
from platformdirs import user_cache_dir
|
||||||
|
|
||||||
|
APP_NAME = "@trezor" # TODO
|
||||||
|
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
|
||||||
|
|
||||||
|
db = JsonChannelDatabase(DATA_PATH)
|
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import zlib
|
||||||
|
|
||||||
|
CHECKSUM_LENGTH = 4
|
||||||
|
|
||||||
|
|
||||||
|
def compute(data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Returns a CRC-32 checksum of the provided `data`.
|
||||||
|
"""
|
||||||
|
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid(checksum: bytes, data: bytes) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether the CRC-32 checksum of the `data` is the same
|
||||||
|
as the checksum provided in `checksum`.
|
||||||
|
"""
|
||||||
|
data_checksum = compute(data)
|
||||||
|
return checksum == data_checksum
|
59
python/src/trezorlib/transport/thp/control_byte.py
Normal file
59
python/src/trezorlib/transport/thp/control_byte.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
CODEC_V1 = 0x3F
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
HANDSHAKE_INIT_REQ = 0x00
|
||||||
|
HANDSHAKE_INIT_RES = 0x01
|
||||||
|
HANDSHAKE_COMP_REQ = 0x02
|
||||||
|
HANDSHAKE_COMP_RES = 0x03
|
||||||
|
ENCRYPTED_TRANSPORT = 0x04
|
||||||
|
|
||||||
|
CONTINUATION_PACKET_MASK = 0x80
|
||||||
|
ACK_MASK = 0xF7
|
||||||
|
DATA_MASK = 0xE7
|
||||||
|
|
||||||
|
ACK_MESSAGE = 0x20
|
||||||
|
_ERROR = 0x42
|
||||||
|
CHANNEL_ALLOCATION_REQ = 0x40
|
||||||
|
_CHANNEL_ALLOCATION_RES = 0x41
|
||||||
|
|
||||||
|
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||||
|
TREZOR_STATE_PAIRED = b"\x01"
|
||||||
|
|
||||||
|
|
||||||
|
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
|
||||||
|
if seq_bit == 0:
|
||||||
|
return ctrl_byte & 0xEF
|
||||||
|
if seq_bit == 1:
|
||||||
|
return ctrl_byte | 0x10
|
||||||
|
raise Exception("Unexpected sequence bit")
|
||||||
|
|
||||||
|
|
||||||
|
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
|
||||||
|
if ack_bit == 0:
|
||||||
|
return ctrl_byte & 0xF7
|
||||||
|
if ack_bit == 1:
|
||||||
|
return ctrl_byte | 0x08
|
||||||
|
raise Exception("Unexpected acknowledgement bit")
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq_bit(ctrl_byte: int) -> int:
|
||||||
|
return (ctrl_byte & 0x10) >> 4
|
||||||
|
|
||||||
|
|
||||||
|
def is_ack(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||||
|
|
||||||
|
|
||||||
|
def is_continuation(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
|
||||||
|
|
||||||
|
|
||||||
|
def is_encrypted_transport(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||||
|
|
||||||
|
|
||||||
|
def is_handshake_init_req(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
|
||||||
|
|
||||||
|
|
||||||
|
def is_handshake_comp_req(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ
|
116
python/src/trezorlib/transport/thp/curve25519.py
Normal file
116
python/src/trezorlib/transport/thp/curve25519.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
p = 2**255 - 19
|
||||||
|
J = 486662
|
||||||
|
|
||||||
|
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
|
||||||
|
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
|
||||||
|
a24 = 121666 # (J + 2) // 4
|
||||||
|
|
||||||
|
|
||||||
|
def decode_scalar(scalar: bytes) -> int:
|
||||||
|
# decodeScalar25519 from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
|
||||||
|
if len(scalar) != 32:
|
||||||
|
raise ValueError("Invalid length of scalar")
|
||||||
|
|
||||||
|
array = bytearray(scalar)
|
||||||
|
array[0] &= 248
|
||||||
|
array[31] &= 127
|
||||||
|
array[31] |= 64
|
||||||
|
|
||||||
|
return int.from_bytes(array, "little")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_coordinate(coordinate: bytes) -> int:
|
||||||
|
# decodeUCoordinate from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
if len(coordinate) != 32:
|
||||||
|
raise ValueError("Invalid length of coordinate")
|
||||||
|
|
||||||
|
array = bytearray(coordinate)
|
||||||
|
array[-1] &= 0x7F
|
||||||
|
return int.from_bytes(array, "little") % p
|
||||||
|
|
||||||
|
|
||||||
|
def encode_coordinate(coordinate: int) -> bytes:
|
||||||
|
# encodeUCoordinate from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
return coordinate.to_bytes(32, "little")
|
||||||
|
|
||||||
|
|
||||||
|
def get_private_key(secret: bytes) -> bytes:
|
||||||
|
return decode_scalar(secret).to_bytes(32, "little")
|
||||||
|
|
||||||
|
|
||||||
|
def get_public_key(private_key: bytes) -> bytes:
|
||||||
|
base_point = int.to_bytes(9, 32, "little")
|
||||||
|
return multiply(private_key, base_point)
|
||||||
|
|
||||||
|
|
||||||
|
def multiply(private_scalar: bytes, public_point: bytes):
|
||||||
|
# X25519 from
|
||||||
|
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||||
|
|
||||||
|
def ladder_operation(
|
||||||
|
x1: int, x2: int, z2: int, x3: int, z3: int
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
|
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
|
||||||
|
# (x4, z4) = 2 * (x2, z2)
|
||||||
|
# (x5, z5) = (x2, z2) + (x3, z3)
|
||||||
|
# where (x1, 1) = (x3, z3) - (x2, z2)
|
||||||
|
|
||||||
|
a = (x2 + z2) % p
|
||||||
|
aa = (a * a) % p
|
||||||
|
b = (x2 - z2) % p
|
||||||
|
bb = (b * b) % p
|
||||||
|
e = (aa - bb) % p
|
||||||
|
c = (x3 + z3) % p
|
||||||
|
d = (x3 - z3) % p
|
||||||
|
da = (d * a) % p
|
||||||
|
cb = (c * b) % p
|
||||||
|
t0 = (da + cb) % p
|
||||||
|
x5 = (t0 * t0) % p
|
||||||
|
t1 = (da - cb) % p
|
||||||
|
t2 = (t1 * t1) % p
|
||||||
|
z5 = (x1 * t2) % p
|
||||||
|
x4 = (aa * bb) % p
|
||||||
|
t3 = (a24 * e) % p
|
||||||
|
t4 = (bb + t3) % p
|
||||||
|
z4 = (e * t4) % p
|
||||||
|
|
||||||
|
return x4, z4, x5, z5
|
||||||
|
|
||||||
|
def conditional_swap(first: int, second: int, condition: int):
|
||||||
|
# Returns (second, first) if condition is true and (first, second) otherwise
|
||||||
|
# Must be implemented in a way that it is constant time
|
||||||
|
true_mask = -condition
|
||||||
|
false_mask = ~true_mask
|
||||||
|
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
|
||||||
|
first & true_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
k = decode_scalar(private_scalar)
|
||||||
|
u = decode_coordinate(public_point)
|
||||||
|
|
||||||
|
x_1 = u
|
||||||
|
x_2 = 1
|
||||||
|
z_2 = 0
|
||||||
|
x_3 = u
|
||||||
|
z_3 = 1
|
||||||
|
swap = 0
|
||||||
|
|
||||||
|
for i in reversed(range(256)):
|
||||||
|
bit = (k >> i) & 1
|
||||||
|
swap = bit ^ swap
|
||||||
|
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||||
|
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||||
|
swap = bit
|
||||||
|
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
|
||||||
|
|
||||||
|
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||||
|
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||||
|
|
||||||
|
x = pow(z_2, p - 2, p) * x_2 % p
|
||||||
|
return encode_coordinate(x)
|
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import struct
|
||||||
|
|
||||||
|
CODEC_V1 = 0x3F
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
HANDSHAKE_INIT_REQ = 0x00
|
||||||
|
HANDSHAKE_INIT_RES = 0x01
|
||||||
|
HANDSHAKE_COMP_REQ = 0x02
|
||||||
|
HANDSHAKE_COMP_RES = 0x03
|
||||||
|
ENCRYPTED_TRANSPORT = 0x04
|
||||||
|
|
||||||
|
CONTINUATION_PACKET_MASK = 0x80
|
||||||
|
ACK_MASK = 0xF7
|
||||||
|
DATA_MASK = 0xE7
|
||||||
|
|
||||||
|
ACK_MESSAGE = 0x20
|
||||||
|
_ERROR = 0x42
|
||||||
|
CHANNEL_ALLOCATION_REQ = 0x40
|
||||||
|
_CHANNEL_ALLOCATION_RES = 0x41
|
||||||
|
|
||||||
|
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||||
|
TREZOR_STATE_PAIRED = b"\x01"
|
||||||
|
|
||||||
|
BROADCAST_CHANNEL_ID = 0xFFFF
|
||||||
|
|
||||||
|
|
||||||
|
class MessageHeader:
|
||||||
|
format_str_init = ">BHH"
|
||||||
|
format_str_cont = ">BH"
|
||||||
|
|
||||||
|
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||||
|
self.ctrl_byte = ctrl_byte
|
||||||
|
self.cid = cid
|
||||||
|
self.data_length = length
|
||||||
|
|
||||||
|
def to_bytes_init(self) -> bytes:
|
||||||
|
return struct.pack(
|
||||||
|
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_bytes_cont(self) -> bytes:
|
||||||
|
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
|
||||||
|
|
||||||
|
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||||
|
struct.pack_into(
|
||||||
|
self.format_str_init,
|
||||||
|
buffer,
|
||||||
|
buffer_offset,
|
||||||
|
self.ctrl_byte,
|
||||||
|
self.cid,
|
||||||
|
self.data_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||||
|
struct.pack_into(
|
||||||
|
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_ack(self) -> bool:
|
||||||
|
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||||
|
|
||||||
|
def is_channel_allocation_response(self):
|
||||||
|
return (
|
||||||
|
self.cid == BROADCAST_CHANNEL_ID
|
||||||
|
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_handshake_init_response(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
|
||||||
|
|
||||||
|
def is_handshake_comp_response(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
|
||||||
|
|
||||||
|
def is_encrypted_transport(self) -> bool:
|
||||||
|
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_error_header(cls, cid: int, length: int):
|
||||||
|
return cls(_ERROR, cid, length)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_channel_allocation_request_header(cls, length: int):
|
||||||
|
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)
|
32
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
32
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from ... import messages
|
||||||
|
from ...mapping import ProtobufMapping
|
||||||
|
from .. import Transport
|
||||||
|
from ..thp.channel_data import ChannelData
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolAndChannel:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport: Transport,
|
||||||
|
mapping: ProtobufMapping,
|
||||||
|
channel_data: ChannelData | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.transport = transport
|
||||||
|
self.mapping = mapping
|
||||||
|
self.channel_keys = channel_data
|
||||||
|
|
||||||
|
def get_features(self) -> messages.Features:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_channel_data(self) -> ChannelData:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def update_features(self) -> None:
|
||||||
|
raise NotImplementedError
|
97
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
97
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import struct
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from ... import exceptions, messages
|
||||||
|
from ...log import DUMP_BYTES
|
||||||
|
from .protocol_and_channel import ProtocolAndChannel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolV1(ProtocolAndChannel):
|
||||||
|
HEADER_LEN = struct.calcsize(">HL")
|
||||||
|
_features: messages.Features | None = None
|
||||||
|
|
||||||
|
def get_features(self) -> messages.Features:
|
||||||
|
if self._features is None:
|
||||||
|
self.update_features()
|
||||||
|
assert self._features is not None
|
||||||
|
return self._features
|
||||||
|
|
||||||
|
def update_features(self) -> None:
|
||||||
|
self.write(messages.GetFeatures())
|
||||||
|
resp = self.read()
|
||||||
|
if not isinstance(resp, messages.Features):
|
||||||
|
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||||
|
self._features = resp
|
||||||
|
|
||||||
|
def read(self) -> t.Any:
|
||||||
|
msg_type, msg_bytes = self._read()
|
||||||
|
LOG.log(
|
||||||
|
DUMP_BYTES,
|
||||||
|
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
|
)
|
||||||
|
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||||
|
LOG.debug(
|
||||||
|
f"received message: {msg.__class__.__name__}",
|
||||||
|
extra={"protobuf": msg},
|
||||||
|
)
|
||||||
|
self.transport.close()
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def write(self, msg: t.Any) -> None:
|
||||||
|
LOG.debug(
|
||||||
|
f"sending message: {msg.__class__.__name__}",
|
||||||
|
extra={"protobuf": msg},
|
||||||
|
)
|
||||||
|
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||||
|
LOG.log(
|
||||||
|
DUMP_BYTES,
|
||||||
|
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||||
|
)
|
||||||
|
self._write(msg_type, msg_bytes)
|
||||||
|
|
||||||
|
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||||
|
chunk_size = self.transport.CHUNK_SIZE
|
||||||
|
header = struct.pack(">HL", message_type, len(message_data))
|
||||||
|
buffer = bytearray(b"##" + header + message_data)
|
||||||
|
|
||||||
|
while buffer:
|
||||||
|
# Report ID, data padded to 63 bytes
|
||||||
|
chunk = b"?" + buffer[: chunk_size - 1]
|
||||||
|
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||||
|
self.transport.write_chunk(chunk)
|
||||||
|
buffer = buffer[63:]
|
||||||
|
|
||||||
|
def _read(self) -> t.Tuple[int, bytes]:
|
||||||
|
buffer = bytearray()
|
||||||
|
# Read header with first part of message data
|
||||||
|
msg_type, datalen, first_chunk = self.read_first()
|
||||||
|
buffer.extend(first_chunk)
|
||||||
|
|
||||||
|
# Read the rest of the message
|
||||||
|
while len(buffer) < datalen:
|
||||||
|
buffer.extend(self.read_next())
|
||||||
|
|
||||||
|
return msg_type, buffer[:datalen]
|
||||||
|
|
||||||
|
def read_first(self) -> t.Tuple[int, int, bytes]:
|
||||||
|
chunk = self.transport.read_chunk()
|
||||||
|
if chunk[:3] != b"?##":
|
||||||
|
raise RuntimeError("Unexpected magic characters")
|
||||||
|
try:
|
||||||
|
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError("Cannot parse header")
|
||||||
|
|
||||||
|
data = chunk[3 + self.HEADER_LEN :]
|
||||||
|
return msg_type, datalen, data
|
||||||
|
|
||||||
|
def read_next(self) -> bytes:
|
||||||
|
chunk = self.transport.read_chunk()
|
||||||
|
if chunk[:1] != b"?":
|
||||||
|
raise RuntimeError("Unexpected magic characters")
|
||||||
|
return chunk[1:]
|
404
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
404
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
@ -0,0 +1,404 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import typing as t
|
||||||
|
from binascii import hexlify
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
import click
|
||||||
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||||
|
|
||||||
|
from ... import exceptions, messages
|
||||||
|
from ...mapping import ProtobufMapping
|
||||||
|
from .. import Transport
|
||||||
|
from ..thp import checksum, curve25519, thp_io
|
||||||
|
from ..thp.channel_data import ChannelData
|
||||||
|
from ..thp.checksum import CHECKSUM_LENGTH
|
||||||
|
from ..thp.message_header import MessageHeader
|
||||||
|
from . import control_byte
|
||||||
|
from .channel_database import ChannelDatabase, get_channel_db
|
||||||
|
from .protocol_and_channel import ProtocolAndChannel
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MANAGEMENT_SESSION_ID: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
|
||||||
|
hash = hashlib.sha256(val_1)
|
||||||
|
hash.update(val_2)
|
||||||
|
return hash.digest()
|
||||||
|
|
||||||
|
|
||||||
|
def _hkdf(chaining_key: bytes, input: bytes):
|
||||||
|
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
|
||||||
|
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
|
||||||
|
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
|
||||||
|
ctx_output_2.update(b"\x02")
|
||||||
|
output_2 = ctx_output_2.digest()
|
||||||
|
return (output_1, output_2)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||||
|
if not nonce <= 0xFFFFFFFFFFFFFFFF:
|
||||||
|
raise ValueError("Nonce overflow, terminate the channel")
|
||||||
|
return bytes(4) + nonce.to_bytes(8, "big")
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolV2(ProtocolAndChannel):
|
||||||
|
channel_id: int
|
||||||
|
channel_database: ChannelDatabase
|
||||||
|
key_request: bytes
|
||||||
|
key_response: bytes
|
||||||
|
nonce_request: int
|
||||||
|
nonce_response: int
|
||||||
|
sync_bit_send: int
|
||||||
|
sync_bit_receive: int
|
||||||
|
|
||||||
|
_has_valid_channel: bool = False
|
||||||
|
_features: messages.Features | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
transport: Transport,
|
||||||
|
mapping: ProtobufMapping,
|
||||||
|
channel_data: ChannelData | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_database: ChannelDatabase = get_channel_db()
|
||||||
|
super().__init__(transport, mapping, channel_data)
|
||||||
|
if channel_data is not None:
|
||||||
|
self.channel_id = channel_data.channel_id
|
||||||
|
self.key_request = bytes.fromhex(channel_data.key_request)
|
||||||
|
self.key_response = bytes.fromhex(channel_data.key_response)
|
||||||
|
self.nonce_request = channel_data.nonce_request
|
||||||
|
self.nonce_response = channel_data.nonce_response
|
||||||
|
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||||
|
self.sync_bit_send = channel_data.sync_bit_send
|
||||||
|
self._has_valid_channel = True
|
||||||
|
|
||||||
|
def get_channel(self) -> ProtocolV2:
|
||||||
|
if not self._has_valid_channel:
|
||||||
|
self._establish_new_channel()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_channel_data(self) -> ChannelData:
|
||||||
|
return ChannelData(
|
||||||
|
protocol_version=2,
|
||||||
|
transport_path=self.transport.get_path(),
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
key_request=self.key_request,
|
||||||
|
key_response=self.key_response,
|
||||||
|
nonce_request=self.nonce_request,
|
||||||
|
nonce_response=self.nonce_response,
|
||||||
|
sync_bit_receive=self.sync_bit_receive,
|
||||||
|
sync_bit_send=self.sync_bit_send,
|
||||||
|
)
|
||||||
|
|
||||||
|
def read(self, session_id: int) -> t.Any:
|
||||||
|
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
if sid != session_id:
|
||||||
|
raise Exception("Received messsage on a different session.")
|
||||||
|
self.channel_database.save_channel(self)
|
||||||
|
return self.mapping.decode(msg_type, msg_data)
|
||||||
|
|
||||||
|
def write(self, session_id: int, msg: t.Any) -> None:
|
||||||
|
msg_type, msg_data = self.mapping.encode(msg)
|
||||||
|
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||||
|
self.channel_database.save_channel(self)
|
||||||
|
|
||||||
|
def get_features(self) -> messages.Features:
|
||||||
|
if not self._has_valid_channel:
|
||||||
|
self._establish_new_channel()
|
||||||
|
if self._features is None:
|
||||||
|
self.update_features()
|
||||||
|
assert self._features is not None
|
||||||
|
return self._features
|
||||||
|
|
||||||
|
def update_features(self) -> None:
|
||||||
|
message = messages.GetFeatures()
|
||||||
|
message_type, message_data = self.mapping.encode(message)
|
||||||
|
self.session_id: int = 0
|
||||||
|
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||||
|
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||||
|
_, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
features = self.mapping.decode(msg_type, msg_data)
|
||||||
|
if not isinstance(features, messages.Features):
|
||||||
|
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||||
|
self._features = features
|
||||||
|
|
||||||
|
def _establish_new_channel(self) -> None:
|
||||||
|
self.sync_bit_send = 0
|
||||||
|
self.sync_bit_receive = 0
|
||||||
|
# Send channel allocation request
|
||||||
|
# Note that [:8] on the following line is required when tests use
|
||||||
|
# WITH_MOCK_URANDOM. Without [:8] such tests will (almost always) fail.
|
||||||
|
channel_id_request_nonce = os.urandom(8)[:8]
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport,
|
||||||
|
MessageHeader.get_channel_allocation_request_header(12),
|
||||||
|
channel_id_request_nonce,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read channel allocation response
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not self._is_valid_channel_allocation_response(
|
||||||
|
header, payload, channel_id_request_nonce
|
||||||
|
):
|
||||||
|
# TODO raise exception here, I guess
|
||||||
|
raise Exception("Invalid channel allocation response.")
|
||||||
|
|
||||||
|
self.channel_id = int.from_bytes(payload[8:10], "big")
|
||||||
|
self.device_properties = payload[10:]
|
||||||
|
|
||||||
|
# Send handshake init request
|
||||||
|
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||||
|
# Note that [:32] on the following line is required when tests use
|
||||||
|
# WITH_MOCK_URANDOM. Without [:32] such tests will (almost always) fail.
|
||||||
|
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)[:32])
|
||||||
|
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||||
|
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read ACK
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
click.echo("Received message is not a valid ACK", err=True)
|
||||||
|
|
||||||
|
# Read handshake init response
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
self._send_ack_0()
|
||||||
|
|
||||||
|
if not header.is_handshake_init_response():
|
||||||
|
click.echo(
|
||||||
|
"Received message is not a valid handshake init response message",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
trezor_ephemeral_pubkey = payload[:32]
|
||||||
|
encrypted_trezor_static_pubkey = payload[32:80]
|
||||||
|
noise_tag = payload[80:96]
|
||||||
|
|
||||||
|
# TODO check noise tag
|
||||||
|
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||||
|
|
||||||
|
# Prepare and send handshake completion request
|
||||||
|
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||||
|
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||||
|
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||||
|
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||||
|
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||||
|
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
PROTOCOL_NAME,
|
||||||
|
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||||
|
)
|
||||||
|
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
try:
|
||||||
|
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||||
|
IV_1, encrypted_trezor_static_pubkey, h
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(
|
||||||
|
f"Exception of type{type(e)}", err=True
|
||||||
|
) # TODO how to handle potential exceptions? Q for Matejcik
|
||||||
|
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||||
|
)
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
|
||||||
|
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||||
|
h = _sha256_of_two(h, tag_of_empty_string)
|
||||||
|
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
|
||||||
|
|
||||||
|
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||||
|
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||||
|
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
|
||||||
|
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||||
|
ck, k = _hkdf(
|
||||||
|
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
|
||||||
|
)
|
||||||
|
msg_data = self.mapping.encode_without_wire_type(
|
||||||
|
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||||
|
pairing_methods=[
|
||||||
|
messages.ThpPairingMethod.NoMethod,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
aes_ctx = AESGCM(k)
|
||||||
|
|
||||||
|
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||||
|
h = _sha256_of_two(h, encrypted_payload)
|
||||||
|
ha_completion_req_header = MessageHeader(
|
||||||
|
0x12,
|
||||||
|
self.channel_id,
|
||||||
|
len(encrypted_host_static_pubkey)
|
||||||
|
+ len(encrypted_payload)
|
||||||
|
+ CHECKSUM_LENGTH,
|
||||||
|
)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport,
|
||||||
|
ha_completion_req_header,
|
||||||
|
encrypted_host_static_pubkey + encrypted_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read ACK
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
click.echo("Received message is not a valid ACK", err=True)
|
||||||
|
|
||||||
|
# Read handshake completion response, ignore payload as we do not care about the state
|
||||||
|
header, _ = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_handshake_comp_response():
|
||||||
|
click.echo(
|
||||||
|
"Received message is not a valid handshake completion response",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
self._send_ack_1()
|
||||||
|
|
||||||
|
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||||
|
self.nonce_request = 0
|
||||||
|
self.nonce_response = 1
|
||||||
|
|
||||||
|
# Send StartPairingReqest message
|
||||||
|
message = messages.ThpStartPairingRequest()
|
||||||
|
message_type, message_data = self.mapping.encode(message)
|
||||||
|
|
||||||
|
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||||
|
|
||||||
|
# Read ACK
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
click.echo("Received message is not a valid ACK", err=True)
|
||||||
|
|
||||||
|
# Read
|
||||||
|
_, msg_type, msg_data = self.read_and_decrypt()
|
||||||
|
maaa = self.mapping.decode(msg_type, msg_data)
|
||||||
|
|
||||||
|
assert isinstance(maaa, messages.ThpEndResponse)
|
||||||
|
self._has_valid_channel = True
|
||||||
|
|
||||||
|
def _send_ack_0(self):
|
||||||
|
LOG.debug("sending ack 0")
|
||||||
|
header = MessageHeader(0x20, self.channel_id, 4)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||||
|
|
||||||
|
def _send_ack_1(self):
|
||||||
|
LOG.debug("sending ack 1")
|
||||||
|
header = MessageHeader(0x28, self.channel_id, 4)
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||||
|
|
||||||
|
def _encrypt_and_write(
|
||||||
|
self,
|
||||||
|
session_id: int,
|
||||||
|
message_type: int,
|
||||||
|
message_data: bytes,
|
||||||
|
ctrl_byte: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
assert self.key_request is not None
|
||||||
|
aes_ctx = AESGCM(self.key_request)
|
||||||
|
|
||||||
|
if ctrl_byte is None:
|
||||||
|
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
|
||||||
|
self.sync_bit_send = 1 - self.sync_bit_send
|
||||||
|
|
||||||
|
sid = session_id.to_bytes(1, "big")
|
||||||
|
msg_type = message_type.to_bytes(2, "big")
|
||||||
|
data = sid + msg_type + message_data
|
||||||
|
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||||
|
self.nonce_request += 1
|
||||||
|
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||||
|
header = MessageHeader(
|
||||||
|
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
|
|
||||||
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
|
self.transport, header, encrypted_message
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
|
||||||
|
header, raw_payload = self._read_until_valid_crc_check()
|
||||||
|
if control_byte.is_ack(header.ctrl_byte):
|
||||||
|
return self.read_and_decrypt()
|
||||||
|
if not header.is_encrypted_transport():
|
||||||
|
click.echo(
|
||||||
|
"Trying to decrypt not encrypted message!"
|
||||||
|
+ hexlify(header.to_bytes_init() + raw_payload).decode(),
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not control_byte.is_ack(header.ctrl_byte):
|
||||||
|
LOG.debug(
|
||||||
|
"--> Get sequence bit %d %s %s",
|
||||||
|
control_byte.get_seq_bit(header.ctrl_byte),
|
||||||
|
"from control byte",
|
||||||
|
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
|
||||||
|
)
|
||||||
|
if control_byte.get_seq_bit(header.ctrl_byte):
|
||||||
|
self._send_ack_1()
|
||||||
|
else:
|
||||||
|
self._send_ack_0()
|
||||||
|
aes_ctx = AESGCM(self.key_response)
|
||||||
|
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||||
|
self.nonce_response += 1
|
||||||
|
|
||||||
|
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||||
|
session_id = message[0]
|
||||||
|
message_type = message[1:3]
|
||||||
|
message_data = message[3:]
|
||||||
|
return (
|
||||||
|
session_id,
|
||||||
|
int.from_bytes(message_type, "big"),
|
||||||
|
message_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _read_until_valid_crc_check(
|
||||||
|
self,
|
||||||
|
) -> t.Tuple[MessageHeader, bytes]:
|
||||||
|
is_valid = False
|
||||||
|
header, payload, chksum = thp_io.read(self.transport)
|
||||||
|
while not is_valid:
|
||||||
|
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||||
|
if not is_valid:
|
||||||
|
click.echo(
|
||||||
|
"Received a message with an invalid checksum:"
|
||||||
|
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
header, payload, chksum = thp_io.read(self.transport)
|
||||||
|
|
||||||
|
return header, payload
|
||||||
|
|
||||||
|
def _is_valid_channel_allocation_response(
|
||||||
|
self, header: MessageHeader, payload: bytes, original_nonce: bytes
|
||||||
|
) -> bool:
|
||||||
|
if not header.is_channel_allocation_response():
|
||||||
|
click.echo(
|
||||||
|
"Received message is not a channel allocation response", err=True
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if len(payload) < 10:
|
||||||
|
click.echo("Invalid channel allocation response payload", err=True)
|
||||||
|
return False
|
||||||
|
if payload[:8] != original_nonce:
|
||||||
|
click.echo(
|
||||||
|
"Invalid channel allocation response payload (nonce mismatch)", err=True
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
class ControlByteType(IntEnum):
|
||||||
|
CHANNEL_ALLOCATION_RES = 1
|
||||||
|
HANDSHAKE_INIT_RES = 2
|
||||||
|
HANDSHAKE_COMP_RES = 3
|
||||||
|
ACK = 4
|
||||||
|
ENCRYPTED_TRANSPORT = 5
|
93
python/src/trezorlib/transport/thp/thp_io.py
Normal file
93
python/src/trezorlib/transport/thp/thp_io.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import struct
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from .. import Transport
|
||||||
|
from ..thp import checksum
|
||||||
|
from .message_header import MessageHeader
|
||||||
|
|
||||||
|
INIT_HEADER_LENGTH = 5
|
||||||
|
CONT_HEADER_LENGTH = 3
|
||||||
|
MAX_PAYLOAD_LEN = 60000
|
||||||
|
MESSAGE_TYPE_LENGTH = 2
|
||||||
|
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
|
||||||
|
|
||||||
|
def write_payload_to_wire_and_add_checksum(
|
||||||
|
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||||
|
):
|
||||||
|
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
|
||||||
|
data = transport_payload + chksum
|
||||||
|
write_payload_to_wire(transport, header, data)
|
||||||
|
|
||||||
|
|
||||||
|
def write_payload_to_wire(
|
||||||
|
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||||
|
):
|
||||||
|
transport.open()
|
||||||
|
buffer = bytearray(transport_payload)
|
||||||
|
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
|
||||||
|
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||||
|
transport.write_chunk(chunk)
|
||||||
|
|
||||||
|
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
|
||||||
|
while buffer:
|
||||||
|
chunk = (
|
||||||
|
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||||
|
transport.write_chunk(chunk)
|
||||||
|
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
|
||||||
|
|
||||||
|
|
||||||
|
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
|
||||||
|
"""
|
||||||
|
Reads from the given wire transport.
|
||||||
|
|
||||||
|
Returns `Tuple[MessageHeader, bytes, bytes]`:
|
||||||
|
1. `header` (`MessageHeader`): Header of the message.
|
||||||
|
2. `data` (`bytes`): Contents of the message (if any).
|
||||||
|
3. `checksum` (`bytes`): crc32 checksum of the header + data.
|
||||||
|
|
||||||
|
"""
|
||||||
|
buffer = bytearray()
|
||||||
|
|
||||||
|
# Read header with first part of message data
|
||||||
|
header, first_chunk = read_first(transport)
|
||||||
|
buffer.extend(first_chunk)
|
||||||
|
|
||||||
|
# Read the rest of the message
|
||||||
|
while len(buffer) < header.data_length:
|
||||||
|
buffer.extend(read_next(transport, header.cid))
|
||||||
|
|
||||||
|
data_len = header.data_length - checksum.CHECKSUM_LENGTH
|
||||||
|
msg_data = buffer[:data_len]
|
||||||
|
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
|
||||||
|
|
||||||
|
return (header, msg_data, chksum)
|
||||||
|
|
||||||
|
|
||||||
|
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
|
||||||
|
chunk = transport.read_chunk()
|
||||||
|
try:
|
||||||
|
ctrl_byte, cid, data_length = struct.unpack(
|
||||||
|
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError("Cannot parse header")
|
||||||
|
|
||||||
|
data = chunk[INIT_HEADER_LENGTH:]
|
||||||
|
return MessageHeader(ctrl_byte, cid, data_length), data
|
||||||
|
|
||||||
|
|
||||||
|
def read_next(transport: Transport, cid: int) -> bytes:
|
||||||
|
chunk = transport.read_chunk()
|
||||||
|
ctrl_byte, read_cid = struct.unpack(
|
||||||
|
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||||
|
)
|
||||||
|
if ctrl_byte != CONTINUATION_PACKET:
|
||||||
|
raise RuntimeError("Continuation packet with incorrect control byte")
|
||||||
|
if read_cid != cid:
|
||||||
|
raise RuntimeError("Continuation packet for different channel")
|
||||||
|
|
||||||
|
return chunk[CONT_HEADER_LENGTH:]
|
@ -1,6 +1,6 @@
|
|||||||
# This file is part of the Trezor project.
|
# This file is part of the Trezor project.
|
||||||
#
|
#
|
||||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||||
#
|
#
|
||||||
# This library is free software: you can redistribute it and/or modify
|
# This library is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
@ -14,14 +14,15 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Iterable, Optional
|
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import TransportException
|
from . import Transport, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..models import TrezorModel
|
from ..models import TrezorModel
|
||||||
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UdpTransport(ProtocolBasedTransport):
|
class UdpTransport(Transport):
|
||||||
|
|
||||||
DEFAULT_HOST = "127.0.0.1"
|
DEFAULT_HOST = "127.0.0.1"
|
||||||
DEFAULT_PORT = 21324
|
DEFAULT_PORT = 21324
|
||||||
PATH_PREFIX = "udp"
|
PATH_PREFIX = "udp"
|
||||||
ENABLED: bool = True
|
ENABLED: bool = True
|
||||||
|
CHUNK_SIZE = 64
|
||||||
|
|
||||||
def __init__(self, device: Optional[str] = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: str | None = None,
|
||||||
|
) -> None:
|
||||||
if not device:
|
if not device:
|
||||||
host = UdpTransport.DEFAULT_HOST
|
host = UdpTransport.DEFAULT_HOST
|
||||||
port = UdpTransport.DEFAULT_PORT
|
port = UdpTransport.DEFAULT_PORT
|
||||||
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
devparts = device.split(":")
|
devparts = device.split(":")
|
||||||
host = devparts[0]
|
host = devparts[0]
|
||||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||||
self.device = (host, port)
|
self.device: Tuple[str, int] = (host, port)
|
||||||
self.socket: Optional[socket.socket] = None
|
|
||||||
|
|
||||||
super().__init__(protocol=ProtocolV1(self))
|
self.socket: socket.socket | None = None
|
||||||
|
super().__init__()
|
||||||
def get_path(self) -> str:
|
|
||||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
|
||||||
|
|
||||||
def find_debug(self) -> "UdpTransport":
|
|
||||||
host, port = self.device
|
|
||||||
return UdpTransport(f"{host}:{port + 1}")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _try_path(cls, path: str) -> "UdpTransport":
|
def _try_path(cls, path: str) -> "UdpTransport":
|
||||||
d = cls(path)
|
d = cls(path)
|
||||||
try:
|
try:
|
||||||
d.open()
|
d.open()
|
||||||
if d._ping():
|
if d.ping():
|
||||||
return d
|
return d
|
||||||
else:
|
else:
|
||||||
raise TransportException(
|
raise TransportException(
|
||||||
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(
|
def enumerate(
|
||||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
cls, _models: Iterable["TrezorModel"] | None = None
|
||||||
) -> Iterable["UdpTransport"]:
|
) -> Iterable["UdpTransport"]:
|
||||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||||
try:
|
try:
|
||||||
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
else:
|
else:
|
||||||
raise TransportException(f"No UDP device at {path}")
|
raise TransportException(f"No UDP device at {path}")
|
||||||
|
|
||||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
def get_path(self) -> str:
|
||||||
try:
|
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||||
self.open()
|
|
||||||
start = time.monotonic()
|
|
||||||
while True:
|
|
||||||
if self._ping():
|
|
||||||
break
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
if elapsed >= timeout:
|
|
||||||
raise TransportException("Timed out waiting for connection.")
|
|
||||||
|
|
||||||
time.sleep(0.05)
|
|
||||||
finally:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
self.socket.close()
|
self.socket.close()
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
|
||||||
def _ping(self) -> bool:
|
|
||||||
"""Test if the device is listening."""
|
|
||||||
assert self.socket is not None
|
|
||||||
resp = None
|
|
||||||
try:
|
|
||||||
self.socket.sendall(b"PINGPING")
|
|
||||||
resp = self.socket.recv(8)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return resp == b"PONGPONG"
|
|
||||||
|
|
||||||
def write_chunk(self, chunk: bytes) -> None:
|
def write_chunk(self, chunk: bytes) -> None:
|
||||||
|
if self.socket is None:
|
||||||
|
self.open()
|
||||||
assert self.socket is not None
|
assert self.socket is not None
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected data length")
|
raise TransportException("Unexpected data length")
|
||||||
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
self.socket.sendall(chunk)
|
self.socket.sendall(chunk)
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self) -> bytes:
|
||||||
|
if self.socket is None:
|
||||||
|
self.open()
|
||||||
assert self.socket is not None
|
assert self.socket is not None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||||
return bytearray(chunk)
|
return bytearray(chunk)
|
||||||
|
|
||||||
|
def find_debug(self) -> "UdpTransport":
|
||||||
|
host, port = self.device
|
||||||
|
return UdpTransport(f"{host}:{port + 1}")
|
||||||
|
|
||||||
|
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||||
|
try:
|
||||||
|
self.open()
|
||||||
|
start = time.monotonic()
|
||||||
|
while True:
|
||||||
|
if self.ping():
|
||||||
|
break
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
|
if elapsed >= timeout:
|
||||||
|
raise TransportException("Timed out waiting for connection.")
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
finally:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def ping(self) -> bool:
|
||||||
|
"""Test if the device is listening."""
|
||||||
|
assert self.socket is not None
|
||||||
|
resp = None
|
||||||
|
try:
|
||||||
|
self.socket.sendall(b"PINGPING")
|
||||||
|
resp = self.socket.recv(8)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return resp == b"PONGPONG"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# This file is part of the Trezor project.
|
# This file is part of the Trezor project.
|
||||||
#
|
#
|
||||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||||
#
|
#
|
||||||
# This library is free software: you can redistribute it and/or modify
|
# This library is free software: you can redistribute it and/or modify
|
||||||
# it under the terms of the GNU Lesser General Public License version 3
|
# it under the terms of the GNU Lesser General Public License version 3
|
||||||
@ -14,16 +14,17 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from ..models import TREZORS, TrezorModel
|
from ..models import TREZORS, TrezorModel
|
||||||
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
|
|||||||
WEBUSB_CHUNK_SIZE = 64
|
WEBUSB_CHUNK_SIZE = 64
|
||||||
|
|
||||||
|
|
||||||
class WebUsbHandle:
|
class WebUsbTransport(Transport):
|
||||||
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
|
"""
|
||||||
|
WebUsbTransport implements transport over WebUSB interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATH_PREFIX = "webusb"
|
||||||
|
ENABLED = USB_IMPORTED
|
||||||
|
context = None
|
||||||
|
CHUNK_SIZE = 64
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: "usb1.USBDevice",
|
||||||
|
debug: bool = False,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.debug = debug
|
||||||
|
|
||||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||||
self.count = 0
|
self.handle: usb1.USBDeviceHandle | None = None
|
||||||
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enumerate(
|
||||||
|
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||||
|
) -> Iterable["WebUsbTransport"]:
|
||||||
|
if cls.context is None:
|
||||||
|
cls.context = usb1.USBContext()
|
||||||
|
cls.context.open()
|
||||||
|
atexit.register(cls.context.close)
|
||||||
|
|
||||||
|
if models is None:
|
||||||
|
models = TREZORS
|
||||||
|
usb_ids = [id for model in models for id in model.usb_ids]
|
||||||
|
devices: List["WebUsbTransport"] = []
|
||||||
|
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||||
|
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||||
|
if usb_id not in usb_ids:
|
||||||
|
continue
|
||||||
|
if not is_vendor_class(dev):
|
||||||
|
continue
|
||||||
|
if usb_reset:
|
||||||
|
handle = dev.open()
|
||||||
|
handle.resetDevice()
|
||||||
|
handle.close()
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
# workaround for issue #223:
|
||||||
|
# on certain combinations of Windows USB drivers and libusb versions,
|
||||||
|
# Trezor is returned twice (possibly because Windows know it as both
|
||||||
|
# a HID and a WebUSB device), and one of the returned devices is
|
||||||
|
# non-functional.
|
||||||
|
dev.getProduct()
|
||||||
|
devices.append(WebUsbTransport(dev))
|
||||||
|
except usb1.USBErrorNotSupported:
|
||||||
|
pass
|
||||||
|
return devices
|
||||||
|
|
||||||
|
def get_path(self) -> str:
|
||||||
|
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.handle = self.device.open()
|
self.handle = self.device.open()
|
||||||
@ -64,6 +121,8 @@ class WebUsbHandle:
|
|||||||
self.handle.claimInterface(self.interface)
|
self.handle.claimInterface(self.interface)
|
||||||
except usb1.USBErrorAccess as e:
|
except usb1.USBErrorAccess as e:
|
||||||
raise DeviceIsBusy(self.device) from e
|
raise DeviceIsBusy(self.device) from e
|
||||||
|
except usb1.USBErrorBusy as e:
|
||||||
|
raise DeviceIsBusy(self.device) from e
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
if self.handle is not None:
|
if self.handle is not None:
|
||||||
@ -75,6 +134,8 @@ class WebUsbHandle:
|
|||||||
self.handle = None
|
self.handle = None
|
||||||
|
|
||||||
def write_chunk(self, chunk: bytes) -> None:
|
def write_chunk(self, chunk: bytes) -> None:
|
||||||
|
if self.handle is None:
|
||||||
|
self.open()
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||||
@ -97,6 +158,8 @@ class WebUsbHandle:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self) -> bytes:
|
||||||
|
if self.handle is None:
|
||||||
|
self.open()
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
endpoint = 0x80 | self.endpoint
|
endpoint = 0x80 | self.endpoint
|
||||||
while True:
|
while True:
|
||||||
@ -117,70 +180,6 @@ class WebUsbHandle:
|
|||||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
class WebUsbTransport(ProtocolBasedTransport):
|
|
||||||
"""
|
|
||||||
WebUsbTransport implements transport over WebUSB interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATH_PREFIX = "webusb"
|
|
||||||
ENABLED = USB_IMPORTED
|
|
||||||
context = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: "usb1.USBDevice",
|
|
||||||
handle: Optional[WebUsbHandle] = None,
|
|
||||||
debug: bool = False,
|
|
||||||
) -> None:
|
|
||||||
if handle is None:
|
|
||||||
handle = WebUsbHandle(device, debug)
|
|
||||||
|
|
||||||
self.device = device
|
|
||||||
self.handle = handle
|
|
||||||
self.debug = debug
|
|
||||||
|
|
||||||
super().__init__(protocol=ProtocolV1(handle))
|
|
||||||
|
|
||||||
def get_path(self) -> str:
|
|
||||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def enumerate(
|
|
||||||
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
|
|
||||||
) -> Iterable["WebUsbTransport"]:
|
|
||||||
if cls.context is None:
|
|
||||||
cls.context = usb1.USBContext()
|
|
||||||
cls.context.open()
|
|
||||||
atexit.register(cls.context.close)
|
|
||||||
|
|
||||||
if models is None:
|
|
||||||
models = TREZORS
|
|
||||||
usb_ids = [id for model in models for id in model.usb_ids]
|
|
||||||
devices: List["WebUsbTransport"] = []
|
|
||||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
|
||||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
|
||||||
if usb_id not in usb_ids:
|
|
||||||
continue
|
|
||||||
if not is_vendor_class(dev):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
# workaround for issue #223:
|
|
||||||
# on certain combinations of Windows USB drivers and libusb versions,
|
|
||||||
# Trezor is returned twice (possibly because Windows know it as both
|
|
||||||
# a HID and a WebUSB device), and one of the returned devices is
|
|
||||||
# non-functional.
|
|
||||||
dev.getProduct()
|
|
||||||
devices.append(WebUsbTransport(dev))
|
|
||||||
except usb1.USBErrorNotSupported:
|
|
||||||
pass
|
|
||||||
except usb1.USBErrorPipe:
|
|
||||||
if usb_reset:
|
|
||||||
handle = dev.open()
|
|
||||||
handle.resetDevice()
|
|
||||||
handle.close()
|
|
||||||
return devices
|
|
||||||
|
|
||||||
def find_debug(self) -> "WebUsbTransport":
|
def find_debug(self) -> "WebUsbTransport":
|
||||||
# For v1 protocol, find debug USB interface for the same serial number
|
# For v1 protocol, find debug USB interface for the same serial number
|
||||||
return WebUsbTransport(self.device, debug=True)
|
return WebUsbTransport(self.device, debug=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user