1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-21 17:49:02 +00:00
This commit is contained in:
Petr Sedláček 2025-03-20 15:39:44 +01:00 committed by GitHub
commit 806b2ec0ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
220 changed files with 6945 additions and 6324 deletions

View File

@ -113,6 +113,7 @@ jobs:
name: Device test
runs-on: ubuntu-latest
needs: legacy_emu
timeout-minutes: 30
strategy:
matrix:
coins: [universal, btconly]
@ -120,6 +121,7 @@ jobs:
env:
EMULATOR: 1
TREZOR_PYTEST_SKIP_ALTCOINS: ${{ matrix.coins == 'universal' && '0' || '1' }}
TESTOPTS: "--timeout 120"
steps:
- uses: actions/checkout@v4
with:
@ -148,6 +150,7 @@ jobs:
name: Upgrade test
runs-on: ubuntu-latest
needs: legacy_emu
timeout-minutes: 10
strategy:
matrix:
asan: ${{ fromJSON(github.event_name == 'schedule' && '["noasan", "asan"]' || '["noasan"]') }}
@ -164,7 +167,7 @@ jobs:
- run: chmod +x legacy/firmware/*.elf
- uses: ./.github/actions/environment
- run: nix-shell --run "tests/download_emulators.sh"
- run: nix-shell --run "poetry run pytest tests/upgrade_tests"
- run: nix-shell --run "poetry run pytest --timeout 120 tests/upgrade_tests"
legacy_hwi_test:
name: HWI test

View File

@ -288,9 +288,10 @@ def cli(
label = "Emulator"
assert emulator.client is not None
trezorlib.device.wipe(emulator.client)
trezorlib.device.wipe(emulator.client.get_seedless_session())
trezorlib.debuglink.load_device(
emulator.client,
emulator.client.get_seedless_session(),
mnemonics,
pin=None,
passphrase_protection=False,

View File

@ -2,7 +2,7 @@
import binascii
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.transport.hid import HidTransport
devices = HidTransport.enumerate()
if len(devices) > 0:

View File

@ -0,0 +1 @@
Changed trezorlib to session-based. Changes also affect trezorctl, python tools, and tests.

View File

@ -14,4 +14,4 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
__version__ = "0.13.11"
__version__ = "0.14.0"

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
from ..debuglink import TrezorClientDebugLink
from ..transport import Transport
from ..transport.udp import UdpTransport
LOG = logging.getLogger(__name__)
@ -103,6 +104,8 @@ class Emulator:
"""
if self._client is None:
raise RuntimeError
if self._client.is_invalidated:
self._client = self._client.get_new_client()
return self._client
def make_args(self) -> List[str]:
@ -116,13 +119,12 @@ class Emulator:
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
assert self.process is not None, "Emulator not started"
transport = self._get_transport()
transport.open()
self.transport.open()
LOG.info("Waiting for emulator to come up...")
start = time.monotonic()
try:
while True:
if transport._ping():
if self.transport.ping():
break
if self.process.poll() is not None:
raise RuntimeError("Emulator process died")
@ -133,7 +135,7 @@ class Emulator:
time.sleep(0.1)
finally:
transport.close()
self.transport.close()
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
@ -164,7 +166,11 @@ class Emulator:
env=env,
)
def start(self) -> None:
def start(
self,
transport: Optional[UdpTransport] = None,
debug_transport: Optional[Transport] = None,
) -> None:
if self.process:
if self.process.poll() is not None:
# process has died, stop and start again
@ -174,6 +180,7 @@ class Emulator:
# process is running, no need to start again
return
self.transport = transport or self._get_transport()
self.process = self.launch_process()
_RUNNING_PIDS.add(self.process)
try:
@ -187,15 +194,16 @@ class Emulator:
(self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n")
(self.profile_dir / "trezor.port").write_text(str(self.port) + "\n")
transport = self._get_transport()
self._client = TrezorClientDebugLink(
transport, auto_interact=self.auto_interact
self.transport,
auto_interact=self.auto_interact,
open_transport=True,
debug_transport=debug_transport,
)
self._client.open()
def stop(self) -> None:
if self._client:
self._client.close()
self._client.close_transport()
self._client = None
if self.process:
@ -219,8 +227,9 @@ class Emulator:
# preserving the recording directory between restarts
self.restart_amount += 1
prev_screenshot_dir = self.client.debug.screenshot_recording_dir
debug_transport = self.client.debug.transport
self.stop()
self.start()
self.start(transport=self.transport, debug_transport=debug_transport)
if prev_screenshot_dir:
self.client.debug.start_recording(
prev_screenshot_dir, refresh_index=self.restart_amount

View File

@ -10,7 +10,7 @@ from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, utils
from . import device
from .client import TrezorClient
from .transport.session import Session
LOG = logging.getLogger(__name__)
@ -349,7 +349,7 @@ def verify_authentication_response(
def authenticate_device(
client: TrezorClient,
session: Session,
challenge: bytes | None = None,
*,
whitelist: t.Collection[bytes] | None = None,
@ -359,7 +359,7 @@ def authenticate_device(
if challenge is None:
challenge = secrets.token_bytes(16)
resp = device.authenticate(client, challenge)
resp = device.authenticate(session, challenge)
return verify_authentication_response(
challenge,

View File

@ -19,16 +19,16 @@ from typing import TYPE_CHECKING
from . import messages
if TYPE_CHECKING:
from .client import TrezorClient
from .transport.session import Session
def list_names(
client: "TrezorClient",
session: "Session",
) -> messages.BenchmarkNames:
return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult:
return client.call(
def run(session: "Session", name: str) -> messages.BenchmarkResult:
return session.call(
messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
)

View File

@ -18,20 +18,19 @@ from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import session
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
),
@ -40,17 +39,16 @@ def get_address(
def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
session: "Session", address_n: "Address", show_display: bool = False
) -> bytes:
return client.call(
return session.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
expect=messages.BinancePublicKey,
).public_key
@session
def sign_tx(
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
) -> messages.BinanceSignedTx:
msg = tx_json["msgs"][0]
tx_msg = tx_json.copy()
@ -59,7 +57,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
client.call(envelope, expect=messages.BinanceTxRequest)
session.call(envelope, expect=messages.BinanceTxRequest)
if "refid" in msg:
msg = dict_to_proto(messages.BinanceCancelMsg, msg)
@ -70,4 +68,4 @@ def sign_tx(
else:
raise ValueError("can not determine msg type")
return client.call(msg, expect=messages.BinanceSignedTx)
return session.call(msg, expect=messages.BinanceSignedTx)

View File

@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict
from . import exceptions, messages
from .tools import _return_success, prepare_message_bytes, session
from .tools import _return_success, prepare_message_bytes
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
class ScriptSig(TypedDict):
asm: str
@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
def get_public_node(
client: "TrezorClient",
session: "Session",
n: "Address",
ecdsa_curve_name: Optional[str] = None,
show_display: bool = False,
@ -116,12 +116,12 @@ def get_public_node(
unlock_path_mac: Optional[bytes] = None,
) -> messages.PublicKey:
if unlock_path:
client.call(
session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
)
return client.call(
return session.call(
messages.GetPublicKey(
address_n=n,
ecdsa_curve_name=ecdsa_curve_name,
@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str:
def get_authenticated_address(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
show_display: bool = False,
@ -151,12 +151,12 @@ def get_authenticated_address(
chunkify: bool = False,
) -> messages.Address:
if unlock_path:
client.call(
session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
)
return client.call(
return session.call(
messages.GetAddress(
address_n=n,
coin_name=coin_name,
@ -171,13 +171,13 @@ def get_authenticated_address(
def get_ownership_id(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> bytes:
return client.call(
return session.call(
messages.GetOwnershipId(
address_n=n,
coin_name=coin_name,
@ -189,7 +189,7 @@ def get_ownership_id(
def get_ownership_proof(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
@ -200,9 +200,9 @@ def get_ownership_proof(
preauthorized: bool = False,
) -> Tuple[bytes, bytes]:
if preauthorized:
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
res = client.call(
res = session.call(
messages.GetOwnershipProof(
address_n=n,
coin_name=coin_name,
@ -219,7 +219,7 @@ def get_ownership_proof(
def sign_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
message: AnyStr,
@ -227,7 +227,7 @@ def sign_message(
no_script_type: bool = False,
chunkify: bool = False,
) -> messages.MessageSignature:
return client.call(
return session.call(
messages.SignMessage(
coin_name=coin_name,
address_n=n,
@ -241,7 +241,7 @@ def sign_message(
def verify_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
address: str,
signature: bytes,
@ -249,7 +249,7 @@ def verify_message(
chunkify: bool = False,
) -> bool:
try:
client.call(
session.call(
messages.VerifyMessage(
address=address,
signature=signature,
@ -264,9 +264,8 @@ def verify_message(
return False
@session
def sign_tx(
client: "TrezorClient",
session: "Session",
coin_name: str,
inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType],
@ -314,14 +313,14 @@ def sign_tx(
setattr(signtx, name, value)
if unlock_path:
client.call(
session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
)
elif preauthorized:
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
res = client.call(signtx, expect=messages.TxRequest)
res = session.call(signtx, expect=messages.TxRequest)
# Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -380,7 +379,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index]
res = client.call(msg, expect=messages.TxRequest)
res = session.call(msg, expect=messages.TxRequest)
else:
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
@ -410,7 +409,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}."
)
res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
for i, sig in zip(inputs, signatures):
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
@ -420,7 +419,7 @@ def sign_tx(
def authorize_coinjoin(
client: "TrezorClient",
session: "Session",
coordinator: str,
max_rounds: int,
max_coordinator_fee_rate: int,
@ -429,7 +428,7 @@ def authorize_coinjoin(
coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> str | None:
resp = client.call(
resp = session.call(
messages.AuthorizeCoinJoin(
coordinator=coordinator,
max_rounds=max_rounds,

View File

@ -35,7 +35,7 @@ from . import messages as m
from . import tools
if TYPE_CHECKING:
from .client import TrezorClient
from .transport.session import Session
PROTOCOL_MAGICS = {
"mainnet": 764824073,
@ -818,7 +818,7 @@ def _get_collateral_inputs_items(
def get_address(
client: "TrezorClient",
session: "Session",
address_parameters: m.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
@ -826,7 +826,7 @@ def get_address(
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
m.CardanoGetAddress(
address_parameters=address_parameters,
protocol_magic=protocol_magic,
@ -840,12 +840,12 @@ def get_address(
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: List[int],
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
show_display: bool = False,
) -> m.CardanoPublicKey:
return client.call(
return session.call(
m.CardanoGetPublicKey(
address_n=address_n,
derivation_type=derivation_type,
@ -856,12 +856,12 @@ def get_public_key(
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
native_script: m.CardanoNativeScript,
display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
) -> m.CardanoNativeScriptHash:
return client.call(
return session.call(
m.CardanoGetNativeScriptHash(
script=native_script,
display_format=display_format,
@ -872,7 +872,7 @@ def get_native_script_hash(
def sign_tx(
client: "TrezorClient",
session: "Session",
signing_mode: m.CardanoTxSigningMode,
inputs: List[InputWithPath],
outputs: List[OutputWithData],
@ -907,7 +907,7 @@ def sign_tx(
signing_mode,
)
response = client.call(
response = session.call(
m.CardanoSignTxInit(
signing_mode=signing_mode,
inputs_count=len(inputs),
@ -942,12 +942,12 @@ def sign_tx(
_get_certificates_items(certificates),
withdrawals,
):
response = client.call(tx_item, expect=m.CardanoTxItemAck)
response = session.call(tx_item, expect=m.CardanoTxItemAck)
sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None:
auxiliary_data_supplement = client.call(
auxiliary_data_supplement = session.call(
auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
)
if (
@ -958,25 +958,25 @@ def sign_tx(
auxiliary_data_supplement.__dict__
)
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
for tx_item in chain(
_get_mint_items(mint),
_get_collateral_inputs_items(collateral_inputs),
required_signers,
):
response = client.call(tx_item, expect=m.CardanoTxItemAck)
response = session.call(tx_item, expect=m.CardanoTxItemAck)
if collateral_return is not None:
for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item, expect=m.CardanoTxItemAck)
response = session.call(tx_item, expect=m.CardanoTxItemAck)
for reference_input in reference_inputs:
response = client.call(reference_input, expect=m.CardanoTxItemAck)
response = session.call(reference_input, expect=m.CardanoTxItemAck)
sign_tx_response["witnesses"] = []
for witness_request in witness_requests:
response = client.call(witness_request, expect=m.CardanoTxWitnessResponse)
response = session.call(witness_request, expect=m.CardanoTxWitnessResponse)
sign_tx_response["witnesses"].append(
{
"type": response.type,
@ -986,9 +986,9 @@ def sign_tx(
}
)
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
return sign_tx_response

View File

@ -14,33 +14,44 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit
import functools
import logging
import os
import sys
import typing as t
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
from .. import exceptions, transport
from ..client import TrezorClient
from ..ui import ClickUI, ScriptUI
from .. import exceptions, transport, ui
from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1
if TYPE_CHECKING:
LOG = logging.getLogger(__name__)
_TRANSPORT: Transport | None = None
if t.TYPE_CHECKING:
# Needed to enforce a return value from decorators
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from ..transport import Transport
from ..ui import TrezorClientUI
P = ParamSpec("P")
R = TypeVar("R")
R = t.TypeVar("R")
FuncWithSession = t.Callable[Concatenate[Session, P], R]
class ChoiceType(click.Choice):
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
def __init__(
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
) -> None:
super().__init__(list(typemap.keys()))
self.case_sensitive = case_sensitive
if case_sensitive:
@ -48,7 +59,7 @@ class ChoiceType(click.Choice):
else:
self.typemap = {k.lower(): v for k, v in typemap.items()}
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
if value in self.typemap.values():
return value
value = super().convert(value, param, ctx)
@ -57,11 +68,52 @@ class ChoiceType(click.Choice):
return self.typemap[value]
def get_passphrase(
available_on_device: bool, passphrase_on_host: bool
) -> t.Union[str, object]:
if available_on_device and not passphrase_on_host:
return PASSPHRASE_ON_DEVICE
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
while True:
try:
passphrase = ui.prompt(
"Passphrase required",
hide_input=True,
default="",
show_default=False,
)
# In case user sees the input on the screen, we do not need confirmation
if not ui.CAN_HANDLE_HIDDEN_INPUT:
return passphrase
second = ui.prompt(
"Confirm your passphrase",
hide_input=True,
default="",
show_default=False,
)
if passphrase == second:
return passphrase
else:
ui.echo("Passphrase did not match. Please try again.")
except click.Abort:
raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
return TrezorClient(transport)
class TrezorConnection:
def __init__(
self,
path: str,
session_id: Optional[bytes],
session_id: bytes | None,
passphrase_on_host: bool,
script: bool,
) -> None:
@ -70,31 +122,95 @@ class TrezorConnection:
self.passphrase_on_host = passphrase_on_host
self.script = script
def get_session(
self,
derive_cardano: bool = False,
empty_passphrase: bool = False,
must_resume: bool = False,
) -> Session:
client = self.get_client()
if must_resume and self.session_id is None:
click.echo("Failed to resume session - no session id provided")
raise RuntimeError("Failed to resume session - no session id provided")
# Try resume session from id
if self.session_id is not None:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
else:
raise Exception("Unsupported client protocol", client.protocol_version)
if must_resume:
if session.id != self.session_id or session.id is None:
click.echo("Failed to resume session")
env_var = os.environ.get("TREZOR_SESSION_ID")
if env_var and bytes.fromhex(env_var) == self.session_id:
click.echo(
"Session-id stored in TREZOR_SESSION_ID is no longer valid. Call 'unset TREZOR_SESSION_ID' to clear it."
)
raise exceptions.FailedSessionResumption()
return session
features = client.protocol.get_features()
passphrase_protection = features.passphrase_protection
if passphrase_protection is None:
raise RuntimeError("Device is locked")
if not passphrase_protection:
return client.get_session(derive_cardano=derive_cardano)
if empty_passphrase:
passphrase = ""
elif self.script:
passphrase = None
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True
)
return session
def get_transport(self) -> "Transport":
global _TRANSPORT
if _TRANSPORT is not None:
return _TRANSPORT
try:
# look for transport without prefix search
return transport.get_transport(self.path, prefix_search=False)
_TRANSPORT = transport.get_transport(self.path, prefix_search=False)
except Exception:
# most likely not found. try again below.
pass
# look for transport with prefix search
# if this fails, we want the exception to bubble up to the caller
return transport.get_transport(self.path, prefix_search=True)
if not _TRANSPORT:
_TRANSPORT = transport.get_transport(self.path, prefix_search=True)
def get_ui(self) -> "TrezorClientUI":
if self.script:
# It is alright to return just the class object instead of instance,
# as the ScriptUI class object itself is the implementation of TrezorClientUI
# (ScriptUI is just a set of staticmethods)
return ScriptUI
else:
return ClickUI(passphrase_on_host=self.passphrase_on_host)
_TRANSPORT.open()
atexit.register(_TRANSPORT.close)
return _TRANSPORT
def get_client(self) -> TrezorClient:
transport = self.get_transport()
ui = self.get_ui()
return TrezorClient(transport, ui=ui, session_id=self.session_id)
client = get_client(self.get_transport())
if self.script:
client.button_callback = ui.ScriptUI.button_request
client.passphrase_callback = ui.ScriptUI.get_passphrase
client.pin_callback = ui.ScriptUI.get_pin
else:
click_ui = ui.ClickUI()
client.button_callback = click_ui.button_request
client.passphrase_callback = click_ui.get_passphrase
client.pin_callback = click_ui.get_pin
return client
def get_seedless_session(self) -> Session:
client = self.get_client()
seedless_session = client.get_seedless_session()
return seedless_session
@contextmanager
def client_context(self):
@ -127,36 +243,98 @@ class TrezorConnection:
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
@contextmanager
def session_context(
self,
empty_passphrase: bool = False,
derive_cardano: bool = False,
seedless: bool = False,
must_resume: bool = False,
):
"""Get a session instance as a context manager. Handle errors in a manner
appropriate for end-users.
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Usage:
>>> with obj.session_context() as session:
>>> do_your_actions_here()
"""
try:
if seedless:
session = self.get_seedless_session()
else:
session = self.get_session(
derive_cardano=derive_cardano,
empty_passphrase=empty_passphrase,
must_resume=must_resume,
)
except transport.DeviceIsBusy:
click.echo("Device is in use by another process.")
sys.exit(1)
except exceptions.FailedSessionResumption:
sys.exit(1)
except Exception:
click.echo("Failed to find a Trezor device.")
if self.path is not None:
click.echo(f"Using path: {self.path}")
sys.exit(1)
Sessions are handled transparently. The user is warned when session did not resume
cleanly. The session is closed after the command completes - unless the session
was resumed, in which case it should remain open.
try:
yield session
except exceptions.Cancelled:
# handle cancel action
click.echo("Action was cancelled.")
sys.exit(1)
except exceptions.TrezorException as e:
# handle any Trezor-sent exceptions as user-readable
raise click.ClickException(str(e)) from e
# other exceptions may cause a traceback
def with_session(
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
*,
empty_passphrase: bool = False,
derive_cardano: bool = False,
seedless: bool = False,
must_resume: bool = False,
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
"""Provides a Click command with parameter `session=obj.get_session(...)`
based on the parameters provided.
If default parameters are ok, this decorator can be used without parentheses.
"""
@click.pass_obj
@functools.wraps(func)
def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
def decorator(
func: FuncWithSession,
) -> "t.Callable[P, R]":
try:
return func(client, *args, **kwargs)
finally:
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
@click.pass_obj
@functools.wraps(func)
def function_with_session(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
is_resume_mandatory = must_resume or obj.session_id is not None
return trezorctl_command_with_client
with obj.session_context(
empty_passphrase=empty_passphrase,
derive_cardano=derive_cardano,
seedless=seedless,
must_resume=is_resume_mandatory,
) as session:
try:
return func(session, *args, **kwargs)
finally:
if not is_resume_mandatory:
session.end()
return function_with_session
# If the decorator @get_session is used without parentheses
if func and callable(func):
return decorator(func) # type: ignore [Function return type]
return decorator
class AliasedGroup(click.Group):
@ -188,14 +366,14 @@ class AliasedGroup(click.Group):
def __init__(
self,
aliases: Optional[Dict[str, click.Command]] = None,
*args: Any,
**kwargs: Any,
aliases: t.Dict[str, click.Command] | None = None,
*args: t.Any,
**kwargs: t.Any,
) -> None:
super().__init__(*args, **kwargs)
self.aliases = aliases or {}
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
cmd_name = cmd_name.replace("_", "-")
# try to look up the real name
cmd = super().get_command(ctx, cmd_name)

View File

@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
import click
from .. import benchmark
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
def list_names_patern(
client: "TrezorClient", pattern: Optional[str] = None
) -> List[str]:
names = list(benchmark.list_names(client).names)
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
names = list(benchmark.list_names(session).names)
if pattern is None:
return names
return [name for name in names if fnmatch(name, pattern)]
@ -43,10 +41,10 @@ def cli() -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@with_session(empty_passphrase=True)
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@cli.command()
@click.argument("pattern", required=False)
@with_client
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
@with_session(empty_passphrase=True)
def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark"""
names = list_names_patern(client, pattern)
names = list_names_patern(session, pattern)
if len(names) == 0:
click.echo("No benchmark satisfies the pattern.")
else:
for name in names:
result = benchmark.run(client, name)
result = benchmark.run(session, name)
click.echo(f"{name}: {result.value} {result.unit}")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import binance, tools
from . import with_client
from ..transport.session import Session
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
@ -39,23 +39,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Binance address for specified path."""
address_n = tools.parse_path(address)
return binance.get_address(client, address_n, show_display, chunkify)
return binance.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Binance public key."""
address_n = tools.parse_path(address)
return binance.get_public_key(client, address_n, show_display).hex()
return binance.get_public_key(session, address_n, show_display).hex()
@cli.command()
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.BinanceSignedTx":
"""Sign Binance transaction.
Transaction must be provided as a JSON file.
"""
address_n = tools.parse_path(address)
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)

View File

@ -13,6 +13,7 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64
import json
@ -22,10 +23,10 @@ import click
import construct as c
from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48
@ -174,15 +175,15 @@ def cli() -> None:
help="Sort pubkeys lexicographically using BIP-67",
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
script_type: Optional[messages.InputScriptType],
script_type: messages.InputScriptType | None,
show_display: bool,
multisig_xpub: List[str],
multisig_threshold: Optional[int],
multisig_threshold: int | None,
multisig_suffix_length: int,
multisig_sort_pubkeys: bool,
chunkify: bool,
@ -235,7 +236,7 @@ def get_address(
multisig = None
return btc.get_address(
client,
session,
coin,
address_n,
show_display,
@ -252,9 +253,9 @@ def get_address(
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_node(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
curve: Optional[str],
@ -266,7 +267,7 @@ def get_public_node(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node(
client,
session,
address_n,
ecdsa_curve_name=curve,
show_display=show_display,
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
purpose: Optional[int],
@ -326,7 +327,7 @@ def _get_descriptor(
n = tools.parse_path(path)
pub = btc.get_public_node(
client,
session,
n,
show_display=show_display,
coin_name=coin,
@ -363,9 +364,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
account_type: Optional[int],
@ -375,7 +376,7 @@ def get_descriptor(
"""Get descriptor of given account."""
try:
return _get_descriptor(
client, coin, account, account_type, script_type, show_display
session, coin, account, account_type, script_type, show_display
)
except ValueError as e:
raise click.ClickException(str(e))
@ -390,8 +391,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File())
@with_client
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
}
_, serialized_tx = btc.sign_tx(
client,
session,
coin,
inputs,
outputs,
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
message: str,
@ -462,7 +463,7 @@ def sign_message(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
res = btc.sign_message(
client,
session,
coin,
address_n,
message,
@ -483,9 +484,9 @@ def sign_message(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
signature: str,
@ -495,7 +496,7 @@ def verify_message(
"""Verify message."""
signature_bytes = base64.b64decode(signature)
return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify
session, coin, address, signature_bytes, message, chunkify=chunkify
)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import cardano, messages, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
@ -62,9 +62,9 @@ def cli() -> None:
@click.option("-i", "--include-network-id", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@click.option("-T", "--tag-cbor-sets", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def sign_tx(
client: "TrezorClient",
session: "Session",
file: TextIO,
signing_mode: messages.CardanoTxSigningMode,
protocol_magic: int,
@ -123,9 +123,8 @@ def sign_tx(
for p in transaction["additional_witness_requests"]
]
client.init_device(derive_cardano=True)
sign_tx_response = cardano.sign_tx(
client,
session,
signing_mode,
inputs,
outputs,
@ -209,9 +208,9 @@ def sign_tx(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
address_type: messages.CardanoAddressType,
staking_address: str,
@ -262,9 +261,8 @@ def get_address(
script_staking_hash_bytes,
)
client.init_device(derive_cardano=True)
return cardano.get_address(
client,
session,
address_parameters,
protocol_magic,
network_id,
@ -283,18 +281,17 @@ def get_address(
default=messages.CardanoDerivationType.ICARUS,
)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session(derive_cardano=True)
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
derivation_type: messages.CardanoDerivationType,
show_display: bool,
) -> messages.CardanoPublicKey:
"""Get Cardano public key."""
address_n = tools.parse_path(address)
client.init_device(derive_cardano=True)
return cardano.get_public_key(
client, address_n, derivation_type=derivation_type, show_display=show_display
session, address_n, derivation_type=derivation_type, show_display=show_display
)
@ -312,9 +309,9 @@ def get_public_key(
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
default=messages.CardanoDerivationType.ICARUS,
)
@with_client
@with_session(derive_cardano=True)
def get_native_script_hash(
client: "TrezorClient",
session: "Session",
file: TextIO,
display_format: messages.CardanoNativeScriptHashDisplayFormat,
derivation_type: messages.CardanoDerivationType,
@ -323,7 +320,6 @@ def get_native_script_hash(
native_script_json = json.load(file)
native_script = cardano.parse_native_script(native_script_json)
client.init_device(derive_cardano=True)
return cardano.get_native_script_hash(
client, native_script, display_format, derivation_type=derivation_type
session, native_script, display_format, derivation_type=derivation_type
)

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
import click
from .. import misc, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PROMPT_TYPE = ChoiceType(
@ -42,10 +42,10 @@ def cli() -> None:
@cli.command()
@click.argument("size", type=int)
@with_client
def get_entropy(client: "TrezorClient", size: int) -> str:
@with_session(empty_passphrase=True)
def get_entropy(session: "Session", size: int) -> str:
"""Get random bytes from device."""
return misc.get_entropy(client, size).hex()
return misc.get_entropy(session, size).hex()
@cli.command()
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session(empty_passphrase=True)
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -75,7 +75,7 @@ def encrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.encrypt_keyvalue(
client,
session,
address_n,
key,
value.encode(),
@ -91,9 +91,9 @@ def encrypt_keyvalue(
)
@click.argument("key")
@click.argument("value")
@with_client
@with_session(empty_passphrase=True)
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
address: str,
key: str,
value: str,
@ -112,7 +112,7 @@ def decrypt_keyvalue(
ask_on_encrypt, ask_on_decrypt = prompt
address_n = tools.parse_path(address)
return misc.decrypt_keyvalue(
client,
session,
address_n,
key,
bytes.fromhex(value),

View File

@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
import click
from .. import mapping, messages, protobuf
from ..client import TrezorClient
from ..debuglink import TrezorClientDebugLink
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
from ..debuglink import record_screen
from . import with_client
from ..transport.session import Session
from . import with_session
if TYPE_CHECKING:
from . import TrezorConnection
@ -35,53 +34,6 @@ def cli() -> None:
"""Miscellaneous debug features."""
@cli.command()
@click.argument("message_name_or_type")
@click.argument("hex_data")
@click.pass_obj
def send_bytes(
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
) -> None:
"""Send raw bytes to Trezor.
Message type and message data must be specified separately, due to how message
chunking works on the transport level. Message length is calculated and sent
automatically, and it is currently impossible to explicitly specify invalid length.
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
in which case the value of that enum is used.
"""
if message_name_or_type.isdigit():
message_type = int(message_name_or_type)
else:
message_type = getattr(messages.MessageType, message_name_or_type)
if not isinstance(message_type, int):
raise click.ClickException("Invalid message type.")
try:
message_data = bytes.fromhex(hex_data)
except Exception as e:
raise click.ClickException("Invalid hex data.") from e
transport = obj.get_transport()
transport.begin_session()
transport.write(message_type, message_data)
response_type, response_data = transport.read()
transport.end_session()
click.echo(f"Response type: {response_type}")
click.echo(f"Response data: {response_data.hex()}")
try:
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
click.echo("Parsed message:")
click.echo(protobuf.format_message(msg))
except Exception as e:
click.echo(f"Could not parse response: {e}")
@cli.command()
@click.argument("directory", required=False)
@click.option("-s", "--stop", is_flag=True, help="Stop the recording")
@ -100,23 +52,22 @@ def record_screen_from_connection(
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
transport = obj.get_transport()
debug_client = TrezorClientDebugLink(transport, auto_interact=False)
debug_client.open()
record_screen(debug_client, directory, report_func=click.echo)
debug_client.close()
debug_client.close_transport()
@cli.command()
@with_client
def prodtest_t1(client: "TrezorClient") -> None:
@with_session(seedless=True)
def prodtest_t1(session: "Session") -> None:
"""Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
"""
debuglink_prodtest_t1(client)
debuglink_prodtest_t1(session)
@cli.command()
@with_client
def optiga_set_sec_max(client: "TrezorClient") -> None:
@with_session(seedless=True)
def optiga_set_sec_max(session: "Session") -> None:
"""Set Optiga's security event counter to maximum."""
debuglink_optiga_set_sec_max(client)
debuglink_optiga_set_sec_max(session)

View File

@ -25,10 +25,10 @@ import requests
from .. import authentication, debuglink, device, exceptions, messages, ui
from ..tools import format_path
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
from . import TrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = {
@ -64,17 +64,18 @@ def cli() -> None:
help="Wipe device in bootloader mode. This also erases the firmware.",
is_flag=True,
)
@with_client
def wipe(client: "TrezorClient", bootloader: bool) -> None:
@with_session(seedless=True)
def wipe(session: "Session", bootloader: bool) -> None:
"""Reset device to factory defaults and remove all private data."""
features = session.features
if bootloader:
if not client.features.bootloader_mode:
if not features.bootloader_mode:
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
else:
click.echo("Wiping user data and firmware!")
else:
if client.features.bootloader_mode:
if features.bootloader_mode:
click.echo(
"Your device is in bootloader mode. This operation would also erase firmware."
)
@ -86,7 +87,11 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
else:
click.echo("Wiping user data!")
device.wipe(client)
try:
device.wipe(session)
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@cli.command()
@ -99,9 +104,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
@click.option("-a", "--academic", is_flag=True)
@click.option("-b", "--needs-backup", is_flag=True)
@click.option("-n", "--no-backup", is_flag=True)
@with_client
@with_session(seedless=True)
def load(
client: "TrezorClient",
session: "Session",
mnemonic: t.Sequence[str],
pin: str,
passphrase_protection: bool,
@ -132,7 +137,7 @@ def load(
try:
debuglink.load_device(
client,
session,
mnemonic=list(mnemonic),
pin=pin,
passphrase_protection=passphrase_protection,
@ -167,9 +172,9 @@ def load(
)
@click.option("-d", "--dry-run", is_flag=True)
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
@with_client
@with_session(seedless=True)
def recover(
client: "TrezorClient",
session: "Session",
words: str,
expand: bool,
pin_protection: bool,
@ -197,7 +202,7 @@ def recover(
type = messages.RecoveryType.UnlockRepeatedBackup
device.recover(
client,
session,
word_count=int(words),
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -219,9 +224,9 @@ def recover(
@click.option("-n", "--no-backup", is_flag=True)
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
@click.option("-e", "--entropy-check-count", type=click.IntRange(0))
@with_client
@with_session(seedless=True)
def setup(
client: "TrezorClient",
session: "Session",
strength: int | None,
passphrase_protection: bool,
pin_protection: bool,
@ -241,10 +246,10 @@ def setup(
if (
backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
and messages.Capability.Shamir not in client.features.capabilities
and messages.Capability.Shamir not in session.features.capabilities
) or (
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
and messages.Capability.ShamirGroups not in client.features.capabilities
and messages.Capability.ShamirGroups not in session.features.capabilities
):
click.echo(
"WARNING: Your Trezor device does not indicate support for the requested\n"
@ -252,7 +257,7 @@ def setup(
)
path_xpubs = device.setup(
client,
session,
strength=strength,
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -273,22 +278,21 @@ def setup(
@cli.command()
@click.option("-t", "--group-threshold", type=int)
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
@with_client
@with_session(seedless=True)
def backup(
client: "TrezorClient",
session: "Session",
group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (),
) -> None:
"""Perform device seed backup."""
device.backup(client, group_threshold, groups)
device.backup(session, group_threshold, groups)
@cli.command()
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
@with_client
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> None:
@with_session(seedless=True)
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> None:
"""Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored
@ -302,33 +306,33 @@ def sd_protect(
off - Remove SD card secret protection.
refresh - Replace the current SD card secret with a new one.
"""
if client.features.model == "1":
if session.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.")
device.sd_protect(client, operation)
device.sd_protect(session, operation)
@cli.command()
@click.pass_obj
def reboot_to_bootloader(obj: "TrezorConnection") -> None:
"""Reboot device into bootloader mode."""
# avoid using @with_client because it closes the session afterwards,
# avoid using @with_session because it closes the session afterwards,
# which triggers double prompt on device
with obj.client_context() as client:
device.reboot_to_bootloader(client)
device.reboot_to_bootloader(client.get_seedless_session())
@cli.command()
@with_client
def tutorial(client: "TrezorClient") -> None:
@with_session(seedless=True)
def tutorial(session: "Session") -> None:
"""Show on-device tutorial."""
device.show_device_tutorial(client)
device.show_device_tutorial(session)
@cli.command()
@with_client
def unlock_bootloader(client: "TrezorClient") -> None:
@with_session(seedless=True)
def unlock_bootloader(session: "Session") -> None:
"""Unlocks bootloader. Irreversible."""
device.unlock_bootloader(client)
device.unlock_bootloader(session)
@cli.command()
@ -339,12 +343,11 @@ def unlock_bootloader(client: "TrezorClient") -> None:
type=int,
help="Dialog expiry in seconds.",
)
@with_client
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None:
@with_session(seedless=True)
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> None:
"""Show a "Do not disconnect" dialog."""
if enable is False:
device.set_busy(client, None)
return
device.set_busy(session, None)
if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -354,7 +357,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
)
device.set_busy(client, expiry * 1000)
device.set_busy(session, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = (
@ -374,9 +377,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
is_flag=True,
help="Do not check intermediate certificates against the whitelist.",
)
@with_client
@with_session(seedless=True)
def authenticate(
client: "TrezorClient",
session: "Session",
hex_challenge: str | None,
root: t.BinaryIO | None,
raw: bool | None,
@ -397,7 +400,7 @@ def authenticate(
challenge = bytes.fromhex(hex_challenge)
if raw:
msg = device.authenticate(client, challenge)
msg = device.authenticate(session, challenge)
click.echo(f"Challenge: {hex_challenge}")
click.echo(f"Signature of challenge: {msg.signature.hex()}")
@ -436,14 +439,14 @@ def authenticate(
else:
whitelist_json = requests.get(
PUBKEY_WHITELIST_URL_TEMPLATE.format(
model=client.model.internal_name.lower()
model=session.model.internal_name.lower()
)
).json()
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
try:
authentication.authenticate_device(
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
)
except authentication.DeviceNotAuthentic:
click.echo("Device is not authentic.")

View File

@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import eos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from .. import messages
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
@ -37,11 +37,11 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Eos public key in base58 encoding."""
address_n = tools.parse_path(address)
res = eos.get_public_key(client, address_n, show_display)
res = eos.get_public_key(session, address_n, show_display)
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> "messages.EosSignedTx":
"""Sign EOS transaction."""
tx_json = json.load(file)
address_n = tools.parse_path(address)
return eos.sign_tx(
client,
session,
address_n,
tx_json["transaction"],
tx_json["chain_id"],

View File

@ -26,14 +26,14 @@ import click
from .. import _rlp, definitions, ethereum, tools
from ..messages import EthereumDefinitions
from . import with_client
from . import with_session
if TYPE_CHECKING:
import web3
from eth_typing import ChecksumAddress # noqa: I900
from web3.types import Wei
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
@ -268,24 +268,24 @@ def cli(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ethereum address in hex encoding."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
return ethereum.get_address(client, address_n, show_display, network, chunkify)
return ethereum.get_address(session, address_n, show_display, network, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
@with_session
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
"""Get Ethereum public node of given path."""
address_n = tools.parse_path(address)
result = ethereum.get_public_node(client, address_n, show_display=show_display)
result = ethereum.get_public_node(session, address_n, show_display=show_display)
return {
"node": {
"depth": result.node.depth,
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("to_address")
@click.argument("amount", callback=_amount_to_int)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
chain_id: int,
address: str,
amount: int,
@ -400,7 +400,7 @@ def sign_tx(
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
address_n = tools.parse_path(address)
from_address = ethereum.get_address(
client, address_n, encoded_network=encoded_network
session, address_n, encoded_network=encoded_network
)
if token:
@ -446,7 +446,7 @@ def sign_tx(
assert max_gas_fee is not None
assert max_priority_fee is not None
sig = ethereum.sign_tx_eip1559(
client,
session,
n=address_n,
nonce=nonce,
gas_limit=gas_limit,
@ -465,7 +465,7 @@ def sign_tx(
gas_price = _get_web3().eth.gas_price
assert gas_price is not None
sig = ethereum.sign_tx(
client,
session,
n=address_n,
tx_type=tx_type,
nonce=nonce,
@ -526,14 +526,14 @@ def sign_tx(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient", address: str, message: str, chunkify: bool
session: "Session", address: str, message: str, chunkify: bool
) -> Dict[str, str]:
"""Sign message with Ethereum address."""
address_n = tools.parse_path(address)
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
output = {
"message": message,
"address": ret.address,
@ -550,9 +550,9 @@ def sign_message(
help="Be compatible with Metamask's signTypedData_v4 implementation",
)
@click.argument("file", type=click.File("r"))
@with_client
@with_session
def sign_typed_data(
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
) -> Dict[str, str]:
"""Sign typed data (EIP-712) with Ethereum address.
@ -565,7 +565,7 @@ def sign_typed_data(
defs = EthereumDefinitions(encoded_network=network)
data = json.loads(file.read())
ret = ethereum.sign_typed_data(
client,
session,
address_n,
data,
metamask_v4_compat=metamask_v4_compat,
@ -583,9 +583,9 @@ def sign_typed_data(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: str,
message: str,
@ -594,7 +594,7 @@ def verify_message(
"""Verify message signed with Ethereum address."""
signature_bytes = ethereum.decode_hex(signature)
return ethereum.verify_message(
client, address, signature_bytes, message, chunkify=chunkify
session, address, signature_bytes, message, chunkify=chunkify
)
@ -602,9 +602,9 @@ def verify_message(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.argument("domain_hash_hex")
@click.argument("message_hash_hex")
@with_client
@with_session
def sign_typed_data_hash(
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
) -> Dict[str, str]:
"""
Sign hash of typed data (EIP-712) with Ethereum address.
@ -618,7 +618,7 @@ def sign_typed_data_hash(
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
ret = ethereum.sign_typed_data_hash(
client, address_n, domain_hash, message_hash, network
session, address_n, domain_hash, message_hash, network
)
output = {
"domain_hash": domain_hash_hex,

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
import click
from .. import fido
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
@ -40,10 +40,10 @@ def credentials() -> None:
@credentials.command(name="list")
@with_client
def credentials_list(client: "TrezorClient") -> None:
@with_session(empty_passphrase=True)
def credentials_list(session: "Session") -> None:
"""List all resident credentials on the device."""
creds = fido.list_credentials(client)
creds = fido.list_credentials(session)
for cred in creds:
click.echo("")
click.echo(f"WebAuthn credential at index {cred.index}:")
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_client
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None:
@with_session(empty_passphrase=True)
def credentials_add(session: "Session", hex_credential_id: str) -> None:
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
"""
fido.add_credential(client, bytes.fromhex(hex_credential_id))
fido.add_credential(session, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove")
@click.option(
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_client
def credentials_remove(client: "TrezorClient", index: int) -> None:
@with_session(empty_passphrase=True)
def credentials_remove(session: "Session", index: int) -> None:
"""Remove the resident credential at the given index."""
fido.remove_credential(client, index)
fido.remove_credential(session, index)
#
@ -110,19 +110,19 @@ def counter() -> None:
@counter.command(name="set")
@click.argument("counter", type=int)
@with_client
def counter_set(client: "TrezorClient", counter: int) -> None:
@with_session(empty_passphrase=True)
def counter_set(session: "Session", counter: int) -> None:
"""Set FIDO/U2F counter value."""
fido.set_counter(client, counter)
fido.set_counter(session, counter)
@counter.command(name="get-next")
@with_client
def counter_get_next(client: "TrezorClient") -> int:
@with_session(empty_passphrase=True)
def counter_get_next(session: "Session") -> int:
"""Get-and-increase value of FIDO/U2F counter.
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
is returned and atomically increased. This command performs the same operation
and returns the counter value.
"""
return fido.get_next_counter(client)
return fido.get_next_counter(session)

View File

@ -37,10 +37,11 @@ import requests
from .. import device, exceptions, firmware, messages, models
from ..firmware import models as fw_models
from ..models import TrezorModel
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
from . import TrezorConnection
MODEL_CHOICE = ChoiceType(
@ -75,9 +76,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
This is the case from bootloader version 1.8.0, and also holds for firmware version
1.8.0 because that installs the appropriate bootloader.
"""
f = client.features
version = (f.major_version, f.minor_version, f.patch_version)
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
features = client.features
version = client.version
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
return bootloader_onev2
@ -307,25 +308,26 @@ def find_best_firmware_version(
If the specified version is not found, prints the closest available version
(higher than the specified one, if existing).
"""
features = client.features
model = client.model
if bitcoin_only is None:
bitcoin_only = _should_use_bitcoin_only(client.features)
bitcoin_only = _should_use_bitcoin_only(features)
def version_str(version: Iterable[int]) -> str:
return ".".join(map(str, version))
f = client.features
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
releases = get_all_firmware_releases(model, bitcoin_only, beta)
highest_version = releases[0]["version"]
if version:
want_version = [int(x) for x in version.split(".")]
if len(want_version) != 3:
click.echo("Please use the 'X.Y.Z' version format.")
if want_version[0] != f.major_version:
if want_version[0] != features.major_version:
click.echo(
f"Warning: Trezor {client.model.name} firmware version should be "
f"{f.major_version}.X.Y (requested: {version})"
f"Warning: Trezor {model.name} firmware version should be "
f"{features.major_version}.X.Y (requested: {version})"
)
else:
want_version = highest_version
@ -360,8 +362,8 @@ def find_best_firmware_version(
# to the newer one, in that case update to the minimal
# compatible version first
# Choosing the version key to compare based on (not) being in BL mode
client_version = [f.major_version, f.minor_version, f.patch_version]
if f.bootloader_mode:
client_version = client.version
if features.bootloader_mode:
key_to_compare = "min_bootloader_version"
else:
key_to_compare = "min_firmware_version"
@ -454,11 +456,11 @@ def extract_embedded_fw(
def upload_firmware_into_device(
client: "TrezorClient",
session: "Session",
firmware_data: bytes,
) -> None:
"""Perform the final act of loading the firmware into Trezor."""
f = client.features
f = session.features
try:
if f.major_version == 1 and f.firmware_present is not False:
# Trezor One does not send ButtonRequest
@ -468,7 +470,7 @@ def upload_firmware_into_device(
with click.progressbar(
label="Uploading", length=len(firmware_data), show_eta=False
) as bar:
firmware.update(client, firmware_data, bar.update)
firmware.update(session, firmware_data, bar.update)
except exceptions.Cancelled:
click.echo("Update aborted on device.")
except exceptions.TrezorException as e:
@ -661,6 +663,7 @@ def update(
against data.trezor.io information, if available.
"""
with obj.client_context() as client:
seedless_session = client.get_seedless_session()
if sum(bool(x) for x in (filename, url, version)) > 1:
click.echo("You can use only one of: filename, url, version.")
sys.exit(1)
@ -716,7 +719,7 @@ def update(
if _is_strict_update(client, firmware_data):
header_size = _get_firmware_header_size(firmware_data)
device.reboot_to_bootloader(
client,
seedless_session,
boot_command=messages.BootCommand.INSTALL_UPGRADE,
firmware_header=firmware_data[:header_size],
language_data=language_data,
@ -726,7 +729,7 @@ def update(
click.echo(
"WARNING: Seamless installation not possible, language data will not be uploaded."
)
device.reboot_to_bootloader(client)
device.reboot_to_bootloader(seedless_session)
click.echo("Waiting for bootloader...")
while True:
@ -742,13 +745,15 @@ def update(
click.echo("Please switch your device to bootloader mode.")
sys.exit(1)
upload_firmware_into_device(client=client, firmware_data=firmware_data)
upload_firmware_into_device(
session=client.get_seedless_session(), firmware_data=firmware_data
)
@cli.command()
@click.argument("hex_challenge", required=False)
@with_client
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
@with_session(seedless=True)
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
"""Get a hash of the installed firmware combined with the optional challenge."""
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
return firmware.get_hash(client, challenge).hex()
return firmware.get_hash(session, challenge).hex()

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
import click
from .. import messages, monero, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
@ -42,9 +42,9 @@ def cli() -> None:
default=messages.MoneroNetworkType.MAINNET,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
network_type: messages.MoneroNetworkType,
@ -52,7 +52,7 @@ def get_address(
) -> bytes:
"""Get Monero address for specified path."""
address_n = tools.parse_path(address)
return monero.get_address(client, address_n, show_display, network_type, chunkify)
return monero.get_address(session, address_n, show_display, network_type, chunkify)
@cli.command()
@ -63,13 +63,13 @@ def get_address(
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
default=messages.MoneroNetworkType.MAINNET,
)
@with_client
@with_session
def get_watch_key(
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
session: "Session", address: str, network_type: messages.MoneroNetworkType
) -> Dict[str, str]:
"""Get Monero watch key for specified path."""
address_n = tools.parse_path(address)
res = monero.get_watch_key(client, address_n, network_type)
res = monero.get_watch_key(session, address_n, network_type)
# TODO: could be made required in MoneroWatchKey
assert res.address is not None
assert res.watch_key is not None

View File

@ -21,10 +21,10 @@ import click
import requests
from .. import nem, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
@ -39,9 +39,9 @@ def cli() -> None:
@click.option("-N", "--network", type=int, default=0x68)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
network: int,
show_display: bool,
@ -49,7 +49,7 @@ def get_address(
) -> str:
"""Get NEM address for specified path."""
address_n = tools.parse_path(address)
return nem.get_address(client, address_n, network, show_display, chunkify)
return nem.get_address(session, address_n, network, show_display, chunkify)
@cli.command()
@ -58,9 +58,9 @@ def get_address(
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
file: TextIO,
broadcast: Optional[str],
@ -71,7 +71,7 @@ def sign_tx(
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
"""
address_n = tools.parse_path(address)
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}

View File

@ -22,10 +22,10 @@ import typing as t
import click
from .. import messages, nostr, tools
from . import with_client
from . import with_session
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_TEMPLATE = "m/44h/1237h/{}h/0/0"
@ -38,9 +38,9 @@ def cli() -> None:
@cli.command()
@click.option("-a", "--account", default=0, help="Account index")
@with_client
@with_session
def get_pubkey(
client: "TrezorClient",
session: "Session",
account: int,
) -> str:
"""Return the pubkey derived by the given path."""
@ -48,7 +48,7 @@ def get_pubkey(
address_n = tools.parse_path(PATH_TEMPLATE.format(account))
return nostr.get_pubkey(
client,
session,
address_n,
).hex()
@ -56,9 +56,9 @@ def get_pubkey(
@cli.command()
@click.option("-a", "--account", default=0, help="Account index")
@click.argument("event")
@with_client
@with_session
def sign_event(
client: "TrezorClient",
session: "Session",
account: int,
event: str,
) -> dict[str, str]:
@ -69,7 +69,7 @@ def sign_event(
address_n = tools.parse_path(PATH_TEMPLATE.format(account))
res = nostr.sign_event(
client,
session,
messages.NostrSignEvent(
address_n=address_n,
created_at=event_json["created_at"],

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import ripple, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
@ -37,13 +37,13 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Ripple address"""
address_n = tools.parse_path(address)
return ripple.get_address(client, address_n, show_display, chunkify)
return ripple.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -51,13 +51,13 @@ def get_address(
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
"""Sign Ripple transaction"""
address_n = tools.parse_path(address)
msg = ripple.create_sign_tx_msg(json.load(file))
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
click.echo("Signature:")
click.echo(result.signature.hex())
click.echo()

View File

@ -24,10 +24,10 @@ import click
import requests
from .. import device, messages, toif
from . import AliasedGroup, ChoiceType, with_client
from . import AliasedGroup, ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
try:
from PIL import Image
@ -190,18 +190,18 @@ def cli() -> None:
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
@with_session(seedless=True)
def pin(session: "Session", enable: Optional[bool], remove: bool) -> None:
"""Set, change or remove PIN."""
# Remove argument is there for backwards compatibility
device.change_pin(client, remove=_should_remove(enable, remove))
device.change_pin(session, remove=_should_remove(enable, remove))
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
@with_session(seedless=True)
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> None:
"""Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -209,32 +209,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> N
removed and the device will be reset to factory defaults.
"""
# Remove argument is there for backwards compatibility
device.change_wipe_code(client, remove=_should_remove(enable, remove))
device.change_wipe_code(session, remove=_should_remove(enable, remove))
@cli.command()
# keep the deprecated -l/--label option, make it do nothing
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label")
@with_client
def label(client: "TrezorClient", label: str) -> None:
@with_session(seedless=True)
def label(session: "Session", label: str) -> None:
"""Set new device label."""
device.apply_settings(client, label=label)
device.apply_settings(session, label=label)
@cli.command()
@with_client
def brightness(client: "TrezorClient") -> None:
@with_session(seedless=True)
def brightness(session: "Session") -> None:
"""Set display brightness."""
device.set_brightness(client)
device.set_brightness(session)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
@with_session(seedless=True)
def haptic_feedback(session: "Session", enable: bool) -> None:
"""Enable or disable haptic feedback."""
device.apply_settings(client, haptic_feedback=enable)
device.apply_settings(session, haptic_feedback=enable)
@cli.command()
@ -243,9 +243,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
)
@click.option("-d/-D", "--display/--no-display", default=None)
@with_client
@with_session(seedless=True)
def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
) -> None:
"""Set new language with translations."""
if remove != (path_or_url is None):
@ -269,30 +269,28 @@ def language(
raise click.ClickException(
f"Failed to load translations from {path_or_url}"
) from None
device.change_language(client, language_data=language_data, show_display=display)
device.change_language(session, language_data=language_data, show_display=display)
@cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION))
@with_client
def display_rotation(
client: "TrezorClient", rotation: messages.DisplayRotation
) -> None:
@with_session(seedless=True)
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> None:
"""Set display rotation.
Configure display rotation for Trezor Model T. The options are
north, east, south or west.
"""
device.apply_settings(client, display_rotation=rotation)
device.apply_settings(session, display_rotation=rotation)
@cli.command()
@click.argument("delay", type=str)
@with_client
def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
@with_session(seedless=True)
def auto_lock_delay(session: "Session", delay: str) -> None:
"""Set auto-lock delay (in seconds)."""
if not client.features.pin_protection:
if not session.features.pin_protection:
raise click.ClickException("Set up a PIN first")
value, unit = delay[:-1], delay[-1:]
@ -301,13 +299,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
seconds = float(value) * units[unit]
else:
seconds = float(delay) # assume seconds if no unit is specified
device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
@cli.command()
@click.argument("flags")
@with_client
def flags(client: "TrezorClient", flags: str) -> None:
@with_session(seedless=True)
def flags(session: "Session", flags: str) -> None:
"""Set device flags."""
if flags.lower().startswith("0b"):
flags_int = int(flags, 2)
@ -315,7 +313,7 @@ def flags(client: "TrezorClient", flags: str) -> None:
flags_int = int(flags, 16)
else:
flags_int = int(flags)
device.apply_flags(client, flags=flags_int)
device.apply_flags(session, flags=flags_int)
@cli.command()
@ -324,8 +322,8 @@ def flags(client: "TrezorClient", flags: str) -> None:
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
)
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client
def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
@with_session(seedless=True)
def homescreen(session: "Session", filename: str, quality: int) -> None:
"""Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default'
@ -337,39 +335,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
if not path.exists() or not path.is_file():
raise click.ClickException("Cannot open file")
if client.features.model == "1":
if session.features.model == "1":
img = image_to_t1(path)
else:
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 240
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 240
)
img = image_to_jpeg(path, width, height, quality)
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = client.features.homescreen_width
height = client.features.homescreen_height
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
width = session.features.homescreen_width
height = session.features.homescreen_height
if width is None or height is None:
raise click.ClickException("Device did not report homescreen size.")
img = image_to_toif(path, width, height, True)
elif (
client.features.homescreen_format == messages.HomescreenFormat.Toif
or client.features.homescreen_format is None
session.features.homescreen_format == messages.HomescreenFormat.Toif
or session.features.homescreen_format is None
):
width = (
client.features.homescreen_width
if client.features.homescreen_width is not None
session.features.homescreen_width
if session.features.homescreen_width is not None
else 144
)
height = (
client.features.homescreen_height
if client.features.homescreen_height is not None
session.features.homescreen_height
if session.features.homescreen_height is not None
else 144
)
img = image_to_toif(path, width, height, False)
@ -379,7 +377,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"Unknown image format requested by the device."
)
device.apply_settings(client, homescreen=img)
device.apply_settings(session, homescreen=img)
@cli.command()
@ -387,9 +385,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
)
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
@with_client
@with_session(seedless=True)
def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
session: "Session", always: bool, level: messages.SafetyCheckLevel
) -> None:
"""Set safety check level.
@ -402,18 +400,18 @@ def safety_checks(
"""
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways
device.apply_settings(client, safety_checks=level)
device.apply_settings(session, safety_checks=level)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def experimental_features(client: "TrezorClient", enable: bool) -> None:
@with_session(seedless=True)
def experimental_features(session: "Session", enable: bool) -> None:
"""Enable or disable experimental message types.
This is a developer feature. Use with caution.
"""
device.apply_settings(client, experimental_features=enable)
device.apply_settings(session, experimental_features=enable)
#
@ -436,25 +434,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None:
@with_session(seedless=True)
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> None:
"""Enable passphrase."""
if client.features.passphrase_protection is not True:
if session.features.passphrase_protection is not True:
use_passphrase = True
else:
use_passphrase = None
device.apply_settings(
client,
session,
use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device,
)
@passphrase.command(name="off")
@with_client
def passphrase_off(client: "TrezorClient") -> None:
@with_session(seedless=True)
def passphrase_off(session: "Session") -> None:
"""Disable passphrase."""
device.apply_settings(client, use_passphrase=False)
device.apply_settings(session, use_passphrase=False)
# Registering the aliases for backwards compatibility
@ -467,10 +465,10 @@ passphrase.aliases = {
@passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None:
@with_session(seedless=True)
def hide_passphrase_from_host(session: "Session", hide: bool) -> None:
"""Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution.
"""
device.apply_settings(client, hide_passphrase_from_host=hide)
device.apply_settings(session, hide_passphrase_from_host=hide)

View File

@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
import click
from .. import messages, solana, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
DEFAULT_PATH = "m/44h/501h/0h/0h"
@ -21,40 +21,40 @@ def cli() -> None:
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_key(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
) -> bytes:
"""Get Solana public key."""
address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display)
return solana.get_public_key(session, address_n, show_display)
@cli.command()
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
address: str,
show_display: bool,
chunkify: bool,
) -> str:
"""Get Solana address."""
address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify)
return solana.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.argument("serialized_tx", type=str)
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
@click.option("-a", "--additional-information-file", type=click.File("r"))
@with_client
@with_session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: str,
serialized_tx: str,
additional_information_file: Optional[TextIO],
@ -78,7 +78,7 @@ def sign_tx(
)
return solana.sign_tx(
client,
session,
address_n,
bytes.fromhex(serialized_tx),
additional_information,

View File

@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
import click
from .. import stellar, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
try:
from stellar_sdk import (
@ -52,13 +52,13 @@ def cli() -> None:
)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Stellar public address."""
address_n = tools.parse_path(address)
return stellar.get_address(client, address_n, show_display, chunkify)
return stellar.get_address(session, address_n, show_display, chunkify)
@cli.command()
@ -77,9 +77,9 @@ def get_address(
help="Network passphrase (blank for public network).",
)
@click.argument("b64envelope")
@with_client
@with_session
def sign_transaction(
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
session: "Session", b64envelope: str, address: str, network_passphrase: str
) -> bytes:
"""Sign a base64-encoded transaction envelope.
@ -109,6 +109,6 @@ def sign_transaction(
address_n = tools.parse_path(address)
tx, operations = stellar.from_envelope(envelope)
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
return base64.b64encode(resp.signature)

View File

@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
import click
from .. import messages, protobuf, tezos, tools
from . import with_client
from . import with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
@ -37,23 +37,23 @@ def cli() -> None:
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
session: "Session", address: str, show_display: bool, chunkify: bool
) -> str:
"""Get Tezos address for specified path."""
address_n = tools.parse_path(address)
return tezos.get_address(client, address_n, show_display, chunkify)
return tezos.get_address(session, address_n, show_display, chunkify)
@cli.command()
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-d", "--show-display", is_flag=True)
@with_client
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
@with_session
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
"""Get Tezos public key."""
address_n = tools.parse_path(address)
return tezos.get_public_key(client, address_n, show_display)
return tezos.get_public_key(session, address_n, show_display)
@cli.command()
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
@click.option("-n", "--address", required=True, help=PATH_HELP)
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def sign_tx(
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
session: "Session", address: str, file: TextIO, chunkify: bool
) -> messages.TezosSignedTx:
"""Sign Tezos transaction."""
address_n = tools.parse_path(address)
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify)
return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)

View File

@ -24,9 +24,9 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click
from .. import __version__, log, messages, protobuf, ui
from ..client import TrezorClient
from .. import __version__, log, messages, protobuf
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.udp import UdpTransport
from . import (
AliasedGroup,
@ -50,7 +50,7 @@ from . import (
solana,
stellar,
tezos,
with_client,
with_session,
)
F = TypeVar("F", bound=Callable)
@ -286,18 +286,24 @@ def format_device_name(features: messages.Features) -> str:
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""
if no_resolve:
return enumerate_devices()
for d in enumerate_devices():
click.echo(d.get_path())
return
from . import get_client
for transport in enumerate_devices():
try:
client = TrezorClient(transport, ui=ui.ClickUI())
client = get_client(transport)
transport.open()
description = format_device_name(client.features)
client.end_session()
except DeviceIsBusy:
description = "Device is in use by another process"
except Exception:
description = "Failed to read details"
click.echo(f"{transport} - {description}")
except Exception as e:
description = "Failed to read details " + str(type(e))
finally:
transport.close()
click.echo(f"{transport.get_path()} - {description}")
return None
@ -315,23 +321,21 @@ def version() -> str:
@cli.command()
@click.argument("message")
@click.option("-b", "--button-protection", is_flag=True)
@with_client
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
@with_session(empty_passphrase=True)
def ping(session: "Session", message: str, button_protection: bool) -> str:
"""Send ping message."""
return client.ping(message, button_protection=button_protection)
return session.ping(message, button_protection)
@cli.command()
@click.pass_obj
def get_session(obj: TrezorConnection) -> str:
@click.option("-c", "derive_cardano", is_flag=True, help="Derive Cardano session.")
def get_session(obj: TrezorConnection, derive_cardano: bool = False) -> str:
"""Get a session ID for subsequent commands.
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
`trezorctl -s SESSION_ID`, or set it to an environment variable `TREZOR_SESSION_ID`,
to avoid having to enter passphrase for subsequent commands.
The session ID is valid until another client starts using Trezor, until the next
get-session call, or until Trezor is disconnected.
"""
# make sure session is not resumed
obj.session_id = None
@ -342,25 +346,26 @@ def get_session(obj: TrezorConnection) -> str:
"Upgrade your firmware to enable session support."
)
client.ensure_unlocked()
if client.session_id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.")
else:
return client.session_id.hex()
session = obj.get_session(derive_cardano=derive_cardano)
if session.id is None:
raise click.ClickException("Passphrase not enabled or firmware too old.")
else:
return session.id.hex()
@cli.command()
@with_client
def clear_session(client: "TrezorClient") -> None:
@with_session(must_resume=True, empty_passphrase=True)
def clear_session(session: "Session") -> None:
"""Clear session (remove cached PIN, passphrase, etc.)."""
return client.clear_session()
session.call(messages.LockDevice())
session.end()
@cli.command()
@with_client
def get_features(client: "TrezorClient") -> messages.Features:
@with_session(seedless=True)
def get_features(session: "Session") -> messages.Features:
"""Retrieve device features and settings."""
return client.features
return session.features
@cli.command()

View File

@ -13,28 +13,23 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import os
import typing as t
import warnings
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from mnemonic import Mnemonic
from enum import IntEnum
from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES
from .messages import Capability
from .protobuf import MessageType
from .tools import parse_path, session
from .mapping import ProtobufMapping
from .tools import parse_path
from .transport import Transport, get_transport
from .transport.thp.protocol_and_channel import Channel
from .transport.thp.protocol_v1 import ProtocolV1Channel
if TYPE_CHECKING:
from .transport import Transport
from .ui import TrezorClientUI
UI = TypeVar("UI", bound="TrezorClientUI")
MT = TypeVar("MT", bound=MessageType)
if t.TYPE_CHECKING:
from .transport.session import Session
LOG = logging.getLogger(__name__)
@ -51,330 +46,170 @@ Or visit https://suite.trezor.io/
""".strip()
def get_default_client(
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
) -> "TrezorClient":
"""Get a client for a connected Trezor device.
Returns a TrezorClient instance with minimum fuss.
If path is specified, does a prefix-search for the specified device. Otherwise, uses
the value of TREZOR_PATH env variable, or finds first connected Trezor.
If no UI is supplied, instantiates the default CLI UI.
"""
from .transport import get_transport
from .ui import ClickUI
if path is None:
path = os.getenv("TREZOR_PATH")
transport = get_transport(path, prefix_search=True)
if ui is None:
ui = ClickUI()
return TrezorClient(transport, ui, **kwargs)
LOG = logging.getLogger(__name__)
class TrezorClient(Generic[UI]):
"""Trezor client, a connection to a Trezor device.
class ProtocolVersion(IntEnum):
UNKNOWN = 0x00
PROTOCOL_V1 = 0x01 # Codec
PROTOCOL_V2 = 0x02 # THP
This class allows you to manage connection state, send and receive protobuf
messages, handle user interactions, and perform some generic tasks
(send a cancel message, initialize or clear a session, ping the device).
"""
model: models.TrezorModel
transport: "Transport"
session_id: Optional[bytes]
ui: UI
features: messages.Features
class TrezorClient:
button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None
passphrase_callback: (
t.Callable[[Session, messages.PassphraseRequest], t.Any] | None
) = None
pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None
_seedless_session: Session | None = None
_features: messages.Features | None = None
_protocol_version: int
_setup_pin: str | None = None # Should be used only by conftest
def __init__(
self,
transport: "Transport",
ui: UI,
session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None,
model: Optional[models.TrezorModel] = None,
_init_device: bool = True,
transport: Transport,
protobuf_mapping: ProtobufMapping | None = None,
protocol: Channel | None = None,
) -> None:
"""Create a TrezorClient instance.
You have to provide a `transport`, i.e., a raw connection to the device. You can
use `trezorlib.transport.get_transport` to find one.
You have to provide a UI implementation for the three kinds of interaction:
- button request (notify the user that their interaction is needed)
- PIN request (on T1, ask the user to input numbers for a PIN matrix)
- passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for
details.
You can supply a `session_id` you might have saved in the previous session. If
you do, the user might not need to enter their passphrase again.
You can provide Trezor model information. If not provided, it is detected from
the model name reported at initialization time.
By default, the instance will open a connection to the Trezor device, send an
`Initialize` message, set up the `features` field from the response, and connect
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
this means that `client.features` is unset. Use `client.init_device()` or
`client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break.
Only use this if you are _sure_ that you know what you are doing. This feature
might be removed at any time.
"""
LOG.info(f"creating client instance for device: {transport.get_path()}")
# Here, self.model could be set to None. Unless _init_device is False, it will
# get correctly reconfigured as part of the init_device flow.
self.model = model # type: ignore ["None" is incompatible with "TrezorModel"]
if self.model:
self.mapping = self.model.default_mapping
else:
self.mapping = mapping.DEFAULT_MAPPING
Transport needs to be opened before calling a method (or accessing
an attribute) for the first time. It should be closed after you're
done using the client.
"""
self._is_invalidated: bool = False
self.transport = transport
self.ui = ui
self.session_counter = 0
self.session_id = session_id
if _init_device:
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
def open(self) -> None:
if self.session_counter == 0:
self.transport.begin_session()
self.session_counter += 1
def close(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
# TODO call EndSession here?
self.transport.end_session()
def cancel(self) -> None:
self._raw_write(messages.Cancel())
def call_raw(self, msg: MessageType) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()
def _raw_write(self, msg: MessageType) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self.transport.write(msg_type, msg_bytes)
def _raw_read(self) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
msg_type, msg_bytes = self.transport.read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
return msg
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
try:
pin = self.ui.get_pin(msg.type)
except exceptions.Cancelled:
self.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
self.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise exceptions.PinException(resp.code, resp.message)
if protobuf_mapping is None:
self.mapping = mapping.DEFAULT_MAPPING
else:
return resp
self.mapping = protobuf_mapping
if protocol is None:
self.protocol = self._get_protocol()
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
available_on_device = Capability.PassphraseEntry in self.features.capabilities
if isinstance(self.protocol, ProtocolV1Channel):
self._protocol_version = ProtocolVersion.PROTOCOL_V1
else:
self._protocol_version = ProtocolVersion.UNKNOWN
def send_passphrase(
passphrase: Optional[str] = None, on_device: Optional[bool] = None
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = self.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
self.session_id = resp.state
resp = self.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
@classmethod
def resume(
cls,
transport: Transport,
protobuf_mapping: ProtobufMapping | None = None,
) -> TrezorClient:
if protobuf_mapping is None:
protobuf_mapping = mapping.DEFAULT_MAPPING
protocol = ProtocolV1Channel(transport, protobuf_mapping)
return TrezorClient(transport, protobuf_mapping, protocol)
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
passphrase = self.ui.get_passphrase(available_on_device=available_on_device)
except exceptions.Cancelled:
self.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
self.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
self.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self._raw_write(messages.ButtonAck())
self.ui.button_request(msg)
return self._raw_read()
@session
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
self.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
resp = self._callback_pin(resp)
elif isinstance(resp, messages.PassphraseRequest):
resp = self._callback_passphrase(resp)
elif isinstance(resp, messages.ButtonRequest):
resp = self._callback_button(resp)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
elif not isinstance(resp, expect):
raise exceptions.UnexpectedMessageError(expect, resp)
else:
return resp
def _refresh_features(self, features: messages.Features) -> None:
"""Update internal fields based on passed-in Features message."""
if not self.model:
self.model = models.detect(features)
if features.vendor not in self.model.vendors:
raise exceptions.TrezorException(f"Unrecognized vendor: {features.vendor}")
self.features = features
self.version = (
self.features.major_version,
self.features.minor_version,
self.features.patch_version,
)
self.check_firmware_version(warn_only=True)
if self.features.session_id is not None:
self.session_id = self.features.session_id
self.features.session_id = None
@session
def refresh_features(self) -> messages.Features:
"""Reload features from the device.
Should be called after changing settings or performing operations that affect
device state.
"""
resp = self.call_raw(messages.GetFeatures())
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._refresh_features(resp)
return resp
@session
def init_device(
def get_session(
self,
*,
session_id: Optional[bytes] = None,
new_session: bool = False,
derive_cardano: Optional[bool] = None,
) -> Optional[bytes]:
"""Initialize the device and return a session ID.
You can optionally specify a session ID. If the session still exists on the
device, the same session ID will be returned and the session is resumed.
Otherwise a different session ID is returned.
Specify `new_session=True` to open a fresh session. Since firmware version
1.9.0/2.3.0, the previous session will remain cached on the device, and can be
resumed by calling `init_device` again with the appropriate session ID.
If neither `new_session` nor `session_id` is specified, the current session ID
will be reused. If no session ID was cached, a new session ID will be allocated
and returned.
# Version notes:
Trezor One older than 1.9.0 does not have session management. Optional arguments
have no effect and the function returns None
Trezor T older than 2.3.0 does not have session cache. Requesting a new session
will overwrite the old one. In addition, this function will always return None.
A valid session_id can be obtained from the `session_id` attribute, but only
after a passphrase-protected call is performed. You can use the following code:
>>> client.init_device()
>>> client.ensure_unlocked()
>>> valid_session_id = client.session_id
passphrase: str | object | None = None,
derive_cardano: bool = False,
session_id: bytes | None = None,
should_derive: bool = True,
) -> Session:
"""
if new_session:
self.session_id = None
elif session_id is not None:
self.session_id = session_id
Returns initialized session (with derived seed).
resp = self.call_raw(
messages.Initialize(
session_id=self.session_id,
Will fail if the device is not initialized
"""
from .transport.session import SessionV1, derive_seed
if isinstance(self.protocol, ProtocolV1Channel):
session = SessionV1.new(
self,
derive_cardano=derive_cardano,
session_id=session_id,
)
if should_derive:
if isinstance(passphrase, str):
temporary = self.passphrase_callback
self.passphrase_callback = get_callback_passphrase_v1(
passphrase=passphrase
)
derive_seed(session)
self.passphrase_callback = temporary
elif passphrase is PASSPHRASE_ON_DEVICE:
derive_seed(session)
return session
raise NotImplementedError
def resume_session(self, session: Session) -> Session:
"""
Note: this function potentially modifies the input session.
"""
from .transport.session import SessionV1
if isinstance(session, SessionV1):
session.init_session()
return session
else:
raise NotImplementedError
def get_seedless_session(self, new_session: bool = False) -> Session:
from .transport.session import SessionV1
if not new_session and self._seedless_session is not None:
return self._seedless_session
if isinstance(self.protocol, ProtocolV1Channel):
self._seedless_session = SessionV1.new(client=self, derive_cardano=False)
assert self._seedless_session is not None
return self._seedless_session
def invalidate(self) -> None:
self._is_invalidated = True
@property
def features(self) -> messages.Features:
if self._features is None:
self._features = self.protocol.get_features()
self.check_firmware_version(warn_only=True)
assert self._features is not None
return self._features
@property
def protocol_version(self) -> int:
return self._protocol_version
@property
def model(self) -> models.TrezorModel:
model = models.detect(self.features)
if self.features.vendor not in model.vendors:
raise exceptions.TrezorException(
f"Unrecognized vendor: {self.features.vendor}"
)
return model
@property
def version(self) -> tuple[int, int, int]:
f = self.features
ver = (
f.major_version,
f.minor_version,
f.patch_version,
)
if isinstance(resp, messages.Failure):
# can happen if `derive_cardano` does not match the current session
raise exceptions.TrezorFailure(resp)
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to Initialize")
return ver
if self.session_id is not None and resp.session_id == self.session_id:
LOG.info("Successfully resumed session")
elif session_id is not None:
LOG.info("Failed to resume session")
@property
def is_invalidated(self) -> bool:
return self._is_invalidated
# TT < 2.3.0 compatibility:
# _refresh_features will clear out the session_id field. We want this function
# to return its value, so that callers can rely on it being either a valid
# session_id, or None if we can't do that.
# Older TT FW does not report session_id in Features and self.session_id might
# be invalid because TT will not allocate a session_id until a passphrase
# exchange happens.
reported_session_id = resp.session_id
self._refresh_features(resp)
return reported_session_id
def refresh_features(self) -> messages.Features:
self.protocol.update_features()
self._features = self.protocol.get_features()
self.check_firmware_version(warn_only=True)
return self._features
def _get_protocol(self) -> Channel:
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
return protocol
def is_outdated(self) -> bool:
if self.features.bootloader_mode:
@ -388,101 +223,35 @@ class TrezorClient(Generic[UI]):
else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
def ping(self, msg: str, button_protection: bool = False) -> str:
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.
# So we short-circuit the simplest variant of ping with call_raw.
if not button_protection:
# XXX this should be: `with self:`
try:
self.open()
resp = self.call_raw(messages.Ping(message=msg))
if isinstance(resp, messages.ButtonRequest):
# device is PIN-locked.
# respond and hope for the best
resp = self._callback_button(resp)
resp = messages.Success.ensure_isinstance(resp)
assert resp.message is not None
return resp.message
finally:
self.close()
resp = self.call(
messages.Ping(message=msg, button_protection=button_protection),
expect=messages.Success,
)
assert resp.message is not None
return resp.message
def get_default_client(
path: t.Optional[str] = None,
**kwargs: t.Any,
) -> "TrezorClient":
"""Get a client for a connected Trezor device.
def get_device_id(self) -> Optional[str]:
return self.features.device_id
Returns a TrezorClient instance with minimum fuss.
@session
def lock(self, *, _refresh_features: bool = True) -> None:
"""Lock the device.
Transport is opened and should be closed after you're done with the client.
If the device does not have a PIN configured, this will do nothing.
Otherwise, a lock screen will be shown and the device will prompt for PIN
before further actions.
If path is specified, does a prefix-search for the specified device. Otherwise, uses
the value of TREZOR_PATH env variable, or finds first connected Trezor.
"""
This call does _not_ invalidate passphrase cache. If passphrase is in use,
the device will not prompt for it after unlocking.
if path is None:
path = os.getenv("TREZOR_PATH")
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
passphrase cache, use `clear_session()`.
"""
# Private argument _refresh_features can be used internally to avoid
# refreshing in cases where we will refresh soon anyway. This is used
# in TrezorClient.clear_session()
self.call(messages.LockDevice())
if _refresh_features:
self.refresh_features()
transport = get_transport(path, prefix_search=True)
transport.open()
@session
def ensure_unlocked(self) -> None:
"""Ensure the device is unlocked and a passphrase is cached.
return TrezorClient(transport, **kwargs)
If the device is locked, this will prompt for PIN. If passphrase is enabled
and no passphrase is cached for the current session, the device will also
prompt for passphrase.
After calling this method, further actions on the device will not prompt for
PIN or passphrase until the device is locked or the session becomes invalid.
"""
from .btc import get_address
def get_callback_passphrase_v1(
passphrase: str = "",
) -> t.Callable[[Session, t.Any], t.Any] | None:
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features()
def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any:
return session.call(messages.PassphraseAck(passphrase=passphrase))
def end_session(self) -> None:
"""Close the current session and clear cached passphrase.
The session will become invalid until `init_device()` is called again.
If passphrase is enabled, further actions will prompt for it again.
This is a no-op in bootloader mode, as it does not support session management.
"""
# since: 2.3.4, 1.9.4
try:
if not self.features.bootloader_mode:
self.call(messages.EndSession())
except exceptions.TrezorFailure:
# A failure most likely means that the FW version does not support
# the EndSession call. We ignore the failure and clear the local session_id.
# The client-side end result is identical.
pass
self.session_id = None
@session
def clear_session(self) -> None:
"""Lock the device and present a fresh session.
The current session will be invalidated and a new one will be started. If the
device has PIN enabled, it will become locked.
Equivalent to calling `lock()`, `end_session()` and `init_device()`.
"""
self.lock(_refresh_features=False)
self.end_session()
self.init_device(new_session=True)
return _callback_passphrase_v1

File diff suppressed because it is too large Load Diff

View File

@ -28,16 +28,10 @@ from slip10 import SLIP10
from . import messages
from .exceptions import Cancelled, TrezorException
from .tools import (
Address,
_deprecation_retval_helper,
_return_success,
parse_path,
session,
)
from .tools import Address, _deprecation_retval_helper, _return_success, parse_path
if TYPE_CHECKING:
from .client import TrezorClient
from .transport.session import Session
RECOVERY_BACK = "\x08" # backspace character, sent literally
@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
@session
def apply_settings(
client: "TrezorClient",
session: "Session",
label: Optional[str] = None,
language: Optional[str] = None,
use_passphrase: Optional[bool] = None,
@ -79,13 +72,13 @@ def apply_settings(
haptic_feedback=haptic_feedback,
)
out = client.call(settings, expect=messages.Success)
client.refresh_features()
out = session.call(settings, expect=messages.Success)
session.refresh_features()
return _return_success(out)
def _send_language_data(
client: "TrezorClient",
session: "Session",
request: "messages.TranslationDataRequest",
language_data: bytes,
) -> None:
@ -95,69 +88,61 @@ def _send_language_data(
data_length = response.data_length
data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
@session
def change_language(
client: "TrezorClient",
session: "Session",
language_data: bytes,
show_display: bool | None = None,
) -> str | None:
data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg)
response = session.call(msg)
if data_length > 0:
response = messages.TranslationDataRequest.ensure_isinstance(response)
_send_language_data(client, response, language_data)
_send_language_data(session, response, language_data)
else:
messages.Success.ensure_isinstance(response)
client.refresh_features() # changing the language in features
session.refresh_features() # changing the language in features
return _return_success(messages.Success(message="Language changed."))
@session
def apply_flags(client: "TrezorClient", flags: int) -> str | None:
out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
client.refresh_features()
def apply_flags(session: "Session", flags: int) -> str | None:
out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
session.refresh_features()
return _return_success(out)
@session
def change_pin(client: "TrezorClient", remove: bool = False) -> str | None:
ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success)
client.refresh_features()
def change_pin(session: "Session", remove: bool = False) -> str | None:
ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success)
session.refresh_features()
return _return_success(ret)
@session
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None:
ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
client.refresh_features()
def change_wipe_code(session: "Session", remove: bool = False) -> str | None:
ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
session.refresh_features()
return _return_success(ret)
@session
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
session: "Session", operation: messages.SdProtectOperationType
) -> str | None:
ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success)
client.refresh_features()
ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success)
session.refresh_features()
return _return_success(ret)
@session
def wipe(client: "TrezorClient") -> str | None:
ret = client.call(messages.WipeDevice(), expect=messages.Success)
if not client.features.bootloader_mode:
client.init_device()
def wipe(session: "Session") -> str | None:
ret = session.call(messages.WipeDevice(), expect=messages.Success)
session.invalidate()
return _return_success(ret)
@session
def recover(
client: "TrezorClient",
session: "Session",
word_count: int = 24,
passphrase_protection: bool = False,
pin_protection: bool = True,
@ -193,13 +178,13 @@ def recover(
if type is None:
type = messages.RecoveryType.NormalRecovery
if client.features.model == "1" and input_callback is None:
if session.features.model == "1" and input_callback is None:
raise RuntimeError("Input callback required for Trezor One")
if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24")
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
raise RuntimeError(
"Device already initialized. Call device.wipe() and try again."
)
@ -221,20 +206,20 @@ def recover(
msg.label = label
msg.u2f_counter = u2f_counter
res = client.call(msg)
res = session.call(msg)
while isinstance(res, messages.WordRequest):
try:
assert input_callback is not None
inp = input_callback(res.type)
res = client.call(messages.WordAck(word=inp))
res = session.call(messages.WordAck(word=inp))
except Cancelled:
res = client.call(messages.Cancel())
res = session.call(messages.Cancel())
# check that the result is a Success
res = messages.Success.ensure_isinstance(res)
# reinitialize the device
client.init_device()
session.refresh_features()
return _deprecation_retval_helper(res)
@ -280,7 +265,7 @@ def _seed_from_entropy(
def reset(
client: "TrezorClient",
session: "Session",
display_random: bool = False,
strength: Optional[int] = None,
passphrase_protection: bool = False,
@ -313,7 +298,7 @@ def reset(
)
setup(
client,
session,
strength=strength,
passphrase_protection=passphrase_protection,
pin_protection=pin_protection,
@ -331,9 +316,8 @@ def _get_external_entropy() -> bytes:
return secrets.token_bytes(32)
@session
def setup(
client: "TrezorClient",
session: "Session",
*,
strength: Optional[int] = None,
passphrase_protection: bool = True,
@ -388,19 +372,19 @@ def setup(
check.
"""
if client.features.initialized:
if session.features.initialized:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
)
if strength is None:
if client.features.model == "1":
if session.features.model == "1":
strength = 256
else:
strength = 128
if backup_type is None:
if client.version < SLIP39_EXTENDABLE_MIN_VERSION:
if session.version < SLIP39_EXTENDABLE_MIN_VERSION:
# includes Trezor One 1.x.x
backup_type = messages.BackupType.Bip39
else:
@ -411,7 +395,7 @@ def setup(
paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")]
if entropy_check_count is None:
if client.version < ENTROPY_CHECK_MIN_VERSION:
if session.version < ENTROPY_CHECK_MIN_VERSION:
# includes Trezor One 1.x.x
entropy_check_count = 0
else:
@ -431,18 +415,18 @@ def setup(
)
if entropy_check_count > 0:
xpubs = _reset_with_entropycheck(
client, msg, entropy_check_count, paths, _get_entropy
session, msg, entropy_check_count, paths, _get_entropy
)
else:
_reset_no_entropycheck(client, msg, _get_entropy)
_reset_no_entropycheck(session, msg, _get_entropy)
xpubs = []
client.init_device()
session.refresh_features()
return xpubs
def _reset_no_entropycheck(
client: "TrezorClient",
session: "Session",
msg: messages.ResetDevice,
get_entropy: Callable[[], bytes],
) -> None:
@ -454,12 +438,12 @@ def _reset_no_entropycheck(
<< Success
"""
assert msg.entropy_check is False
client.call(msg, expect=messages.EntropyRequest)
client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
session.call(msg, expect=messages.EntropyRequest)
session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
def _reset_with_entropycheck(
client: "TrezorClient",
session: "Session",
reset_msg: messages.ResetDevice,
entropy_check_count: int,
paths: Iterable[Address],
@ -495,7 +479,7 @@ def _reset_with_entropycheck(
def get_xpubs() -> list[tuple[Address, str]]:
xpubs = []
for path in paths:
resp = client.call(
resp = session.call(
messages.GetPublicKey(address_n=path), expect=messages.PublicKey
)
xpubs.append((path, resp.xpub))
@ -524,13 +508,13 @@ def _reset_with_entropycheck(
raise TrezorException("Invalid XPUB in entropy check")
xpubs = []
resp = client.call(reset_msg, expect=messages.EntropyRequest)
resp = session.call(reset_msg, expect=messages.EntropyRequest)
entropy_commitment = resp.entropy_commitment
while True:
# provide external entropy for this round
external_entropy = get_entropy()
client.call(
session.call(
messages.EntropyAck(entropy=external_entropy),
expect=messages.EntropyCheckReady,
)
@ -540,7 +524,7 @@ def _reset_with_entropycheck(
if entropy_check_count <= 0:
# last round, wait for a Success and exit the loop
client.call(
session.call(
messages.EntropyCheckContinue(finish=True),
expect=messages.Success,
)
@ -549,7 +533,7 @@ def _reset_with_entropycheck(
entropy_check_count -= 1
# Next round starts.
resp = client.call(
resp = session.call(
messages.EntropyCheckContinue(finish=False),
expect=messages.EntropyRequest,
)
@ -570,13 +554,12 @@ def _reset_with_entropycheck(
return xpubs
@session
def backup(
client: "TrezorClient",
session: "Session",
group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (),
) -> str | None:
ret = client.call(
ret = session.call(
messages.BackupDevice(
group_threshold=group_threshold,
groups=[
@ -586,37 +569,36 @@ def backup(
),
expect=messages.Success,
)
client.refresh_features()
session.refresh_features()
return _return_success(ret)
def cancel_authorization(client: "TrezorClient") -> str | None:
ret = client.call(messages.CancelAuthorization(), expect=messages.Success)
def cancel_authorization(session: "Session") -> str | None:
ret = session.call(messages.CancelAuthorization(), expect=messages.Success)
return _return_success(ret)
def unlock_path(client: "TrezorClient", n: "Address") -> bytes:
resp = client.call(
def unlock_path(session: "Session", n: "Address") -> bytes:
resp = session.call(
messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
)
# Cancel the UnlockPath workflow now that we have the authentication code.
try:
client.call(messages.Cancel())
session.call(messages.Cancel())
except Cancelled:
return resp.mac
else:
raise TrezorException("Unexpected response in UnlockPath flow")
@session
def reboot_to_bootloader(
client: "TrezorClient",
session: "Session",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None,
language_data: bytes = b"",
) -> str | None:
response = client.call(
response = session.call(
messages.RebootToBootloader(
boot_command=boot_command,
firmware_header=firmware_header,
@ -624,43 +606,38 @@ def reboot_to_bootloader(
)
)
if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data)
response = _send_language_data(session, response, language_data)
return _return_success(messages.Success(message=""))
@session
def show_device_tutorial(client: "TrezorClient") -> str | None:
ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
def show_device_tutorial(session: "Session") -> str | None:
ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success)
return _return_success(ret)
@session
def unlock_bootloader(client: "TrezorClient") -> str | None:
ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
def unlock_bootloader(session: "Session") -> str | None:
ret = session.call(messages.UnlockBootloader(), expect=messages.Success)
return _return_success(ret)
@session
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None:
"""Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state.
"""
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
client.refresh_features()
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
session.refresh_features()
return _return_success(ret)
def authenticate(
client: "TrezorClient", challenge: bytes
) -> messages.AuthenticityProof:
return client.call(
def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof:
return session.call(
messages.AuthenticateDevice(challenge=challenge),
expect=messages.AuthenticityProof,
)
def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None:
ret = client.call(messages.SetBrightness(value=value), expect=messages.Success)
def set_brightness(session: "Session", value: Optional[int] = None) -> str | None:
ret = session.call(messages.SetBrightness(value=value), expect=messages.Success)
return _return_success(ret)

View File

@ -18,11 +18,11 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages
from .tools import b58decode, session
from .tools import b58decode
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
def name_to_number(name: str) -> int:
@ -319,17 +319,16 @@ def parse_transaction_json(
def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False
session: "Session", n: "Address", show_display: bool = False
) -> messages.EosPublicKey:
return client.call(
return session.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EosPublicKey,
)
@session
def sign_tx(
client: "TrezorClient",
session: "Session",
address: "Address",
transaction: dict,
chain_id: str,
@ -345,11 +344,11 @@ def sign_tx(
chunkify=chunkify,
)
response = client.call(msg)
response = session.call(msg)
try:
while isinstance(response, messages.EosTxActionRequest):
response = client.call(actions.pop(0))
response = session.call(actions.pop(0))
except IndexError:
# pop from empty list
raise exceptions.TrezorException(

View File

@ -18,11 +18,11 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages
from .tools import prepare_message_bytes, session, unharden
from .tools import prepare_message_bytes, unharden
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
def int_to_big_endian(value: int) -> bytes:
@ -161,13 +161,13 @@ def network_from_address_n(
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
show_display: bool = False,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> str:
resp = client.call(
resp = session.call(
messages.EthereumGetAddress(
address_n=n,
show_display=show_display,
@ -181,17 +181,16 @@ def get_address(
def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False
session: "Session", n: "Address", show_display: bool = False
) -> messages.EthereumPublicKey:
return client.call(
return session.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EthereumPublicKey,
)
@session
def sign_tx(
client: "TrezorClient",
session: "Session",
n: "Address",
nonce: int,
gas_price: int,
@ -227,13 +226,13 @@ def sign_tx(
data, chunk = data[1024:], data[:1024]
msg.data_initial_chunk = chunk
response = client.call(msg)
response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
@ -248,9 +247,8 @@ def sign_tx(
return response.signature_v, response.signature_r, response.signature_s
@session
def sign_tx_eip1559(
client: "TrezorClient",
session: "Session",
n: "Address",
*,
nonce: int,
@ -283,13 +281,13 @@ def sign_tx_eip1559(
chunkify=chunkify,
)
response = client.call(msg)
response = session.call(msg)
assert isinstance(response, messages.EthereumTxRequest)
while response.data_length is not None:
data_length = response.data_length
data, chunk = data[data_length:], data[:data_length]
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
assert isinstance(response, messages.EthereumTxRequest)
assert response.signature_v is not None
@ -299,13 +297,13 @@ def sign_tx_eip1559(
def sign_message(
client: "TrezorClient",
session: "Session",
n: "Address",
message: AnyStr,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> messages.EthereumMessageSignature:
return client.call(
return session.call(
messages.EthereumSignMessage(
address_n=n,
message=prepare_message_bytes(message),
@ -317,7 +315,7 @@ def sign_message(
def sign_typed_data(
client: "TrezorClient",
session: "Session",
n: "Address",
data: Dict[str, Any],
*,
@ -333,7 +331,7 @@ def sign_typed_data(
metamask_v4_compat=metamask_v4_compat,
definitions=definitions,
)
response = client.call(request)
response = session.call(request)
# Sending all the types
while isinstance(response, messages.EthereumTypedDataStructRequest):
@ -349,7 +347,7 @@ def sign_typed_data(
members.append(struct_member)
request = messages.EthereumTypedDataStructAck(members=members)
response = client.call(request)
response = session.call(request)
# Sending the whole message that should be signed
while isinstance(response, messages.EthereumTypedDataValueRequest):
@ -362,7 +360,7 @@ def sign_typed_data(
member_typename = data["primaryType"]
member_data = data["message"]
else:
client.cancel()
session.cancel()
raise exceptions.TrezorException("Root index can only be 0 or 1")
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
@ -385,20 +383,20 @@ def sign_typed_data(
encoded_data = encode_data(member_data, member_typename)
request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request)
response = session.call(request)
return messages.EthereumTypedDataSignature.ensure_isinstance(response)
def verify_message(
client: "TrezorClient",
session: "Session",
address: str,
signature: bytes,
message: AnyStr,
chunkify: bool = False,
) -> bool:
try:
client.call(
session.call(
messages.EthereumVerifyMessage(
address=address,
signature=signature,
@ -413,13 +411,13 @@ def verify_message(
def sign_typed_data_hash(
client: "TrezorClient",
session: "Session",
n: "Address",
domain_hash: bytes,
message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None,
) -> messages.EthereumTypedDataSignature:
return client.call(
return session.call(
messages.EthereumSignTypedHash(
address_n=n,
domain_separator_hash=domain_hash,

View File

@ -85,3 +85,10 @@ class UnexpectedMessageError(TrezorException):
self.expected = expected
self.actual = actual
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")
class FailedSessionResumption(TrezorException):
"""Provided session_id is not valid / session cannot be resumed.
Raised when `trezorctl -s <sesssion_id>` is used or `TREZOR_SESSION_ID = <session_id>`
is set and resumption of session with the `session_id` fails."""

View File

@ -22,37 +22,37 @@ from . import messages
from .tools import _return_success
if TYPE_CHECKING:
from .client import TrezorClient
from .transport.session import Session
def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]:
return client.call(
def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]:
return session.call(
messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
).credentials
def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None:
ret = client.call(
def add_credential(session: "Session", credential_id: bytes) -> str | None:
ret = session.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id),
expect=messages.Success,
)
return _return_success(ret)
def remove_credential(client: "TrezorClient", index: int) -> str | None:
ret = client.call(
def remove_credential(session: "Session", index: int) -> str | None:
ret = session.call(
messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
)
return _return_success(ret)
def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None:
ret = client.call(
def set_counter(session: "Session", u2f_counter: int) -> str | None:
ret = session.call(
messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
)
return _return_success(ret)
def get_next_counter(client: "TrezorClient") -> int:
ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
def get_next_counter(session: "Session") -> int:
ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
return ret.u2f_counter

View File

@ -22,7 +22,6 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard
from .. import messages
from ..tools import session
from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware
from .models import Model
@ -41,7 +40,7 @@ if True:
from .vendor import * # noqa: F401, F403
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.session import Session
T = t.TypeVar("T", bound="FirmwareType")
@ -77,20 +76,19 @@ def is_onev2(fw: FirmwareType) -> TypeGuard[LegacyFirmware]:
# ====== Client functions ====== #
@session
def update(
client: TrezorClient,
session: Session,
data: bytes,
progress_update: t.Callable[[int], t.Any] = lambda _: None,
):
if client.features.bootloader_mode is False:
if session.features.bootloader_mode is False:
raise RuntimeError("Device must be in bootloader mode")
resp = client.call(messages.FirmwareErase(length=len(data)))
resp = session.call(messages.FirmwareErase(length=len(data)))
# TREZORv1 method
if isinstance(resp, messages.Success):
resp = client.call(messages.FirmwareUpload(payload=data))
resp = session.call(messages.FirmwareUpload(payload=data))
progress_update(len(data))
if isinstance(resp, messages.Success):
return
@ -102,7 +100,7 @@ def update(
length = resp.length
payload = data[resp.offset : resp.offset + length]
digest = blake2s(payload).digest()
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
progress_update(length)
if isinstance(resp, messages.Success):
@ -111,7 +109,7 @@ def update(
raise RuntimeError(f"Unexpected message {resp}")
def get_hash(client: TrezorClient, challenge: bytes | None) -> bytes:
return client.call(
def get_hash(session: Session, challenge: bytes | None) -> bytes:
return session.call(
messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
).hash

View File

@ -85,6 +85,7 @@ class ProtobufMapping:
mapping = cls()
message_types = getattr(module, "MessageType")
for entry in message_types:
msg_class = getattr(module, entry.name, None)
if msg_class is None:

View File

@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional
from . import messages
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
def get_entropy(client: "TrezorClient", size: int) -> bytes:
return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
def get_entropy(session: "Session", size: int) -> bytes:
return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
def sign_identity(
client: "TrezorClient",
session: "Session",
identity: messages.IdentityType,
challenge_hidden: bytes,
challenge_visual: str,
ecdsa_curve_name: Optional[str] = None,
) -> messages.SignedIdentity:
return client.call(
return session.call(
messages.SignIdentity(
identity=identity,
challenge_hidden=challenge_hidden,
@ -46,12 +46,12 @@ def sign_identity(
def get_ecdh_session_key(
client: "TrezorClient",
session: "Session",
identity: messages.IdentityType,
peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None,
) -> messages.ECDHSessionKey:
return client.call(
return session.call(
messages.GetECDHSessionKey(
identity=identity,
peer_public_key=peer_public_key,
@ -62,7 +62,7 @@ def get_ecdh_session_key(
def encrypt_keyvalue(
client: "TrezorClient",
session: "Session",
n: "Address",
key: str,
value: bytes,
@ -70,7 +70,7 @@ def encrypt_keyvalue(
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> bytes:
return client.call(
return session.call(
messages.CipherKeyValue(
address_n=n,
key=key,
@ -85,7 +85,7 @@ def encrypt_keyvalue(
def decrypt_keyvalue(
client: "TrezorClient",
session: "Session",
n: "Address",
key: str,
value: bytes,
@ -93,7 +93,7 @@ def decrypt_keyvalue(
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> bytes:
return client.call(
return session.call(
messages.CipherKeyValue(
address_n=n,
key=key,
@ -107,5 +107,5 @@ def decrypt_keyvalue(
).value
def get_nonce(client: "TrezorClient") -> bytes:
return client.call(messages.GetNonce(), expect=messages.Nonce).nonce
def get_nonce(session: "Session") -> bytes:
return session.call(messages.GetNonce(), expect=messages.Nonce).nonce

View File

@ -19,8 +19,8 @@ from typing import TYPE_CHECKING
from . import messages
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
# MAINNET = 0
@ -30,13 +30,13 @@ if TYPE_CHECKING:
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False,
) -> bytes:
return client.call(
return session.call(
messages.MoneroGetAddress(
address_n=n,
show_display=show_display,
@ -48,11 +48,11 @@ def get_address(
def get_watch_key(
client: "TrezorClient",
session: "Session",
n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> messages.MoneroWatchKey:
return client.call(
return session.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
expect=messages.MoneroWatchKey,
)

View File

@ -20,8 +20,8 @@ from typing import TYPE_CHECKING
from . import exceptions, messages
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
TYPE_TRANSACTION_TRANSFER = 0x0101
TYPE_IMPORTANCE_TRANSFER = 0x0801
@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
def get_address(
client: "TrezorClient",
session: "Session",
n: "Address",
network: int,
show_display: bool = False,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify
),
@ -210,7 +210,7 @@ def get_address(
def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
) -> messages.NEMSignedTx:
try:
msg = create_sign_tx(transaction, chunkify=chunkify)
@ -219,4 +219,4 @@ def sign_tx(
assert msg.transaction is not None
msg.transaction.address_n = n
return client.call(msg, expect=messages.NEMSignedTx)
return session.call(msg, expect=messages.NEMSignedTx)

View File

@ -20,12 +20,12 @@ from typing import TYPE_CHECKING
from . import messages
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
def get_pubkey(client: "TrezorClient", n: "Address") -> bytes:
return client.call(
def get_pubkey(session: "Session", n: "Address") -> bytes:
return session.call(
messages.NostrGetPubkey(
address_n=n,
),
@ -34,7 +34,7 @@ def get_pubkey(client: "TrezorClient", n: "Address") -> bytes:
def sign_event(
client: "TrezorClient",
session: "Session",
sign_event: messages.NostrSignEvent,
) -> messages.NostrEventSignature:
return client.call(sign_event, expect=messages.NostrEventSignature)
return session.call(sign_event, expect=messages.NostrEventSignature)

View File

@ -21,20 +21,20 @@ from .protobuf import dict_to_proto
from .tools import dict_from_camelcase
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
),
@ -43,14 +43,14 @@ def get_address(
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: "Address",
msg: messages.RippleSignTx,
chunkify: bool = False,
) -> messages.RippleSignedTx:
msg.address_n = address_n
msg.chunkify = chunkify
return client.call(msg, expect=messages.RippleSignedTx)
return session.call(msg, expect=messages.RippleSignedTx)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional
from . import messages
if TYPE_CHECKING:
from .client import TrezorClient
from .transport.session import Session
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: List[int],
show_display: bool,
) -> bytes:
return client.call(
return session.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
expect=messages.SolanaPublicKey,
).public_key
def get_address(
client: "TrezorClient",
session: "Session",
address_n: List[int],
show_display: bool,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
messages.SolanaGetAddress(
address_n=address_n,
show_display=show_display,
@ -34,12 +34,12 @@ def get_address(
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: List[int],
serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> bytes:
return client.call(
return session.call(
messages.SolanaSignTx(
address_n=address_n,
serialized_tx=serialized_tx,

View File

@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union
from . import exceptions, messages
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
StellarMessageType = Union[
messages.StellarAccountMergeOp,
@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
),
@ -336,7 +336,7 @@ def get_address(
def sign_tx(
client: "TrezorClient",
session: "Session",
tx: messages.StellarSignTx,
operations: List["StellarMessageType"],
address_n: "Address",
@ -352,10 +352,10 @@ def sign_tx(
# 3. Receive a StellarTxOpRequest message
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
# 5. The final message received will be StellarSignedTx which is returned from this method
resp = client.call(tx)
resp = session.call(tx)
try:
while isinstance(resp, messages.StellarTxOpRequest):
resp = client.call(operations.pop(0))
resp = session.call(operations.pop(0))
except IndexError:
# pop from empty list
raise exceptions.TrezorException(

View File

@ -19,17 +19,17 @@ from typing import TYPE_CHECKING
from . import messages
if TYPE_CHECKING:
from .client import TrezorClient
from .tools import Address
from .transport.session import Session
def get_address(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
),
@ -38,12 +38,12 @@ def get_address(
def get_public_key(
client: "TrezorClient",
session: "Session",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> str:
return client.call(
return session.call(
messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify
),
@ -52,11 +52,11 @@ def get_public_key(
def sign_tx(
client: "TrezorClient",
session: "Session",
address_n: "Address",
sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False,
) -> messages.TezosSignedTx:
sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg, expect=messages.TezosSignedTx)
return session.call(sign_tx_msg, expect=messages.TezosSignedTx)

View File

@ -45,7 +45,7 @@ if TYPE_CHECKING:
# More details: https://www.python.org/dev/peps/pep-0612/
from typing import TypeVar
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec
from . import client
from .messages import Success
@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None:
return _deprecation_retval_helper(msg.message, stacklevel=1)
def session(
f: "Callable[Concatenate[TrezorClient, P], R]",
) -> "Callable[Concatenate[TrezorClient, P], R]":
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client.open()
try:
return f(client, *args, **kwargs)
finally:
client.close()
return wrapped_f
# de-camelcasifier
# https://stackoverflow.com/a/1176023/222189

View File

@ -17,14 +17,15 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar
import typing as t
from ..exceptions import TrezorException
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from ..models import TrezorModel
T = TypeVar("T", bound="Transport")
T = t.TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__)
@ -34,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip()
MessagePayload = Tuple[int, bytes]
MessagePayload = t.Tuple[int, bytes]
class TransportException(TrezorException):
@ -50,72 +51,57 @@ class Timeout(TransportException):
class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX: str
ENABLED = False
def __str__(self) -> str:
return self.get_path()
@classmethod
def enumerate(
cls: t.Type[T], models: t.Iterable[TrezorModel] | None = None
) -> t.Iterable[T]:
raise NotImplementedError
@classmethod
def find_by_path(cls: t.Type[T], path: str, prefix_search: bool = False) -> T:
for device in cls.enumerate():
if device.get_path() == path:
return device
if prefix_search and device.get_path().startswith(path):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str:
raise NotImplementedError
def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
def find_debug(self: T) -> T:
raise NotImplementedError
@classmethod
def enumerate(
cls: type[T], models: Iterable[TrezorModel] | None = None
) -> Iterable[T]:
def open(self) -> None:
raise NotImplementedError
@classmethod
def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
for device in cls.enumerate():
if (
path is None
or device.get_path() == path
or (prefix_search and device.get_path().startswith(path))
):
return device
def close(self) -> None:
raise NotImplementedError
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self, timeout: float | None = None) -> bytes:
raise NotImplementedError
def ping(self) -> bool:
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int | None]
def all_transports() -> Iterable[type["Transport"]]:
def all_transports() -> t.Iterable[t.Type["Transport"]]:
from .bridge import BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
transports: Tuple[type["Transport"], ...] = (
transports: t.Tuple[t.Type["Transport"], ...] = (
BridgeTransport,
HidTransport,
UdpTransport,
@ -125,9 +111,9 @@ def all_transports() -> Iterable[type["Transport"]]:
def enumerate_devices(
models: Iterable[TrezorModel] | None = None,
) -> Sequence[Transport]:
devices: list[Transport] = []
models: t.Iterable[TrezorModel] | None = None,
) -> t.Sequence[Transport]:
devices: t.List[Transport] = []
for transport in all_transports():
name = transport.__name__
try:

View File

@ -17,16 +17,15 @@
from __future__ import annotations
import logging
import struct
from typing import TYPE_CHECKING, Any, Iterable
import typing as t
import requests
from typing_extensions import Self
from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
from . import DeviceIsBusy, Transport, TransportException
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from ..models import TrezorModel
LOG = logging.getLogger(__name__)
@ -58,10 +57,13 @@ def call_bridge(
return r
def is_legacy_bridge() -> bool:
def get_bridge_version() -> t.Tuple[int, ...]:
config = call_bridge("configure").json()
version_tuple = tuple(map(int, config["version"].split(".")))
return version_tuple < TREZORD_VERSION_MODERN
return tuple(map(int, config["version"].split(".")))
def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN
class BridgeHandle:
@ -115,15 +117,15 @@ class BridgeTransport(Transport):
PATH_PREFIX = "bridge"
ENABLED: bool = True
CHUNK_SIZE = None
def __init__(
self, device: dict[str, Any], legacy: bool, debug: bool = False
self, device: dict[str, t.Any], legacy: bool, debug: bool = False
) -> None:
if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge")
self.device = device
self.session: str | None = None
self.session: str | None = device["session"]
self.debug = debug
self.legacy = legacy
@ -154,8 +156,8 @@ class BridgeTransport(Transport):
@classmethod
def enumerate(
cls, _models: Iterable[TrezorModel] | None = None
) -> Iterable["BridgeTransport"]:
cls, _models: t.Iterable[TrezorModel] | None = None
) -> t.Iterable["BridgeTransport"]:
try:
legacy = is_legacy_bridge()
return [
@ -164,7 +166,7 @@ class BridgeTransport(Transport):
except Exception:
return []
def begin_session(self) -> None:
def open(self) -> None:
try:
data = self._call("acquire/" + self.device["path"])
except BridgeException as e:
@ -173,18 +175,17 @@ class BridgeTransport(Transport):
raise
self.session = data.json()["session"]
def end_session(self) -> None:
def close(self) -> None:
if not self.session:
return
self._call("release")
self.session = None
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data)
def write_chunk(self, chunk: bytes) -> None:
self.handle.write_buf(chunk)
def read(self, timeout: float | None = None) -> MessagePayload:
data = self.handle.read_buf(timeout=timeout)
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen]
def read_chunk(self, timeout: float | None = None) -> bytes:
return self.handle.read_buf(timeout=timeout)
def ping(self) -> bool:
return self.session is not None

View File

@ -19,12 +19,11 @@ from __future__ import annotations
import logging
import sys
import time
from typing import Any, Dict, Iterable
import typing as t
from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import UDEV_RULES_STR, Timeout, Transport, TransportException
LOG = logging.getLogger(__name__)
@ -37,23 +36,61 @@ except Exception as e:
HID_IMPORTED = False
HidDevice = Dict[str, Any]
HidDeviceHandle = Any
HidDevice = t.Dict[str, t.Any]
HidDeviceHandle = t.Any
class HidHandle:
def __init__(
self, path: bytes, serial: str, probe_hid_version: bool = False
) -> None:
self.path = path
self.serial = serial
class HidTransport(Transport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
self.device = device
self.device_path = device["path"]
self.device_serial_number = device["serial_number"]
self.handle: HidDeviceHandle = None
self.hid_version = None if probe_hid_version else 2
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
) -> t.Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: t.List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self) -> None:
self.handle = hid.device()
try:
self.handle.open_path(self.path)
self.handle.open_path(self.device_path)
except (IOError, OSError) as e:
if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,)
@ -64,11 +101,11 @@ class HidHandle:
# and we wouldn't even know.
# So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string()
if serial != self.serial:
if serial != self.device_serial_number:
self.handle.close()
self.handle = None
raise TransportException(
f"Unexpected device {serial} on path {self.path.decode()}"
f"Unexpected device {serial} on path {self.device_path.decode()}"
)
self.handle.set_nonblocking(True)
@ -79,7 +116,7 @@ class HidHandle:
def close(self) -> None:
if self.handle is not None:
# reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string()
self.device_serial_number = self.handle.get_serial_number_string()
self.handle.close()
self.handle = None
@ -119,52 +156,8 @@ class HidHandle:
return 1
raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
super().__init__(protocol=ProtocolV1(self.handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
) -> Iterable[HidTransport]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: list[HidTransport] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> HidTransport:
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def ping(self) -> bool:
return self.handle is not None
def is_wirelink(dev: HidDevice) -> bool:

View File

@ -1,179 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import struct
from typing_extensions import Protocol as StructuralType
from . import MessagePayload, Timeout, Transport
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
_DEFAULT_READ_TIMEOUT: float | None = None
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None: ...
def close(self) -> None: ...
def read_chunk(self, timeout: float | None = None) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
try:
# Drop queued responses to old requests
while True:
msg = self.handle.read_chunk(timeout=0.1)
LOG.warning("ignored: %s", msg)
except Timeout:
pass
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self, timeout: float | None = None) -> MessagePayload:
return self.protocol.read(timeout=timeout)
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
HEADER_LEN = struct.calcsize(">HL")
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self, timeout: float | None = None) -> MessagePayload:
if timeout is None:
timeout = _DEFAULT_READ_TIMEOUT
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next(timeout=timeout))
return msg_type, buffer[:datalen]
def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self, timeout: float | None = None) -> bytes:
chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]

View File

@ -0,0 +1,174 @@
from __future__ import annotations
import logging
import typing as t
from .. import exceptions, messages, models
from ..protobuf import MessageType
from .thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING:
from ..client import TrezorClient
LOG = logging.getLogger(__name__)
MT = t.TypeVar("MT", bound=MessageType)
class Session:
def __init__(self, client: TrezorClient, id: bytes) -> None:
self.client = client
self._id = id
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
self.client.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
if self.client.pin_callback is None:
raise NotImplementedError("Missing pin_callback")
resp = self.client.pin_callback(self, resp)
elif isinstance(resp, messages.PassphraseRequest):
if self.client.passphrase_callback is None:
raise NotImplementedError("Missing passphrase_callback")
resp = self.client.passphrase_callback(self, resp)
elif isinstance(resp, messages.ButtonRequest):
resp = (self.client.button_callback or default_button_callback)(
self, resp
)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
elif not isinstance(resp, expect):
raise exceptions.UnexpectedMessageError(expect, resp)
else:
return resp
def call_raw(self, msg: t.Any) -> t.Any:
self._write(msg)
return self._read()
def _write(self, msg: t.Any) -> None:
raise NotImplementedError
def _read(self) -> t.Any:
raise NotImplementedError
def refresh_features(self) -> None:
self.client.refresh_features()
def end(self) -> t.Any:
return self.call(messages.EndSession())
def cancel(self) -> None:
self._write(messages.Cancel())
def ping(self, message: str, button_protection: bool | None = None) -> str:
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.
# So we short-circuit the simplest variant of ping with call_raw.
if not button_protection:
resp = self.call_raw(messages.Ping(message=message))
if isinstance(resp, messages.ButtonRequest):
# device is PIN-locked.
# respond and hope for the best
resp = (self.client.button_callback or default_button_callback)(
self, resp
)
resp = messages.Success.ensure_isinstance(resp)
assert resp.message is not None
return resp.message
resp = self.call(
messages.Ping(message=message, button_protection=button_protection),
expect=messages.Success,
)
assert resp.message is not None
return resp.message
def invalidate(self) -> None:
self.client.invalidate()
@property
def features(self) -> messages.Features:
return self.client.features
@property
def model(self) -> models.TrezorModel:
return self.client.model
@property
def version(self) -> t.Tuple[int, int, int]:
return self.client.version
@property
def id(self) -> bytes:
return self._id
@id.setter
def id(self, value: bytes) -> None:
if not isinstance(value, bytes):
raise ValueError("id must be of type bytes")
self._id = value
class SessionV1(Session):
derive_cardano: bool | None = False
@classmethod
def new(
cls,
client: TrezorClient,
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1Channel)
session = SessionV1(client, id=session_id or b"")
session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano)
return session
@classmethod
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1Channel)
session = SessionV1(client, session_id)
session.init_session()
return session
def _write(self, msg: t.Any) -> None:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
self.client.protocol.write(msg)
def _read(self) -> t.Any:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None) -> None:
if self.id == b"":
session_id = None
else:
session_id = self.id
resp: messages.Features = self.call_raw(
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
)
assert isinstance(resp, messages.Features)
if resp.session_id is not None:
self.id = resp.session_id
def default_button_callback(session: Session, msg: t.Any) -> t.Any:
return session.call_raw(messages.ButtonAck())
def derive_seed(session: Session) -> None:
from ..btc import get_address
from ..client import PASSPHRASE_TEST_PATH
get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
session.refresh_features()

View File

@ -0,0 +1,26 @@
from __future__ import annotations
import logging
from ... import messages
from ...mapping import ProtobufMapping
from .. import Transport
LOG = logging.getLogger(__name__)
class Channel:
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
) -> None:
self.transport = transport
self.mapping = mapping
def get_features(self) -> messages.Features:
raise NotImplementedError()
def update_features(self) -> None:
raise NotImplementedError

View File

@ -0,0 +1,128 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2025 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from .protocol_and_channel import Channel
LOG = logging.getLogger(__name__)
class ProtocolV1Channel(Channel):
_DEFAULT_READ_TIMEOUT: t.ClassVar[float | None] = None
HEADER_LEN: t.ClassVar[int] = struct.calcsize(">HL")
_features: messages.Features | None = None
def get_features(self) -> messages.Features:
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
def read(self, timeout: float | None = None) -> t.Any:
msg_type, msg_bytes = self._read(timeout=timeout)
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
return msg
def write(self, msg: t.Any) -> None:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
if chunk_size is None:
self.transport.write_chunk(header + message_data)
return
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to (chunk_size - 1) bytes
chunk = b"?" + buffer[: chunk_size - 1]
chunk = chunk.ljust(chunk_size, b"\x00")
self.transport.write_chunk(chunk)
buffer = buffer[chunk_size - 1 :]
def _read(self, timeout: float | None = None) -> t.Tuple[int, bytes]:
if timeout is None:
timeout = self._DEFAULT_READ_TIMEOUT
if self.transport.CHUNK_SIZE is None:
return self.read_chunkless(timeout=timeout)
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next(timeout=timeout))
return msg_type, buffer[:datalen]
def read_chunkless(self, timeout: float | None = None) -> t.Tuple[int, bytes]:
data = self.transport.read_chunk(timeout=timeout)
msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN])
return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen]
def read_first(self, timeout: float | None = None) -> t.Tuple[int, int, bytes]:
chunk = self.transport.read_chunk(timeout=timeout)
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self, timeout: float | None = None) -> bytes:
chunk = self.transport.read_chunk(timeout=timeout)
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]

View File

@ -19,11 +19,10 @@ from __future__ import annotations
import logging
import socket
import time
from typing import TYPE_CHECKING, Iterable
from typing import TYPE_CHECKING, Iterable, Tuple
from ..log import DUMP_PACKETS
from . import Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import Timeout, Transport, TransportException
if TYPE_CHECKING:
from ..models import TrezorModel
@ -33,12 +32,13 @@ SOCKET_TIMEOUT = 0.1
LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport):
class UdpTransport(Transport):
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324
PATH_PREFIX = "udp"
ENABLED: bool = True
CHUNK_SIZE = 64
def __init__(self, device: str | None = None) -> None:
if not device:
@ -48,24 +48,17 @@ class UdpTransport(ProtocolBasedTransport):
devparts = device.split(":")
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port)
self.device: Tuple[str, int] = (host, port)
self.socket: socket.socket | None = None
super().__init__(protocol=ProtocolV1(self))
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
super().__init__()
@classmethod
def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path)
try:
d.open()
if d._ping():
if d.ping():
return d
else:
raise TransportException(
@ -99,20 +92,8 @@ class UdpTransport(ProtocolBasedTransport):
assert prefix_search # otherwise we would have raised above
return super().find_by_path(path, prefix_search)
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self._ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise Timeout("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -124,17 +105,6 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.close()
self.socket = None
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"
def write_chunk(self, chunk: bytes) -> None:
assert self.socket is not None
if len(chunk) != 64:
@ -156,3 +126,33 @@ class UdpTransport(ProtocolBasedTransport):
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self.ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise Timeout("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"

View File

@ -20,14 +20,11 @@ import atexit
import logging
import sys
import time
from typing import Iterable
from typing_extensions import Self
from typing import Iterable, List
from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, Transport, TransportException
LOG = logging.getLogger(__name__)
@ -48,14 +45,70 @@ USB_COMM_TIMEOUT_MS = 300
WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle:
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
class WebUsbTransport(Transport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
CHUNK_SIZE = 64
def __init__(
self,
device: "usb1.USBDevice",
debug: bool = False,
) -> None:
self.device = device
self.debug = debug
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0
self.handle: usb1.USBDeviceHandle | None = None
super().__init__()
@classmethod
def enumerate(
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
def open(self) -> None:
self.handle = self.device.open()
if self.handle is None:
@ -68,6 +121,8 @@ class WebUsbHandle:
self.handle.claimInterface(self.interface)
except usb1.USBErrorAccess as e:
raise DeviceIsBusy(self.device) from e
except usb1.USBErrorBusy as e:
raise DeviceIsBusy(self.device) from e
def close(self) -> None:
if self.handle is not None:
@ -119,76 +174,13 @@ class WebUsbHandle:
except Exception as e:
raise TransportException(f"USB read failed: {e}") from e
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
def __init__(
self,
device: usb1.USBDevice,
handle: WebUsbHandle | None = None,
debug: bool = False,
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
self.device = device
self.handle = handle
self.debug = debug
super().__init__(protocol=ProtocolV1(handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(
cls,
models: Iterable[TrezorModel] | None = None,
usb_reset: bool = False,
) -> Iterable[WebUsbTransport]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: list[WebUsbTransport] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
except usb1.USBErrorPipe:
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
return devices
def find_debug(self) -> Self:
def find_debug(self) -> "WebUsbTransport":
# For v1 protocol, find debug USB interface for the same serial number
return self.__class__(self.device, debug=True)
def ping(self) -> bool:
return self.handle is not None
def is_vendor_class(dev: usb1.USBDevice) -> bool:
configurationId = 0

View File

@ -16,16 +16,16 @@
import os
import sys
from typing import Any, Callable, Optional, Union
import typing as t
import click
from mnemonic import Mnemonic
from typing_extensions import Protocol
from . import device, messages
from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE
from .exceptions import Cancelled
from .messages import PinMatrixRequestType, WordRequestType
from .client import MAX_PIN_LENGTH
from .exceptions import Cancelled, PinException
from .messages import Capability, PinMatrixRequestType, WordRequestType
from .transport.session import Session
PIN_MATRIX_DESCRIPTION = """
Use the numeric keypad or lowercase letters to describe number positions.
@ -62,19 +62,11 @@ WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond
CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()
class TrezorClientUI(Protocol):
def button_request(self, br: messages.ButtonRequest) -> None: ...
def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ...
def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ...
def echo(*args: Any, **kwargs: Any) -> None:
def echo(*args: t.Any, **kwargs: t.Any) -> None:
return click.echo(*args, err=True, **kwargs)
def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any:
def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any:
# Disallowing hidden input and warning user when it would cause issues
if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
hide_input = False
@ -99,14 +91,16 @@ class ClickUI:
return "Please confirm action on your Trezor device."
def button_request(self, br: messages.ButtonRequest) -> None:
def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
prompt = self._prompt_for_button(br)
if prompt != self.last_prompt_shown:
echo(prompt)
if not self.always_prompt:
self.last_prompt_shown = prompt
return session.call_raw(messages.ButtonAck())
def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
code = request.type
if code == PIN_CURRENT:
desc = "current PIN"
elif code == PIN_NEW:
@ -129,6 +123,7 @@ class ClickUI:
try:
pin = prompt(f"Please enter {desc}", hide_input=True)
except click.Abort:
session.call_raw(messages.Cancel())
raise Cancelled from None
# translate letters to numbers if letters were used
@ -142,16 +137,33 @@ class ClickUI:
elif len(pin) > MAX_PIN_LENGTH:
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
else:
return pin
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
def get_passphrase(
self, session: Session, request: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
if available_on_device and not self.passphrase_on_host:
return PASSPHRASE_ON_DEVICE
return session.call_raw(
messages.PassphraseAck(passphrase=None, on_device=True)
)
env_passphrase = os.getenv("PASSPHRASE")
if env_passphrase is not None:
echo("Passphrase required. Using PASSPHRASE environment variable.")
return env_passphrase
return session.call_raw(
messages.PassphraseAck(passphrase=env_passphrase, on_device=False)
)
while True:
try:
@ -163,7 +175,7 @@ class ClickUI:
)
# In case user sees the input on the screen, we do not need confirmation
if not CAN_HANDLE_HIDDEN_INPUT:
return passphrase
break
second = prompt(
"Confirm your passphrase",
hide_input=True,
@ -171,12 +183,16 @@ class ClickUI:
show_default=False,
)
if passphrase == second:
return passphrase
break
else:
echo("Passphrase did not match. Please try again.")
except click.Abort:
raise Cancelled from None
return session.call_raw(
messages.PassphraseAck(passphrase=passphrase, on_device=False)
)
class ScriptUI:
"""Interface to be used by scripts, not directly by user.
@ -190,13 +206,14 @@ class ScriptUI:
"""
@staticmethod
def button_request(br: messages.ButtonRequest) -> None:
# TODO: send name={br.name} when it will be supported
def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
code = br.code.name if br.code else None
print(f"?BUTTON code={code} pages={br.pages}")
print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
return session.call_raw(messages.ButtonAck())
@staticmethod
def get_pin(code: Optional[PinMatrixRequestType] = None) -> str:
def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
code = request.type
if code is None:
print("?PIN")
else:
@ -208,10 +225,22 @@ class ScriptUI:
elif not pin.startswith(":"):
raise RuntimeError("Sent PIN must start with ':'")
else:
return pin[1:]
pin = pin[1:]
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
@staticmethod
def get_passphrase(available_on_device: bool) -> Union[str, object]:
def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
if available_on_device:
print("?PASSPHRASE available_on_device")
else:
@ -221,16 +250,21 @@ class ScriptUI:
if passphrase == "CANCEL":
raise Cancelled from None
elif passphrase == "ON_DEVICE":
return PASSPHRASE_ON_DEVICE
return session.call_raw(
messages.PassphraseAck(passphrase=None, on_device=True)
)
elif not passphrase.startswith(":"):
raise RuntimeError("Sent passphrase must start with ':'")
else:
return passphrase[1:]
passphrase = passphrase[1:]
return session.call_raw(
messages.PassphraseAck(passphrase=passphrase, on_device=False)
)
def mnemonic_words(
expand: bool = False, language: str = "english"
) -> Callable[[WordRequestType], str]:
) -> t.Callable[[WordRequestType], str]:
if expand:
wordlist = Mnemonic(language).wordlist
else:

View File

@ -35,7 +35,6 @@ import trezorlib.misc
from trezorlib.client import TrezorClient
from trezorlib.tools import Address
from trezorlib.transport import enumerate_devices
from trezorlib.ui import ClickUI
version_tuple = tuple(map(int, trezorlib.__version__.split(".")))
if not (0, 11) <= version_tuple < (0, 14):
@ -71,16 +70,18 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport":
sys.stderr.write("Available devices:\n")
for d in devices:
try:
client = TrezorClient(d, ui=ClickUI())
d.open()
client = TrezorClient(d)
except IOError:
sys.stderr.write("[-] <device is currently in use>\n")
continue
if client.features.label:
sys.stderr.write(f"[{i}] {client.features.label}\n")
else:
sys.stderr.write(f"[{i}] <no label>\n")
client.close()
if client.features.label:
sys.stderr.write(f"[{i}] {client.features.label}\n")
else:
sys.stderr.write(f"[{i}] <no label>\n")
finally:
d.close()
i += 1
sys.stderr.write("----------------------------\n")
@ -106,7 +107,9 @@ def main() -> None:
devices = wait_for_devices()
transport = choose_device(devices)
client = TrezorClient(transport, ui=ClickUI())
transport.open()
client = TrezorClient(transport)
session = client.get_seedless_session()
rootdir = os.environ["encfs_root"] # Read "man encfs" for more
passw_file = os.path.join(rootdir, "password.dat")
@ -120,7 +123,7 @@ def main() -> None:
sys.stderr.write("Computer asked Trezor for new strong password.\n")
# 32 bytes, good for AES
trezor_entropy = trezorlib.misc.get_entropy(client, 32)
trezor_entropy = trezorlib.misc.get_entropy(session, 32)
urandom_entropy = os.urandom(32)
passw = hashlib.sha256(trezor_entropy + urandom_entropy).digest()
@ -129,7 +132,7 @@ def main() -> None:
bip32_path = Address([10, 0])
passw_encrypted = trezorlib.misc.encrypt_keyvalue(
client, bip32_path, label, passw, False, True
session, bip32_path, label, passw, False, True
)
data = {
@ -144,13 +147,14 @@ def main() -> None:
data = json.load(open(passw_file, "r"))
passw = trezorlib.misc.decrypt_keyvalue(
client,
session,
data["bip32_path"],
data["label"],
bytes.fromhex(data["password_encrypted_hex"]),
False,
True,
)
transport.close()
print(passw)

View File

@ -24,15 +24,19 @@ from trezorlib.tools import parse_path
def main() -> None:
# Use first connected device
client = get_default_client()
session = client.get_session()
# Print out Trezor's features and settings
print(client.features)
print(session.features)
# Get the first address of first BIP44 account
bip32_path = parse_path("44h/0h/0h/0/0")
address = btc.get_address(client, "Bitcoin", bip32_path, True)
address = btc.get_address(session, "Bitcoin", bip32_path, True)
print("Bitcoin address:", address)
# Release underlying transport (USB/BLE/UDP)
client.transport.close()
if __name__ == "__main__":
main()

View File

@ -62,6 +62,8 @@ def main() -> None:
sectoraddrs[sector] + offset, content[offset : offset + step], flash=True
)
debug.close()
if __name__ == "__main__":
main()

View File

@ -58,6 +58,7 @@ def main() -> None:
f.write(mem)
f.close()
debug.close()
if __name__ == "__main__":

View File

@ -39,6 +39,7 @@ def find_debug() -> DebugLink:
def main() -> None:
debug = find_debug()
debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True)
debug.close()
if __name__ == "__main__":

View File

@ -26,23 +26,24 @@ from urllib.parse import urlparse
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from trezorlib import misc, ui
from trezorlib import misc
from trezorlib.client import TrezorClient
from trezorlib.tools import parse_path
from trezorlib.transport import get_transport
from trezorlib.transport.session import Session
# Return path by BIP-32
BIP32_PATH = parse_path("10016h/0")
# Deriving master key
def getMasterKey(client: TrezorClient) -> str:
def getMasterKey(session: Session) -> str:
bip32_path = BIP32_PATH
ENC_KEY = "Activate TREZOR Password Manager?"
ENC_VALUE = bytes.fromhex(
"2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee"
)
key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True)
key = misc.encrypt_keyvalue(session, bip32_path, ENC_KEY, ENC_VALUE, True, True)
return key.hex()
@ -101,7 +102,7 @@ def decryptEntryValue(nonce: str, val: bytes) -> dict:
# Decrypt give entry nonce
def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
def getDecryptedNonce(session: Session, entry: dict) -> str:
print()
print("Waiting for Trezor input ...")
print()
@ -117,7 +118,7 @@ def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
ENC_KEY = f"Unlock {item} for user {entry['username']}?"
ENC_VALUE = entry["nonce"]
decrypted_nonce = misc.decrypt_keyvalue(
client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
session, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
)
return decrypted_nonce.hex()
@ -144,13 +145,15 @@ def main() -> None:
print(e)
return
client = TrezorClient(transport=transport, ui=ui.ClickUI())
transport.open()
client = TrezorClient(transport=transport)
session = client.get_seedless_session()
print()
print("Confirm operation on Trezor")
print()
masterKey = getMasterKey(client)
masterKey = getMasterKey(session)
# print('master key:', masterKey)
fileName = getFileEncKey(masterKey)[0]
@ -173,7 +176,7 @@ def main() -> None:
entry_id = input("Select entry number to decrypt: ")
entry_id = str(entry_id)
plain_nonce = getDecryptedNonce(client, entries[entry_id])
plain_nonce = getDecryptedNonce(session, entries[entry_id])
pwdArr = entries[entry_id]["password"]["data"]
pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr])
@ -183,6 +186,8 @@ def main() -> None:
safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr])
print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
client.transport.close()
if __name__ == "__main__":
main()

View File

@ -36,12 +36,14 @@ import click
from bottle import post, request, response, run
import trezorlib.mapping
import trezorlib.messages
import trezorlib.models
import trezorlib.transport
import trezorlib.transport.session as transport_session
from trezorlib.client import TrezorClient
from trezorlib.protobuf import format_message
from trezorlib.transport.bridge import BridgeTransport
from trezorlib.ui import TrezorClientUI
from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel
# ignore bridge. we are the bridge
BridgeTransport.ENABLED = False
@ -59,15 +61,18 @@ logging.basicConfig(
LOG = logging.getLogger()
class SilentUI(TrezorClientUI):
def get_pin(self, _code: t.Any) -> str:
return ""
def pin_callback(
session: transport_session.Session, request: trezorlib.messages.PinMatrixRequest
) -> t.Any:
return session.call_raw(trezorlib.messages.PinMatrixAck(pin=""))
def get_passphrase(self) -> str:
return ""
def button_request(self, _br: t.Any) -> None:
pass
def passphrase_callback(
session: transport_session.Session, request: trezorlib.messages.PassphraseRequest
) -> t.Any:
return session.call_raw(
trezorlib.messages.PassphraseAck(passphrase="", on_device=False)
)
class Session:
@ -102,10 +107,16 @@ class Transport:
self.path = transport.get_path()
self.session: Session | None = None
self.transport = transport
self.protocol = ProtocolV1Channel(transport, trezorlib.mapping.DEFAULT_MAPPING)
client = TrezorClient(transport, ui=SilentUI())
transport.open()
client = TrezorClient(transport)
client.pin_callback = pin_callback
client.passphrase_callback = passphrase_callback
self.model = client.model
client.end_session()
client.get_seedless_session().end()
transport.close()
def acquire(self, sid: str) -> str:
if self.session_id() != sid:
@ -114,11 +125,11 @@ class Transport:
self.session.release()
self.session = Session(self)
self.transport.begin_session()
self.transport.open()
return self.session.id
def release(self) -> None:
self.transport.end_session()
self.transport.close()
self.session = None
def session_id(self) -> str | None:
@ -139,10 +150,10 @@ class Transport:
}
def write(self, msg_id: int, data: bytes) -> None:
self.transport.write(msg_id, data)
self.protocol._write(msg_id, data)
def read(self) -> tuple[int, bytes]:
return self.transport.read()
return self.protocol._read()
@classmethod
def find(cls, path: str) -> Transport | None:

View File

@ -7,14 +7,17 @@
import io
import sys
from trezorlib import misc, ui
from trezorlib import misc
from trezorlib.client import TrezorClient
from trezorlib.transport import get_transport
def main() -> None:
try:
client = TrezorClient(get_transport(), ui=ui.ClickUI())
transport = get_transport()
transport.open()
client = TrezorClient(transport)
session = client.get_seedless_session()
except Exception as e:
print(e)
return
@ -25,10 +28,10 @@ def main() -> None:
with io.open(arg1, "wb") as f:
for _ in range(0, arg2, step):
entropy = misc.get_entropy(client, step)
entropy = misc.get_entropy(session, step)
f.write(entropy)
client.close()
transport.close()
if __name__ == "__main__":

View File

@ -27,26 +27,29 @@ from trezorlib.client import TrezorClient
from trezorlib.misc import decrypt_keyvalue, encrypt_keyvalue
from trezorlib.tools import parse_path
from trezorlib.transport import get_transport
from trezorlib.ui import ClickUI
BIP32_PATH = parse_path("10016h/0")
def encrypt(type: str, domain: str, secret: str) -> str:
transport = get_transport()
client = TrezorClient(transport, ClickUI())
transport.open()
client = TrezorClient(transport)
session = client.get_seedless_session()
dom = type.upper() + ": " + domain
enc = encrypt_keyvalue(client, BIP32_PATH, dom, secret.encode(), False, True)
client.close()
enc = encrypt_keyvalue(session, BIP32_PATH, dom, secret.encode(), False, True)
transport.close()
return enc.hex()
def decrypt(type: str, domain: str, secret: bytes) -> bytes:
transport = get_transport()
client = TrezorClient(transport, ClickUI())
transport.open()
client = TrezorClient(transport)
session = client.get_seedless_session()
dom = type.upper() + ": " + domain
dec = decrypt_keyvalue(client, BIP32_PATH, dom, secret, False, True)
client.close()
dec = decrypt_keyvalue(session, BIP32_PATH, dom, secret, False, True)
transport.close()
return dec

View File

@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str):
if __name__ == "__main__":
wirelink = get_device()
client = Client(wirelink)
client.open()
session = client.get_seedless_session()
i = 0
@ -76,10 +76,12 @@ if __name__ == "__main__":
# change PIN
new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10)))
client.set_input_flow(pin_input_flow(client, last_pin, new_pin))
session.set_input_flow(pin_input_flow(client, last_pin, new_pin))
device.change_pin(client)
client.set_input_flow(None)
session.set_input_flow(None)
last_pin = new_pin
print(f"iteration {i}")
i = i + 1
wirelink.close()

View File

@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Tuple
import pytest
from trezorlib import btc, device, exceptions, messages
from trezorlib.client import PASSPHRASE_ON_DEVICE
from trezorlib.debuglink import DebugLink, LayoutType
from trezorlib.protobuf import MessageType
from trezorlib.tools import parse_path
@ -66,8 +67,8 @@ def _center_button(debug: DebugLink) -> Tuple[int, int]:
def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int):
debug = device_handler.debuglink()
device_handler.run(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore
device_handler.client.get_seedless_session().lock()
device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore
assert "PinKeyboard" in debug.read_layout().all_components()
@ -106,7 +107,7 @@ def test_autolock_interrupts_signing(device_handler: "BackgroundDeviceHandler"):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
device_handler.run(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore
device_handler.run_with_session(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore
assert (
"1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1"
@ -144,6 +145,10 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink()
# Prepare session to use later
session = device_handler.client.get_session()
# try to sign a transaction
inp1 = messages.TxInputType(
address_n=parse_path("86h/0h/0h/0/0"),
@ -159,8 +164,8 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
device_handler.run(
btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET
device_handler.run_with_provided_session(
session, btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET
)
assert (
@ -190,14 +195,14 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
def sleepy_filter(msg: MessageType) -> MessageType:
time.sleep(10.1)
device_handler.client.set_filter(messages.TxAck, None)
session.set_filter(messages.TxAck, None)
return msg
with device_handler.client:
device_handler.client.set_filter(messages.TxAck, sleepy_filter)
with session:
session.set_filter(messages.TxAck, sleepy_filter)
# confirm transaction
if debug.layout_type is LayoutType.Bolt:
debug.click(debug.screen_buttons.ok())
debug.click(debug.screen_buttons.ok(), hold_ms=1000)
elif debug.layout_type is LayoutType.Delizia:
debug.click(debug.screen_buttons.tap_to_confirm())
elif debug.layout_type is LayoutType.Caesar:
@ -206,7 +211,6 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
signatures, tx = device_handler.result()
assert len(signatures) == 1
assert tx
assert device_handler.features().unlocked is False
@ -216,8 +220,9 @@ def test_autolock_passphrase_keyboard(device_handler: "BackgroundDeviceHandler")
debug = device_handler.debuglink()
# get address
device_handler.run(common.get_test_address) # type: ignore
session = device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE)
device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore
assert "PassphraseKeyboard" in debug.read_layout().all_components()
if debug.layout_type is LayoutType.Caesar:
@ -253,8 +258,8 @@ def test_autolock_interrupts_passphrase(device_handler: "BackgroundDeviceHandler
debug = device_handler.debuglink()
# get address
device_handler.run(common.get_test_address) # type: ignore
session = device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE)
device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore
assert "PassphraseKeyboard" in debug.read_layout().all_components()
if debug.layout_type is LayoutType.Caesar:
@ -293,7 +298,7 @@ def test_dryrun_locks_at_number_of_words(device_handler: "BackgroundDeviceHandle
set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink()
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
layout = unlock_dry_run(debug)
assert TR.recovery__num_of_words in debug.read_layout().text_content()
@ -326,7 +331,7 @@ def test_dryrun_locks_at_word_entry(device_handler: "BackgroundDeviceHandler"):
set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink()
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
unlock_dry_run(debug)
@ -353,7 +358,7 @@ def test_dryrun_enter_word_slowly(device_handler: "BackgroundDeviceHandler"):
set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink()
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
unlock_dry_run(debug)
@ -418,7 +423,11 @@ def test_autolock_does_not_interrupt_preauthorized(
debug = device_handler.debuglink()
device_handler.run(
# Prepare session to use later
session = device_handler.client.get_session()
device_handler.run_with_provided_session(
session,
btc.authorize_coinjoin,
coordinator="www.example.com",
max_rounds=2,
@ -532,14 +541,15 @@ def test_autolock_does_not_interrupt_preauthorized(
def sleepy_filter(msg: MessageType) -> MessageType:
time.sleep(10.1)
device_handler.client.set_filter(messages.SignTx, None)
session.set_filter(messages.SignTx, None)
return msg
with device_handler.client:
with session:
# Start DoPreauthorized flow when device is unlocked. Wait 10s before
# delivering SignTx, by that time autolock timer should have fired.
device_handler.client.set_filter(messages.SignTx, sleepy_filter)
device_handler.run(
session.set_filter(messages.SignTx, sleepy_filter)
device_handler.run_with_provided_session(
session,
btc.sign_tx,
"Testnet",
inputs,

View File

@ -52,7 +52,9 @@ def test_backup_slip39_custom(
assert features.initialized is False
device_handler.run(
session = device_handler.client.get_seedless_session()
device_handler.run_with_provided_session(
session,
device.setup,
strength=128,
backup_type=messages.BackupType.Slip39_Basic,
@ -71,7 +73,7 @@ def test_backup_slip39_custom(
# retrieve the result to check that it's not a TrezorFailure exception
device_handler.result()
device_handler.run(
device_handler.run_with_session(
device.backup,
group_threshold=group_threshold,
groups=[(share_threshold, share_count)],

View File

@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
import pytest
from trezorlib import models
from trezorlib import messages, models
from trezorlib.debuglink import LayoutType
from .. import common
@ -34,6 +34,9 @@ PIN4 = "1234"
@pytest.mark.setup_client(pin=PIN4)
def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink()
session = device_handler.client.get_seedless_session()
session.call(messages.LockDevice())
session.refresh_features()
short_duration = {
models.T1B1: 500,
@ -59,22 +62,25 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
assert device_handler.features().unlocked is False
# unlock with message
device_handler.run(common.get_test_address)
device_handler.run_with_session(common.get_test_address)
assert "PinKeyboard" in debug.read_layout().all_components()
debug.input("1234")
assert device_handler.result()
session.refresh_features()
assert device_handler.features().unlocked is True
# short touch
hold(short_duration)
time.sleep(0.5) # so that the homescreen appears again (hacky)
session.refresh_features()
assert device_handler.features().unlocked is True
# lock
hold(lock_duration)
session.refresh_features()
assert device_handler.features().unlocked is False
# unlock by touching
@ -86,8 +92,10 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
assert "PinKeyboard" in layout.all_components()
debug.input("1234")
session.refresh_features()
assert device_handler.features().unlocked is True
# lock
hold(lock_duration)
session.refresh_features()
assert device_handler.features().unlocked is False

View File

@ -73,7 +73,7 @@ def prepare_passphrase_dialogue(
device_handler: "BackgroundDeviceHandler", address: Optional[str] = None
) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink()
device_handler.run(get_test_address) # type: ignore
device_handler.run_with_session(get_test_address) # type: ignore
assert debug.read_layout().main_component() == "PassphraseKeyboard"
# Resetting the category as it could have been changed by previous tests

View File

@ -91,7 +91,7 @@ def prepare_passphrase_dialogue(
device_handler: "BackgroundDeviceHandler", address: Optional[str] = None
) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink()
device_handler.run(get_test_address) # type: ignore
device_handler.run_with_session(get_test_address) # type: ignore
layout = debug.read_layout()
assert "PassphraseKeyboard" in layout.all_components()
assert layout.passphrase() == ""

View File

@ -90,17 +90,19 @@ def prepare(
tap = False
device_handler.client.get_seedless_session().lock()
# Setup according to the wanted situation
if situation == Situation.PIN_INPUT:
# Any action triggering the PIN dialogue
device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
tap = True
if situation == Situation.PIN_INPUT_CANCEL:
# Any action triggering the PIN dialogue
device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
elif situation == Situation.PIN_SETUP:
# Set new PIN
device_handler.run(device.change_pin) # type: ignore
device_handler.run_with_session(device.change_pin) # type: ignore
assert (
TR.pin__turn_on in debug.read_layout().text_content()
or TR.pin__info in debug.read_layout().text_content()
@ -114,14 +116,14 @@ def prepare(
go_next(debug)
elif situation == Situation.PIN_CHANGE:
# Change PIN
device_handler.run(device.change_pin) # type: ignore
device_handler.run_with_session(device.change_pin) # type: ignore
_input_see_confirm(debug, old_pin)
assert TR.pin__change in debug.read_layout().text_content()
go_next(debug)
_input_see_confirm(debug, old_pin)
elif situation == Situation.WIPE_CODE_SETUP:
# Set wipe code
device_handler.run(device.change_wipe_code) # type: ignore
device_handler.run_with_session(device.change_wipe_code) # type: ignore
if old_pin:
_input_see_confirm(debug, old_pin)
assert TR.wipe_code__turn_on in debug.read_layout().text_content()

View File

@ -40,7 +40,7 @@ def prepare_recovery_and_evaluate(
features = device_handler.features()
debug = device_handler.debuglink()
assert features.initialized is False
device_handler.run(device.recover, pin_protection=False) # type: ignore
device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore
yield debug
@ -58,7 +58,7 @@ def prepare_recovery_and_evaluate_cancel(
features = device_handler.features()
debug = device_handler.debuglink()
assert features.initialized is False
device_handler.run(device.recover, pin_protection=False) # type: ignore
device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore
yield debug
@ -113,10 +113,11 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink()
# initiate and confirm the recovery
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
recovery.confirm_recovery(debug, title="recovery__title_dry_run")
# select number of words
recovery.select_number_of_words(debug, num_of_words=12)
device_handler.client.transport.close()
# abort the process running the recovery from host
device_handler.kill_task()
@ -124,16 +125,20 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
# from the host side.
# Reopen client and debuglink, closed by kill_task
device_handler.client.open()
device_handler.client.transport.open()
debug = device_handler.debuglink()
# Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART)
try:
features = device_handler.client.call(messages.Initialize())
features = device_handler.client.get_seedless_session().call(
messages.Initialize()
)
except exceptions.Cancelled:
# due to a related problem, the first call in this situation will return
# a Cancelled failure. This test does not care, we just retry.
features = device_handler.client.call(messages.Initialize())
features = device_handler.client.get_seedless_session().call(
messages.Initialize()
)
assert features.recovery_status == messages.RecoveryStatus.Recovery
# Trezor is sitting in recovery_homescreen now, waiting for the user to select

View File

@ -40,7 +40,7 @@ def test_repeated_backup(
assert features.initialized is False
device_handler.run(
device_handler.run_with_session(
device.setup,
strength=128,
backup_type=messages.BackupType.Slip39_Basic,
@ -93,7 +93,7 @@ def test_repeated_backup(
assert features.recovery_status == messages.RecoveryStatus.Nothing
# run recovery to unlock backup
device_handler.run(
device_handler.run_with_session(
device.recover,
type=messages.RecoveryType.UnlockRepeatedBackup,
)
@ -160,7 +160,7 @@ def test_repeated_backup(
assert features.recovery_status == messages.RecoveryStatus.Nothing
# try to unlock backup again...
device_handler.run(
device_handler.run_with_session(
device.recover,
type=messages.RecoveryType.UnlockRepeatedBackup,
)
@ -200,7 +200,7 @@ def test_repeated_backup(
assert features.recovery_status == messages.RecoveryStatus.Nothing
# try to unlock backup yet again...
device_handler.run(
device_handler.run_with_session(
device.recover,
type=messages.RecoveryType.UnlockRepeatedBackup,
)

View File

@ -39,7 +39,7 @@ def test_reset_bip39(device_handler: "BackgroundDeviceHandler"):
assert features.initialized is False
device_handler.run(
device_handler.run_with_session(
device.setup,
strength=128,
backup_type=messages.BackupType.Bip39,

View File

@ -50,7 +50,7 @@ def test_reset_slip39_advanced(
assert features.initialized is False
device_handler.run(
device_handler.run_with_session(
device.setup,
backup_type=messages.BackupType.Slip39_Advanced,
pin_protection=False,

View File

@ -46,7 +46,7 @@ def test_reset_slip39_basic(
assert features.initialized is False
device_handler.run(
device_handler.run_with_session(
device.setup,
strength=128,
backup_type=messages.BackupType.Slip39_Basic,

View File

@ -39,7 +39,7 @@ def prepare_tutorial_and_cancel_after_it(
device_handler: "BackgroundDeviceHandler", cancelled: bool = False
) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial)
device_handler.run_with_session(device.show_device_tutorial)
yield debug

View File

@ -35,7 +35,7 @@ pytestmark = [
def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial)
device_handler.run_with_session(device.show_device_tutorial)
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
debug.click(debug.screen_buttons.tap_to_confirm())
@ -55,7 +55,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial)
device_handler.run_with_session(device.show_device_tutorial)
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
debug.click(debug.screen_buttons.tap_to_confirm())
@ -81,7 +81,7 @@ def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial)
device_handler.run_with_session(device.show_device_tutorial)
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
debug.click(debug.screen_buttons.tap_to_confirm())
@ -104,7 +104,7 @@ def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial)
device_handler.run_with_session(device.show_device_tutorial)
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
debug.click(debug.screen_buttons.tap_to_confirm())
@ -134,7 +134,7 @@ def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial)
device_handler.run_with_session(device.show_device_tutorial)
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
debug.click(debug.screen_buttons.tap_to_confirm())

View File

@ -32,8 +32,8 @@ if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator
from trezorlib.debuglink import DebugLink
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import ButtonRequest
from trezorlib.transport.session import Session
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
@ -336,10 +336,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None:
assert got >= expected
def get_test_address(client: "Client") -> str:
def get_test_address(session: "Session") -> str:
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
protected call, or to identify the root secret (seed+passphrase)"""
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
return btc.get_address(session, "Testnet", TEST_ADDRESS_N)
def compact_size(n: int) -> bytes:
@ -378,5 +378,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
debug.swipe_up()
def is_core(client: "Client") -> bool:
return client.model is not models.T1B1
def is_core(session: "Session") -> bool:
return session.model is not models.T1B1

View File

@ -21,6 +21,7 @@ import os
import typing as t
from enum import IntEnum
from pathlib import Path
from time import sleep
import pytest
import xdist
@ -31,7 +32,8 @@ from trezorlib import debuglink, log, models
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.device import apply_settings
from trezorlib.device import wipe as wipe_device
from trezorlib.transport import enumerate_devices, get_transport, protocol
from trezorlib.transport import enumerate_devices, get_transport
from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel
# register rewrites before importing from local package
# so that we see details of failed asserts from this module
@ -49,6 +51,7 @@ if t.TYPE_CHECKING:
from _pytest.terminal import TerminalReporter
from trezorlib._internal.emulator import Emulator
from trezorlib.debuglink import SessionDebugWrapper
HERE = Path(__file__).resolve().parent
@ -78,7 +81,7 @@ def core_emulator(request: pytest.FixtureRequest) -> t.Iterator[Emulator]:
"""Fixture returning default core emulator with possibility of screen recording."""
with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu:
# Modifying emu.client to add screen recording (when --ui=test is used)
with ui_tests.screen_recording(emu.client, request) as _:
with ui_tests.screen_recording(emu.client, request, lambda: emu.client) as _:
yield emu
@ -127,7 +130,15 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
@pytest.fixture(scope="session")
def _raw_client(request: pytest.FixtureRequest) -> Client:
def _raw_client(request: pytest.FixtureRequest) -> t.Generator[Client, None, None]:
client = _get_raw_client(request)
try:
yield client
finally:
client.close_transport()
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
# In case tests run in parallel, each process has its own emulator/client.
# Requesting the emulator fixture only if relevant.
if request.session.config.getoption("control_emulators"):
@ -137,7 +148,7 @@ def _raw_client(request: pytest.FixtureRequest) -> Client:
interact = os.environ.get("INTERACT") == "1"
if not interact:
# prevent tests from getting stuck in case there is an USB packet loss
protocol._DEFAULT_READ_TIMEOUT = 50.0
ProtocolV1Channel._DEFAULT_READ_TIMEOUT = 50.0
path = os.environ.get("TREZOR_PATH")
if path:
@ -153,7 +164,7 @@ def _client_from_path(
) -> Client:
try:
transport = get_transport(path)
return Client(transport, auto_interact=not interact)
return Client(transport, auto_interact=not interact, open_transport=True)
except Exception as e:
request.session.shouldstop = "Failed to communicate with Trezor"
raise RuntimeError(f"Failed to open debuglink for {path}") from e
@ -162,10 +173,7 @@ def _client_from_path(
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
devices = enumerate_devices()
for device in devices:
try:
return Client(device, auto_interact=not interact)
except Exception:
pass
return Client(device, auto_interact=not interact, open_transport=True)
request.session.shouldstop = "Failed to communicate with Trezor"
raise RuntimeError("No debuggable device found")
@ -240,7 +248,7 @@ class ModelsFilter:
@pytest.fixture(scope="function")
def client(
def _client_unlocked(
request: pytest.FixtureRequest, _raw_client: Client
) -> t.Generator[Client, None, None]:
"""Client fixture.
@ -280,14 +288,14 @@ def client(
test_ui = request.config.getoption("ui")
_raw_client.reset_debug_features()
_raw_client.open()
try:
_raw_client.sync_responses()
_raw_client.init_device()
except Exception:
request.session.shouldstop = "Failed to communicate with Trezor"
pytest.fail("Failed to communicate with Trezor")
# _raw_client.reset_debug_features()
if isinstance(_raw_client.protocol, ProtocolV1Channel):
try:
_raw_client.sync_responses()
except Exception:
request.session.shouldstop = "Failed to communicate with Trezor"
pytest.fail("Failed to communicate with Trezor")
_raw_client._seedless_session = _raw_client.get_seedless_session(new_session=True)
# Resetting all the debug events to not be influenced by previous test
_raw_client.debug.reset_debug_events()
@ -300,13 +308,20 @@ def client(
should_format = sd_marker.kwargs.get("formatted", True)
_raw_client.debug.erase_sd_card(format=should_format)
wipe_device(_raw_client)
if _raw_client.is_invalidated:
_raw_client = _raw_client.get_new_client()
session = _raw_client.get_seedless_session()
wipe_device(session)
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
if not _raw_client.features.bootloader_mode:
_raw_client.refresh_features()
# Load language again, as it got erased in wipe
if _raw_client.model is not models.T1B1:
lang = request.session.config.getoption("lang") or "en"
assert isinstance(lang, str)
translations.set_language(_raw_client, lang)
translations.set_language(_raw_client.get_seedless_session(), lang)
setup_params = dict(
uninitialized=False,
@ -324,32 +339,59 @@ def client(
use_passphrase = setup_params["passphrase"] is True or isinstance(
setup_params["passphrase"], str
)
if not setup_params["uninitialized"]:
session = _raw_client.get_seedless_session(new_session=True)
debuglink.load_device(
_raw_client,
session,
mnemonic=setup_params["mnemonic"], # type: ignore
pin=setup_params["pin"], # type: ignore
passphrase_protection=use_passphrase,
label="test",
needs_backup=setup_params["needs_backup"], # type: ignore
no_backup=setup_params["no_backup"], # type: ignore
_skip_init_device=True,
_skip_init_device=False,
)
_raw_client._setup_pin = setup_params["pin"]
if request.node.get_closest_marker("experimental"):
apply_settings(_raw_client, experimental_features=True)
apply_settings(session, experimental_features=True)
session.end()
if use_passphrase and isinstance(setup_params["passphrase"], str):
_raw_client.use_passphrase(setup_params["passphrase"])
yield _raw_client
_raw_client.lock(_refresh_features=False)
_raw_client.init_device(new_session=True)
with ui_tests.screen_recording(_raw_client, request):
yield _raw_client
@pytest.fixture(scope="function")
def client(
request: pytest.FixtureRequest, _client_unlocked: Client
) -> t.Generator[Client, None, None]:
_client_unlocked.lock()
with ui_tests.screen_recording(_client_unlocked, request):
yield _client_unlocked
_raw_client.close()
@pytest.fixture(scope="function")
def session(
request: pytest.FixtureRequest, _client_unlocked: Client
) -> t.Generator[SessionDebugWrapper, None, None]:
if bool(request.node.get_closest_marker("uninitialized_session")):
session = _client_unlocked.get_seedless_session()
else:
derive_cardano = bool(request.node.get_closest_marker("cardano"))
passphrase = ""
marker = request.node.get_closest_marker("setup_client")
if marker and isinstance(marker.kwargs.get("passphrase"), str):
passphrase = marker.kwargs["passphrase"]
if _client_unlocked._setup_pin is not None:
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
session = _client_unlocked.get_session(
derive_cardano=derive_cardano, passphrase=passphrase
)
if _client_unlocked._setup_pin is not None:
session.lock()
with ui_tests.screen_recording(_client_unlocked, request):
yield session
# Calling session.end() is not needed since the device gets wiped later anyway.
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
@ -467,6 +509,10 @@ def pytest_configure(config: "Config") -> None:
"markers",
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
)
config.addinivalue_line(
"markers",
"uninitialized_session: use uninitialized session instance",
)
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
for line in f:
config.addinivalue_line("markers", line.strip())

View File

@ -6,12 +6,12 @@ from concurrent.futures import ThreadPoolExecutor
import typing_extensions as tx
from trezorlib.client import PASSPHRASE_ON_DEVICE
from trezorlib.messages import DebugWaitType
from trezorlib.transport import udp
if t.TYPE_CHECKING:
from trezorlib._internal.emulator import Emulator
from trezorlib.debuglink import DebugLink
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import Features
@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1
class NullUI:
@staticmethod
def clear(*args, **kwargs):
pass
@staticmethod
def button_request(code):
pass
@ -50,11 +54,29 @@ class BackgroundDeviceHandler:
self.client = client
self.client.ui = NullUI # type: ignore [NullUI is OK UI]
self.client.watch_layout(True)
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
def run(
def run_with_session(
self,
function: t.Callable[tx.Concatenate["Client", P], t.Any],
function: t.Callable[tx.Concatenate["Session", P], t.Any],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""Runs some function that interacts with a device.
Makes sure the UI is updated before returning.
"""
if self.task is not None:
raise RuntimeError("Wait for previous task first")
# wait for the first UI change triggered by the task running in the background
session = self.client.get_session()
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, session, *args, **kwargs)
def run_with_provided_session(
self,
session,
function: t.Callable[tx.Concatenate["Session", P], t.Any],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
@ -67,15 +89,13 @@ class BackgroundDeviceHandler:
# wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, self.client, *args, **kwargs)
self.task = self._pool.submit(function, session, *args, **kwargs)
def kill_task(self) -> None:
if self.task is not None:
# Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have
# a close() method.
while self.client.session_counter > 0:
self.client.close()
try:
self.task.result(timeout=1)
except Exception:
@ -99,7 +119,7 @@ class BackgroundDeviceHandler:
def features(self) -> "Features":
if self.task is not None:
raise RuntimeError("Cannot query features while task is running")
self.client.init_device()
self.client.refresh_features()
return self.client.features
def debuglink(self) -> "DebugLink":

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib.binance import get_address
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
from ...input_flows import InputFlowShowAddressQRCode
@ -38,23 +38,23 @@ BINANCE_ADDRESS_TEST_VECTORS = [
@pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS)
def test_binance_get_address(client: Client, path: str, expected_address: str):
def test_binance_get_address(session: Session, path: str, expected_address: str):
# data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50
address = get_address(client, parse_path(path), show_display=True)
address = get_address(session, parse_path(path), show_display=True)
assert address == expected_address
@pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS)
def test_binance_get_address_chunkify_details(
client: Client, path: str, expected_address: str
session: Session, path: str, expected_address: str
):
# data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50
with client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
address = get_address(
client, parse_path(path), show_display=True, chunkify=True
session, parse_path(path), show_display=True, chunkify=True
)
assert address == expected_address

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import binance
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
from ...input_flows import InputFlowShowXpubQRCode
@ -31,11 +31,11 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0")
@pytest.mark.setup_client(
mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin"
)
def test_binance_get_public_key(client: Client):
with client:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
sig = binance.get_public_key(client, BINANCE_PATH, show_display=True)
def test_binance_get_public_key(session: Session):
with session:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
sig = binance.get_public_key(session, BINANCE_PATH, show_display=True)
assert (
sig.hex()
== "029729a52e4e3c2b4a4e52aa74033eedaf8ba1df5ab6d1f518fd69e67bbd309b0e"

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import binance
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
BINANCE_TEST_VECTORS = [
@ -110,10 +110,10 @@ BINANCE_TEST_VECTORS = [
@pytest.mark.parametrize("message, expected_response", BINANCE_TEST_VECTORS)
@pytest.mark.parametrize("chunkify", (True, False))
def test_binance_sign_message(
client: Client, chunkify: bool, message: dict, expected_response: dict
session: Session, chunkify: bool, message: dict, expected_response: dict
):
response = binance.sign_tx(
client, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify
session, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify
)
assert response.public_key.hex() == expected_response["public_key"]

View File

@ -4,6 +4,7 @@ from hashlib import sha256
from ecdsa import SECP256k1, SigningKey
from trezorlib import btc, messages
from trezorlib.transport.session import Session
from ...common import compact_size
@ -27,7 +28,12 @@ def hash_bytes_prefixed(hasher, data):
def make_payment_request(
client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None
session: Session,
recipient_name,
outputs,
change_addresses=None,
memos=None,
nonce=None,
):
h_pr = sha256(b"SL\x00\x24")
@ -52,7 +58,7 @@ def make_payment_request(
hash_bytes_prefixed(h_pr, memo.text.encode())
elif isinstance(memo, RefundMemo):
address_resp = btc.get_authenticated_address(
client, "Testnet", memo.address_n
session, "Testnet", memo.address_n
)
msg_memo = messages.RefundMemo(
address=address_resp.address, mac=address_resp.mac
@ -63,7 +69,7 @@ def make_payment_request(
hash_bytes_prefixed(h_pr, address_resp.address.encode())
elif isinstance(memo, CoinPurchaseMemo):
address_resp = btc.get_authenticated_address(
client, memo.coin_name, memo.address_n
session, memo.coin_name, memo.address_n
)
msg_memo = messages.CoinPurchaseMemo(
coin_type=memo.slip44,

View File

@ -19,6 +19,7 @@ import time
import pytest
from trezorlib import btc, device, messages
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path
@ -59,15 +60,15 @@ SLIP25_PATH = parse_path("m/10025h")
@pytest.mark.parametrize("chunkify", (True, False))
@pytest.mark.setup_client(pin=PIN)
def test_sign_tx(client: Client, chunkify: bool):
def test_sign_tx(session: Session, chunkify: bool):
# NOTE: FAKE input tx
assert session.features.unlocked is False
commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big")
with client:
client.use_pin_sequence([PIN])
with session:
session.client.use_pin_sequence([PIN])
btc.authorize_coinjoin(
client,
session,
coordinator="www.example.com",
max_rounds=2,
max_coordinator_fee_rate=500_000, # 0.5 %
@ -77,14 +78,14 @@ def test_sign_tx(client: Client, chunkify: bool):
script_type=messages.InputScriptType.SPENDTAPROOT,
)
client.call(messages.LockDevice())
session.call(messages.LockDevice())
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[messages.PreauthorizedRequest, messages.OwnershipProof]
)
btc.get_ownership_proof(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -93,12 +94,12 @@ def test_sign_tx(client: Client, chunkify: bool):
preauthorized=True,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[messages.PreauthorizedRequest, messages.OwnershipProof]
)
btc.get_ownership_proof(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/5"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -206,8 +207,8 @@ def test_sign_tx(client: Client, chunkify: bool):
no_fee_indices=[],
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
messages.PreauthorizedRequest(),
request_input(0),
@ -222,7 +223,7 @@ def test_sign_tx(client: Client, chunkify: bool):
]
)
signatures, serialized_tx = btc.sign_tx(
client,
session,
"Testnet",
inputs,
outputs,
@ -243,7 +244,7 @@ def test_sign_tx(client: Client, chunkify: bool):
# Test for a second time.
btc.sign_tx(
client,
session,
"Testnet",
inputs,
outputs,
@ -256,7 +257,7 @@ def test_sign_tx(client: Client, chunkify: bool):
# Test for a third time, number of rounds should be exceeded.
with pytest.raises(TrezorFailure, match="No preauthorized operation"):
btc.sign_tx(
client,
session,
"Testnet",
inputs,
outputs,
@ -267,7 +268,7 @@ def test_sign_tx(client: Client, chunkify: bool):
)
def test_sign_tx_large(client: Client):
def test_sign_tx_large(session: Session):
# NOTE: FAKE input tx
commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big")
@ -278,17 +279,16 @@ def test_sign_tx_large(client: Client):
output_denom = 10_000 # sats
max_expected_delay = 80 # seconds
with client:
btc.authorize_coinjoin(
client,
coordinator="www.example.com",
max_rounds=2,
max_coordinator_fee_rate=500_000, # 0.5 %
max_fee_per_kvbyte=3500,
n=parse_path("m/10025h/1h/0h/1h"),
coin_name="Testnet",
script_type=messages.InputScriptType.SPENDTAPROOT,
)
btc.authorize_coinjoin(
session,
coordinator="www.example.com",
max_rounds=2,
max_coordinator_fee_rate=500_000, # 0.5 %
max_fee_per_kvbyte=3500,
n=parse_path("m/10025h/1h/0h/1h"),
coin_name="Testnet",
script_type=messages.InputScriptType.SPENDTAPROOT,
)
# INPUTS.
@ -399,22 +399,21 @@ def test_sign_tx_large(client: Client):
)
start = time.time()
with client:
btc.sign_tx(
client,
"Testnet",
inputs,
outputs,
prev_txes=TX_CACHE_TESTNET,
coinjoin_request=coinjoin_req,
preauthorized=True,
serialize=False,
)
btc.sign_tx(
session,
"Testnet",
inputs,
outputs,
prev_txes=TX_CACHE_TESTNET,
coinjoin_request=coinjoin_req,
preauthorized=True,
serialize=False,
)
delay = time.time() - start
assert delay <= max_expected_delay
def test_sign_tx_spend(client: Client):
def test_sign_tx_spend(session: Session):
# NOTE: FAKE input tx
inputs = [
@ -446,15 +445,15 @@ def test_sign_tx_spend(client: Client):
# Ensure that Trezor refuses to spend from CoinJoin without user authorization.
with pytest.raises(TrezorFailure, match="Forbidden key path"):
_, serialized_tx = btc.sign_tx(
client,
session,
"Testnet",
inputs,
outputs,
prev_txes=TX_CACHE_TESTNET,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest,
@ -462,7 +461,7 @@ def test_sign_tx_spend(client: Client):
request_output(0),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_output(0),
@ -472,7 +471,7 @@ def test_sign_tx_spend(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client,
session,
"Testnet",
inputs,
outputs,
@ -487,7 +486,7 @@ def test_sign_tx_spend(client: Client):
)
def test_sign_tx_migration(client: Client):
def test_sign_tx_migration(session: Session):
inputs = [
messages.TxInputType(
address_n=parse_path("m/84h/1h/3h/0/12"),
@ -520,15 +519,15 @@ def test_sign_tx_migration(client: Client):
# Ensure that Trezor refuses to receive to CoinJoin path without the user first authorizing access to CoinJoin paths.
with pytest.raises(TrezorFailure, match="Forbidden key path"):
_, serialized_tx = btc.sign_tx(
client,
session,
"Testnet",
inputs,
outputs,
prev_txes=TX_CACHE_TESTNET,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest,
@ -536,7 +535,7 @@ def test_sign_tx_migration(client: Client):
request_input(1),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(TXHASH_2cc3c1),
@ -558,7 +557,7 @@ def test_sign_tx_migration(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client,
session,
"Testnet",
inputs,
outputs,
@ -573,11 +572,11 @@ def test_sign_tx_migration(client: Client):
)
def test_wrong_coordinator(client: Client):
def test_wrong_coordinator(session: Session):
# Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator.
btc.authorize_coinjoin(
client,
session,
coordinator="www.example.com",
max_rounds=10,
max_coordinator_fee_rate=500_000, # 0.5 %
@ -589,7 +588,7 @@ def test_wrong_coordinator(client: Client):
with pytest.raises(TrezorFailure, match="Unauthorized operation"):
btc.get_ownership_proof(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -599,9 +598,9 @@ def test_wrong_coordinator(client: Client):
)
def test_wrong_account_type(client: Client):
def test_wrong_account_type(session: Session):
params = {
"client": client,
"session": session,
"coordinator": "www.example.com",
"max_rounds": 10,
"max_coordinator_fee_rate": 500_000, # 0.5 %
@ -625,11 +624,11 @@ def test_wrong_account_type(client: Client):
)
def test_cancel_authorization(client: Client):
def test_cancel_authorization(session: Session):
# Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator.
btc.authorize_coinjoin(
client,
session,
coordinator="www.example.com",
max_rounds=10,
max_coordinator_fee_rate=500_000, # 0.5 %
@ -639,11 +638,11 @@ def test_cancel_authorization(client: Client):
script_type=messages.InputScriptType.SPENDTAPROOT,
)
device.cancel_authorization(client)
device.cancel_authorization(session)
with pytest.raises(TrezorFailure, match="No preauthorized operation"):
btc.get_ownership_proof(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -653,35 +652,35 @@ def test_cancel_authorization(client: Client):
)
def test_get_public_key(client: Client):
def test_get_public_key(session: Session):
ACCOUNT_PATH = parse_path("m/10025h/1h/0h/1h")
EXPECTED_XPUB = "tpubDEMKm4M3S2Grx5DHTfbX9et5HQb9KhdjDCkUYdH9gvVofvPTE6yb2MH52P9uc4mx6eFohUmfN1f4hhHNK28GaZnWRXr3b8KkfFcySo1SmXU"
# Ensure that user cannot access SLIP-25 path without UnlockPath.
with pytest.raises(TrezorFailure, match="Forbidden key path"):
resp = btc.get_public_node(
client,
session,
ACCOUNT_PATH,
coin_name="Testnet",
script_type=messages.InputScriptType.SPENDTAPROOT,
)
# Get unlock path MAC.
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest,
messages.Failure(code=messages.FailureType.ActionCancelled),
]
)
unlock_path_mac = device.unlock_path(client, n=SLIP25_PATH)
unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH)
# Ensure that UnlockPath fails with invalid MAC.
invalid_unlock_path_mac = bytes([unlock_path_mac[0] ^ 1]) + unlock_path_mac[1:]
with pytest.raises(TrezorFailure, match="Invalid MAC"):
resp = btc.get_public_node(
client,
session,
ACCOUNT_PATH,
coin_name="Testnet",
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -690,15 +689,15 @@ def test_get_public_key(client: Client):
)
# Ensure that user does not need to confirm access when path unlock is requested with MAC.
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
messages.UnlockedPathRequest,
messages.PublicKey,
]
)
resp = btc.get_public_node(
client,
session,
ACCOUNT_PATH,
coin_name="Testnet",
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -708,11 +707,12 @@ def test_get_public_key(client: Client):
assert resp.xpub == EXPECTED_XPUB
def test_get_address(client: Client):
def test_get_address(session: Session):
# Ensure that the SLIP-0025 external chain is inaccessible without user confirmation.
with pytest.raises(TrezorFailure, match="Forbidden key path"):
btc.get_address(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/0/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -720,20 +720,20 @@ def test_get_address(client: Client):
)
# Unlock CoinJoin path.
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest,
messages.Failure(code=messages.FailureType.ActionCancelled),
]
)
unlock_path_mac = device.unlock_path(client, SLIP25_PATH)
unlock_path_mac = device.unlock_path(session, SLIP25_PATH)
# Ensure that the SLIP-0025 external chain is accessible after user confirmation.
for chunkify in (True, False):
resp = btc.get_address(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/0/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -745,7 +745,7 @@ def test_get_address(client: Client):
assert resp == "tb1pl3y9gf7xk2ryvmav5ar66ra0d2hk7lhh9mmusx3qvn0n09kmaghqh32ru7"
resp = btc.get_address(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/0/1"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -758,7 +758,7 @@ def test_get_address(client: Client):
# Ensure that the SLIP-0025 internal chain is inaccessible even with user authorization.
with pytest.raises(TrezorFailure, match="Forbidden key path"):
btc.get_address(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -769,7 +769,7 @@ def test_get_address(client: Client):
with pytest.raises(TrezorFailure, match="Forbidden key path"):
btc.get_address(
client,
session,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/1"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -781,7 +781,7 @@ def test_get_address(client: Client):
# Ensure that another SLIP-0025 account is inaccessible with the same MAC.
with pytest.raises(TrezorFailure, match="Forbidden key path"):
btc.get_address(
client,
session,
"Testnet",
parse_path("m/10025h/1h/1h/1h/0/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -793,8 +793,10 @@ def test_get_address(client: Client):
def test_multisession_authorization(client: Client):
# Authorize CoinJoin with www.example1.com in session 1.
session1 = client.get_session()
btc.authorize_coinjoin(
client,
session1,
coordinator="www.example1.com",
max_rounds=10,
max_coordinator_fee_rate=500_000, # 0.5 %
@ -805,12 +807,11 @@ def test_multisession_authorization(client: Client):
)
# Open a second session.
session_id1 = client.session_id
client.init_device(new_session=True)
session2 = client.get_session()
# Authorize CoinJoin with www.example2.com in session 2.
btc.authorize_coinjoin(
client,
session2,
coordinator="www.example2.com",
max_rounds=10,
max_coordinator_fee_rate=500_000, # 0.5 %
@ -823,7 +824,7 @@ def test_multisession_authorization(client: Client):
# Requesting a preauthorized ownership proof for www.example1.com should fail in session 2.
with pytest.raises(TrezorFailure, match="Unauthorized operation"):
ownership_proof, _ = btc.get_ownership_proof(
client,
session2,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -834,7 +835,7 @@ def test_multisession_authorization(client: Client):
# Requesting a preauthorized ownership proof for www.example2.com should succeed in session 2.
ownership_proof, _ = btc.get_ownership_proof(
client,
session2,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -849,12 +850,10 @@ def test_multisession_authorization(client: Client):
)
# Switch back to the first session.
session_id2 = client.session_id
client.init_device(session_id=session_id1)
session1 = client.resume_session(session1)
# Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1.
ownership_proof, _ = btc.get_ownership_proof(
client,
session1,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -871,7 +870,7 @@ def test_multisession_authorization(client: Client):
# Requesting a preauthorized ownership proof for www.example2.com should fail in session 1.
with pytest.raises(TrezorFailure, match="Unauthorized operation"):
ownership_proof, _ = btc.get_ownership_proof(
client,
session1,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -881,12 +880,12 @@ def test_multisession_authorization(client: Client):
)
# Cancel the authorization in session 1.
device.cancel_authorization(client)
device.cancel_authorization(session1)
# Requesting a preauthorized ownership proof should fail now.
with pytest.raises(TrezorFailure, match="No preauthorized operation"):
ownership_proof, _ = btc.get_ownership_proof(
client,
session1,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,
@ -896,11 +895,10 @@ def test_multisession_authorization(client: Client):
)
# Switch to the second session.
client.init_device(session_id=session_id2)
session2 = client.resume_session(session2)
# Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2.
ownership_proof, _ = btc.get_ownership_proof(
client,
session2,
"Testnet",
parse_path("m/10025h/1h/0h/1h/1/0"),
script_type=messages.InputScriptType.SPENDTAPROOT,

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import H_, parse_path
@ -53,7 +53,7 @@ FAKE_TXHASH_203416 = bytes.fromhex(
pytestmark = pytest.mark.altcoin
def test_send_bch_change(client: Client):
def test_send_bch_change(session: Session):
inp1 = messages.TxInputType(
address_n=parse_path("m/44h/145h/0h/0/0"),
# bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv
@ -72,14 +72,14 @@ def test_send_bch_change(client: Client):
amount=73_452,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(TXHASH_bc37c2),
@ -92,9 +92,9 @@ def test_send_bch_change(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
)
# raise Exception(hexlify(serialized_tx))
assert_tx_matches(
serialized_tx,
hash_link="https://bch1.trezor.io/api/tx/502e8577b237b0152843a416f8f1ab0c63321b1be7a8cad7bf5c5c216fcf062c",
@ -102,7 +102,7 @@ def test_send_bch_change(client: Client):
)
def test_send_bch_nochange(client: Client):
def test_send_bch_nochange(session: Session):
inp1 = messages.TxInputType(
address_n=parse_path("m/44h/145h/0h/1/0"),
# bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw
@ -124,14 +124,14 @@ def test_send_bch_nochange(client: Client):
amount=1_934_960,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(TXHASH_502e85),
@ -150,7 +150,7 @@ def test_send_bch_nochange(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
)
assert_tx_matches(
@ -160,7 +160,7 @@ def test_send_bch_nochange(client: Client):
)
def test_send_bch_oldaddr(client: Client):
def test_send_bch_oldaddr(session: Session):
inp1 = messages.TxInputType(
address_n=parse_path("m/44h/145h/0h/1/0"),
# bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw
@ -182,14 +182,14 @@ def test_send_bch_oldaddr(client: Client):
amount=1_934_960,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(TXHASH_502e85),
@ -208,7 +208,7 @@ def test_send_bch_oldaddr(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
)
assert_tx_matches(
@ -218,7 +218,7 @@ def test_send_bch_oldaddr(client: Client):
)
def test_attack_change_input(client: Client):
def test_attack_change_input(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -252,15 +252,15 @@ def test_attack_change_input(client: Client):
return msg
with client:
client.set_filter(messages.TxAck, attack_processor)
client.set_expected_responses(
with session:
session.set_filter(messages.TxAck, attack_processor)
session.set_expected_responses(
[
request_input(0),
request_output(0),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_bd32ff),
@ -271,16 +271,16 @@ def test_attack_change_input(client: Client):
]
)
with pytest.raises(TrezorFailure):
btc.sign_tx(client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API)
btc.sign_tx(session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API)
@pytest.mark.multisig
def test_send_bch_multisig_wrongchange(client: Client):
def test_send_bch_multisig_wrongchange(session: Session):
# NOTE: fake input tx used
nodes = [
btc.get_public_node(
client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
).node
for i in range(1, 4)
]
@ -327,13 +327,13 @@ def test_send_bch_multisig_wrongchange(client: Client):
script_type=messages.OutputScriptType.PAYTOMULTISIG,
amount=23_000,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_062fbd),
@ -346,7 +346,7 @@ def test_send_bch_multisig_wrongchange(client: Client):
]
)
(signatures1, serialized_tx) = btc.sign_tx(
client, "Bcash", [inp1], [out1], prev_txes=TX_API
session, "Bcash", [inp1], [out1], prev_txes=TX_API
)
assert (
signatures1[0].hex()
@ -359,12 +359,12 @@ def test_send_bch_multisig_wrongchange(client: Client):
@pytest.mark.multisig
def test_send_bch_multisig_change(client: Client):
def test_send_bch_multisig_change(session: Session):
# NOTE: fake input tx used
nodes = [
btc.get_public_node(
client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
).node
for i in range(1, 4)
]
@ -395,13 +395,13 @@ def test_send_bch_multisig_change(client: Client):
script_type=messages.OutputScriptType.PAYTOMULTISIG,
amount=24_000,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
@ -415,7 +415,7 @@ def test_send_bch_multisig_change(client: Client):
]
)
(signatures1, serialized_tx) = btc.sign_tx(
client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
)
assert (
@ -434,13 +434,13 @@ def test_send_bch_multisig_change(client: Client):
)
out2.address_n[2] = H_(1)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
@ -454,7 +454,7 @@ def test_send_bch_multisig_change(client: Client):
]
)
(signatures1, serialized_tx) = btc.sign_tx(
client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
)
assert (
@ -468,7 +468,7 @@ def test_send_bch_multisig_change(client: Client):
@pytest.mark.models("core")
def test_send_bch_external_presigned(client: Client):
def test_send_bch_external_presigned(session: Session):
inp1 = messages.TxInputType(
# address_n=parse_path("44'/145'/0'/1/0"),
# bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw
@ -496,14 +496,14 @@ def test_send_bch_external_presigned(client: Client):
amount=1_934_960,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(TXHASH_502e85),
@ -522,7 +522,7 @@ def test_send_bch_external_presigned(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
)
assert_tx_matches(

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import H_, parse_path, tx_hash
@ -51,7 +51,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")]
# All data taken from T1
def test_send_bitcoin_gold_change(client: Client):
def test_send_bitcoin_gold_change(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -71,14 +71,14 @@ def test_send_bitcoin_gold_change(client: Client):
amount=1_252_382_934 - 1_896_050 - 1_000,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_6f0398),
@ -92,7 +92,7 @@ def test_send_bitcoin_gold_change(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
)
assert (
@ -101,7 +101,7 @@ def test_send_bitcoin_gold_change(client: Client):
)
def test_send_bitcoin_gold_nochange(client: Client):
def test_send_bitcoin_gold_nochange(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -124,14 +124,14 @@ def test_send_bitcoin_gold_nochange(client: Client):
amount=1_252_382_934 + 38_448_607 - 1_000,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_6f0398),
@ -150,7 +150,7 @@ def test_send_bitcoin_gold_nochange(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
)
assert (
@ -159,7 +159,7 @@ def test_send_bitcoin_gold_nochange(client: Client):
)
def test_attack_change_input(client: Client):
def test_attack_change_input(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -193,15 +193,15 @@ def test_attack_change_input(client: Client):
return msg
with client:
client.set_filter(messages.TxAck, attack_processor)
client.set_expected_responses(
with session:
session.set_filter(messages.TxAck, attack_processor)
session.set_expected_responses(
[
request_input(0),
request_output(0),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_6f0398),
@ -213,16 +213,16 @@ def test_attack_change_input(client: Client):
]
)
with pytest.raises(TrezorFailure):
btc.sign_tx(client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API)
btc.sign_tx(session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API)
@pytest.mark.multisig
def test_send_btg_multisig_change(client: Client):
def test_send_btg_multisig_change(session: Session):
# NOTE: fake input tx used
nodes = [
btc.get_public_node(
client, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold"
session, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold"
).node
for i in range(1, 4)
]
@ -254,13 +254,13 @@ def test_send_btg_multisig_change(client: Client):
script_type=messages.OutputScriptType.PAYTOMULTISIG,
amount=1_252_382_934 - 24_000 - 1_000,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
@ -275,7 +275,7 @@ def test_send_btg_multisig_change(client: Client):
]
)
signatures, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
)
assert (
@ -293,13 +293,13 @@ def test_send_btg_multisig_change(client: Client):
)
out2.address_n[2] = H_(1)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
@ -314,7 +314,7 @@ def test_send_btg_multisig_change(client: Client):
]
)
signatures, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
)
assert (
@ -327,7 +327,7 @@ def test_send_btg_multisig_change(client: Client):
)
def test_send_p2sh(client: Client):
def test_send_p2sh(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -347,16 +347,16 @@ def test_send_p2sh(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
amount=1_252_382_934 - 11_000 - 12_300_000,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_db7239),
@ -371,7 +371,7 @@ def test_send_p2sh(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
)
assert (
@ -380,7 +380,7 @@ def test_send_p2sh(client: Client):
)
def test_send_p2sh_witness_change(client: Client):
def test_send_p2sh_witness_change(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -400,13 +400,13 @@ def test_send_p2sh_witness_change(client: Client):
script_type=messages.OutputScriptType.PAYTOP2SHWITNESS,
amount=1_252_382_934 - 11_000 - 12_300_000,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
@ -422,7 +422,7 @@ def test_send_p2sh_witness_change(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
)
assert (
@ -432,12 +432,12 @@ def test_send_p2sh_witness_change(client: Client):
@pytest.mark.multisig
def test_send_multisig_1(client: Client):
def test_send_multisig_1(session: Session):
# NOTE: fake input tx used
nodes = [
btc.get_public_node(
client, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold"
session, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold"
).node
for i in range(1, 4)
]
@ -460,13 +460,13 @@ def test_send_multisig_1(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_7f1f6b),
@ -479,17 +479,17 @@ def test_send_multisig_1(client: Client):
request_finished(),
]
)
signatures, _ = btc.sign_tx(client, "Bgold", [inp1], [out1], prev_txes=TX_API)
signatures, _ = btc.sign_tx(session, "Bgold", [inp1], [out1], prev_txes=TX_API)
# store signature
inp1.multisig.signatures[0] = signatures[0]
# sign with third key
inp1.address_n[2] = H_(3)
client.set_expected_responses(
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_7f1f6b),
@ -503,7 +503,7 @@ def test_send_multisig_1(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1], [out1], prev_txes=TX_API
session, "Bgold", [inp1], [out1], prev_txes=TX_API
)
assert (
@ -512,7 +512,7 @@ def test_send_multisig_1(client: Client):
)
def test_send_mixed_inputs(client: Client):
def test_send_mixed_inputs(session: Session):
# NOTE: fake input tx used
# First is non-segwit, second is segwit.
@ -537,9 +537,9 @@ def test_send_mixed_inputs(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
with session:
_, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
)
assert (
@ -549,7 +549,7 @@ def test_send_mixed_inputs(client: Client):
@pytest.mark.models("core")
def test_send_btg_external_presigned(client: Client):
def test_send_btg_external_presigned(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -577,14 +577,14 @@ def test_send_btg_external_presigned(client: Client):
amount=1_252_382_934 + 58_456 - 1_000,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_6f0398),
@ -603,7 +603,7 @@ def test_send_btg_external_presigned(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
)
assert (

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
from ...common import is_core
@ -43,7 +43,7 @@ TXHASH_15575a = bytes.fromhex(
pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")]
def test_send_dash(client: Client):
def test_send_dash(session: Session):
inp1 = messages.TxInputType(
address_n=parse_path("m/44h/5h/0h/0/0"),
# dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH
@ -57,13 +57,13 @@ def test_send_dash(client: Client):
amount=999_999_000,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(inp1.prev_hash),
@ -77,7 +77,9 @@ def test_send_dash(client: Client):
request_finished(),
]
)
_, serialized_tx = btc.sign_tx(client, "Dash", [inp1], [out1], prev_txes=TX_API)
_, serialized_tx = btc.sign_tx(
session, "Dash", [inp1], [out1], prev_txes=TX_API
)
assert (
serialized_tx.hex()
@ -85,7 +87,7 @@ def test_send_dash(client: Client):
)
def test_send_dash_dip2_input(client: Client):
def test_send_dash_dip2_input(session: Session):
inp1 = messages.TxInputType(
address_n=parse_path("m/44h/5h/0h/0/0"),
# dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH
@ -104,14 +106,14 @@ def test_send_dash_dip2_input(client: Client):
amount=95_000_000,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(inp1.prev_hash),
@ -128,7 +130,7 @@ def test_send_dash_dip2_input(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Dash", [inp1], [out1, out2], prev_txes=TX_API
session, "Dash", [inp1], [out1, out2], prev_txes=TX_API
)
assert (

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
from ...common import is_core
@ -57,7 +57,7 @@ pytestmark = [
]
def test_send_decred(client: Client):
def test_send_decred(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -76,13 +76,13 @@ def test_send_decred(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.FeeOverThreshold),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
@ -95,7 +95,7 @@ def test_send_decred(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Decred Testnet", [inp1], [out1], prev_txes=TX_API
session, "Decred Testnet", [inp1], [out1], prev_txes=TX_API
)
assert (
@ -105,7 +105,7 @@ def test_send_decred(client: Client):
@pytest.mark.models("core")
def test_purchase_ticket_decred(client: Client):
def test_purchase_ticket_decred(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -133,8 +133,8 @@ def test_purchase_ticket_decred(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_output(0),
@ -153,7 +153,7 @@ def test_purchase_ticket_decred(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client,
session,
"Decred Testnet",
[inp1],
[out1, out2, out3],
@ -168,7 +168,7 @@ def test_purchase_ticket_decred(client: Client):
@pytest.mark.models("core")
def test_spend_from_stake_generation_and_revocation_decred(client: Client):
def test_spend_from_stake_generation_and_revocation_decred(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -197,14 +197,14 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_8b6890),
@ -223,7 +223,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API
session, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API
)
assert (
@ -232,7 +232,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client):
)
def test_send_decred_change(client: Client):
def test_send_decred_change(session: Session):
# NOTE: fake input tx used
inp1 = messages.TxInputType(
@ -278,15 +278,15 @@ def test_send_decred_change(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_input(2),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
request_output(1),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
@ -311,7 +311,7 @@ def test_send_decred_change(client: Client):
]
)
_, serialized_tx = btc.sign_tx(
client,
session,
"Decred Testnet",
[inp1, inp2, inp3],
[out1, out2],
@ -325,12 +325,12 @@ def test_send_decred_change(client: Client):
@pytest.mark.multisig
def test_decred_multisig_change(client: Client):
def test_decred_multisig_change(session: Session):
# NOTE: fake input tx used
paths = [parse_path(f"m/48h/1h/{index}'/0'") for index in range(3)]
nodes = [
btc.get_public_node(client, address_n, coin_name="Decred Testnet").node
btc.get_public_node(session, address_n, coin_name="Decred Testnet").node
for address_n in paths
]
@ -384,15 +384,15 @@ def test_decred_multisig_change(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
request_input(0),
request_input(1),
request_output(0),
request_output(1),
messages.ButtonRequest(code=B.ConfirmOutput),
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
request_input(0),
request_meta(FAKE_TXHASH_9ac7d2),
@ -410,7 +410,7 @@ def test_decred_multisig_change(client: Client):
]
)
signature, serialized_tx = btc.sign_tx(
client,
session,
"Decred Testnet",
[inp1, inp2],
[out1, out2],

View File

@ -18,7 +18,7 @@ import pytest
from trezorlib import btc, messages, models
from trezorlib.cli import btc as btc_cli
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import H_
from ...input_flows import InputFlowShowXpubQRCode
@ -165,14 +165,16 @@ def _address_n(purpose, coin, account, script_type):
@pytest.mark.parametrize(
"coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS
)
def test_descriptors(client: Client, coin, account, purpose, script_type, descriptors):
with client:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
def test_descriptors(
session: Session, coin, account, purpose, script_type, descriptors
):
with session:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
address_n = _address_n(purpose, coin, account, script_type)
res = btc.get_public_node(
client,
session,
_address_n(purpose, coin, account, script_type),
show_display=True,
coin_name=coin,
@ -187,13 +189,13 @@ def test_descriptors(client: Client, coin, account, purpose, script_type, descri
"coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS
)
def test_descriptors_trezorlib(
client: Client, coin, account, purpose, script_type, descriptors
session: Session, coin, account, purpose, script_type, descriptors
):
with client:
if client.model != models.T1B1:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
with session:
if session.client.model != models.T1B1:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
res = btc_cli._get_descriptor(
client, coin, account, purpose, script_type, show_display=True
session, coin, account, purpose, script_type, show_display=True
)
assert res == descriptors

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
from ...tx_cache import TxCache
@ -30,7 +30,7 @@ TXHASH_8a34cc = bytes.fromhex(
@pytest.mark.altcoin
def test_spend_lelantus(client: Client):
def test_spend_lelantus(session: Session):
inp1 = messages.TxInputType(
# THgGLVqfzJcaxRVPWE5fd8YJ1GpVePq2Uk
address_n=parse_path("m/44h/1h/0h/0/4"),
@ -45,7 +45,7 @@ def test_spend_lelantus(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
_, serialized_tx = btc.sign_tx(
client, "Firo Testnet", [inp1], [out1], prev_txes=TX_API
session, "Firo Testnet", [inp1], [out1], prev_txes=TX_API
)
assert_tx_matches(
serialized_tx,

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
TXHASH_33043a = bytes.fromhex(
@ -27,7 +27,7 @@ TXHASH_33043a = bytes.fromhex(
pytestmark = pytest.mark.altcoin
def test_send_p2tr(client: Client):
def test_send_p2tr(session: Session):
inp1 = messages.TxInputType(
# fc1prr07akly3xjtmggue0p04vghr8pdcgxrye2s00sahptwjeawxrkq2rxzr7
address_n=parse_path("m/86h/75h/0h/0/1"),
@ -42,7 +42,7 @@ def test_send_p2tr(client: Client):
amount=99_996_670_000,
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
_, serialized_tx = btc.sign_tx(client, "Fujicoin", [inp1], [out1])
_, serialized_tx = btc.sign_tx(session, "Fujicoin", [inp1], [out1])
# Transaction hex changed with fix #2085, all other details are the same as this tx:
# https://explorer.fujicoin.org/tx/a1c6a81f5e8023b17e6e3e51e2596d5b5e1d4914ea13c0c31cef90b3c3edee86
assert (

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, device, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import MultisigPubkeysOrder, SafetyCheckLevel
from trezorlib.tools import parse_path
@ -36,112 +36,112 @@ def getmultisig(chain, nr, xpubs):
)
def test_btc(client: Client):
def test_btc(session: Session):
assert (
btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"))
btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"))
== "1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL"
)
assert (
btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/1"))
btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/1"))
== "1GWFxtwWmNVqotUPXLcKVL2mUKpshuJYo"
)
assert (
btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
== "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE"
)
@pytest.mark.altcoin
def test_ltc(client: Client):
def test_ltc(session: Session):
assert (
btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/0"))
btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/0"))
== "LcubERmHD31PWup1fbozpKuiqjHZ4anxcL"
)
assert (
btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/1"))
btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/1"))
== "LVWBmHBkCGNjSPHucvL2PmnuRAJnucmRE6"
)
assert (
btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/1/0"))
btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/1/0"))
== "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h"
)
def test_tbtc(client: Client):
def test_tbtc(session: Session):
assert (
btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/0"))
btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/0"))
== "mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q"
)
assert (
btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/1"))
btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/1"))
== "mopZWqZZyQc3F2Sy33cvDtJchSAMsnLi7b"
)
assert (
btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/1/0"))
btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/1/0"))
== "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ"
)
@pytest.mark.altcoin
def test_bch(client: Client):
def test_bch(session: Session):
assert (
btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/0"))
btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/0"))
== "bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv"
)
assert (
btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/1"))
btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/1"))
== "bitcoincash:qr23ajjfd9wd73l87j642puf8cad20lfmqdgwvpat4"
)
assert (
btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/1/0"))
btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/1/0"))
== "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw"
)
@pytest.mark.altcoin
def test_grs(client: Client):
def test_grs(session: Session):
assert (
btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/0/0"))
btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/0/0"))
== "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM"
)
assert (
btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/0"))
btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/0"))
== "FmRaqvVBRrAp2Umfqx9V1ectZy8gw54QDN"
)
assert (
btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1"))
btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1"))
== "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di"
)
@pytest.mark.altcoin
def test_tgrs(client: Client):
def test_tgrs(session: Session):
assert (
btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"))
btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"))
== "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y"
)
assert (
btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0"))
btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0"))
== "mm6kLYbGEL1tGe4ZA8xacfgRPdW1LMq8cN"
)
assert (
btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1"))
btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1"))
== "mjXZwmEi1z1MzveZrKUAo4DBgbdq6ZhGD6"
)
@pytest.mark.altcoin
def test_elements(client: Client):
def test_elements(session: Session):
assert (
btc.get_address(client, "Elements", parse_path("m/44h/1h/0h/0/0"))
btc.get_address(session, "Elements", parse_path("m/44h/1h/0h/0/0"))
== "2dpWh6jbhAowNsQ5agtFzi7j6nKscj6UnEr"
)
@pytest.mark.models("core")
def test_address_mac(client: Client):
def test_address_mac(session: Session):
resp = btc.get_authenticated_address(
client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")
session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")
)
assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE"
assert (
@ -150,7 +150,7 @@ def test_address_mac(client: Client):
)
resp = btc.get_authenticated_address(
client, "Testnet", parse_path("m/44h/1h/0h/1/0")
session, "Testnet", parse_path("m/44h/1h/0h/1/0")
)
assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ"
assert (
@ -160,16 +160,16 @@ def test_address_mac(client: Client):
# Script type mismatch.
resp = btc.get_authenticated_address(
client, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False
session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False
)
assert resp.mac is None
@pytest.mark.models("core")
@pytest.mark.altcoin
def test_altcoin_address_mac(client: Client):
def test_altcoin_address_mac(session: Session):
resp = btc.get_authenticated_address(
client, "Litecoin", parse_path("m/44h/2h/0h/1/0")
session, "Litecoin", parse_path("m/44h/2h/0h/1/0")
)
assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h"
assert (
@ -178,7 +178,7 @@ def test_altcoin_address_mac(client: Client):
)
resp = btc.get_authenticated_address(
client, "Bcash", parse_path("m/44h/145h/0h/1/0")
session, "Bcash", parse_path("m/44h/145h/0h/1/0")
)
assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw"
assert (
@ -187,7 +187,7 @@ def test_altcoin_address_mac(client: Client):
)
resp = btc.get_authenticated_address(
client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")
session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")
)
assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di"
assert (
@ -197,9 +197,9 @@ def test_altcoin_address_mac(client: Client):
@pytest.mark.multisig
def test_multisig_pubkeys_order(client: Client):
xpub_internal = btc.get_public_node(client, parse_path("m/45h/0")).xpub
xpub_external = btc.get_public_node(client, parse_path("m/45h/1")).xpub
def test_multisig_pubkeys_order(session: Session):
xpub_internal = btc.get_public_node(session, parse_path("m/45h/0")).xpub
xpub_external = btc.get_public_node(session, parse_path("m/45h/1")).xpub
multisig_unsorted_1 = messages.MultisigRedeemScriptType(
nodes=[bip32.deserialize(xpub) for xpub in [xpub_external, xpub_internal]],
@ -238,45 +238,45 @@ def test_multisig_pubkeys_order(client: Client):
assert (
btc.get_address(
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1
)
== address_unsorted_1
)
assert (
btc.get_address(
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2
)
== address_unsorted_2
)
assert (
btc.get_address(
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1
)
== address_unsorted_2
)
assert (
btc.get_address(
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2
)
== address_unsorted_2
)
@pytest.mark.multisig
def test_multisig(client: Client):
def test_multisig(session: Session):
xpubs = []
for n in range(1, 4):
node = btc.get_public_node(client, parse_path(f"m/44h/0h/{n}h"))
node = btc.get_public_node(session, parse_path(f"m/44h/0h/{n}h"))
xpubs.append(node.xpub)
for nr in range(1, 4):
with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
client,
session,
"Bitcoin",
parse_path(f"m/44h/0h/{nr}h/0/0"),
show_display=(nr == 1),
@ -286,7 +286,7 @@ def test_multisig(client: Client):
)
assert (
btc.get_address(
client,
session,
"Bitcoin",
parse_path(f"m/44h/0h/{nr}h/1/0"),
show_display=(nr == 1),
@ -298,11 +298,11 @@ def test_multisig(client: Client):
@pytest.mark.multisig
@pytest.mark.parametrize("show_display", (True, False))
def test_multisig_missing(client: Client, show_display):
def test_multisig_missing(session: Session, show_display):
# Use account numbers 1, 2 and 3 to create a valid multisig,
# but not containing the keys from account 0 used below.
nodes = [
btc.get_public_node(client, parse_path(f"m/44h/0h/{i}h")).node
btc.get_public_node(session, parse_path(f"m/44h/0h/{i}h")).node
for i in range(1, 4)
]
@ -321,12 +321,12 @@ def test_multisig_missing(client: Client, show_display):
)
for multisig in (multisig1, multisig2):
with client, pytest.raises(TrezorFailure):
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
with pytest.raises(TrezorFailure), session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.get_address(
client,
session,
"Bitcoin",
parse_path("m/44h/0h/0h/0/0"),
show_display=show_display,
@ -336,22 +336,22 @@ def test_multisig_missing(client: Client, show_display):
@pytest.mark.altcoin
@pytest.mark.multisig
def test_bch_multisig(client: Client):
def test_bch_multisig(session: Session):
xpubs = []
for n in range(1, 4):
node = btc.get_public_node(
client, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash"
session, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash"
)
xpubs.append(node.xpub)
for nr in range(1, 4):
with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
client,
session,
"Bcash",
parse_path(f"m/44h/145h/{nr}h/0/0"),
show_display=(nr == 1),
@ -361,7 +361,7 @@ def test_bch_multisig(client: Client):
)
assert (
btc.get_address(
client,
session,
"Bcash",
parse_path(f"m/44h/145h/{nr}h/1/0"),
show_display=(nr == 1),
@ -371,43 +371,43 @@ def test_bch_multisig(client: Client):
)
def test_public_ckd(client: Client):
node = btc.get_public_node(client, parse_path("m/44h/0h/0h")).node
node_sub1 = btc.get_public_node(client, parse_path("m/44h/0h/0h/1/0")).node
def test_public_ckd(session: Session):
node = btc.get_public_node(session, parse_path("m/44h/0h/0h")).node
node_sub1 = btc.get_public_node(session, parse_path("m/44h/0h/0h/1/0")).node
node_sub2 = bip32.public_ckd(node, [1, 0])
assert node_sub1.chain_code == node_sub2.chain_code
assert node_sub1.public_key == node_sub2.public_key
address1 = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
address1 = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
address2 = bip32.get_address(node_sub2, 0)
assert address2 == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE"
assert address1 == address2
def test_invalid_path(client: Client):
def test_invalid_path(session: Session):
with pytest.raises(TrezorFailure, match="Forbidden key path"):
# slip44 id mismatch
btc.get_address(
client, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True
session, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True
)
def test_unknown_path(client: Client):
def test_unknown_path(session: Session):
UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0")
with client:
client.set_expected_responses([messages.Failure])
with session:
session.set_expected_responses([messages.Failure])
with pytest.raises(TrezorFailure, match="Forbidden key path"):
# account number is too high
btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True)
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True)
# disable safety checks
device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily)
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with client:
client.set_expected_responses(
with session:
session.set_expected_responses(
[
messages.ButtonRequest(
code=messages.ButtonRequestType.UnknownDerivationPath
@ -416,30 +416,30 @@ def test_unknown_path(client: Client):
messages.Address,
]
)
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
if is_core(session):
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
# try again with a warning
btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True)
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True)
with client:
with session:
# no warning is displayed when the call is silent
client.set_expected_responses([messages.Address])
btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=False)
session.set_expected_responses([messages.Address])
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False)
@pytest.mark.altcoin
def test_crw(client: Client):
def test_crw(session: Session):
assert (
btc.get_address(client, "Crown", parse_path("m/44h/72h/0h/0/0"))
btc.get_address(session, "Crown", parse_path("m/44h/72h/0h/0/0"))
== "CRWYdvZM1yXMKQxeN3hRsAbwa7drfvTwys48"
)
@pytest.mark.multisig
def test_multisig_different_paths(client: Client):
def test_multisig_different_paths(session: Session):
nodes = [
btc.get_public_node(client, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node
btc.get_public_node(session, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node
for i in range(2)
]
@ -455,12 +455,12 @@ def test_multisig_different_paths(client: Client):
with pytest.raises(
Exception, match="Using different paths for different xpubs is not allowed"
):
with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.get_address(
client,
session,
"Bitcoin",
parse_path("m/45h/0/0/0"),
show_display=True,
@ -468,13 +468,13 @@ def test_multisig_different_paths(client: Client):
script_type=messages.InputScriptType.SPENDMULTISIG,
)
device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily)
with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.get_address(
client,
session,
"Bitcoin",
parse_path("m/45h/0/0/0"),
show_display=True,

View File

@ -17,7 +17,7 @@
import pytest
from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path
@ -25,10 +25,10 @@ from ...common import is_core
from ...input_flows import InputFlowConfirmAllWarnings
def test_show_segwit(client: Client):
def test_show_segwit(session: Session):
assert (
btc.get_address(
client,
session,
"Testnet",
parse_path("m/49h/1h/0h/1/0"),
True,
@ -39,7 +39,7 @@ def test_show_segwit(client: Client):
)
assert (
btc.get_address(
client,
session,
"Testnet",
parse_path("m/49h/1h/0h/0/0"),
False,
@ -50,7 +50,7 @@ def test_show_segwit(client: Client):
)
assert (
btc.get_address(
client,
session,
"Testnet",
parse_path("m/44h/1h/0h/0/0"),
False,
@ -61,7 +61,7 @@ def test_show_segwit(client: Client):
)
assert (
btc.get_address(
client,
session,
"Testnet",
parse_path("m/44h/1h/0h/0/0"),
False,
@ -73,14 +73,14 @@ def test_show_segwit(client: Client):
@pytest.mark.altcoin
def test_show_segwit_altcoin(client: Client):
with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
def test_show_segwit_altcoin(session: Session):
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
client,
session,
"Groestlcoin Testnet",
parse_path("m/49h/1h/0h/1/0"),
True,
@ -91,7 +91,7 @@ def test_show_segwit_altcoin(client: Client):
)
assert (
btc.get_address(
client,
session,
"Groestlcoin Testnet",
parse_path("m/49h/1h/0h/0/0"),
True,
@ -102,7 +102,7 @@ def test_show_segwit_altcoin(client: Client):
)
assert (
btc.get_address(
client,
session,
"Groestlcoin Testnet",
parse_path("m/44h/1h/0h/0/0"),
True,
@ -113,7 +113,7 @@ def test_show_segwit_altcoin(client: Client):
)
assert (
btc.get_address(
client,
session,
"Groestlcoin Testnet",
parse_path("m/44h/1h/0h/0/0"),
True,
@ -124,7 +124,7 @@ def test_show_segwit_altcoin(client: Client):
)
assert (
btc.get_address(
client,
session,
"Elements",
parse_path("m/49h/1h/0h/0/0"),
True,
@ -136,10 +136,10 @@ def test_show_segwit_altcoin(client: Client):
@pytest.mark.multisig
def test_show_multisig_3(client: Client):
def test_show_multisig_3(session: Session):
nodes = [
btc.get_public_node(
client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet"
session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet"
).node
for i in range(1, 4)
]
@ -155,7 +155,7 @@ def test_show_multisig_3(client: Client):
for i in [1, 2, 3]:
assert (
btc.get_address(
client,
session,
"Testnet",
parse_path(f"m/49h/1h/{i}h/0/7"),
False,
@ -168,11 +168,11 @@ def test_show_multisig_3(client: Client):
@pytest.mark.multisig
@pytest.mark.parametrize("show_display", (True, False))
def test_multisig_missing(client: Client, show_display):
def test_multisig_missing(session: Session, show_display):
# Use account numbers 1, 2 and 3 to create a valid multisig,
# but not containing the keys from account 0 used below.
nodes = [
btc.get_public_node(client, parse_path(f"m/49h/0h/{i}h")).node
btc.get_public_node(session, parse_path(f"m/49h/0h/{i}h")).node
for i in range(1, 4)
]
@ -193,7 +193,7 @@ def test_multisig_missing(client: Client, show_display):
for multisig in (multisig1, multisig2):
with pytest.raises(TrezorFailure):
btc.get_address(
client,
session,
"Bitcoin",
parse_path("m/49h/0h/0h/0/0"),
show_display=show_display,

Some files were not shown because too many files have changed in this diff Show More