mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-04-21 17:49:02 +00:00
Merge 0cb657a83a
into 52f5593f28
This commit is contained in:
commit
806b2ec0ba
5
.github/workflows/legacy.yml
vendored
5
.github/workflows/legacy.yml
vendored
@ -113,6 +113,7 @@ jobs:
|
||||
name: Device test
|
||||
runs-on: ubuntu-latest
|
||||
needs: legacy_emu
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
matrix:
|
||||
coins: [universal, btconly]
|
||||
@ -120,6 +121,7 @@ jobs:
|
||||
env:
|
||||
EMULATOR: 1
|
||||
TREZOR_PYTEST_SKIP_ALTCOINS: ${{ matrix.coins == 'universal' && '0' || '1' }}
|
||||
TESTOPTS: "--timeout 120"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@ -148,6 +150,7 @@ jobs:
|
||||
name: Upgrade test
|
||||
runs-on: ubuntu-latest
|
||||
needs: legacy_emu
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
asan: ${{ fromJSON(github.event_name == 'schedule' && '["noasan", "asan"]' || '["noasan"]') }}
|
||||
@ -164,7 +167,7 @@ jobs:
|
||||
- run: chmod +x legacy/firmware/*.elf
|
||||
- uses: ./.github/actions/environment
|
||||
- run: nix-shell --run "tests/download_emulators.sh"
|
||||
- run: nix-shell --run "poetry run pytest tests/upgrade_tests"
|
||||
- run: nix-shell --run "poetry run pytest --timeout 120 tests/upgrade_tests"
|
||||
|
||||
legacy_hwi_test:
|
||||
name: HWI test
|
||||
|
@ -288,9 +288,10 @@ def cli(
|
||||
label = "Emulator"
|
||||
|
||||
assert emulator.client is not None
|
||||
trezorlib.device.wipe(emulator.client)
|
||||
trezorlib.device.wipe(emulator.client.get_seedless_session())
|
||||
|
||||
trezorlib.debuglink.load_device(
|
||||
emulator.client,
|
||||
emulator.client.get_seedless_session(),
|
||||
mnemonics,
|
||||
pin=None,
|
||||
passphrase_protection=False,
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import binascii
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport
|
||||
from trezorlib.transport.hid import HidTransport
|
||||
|
||||
devices = HidTransport.enumerate()
|
||||
if len(devices) > 0:
|
||||
|
1
python/.changelog.d/4577.changed
Normal file
1
python/.changelog.d/4577.changed
Normal file
@ -0,0 +1 @@
|
||||
Changed trezorlib to session-based. Changes also affect trezorctl, python tools, and tests.
|
@ -14,4 +14,4 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
__version__ = "0.13.11"
|
||||
__version__ = "0.14.0"
|
||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, TextIO, Union, cast
|
||||
|
||||
from ..debuglink import TrezorClientDebugLink
|
||||
from ..transport import Transport
|
||||
from ..transport.udp import UdpTransport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -103,6 +104,8 @@ class Emulator:
|
||||
"""
|
||||
if self._client is None:
|
||||
raise RuntimeError
|
||||
if self._client.is_invalidated:
|
||||
self._client = self._client.get_new_client()
|
||||
return self._client
|
||||
|
||||
def make_args(self) -> List[str]:
|
||||
@ -116,13 +119,12 @@ class Emulator:
|
||||
|
||||
def wait_until_ready(self, timeout: float = EMULATOR_WAIT_TIME) -> None:
|
||||
assert self.process is not None, "Emulator not started"
|
||||
transport = self._get_transport()
|
||||
transport.open()
|
||||
self.transport.open()
|
||||
LOG.info("Waiting for emulator to come up...")
|
||||
start = time.monotonic()
|
||||
try:
|
||||
while True:
|
||||
if transport._ping():
|
||||
if self.transport.ping():
|
||||
break
|
||||
if self.process.poll() is not None:
|
||||
raise RuntimeError("Emulator process died")
|
||||
@ -133,7 +135,7 @@ class Emulator:
|
||||
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
transport.close()
|
||||
self.transport.close()
|
||||
|
||||
LOG.info(f"Emulator ready after {time.monotonic() - start:.3f} seconds")
|
||||
|
||||
@ -164,7 +166,11 @@ class Emulator:
|
||||
env=env,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
def start(
|
||||
self,
|
||||
transport: Optional[UdpTransport] = None,
|
||||
debug_transport: Optional[Transport] = None,
|
||||
) -> None:
|
||||
if self.process:
|
||||
if self.process.poll() is not None:
|
||||
# process has died, stop and start again
|
||||
@ -174,6 +180,7 @@ class Emulator:
|
||||
# process is running, no need to start again
|
||||
return
|
||||
|
||||
self.transport = transport or self._get_transport()
|
||||
self.process = self.launch_process()
|
||||
_RUNNING_PIDS.add(self.process)
|
||||
try:
|
||||
@ -187,15 +194,16 @@ class Emulator:
|
||||
(self.profile_dir / "trezor.pid").write_text(str(self.process.pid) + "\n")
|
||||
(self.profile_dir / "trezor.port").write_text(str(self.port) + "\n")
|
||||
|
||||
transport = self._get_transport()
|
||||
self._client = TrezorClientDebugLink(
|
||||
transport, auto_interact=self.auto_interact
|
||||
self.transport,
|
||||
auto_interact=self.auto_interact,
|
||||
open_transport=True,
|
||||
debug_transport=debug_transport,
|
||||
)
|
||||
self._client.open()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client.close_transport()
|
||||
self._client = None
|
||||
|
||||
if self.process:
|
||||
@ -219,8 +227,9 @@ class Emulator:
|
||||
# preserving the recording directory between restarts
|
||||
self.restart_amount += 1
|
||||
prev_screenshot_dir = self.client.debug.screenshot_recording_dir
|
||||
debug_transport = self.client.debug.transport
|
||||
self.stop()
|
||||
self.start()
|
||||
self.start(transport=self.transport, debug_transport=debug_transport)
|
||||
if prev_screenshot_dir:
|
||||
self.client.debug.start_recording(
|
||||
prev_screenshot_dir, refresh_index=self.restart_amount
|
||||
|
@ -10,7 +10,7 @@ from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, utils
|
||||
|
||||
from . import device
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -349,7 +349,7 @@ def verify_authentication_response(
|
||||
|
||||
|
||||
def authenticate_device(
|
||||
client: TrezorClient,
|
||||
session: Session,
|
||||
challenge: bytes | None = None,
|
||||
*,
|
||||
whitelist: t.Collection[bytes] | None = None,
|
||||
@ -359,7 +359,7 @@ def authenticate_device(
|
||||
if challenge is None:
|
||||
challenge = secrets.token_bytes(16)
|
||||
|
||||
resp = device.authenticate(client, challenge)
|
||||
resp = device.authenticate(session, challenge)
|
||||
|
||||
return verify_authentication_response(
|
||||
challenge,
|
||||
|
@ -19,16 +19,16 @@ from typing import TYPE_CHECKING
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def list_names(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
) -> messages.BenchmarkNames:
|
||||
return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
|
||||
return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
|
||||
|
||||
|
||||
def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult:
|
||||
return client.call(
|
||||
def run(session: "Session", name: str) -> messages.BenchmarkResult:
|
||||
return session.call(
|
||||
messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
|
||||
)
|
||||
|
@ -18,20 +18,19 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from . import messages
|
||||
from .protobuf import dict_to_proto
|
||||
from .tools import session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -40,17 +39,16 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient", address_n: "Address", show_display: bool = False
|
||||
session: "Session", address_n: "Address", show_display: bool = False
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
|
||||
expect=messages.BinancePublicKey,
|
||||
).public_key
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
|
||||
) -> messages.BinanceSignedTx:
|
||||
msg = tx_json["msgs"][0]
|
||||
tx_msg = tx_json.copy()
|
||||
@ -59,7 +57,7 @@ def sign_tx(
|
||||
tx_msg["chunkify"] = chunkify
|
||||
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
|
||||
|
||||
client.call(envelope, expect=messages.BinanceTxRequest)
|
||||
session.call(envelope, expect=messages.BinanceTxRequest)
|
||||
|
||||
if "refid" in msg:
|
||||
msg = dict_to_proto(messages.BinanceCancelMsg, msg)
|
||||
@ -70,4 +68,4 @@ def sign_tx(
|
||||
else:
|
||||
raise ValueError("can not determine msg type")
|
||||
|
||||
return client.call(msg, expect=messages.BinanceSignedTx)
|
||||
return session.call(msg, expect=messages.BinanceSignedTx)
|
||||
|
@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
|
||||
from typing_extensions import Protocol, TypedDict
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import _return_success, prepare_message_bytes, session
|
||||
from .tools import _return_success, prepare_message_bytes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
class ScriptSig(TypedDict):
|
||||
asm: str
|
||||
@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
|
||||
|
||||
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
show_display: bool = False,
|
||||
@ -116,12 +116,12 @@ def get_public_node(
|
||||
unlock_path_mac: Optional[bytes] = None,
|
||||
) -> messages.PublicKey:
|
||||
if unlock_path:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||
expect=messages.UnlockedPathRequest,
|
||||
)
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetPublicKey(
|
||||
address_n=n,
|
||||
ecdsa_curve_name=ecdsa_curve_name,
|
||||
@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str:
|
||||
|
||||
|
||||
def get_authenticated_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
@ -151,12 +151,12 @@ def get_authenticated_address(
|
||||
chunkify: bool = False,
|
||||
) -> messages.Address:
|
||||
if unlock_path:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||
expect=messages.UnlockedPathRequest,
|
||||
)
|
||||
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetAddress(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -171,13 +171,13 @@ def get_authenticated_address(
|
||||
|
||||
|
||||
def get_ownership_id(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetOwnershipId(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -189,7 +189,7 @@ def get_ownership_id(
|
||||
|
||||
|
||||
def get_ownership_proof(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
multisig: Optional[messages.MultisigRedeemScriptType] = None,
|
||||
@ -200,9 +200,9 @@ def get_ownership_proof(
|
||||
preauthorized: bool = False,
|
||||
) -> Tuple[bytes, bytes]:
|
||||
if preauthorized:
|
||||
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
|
||||
res = client.call(
|
||||
res = session.call(
|
||||
messages.GetOwnershipProof(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
@ -219,7 +219,7 @@ def get_ownership_proof(
|
||||
|
||||
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
@ -227,7 +227,7 @@ def sign_message(
|
||||
no_script_type: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> messages.MessageSignature:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SignMessage(
|
||||
coin_name=coin_name,
|
||||
address_n=n,
|
||||
@ -241,7 +241,7 @@ def sign_message(
|
||||
|
||||
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
address: str,
|
||||
signature: bytes,
|
||||
@ -249,7 +249,7 @@ def verify_message(
|
||||
chunkify: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.VerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
@ -264,9 +264,8 @@ def verify_message(
|
||||
return False
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin_name: str,
|
||||
inputs: Sequence[messages.TxInputType],
|
||||
outputs: Sequence[messages.TxOutputType],
|
||||
@ -314,14 +313,14 @@ def sign_tx(
|
||||
setattr(signtx, name, value)
|
||||
|
||||
if unlock_path:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
|
||||
expect=messages.UnlockedPathRequest,
|
||||
)
|
||||
elif preauthorized:
|
||||
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
|
||||
|
||||
res = client.call(signtx, expect=messages.TxRequest)
|
||||
res = session.call(signtx, expect=messages.TxRequest)
|
||||
|
||||
# Prepare structure for signatures
|
||||
signatures: List[Optional[bytes]] = [None] * len(inputs)
|
||||
@ -380,7 +379,7 @@ def sign_tx(
|
||||
if res.request_type == R.TXPAYMENTREQ:
|
||||
assert res.details.request_index is not None
|
||||
msg = payment_reqs[res.details.request_index]
|
||||
res = client.call(msg, expect=messages.TxRequest)
|
||||
res = session.call(msg, expect=messages.TxRequest)
|
||||
else:
|
||||
msg = messages.TransactionType()
|
||||
if res.request_type == R.TXMETA:
|
||||
@ -410,7 +409,7 @@ def sign_tx(
|
||||
f"Unknown request type - {res.request_type}."
|
||||
)
|
||||
|
||||
res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
|
||||
res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
|
||||
|
||||
for i, sig in zip(inputs, signatures):
|
||||
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
|
||||
@ -420,7 +419,7 @@ def sign_tx(
|
||||
|
||||
|
||||
def authorize_coinjoin(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coordinator: str,
|
||||
max_rounds: int,
|
||||
max_coordinator_fee_rate: int,
|
||||
@ -429,7 +428,7 @@ def authorize_coinjoin(
|
||||
coin_name: str,
|
||||
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
|
||||
) -> str | None:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.AuthorizeCoinJoin(
|
||||
coordinator=coordinator,
|
||||
max_rounds=max_rounds,
|
||||
|
@ -35,7 +35,7 @@ from . import messages as m
|
||||
from . import tools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
PROTOCOL_MAGICS = {
|
||||
"mainnet": 764824073,
|
||||
@ -818,7 +818,7 @@ def _get_collateral_inputs_items(
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_parameters: m.CardanoAddressParametersType,
|
||||
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
|
||||
network_id: int = NETWORK_IDS["mainnet"],
|
||||
@ -826,7 +826,7 @@ def get_address(
|
||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
m.CardanoGetAddress(
|
||||
address_parameters=address_parameters,
|
||||
protocol_magic=protocol_magic,
|
||||
@ -840,12 +840,12 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||
show_display: bool = False,
|
||||
) -> m.CardanoPublicKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
m.CardanoGetPublicKey(
|
||||
address_n=address_n,
|
||||
derivation_type=derivation_type,
|
||||
@ -856,12 +856,12 @@ def get_public_key(
|
||||
|
||||
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
native_script: m.CardanoNativeScript,
|
||||
display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
|
||||
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
|
||||
) -> m.CardanoNativeScriptHash:
|
||||
return client.call(
|
||||
return session.call(
|
||||
m.CardanoGetNativeScriptHash(
|
||||
script=native_script,
|
||||
display_format=display_format,
|
||||
@ -872,7 +872,7 @@ def get_native_script_hash(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
signing_mode: m.CardanoTxSigningMode,
|
||||
inputs: List[InputWithPath],
|
||||
outputs: List[OutputWithData],
|
||||
@ -907,7 +907,7 @@ def sign_tx(
|
||||
signing_mode,
|
||||
)
|
||||
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
m.CardanoSignTxInit(
|
||||
signing_mode=signing_mode,
|
||||
inputs_count=len(inputs),
|
||||
@ -942,12 +942,12 @@ def sign_tx(
|
||||
_get_certificates_items(certificates),
|
||||
withdrawals,
|
||||
):
|
||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
|
||||
sign_tx_response: Dict[str, Any] = {}
|
||||
|
||||
if auxiliary_data is not None:
|
||||
auxiliary_data_supplement = client.call(
|
||||
auxiliary_data_supplement = session.call(
|
||||
auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
|
||||
)
|
||||
if (
|
||||
@ -958,25 +958,25 @@ def sign_tx(
|
||||
auxiliary_data_supplement.__dict__
|
||||
)
|
||||
|
||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
|
||||
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
|
||||
|
||||
for tx_item in chain(
|
||||
_get_mint_items(mint),
|
||||
_get_collateral_inputs_items(collateral_inputs),
|
||||
required_signers,
|
||||
):
|
||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
|
||||
if collateral_return is not None:
|
||||
for tx_item in _get_output_items(collateral_return):
|
||||
response = client.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
response = session.call(tx_item, expect=m.CardanoTxItemAck)
|
||||
|
||||
for reference_input in reference_inputs:
|
||||
response = client.call(reference_input, expect=m.CardanoTxItemAck)
|
||||
response = session.call(reference_input, expect=m.CardanoTxItemAck)
|
||||
|
||||
sign_tx_response["witnesses"] = []
|
||||
for witness_request in witness_requests:
|
||||
response = client.call(witness_request, expect=m.CardanoTxWitnessResponse)
|
||||
response = session.call(witness_request, expect=m.CardanoTxWitnessResponse)
|
||||
sign_tx_response["witnesses"].append(
|
||||
{
|
||||
"type": response.type,
|
||||
@ -986,9 +986,9 @@ def sign_tx(
|
||||
}
|
||||
)
|
||||
|
||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
|
||||
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
|
||||
sign_tx_response["tx_hash"] = response.tx_hash
|
||||
|
||||
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
|
||||
response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
|
||||
|
||||
return sign_tx_response
|
||||
|
@ -14,33 +14,44 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
|
||||
import click
|
||||
|
||||
from .. import exceptions, transport
|
||||
from ..client import TrezorClient
|
||||
from ..ui import ClickUI, ScriptUI
|
||||
from .. import exceptions, transport, ui
|
||||
from ..client import PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient
|
||||
from ..messages import Capability
|
||||
from ..transport import Transport
|
||||
from ..transport.session import Session, SessionV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
_TRANSPORT: Transport | None = None
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
# Needed to enforce a return value from decorators
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from ..transport import Transport
|
||||
from ..ui import TrezorClientUI
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
R = t.TypeVar("R")
|
||||
FuncWithSession = t.Callable[Concatenate[Session, P], R]
|
||||
|
||||
|
||||
class ChoiceType(click.Choice):
|
||||
def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None:
|
||||
|
||||
def __init__(
|
||||
self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True
|
||||
) -> None:
|
||||
super().__init__(list(typemap.keys()))
|
||||
self.case_sensitive = case_sensitive
|
||||
if case_sensitive:
|
||||
@ -48,7 +59,7 @@ class ChoiceType(click.Choice):
|
||||
else:
|
||||
self.typemap = {k.lower(): v for k, v in typemap.items()}
|
||||
|
||||
def convert(self, value: Any, param: Any, ctx: click.Context) -> Any:
|
||||
def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any:
|
||||
if value in self.typemap.values():
|
||||
return value
|
||||
value = super().convert(value, param, ctx)
|
||||
@ -57,11 +68,52 @@ class ChoiceType(click.Choice):
|
||||
return self.typemap[value]
|
||||
|
||||
|
||||
def get_passphrase(
|
||||
available_on_device: bool, passphrase_on_host: bool
|
||||
) -> t.Union[str, object]:
|
||||
if available_on_device and not passphrase_on_host:
|
||||
return PASSPHRASE_ON_DEVICE
|
||||
|
||||
env_passphrase = os.getenv("PASSPHRASE")
|
||||
if env_passphrase is not None:
|
||||
ui.echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
return env_passphrase
|
||||
|
||||
while True:
|
||||
try:
|
||||
passphrase = ui.prompt(
|
||||
"Passphrase required",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
# In case user sees the input on the screen, we do not need confirmation
|
||||
if not ui.CAN_HANDLE_HIDDEN_INPUT:
|
||||
return passphrase
|
||||
second = ui.prompt(
|
||||
"Confirm your passphrase",
|
||||
hide_input=True,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
if passphrase == second:
|
||||
return passphrase
|
||||
else:
|
||||
ui.echo("Passphrase did not match. Please try again.")
|
||||
except click.Abort:
|
||||
raise exceptions.Cancelled from None
|
||||
|
||||
|
||||
def get_client(transport: Transport) -> TrezorClient:
|
||||
return TrezorClient(transport)
|
||||
|
||||
|
||||
class TrezorConnection:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
session_id: Optional[bytes],
|
||||
session_id: bytes | None,
|
||||
passphrase_on_host: bool,
|
||||
script: bool,
|
||||
) -> None:
|
||||
@ -70,31 +122,95 @@ class TrezorConnection:
|
||||
self.passphrase_on_host = passphrase_on_host
|
||||
self.script = script
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
derive_cardano: bool = False,
|
||||
empty_passphrase: bool = False,
|
||||
must_resume: bool = False,
|
||||
) -> Session:
|
||||
client = self.get_client()
|
||||
if must_resume and self.session_id is None:
|
||||
click.echo("Failed to resume session - no session id provided")
|
||||
raise RuntimeError("Failed to resume session - no session id provided")
|
||||
|
||||
# Try resume session from id
|
||||
if self.session_id is not None:
|
||||
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
||||
session = SessionV1.resume_from_id(
|
||||
client=client, session_id=self.session_id
|
||||
)
|
||||
else:
|
||||
raise Exception("Unsupported client protocol", client.protocol_version)
|
||||
if must_resume:
|
||||
if session.id != self.session_id or session.id is None:
|
||||
click.echo("Failed to resume session")
|
||||
env_var = os.environ.get("TREZOR_SESSION_ID")
|
||||
if env_var and bytes.fromhex(env_var) == self.session_id:
|
||||
click.echo(
|
||||
"Session-id stored in TREZOR_SESSION_ID is no longer valid. Call 'unset TREZOR_SESSION_ID' to clear it."
|
||||
)
|
||||
raise exceptions.FailedSessionResumption()
|
||||
return session
|
||||
|
||||
features = client.protocol.get_features()
|
||||
|
||||
passphrase_protection = features.passphrase_protection
|
||||
if passphrase_protection is None:
|
||||
raise RuntimeError("Device is locked")
|
||||
|
||||
if not passphrase_protection:
|
||||
return client.get_session(derive_cardano=derive_cardano)
|
||||
|
||||
if empty_passphrase:
|
||||
passphrase = ""
|
||||
elif self.script:
|
||||
passphrase = None
|
||||
else:
|
||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||
session = client.get_session(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True
|
||||
)
|
||||
return session
|
||||
|
||||
def get_transport(self) -> "Transport":
|
||||
global _TRANSPORT
|
||||
if _TRANSPORT is not None:
|
||||
return _TRANSPORT
|
||||
|
||||
try:
|
||||
# look for transport without prefix search
|
||||
return transport.get_transport(self.path, prefix_search=False)
|
||||
_TRANSPORT = transport.get_transport(self.path, prefix_search=False)
|
||||
except Exception:
|
||||
# most likely not found. try again below.
|
||||
pass
|
||||
|
||||
# look for transport with prefix search
|
||||
# if this fails, we want the exception to bubble up to the caller
|
||||
return transport.get_transport(self.path, prefix_search=True)
|
||||
if not _TRANSPORT:
|
||||
_TRANSPORT = transport.get_transport(self.path, prefix_search=True)
|
||||
|
||||
def get_ui(self) -> "TrezorClientUI":
|
||||
if self.script:
|
||||
# It is alright to return just the class object instead of instance,
|
||||
# as the ScriptUI class object itself is the implementation of TrezorClientUI
|
||||
# (ScriptUI is just a set of staticmethods)
|
||||
return ScriptUI
|
||||
else:
|
||||
return ClickUI(passphrase_on_host=self.passphrase_on_host)
|
||||
_TRANSPORT.open()
|
||||
atexit.register(_TRANSPORT.close)
|
||||
return _TRANSPORT
|
||||
|
||||
def get_client(self) -> TrezorClient:
|
||||
transport = self.get_transport()
|
||||
ui = self.get_ui()
|
||||
return TrezorClient(transport, ui=ui, session_id=self.session_id)
|
||||
client = get_client(self.get_transport())
|
||||
if self.script:
|
||||
client.button_callback = ui.ScriptUI.button_request
|
||||
client.passphrase_callback = ui.ScriptUI.get_passphrase
|
||||
client.pin_callback = ui.ScriptUI.get_pin
|
||||
else:
|
||||
click_ui = ui.ClickUI()
|
||||
client.button_callback = click_ui.button_request
|
||||
client.passphrase_callback = click_ui.get_passphrase
|
||||
client.pin_callback = click_ui.get_pin
|
||||
return client
|
||||
|
||||
def get_seedless_session(self) -> Session:
|
||||
client = self.get_client()
|
||||
seedless_session = client.get_seedless_session()
|
||||
return seedless_session
|
||||
|
||||
@contextmanager
|
||||
def client_context(self):
|
||||
@ -127,36 +243,98 @@ class TrezorConnection:
|
||||
raise click.ClickException(str(e)) from e
|
||||
# other exceptions may cause a traceback
|
||||
|
||||
@contextmanager
|
||||
def session_context(
|
||||
self,
|
||||
empty_passphrase: bool = False,
|
||||
derive_cardano: bool = False,
|
||||
seedless: bool = False,
|
||||
must_resume: bool = False,
|
||||
):
|
||||
"""Get a session instance as a context manager. Handle errors in a manner
|
||||
appropriate for end-users.
|
||||
|
||||
def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]":
|
||||
"""Wrap a Click command in `with obj.client_context() as client`.
|
||||
Usage:
|
||||
>>> with obj.session_context() as session:
|
||||
>>> do_your_actions_here()
|
||||
"""
|
||||
try:
|
||||
if seedless:
|
||||
session = self.get_seedless_session()
|
||||
else:
|
||||
session = self.get_session(
|
||||
derive_cardano=derive_cardano,
|
||||
empty_passphrase=empty_passphrase,
|
||||
must_resume=must_resume,
|
||||
)
|
||||
except transport.DeviceIsBusy:
|
||||
click.echo("Device is in use by another process.")
|
||||
sys.exit(1)
|
||||
except exceptions.FailedSessionResumption:
|
||||
sys.exit(1)
|
||||
except Exception:
|
||||
click.echo("Failed to find a Trezor device.")
|
||||
if self.path is not None:
|
||||
click.echo(f"Using path: {self.path}")
|
||||
sys.exit(1)
|
||||
|
||||
Sessions are handled transparently. The user is warned when session did not resume
|
||||
cleanly. The session is closed after the command completes - unless the session
|
||||
was resumed, in which case it should remain open.
|
||||
try:
|
||||
yield session
|
||||
except exceptions.Cancelled:
|
||||
# handle cancel action
|
||||
click.echo("Action was cancelled.")
|
||||
sys.exit(1)
|
||||
except exceptions.TrezorException as e:
|
||||
# handle any Trezor-sent exceptions as user-readable
|
||||
raise click.ClickException(str(e)) from e
|
||||
# other exceptions may cause a traceback
|
||||
|
||||
|
||||
def with_session(
|
||||
func: "t.Callable[Concatenate[Session, P], R]|None" = None,
|
||||
*,
|
||||
empty_passphrase: bool = False,
|
||||
derive_cardano: bool = False,
|
||||
seedless: bool = False,
|
||||
must_resume: bool = False,
|
||||
) -> t.Callable[[FuncWithSession], t.Callable[P, R]]:
|
||||
"""Provides a Click command with parameter `session=obj.get_session(...)`
|
||||
based on the parameters provided.
|
||||
|
||||
If default parameters are ok, this decorator can be used without parentheses.
|
||||
"""
|
||||
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def trezorctl_command_with_client(
|
||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
with obj.client_context() as client:
|
||||
session_was_resumed = obj.session_id == client.session_id
|
||||
if not session_was_resumed and obj.session_id is not None:
|
||||
# tried to resume but failed
|
||||
click.echo("Warning: failed to resume session.", err=True)
|
||||
def decorator(
|
||||
func: FuncWithSession,
|
||||
) -> "t.Callable[P, R]":
|
||||
|
||||
try:
|
||||
return func(client, *args, **kwargs)
|
||||
finally:
|
||||
if not session_was_resumed:
|
||||
try:
|
||||
client.end_session()
|
||||
except Exception:
|
||||
pass
|
||||
@click.pass_obj
|
||||
@functools.wraps(func)
|
||||
def function_with_session(
|
||||
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
|
||||
) -> "R":
|
||||
is_resume_mandatory = must_resume or obj.session_id is not None
|
||||
|
||||
return trezorctl_command_with_client
|
||||
with obj.session_context(
|
||||
empty_passphrase=empty_passphrase,
|
||||
derive_cardano=derive_cardano,
|
||||
seedless=seedless,
|
||||
must_resume=is_resume_mandatory,
|
||||
) as session:
|
||||
try:
|
||||
return func(session, *args, **kwargs)
|
||||
|
||||
finally:
|
||||
if not is_resume_mandatory:
|
||||
session.end()
|
||||
|
||||
return function_with_session
|
||||
|
||||
# If the decorator @get_session is used without parentheses
|
||||
if func and callable(func):
|
||||
return decorator(func) # type: ignore [Function return type]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class AliasedGroup(click.Group):
|
||||
@ -188,14 +366,14 @@ class AliasedGroup(click.Group):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aliases: Optional[Dict[str, click.Command]] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
aliases: t.Dict[str, click.Command] | None = None,
|
||||
*args: t.Any,
|
||||
**kwargs: t.Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.aliases = aliases or {}
|
||||
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]:
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||
cmd_name = cmd_name.replace("_", "-")
|
||||
# try to look up the real name
|
||||
cmd = super().get_command(ctx, cmd_name)
|
||||
|
@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import click
|
||||
|
||||
from .. import benchmark
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
def list_names_patern(
|
||||
client: "TrezorClient", pattern: Optional[str] = None
|
||||
) -> List[str]:
|
||||
names = list(benchmark.list_names(client).names)
|
||||
def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
|
||||
names = list(benchmark.list_names(session).names)
|
||||
if pattern is None:
|
||||
return names
|
||||
return [name for name in names if fnmatch(name, pattern)]
|
||||
@ -43,10 +41,10 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_client
|
||||
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def list_names(session: "Session", pattern: Optional[str] = None) -> None:
|
||||
"""List names of all supported benchmarks"""
|
||||
names = list_names_patern(client, pattern)
|
||||
names = list_names_patern(session, pattern)
|
||||
if len(names) == 0:
|
||||
click.echo("No benchmark satisfies the pattern.")
|
||||
else:
|
||||
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("pattern", required=False)
|
||||
@with_client
|
||||
def run(client: "TrezorClient", pattern: Optional[str]) -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def run(session: "Session", pattern: Optional[str]) -> None:
|
||||
"""Run benchmark"""
|
||||
names = list_names_patern(client, pattern)
|
||||
names = list_names_patern(session, pattern)
|
||||
if len(names) == 0:
|
||||
click.echo("No benchmark satisfies the pattern.")
|
||||
else:
|
||||
for name in names:
|
||||
result = benchmark.run(client, name)
|
||||
result = benchmark.run(session, name)
|
||||
click.echo(f"{name}: {result.value} {result.unit}")
|
||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import binance, tools
|
||||
from . import with_client
|
||||
from ..transport.session import Session
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import messages
|
||||
from ..client import TrezorClient
|
||||
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0"
|
||||
@ -39,23 +39,23 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Binance address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_address(client, address_n, show_display, chunkify)
|
||||
return binance.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Binance public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.get_public_key(client, address_n, show_display).hex()
|
||||
return binance.get_public_key(session, address_n, show_display).hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> "messages.BinanceSignedTx":
|
||||
"""Sign Binance transaction.
|
||||
|
||||
Transaction must be provided as a JSON file.
|
||||
"""
|
||||
address_n = tools.parse_path(address)
|
||||
return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
||||
return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||
|
@ -13,6 +13,7 @@
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
@ -22,10 +23,10 @@ import click
|
||||
import construct as c
|
||||
|
||||
from .. import btc, messages, protobuf, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PURPOSE_BIP44 = 44
|
||||
PURPOSE_BIP48 = 48
|
||||
@ -174,15 +175,15 @@ def cli() -> None:
|
||||
help="Sort pubkeys lexicographically using BIP-67",
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
script_type: Optional[messages.InputScriptType],
|
||||
script_type: messages.InputScriptType | None,
|
||||
show_display: bool,
|
||||
multisig_xpub: List[str],
|
||||
multisig_threshold: Optional[int],
|
||||
multisig_threshold: int | None,
|
||||
multisig_suffix_length: int,
|
||||
multisig_sort_pubkeys: bool,
|
||||
chunkify: bool,
|
||||
@ -235,7 +236,7 @@ def get_address(
|
||||
multisig = None
|
||||
|
||||
return btc.get_address(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
address_n,
|
||||
show_display,
|
||||
@ -252,9 +253,9 @@ def get_address(
|
||||
@click.option("-e", "--curve")
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_public_node(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
curve: Optional[str],
|
||||
@ -266,7 +267,7 @@ def get_public_node(
|
||||
if script_type is None:
|
||||
script_type = guess_script_type_from_path(address_n)
|
||||
result = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
ecdsa_curve_name=curve,
|
||||
show_display=show_display,
|
||||
@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str:
|
||||
|
||||
|
||||
def _get_descriptor(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: Optional[str],
|
||||
account: int,
|
||||
purpose: Optional[int],
|
||||
@ -326,7 +327,7 @@ def _get_descriptor(
|
||||
|
||||
n = tools.parse_path(path)
|
||||
pub = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
n,
|
||||
show_display=show_display,
|
||||
coin_name=coin,
|
||||
@ -363,9 +364,9 @@ def _get_descriptor(
|
||||
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
|
||||
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_descriptor(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: Optional[str],
|
||||
account: int,
|
||||
account_type: Optional[int],
|
||||
@ -375,7 +376,7 @@ def get_descriptor(
|
||||
"""Get descriptor of given account."""
|
||||
try:
|
||||
return _get_descriptor(
|
||||
client, coin, account, account_type, script_type, show_display
|
||||
session, coin, account, account_type, script_type, show_display
|
||||
)
|
||||
except ValueError as e:
|
||||
raise click.ClickException(str(e))
|
||||
@ -390,8 +391,8 @@ def get_descriptor(
|
||||
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("json_file", type=click.File())
|
||||
@with_client
|
||||
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
@with_session
|
||||
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
|
||||
"""Sign transaction.
|
||||
|
||||
Transaction data must be provided in a JSON file. See `transaction-format.md` for
|
||||
@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
}
|
||||
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
inputs,
|
||||
outputs,
|
||||
@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
message: str,
|
||||
@ -462,7 +463,7 @@ def sign_message(
|
||||
if script_type is None:
|
||||
script_type = guess_script_type_from_path(address_n)
|
||||
res = btc.sign_message(
|
||||
client,
|
||||
session,
|
||||
coin,
|
||||
address_n,
|
||||
message,
|
||||
@ -483,9 +484,9 @@ def sign_message(
|
||||
@click.argument("address")
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
coin: str,
|
||||
address: str,
|
||||
signature: str,
|
||||
@ -495,7 +496,7 @@ def verify_message(
|
||||
"""Verify message."""
|
||||
signature_bytes = base64.b64decode(signature)
|
||||
return btc.verify_message(
|
||||
client, coin, address, signature_bytes, message, chunkify=chunkify
|
||||
session, coin, address, signature_bytes, message, chunkify=chunkify
|
||||
)
|
||||
|
||||
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
||||
import click
|
||||
|
||||
from .. import cardano, messages, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0"
|
||||
|
||||
@ -62,9 +62,9 @@ def cli() -> None:
|
||||
@click.option("-i", "--include-network-id", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.option("-T", "--tag-cbor-sets", is_flag=True)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
file: TextIO,
|
||||
signing_mode: messages.CardanoTxSigningMode,
|
||||
protocol_magic: int,
|
||||
@ -123,9 +123,8 @@ def sign_tx(
|
||||
for p in transaction["additional_witness_requests"]
|
||||
]
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
sign_tx_response = cardano.sign_tx(
|
||||
client,
|
||||
session,
|
||||
signing_mode,
|
||||
inputs,
|
||||
outputs,
|
||||
@ -209,9 +208,9 @@ def sign_tx(
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
address_type: messages.CardanoAddressType,
|
||||
staking_address: str,
|
||||
@ -262,9 +261,8 @@ def get_address(
|
||||
script_staking_hash_bytes,
|
||||
)
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_address(
|
||||
client,
|
||||
session,
|
||||
address_parameters,
|
||||
protocol_magic,
|
||||
network_id,
|
||||
@ -283,18 +281,17 @@ def get_address(
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
show_display: bool,
|
||||
) -> messages.CardanoPublicKey:
|
||||
"""Get Cardano public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_public_key(
|
||||
client, address_n, derivation_type=derivation_type, show_display=show_display
|
||||
session, address_n, derivation_type=derivation_type, show_display=show_display
|
||||
)
|
||||
|
||||
|
||||
@ -312,9 +309,9 @@ def get_public_key(
|
||||
type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}),
|
||||
default=messages.CardanoDerivationType.ICARUS,
|
||||
)
|
||||
@with_client
|
||||
@with_session(derive_cardano=True)
|
||||
def get_native_script_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
file: TextIO,
|
||||
display_format: messages.CardanoNativeScriptHashDisplayFormat,
|
||||
derivation_type: messages.CardanoDerivationType,
|
||||
@ -323,7 +320,6 @@ def get_native_script_hash(
|
||||
native_script_json = json.load(file)
|
||||
native_script = cardano.parse_native_script(native_script_json)
|
||||
|
||||
client.init_device(derive_cardano=True)
|
||||
return cardano.get_native_script_hash(
|
||||
client, native_script, display_format, derivation_type=derivation_type
|
||||
session, native_script, display_format, derivation_type=derivation_type
|
||||
)
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple
|
||||
import click
|
||||
|
||||
from .. import misc, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
PROMPT_TYPE = ChoiceType(
|
||||
@ -42,10 +42,10 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.argument("size", type=int)
|
||||
@with_client
|
||||
def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||
@with_session(empty_passphrase=True)
|
||||
def get_entropy(session: "Session", size: int) -> str:
|
||||
"""Get random bytes from device."""
|
||||
return misc.get_entropy(client, size).hex()
|
||||
return misc.get_entropy(session, size).hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str:
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
@with_session(empty_passphrase=True)
|
||||
def encrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
key: str,
|
||||
value: str,
|
||||
@ -75,7 +75,7 @@ def encrypt_keyvalue(
|
||||
ask_on_encrypt, ask_on_decrypt = prompt
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.encrypt_keyvalue(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
key,
|
||||
value.encode(),
|
||||
@ -91,9 +91,9 @@ def encrypt_keyvalue(
|
||||
)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
@with_client
|
||||
@with_session(empty_passphrase=True)
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
key: str,
|
||||
value: str,
|
||||
@ -112,7 +112,7 @@ def decrypt_keyvalue(
|
||||
ask_on_encrypt, ask_on_decrypt = prompt
|
||||
address_n = tools.parse_path(address)
|
||||
return misc.decrypt_keyvalue(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
key,
|
||||
bytes.fromhex(value),
|
||||
|
@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union
|
||||
|
||||
import click
|
||||
|
||||
from .. import mapping, messages, protobuf
|
||||
from ..client import TrezorClient
|
||||
from ..debuglink import TrezorClientDebugLink
|
||||
from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max
|
||||
from ..debuglink import prodtest_t1 as debuglink_prodtest_t1
|
||||
from ..debuglink import record_screen
|
||||
from . import with_client
|
||||
from ..transport.session import Session
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import TrezorConnection
|
||||
@ -35,53 +34,6 @@ def cli() -> None:
|
||||
"""Miscellaneous debug features."""
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("message_name_or_type")
|
||||
@click.argument("hex_data")
|
||||
@click.pass_obj
|
||||
def send_bytes(
|
||||
obj: "TrezorConnection", message_name_or_type: str, hex_data: str
|
||||
) -> None:
|
||||
"""Send raw bytes to Trezor.
|
||||
|
||||
Message type and message data must be specified separately, due to how message
|
||||
chunking works on the transport level. Message length is calculated and sent
|
||||
automatically, and it is currently impossible to explicitly specify invalid length.
|
||||
|
||||
MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum,
|
||||
in which case the value of that enum is used.
|
||||
"""
|
||||
if message_name_or_type.isdigit():
|
||||
message_type = int(message_name_or_type)
|
||||
else:
|
||||
message_type = getattr(messages.MessageType, message_name_or_type)
|
||||
|
||||
if not isinstance(message_type, int):
|
||||
raise click.ClickException("Invalid message type.")
|
||||
|
||||
try:
|
||||
message_data = bytes.fromhex(hex_data)
|
||||
except Exception as e:
|
||||
raise click.ClickException("Invalid hex data.") from e
|
||||
|
||||
transport = obj.get_transport()
|
||||
transport.begin_session()
|
||||
transport.write(message_type, message_data)
|
||||
|
||||
response_type, response_data = transport.read()
|
||||
transport.end_session()
|
||||
|
||||
click.echo(f"Response type: {response_type}")
|
||||
click.echo(f"Response data: {response_data.hex()}")
|
||||
|
||||
try:
|
||||
msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
click.echo("Parsed message:")
|
||||
click.echo(protobuf.format_message(msg))
|
||||
except Exception as e:
|
||||
click.echo(f"Could not parse response: {e}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("directory", required=False)
|
||||
@click.option("-s", "--stop", is_flag=True, help="Stop the recording")
|
||||
@ -100,23 +52,22 @@ def record_screen_from_connection(
|
||||
"""Record screen helper to transform TrezorConnection into TrezorClientDebugLink."""
|
||||
transport = obj.get_transport()
|
||||
debug_client = TrezorClientDebugLink(transport, auto_interact=False)
|
||||
debug_client.open()
|
||||
record_screen(debug_client, directory, report_func=click.echo)
|
||||
debug_client.close()
|
||||
debug_client.close_transport()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def prodtest_t1(client: "TrezorClient") -> None:
|
||||
@with_session(seedless=True)
|
||||
def prodtest_t1(session: "Session") -> None:
|
||||
"""Perform a prodtest on Model One.
|
||||
|
||||
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
|
||||
"""
|
||||
debuglink_prodtest_t1(client)
|
||||
debuglink_prodtest_t1(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def optiga_set_sec_max(client: "TrezorClient") -> None:
|
||||
@with_session(seedless=True)
|
||||
def optiga_set_sec_max(session: "Session") -> None:
|
||||
"""Set Optiga's security event counter to maximum."""
|
||||
debuglink_optiga_set_sec_max(client)
|
||||
debuglink_optiga_set_sec_max(session)
|
||||
|
@ -25,10 +25,10 @@ import requests
|
||||
|
||||
from .. import authentication, debuglink, device, exceptions, messages, ui
|
||||
from ..tools import format_path
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
from . import TrezorConnection
|
||||
|
||||
RECOVERY_DEVICE_INPUT_METHOD = {
|
||||
@ -64,17 +64,18 @@ def cli() -> None:
|
||||
help="Wipe device in bootloader mode. This also erases the firmware.",
|
||||
is_flag=True,
|
||||
)
|
||||
@with_client
|
||||
def wipe(client: "TrezorClient", bootloader: bool) -> None:
|
||||
@with_session(seedless=True)
|
||||
def wipe(session: "Session", bootloader: bool) -> None:
|
||||
"""Reset device to factory defaults and remove all private data."""
|
||||
features = session.features
|
||||
if bootloader:
|
||||
if not client.features.bootloader_mode:
|
||||
if not features.bootloader_mode:
|
||||
click.echo("Please switch your device to bootloader mode.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo("Wiping user data and firmware!")
|
||||
else:
|
||||
if client.features.bootloader_mode:
|
||||
if features.bootloader_mode:
|
||||
click.echo(
|
||||
"Your device is in bootloader mode. This operation would also erase firmware."
|
||||
)
|
||||
@ -86,7 +87,11 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
|
||||
else:
|
||||
click.echo("Wiping user data!")
|
||||
|
||||
device.wipe(client)
|
||||
try:
|
||||
device.wipe(session)
|
||||
except exceptions.TrezorFailure as e:
|
||||
click.echo("Action failed: {} {}".format(*e.args))
|
||||
sys.exit(3)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -99,9 +104,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> None:
|
||||
@click.option("-a", "--academic", is_flag=True)
|
||||
@click.option("-b", "--needs-backup", is_flag=True)
|
||||
@click.option("-n", "--no-backup", is_flag=True)
|
||||
@with_client
|
||||
@with_session(seedless=True)
|
||||
def load(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
mnemonic: t.Sequence[str],
|
||||
pin: str,
|
||||
passphrase_protection: bool,
|
||||
@ -132,7 +137,7 @@ def load(
|
||||
|
||||
try:
|
||||
debuglink.load_device(
|
||||
client,
|
||||
session,
|
||||
mnemonic=list(mnemonic),
|
||||
pin=pin,
|
||||
passphrase_protection=passphrase_protection,
|
||||
@ -167,9 +172,9 @@ def load(
|
||||
)
|
||||
@click.option("-d", "--dry-run", is_flag=True)
|
||||
@click.option("-b", "--unlock-repeated-backup", is_flag=True)
|
||||
@with_client
|
||||
@with_session(seedless=True)
|
||||
def recover(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
words: str,
|
||||
expand: bool,
|
||||
pin_protection: bool,
|
||||
@ -197,7 +202,7 @@ def recover(
|
||||
type = messages.RecoveryType.UnlockRepeatedBackup
|
||||
|
||||
device.recover(
|
||||
client,
|
||||
session,
|
||||
word_count=int(words),
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -219,9 +224,9 @@ def recover(
|
||||
@click.option("-n", "--no-backup", is_flag=True)
|
||||
@click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE))
|
||||
@click.option("-e", "--entropy-check-count", type=click.IntRange(0))
|
||||
@with_client
|
||||
@with_session(seedless=True)
|
||||
def setup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
strength: int | None,
|
||||
passphrase_protection: bool,
|
||||
pin_protection: bool,
|
||||
@ -241,10 +246,10 @@ def setup(
|
||||
if (
|
||||
backup_type
|
||||
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
|
||||
and messages.Capability.Shamir not in client.features.capabilities
|
||||
and messages.Capability.Shamir not in session.features.capabilities
|
||||
) or (
|
||||
backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable)
|
||||
and messages.Capability.ShamirGroups not in client.features.capabilities
|
||||
and messages.Capability.ShamirGroups not in session.features.capabilities
|
||||
):
|
||||
click.echo(
|
||||
"WARNING: Your Trezor device does not indicate support for the requested\n"
|
||||
@ -252,7 +257,7 @@ def setup(
|
||||
)
|
||||
|
||||
path_xpubs = device.setup(
|
||||
client,
|
||||
session,
|
||||
strength=strength,
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -273,22 +278,21 @@ def setup(
|
||||
@cli.command()
|
||||
@click.option("-t", "--group-threshold", type=int)
|
||||
@click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N")
|
||||
@with_client
|
||||
@with_session(seedless=True)
|
||||
def backup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
group_threshold: int | None = None,
|
||||
groups: t.Sequence[tuple[int, int]] = (),
|
||||
) -> None:
|
||||
"""Perform device seed backup."""
|
||||
device.backup(client, group_threshold, groups)
|
||||
|
||||
device.backup(session, group_threshold, groups)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS))
|
||||
@with_client
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
) -> None:
|
||||
@with_session(seedless=True)
|
||||
def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> None:
|
||||
"""Secure the device with SD card protection.
|
||||
|
||||
When SD card protection is enabled, a randomly generated secret is stored
|
||||
@ -302,33 +306,33 @@ def sd_protect(
|
||||
off - Remove SD card secret protection.
|
||||
refresh - Replace the current SD card secret with a new one.
|
||||
"""
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
raise click.ClickException("Trezor One does not support SD card protection.")
|
||||
device.sd_protect(client, operation)
|
||||
device.sd_protect(session, operation)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_obj
|
||||
def reboot_to_bootloader(obj: "TrezorConnection") -> None:
|
||||
"""Reboot device into bootloader mode."""
|
||||
# avoid using @with_client because it closes the session afterwards,
|
||||
# avoid using @with_session because it closes the session afterwards,
|
||||
# which triggers double prompt on device
|
||||
with obj.client_context() as client:
|
||||
device.reboot_to_bootloader(client)
|
||||
device.reboot_to_bootloader(client.get_seedless_session())
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def tutorial(client: "TrezorClient") -> None:
|
||||
@with_session(seedless=True)
|
||||
def tutorial(session: "Session") -> None:
|
||||
"""Show on-device tutorial."""
|
||||
device.show_device_tutorial(client)
|
||||
device.show_device_tutorial(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def unlock_bootloader(client: "TrezorClient") -> None:
|
||||
@with_session(seedless=True)
|
||||
def unlock_bootloader(session: "Session") -> None:
|
||||
"""Unlocks bootloader. Irreversible."""
|
||||
device.unlock_bootloader(client)
|
||||
device.unlock_bootloader(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -339,12 +343,11 @@ def unlock_bootloader(client: "TrezorClient") -> None:
|
||||
type=int,
|
||||
help="Dialog expiry in seconds.",
|
||||
)
|
||||
@with_client
|
||||
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None:
|
||||
@with_session(seedless=True)
|
||||
def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> None:
|
||||
"""Show a "Do not disconnect" dialog."""
|
||||
if enable is False:
|
||||
device.set_busy(client, None)
|
||||
return
|
||||
device.set_busy(session, None)
|
||||
|
||||
if expiry is None:
|
||||
raise click.ClickException("Missing option '-e' / '--expiry'.")
|
||||
@ -354,7 +357,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
|
||||
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
|
||||
)
|
||||
|
||||
device.set_busy(client, expiry * 1000)
|
||||
device.set_busy(session, expiry * 1000)
|
||||
|
||||
|
||||
PUBKEY_WHITELIST_URL_TEMPLATE = (
|
||||
@ -374,9 +377,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = (
|
||||
is_flag=True,
|
||||
help="Do not check intermediate certificates against the whitelist.",
|
||||
)
|
||||
@with_client
|
||||
@with_session(seedless=True)
|
||||
def authenticate(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
hex_challenge: str | None,
|
||||
root: t.BinaryIO | None,
|
||||
raw: bool | None,
|
||||
@ -397,7 +400,7 @@ def authenticate(
|
||||
challenge = bytes.fromhex(hex_challenge)
|
||||
|
||||
if raw:
|
||||
msg = device.authenticate(client, challenge)
|
||||
msg = device.authenticate(session, challenge)
|
||||
|
||||
click.echo(f"Challenge: {hex_challenge}")
|
||||
click.echo(f"Signature of challenge: {msg.signature.hex()}")
|
||||
@ -436,14 +439,14 @@ def authenticate(
|
||||
else:
|
||||
whitelist_json = requests.get(
|
||||
PUBKEY_WHITELIST_URL_TEMPLATE.format(
|
||||
model=client.model.internal_name.lower()
|
||||
model=session.model.internal_name.lower()
|
||||
)
|
||||
).json()
|
||||
whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]]
|
||||
|
||||
try:
|
||||
authentication.authenticate_device(
|
||||
client, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
||||
session, challenge, root_pubkey=root_bytes, whitelist=whitelist
|
||||
)
|
||||
except authentication.DeviceNotAuthentic:
|
||||
click.echo("Device is not authentic.")
|
||||
|
@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import eos, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import messages
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0"
|
||||
|
||||
@ -37,11 +37,11 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Eos public key in base58 encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
res = eos.get_public_key(client, address_n, show_display)
|
||||
res = eos.get_public_key(session, address_n, show_display)
|
||||
return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}"
|
||||
|
||||
|
||||
@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> "messages.EosSignedTx":
|
||||
"""Sign EOS transaction."""
|
||||
tx_json = json.load(file)
|
||||
|
||||
address_n = tools.parse_path(address)
|
||||
return eos.sign_tx(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
tx_json["transaction"],
|
||||
tx_json["chain_id"],
|
||||
|
@ -26,14 +26,14 @@ import click
|
||||
|
||||
from .. import _rlp, definitions, ethereum, tools
|
||||
from ..messages import EthereumDefinitions
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import web3
|
||||
from eth_typing import ChecksumAddress # noqa: I900
|
||||
from web3.types import Wei
|
||||
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0"
|
||||
|
||||
@ -268,24 +268,24 @@ def cli(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Ethereum address in hex encoding."""
|
||||
address_n = tools.parse_path(address)
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
return ethereum.get_address(client, address_n, show_display, network, chunkify)
|
||||
return ethereum.get_address(session, address_n, show_display, network, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict:
|
||||
@with_session
|
||||
def get_public_node(session: "Session", address: str, show_display: bool) -> dict:
|
||||
"""Get Ethereum public node of given path."""
|
||||
address_n = tools.parse_path(address)
|
||||
result = ethereum.get_public_node(client, address_n, show_display=show_display)
|
||||
result = ethereum.get_public_node(session, address_n, show_display=show_display)
|
||||
return {
|
||||
"node": {
|
||||
"depth": result.node.depth,
|
||||
@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("to_address")
|
||||
@click.argument("amount", callback=_amount_to_int)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
chain_id: int,
|
||||
address: str,
|
||||
amount: int,
|
||||
@ -400,7 +400,7 @@ def sign_tx(
|
||||
encoded_network = DEFINITIONS_SOURCE.get_network(chain_id)
|
||||
address_n = tools.parse_path(address)
|
||||
from_address = ethereum.get_address(
|
||||
client, address_n, encoded_network=encoded_network
|
||||
session, address_n, encoded_network=encoded_network
|
||||
)
|
||||
|
||||
if token:
|
||||
@ -446,7 +446,7 @@ def sign_tx(
|
||||
assert max_gas_fee is not None
|
||||
assert max_priority_fee is not None
|
||||
sig = ethereum.sign_tx_eip1559(
|
||||
client,
|
||||
session,
|
||||
n=address_n,
|
||||
nonce=nonce,
|
||||
gas_limit=gas_limit,
|
||||
@ -465,7 +465,7 @@ def sign_tx(
|
||||
gas_price = _get_web3().eth.gas_price
|
||||
assert gas_price is not None
|
||||
sig = ethereum.sign_tx(
|
||||
client,
|
||||
session,
|
||||
n=address_n,
|
||||
tx_type=tx_type,
|
||||
nonce=nonce,
|
||||
@ -526,14 +526,14 @@ def sign_tx(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_message(
|
||||
client: "TrezorClient", address: str, message: str, chunkify: bool
|
||||
session: "Session", address: str, message: str, chunkify: bool
|
||||
) -> Dict[str, str]:
|
||||
"""Sign message with Ethereum address."""
|
||||
address_n = tools.parse_path(address)
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify)
|
||||
ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify)
|
||||
output = {
|
||||
"message": message,
|
||||
"address": ret.address,
|
||||
@ -550,9 +550,9 @@ def sign_message(
|
||||
help="Be compatible with Metamask's signTypedData_v4 implementation",
|
||||
)
|
||||
@click.argument("file", type=click.File("r"))
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_typed_data(
|
||||
client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO
|
||||
session: "Session", address: str, metamask_v4_compat: bool, file: TextIO
|
||||
) -> Dict[str, str]:
|
||||
"""Sign typed data (EIP-712) with Ethereum address.
|
||||
|
||||
@ -565,7 +565,7 @@ def sign_typed_data(
|
||||
defs = EthereumDefinitions(encoded_network=network)
|
||||
data = json.loads(file.read())
|
||||
ret = ethereum.sign_typed_data(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
data,
|
||||
metamask_v4_compat=metamask_v4_compat,
|
||||
@ -583,9 +583,9 @@ def sign_typed_data(
|
||||
@click.argument("address")
|
||||
@click.argument("signature")
|
||||
@click.argument("message")
|
||||
@with_client
|
||||
@with_session
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
signature: str,
|
||||
message: str,
|
||||
@ -594,7 +594,7 @@ def verify_message(
|
||||
"""Verify message signed with Ethereum address."""
|
||||
signature_bytes = ethereum.decode_hex(signature)
|
||||
return ethereum.verify_message(
|
||||
client, address, signature_bytes, message, chunkify=chunkify
|
||||
session, address, signature_bytes, message, chunkify=chunkify
|
||||
)
|
||||
|
||||
|
||||
@ -602,9 +602,9 @@ def verify_message(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.argument("domain_hash_hex")
|
||||
@click.argument("message_hash_hex")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_typed_data_hash(
|
||||
client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str
|
||||
session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Sign hash of typed data (EIP-712) with Ethereum address.
|
||||
@ -618,7 +618,7 @@ def sign_typed_data_hash(
|
||||
message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None
|
||||
network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE)
|
||||
ret = ethereum.sign_typed_data_hash(
|
||||
client, address_n, domain_hash, message_hash, network
|
||||
session, address_n, domain_hash, message_hash, network
|
||||
)
|
||||
output = {
|
||||
"domain_hash": domain_hash_hex,
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING
|
||||
import click
|
||||
|
||||
from .. import fido
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"}
|
||||
|
||||
@ -40,10 +40,10 @@ def credentials() -> None:
|
||||
|
||||
|
||||
@credentials.command(name="list")
|
||||
@with_client
|
||||
def credentials_list(client: "TrezorClient") -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def credentials_list(session: "Session") -> None:
|
||||
"""List all resident credentials on the device."""
|
||||
creds = fido.list_credentials(client)
|
||||
creds = fido.list_credentials(session)
|
||||
for cred in creds:
|
||||
click.echo("")
|
||||
click.echo(f"WebAuthn credential at index {cred.index}:")
|
||||
@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None:
|
||||
|
||||
@credentials.command(name="add")
|
||||
@click.argument("hex_credential_id")
|
||||
@with_client
|
||||
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def credentials_add(session: "Session", hex_credential_id: str) -> None:
|
||||
"""Add the credential with the given ID as a resident credential.
|
||||
|
||||
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
|
||||
"""
|
||||
fido.add_credential(client, bytes.fromhex(hex_credential_id))
|
||||
fido.add_credential(session, bytes.fromhex(hex_credential_id))
|
||||
|
||||
|
||||
@credentials.command(name="remove")
|
||||
@click.option(
|
||||
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
|
||||
)
|
||||
@with_client
|
||||
def credentials_remove(client: "TrezorClient", index: int) -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def credentials_remove(session: "Session", index: int) -> None:
|
||||
"""Remove the resident credential at the given index."""
|
||||
fido.remove_credential(client, index)
|
||||
fido.remove_credential(session, index)
|
||||
|
||||
|
||||
#
|
||||
@ -110,19 +110,19 @@ def counter() -> None:
|
||||
|
||||
@counter.command(name="set")
|
||||
@click.argument("counter", type=int)
|
||||
@with_client
|
||||
def counter_set(client: "TrezorClient", counter: int) -> None:
|
||||
@with_session(empty_passphrase=True)
|
||||
def counter_set(session: "Session", counter: int) -> None:
|
||||
"""Set FIDO/U2F counter value."""
|
||||
fido.set_counter(client, counter)
|
||||
fido.set_counter(session, counter)
|
||||
|
||||
|
||||
@counter.command(name="get-next")
|
||||
@with_client
|
||||
def counter_get_next(client: "TrezorClient") -> int:
|
||||
@with_session(empty_passphrase=True)
|
||||
def counter_get_next(session: "Session") -> int:
|
||||
"""Get-and-increase value of FIDO/U2F counter.
|
||||
|
||||
FIDO counter value cannot be read directly. On each U2F exchange, the counter value
|
||||
is returned and atomically increased. This command performs the same operation
|
||||
and returns the counter value.
|
||||
"""
|
||||
return fido.get_next_counter(client)
|
||||
return fido.get_next_counter(session)
|
||||
|
@ -37,10 +37,11 @@ import requests
|
||||
from .. import device, exceptions, firmware, messages, models
|
||||
from ..firmware import models as fw_models
|
||||
from ..models import TrezorModel
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
from . import TrezorConnection
|
||||
|
||||
MODEL_CHOICE = ChoiceType(
|
||||
@ -75,9 +76,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool:
|
||||
This is the case from bootloader version 1.8.0, and also holds for firmware version
|
||||
1.8.0 because that installs the appropriate bootloader.
|
||||
"""
|
||||
f = client.features
|
||||
version = (f.major_version, f.minor_version, f.patch_version)
|
||||
bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0)
|
||||
features = client.features
|
||||
version = client.version
|
||||
bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0)
|
||||
return bootloader_onev2
|
||||
|
||||
|
||||
@ -307,25 +308,26 @@ def find_best_firmware_version(
|
||||
If the specified version is not found, prints the closest available version
|
||||
(higher than the specified one, if existing).
|
||||
"""
|
||||
features = client.features
|
||||
model = client.model
|
||||
|
||||
if bitcoin_only is None:
|
||||
bitcoin_only = _should_use_bitcoin_only(client.features)
|
||||
bitcoin_only = _should_use_bitcoin_only(features)
|
||||
|
||||
def version_str(version: Iterable[int]) -> str:
|
||||
return ".".join(map(str, version))
|
||||
|
||||
f = client.features
|
||||
|
||||
releases = get_all_firmware_releases(client.model, bitcoin_only, beta)
|
||||
releases = get_all_firmware_releases(model, bitcoin_only, beta)
|
||||
highest_version = releases[0]["version"]
|
||||
|
||||
if version:
|
||||
want_version = [int(x) for x in version.split(".")]
|
||||
if len(want_version) != 3:
|
||||
click.echo("Please use the 'X.Y.Z' version format.")
|
||||
if want_version[0] != f.major_version:
|
||||
if want_version[0] != features.major_version:
|
||||
click.echo(
|
||||
f"Warning: Trezor {client.model.name} firmware version should be "
|
||||
f"{f.major_version}.X.Y (requested: {version})"
|
||||
f"Warning: Trezor {model.name} firmware version should be "
|
||||
f"{features.major_version}.X.Y (requested: {version})"
|
||||
)
|
||||
else:
|
||||
want_version = highest_version
|
||||
@ -360,8 +362,8 @@ def find_best_firmware_version(
|
||||
# to the newer one, in that case update to the minimal
|
||||
# compatible version first
|
||||
# Choosing the version key to compare based on (not) being in BL mode
|
||||
client_version = [f.major_version, f.minor_version, f.patch_version]
|
||||
if f.bootloader_mode:
|
||||
client_version = client.version
|
||||
if features.bootloader_mode:
|
||||
key_to_compare = "min_bootloader_version"
|
||||
else:
|
||||
key_to_compare = "min_firmware_version"
|
||||
@ -454,11 +456,11 @@ def extract_embedded_fw(
|
||||
|
||||
|
||||
def upload_firmware_into_device(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
firmware_data: bytes,
|
||||
) -> None:
|
||||
"""Perform the final act of loading the firmware into Trezor."""
|
||||
f = client.features
|
||||
f = session.features
|
||||
try:
|
||||
if f.major_version == 1 and f.firmware_present is not False:
|
||||
# Trezor One does not send ButtonRequest
|
||||
@ -468,7 +470,7 @@ def upload_firmware_into_device(
|
||||
with click.progressbar(
|
||||
label="Uploading", length=len(firmware_data), show_eta=False
|
||||
) as bar:
|
||||
firmware.update(client, firmware_data, bar.update)
|
||||
firmware.update(session, firmware_data, bar.update)
|
||||
except exceptions.Cancelled:
|
||||
click.echo("Update aborted on device.")
|
||||
except exceptions.TrezorException as e:
|
||||
@ -661,6 +663,7 @@ def update(
|
||||
against data.trezor.io information, if available.
|
||||
"""
|
||||
with obj.client_context() as client:
|
||||
seedless_session = client.get_seedless_session()
|
||||
if sum(bool(x) for x in (filename, url, version)) > 1:
|
||||
click.echo("You can use only one of: filename, url, version.")
|
||||
sys.exit(1)
|
||||
@ -716,7 +719,7 @@ def update(
|
||||
if _is_strict_update(client, firmware_data):
|
||||
header_size = _get_firmware_header_size(firmware_data)
|
||||
device.reboot_to_bootloader(
|
||||
client,
|
||||
seedless_session,
|
||||
boot_command=messages.BootCommand.INSTALL_UPGRADE,
|
||||
firmware_header=firmware_data[:header_size],
|
||||
language_data=language_data,
|
||||
@ -726,7 +729,7 @@ def update(
|
||||
click.echo(
|
||||
"WARNING: Seamless installation not possible, language data will not be uploaded."
|
||||
)
|
||||
device.reboot_to_bootloader(client)
|
||||
device.reboot_to_bootloader(seedless_session)
|
||||
|
||||
click.echo("Waiting for bootloader...")
|
||||
while True:
|
||||
@ -742,13 +745,15 @@ def update(
|
||||
click.echo("Please switch your device to bootloader mode.")
|
||||
sys.exit(1)
|
||||
|
||||
upload_firmware_into_device(client=client, firmware_data=firmware_data)
|
||||
upload_firmware_into_device(
|
||||
session=client.get_seedless_session(), firmware_data=firmware_data
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("hex_challenge", required=False)
|
||||
@with_client
|
||||
def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str:
|
||||
@with_session(seedless=True)
|
||||
def get_hash(session: "Session", hex_challenge: Optional[str]) -> str:
|
||||
"""Get a hash of the installed firmware combined with the optional challenge."""
|
||||
challenge = bytes.fromhex(hex_challenge) if hex_challenge else None
|
||||
return firmware.get_hash(client, challenge).hex()
|
||||
return firmware.get_hash(session, challenge).hex()
|
||||
|
@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict
|
||||
import click
|
||||
|
||||
from .. import messages, monero, tools
|
||||
from . import ChoiceType, with_client
|
||||
from . import ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h"
|
||||
|
||||
@ -42,9 +42,9 @@ def cli() -> None:
|
||||
default=messages.MoneroNetworkType.MAINNET,
|
||||
)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
network_type: messages.MoneroNetworkType,
|
||||
@ -52,7 +52,7 @@ def get_address(
|
||||
) -> bytes:
|
||||
"""Get Monero address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return monero.get_address(client, address_n, show_display, network_type, chunkify)
|
||||
return monero.get_address(session, address_n, show_display, network_type, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -63,13 +63,13 @@ def get_address(
|
||||
type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}),
|
||||
default=messages.MoneroNetworkType.MAINNET,
|
||||
)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_watch_key(
|
||||
client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType
|
||||
session: "Session", address: str, network_type: messages.MoneroNetworkType
|
||||
) -> Dict[str, str]:
|
||||
"""Get Monero watch key for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
res = monero.get_watch_key(client, address_n, network_type)
|
||||
res = monero.get_watch_key(session, address_n, network_type)
|
||||
# TODO: could be made required in MoneroWatchKey
|
||||
assert res.address is not None
|
||||
assert res.watch_key is not None
|
||||
|
@ -21,10 +21,10 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import nem, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h"
|
||||
|
||||
@ -39,9 +39,9 @@ def cli() -> None:
|
||||
@click.option("-N", "--network", type=int, default=0x68)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
network: int,
|
||||
show_display: bool,
|
||||
@ -49,7 +49,7 @@ def get_address(
|
||||
) -> str:
|
||||
"""Get NEM address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return nem.get_address(client, address_n, network, show_display, chunkify)
|
||||
return nem.get_address(session, address_n, network, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -58,9 +58,9 @@ def get_address(
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-b", "--broadcast", help="NIS to announce transaction to")
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
file: TextIO,
|
||||
broadcast: Optional[str],
|
||||
@ -71,7 +71,7 @@ def sign_tx(
|
||||
Transaction file is expected in the NIS (RequestPrepareAnnounce) format.
|
||||
"""
|
||||
address_n = tools.parse_path(address)
|
||||
transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify)
|
||||
transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify)
|
||||
|
||||
payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()}
|
||||
|
||||
|
@ -22,10 +22,10 @@ import typing as t
|
||||
import click
|
||||
|
||||
from .. import messages, nostr, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
|
||||
PATH_TEMPLATE = "m/44h/1237h/{}h/0/0"
|
||||
@ -38,9 +38,9 @@ def cli() -> None:
|
||||
|
||||
@cli.command()
|
||||
@click.option("-a", "--account", default=0, help="Account index")
|
||||
@with_client
|
||||
@with_session
|
||||
def get_pubkey(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
account: int,
|
||||
) -> str:
|
||||
"""Return the pubkey derived by the given path."""
|
||||
@ -48,7 +48,7 @@ def get_pubkey(
|
||||
address_n = tools.parse_path(PATH_TEMPLATE.format(account))
|
||||
|
||||
return nostr.get_pubkey(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
).hex()
|
||||
|
||||
@ -56,9 +56,9 @@ def get_pubkey(
|
||||
@cli.command()
|
||||
@click.option("-a", "--account", default=0, help="Account index")
|
||||
@click.argument("event")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_event(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
account: int,
|
||||
event: str,
|
||||
) -> dict[str, str]:
|
||||
@ -69,7 +69,7 @@ def sign_event(
|
||||
address_n = tools.parse_path(PATH_TEMPLATE.format(account))
|
||||
|
||||
res = nostr.sign_event(
|
||||
client,
|
||||
session,
|
||||
messages.NostrSignEvent(
|
||||
address_n=address_n,
|
||||
created_at=event_json["created_at"],
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import ripple, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0"
|
||||
|
||||
@ -37,13 +37,13 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Ripple address"""
|
||||
address_n = tools.parse_path(address)
|
||||
return ripple.get_address(client, address_n, show_display, chunkify)
|
||||
return ripple.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -51,13 +51,13 @@ def get_address(
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None:
|
||||
@with_session
|
||||
def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None:
|
||||
"""Sign Ripple transaction"""
|
||||
address_n = tools.parse_path(address)
|
||||
msg = ripple.create_sign_tx_msg(json.load(file))
|
||||
|
||||
result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify)
|
||||
result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify)
|
||||
click.echo("Signature:")
|
||||
click.echo(result.signature.hex())
|
||||
click.echo()
|
||||
|
@ -24,10 +24,10 @@ import click
|
||||
import requests
|
||||
|
||||
from .. import device, messages, toif
|
||||
from . import AliasedGroup, ChoiceType, with_client
|
||||
from . import AliasedGroup, ChoiceType, with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
@ -190,18 +190,18 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||
@with_client
|
||||
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
|
||||
@with_session(seedless=True)
|
||||
def pin(session: "Session", enable: Optional[bool], remove: bool) -> None:
|
||||
"""Set, change or remove PIN."""
|
||||
# Remove argument is there for backwards compatibility
|
||||
device.change_pin(client, remove=_should_remove(enable, remove))
|
||||
device.change_pin(session, remove=_should_remove(enable, remove))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-r", "--remove", is_flag=True, hidden=True)
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
|
||||
@with_client
|
||||
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
|
||||
@with_session(seedless=True)
|
||||
def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> None:
|
||||
"""Set or remove the wipe code.
|
||||
|
||||
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
|
||||
@ -209,32 +209,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> N
|
||||
removed and the device will be reset to factory defaults.
|
||||
"""
|
||||
# Remove argument is there for backwards compatibility
|
||||
device.change_wipe_code(client, remove=_should_remove(enable, remove))
|
||||
device.change_wipe_code(session, remove=_should_remove(enable, remove))
|
||||
|
||||
|
||||
@cli.command()
|
||||
# keep the deprecated -l/--label option, make it do nothing
|
||||
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.argument("label")
|
||||
@with_client
|
||||
def label(client: "TrezorClient", label: str) -> None:
|
||||
@with_session(seedless=True)
|
||||
def label(session: "Session", label: str) -> None:
|
||||
"""Set new device label."""
|
||||
device.apply_settings(client, label=label)
|
||||
device.apply_settings(session, label=label)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def brightness(client: "TrezorClient") -> None:
|
||||
@with_session(seedless=True)
|
||||
def brightness(session: "Session") -> None:
|
||||
"""Set display brightness."""
|
||||
device.set_brightness(client)
|
||||
device.set_brightness(session)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
|
||||
@with_session(seedless=True)
|
||||
def haptic_feedback(session: "Session", enable: bool) -> None:
|
||||
"""Enable or disable haptic feedback."""
|
||||
device.apply_settings(client, haptic_feedback=enable)
|
||||
device.apply_settings(session, haptic_feedback=enable)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -243,9 +243,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
|
||||
"-r", "--remove", is_flag=True, default=False, help="Switch back to english."
|
||||
)
|
||||
@click.option("-d/-D", "--display/--no-display", default=None)
|
||||
@with_client
|
||||
@with_session(seedless=True)
|
||||
def language(
|
||||
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
|
||||
session: "Session", path_or_url: str | None, remove: bool, display: bool | None
|
||||
) -> None:
|
||||
"""Set new language with translations."""
|
||||
if remove != (path_or_url is None):
|
||||
@ -269,30 +269,28 @@ def language(
|
||||
raise click.ClickException(
|
||||
f"Failed to load translations from {path_or_url}"
|
||||
) from None
|
||||
device.change_language(client, language_data=language_data, show_display=display)
|
||||
device.change_language(session, language_data=language_data, show_display=display)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("rotation", type=ChoiceType(ROTATION))
|
||||
@with_client
|
||||
def display_rotation(
|
||||
client: "TrezorClient", rotation: messages.DisplayRotation
|
||||
) -> None:
|
||||
@with_session(seedless=True)
|
||||
def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> None:
|
||||
"""Set display rotation.
|
||||
|
||||
Configure display rotation for Trezor Model T. The options are
|
||||
north, east, south or west.
|
||||
"""
|
||||
device.apply_settings(client, display_rotation=rotation)
|
||||
device.apply_settings(session, display_rotation=rotation)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("delay", type=str)
|
||||
@with_client
|
||||
def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
|
||||
@with_session(seedless=True)
|
||||
def auto_lock_delay(session: "Session", delay: str) -> None:
|
||||
"""Set auto-lock delay (in seconds)."""
|
||||
|
||||
if not client.features.pin_protection:
|
||||
if not session.features.pin_protection:
|
||||
raise click.ClickException("Set up a PIN first")
|
||||
|
||||
value, unit = delay[:-1], delay[-1:]
|
||||
@ -301,13 +299,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
|
||||
seconds = float(value) * units[unit]
|
||||
else:
|
||||
seconds = float(delay) # assume seconds if no unit is specified
|
||||
device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
|
||||
device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("flags")
|
||||
@with_client
|
||||
def flags(client: "TrezorClient", flags: str) -> None:
|
||||
@with_session(seedless=True)
|
||||
def flags(session: "Session", flags: str) -> None:
|
||||
"""Set device flags."""
|
||||
if flags.lower().startswith("0b"):
|
||||
flags_int = int(flags, 2)
|
||||
@ -315,7 +313,7 @@ def flags(client: "TrezorClient", flags: str) -> None:
|
||||
flags_int = int(flags, 16)
|
||||
else:
|
||||
flags_int = int(flags)
|
||||
device.apply_flags(client, flags=flags_int)
|
||||
device.apply_flags(session, flags=flags_int)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -324,8 +322,8 @@ def flags(client: "TrezorClient", flags: str) -> None:
|
||||
"-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False
|
||||
)
|
||||
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
|
||||
@with_client
|
||||
def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
|
||||
@with_session(seedless=True)
|
||||
def homescreen(session: "Session", filename: str, quality: int) -> None:
|
||||
"""Set new homescreen.
|
||||
|
||||
To revert to default homescreen, use 'trezorctl set homescreen default'
|
||||
@ -337,39 +335,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
|
||||
if not path.exists() or not path.is_file():
|
||||
raise click.ClickException("Cannot open file")
|
||||
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
img = image_to_t1(path)
|
||||
else:
|
||||
if client.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
||||
if session.features.homescreen_format == messages.HomescreenFormat.Jpeg:
|
||||
width = (
|
||||
client.features.homescreen_width
|
||||
if client.features.homescreen_width is not None
|
||||
session.features.homescreen_width
|
||||
if session.features.homescreen_width is not None
|
||||
else 240
|
||||
)
|
||||
height = (
|
||||
client.features.homescreen_height
|
||||
if client.features.homescreen_height is not None
|
||||
session.features.homescreen_height
|
||||
if session.features.homescreen_height is not None
|
||||
else 240
|
||||
)
|
||||
img = image_to_jpeg(path, width, height, quality)
|
||||
elif client.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
||||
width = client.features.homescreen_width
|
||||
height = client.features.homescreen_height
|
||||
elif session.features.homescreen_format == messages.HomescreenFormat.ToiG:
|
||||
width = session.features.homescreen_width
|
||||
height = session.features.homescreen_height
|
||||
if width is None or height is None:
|
||||
raise click.ClickException("Device did not report homescreen size.")
|
||||
img = image_to_toif(path, width, height, True)
|
||||
elif (
|
||||
client.features.homescreen_format == messages.HomescreenFormat.Toif
|
||||
or client.features.homescreen_format is None
|
||||
session.features.homescreen_format == messages.HomescreenFormat.Toif
|
||||
or session.features.homescreen_format is None
|
||||
):
|
||||
width = (
|
||||
client.features.homescreen_width
|
||||
if client.features.homescreen_width is not None
|
||||
session.features.homescreen_width
|
||||
if session.features.homescreen_width is not None
|
||||
else 144
|
||||
)
|
||||
height = (
|
||||
client.features.homescreen_height
|
||||
if client.features.homescreen_height is not None
|
||||
session.features.homescreen_height
|
||||
if session.features.homescreen_height is not None
|
||||
else 144
|
||||
)
|
||||
img = image_to_toif(path, width, height, False)
|
||||
@ -379,7 +377,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
|
||||
"Unknown image format requested by the device."
|
||||
)
|
||||
|
||||
device.apply_settings(client, homescreen=img)
|
||||
device.apply_settings(session, homescreen=img)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -387,9 +385,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
|
||||
"--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.'
|
||||
)
|
||||
@click.argument("level", type=ChoiceType(SAFETY_LEVELS))
|
||||
@with_client
|
||||
@with_session(seedless=True)
|
||||
def safety_checks(
|
||||
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
|
||||
session: "Session", always: bool, level: messages.SafetyCheckLevel
|
||||
) -> None:
|
||||
"""Set safety check level.
|
||||
|
||||
@ -402,18 +400,18 @@ def safety_checks(
|
||||
"""
|
||||
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
|
||||
level = messages.SafetyCheckLevel.PromptAlways
|
||||
device.apply_settings(client, safety_checks=level)
|
||||
device.apply_settings(session, safety_checks=level)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def experimental_features(client: "TrezorClient", enable: bool) -> None:
|
||||
@with_session(seedless=True)
|
||||
def experimental_features(session: "Session", enable: bool) -> None:
|
||||
"""Enable or disable experimental message types.
|
||||
|
||||
This is a developer feature. Use with caution.
|
||||
"""
|
||||
device.apply_settings(client, experimental_features=enable)
|
||||
device.apply_settings(session, experimental_features=enable)
|
||||
|
||||
|
||||
#
|
||||
@ -436,25 +434,25 @@ passphrase = cast(AliasedGroup, passphrase_main)
|
||||
|
||||
@passphrase.command(name="on")
|
||||
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
|
||||
@with_client
|
||||
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None:
|
||||
@with_session(seedless=True)
|
||||
def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> None:
|
||||
"""Enable passphrase."""
|
||||
if client.features.passphrase_protection is not True:
|
||||
if session.features.passphrase_protection is not True:
|
||||
use_passphrase = True
|
||||
else:
|
||||
use_passphrase = None
|
||||
device.apply_settings(
|
||||
client,
|
||||
session,
|
||||
use_passphrase=use_passphrase,
|
||||
passphrase_always_on_device=force_on_device,
|
||||
)
|
||||
|
||||
|
||||
@passphrase.command(name="off")
|
||||
@with_client
|
||||
def passphrase_off(client: "TrezorClient") -> None:
|
||||
@with_session(seedless=True)
|
||||
def passphrase_off(session: "Session") -> None:
|
||||
"""Disable passphrase."""
|
||||
device.apply_settings(client, use_passphrase=False)
|
||||
device.apply_settings(session, use_passphrase=False)
|
||||
|
||||
|
||||
# Registering the aliases for backwards compatibility
|
||||
@ -467,10 +465,10 @@ passphrase.aliases = {
|
||||
|
||||
@passphrase.command(name="hide")
|
||||
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
|
||||
@with_client
|
||||
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None:
|
||||
@with_session(seedless=True)
|
||||
def hide_passphrase_from_host(session: "Session", hide: bool) -> None:
|
||||
"""Enable or disable hiding passphrase coming from host.
|
||||
|
||||
This is a developer feature. Use with caution.
|
||||
"""
|
||||
device.apply_settings(client, hide_passphrase_from_host=hide)
|
||||
device.apply_settings(session, hide_passphrase_from_host=hide)
|
||||
|
@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO
|
||||
import click
|
||||
|
||||
from .. import messages, solana, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h"
|
||||
DEFAULT_PATH = "m/44h/501h/0h/0h"
|
||||
@ -21,40 +21,40 @@ def cli() -> None:
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
) -> bytes:
|
||||
"""Get Solana public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return solana.get_public_key(client, address_n, show_display)
|
||||
return solana.get_public_key(session, address_n, show_display)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
show_display: bool,
|
||||
chunkify: bool,
|
||||
) -> str:
|
||||
"""Get Solana address."""
|
||||
address_n = tools.parse_path(address)
|
||||
return solana.get_address(client, address_n, show_display, chunkify)
|
||||
return solana.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("serialized_tx", type=str)
|
||||
@click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP)
|
||||
@click.option("-a", "--additional-information-file", type=click.File("r"))
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
serialized_tx: str,
|
||||
additional_information_file: Optional[TextIO],
|
||||
@ -78,7 +78,7 @@ def sign_tx(
|
||||
)
|
||||
|
||||
return solana.sign_tx(
|
||||
client,
|
||||
session,
|
||||
address_n,
|
||||
bytes.fromhex(serialized_tx),
|
||||
additional_information,
|
||||
|
@ -21,10 +21,10 @@ from typing import TYPE_CHECKING
|
||||
import click
|
||||
|
||||
from .. import stellar, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
try:
|
||||
from stellar_sdk import (
|
||||
@ -52,13 +52,13 @@ def cli() -> None:
|
||||
)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Stellar public address."""
|
||||
address_n = tools.parse_path(address)
|
||||
return stellar.get_address(client, address_n, show_display, chunkify)
|
||||
return stellar.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -77,9 +77,9 @@ def get_address(
|
||||
help="Network passphrase (blank for public network).",
|
||||
)
|
||||
@click.argument("b64envelope")
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_transaction(
|
||||
client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str
|
||||
session: "Session", b64envelope: str, address: str, network_passphrase: str
|
||||
) -> bytes:
|
||||
"""Sign a base64-encoded transaction envelope.
|
||||
|
||||
@ -109,6 +109,6 @@ def sign_transaction(
|
||||
|
||||
address_n = tools.parse_path(address)
|
||||
tx, operations = stellar.from_envelope(envelope)
|
||||
resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase)
|
||||
resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase)
|
||||
|
||||
return base64.b64encode(resp.signature)
|
||||
|
@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO
|
||||
import click
|
||||
|
||||
from .. import messages, protobuf, tezos, tools
|
||||
from . import with_client
|
||||
from . import with_session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h"
|
||||
|
||||
@ -37,23 +37,23 @@ def cli() -> None:
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def get_address(
|
||||
client: "TrezorClient", address: str, show_display: bool, chunkify: bool
|
||||
session: "Session", address: str, show_display: bool, chunkify: bool
|
||||
) -> str:
|
||||
"""Get Tezos address for specified path."""
|
||||
address_n = tools.parse_path(address)
|
||||
return tezos.get_address(client, address_n, show_display, chunkify)
|
||||
return tezos.get_address(session, address_n, show_display, chunkify)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-d", "--show-display", is_flag=True)
|
||||
@with_client
|
||||
def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str:
|
||||
@with_session
|
||||
def get_public_key(session: "Session", address: str, show_display: bool) -> str:
|
||||
"""Get Tezos public key."""
|
||||
address_n = tools.parse_path(address)
|
||||
return tezos.get_public_key(client, address_n, show_display)
|
||||
return tezos.get_public_key(session, address_n, show_display)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) ->
|
||||
@click.option("-n", "--address", required=True, help=PATH_HELP)
|
||||
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
|
||||
@click.option("-C", "--chunkify", is_flag=True)
|
||||
@with_client
|
||||
@with_session
|
||||
def sign_tx(
|
||||
client: "TrezorClient", address: str, file: TextIO, chunkify: bool
|
||||
session: "Session", address: str, file: TextIO, chunkify: bool
|
||||
) -> messages.TezosSignedTx:
|
||||
"""Sign Tezos transaction."""
|
||||
address_n = tools.parse_path(address)
|
||||
msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file))
|
||||
return tezos.sign_tx(client, address_n, msg, chunkify=chunkify)
|
||||
return tezos.sign_tx(session, address_n, msg, chunkify=chunkify)
|
||||
|
@ -24,9 +24,9 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
|
||||
|
||||
import click
|
||||
|
||||
from .. import __version__, log, messages, protobuf, ui
|
||||
from ..client import TrezorClient
|
||||
from .. import __version__, log, messages, protobuf
|
||||
from ..transport import DeviceIsBusy, enumerate_devices
|
||||
from ..transport.session import Session
|
||||
from ..transport.udp import UdpTransport
|
||||
from . import (
|
||||
AliasedGroup,
|
||||
@ -50,7 +50,7 @@ from . import (
|
||||
solana,
|
||||
stellar,
|
||||
tezos,
|
||||
with_client,
|
||||
with_session,
|
||||
)
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
@ -286,18 +286,24 @@ def format_device_name(features: messages.Features) -> str:
|
||||
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||
"""List connected Trezor devices."""
|
||||
if no_resolve:
|
||||
return enumerate_devices()
|
||||
for d in enumerate_devices():
|
||||
click.echo(d.get_path())
|
||||
return
|
||||
|
||||
from . import get_client
|
||||
|
||||
for transport in enumerate_devices():
|
||||
try:
|
||||
client = TrezorClient(transport, ui=ui.ClickUI())
|
||||
client = get_client(transport)
|
||||
transport.open()
|
||||
description = format_device_name(client.features)
|
||||
client.end_session()
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
except Exception:
|
||||
description = "Failed to read details"
|
||||
click.echo(f"{transport} - {description}")
|
||||
except Exception as e:
|
||||
description = "Failed to read details " + str(type(e))
|
||||
finally:
|
||||
transport.close()
|
||||
click.echo(f"{transport.get_path()} - {description}")
|
||||
return None
|
||||
|
||||
|
||||
@ -315,23 +321,21 @@ def version() -> str:
|
||||
@cli.command()
|
||||
@click.argument("message")
|
||||
@click.option("-b", "--button-protection", is_flag=True)
|
||||
@with_client
|
||||
def ping(client: "TrezorClient", message: str, button_protection: bool) -> str:
|
||||
@with_session(empty_passphrase=True)
|
||||
def ping(session: "Session", message: str, button_protection: bool) -> str:
|
||||
"""Send ping message."""
|
||||
return client.ping(message, button_protection=button_protection)
|
||||
return session.ping(message, button_protection)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.pass_obj
|
||||
def get_session(obj: TrezorConnection) -> str:
|
||||
@click.option("-c", "derive_cardano", is_flag=True, help="Derive Cardano session.")
|
||||
def get_session(obj: TrezorConnection, derive_cardano: bool = False) -> str:
|
||||
"""Get a session ID for subsequent commands.
|
||||
|
||||
Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with
|
||||
`trezorctl -s SESSION_ID`, or set it to an environment variable `TREZOR_SESSION_ID`,
|
||||
to avoid having to enter passphrase for subsequent commands.
|
||||
|
||||
The session ID is valid until another client starts using Trezor, until the next
|
||||
get-session call, or until Trezor is disconnected.
|
||||
"""
|
||||
# make sure session is not resumed
|
||||
obj.session_id = None
|
||||
@ -342,25 +346,26 @@ def get_session(obj: TrezorConnection) -> str:
|
||||
"Upgrade your firmware to enable session support."
|
||||
)
|
||||
|
||||
client.ensure_unlocked()
|
||||
if client.session_id is None:
|
||||
raise click.ClickException("Passphrase not enabled or firmware too old.")
|
||||
else:
|
||||
return client.session_id.hex()
|
||||
session = obj.get_session(derive_cardano=derive_cardano)
|
||||
if session.id is None:
|
||||
raise click.ClickException("Passphrase not enabled or firmware too old.")
|
||||
else:
|
||||
return session.id.hex()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def clear_session(client: "TrezorClient") -> None:
|
||||
@with_session(must_resume=True, empty_passphrase=True)
|
||||
def clear_session(session: "Session") -> None:
|
||||
"""Clear session (remove cached PIN, passphrase, etc.)."""
|
||||
return client.clear_session()
|
||||
session.call(messages.LockDevice())
|
||||
session.end()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@with_client
|
||||
def get_features(client: "TrezorClient") -> messages.Features:
|
||||
@with_session(seedless=True)
|
||||
def get_features(session: "Session") -> messages.Features:
|
||||
"""Retrieve device features and settings."""
|
||||
return client.features
|
||||
return session.features
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
@ -13,28 +13,23 @@
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
from enum import IntEnum
|
||||
|
||||
from . import exceptions, mapping, messages, models
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability
|
||||
from .protobuf import MessageType
|
||||
from .tools import parse_path, session
|
||||
from .mapping import ProtobufMapping
|
||||
from .tools import parse_path
|
||||
from .transport import Transport, get_transport
|
||||
from .transport.thp.protocol_and_channel import Channel
|
||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .transport import Transport
|
||||
from .ui import TrezorClientUI
|
||||
|
||||
UI = TypeVar("UI", bound="TrezorClientUI")
|
||||
MT = TypeVar("MT", bound=MessageType)
|
||||
if t.TYPE_CHECKING:
|
||||
from .transport.session import Session
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -51,330 +46,170 @@ Or visit https://suite.trezor.io/
|
||||
""".strip()
|
||||
|
||||
|
||||
def get_default_client(
|
||||
path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
|
||||
) -> "TrezorClient":
|
||||
"""Get a client for a connected Trezor device.
|
||||
|
||||
Returns a TrezorClient instance with minimum fuss.
|
||||
|
||||
If path is specified, does a prefix-search for the specified device. Otherwise, uses
|
||||
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
||||
If no UI is supplied, instantiates the default CLI UI.
|
||||
"""
|
||||
from .transport import get_transport
|
||||
from .ui import ClickUI
|
||||
|
||||
if path is None:
|
||||
path = os.getenv("TREZOR_PATH")
|
||||
|
||||
transport = get_transport(path, prefix_search=True)
|
||||
if ui is None:
|
||||
ui = ClickUI()
|
||||
|
||||
return TrezorClient(transport, ui, **kwargs)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrezorClient(Generic[UI]):
|
||||
"""Trezor client, a connection to a Trezor device.
|
||||
class ProtocolVersion(IntEnum):
|
||||
UNKNOWN = 0x00
|
||||
PROTOCOL_V1 = 0x01 # Codec
|
||||
PROTOCOL_V2 = 0x02 # THP
|
||||
|
||||
This class allows you to manage connection state, send and receive protobuf
|
||||
messages, handle user interactions, and perform some generic tasks
|
||||
(send a cancel message, initialize or clear a session, ping the device).
|
||||
"""
|
||||
|
||||
model: models.TrezorModel
|
||||
transport: "Transport"
|
||||
session_id: Optional[bytes]
|
||||
ui: UI
|
||||
features: messages.Features
|
||||
class TrezorClient:
|
||||
button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None
|
||||
passphrase_callback: (
|
||||
t.Callable[[Session, messages.PassphraseRequest], t.Any] | None
|
||||
) = None
|
||||
pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None
|
||||
|
||||
_seedless_session: Session | None = None
|
||||
_features: messages.Features | None = None
|
||||
_protocol_version: int
|
||||
_setup_pin: str | None = None # Should be used only by conftest
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: "Transport",
|
||||
ui: UI,
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
model: Optional[models.TrezorModel] = None,
|
||||
_init_device: bool = True,
|
||||
transport: Transport,
|
||||
protobuf_mapping: ProtobufMapping | None = None,
|
||||
protocol: Channel | None = None,
|
||||
) -> None:
|
||||
"""Create a TrezorClient instance.
|
||||
|
||||
You have to provide a `transport`, i.e., a raw connection to the device. You can
|
||||
use `trezorlib.transport.get_transport` to find one.
|
||||
|
||||
You have to provide a UI implementation for the three kinds of interaction:
|
||||
- button request (notify the user that their interaction is needed)
|
||||
- PIN request (on T1, ask the user to input numbers for a PIN matrix)
|
||||
- passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for
|
||||
details.
|
||||
|
||||
You can supply a `session_id` you might have saved in the previous session. If
|
||||
you do, the user might not need to enter their passphrase again.
|
||||
|
||||
You can provide Trezor model information. If not provided, it is detected from
|
||||
the model name reported at initialization time.
|
||||
|
||||
By default, the instance will open a connection to the Trezor device, send an
|
||||
`Initialize` message, set up the `features` field from the response, and connect
|
||||
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
|
||||
this means that `client.features` is unset. Use `client.init_device()` or
|
||||
`client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break.
|
||||
Only use this if you are _sure_ that you know what you are doing. This feature
|
||||
might be removed at any time.
|
||||
"""
|
||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
||||
# Here, self.model could be set to None. Unless _init_device is False, it will
|
||||
# get correctly reconfigured as part of the init_device flow.
|
||||
self.model = model # type: ignore ["None" is incompatible with "TrezorModel"]
|
||||
if self.model:
|
||||
self.mapping = self.model.default_mapping
|
||||
else:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
Transport needs to be opened before calling a method (or accessing
|
||||
an attribute) for the first time. It should be closed after you're
|
||||
done using the client.
|
||||
"""
|
||||
self._is_invalidated: bool = False
|
||||
self.transport = transport
|
||||
self.ui = ui
|
||||
self.session_counter = 0
|
||||
self.session_id = session_id
|
||||
if _init_device:
|
||||
self.init_device(session_id=session_id, derive_cardano=derive_cardano)
|
||||
|
||||
def open(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.transport.begin_session()
|
||||
self.session_counter += 1
|
||||
|
||||
def close(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
# TODO call EndSession here?
|
||||
self.transport.end_session()
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._raw_write(messages.Cancel())
|
||||
|
||||
def call_raw(self, msg: MessageType) -> MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
self._raw_write(msg)
|
||||
return self._raw_read()
|
||||
|
||||
def _raw_write(self, msg: MessageType) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self.transport.write(msg_type, msg_bytes)
|
||||
|
||||
def _raw_read(self) -> MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
msg_type, msg_bytes = self.transport.read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
|
||||
try:
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
except exceptions.Cancelled:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if any(d not in "123456789" for d in pin) or not (
|
||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||
):
|
||||
self.call_raw(messages.Cancel())
|
||||
raise ValueError("Invalid PIN provided")
|
||||
|
||||
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise exceptions.PinException(resp.code, resp.message)
|
||||
if protobuf_mapping is None:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
else:
|
||||
return resp
|
||||
self.mapping = protobuf_mapping
|
||||
if protocol is None:
|
||||
self.protocol = self._get_protocol()
|
||||
else:
|
||||
self.protocol = protocol
|
||||
self.protocol.mapping = self.mapping
|
||||
|
||||
def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
|
||||
available_on_device = Capability.PassphraseEntry in self.features.capabilities
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
self._protocol_version = ProtocolVersion.PROTOCOL_V1
|
||||
else:
|
||||
self._protocol_version = ProtocolVersion.UNKNOWN
|
||||
|
||||
def send_passphrase(
|
||||
passphrase: Optional[str] = None, on_device: Optional[bool] = None
|
||||
) -> MessageType:
|
||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||
resp = self.call_raw(msg)
|
||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||
self.session_id = resp.state
|
||||
resp = self.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||
return resp
|
||||
@classmethod
|
||||
def resume(
|
||||
cls,
|
||||
transport: Transport,
|
||||
protobuf_mapping: ProtobufMapping | None = None,
|
||||
) -> TrezorClient:
|
||||
if protobuf_mapping is None:
|
||||
protobuf_mapping = mapping.DEFAULT_MAPPING
|
||||
protocol = ProtocolV1Channel(transport, protobuf_mapping)
|
||||
return TrezorClient(transport, protobuf_mapping, protocol)
|
||||
|
||||
# short-circuit old style entry
|
||||
if msg._on_device is True:
|
||||
return send_passphrase(None, None)
|
||||
|
||||
try:
|
||||
passphrase = self.ui.get_passphrase(available_on_device=available_on_device)
|
||||
except exceptions.Cancelled:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||
if not available_on_device:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise RuntimeError("Device is not capable of entering passphrase")
|
||||
else:
|
||||
return send_passphrase(on_device=True)
|
||||
|
||||
# else process host-entered passphrase
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise ValueError("Passphrase too long")
|
||||
|
||||
return send_passphrase(passphrase, on_device=False)
|
||||
|
||||
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
self._raw_write(messages.ButtonAck())
|
||||
self.ui.button_request(msg)
|
||||
return self._raw_read()
|
||||
|
||||
@session
|
||||
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
|
||||
self.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
while True:
|
||||
if isinstance(resp, messages.PinMatrixRequest):
|
||||
resp = self._callback_pin(resp)
|
||||
elif isinstance(resp, messages.PassphraseRequest):
|
||||
resp = self._callback_passphrase(resp)
|
||||
elif isinstance(resp, messages.ButtonRequest):
|
||||
resp = self._callback_button(resp)
|
||||
elif isinstance(resp, messages.Failure):
|
||||
if resp.code == messages.FailureType.ActionCancelled:
|
||||
raise exceptions.Cancelled
|
||||
raise exceptions.TrezorFailure(resp)
|
||||
elif not isinstance(resp, expect):
|
||||
raise exceptions.UnexpectedMessageError(expect, resp)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def _refresh_features(self, features: messages.Features) -> None:
|
||||
"""Update internal fields based on passed-in Features message."""
|
||||
|
||||
if not self.model:
|
||||
self.model = models.detect(features)
|
||||
|
||||
if features.vendor not in self.model.vendors:
|
||||
raise exceptions.TrezorException(f"Unrecognized vendor: {features.vendor}")
|
||||
|
||||
self.features = features
|
||||
self.version = (
|
||||
self.features.major_version,
|
||||
self.features.minor_version,
|
||||
self.features.patch_version,
|
||||
)
|
||||
self.check_firmware_version(warn_only=True)
|
||||
if self.features.session_id is not None:
|
||||
self.session_id = self.features.session_id
|
||||
self.features.session_id = None
|
||||
|
||||
@session
|
||||
def refresh_features(self) -> messages.Features:
|
||||
"""Reload features from the device.
|
||||
|
||||
Should be called after changing settings or performing operations that affect
|
||||
device state.
|
||||
"""
|
||||
resp = self.call_raw(messages.GetFeatures())
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._refresh_features(resp)
|
||||
return resp
|
||||
|
||||
@session
|
||||
def init_device(
|
||||
def get_session(
|
||||
self,
|
||||
*,
|
||||
session_id: Optional[bytes] = None,
|
||||
new_session: bool = False,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
) -> Optional[bytes]:
|
||||
"""Initialize the device and return a session ID.
|
||||
|
||||
You can optionally specify a session ID. If the session still exists on the
|
||||
device, the same session ID will be returned and the session is resumed.
|
||||
Otherwise a different session ID is returned.
|
||||
|
||||
Specify `new_session=True` to open a fresh session. Since firmware version
|
||||
1.9.0/2.3.0, the previous session will remain cached on the device, and can be
|
||||
resumed by calling `init_device` again with the appropriate session ID.
|
||||
|
||||
If neither `new_session` nor `session_id` is specified, the current session ID
|
||||
will be reused. If no session ID was cached, a new session ID will be allocated
|
||||
and returned.
|
||||
|
||||
# Version notes:
|
||||
|
||||
Trezor One older than 1.9.0 does not have session management. Optional arguments
|
||||
have no effect and the function returns None
|
||||
|
||||
Trezor T older than 2.3.0 does not have session cache. Requesting a new session
|
||||
will overwrite the old one. In addition, this function will always return None.
|
||||
A valid session_id can be obtained from the `session_id` attribute, but only
|
||||
after a passphrase-protected call is performed. You can use the following code:
|
||||
|
||||
>>> client.init_device()
|
||||
>>> client.ensure_unlocked()
|
||||
>>> valid_session_id = client.session_id
|
||||
passphrase: str | object | None = None,
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
should_derive: bool = True,
|
||||
) -> Session:
|
||||
"""
|
||||
if new_session:
|
||||
self.session_id = None
|
||||
elif session_id is not None:
|
||||
self.session_id = session_id
|
||||
Returns initialized session (with derived seed).
|
||||
|
||||
resp = self.call_raw(
|
||||
messages.Initialize(
|
||||
session_id=self.session_id,
|
||||
Will fail if the device is not initialized
|
||||
"""
|
||||
from .transport.session import SessionV1, derive_seed
|
||||
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
session = SessionV1.new(
|
||||
self,
|
||||
derive_cardano=derive_cardano,
|
||||
session_id=session_id,
|
||||
)
|
||||
if should_derive:
|
||||
if isinstance(passphrase, str):
|
||||
temporary = self.passphrase_callback
|
||||
self.passphrase_callback = get_callback_passphrase_v1(
|
||||
passphrase=passphrase
|
||||
)
|
||||
derive_seed(session)
|
||||
self.passphrase_callback = temporary
|
||||
elif passphrase is PASSPHRASE_ON_DEVICE:
|
||||
derive_seed(session)
|
||||
|
||||
return session
|
||||
raise NotImplementedError
|
||||
|
||||
def resume_session(self, session: Session) -> Session:
|
||||
"""
|
||||
Note: this function potentially modifies the input session.
|
||||
"""
|
||||
from .transport.session import SessionV1
|
||||
|
||||
if isinstance(session, SessionV1):
|
||||
session.init_session()
|
||||
return session
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_seedless_session(self, new_session: bool = False) -> Session:
|
||||
from .transport.session import SessionV1
|
||||
|
||||
if not new_session and self._seedless_session is not None:
|
||||
return self._seedless_session
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
self._seedless_session = SessionV1.new(client=self, derive_cardano=False)
|
||||
assert self._seedless_session is not None
|
||||
return self._seedless_session
|
||||
|
||||
def invalidate(self) -> None:
|
||||
self._is_invalidated = True
|
||||
|
||||
@property
|
||||
def features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self._features = self.protocol.get_features()
|
||||
self.check_firmware_version(warn_only=True)
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
@property
|
||||
def protocol_version(self) -> int:
|
||||
return self._protocol_version
|
||||
|
||||
@property
|
||||
def model(self) -> models.TrezorModel:
|
||||
model = models.detect(self.features)
|
||||
if self.features.vendor not in model.vendors:
|
||||
raise exceptions.TrezorException(
|
||||
f"Unrecognized vendor: {self.features.vendor}"
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def version(self) -> tuple[int, int, int]:
|
||||
f = self.features
|
||||
ver = (
|
||||
f.major_version,
|
||||
f.minor_version,
|
||||
f.patch_version,
|
||||
)
|
||||
if isinstance(resp, messages.Failure):
|
||||
# can happen if `derive_cardano` does not match the current session
|
||||
raise exceptions.TrezorFailure(resp)
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to Initialize")
|
||||
return ver
|
||||
|
||||
if self.session_id is not None and resp.session_id == self.session_id:
|
||||
LOG.info("Successfully resumed session")
|
||||
elif session_id is not None:
|
||||
LOG.info("Failed to resume session")
|
||||
@property
|
||||
def is_invalidated(self) -> bool:
|
||||
return self._is_invalidated
|
||||
|
||||
# TT < 2.3.0 compatibility:
|
||||
# _refresh_features will clear out the session_id field. We want this function
|
||||
# to return its value, so that callers can rely on it being either a valid
|
||||
# session_id, or None if we can't do that.
|
||||
# Older TT FW does not report session_id in Features and self.session_id might
|
||||
# be invalid because TT will not allocate a session_id until a passphrase
|
||||
# exchange happens.
|
||||
reported_session_id = resp.session_id
|
||||
self._refresh_features(resp)
|
||||
return reported_session_id
|
||||
def refresh_features(self) -> messages.Features:
|
||||
self.protocol.update_features()
|
||||
self._features = self.protocol.get_features()
|
||||
self.check_firmware_version(warn_only=True)
|
||||
return self._features
|
||||
|
||||
def _get_protocol(self) -> Channel:
|
||||
protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
|
||||
return protocol
|
||||
|
||||
def is_outdated(self) -> bool:
|
||||
if self.features.bootloader_mode:
|
||||
@ -388,101 +223,35 @@ class TrezorClient(Generic[UI]):
|
||||
else:
|
||||
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
||||
|
||||
def ping(self, msg: str, button_protection: bool = False) -> str:
|
||||
# We would like ping to work on any valid TrezorClient instance, but
|
||||
# due to the protection modes, we need to go through self.call, and that will
|
||||
# raise an exception if the firmware is too old.
|
||||
# So we short-circuit the simplest variant of ping with call_raw.
|
||||
if not button_protection:
|
||||
# XXX this should be: `with self:`
|
||||
try:
|
||||
self.open()
|
||||
resp = self.call_raw(messages.Ping(message=msg))
|
||||
if isinstance(resp, messages.ButtonRequest):
|
||||
# device is PIN-locked.
|
||||
# respond and hope for the best
|
||||
resp = self._callback_button(resp)
|
||||
resp = messages.Success.ensure_isinstance(resp)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
resp = self.call(
|
||||
messages.Ping(message=msg, button_protection=button_protection),
|
||||
expect=messages.Success,
|
||||
)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
def get_default_client(
|
||||
path: t.Optional[str] = None,
|
||||
**kwargs: t.Any,
|
||||
) -> "TrezorClient":
|
||||
"""Get a client for a connected Trezor device.
|
||||
|
||||
def get_device_id(self) -> Optional[str]:
|
||||
return self.features.device_id
|
||||
Returns a TrezorClient instance with minimum fuss.
|
||||
|
||||
@session
|
||||
def lock(self, *, _refresh_features: bool = True) -> None:
|
||||
"""Lock the device.
|
||||
Transport is opened and should be closed after you're done with the client.
|
||||
|
||||
If the device does not have a PIN configured, this will do nothing.
|
||||
Otherwise, a lock screen will be shown and the device will prompt for PIN
|
||||
before further actions.
|
||||
If path is specified, does a prefix-search for the specified device. Otherwise, uses
|
||||
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
||||
"""
|
||||
|
||||
This call does _not_ invalidate passphrase cache. If passphrase is in use,
|
||||
the device will not prompt for it after unlocking.
|
||||
if path is None:
|
||||
path = os.getenv("TREZOR_PATH")
|
||||
|
||||
To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
|
||||
passphrase cache, use `clear_session()`.
|
||||
"""
|
||||
# Private argument _refresh_features can be used internally to avoid
|
||||
# refreshing in cases where we will refresh soon anyway. This is used
|
||||
# in TrezorClient.clear_session()
|
||||
self.call(messages.LockDevice())
|
||||
if _refresh_features:
|
||||
self.refresh_features()
|
||||
transport = get_transport(path, prefix_search=True)
|
||||
transport.open()
|
||||
|
||||
@session
|
||||
def ensure_unlocked(self) -> None:
|
||||
"""Ensure the device is unlocked and a passphrase is cached.
|
||||
return TrezorClient(transport, **kwargs)
|
||||
|
||||
If the device is locked, this will prompt for PIN. If passphrase is enabled
|
||||
and no passphrase is cached for the current session, the device will also
|
||||
prompt for passphrase.
|
||||
|
||||
After calling this method, further actions on the device will not prompt for
|
||||
PIN or passphrase until the device is locked or the session becomes invalid.
|
||||
"""
|
||||
from .btc import get_address
|
||||
def get_callback_passphrase_v1(
|
||||
passphrase: str = "",
|
||||
) -> t.Callable[[Session, t.Any], t.Any] | None:
|
||||
|
||||
get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
|
||||
self.refresh_features()
|
||||
def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any:
|
||||
return session.call(messages.PassphraseAck(passphrase=passphrase))
|
||||
|
||||
def end_session(self) -> None:
|
||||
"""Close the current session and clear cached passphrase.
|
||||
|
||||
The session will become invalid until `init_device()` is called again.
|
||||
If passphrase is enabled, further actions will prompt for it again.
|
||||
|
||||
This is a no-op in bootloader mode, as it does not support session management.
|
||||
"""
|
||||
# since: 2.3.4, 1.9.4
|
||||
try:
|
||||
if not self.features.bootloader_mode:
|
||||
self.call(messages.EndSession())
|
||||
except exceptions.TrezorFailure:
|
||||
# A failure most likely means that the FW version does not support
|
||||
# the EndSession call. We ignore the failure and clear the local session_id.
|
||||
# The client-side end result is identical.
|
||||
pass
|
||||
self.session_id = None
|
||||
|
||||
@session
|
||||
def clear_session(self) -> None:
|
||||
"""Lock the device and present a fresh session.
|
||||
|
||||
The current session will be invalidated and a new one will be started. If the
|
||||
device has PIN enabled, it will become locked.
|
||||
|
||||
Equivalent to calling `lock()`, `end_session()` and `init_device()`.
|
||||
"""
|
||||
self.lock(_refresh_features=False)
|
||||
self.end_session()
|
||||
self.init_device(new_session=True)
|
||||
return _callback_passphrase_v1
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -28,16 +28,10 @@ from slip10 import SLIP10
|
||||
|
||||
from . import messages
|
||||
from .exceptions import Cancelled, TrezorException
|
||||
from .tools import (
|
||||
Address,
|
||||
_deprecation_retval_helper,
|
||||
_return_success,
|
||||
parse_path,
|
||||
session,
|
||||
)
|
||||
from .tools import Address, _deprecation_retval_helper, _return_success, parse_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||
@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
|
||||
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
|
||||
|
||||
|
||||
@session
|
||||
def apply_settings(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
label: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
use_passphrase: Optional[bool] = None,
|
||||
@ -79,13 +72,13 @@ def apply_settings(
|
||||
haptic_feedback=haptic_feedback,
|
||||
)
|
||||
|
||||
out = client.call(settings, expect=messages.Success)
|
||||
client.refresh_features()
|
||||
out = session.call(settings, expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(out)
|
||||
|
||||
|
||||
def _send_language_data(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
request: "messages.TranslationDataRequest",
|
||||
language_data: bytes,
|
||||
) -> None:
|
||||
@ -95,69 +88,61 @@ def _send_language_data(
|
||||
data_length = response.data_length
|
||||
data_offset = response.data_offset
|
||||
chunk = language_data[data_offset : data_offset + data_length]
|
||||
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||
response = session.call(messages.TranslationDataAck(data_chunk=chunk))
|
||||
|
||||
|
||||
@session
|
||||
def change_language(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
language_data: bytes,
|
||||
show_display: bool | None = None,
|
||||
) -> str | None:
|
||||
data_length = len(language_data)
|
||||
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
if data_length > 0:
|
||||
response = messages.TranslationDataRequest.ensure_isinstance(response)
|
||||
_send_language_data(client, response, language_data)
|
||||
_send_language_data(session, response, language_data)
|
||||
else:
|
||||
messages.Success.ensure_isinstance(response)
|
||||
client.refresh_features() # changing the language in features
|
||||
session.refresh_features() # changing the language in features
|
||||
return _return_success(messages.Success(message="Language changed."))
|
||||
|
||||
|
||||
@session
|
||||
def apply_flags(client: "TrezorClient", flags: int) -> str | None:
|
||||
out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
def apply_flags(session: "Session", flags: int) -> str | None:
|
||||
out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(out)
|
||||
|
||||
|
||||
@session
|
||||
def change_pin(client: "TrezorClient", remove: bool = False) -> str | None:
|
||||
ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
def change_pin(session: "Session", remove: bool = False) -> str | None:
|
||||
ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None:
|
||||
ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
def change_wipe_code(session: "Session", remove: bool = False) -> str | None:
|
||||
ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def sd_protect(
|
||||
client: "TrezorClient", operation: messages.SdProtectOperationType
|
||||
session: "Session", operation: messages.SdProtectOperationType
|
||||
) -> str | None:
|
||||
ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def wipe(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.WipeDevice(), expect=messages.Success)
|
||||
if not client.features.bootloader_mode:
|
||||
client.init_device()
|
||||
def wipe(session: "Session") -> str | None:
|
||||
ret = session.call(messages.WipeDevice(), expect=messages.Success)
|
||||
session.invalidate()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def recover(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
word_count: int = 24,
|
||||
passphrase_protection: bool = False,
|
||||
pin_protection: bool = True,
|
||||
@ -193,13 +178,13 @@ def recover(
|
||||
if type is None:
|
||||
type = messages.RecoveryType.NormalRecovery
|
||||
|
||||
if client.features.model == "1" and input_callback is None:
|
||||
if session.features.model == "1" and input_callback is None:
|
||||
raise RuntimeError("Input callback required for Trezor One")
|
||||
|
||||
if word_count not in (12, 18, 24):
|
||||
raise ValueError("Invalid word count. Use 12/18/24")
|
||||
|
||||
if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||
if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
|
||||
raise RuntimeError(
|
||||
"Device already initialized. Call device.wipe() and try again."
|
||||
)
|
||||
@ -221,20 +206,20 @@ def recover(
|
||||
msg.label = label
|
||||
msg.u2f_counter = u2f_counter
|
||||
|
||||
res = client.call(msg)
|
||||
res = session.call(msg)
|
||||
|
||||
while isinstance(res, messages.WordRequest):
|
||||
try:
|
||||
assert input_callback is not None
|
||||
inp = input_callback(res.type)
|
||||
res = client.call(messages.WordAck(word=inp))
|
||||
res = session.call(messages.WordAck(word=inp))
|
||||
except Cancelled:
|
||||
res = client.call(messages.Cancel())
|
||||
res = session.call(messages.Cancel())
|
||||
|
||||
# check that the result is a Success
|
||||
res = messages.Success.ensure_isinstance(res)
|
||||
# reinitialize the device
|
||||
client.init_device()
|
||||
session.refresh_features()
|
||||
|
||||
return _deprecation_retval_helper(res)
|
||||
|
||||
@ -280,7 +265,7 @@ def _seed_from_entropy(
|
||||
|
||||
|
||||
def reset(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
display_random: bool = False,
|
||||
strength: Optional[int] = None,
|
||||
passphrase_protection: bool = False,
|
||||
@ -313,7 +298,7 @@ def reset(
|
||||
)
|
||||
|
||||
setup(
|
||||
client,
|
||||
session,
|
||||
strength=strength,
|
||||
passphrase_protection=passphrase_protection,
|
||||
pin_protection=pin_protection,
|
||||
@ -331,9 +316,8 @@ def _get_external_entropy() -> bytes:
|
||||
return secrets.token_bytes(32)
|
||||
|
||||
|
||||
@session
|
||||
def setup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
*,
|
||||
strength: Optional[int] = None,
|
||||
passphrase_protection: bool = True,
|
||||
@ -388,19 +372,19 @@ def setup(
|
||||
check.
|
||||
"""
|
||||
|
||||
if client.features.initialized:
|
||||
if session.features.initialized:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call wipe_device() and try again."
|
||||
)
|
||||
|
||||
if strength is None:
|
||||
if client.features.model == "1":
|
||||
if session.features.model == "1":
|
||||
strength = 256
|
||||
else:
|
||||
strength = 128
|
||||
|
||||
if backup_type is None:
|
||||
if client.version < SLIP39_EXTENDABLE_MIN_VERSION:
|
||||
if session.version < SLIP39_EXTENDABLE_MIN_VERSION:
|
||||
# includes Trezor One 1.x.x
|
||||
backup_type = messages.BackupType.Bip39
|
||||
else:
|
||||
@ -411,7 +395,7 @@ def setup(
|
||||
paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")]
|
||||
|
||||
if entropy_check_count is None:
|
||||
if client.version < ENTROPY_CHECK_MIN_VERSION:
|
||||
if session.version < ENTROPY_CHECK_MIN_VERSION:
|
||||
# includes Trezor One 1.x.x
|
||||
entropy_check_count = 0
|
||||
else:
|
||||
@ -431,18 +415,18 @@ def setup(
|
||||
)
|
||||
if entropy_check_count > 0:
|
||||
xpubs = _reset_with_entropycheck(
|
||||
client, msg, entropy_check_count, paths, _get_entropy
|
||||
session, msg, entropy_check_count, paths, _get_entropy
|
||||
)
|
||||
else:
|
||||
_reset_no_entropycheck(client, msg, _get_entropy)
|
||||
_reset_no_entropycheck(session, msg, _get_entropy)
|
||||
xpubs = []
|
||||
|
||||
client.init_device()
|
||||
session.refresh_features()
|
||||
return xpubs
|
||||
|
||||
|
||||
def _reset_no_entropycheck(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
msg: messages.ResetDevice,
|
||||
get_entropy: Callable[[], bytes],
|
||||
) -> None:
|
||||
@ -454,12 +438,12 @@ def _reset_no_entropycheck(
|
||||
<< Success
|
||||
"""
|
||||
assert msg.entropy_check is False
|
||||
client.call(msg, expect=messages.EntropyRequest)
|
||||
client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
|
||||
session.call(msg, expect=messages.EntropyRequest)
|
||||
session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
|
||||
|
||||
|
||||
def _reset_with_entropycheck(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
reset_msg: messages.ResetDevice,
|
||||
entropy_check_count: int,
|
||||
paths: Iterable[Address],
|
||||
@ -495,7 +479,7 @@ def _reset_with_entropycheck(
|
||||
def get_xpubs() -> list[tuple[Address, str]]:
|
||||
xpubs = []
|
||||
for path in paths:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.GetPublicKey(address_n=path), expect=messages.PublicKey
|
||||
)
|
||||
xpubs.append((path, resp.xpub))
|
||||
@ -524,13 +508,13 @@ def _reset_with_entropycheck(
|
||||
raise TrezorException("Invalid XPUB in entropy check")
|
||||
|
||||
xpubs = []
|
||||
resp = client.call(reset_msg, expect=messages.EntropyRequest)
|
||||
resp = session.call(reset_msg, expect=messages.EntropyRequest)
|
||||
entropy_commitment = resp.entropy_commitment
|
||||
|
||||
while True:
|
||||
# provide external entropy for this round
|
||||
external_entropy = get_entropy()
|
||||
client.call(
|
||||
session.call(
|
||||
messages.EntropyAck(entropy=external_entropy),
|
||||
expect=messages.EntropyCheckReady,
|
||||
)
|
||||
@ -540,7 +524,7 @@ def _reset_with_entropycheck(
|
||||
|
||||
if entropy_check_count <= 0:
|
||||
# last round, wait for a Success and exit the loop
|
||||
client.call(
|
||||
session.call(
|
||||
messages.EntropyCheckContinue(finish=True),
|
||||
expect=messages.Success,
|
||||
)
|
||||
@ -549,7 +533,7 @@ def _reset_with_entropycheck(
|
||||
entropy_check_count -= 1
|
||||
|
||||
# Next round starts.
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.EntropyCheckContinue(finish=False),
|
||||
expect=messages.EntropyRequest,
|
||||
)
|
||||
@ -570,13 +554,12 @@ def _reset_with_entropycheck(
|
||||
return xpubs
|
||||
|
||||
|
||||
@session
|
||||
def backup(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
group_threshold: Optional[int] = None,
|
||||
groups: Iterable[tuple[int, int]] = (),
|
||||
) -> str | None:
|
||||
ret = client.call(
|
||||
ret = session.call(
|
||||
messages.BackupDevice(
|
||||
group_threshold=group_threshold,
|
||||
groups=[
|
||||
@ -586,37 +569,36 @@ def backup(
|
||||
),
|
||||
expect=messages.Success,
|
||||
)
|
||||
client.refresh_features()
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def cancel_authorization(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.CancelAuthorization(), expect=messages.Success)
|
||||
def cancel_authorization(session: "Session") -> str | None:
|
||||
ret = session.call(messages.CancelAuthorization(), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def unlock_path(client: "TrezorClient", n: "Address") -> bytes:
|
||||
resp = client.call(
|
||||
def unlock_path(session: "Session", n: "Address") -> bytes:
|
||||
resp = session.call(
|
||||
messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
|
||||
)
|
||||
|
||||
# Cancel the UnlockPath workflow now that we have the authentication code.
|
||||
try:
|
||||
client.call(messages.Cancel())
|
||||
session.call(messages.Cancel())
|
||||
except Cancelled:
|
||||
return resp.mac
|
||||
else:
|
||||
raise TrezorException("Unexpected response in UnlockPath flow")
|
||||
|
||||
|
||||
@session
|
||||
def reboot_to_bootloader(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
|
||||
firmware_header: Optional[bytes] = None,
|
||||
language_data: bytes = b"",
|
||||
) -> str | None:
|
||||
response = client.call(
|
||||
response = session.call(
|
||||
messages.RebootToBootloader(
|
||||
boot_command=boot_command,
|
||||
firmware_header=firmware_header,
|
||||
@ -624,43 +606,38 @@ def reboot_to_bootloader(
|
||||
)
|
||||
)
|
||||
if isinstance(response, messages.TranslationDataRequest):
|
||||
response = _send_language_data(client, response, language_data)
|
||||
response = _send_language_data(session, response, language_data)
|
||||
return _return_success(messages.Success(message=""))
|
||||
|
||||
|
||||
@session
|
||||
def show_device_tutorial(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
|
||||
def show_device_tutorial(session: "Session") -> str | None:
|
||||
ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def unlock_bootloader(client: "TrezorClient") -> str | None:
|
||||
ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
|
||||
def unlock_bootloader(session: "Session") -> str | None:
|
||||
ret = session.call(messages.UnlockBootloader(), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
@session
|
||||
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
|
||||
def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None:
|
||||
"""Sets or clears the busy state of the device.
|
||||
|
||||
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
|
||||
Setting `expiry_ms=None` clears the busy state.
|
||||
"""
|
||||
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
|
||||
client.refresh_features()
|
||||
ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
|
||||
session.refresh_features()
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def authenticate(
|
||||
client: "TrezorClient", challenge: bytes
|
||||
) -> messages.AuthenticityProof:
|
||||
return client.call(
|
||||
def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof:
|
||||
return session.call(
|
||||
messages.AuthenticateDevice(challenge=challenge),
|
||||
expect=messages.AuthenticityProof,
|
||||
)
|
||||
|
||||
|
||||
def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None:
|
||||
ret = client.call(messages.SetBrightness(value=value), expect=messages.Success)
|
||||
def set_brightness(session: "Session", value: Optional[int] = None) -> str | None:
|
||||
ret = session.call(messages.SetBrightness(value=value), expect=messages.Success)
|
||||
return _return_success(ret)
|
||||
|
@ -18,11 +18,11 @@ from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from . import exceptions, messages
|
||||
from .tools import b58decode, session
|
||||
from .tools import b58decode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def name_to_number(name: str) -> int:
|
||||
@ -319,17 +319,16 @@ def parse_transaction_json(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
session: "Session", n: "Address", show_display: bool = False
|
||||
) -> messages.EosPublicKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EosGetPublicKey(address_n=n, show_display=show_display),
|
||||
expect=messages.EosPublicKey,
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: "Address",
|
||||
transaction: dict,
|
||||
chain_id: str,
|
||||
@ -345,11 +344,11 @@ def sign_tx(
|
||||
chunkify=chunkify,
|
||||
)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
|
||||
try:
|
||||
while isinstance(response, messages.EosTxActionRequest):
|
||||
response = client.call(actions.pop(0))
|
||||
response = session.call(actions.pop(0))
|
||||
except IndexError:
|
||||
# pop from empty list
|
||||
raise exceptions.TrezorException(
|
||||
|
@ -18,11 +18,11 @@ import re
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
|
||||
|
||||
from . import definitions, exceptions, messages
|
||||
from .tools import prepare_message_bytes, session, unharden
|
||||
from .tools import prepare_message_bytes, unharden
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def int_to_big_endian(value: int) -> bytes:
|
||||
@ -161,13 +161,13 @@ def network_from_address_n(
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
encoded_network: Optional[bytes] = None,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
resp = client.call(
|
||||
resp = session.call(
|
||||
messages.EthereumGetAddress(
|
||||
address_n=n,
|
||||
show_display=show_display,
|
||||
@ -181,17 +181,16 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_node(
|
||||
client: "TrezorClient", n: "Address", show_display: bool = False
|
||||
session: "Session", n: "Address", show_display: bool = False
|
||||
) -> messages.EthereumPublicKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
|
||||
expect=messages.EthereumPublicKey,
|
||||
)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
nonce: int,
|
||||
gas_price: int,
|
||||
@ -227,13 +226,13 @@ def sign_tx(
|
||||
data, chunk = data[1024:], data[:1024]
|
||||
msg.data_initial_chunk = chunk
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
@ -248,9 +247,8 @@ def sign_tx(
|
||||
return response.signature_v, response.signature_r, response.signature_s
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx_eip1559(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
*,
|
||||
nonce: int,
|
||||
@ -283,13 +281,13 @@ def sign_tx_eip1559(
|
||||
chunkify=chunkify,
|
||||
)
|
||||
|
||||
response = client.call(msg)
|
||||
response = session.call(msg)
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
while response.data_length is not None:
|
||||
data_length = response.data_length
|
||||
data, chunk = data[data_length:], data[:data_length]
|
||||
response = client.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
response = session.call(messages.EthereumTxAck(data_chunk=chunk))
|
||||
assert isinstance(response, messages.EthereumTxRequest)
|
||||
|
||||
assert response.signature_v is not None
|
||||
@ -299,13 +297,13 @@ def sign_tx_eip1559(
|
||||
|
||||
|
||||
def sign_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
message: AnyStr,
|
||||
encoded_network: Optional[bytes] = None,
|
||||
chunkify: bool = False,
|
||||
) -> messages.EthereumMessageSignature:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumSignMessage(
|
||||
address_n=n,
|
||||
message=prepare_message_bytes(message),
|
||||
@ -317,7 +315,7 @@ def sign_message(
|
||||
|
||||
|
||||
def sign_typed_data(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
data: Dict[str, Any],
|
||||
*,
|
||||
@ -333,7 +331,7 @@ def sign_typed_data(
|
||||
metamask_v4_compat=metamask_v4_compat,
|
||||
definitions=definitions,
|
||||
)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
# Sending all the types
|
||||
while isinstance(response, messages.EthereumTypedDataStructRequest):
|
||||
@ -349,7 +347,7 @@ def sign_typed_data(
|
||||
members.append(struct_member)
|
||||
|
||||
request = messages.EthereumTypedDataStructAck(members=members)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
# Sending the whole message that should be signed
|
||||
while isinstance(response, messages.EthereumTypedDataValueRequest):
|
||||
@ -362,7 +360,7 @@ def sign_typed_data(
|
||||
member_typename = data["primaryType"]
|
||||
member_data = data["message"]
|
||||
else:
|
||||
client.cancel()
|
||||
session.cancel()
|
||||
raise exceptions.TrezorException("Root index can only be 0 or 1")
|
||||
|
||||
# It can be asking for a nested structure (the member path being [X, Y, Z, ...])
|
||||
@ -385,20 +383,20 @@ def sign_typed_data(
|
||||
encoded_data = encode_data(member_data, member_typename)
|
||||
|
||||
request = messages.EthereumTypedDataValueAck(value=encoded_data)
|
||||
response = client.call(request)
|
||||
response = session.call(request)
|
||||
|
||||
return messages.EthereumTypedDataSignature.ensure_isinstance(response)
|
||||
|
||||
|
||||
def verify_message(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address: str,
|
||||
signature: bytes,
|
||||
message: AnyStr,
|
||||
chunkify: bool = False,
|
||||
) -> bool:
|
||||
try:
|
||||
client.call(
|
||||
session.call(
|
||||
messages.EthereumVerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
@ -413,13 +411,13 @@ def verify_message(
|
||||
|
||||
|
||||
def sign_typed_data_hash(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
domain_hash: bytes,
|
||||
message_hash: Optional[bytes],
|
||||
encoded_network: Optional[bytes] = None,
|
||||
) -> messages.EthereumTypedDataSignature:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.EthereumSignTypedHash(
|
||||
address_n=n,
|
||||
domain_separator_hash=domain_hash,
|
||||
|
@ -85,3 +85,10 @@ class UnexpectedMessageError(TrezorException):
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}")
|
||||
|
||||
|
||||
class FailedSessionResumption(TrezorException):
|
||||
"""Provided session_id is not valid / session cannot be resumed.
|
||||
|
||||
Raised when `trezorctl -s <sesssion_id>` is used or `TREZOR_SESSION_ID = <session_id>`
|
||||
is set and resumption of session with the `session_id` fails."""
|
||||
|
@ -22,37 +22,37 @@ from . import messages
|
||||
from .tools import _return_success
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]:
|
||||
return client.call(
|
||||
def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]:
|
||||
return session.call(
|
||||
messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
|
||||
).credentials
|
||||
|
||||
|
||||
def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None:
|
||||
ret = client.call(
|
||||
def add_credential(session: "Session", credential_id: bytes) -> str | None:
|
||||
ret = session.call(
|
||||
messages.WebAuthnAddResidentCredential(credential_id=credential_id),
|
||||
expect=messages.Success,
|
||||
)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def remove_credential(client: "TrezorClient", index: int) -> str | None:
|
||||
ret = client.call(
|
||||
def remove_credential(session: "Session", index: int) -> str | None:
|
||||
ret = session.call(
|
||||
messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
|
||||
)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None:
|
||||
ret = client.call(
|
||||
def set_counter(session: "Session", u2f_counter: int) -> str | None:
|
||||
ret = session.call(
|
||||
messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
|
||||
)
|
||||
return _return_success(ret)
|
||||
|
||||
|
||||
def get_next_counter(client: "TrezorClient") -> int:
|
||||
ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
|
||||
def get_next_counter(session: "Session") -> int:
|
||||
ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
|
||||
return ret.u2f_counter
|
||||
|
@ -22,7 +22,6 @@ from hashlib import blake2s
|
||||
from typing_extensions import Protocol, TypeGuard
|
||||
|
||||
from .. import messages
|
||||
from ..tools import session
|
||||
from .core import VendorFirmware
|
||||
from .legacy import LegacyFirmware, LegacyV2Firmware
|
||||
from .models import Model
|
||||
@ -41,7 +40,7 @@ if True:
|
||||
from .vendor import * # noqa: F401, F403
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
from ..transport.session import Session
|
||||
|
||||
T = t.TypeVar("T", bound="FirmwareType")
|
||||
|
||||
@ -77,20 +76,19 @@ def is_onev2(fw: FirmwareType) -> TypeGuard[LegacyFirmware]:
|
||||
# ====== Client functions ====== #
|
||||
|
||||
|
||||
@session
|
||||
def update(
|
||||
client: TrezorClient,
|
||||
session: Session,
|
||||
data: bytes,
|
||||
progress_update: t.Callable[[int], t.Any] = lambda _: None,
|
||||
):
|
||||
if client.features.bootloader_mode is False:
|
||||
if session.features.bootloader_mode is False:
|
||||
raise RuntimeError("Device must be in bootloader mode")
|
||||
|
||||
resp = client.call(messages.FirmwareErase(length=len(data)))
|
||||
resp = session.call(messages.FirmwareErase(length=len(data)))
|
||||
|
||||
# TREZORv1 method
|
||||
if isinstance(resp, messages.Success):
|
||||
resp = client.call(messages.FirmwareUpload(payload=data))
|
||||
resp = session.call(messages.FirmwareUpload(payload=data))
|
||||
progress_update(len(data))
|
||||
if isinstance(resp, messages.Success):
|
||||
return
|
||||
@ -102,7 +100,7 @@ def update(
|
||||
length = resp.length
|
||||
payload = data[resp.offset : resp.offset + length]
|
||||
digest = blake2s(payload).digest()
|
||||
resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||
resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
|
||||
progress_update(length)
|
||||
|
||||
if isinstance(resp, messages.Success):
|
||||
@ -111,7 +109,7 @@ def update(
|
||||
raise RuntimeError(f"Unexpected message {resp}")
|
||||
|
||||
|
||||
def get_hash(client: TrezorClient, challenge: bytes | None) -> bytes:
|
||||
return client.call(
|
||||
def get_hash(session: Session, challenge: bytes | None) -> bytes:
|
||||
return session.call(
|
||||
messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
|
||||
).hash
|
||||
|
@ -85,6 +85,7 @@ class ProtobufMapping:
|
||||
mapping = cls()
|
||||
|
||||
message_types = getattr(module, "MessageType")
|
||||
|
||||
for entry in message_types:
|
||||
msg_class = getattr(module, entry.name, None)
|
||||
if msg_class is None:
|
||||
|
@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_entropy(client: "TrezorClient", size: int) -> bytes:
|
||||
return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
|
||||
def get_entropy(session: "Session", size: int) -> bytes:
|
||||
return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
|
||||
|
||||
|
||||
def sign_identity(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
identity: messages.IdentityType,
|
||||
challenge_hidden: bytes,
|
||||
challenge_visual: str,
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> messages.SignedIdentity:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SignIdentity(
|
||||
identity=identity,
|
||||
challenge_hidden=challenge_hidden,
|
||||
@ -46,12 +46,12 @@ def sign_identity(
|
||||
|
||||
|
||||
def get_ecdh_session_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
identity: messages.IdentityType,
|
||||
peer_public_key: bytes,
|
||||
ecdsa_curve_name: Optional[str] = None,
|
||||
) -> messages.ECDHSessionKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.GetECDHSessionKey(
|
||||
identity=identity,
|
||||
peer_public_key=peer_public_key,
|
||||
@ -62,7 +62,7 @@ def get_ecdh_session_key(
|
||||
|
||||
|
||||
def encrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
@ -70,7 +70,7 @@ def encrypt_keyvalue(
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
key=key,
|
||||
@ -85,7 +85,7 @@ def encrypt_keyvalue(
|
||||
|
||||
|
||||
def decrypt_keyvalue(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
key: str,
|
||||
value: bytes,
|
||||
@ -93,7 +93,7 @@ def decrypt_keyvalue(
|
||||
ask_on_decrypt: bool = True,
|
||||
iv: bytes = b"",
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.CipherKeyValue(
|
||||
address_n=n,
|
||||
key=key,
|
||||
@ -107,5 +107,5 @@ def decrypt_keyvalue(
|
||||
).value
|
||||
|
||||
|
||||
def get_nonce(client: "TrezorClient") -> bytes:
|
||||
return client.call(messages.GetNonce(), expect=messages.Nonce).nonce
|
||||
def get_nonce(session: "Session") -> bytes:
|
||||
return session.call(messages.GetNonce(), expect=messages.Nonce).nonce
|
||||
|
@ -19,8 +19,8 @@ from typing import TYPE_CHECKING
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
# MAINNET = 0
|
||||
@ -30,13 +30,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
show_display: bool = False,
|
||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||
chunkify: bool = False,
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.MoneroGetAddress(
|
||||
address_n=n,
|
||||
show_display=show_display,
|
||||
@ -48,11 +48,11 @@ def get_address(
|
||||
|
||||
|
||||
def get_watch_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
|
||||
) -> messages.MoneroWatchKey:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
|
||||
expect=messages.MoneroWatchKey,
|
||||
)
|
||||
|
@ -20,8 +20,8 @@ from typing import TYPE_CHECKING
|
||||
from . import exceptions, messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
TYPE_TRANSACTION_TRANSFER = 0x0101
|
||||
TYPE_IMPORTANCE_TRANSFER = 0x0801
|
||||
@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
n: "Address",
|
||||
network: int,
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.NEMGetAddress(
|
||||
address_n=n, network=network, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -210,7 +210,7 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
|
||||
session: "Session", n: "Address", transaction: dict, chunkify: bool = False
|
||||
) -> messages.NEMSignedTx:
|
||||
try:
|
||||
msg = create_sign_tx(transaction, chunkify=chunkify)
|
||||
@ -219,4 +219,4 @@ def sign_tx(
|
||||
|
||||
assert msg.transaction is not None
|
||||
msg.transaction.address_n = n
|
||||
return client.call(msg, expect=messages.NEMSignedTx)
|
||||
return session.call(msg, expect=messages.NEMSignedTx)
|
||||
|
@ -20,12 +20,12 @@ from typing import TYPE_CHECKING
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_pubkey(client: "TrezorClient", n: "Address") -> bytes:
|
||||
return client.call(
|
||||
def get_pubkey(session: "Session", n: "Address") -> bytes:
|
||||
return session.call(
|
||||
messages.NostrGetPubkey(
|
||||
address_n=n,
|
||||
),
|
||||
@ -34,7 +34,7 @@ def get_pubkey(client: "TrezorClient", n: "Address") -> bytes:
|
||||
|
||||
|
||||
def sign_event(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
sign_event: messages.NostrSignEvent,
|
||||
) -> messages.NostrEventSignature:
|
||||
return client.call(sign_event, expect=messages.NostrEventSignature)
|
||||
return session.call(sign_event, expect=messages.NostrEventSignature)
|
||||
|
@ -21,20 +21,20 @@ from .protobuf import dict_to_proto
|
||||
from .tools import dict_from_camelcase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
|
||||
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.RippleGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -43,14 +43,14 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
msg: messages.RippleSignTx,
|
||||
chunkify: bool = False,
|
||||
) -> messages.RippleSignedTx:
|
||||
msg.address_n = address_n
|
||||
msg.chunkify = chunkify
|
||||
return client.call(msg, expect=messages.RippleSignedTx)
|
||||
return session.call(msg, expect=messages.RippleSignedTx)
|
||||
|
||||
|
||||
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
|
||||
|
@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
show_display: bool,
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
|
||||
expect=messages.SolanaPublicKey,
|
||||
).public_key
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
show_display: bool,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaGetAddress(
|
||||
address_n=address_n,
|
||||
show_display=show_display,
|
||||
@ -34,12 +34,12 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: List[int],
|
||||
serialized_tx: bytes,
|
||||
additional_info: Optional[messages.SolanaTxAdditionalInfo],
|
||||
) -> bytes:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.SolanaSignTx(
|
||||
address_n=address_n,
|
||||
serialized_tx=serialized_tx,
|
||||
|
@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from . import exceptions, messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
StellarMessageType = Union[
|
||||
messages.StellarAccountMergeOp,
|
||||
@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.StellarGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -336,7 +336,7 @@ def get_address(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
tx: messages.StellarSignTx,
|
||||
operations: List["StellarMessageType"],
|
||||
address_n: "Address",
|
||||
@ -352,10 +352,10 @@ def sign_tx(
|
||||
# 3. Receive a StellarTxOpRequest message
|
||||
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
||||
# 5. The final message received will be StellarSignedTx which is returned from this method
|
||||
resp = client.call(tx)
|
||||
resp = session.call(tx)
|
||||
try:
|
||||
while isinstance(resp, messages.StellarTxOpRequest):
|
||||
resp = client.call(operations.pop(0))
|
||||
resp = session.call(operations.pop(0))
|
||||
except IndexError:
|
||||
# pop from empty list
|
||||
raise exceptions.TrezorException(
|
||||
|
@ -19,17 +19,17 @@ from typing import TYPE_CHECKING
|
||||
from . import messages
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import TrezorClient
|
||||
from .tools import Address
|
||||
from .transport.session import Session
|
||||
|
||||
|
||||
def get_address(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.TezosGetAddress(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -38,12 +38,12 @@ def get_address(
|
||||
|
||||
|
||||
def get_public_key(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
show_display: bool = False,
|
||||
chunkify: bool = False,
|
||||
) -> str:
|
||||
return client.call(
|
||||
return session.call(
|
||||
messages.TezosGetPublicKey(
|
||||
address_n=address_n, show_display=show_display, chunkify=chunkify
|
||||
),
|
||||
@ -52,11 +52,11 @@ def get_public_key(
|
||||
|
||||
|
||||
def sign_tx(
|
||||
client: "TrezorClient",
|
||||
session: "Session",
|
||||
address_n: "Address",
|
||||
sign_tx_msg: messages.TezosSignTx,
|
||||
chunkify: bool = False,
|
||||
) -> messages.TezosSignedTx:
|
||||
sign_tx_msg.address_n = address_n
|
||||
sign_tx_msg.chunkify = chunkify
|
||||
return client.call(sign_tx_msg, expect=messages.TezosSignedTx)
|
||||
return session.call(sign_tx_msg, expect=messages.TezosSignedTx)
|
||||
|
@ -45,7 +45,7 @@ if TYPE_CHECKING:
|
||||
# More details: https://www.python.org/dev/peps/pep-0612/
|
||||
from typing import TypeVar
|
||||
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from . import client
|
||||
from .messages import Success
|
||||
@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None:
|
||||
return _deprecation_retval_helper(msg.message, stacklevel=1)
|
||||
|
||||
|
||||
def session(
|
||||
f: "Callable[Concatenate[TrezorClient, P], R]",
|
||||
) -> "Callable[Concatenate[TrezorClient, P], R]":
|
||||
# Decorator wraps a BaseClient method
|
||||
# with session activation / deactivation
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
client.open()
|
||||
try:
|
||||
return f(client, *args, **kwargs)
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
return wrapped_f
|
||||
|
||||
|
||||
# de-camelcasifier
|
||||
# https://stackoverflow.com/a/1176023/222189
|
||||
|
||||
|
@ -17,14 +17,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar
|
||||
import typing as t
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="Transport")
|
||||
T = t.TypeVar("T", bound="Transport")
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -34,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
||||
""".strip()
|
||||
|
||||
|
||||
MessagePayload = Tuple[int, bytes]
|
||||
MessagePayload = t.Tuple[int, bytes]
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
@ -50,72 +51,57 @@ class Timeout(TransportException):
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Raw connection to a Trezor device.
|
||||
|
||||
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
|
||||
or USB-HID connection, or UDP socket of listening emulator(s).
|
||||
It can also enumerate devices available over this communication link, and return
|
||||
them as instances.
|
||||
|
||||
Transport instance is a thing that:
|
||||
- can be identified and requested by a string URI-like path
|
||||
- can open and close sessions, which enclose related operations
|
||||
- can read and write protobuf messages
|
||||
|
||||
You need to implement a new Transport subclass if you invent a new way to connect
|
||||
a Trezor device to a computer.
|
||||
"""
|
||||
|
||||
PATH_PREFIX: str
|
||||
ENABLED = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.get_path()
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: t.Type[T], models: t.Iterable[TrezorModel] | None = None
|
||||
) -> t.Iterable[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: t.Type[T], path: str, prefix_search: bool = False) -> T:
|
||||
for device in cls.enumerate():
|
||||
|
||||
if device.get_path() == path:
|
||||
return device
|
||||
|
||||
if prefix_search and device.get_path().startswith(path):
|
||||
return device
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
def get_path(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def begin_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def find_debug(self: T) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: type[T], models: Iterable[TrezorModel] | None = None
|
||||
) -> Iterable[T]:
|
||||
def open(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
or device.get_path() == path
|
||||
or (prefix_search and device.get_path().startswith(path))
|
||||
):
|
||||
return device
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def ping(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
CHUNK_SIZE: t.ClassVar[int | None]
|
||||
|
||||
|
||||
def all_transports() -> Iterable[type["Transport"]]:
|
||||
def all_transports() -> t.Iterable[t.Type["Transport"]]:
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[type["Transport"], ...] = (
|
||||
transports: t.Tuple[t.Type["Transport"], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
@ -125,9 +111,9 @@ def all_transports() -> Iterable[type["Transport"]]:
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: Iterable[TrezorModel] | None = None,
|
||||
) -> Sequence[Transport]:
|
||||
devices: list[Transport] = []
|
||||
models: t.Iterable[TrezorModel] | None = None,
|
||||
) -> t.Sequence[Transport]:
|
||||
devices: t.List[Transport] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
|
@ -17,16 +17,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
import typing as t
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||
from . import DeviceIsBusy, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -58,10 +57,13 @@ def call_bridge(
|
||||
return r
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
def get_bridge_version() -> t.Tuple[int, ...]:
|
||||
config = call_bridge("configure").json()
|
||||
version_tuple = tuple(map(int, config["version"].split(".")))
|
||||
return version_tuple < TREZORD_VERSION_MODERN
|
||||
return tuple(map(int, config["version"].split(".")))
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||
|
||||
|
||||
class BridgeHandle:
|
||||
@ -115,15 +117,15 @@ class BridgeTransport(Transport):
|
||||
|
||||
PATH_PREFIX = "bridge"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = None
|
||||
|
||||
def __init__(
|
||||
self, device: dict[str, Any], legacy: bool, debug: bool = False
|
||||
self, device: dict[str, t.Any], legacy: bool, debug: bool = False
|
||||
) -> None:
|
||||
if legacy and debug:
|
||||
raise TransportException("Debugging not supported on legacy Bridge")
|
||||
|
||||
self.device = device
|
||||
self.session: str | None = None
|
||||
self.session: str | None = device["session"]
|
||||
self.debug = debug
|
||||
self.legacy = legacy
|
||||
|
||||
@ -154,8 +156,8 @@ class BridgeTransport(Transport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Iterable[TrezorModel] | None = None
|
||||
) -> Iterable["BridgeTransport"]:
|
||||
cls, _models: t.Iterable[TrezorModel] | None = None
|
||||
) -> t.Iterable["BridgeTransport"]:
|
||||
try:
|
||||
legacy = is_legacy_bridge()
|
||||
return [
|
||||
@ -164,7 +166,7 @@ class BridgeTransport(Transport):
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def begin_session(self) -> None:
|
||||
def open(self) -> None:
|
||||
try:
|
||||
data = self._call("acquire/" + self.device["path"])
|
||||
except BridgeException as e:
|
||||
@ -173,18 +175,17 @@ class BridgeTransport(Transport):
|
||||
raise
|
||||
self.session = data.json()["session"]
|
||||
|
||||
def end_session(self) -> None:
|
||||
def close(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
self._call("release")
|
||||
self.session = None
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
self.handle.write_buf(header + message_data)
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
self.handle.write_buf(chunk)
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
data = self.handle.read_buf(timeout=timeout)
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
return msg_type, data[headerlen : headerlen + datalen]
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
return self.handle.read_buf(timeout=timeout)
|
||||
|
||||
def ping(self) -> bool:
|
||||
return self.session is not None
|
||||
|
@ -19,12 +19,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable
|
||||
import typing as t
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, Timeout, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -37,23 +36,61 @@ except Exception as e:
|
||||
HID_IMPORTED = False
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
HidDevice = t.Dict[str, t.Any]
|
||||
HidDeviceHandle = t.Any
|
||||
|
||||
|
||||
class HidHandle:
|
||||
def __init__(
|
||||
self, path: bytes, serial: str, probe_hid_version: bool = False
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.serial = serial
|
||||
class HidTransport(Transport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
|
||||
self.device = device
|
||||
self.device_path = device["path"]
|
||||
self.device_serial_number = device["serial_number"]
|
||||
self.handle: HidDeviceHandle = None
|
||||
self.hid_version = None if probe_hid_version else 2
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
|
||||
) -> t.Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: t.List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
self.handle.open_path(self.device_path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (UDEV_RULES_STR,)
|
||||
@ -64,11 +101,11 @@ class HidHandle:
|
||||
# and we wouldn't even know.
|
||||
# So we check that the serial matches what we expect.
|
||||
serial = self.handle.get_serial_number_string()
|
||||
if serial != self.serial:
|
||||
if serial != self.device_serial_number:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
raise TransportException(
|
||||
f"Unexpected device {serial} on path {self.path.decode()}"
|
||||
f"Unexpected device {serial} on path {self.device_path.decode()}"
|
||||
)
|
||||
|
||||
self.handle.set_nonblocking(True)
|
||||
@ -79,7 +116,7 @@ class HidHandle:
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
# reload serial, because device.wipe() can reset it
|
||||
self.serial = self.handle.get_serial_number_string()
|
||||
self.device_serial_number = self.handle.get_serial_number_string()
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
@ -119,52 +156,8 @@ class HidHandle:
|
||||
return 1
|
||||
raise TransportException("Unknown HID version")
|
||||
|
||||
|
||||
class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice) -> None:
|
||||
self.device = device
|
||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self.handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
|
||||
) -> Iterable[HidTransport]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: list[HidTransport] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> HidTransport:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
def ping(self) -> bool:
|
||||
return self.handle is not None
|
||||
|
||||
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
|
@ -1,179 +0,0 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from . import MessagePayload, Timeout, Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
V2_FIRST_CHUNK = 0x01
|
||||
V2_NEXT_CHUNK = 0x02
|
||||
V2_BEGIN_SESSION = 0x03
|
||||
V2_END_SESSION = 0x04
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_READ_TIMEOUT: float | None = None
|
||||
|
||||
|
||||
class Handle(StructuralType):
|
||||
"""PEP 544 structural type for Handle functionality.
|
||||
(called a "Protocol" in the proposed PEP, name which is impractical here)
|
||||
|
||||
Handle is a "physical" layer for a protocol.
|
||||
It can open/close a connection and read/write bare data in 64-byte chunks.
|
||||
|
||||
Functionally we gain nothing from making this an (abstract) base class for handle
|
||||
implementations, so this definition is for type hinting purposes only. You can,
|
||||
but don't have to, inherit from it.
|
||||
"""
|
||||
|
||||
def open(self) -> None: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
|
||||
class Protocol:
|
||||
"""Wire protocol that can communicate with a Trezor device, given a Handle.
|
||||
|
||||
A Protocol implements the part of the Transport API that relates to communicating
|
||||
logical messages over a physical layer. It is a thing that can:
|
||||
- open and close sessions,
|
||||
- send and receive protobuf messages,
|
||||
given the ability to:
|
||||
- open and close physical connections,
|
||||
- and send and receive binary chunks.
|
||||
|
||||
For now, the class also handles session counting and opening the underlying Handle.
|
||||
This will probably be removed in the future.
|
||||
|
||||
We will need a new Protocol class if we change the way a Trezor device encapsulates
|
||||
its messages.
|
||||
"""
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.handle = handle
|
||||
self.session_counter = 0
|
||||
|
||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||
def begin_session(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.handle.open()
|
||||
try:
|
||||
# Drop queued responses to old requests
|
||||
while True:
|
||||
msg = self.handle.read_chunk(timeout=0.1)
|
||||
LOG.warning("ignored: %s", msg)
|
||||
except Timeout:
|
||||
pass
|
||||
|
||||
self.session_counter += 1
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolBasedTransport(Transport):
|
||||
"""Transport that implements its communications through a Protocol.
|
||||
|
||||
Intended as a base class for implementations that proxy their communication
|
||||
operations to a Protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: Protocol) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
self.protocol.write(message_type, message_data)
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
return self.protocol.read(timeout=timeout)
|
||||
|
||||
def begin_session(self) -> None:
|
||||
self.protocol.begin_session()
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.protocol.end_session()
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
if timeout is None:
|
||||
timeout = _DEFAULT_READ_TIMEOUT
|
||||
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next(timeout=timeout))
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk(timeout=timeout)
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self, timeout: float | None = None) -> bytes:
|
||||
chunk = self.handle.read_chunk(timeout=timeout)
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
return chunk[1:]
|
174
python/src/trezorlib/transport/session.py
Normal file
174
python/src/trezorlib/transport/session.py
Normal file
@ -0,0 +1,174 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from .. import exceptions, messages, models
|
||||
from ..protobuf import MessageType
|
||||
from .thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
MT = t.TypeVar("MT", bound=MessageType)
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||
self.client = client
|
||||
self._id = id
|
||||
|
||||
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
|
||||
self.client.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
|
||||
while True:
|
||||
if isinstance(resp, messages.PinMatrixRequest):
|
||||
if self.client.pin_callback is None:
|
||||
raise NotImplementedError("Missing pin_callback")
|
||||
resp = self.client.pin_callback(self, resp)
|
||||
elif isinstance(resp, messages.PassphraseRequest):
|
||||
if self.client.passphrase_callback is None:
|
||||
raise NotImplementedError("Missing passphrase_callback")
|
||||
resp = self.client.passphrase_callback(self, resp)
|
||||
elif isinstance(resp, messages.ButtonRequest):
|
||||
resp = (self.client.button_callback or default_button_callback)(
|
||||
self, resp
|
||||
)
|
||||
elif isinstance(resp, messages.Failure):
|
||||
if resp.code == messages.FailureType.ActionCancelled:
|
||||
raise exceptions.Cancelled
|
||||
raise exceptions.TrezorFailure(resp)
|
||||
elif not isinstance(resp, expect):
|
||||
raise exceptions.UnexpectedMessageError(expect, resp)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def call_raw(self, msg: t.Any) -> t.Any:
|
||||
self._write(msg)
|
||||
return self._read()
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def refresh_features(self) -> None:
|
||||
self.client.refresh_features()
|
||||
|
||||
def end(self) -> t.Any:
|
||||
return self.call(messages.EndSession())
|
||||
|
||||
def cancel(self) -> None:
|
||||
self._write(messages.Cancel())
|
||||
|
||||
def ping(self, message: str, button_protection: bool | None = None) -> str:
|
||||
# We would like ping to work on any valid TrezorClient instance, but
|
||||
# due to the protection modes, we need to go through self.call, and that will
|
||||
# raise an exception if the firmware is too old.
|
||||
# So we short-circuit the simplest variant of ping with call_raw.
|
||||
if not button_protection:
|
||||
resp = self.call_raw(messages.Ping(message=message))
|
||||
if isinstance(resp, messages.ButtonRequest):
|
||||
# device is PIN-locked.
|
||||
# respond and hope for the best
|
||||
resp = (self.client.button_callback or default_button_callback)(
|
||||
self, resp
|
||||
)
|
||||
resp = messages.Success.ensure_isinstance(resp)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
|
||||
resp = self.call(
|
||||
messages.Ping(message=message, button_protection=button_protection),
|
||||
expect=messages.Success,
|
||||
)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
|
||||
def invalidate(self) -> None:
|
||||
self.client.invalidate()
|
||||
|
||||
@property
|
||||
def features(self) -> messages.Features:
|
||||
return self.client.features
|
||||
|
||||
@property
|
||||
def model(self) -> models.TrezorModel:
|
||||
return self.client.model
|
||||
|
||||
@property
|
||||
def version(self) -> t.Tuple[int, int, int]:
|
||||
return self.client.version
|
||||
|
||||
@property
|
||||
def id(self) -> bytes:
|
||||
return self._id
|
||||
|
||||
@id.setter
|
||||
def id(self, value: bytes) -> None:
|
||||
if not isinstance(value, bytes):
|
||||
raise ValueError("id must be of type bytes")
|
||||
self._id = value
|
||||
|
||||
|
||||
class SessionV1(Session):
|
||||
derive_cardano: bool | None = False
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
client: TrezorClient,
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1Channel)
|
||||
session = SessionV1(client, id=session_id or b"")
|
||||
session.derive_cardano = derive_cardano
|
||||
session.init_session(session.derive_cardano)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1Channel)
|
||||
session = SessionV1(client, session_id)
|
||||
session.init_session()
|
||||
return session
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
||||
self.client.protocol.write(msg)
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
||||
return self.client.protocol.read()
|
||||
|
||||
def init_session(self, derive_cardano: bool | None = None) -> None:
|
||||
if self.id == b"":
|
||||
session_id = None
|
||||
else:
|
||||
session_id = self.id
|
||||
resp: messages.Features = self.call_raw(
|
||||
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
||||
)
|
||||
assert isinstance(resp, messages.Features)
|
||||
if resp.session_id is not None:
|
||||
self.id = resp.session_id
|
||||
|
||||
|
||||
def default_button_callback(session: Session, msg: t.Any) -> t.Any:
|
||||
return session.call_raw(messages.ButtonAck())
|
||||
|
||||
|
||||
def derive_seed(session: Session) -> None:
|
||||
|
||||
from ..btc import get_address
|
||||
from ..client import PASSPHRASE_TEST_PATH
|
||||
|
||||
get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
|
||||
session.refresh_features()
|
26
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
26
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Channel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_features(self) -> None:
|
||||
raise NotImplementedError
|
128
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
128
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
@ -0,0 +1,128 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2025 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...log import DUMP_BYTES
|
||||
from .protocol_and_channel import Channel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV1Channel(Channel):
|
||||
_DEFAULT_READ_TIMEOUT: t.ClassVar[float | None] = None
|
||||
HEADER_LEN: t.ClassVar[int] = struct.calcsize(">HL")
|
||||
_features: messages.Features | None = None
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
self.write(messages.GetFeatures())
|
||||
resp = self.read()
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = resp
|
||||
|
||||
def read(self, timeout: float | None = None) -> t.Any:
|
||||
msg_type, msg_bytes = self._read(timeout=timeout)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
|
||||
if chunk_size is None:
|
||||
self.transport.write_chunk(header + message_data)
|
||||
return
|
||||
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
while buffer:
|
||||
# Report ID, data padded to (chunk_size - 1) bytes
|
||||
chunk = b"?" + buffer[: chunk_size - 1]
|
||||
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||
self.transport.write_chunk(chunk)
|
||||
buffer = buffer[chunk_size - 1 :]
|
||||
|
||||
def _read(self, timeout: float | None = None) -> t.Tuple[int, bytes]:
|
||||
if timeout is None:
|
||||
timeout = self._DEFAULT_READ_TIMEOUT
|
||||
|
||||
if self.transport.CHUNK_SIZE is None:
|
||||
return self.read_chunkless(timeout=timeout)
|
||||
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next(timeout=timeout))
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_chunkless(self, timeout: float | None = None) -> t.Tuple[int, bytes]:
|
||||
data = self.transport.read_chunk(timeout=timeout)
|
||||
msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN])
|
||||
return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen]
|
||||
|
||||
def read_first(self, timeout: float | None = None) -> t.Tuple[int, int, bytes]:
|
||||
chunk = self.transport.read_chunk(timeout=timeout)
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self, timeout: float | None = None) -> bytes:
|
||||
chunk = self.transport.read_chunk(timeout=timeout)
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
return chunk[1:]
|
@ -19,11 +19,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import Timeout, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
@ -33,12 +32,13 @@ SOCKET_TIMEOUT = 0.1
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(ProtocolBasedTransport):
|
||||
class UdpTransport(Transport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(self, device: str | None = None) -> None:
|
||||
if not device:
|
||||
@ -48,24 +48,17 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device = (host, port)
|
||||
self.device: Tuple[str, int] = (host, port)
|
||||
|
||||
self.socket: socket.socket | None = None
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
if d._ping():
|
||||
if d.ping():
|
||||
return d
|
||||
else:
|
||||
raise TransportException(
|
||||
@ -99,20 +92,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
assert prefix_search # otherwise we would have raised above
|
||||
return super().find_by_path(path, prefix_search)
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self._ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise Timeout("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
@ -124,17 +105,6 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def _ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
@ -156,3 +126,33 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return bytearray(chunk)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self.ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise Timeout("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
@ -20,14 +20,11 @@ import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
from typing_extensions import Self
|
||||
from typing import Iterable, List
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZORS, TrezorModel
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -48,14 +45,70 @@ USB_COMM_TIMEOUT_MS = 300
|
||||
WEBUSB_CHUNK_SIZE = 64
|
||||
|
||||
|
||||
class WebUsbHandle:
|
||||
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
|
||||
class WebUsbTransport(Transport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
return devices
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
if self.handle is None:
|
||||
@ -68,6 +121,8 @@ class WebUsbHandle:
|
||||
self.handle.claimInterface(self.interface)
|
||||
except usb1.USBErrorAccess as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
except usb1.USBErrorBusy as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
@ -119,76 +174,13 @@ class WebUsbHandle:
|
||||
except Exception as e:
|
||||
raise TransportException(f"USB read failed: {e}") from e
|
||||
|
||||
|
||||
class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: usb1.USBDevice,
|
||||
handle: WebUsbHandle | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
if handle is None:
|
||||
handle = WebUsbHandle(device, debug)
|
||||
|
||||
self.device = device
|
||||
self.handle = handle
|
||||
self.debug = debug
|
||||
|
||||
super().__init__(protocol=ProtocolV1(handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls,
|
||||
models: Iterable[TrezorModel] | None = None,
|
||||
usb_reset: bool = False,
|
||||
) -> Iterable[WebUsbTransport]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: list[WebUsbTransport] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
except usb1.USBErrorPipe:
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> Self:
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return self.__class__(self.device, debug=True)
|
||||
|
||||
def ping(self) -> bool:
|
||||
return self.handle is not None
|
||||
|
||||
|
||||
def is_vendor_class(dev: usb1.USBDevice) -> bool:
|
||||
configurationId = 0
|
||||
|
@ -16,16 +16,16 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Optional, Union
|
||||
import typing as t
|
||||
|
||||
import click
|
||||
from mnemonic import Mnemonic
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from . import device, messages
|
||||
from .client import MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE
|
||||
from .exceptions import Cancelled
|
||||
from .messages import PinMatrixRequestType, WordRequestType
|
||||
from .client import MAX_PIN_LENGTH
|
||||
from .exceptions import Cancelled, PinException
|
||||
from .messages import Capability, PinMatrixRequestType, WordRequestType
|
||||
from .transport.session import Session
|
||||
|
||||
PIN_MATRIX_DESCRIPTION = """
|
||||
Use the numeric keypad or lowercase letters to describe number positions.
|
||||
@ -62,19 +62,11 @@ WIPE_CODE_CONFIRM = PinMatrixRequestType.WipeCodeSecond
|
||||
CAN_HANDLE_HIDDEN_INPUT = sys.stdin and sys.stdin.isatty()
|
||||
|
||||
|
||||
class TrezorClientUI(Protocol):
|
||||
def button_request(self, br: messages.ButtonRequest) -> None: ...
|
||||
|
||||
def get_pin(self, code: Optional[PinMatrixRequestType]) -> str: ...
|
||||
|
||||
def get_passphrase(self, available_on_device: bool) -> Union[str, object]: ...
|
||||
|
||||
|
||||
def echo(*args: Any, **kwargs: Any) -> None:
|
||||
def echo(*args: t.Any, **kwargs: t.Any) -> None:
|
||||
return click.echo(*args, err=True, **kwargs)
|
||||
|
||||
|
||||
def prompt(text: str, *, hide_input: bool = False, **kwargs: Any) -> Any:
|
||||
def prompt(text: str, *, hide_input: bool = False, **kwargs: t.Any) -> t.Any:
|
||||
# Disallowing hidden input and warning user when it would cause issues
|
||||
if not CAN_HANDLE_HIDDEN_INPUT and hide_input:
|
||||
hide_input = False
|
||||
@ -99,14 +91,16 @@ class ClickUI:
|
||||
|
||||
return "Please confirm action on your Trezor device."
|
||||
|
||||
def button_request(self, br: messages.ButtonRequest) -> None:
|
||||
def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
|
||||
prompt = self._prompt_for_button(br)
|
||||
if prompt != self.last_prompt_shown:
|
||||
echo(prompt)
|
||||
if not self.always_prompt:
|
||||
self.last_prompt_shown = prompt
|
||||
return session.call_raw(messages.ButtonAck())
|
||||
|
||||
def get_pin(self, code: Optional[PinMatrixRequestType] = None) -> str:
|
||||
def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
|
||||
code = request.type
|
||||
if code == PIN_CURRENT:
|
||||
desc = "current PIN"
|
||||
elif code == PIN_NEW:
|
||||
@ -129,6 +123,7 @@ class ClickUI:
|
||||
try:
|
||||
pin = prompt(f"Please enter {desc}", hide_input=True)
|
||||
except click.Abort:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise Cancelled from None
|
||||
|
||||
# translate letters to numbers if letters were used
|
||||
@ -142,16 +137,33 @@ class ClickUI:
|
||||
elif len(pin) > MAX_PIN_LENGTH:
|
||||
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
|
||||
else:
|
||||
return pin
|
||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def get_passphrase(self, available_on_device: bool) -> Union[str, object]:
|
||||
def get_passphrase(
|
||||
self, session: Session, request: messages.PassphraseRequest
|
||||
) -> t.Any:
|
||||
available_on_device = (
|
||||
Capability.PassphraseEntry in session.features.capabilities
|
||||
)
|
||||
if available_on_device and not self.passphrase_on_host:
|
||||
return PASSPHRASE_ON_DEVICE
|
||||
return session.call_raw(
|
||||
messages.PassphraseAck(passphrase=None, on_device=True)
|
||||
)
|
||||
|
||||
env_passphrase = os.getenv("PASSPHRASE")
|
||||
if env_passphrase is not None:
|
||||
echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
return env_passphrase
|
||||
return session.call_raw(
|
||||
messages.PassphraseAck(passphrase=env_passphrase, on_device=False)
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -163,7 +175,7 @@ class ClickUI:
|
||||
)
|
||||
# In case user sees the input on the screen, we do not need confirmation
|
||||
if not CAN_HANDLE_HIDDEN_INPUT:
|
||||
return passphrase
|
||||
break
|
||||
second = prompt(
|
||||
"Confirm your passphrase",
|
||||
hide_input=True,
|
||||
@ -171,12 +183,16 @@ class ClickUI:
|
||||
show_default=False,
|
||||
)
|
||||
if passphrase == second:
|
||||
return passphrase
|
||||
break
|
||||
else:
|
||||
echo("Passphrase did not match. Please try again.")
|
||||
except click.Abort:
|
||||
raise Cancelled from None
|
||||
|
||||
return session.call_raw(
|
||||
messages.PassphraseAck(passphrase=passphrase, on_device=False)
|
||||
)
|
||||
|
||||
|
||||
class ScriptUI:
|
||||
"""Interface to be used by scripts, not directly by user.
|
||||
@ -190,13 +206,14 @@ class ScriptUI:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def button_request(br: messages.ButtonRequest) -> None:
|
||||
# TODO: send name={br.name} when it will be supported
|
||||
def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
|
||||
code = br.code.name if br.code else None
|
||||
print(f"?BUTTON code={code} pages={br.pages}")
|
||||
print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
|
||||
return session.call_raw(messages.ButtonAck())
|
||||
|
||||
@staticmethod
|
||||
def get_pin(code: Optional[PinMatrixRequestType] = None) -> str:
|
||||
def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
|
||||
code = request.type
|
||||
if code is None:
|
||||
print("?PIN")
|
||||
else:
|
||||
@ -208,10 +225,22 @@ class ScriptUI:
|
||||
elif not pin.startswith(":"):
|
||||
raise RuntimeError("Sent PIN must start with ':'")
|
||||
else:
|
||||
return pin[1:]
|
||||
pin = pin[1:]
|
||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
@staticmethod
|
||||
def get_passphrase(available_on_device: bool) -> Union[str, object]:
|
||||
def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any:
|
||||
available_on_device = (
|
||||
Capability.PassphraseEntry in session.features.capabilities
|
||||
)
|
||||
if available_on_device:
|
||||
print("?PASSPHRASE available_on_device")
|
||||
else:
|
||||
@ -221,16 +250,21 @@ class ScriptUI:
|
||||
if passphrase == "CANCEL":
|
||||
raise Cancelled from None
|
||||
elif passphrase == "ON_DEVICE":
|
||||
return PASSPHRASE_ON_DEVICE
|
||||
return session.call_raw(
|
||||
messages.PassphraseAck(passphrase=None, on_device=True)
|
||||
)
|
||||
elif not passphrase.startswith(":"):
|
||||
raise RuntimeError("Sent passphrase must start with ':'")
|
||||
else:
|
||||
return passphrase[1:]
|
||||
passphrase = passphrase[1:]
|
||||
return session.call_raw(
|
||||
messages.PassphraseAck(passphrase=passphrase, on_device=False)
|
||||
)
|
||||
|
||||
|
||||
def mnemonic_words(
|
||||
expand: bool = False, language: str = "english"
|
||||
) -> Callable[[WordRequestType], str]:
|
||||
) -> t.Callable[[WordRequestType], str]:
|
||||
if expand:
|
||||
wordlist = Mnemonic(language).wordlist
|
||||
else:
|
||||
|
@ -35,7 +35,6 @@ import trezorlib.misc
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.tools import Address
|
||||
from trezorlib.transport import enumerate_devices
|
||||
from trezorlib.ui import ClickUI
|
||||
|
||||
version_tuple = tuple(map(int, trezorlib.__version__.split(".")))
|
||||
if not (0, 11) <= version_tuple < (0, 14):
|
||||
@ -71,16 +70,18 @@ def choose_device(devices: Sequence["Transport"]) -> "Transport":
|
||||
sys.stderr.write("Available devices:\n")
|
||||
for d in devices:
|
||||
try:
|
||||
client = TrezorClient(d, ui=ClickUI())
|
||||
d.open()
|
||||
client = TrezorClient(d)
|
||||
except IOError:
|
||||
sys.stderr.write("[-] <device is currently in use>\n")
|
||||
continue
|
||||
|
||||
if client.features.label:
|
||||
sys.stderr.write(f"[{i}] {client.features.label}\n")
|
||||
else:
|
||||
sys.stderr.write(f"[{i}] <no label>\n")
|
||||
client.close()
|
||||
if client.features.label:
|
||||
sys.stderr.write(f"[{i}] {client.features.label}\n")
|
||||
else:
|
||||
sys.stderr.write(f"[{i}] <no label>\n")
|
||||
finally:
|
||||
d.close()
|
||||
i += 1
|
||||
|
||||
sys.stderr.write("----------------------------\n")
|
||||
@ -106,7 +107,9 @@ def main() -> None:
|
||||
|
||||
devices = wait_for_devices()
|
||||
transport = choose_device(devices)
|
||||
client = TrezorClient(transport, ui=ClickUI())
|
||||
transport.open()
|
||||
client = TrezorClient(transport)
|
||||
session = client.get_seedless_session()
|
||||
|
||||
rootdir = os.environ["encfs_root"] # Read "man encfs" for more
|
||||
passw_file = os.path.join(rootdir, "password.dat")
|
||||
@ -120,7 +123,7 @@ def main() -> None:
|
||||
sys.stderr.write("Computer asked Trezor for new strong password.\n")
|
||||
|
||||
# 32 bytes, good for AES
|
||||
trezor_entropy = trezorlib.misc.get_entropy(client, 32)
|
||||
trezor_entropy = trezorlib.misc.get_entropy(session, 32)
|
||||
urandom_entropy = os.urandom(32)
|
||||
passw = hashlib.sha256(trezor_entropy + urandom_entropy).digest()
|
||||
|
||||
@ -129,7 +132,7 @@ def main() -> None:
|
||||
|
||||
bip32_path = Address([10, 0])
|
||||
passw_encrypted = trezorlib.misc.encrypt_keyvalue(
|
||||
client, bip32_path, label, passw, False, True
|
||||
session, bip32_path, label, passw, False, True
|
||||
)
|
||||
|
||||
data = {
|
||||
@ -144,13 +147,14 @@ def main() -> None:
|
||||
data = json.load(open(passw_file, "r"))
|
||||
|
||||
passw = trezorlib.misc.decrypt_keyvalue(
|
||||
client,
|
||||
session,
|
||||
data["bip32_path"],
|
||||
data["label"],
|
||||
bytes.fromhex(data["password_encrypted_hex"]),
|
||||
False,
|
||||
True,
|
||||
)
|
||||
transport.close()
|
||||
|
||||
print(passw)
|
||||
|
||||
|
@ -24,15 +24,19 @@ from trezorlib.tools import parse_path
|
||||
def main() -> None:
|
||||
# Use first connected device
|
||||
client = get_default_client()
|
||||
session = client.get_session()
|
||||
|
||||
# Print out Trezor's features and settings
|
||||
print(client.features)
|
||||
print(session.features)
|
||||
|
||||
# Get the first address of first BIP44 account
|
||||
bip32_path = parse_path("44h/0h/0h/0/0")
|
||||
address = btc.get_address(client, "Bitcoin", bip32_path, True)
|
||||
address = btc.get_address(session, "Bitcoin", bip32_path, True)
|
||||
print("Bitcoin address:", address)
|
||||
|
||||
# Release underlying transport (USB/BLE/UDP)
|
||||
client.transport.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -62,6 +62,8 @@ def main() -> None:
|
||||
sectoraddrs[sector] + offset, content[offset : offset + step], flash=True
|
||||
)
|
||||
|
||||
debug.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -58,6 +58,7 @@ def main() -> None:
|
||||
f.write(mem)
|
||||
|
||||
f.close()
|
||||
debug.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -39,6 +39,7 @@ def find_debug() -> DebugLink:
|
||||
def main() -> None:
|
||||
debug = find_debug()
|
||||
debug.memory_write(int(sys.argv[1], 16), bytes.fromhex(sys.argv[2]), flash=True)
|
||||
debug.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -26,23 +26,24 @@ from urllib.parse import urlparse
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from trezorlib import misc, ui
|
||||
from trezorlib import misc
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.tools import parse_path
|
||||
from trezorlib.transport import get_transport
|
||||
from trezorlib.transport.session import Session
|
||||
|
||||
# Return path by BIP-32
|
||||
BIP32_PATH = parse_path("10016h/0")
|
||||
|
||||
|
||||
# Deriving master key
|
||||
def getMasterKey(client: TrezorClient) -> str:
|
||||
def getMasterKey(session: Session) -> str:
|
||||
bip32_path = BIP32_PATH
|
||||
ENC_KEY = "Activate TREZOR Password Manager?"
|
||||
ENC_VALUE = bytes.fromhex(
|
||||
"2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee2d650551248d792eabf628f451200d7f51cb63e46aadcbb1038aacb05e8c8aee"
|
||||
)
|
||||
key = misc.encrypt_keyvalue(client, bip32_path, ENC_KEY, ENC_VALUE, True, True)
|
||||
key = misc.encrypt_keyvalue(session, bip32_path, ENC_KEY, ENC_VALUE, True, True)
|
||||
return key.hex()
|
||||
|
||||
|
||||
@ -101,7 +102,7 @@ def decryptEntryValue(nonce: str, val: bytes) -> dict:
|
||||
|
||||
|
||||
# Decrypt give entry nonce
|
||||
def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
|
||||
def getDecryptedNonce(session: Session, entry: dict) -> str:
|
||||
print()
|
||||
print("Waiting for Trezor input ...")
|
||||
print()
|
||||
@ -117,7 +118,7 @@ def getDecryptedNonce(client: TrezorClient, entry: dict) -> str:
|
||||
ENC_KEY = f"Unlock {item} for user {entry['username']}?"
|
||||
ENC_VALUE = entry["nonce"]
|
||||
decrypted_nonce = misc.decrypt_keyvalue(
|
||||
client, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
|
||||
session, BIP32_PATH, ENC_KEY, bytes.fromhex(ENC_VALUE), False, True
|
||||
)
|
||||
return decrypted_nonce.hex()
|
||||
|
||||
@ -144,13 +145,15 @@ def main() -> None:
|
||||
print(e)
|
||||
return
|
||||
|
||||
client = TrezorClient(transport=transport, ui=ui.ClickUI())
|
||||
transport.open()
|
||||
client = TrezorClient(transport=transport)
|
||||
session = client.get_seedless_session()
|
||||
|
||||
print()
|
||||
print("Confirm operation on Trezor")
|
||||
print()
|
||||
|
||||
masterKey = getMasterKey(client)
|
||||
masterKey = getMasterKey(session)
|
||||
# print('master key:', masterKey)
|
||||
|
||||
fileName = getFileEncKey(masterKey)[0]
|
||||
@ -173,7 +176,7 @@ def main() -> None:
|
||||
entry_id = input("Select entry number to decrypt: ")
|
||||
entry_id = str(entry_id)
|
||||
|
||||
plain_nonce = getDecryptedNonce(client, entries[entry_id])
|
||||
plain_nonce = getDecryptedNonce(session, entries[entry_id])
|
||||
|
||||
pwdArr = entries[entry_id]["password"]["data"]
|
||||
pwdHex = "".join([hex(x)[2:].zfill(2) for x in pwdArr])
|
||||
@ -183,6 +186,8 @@ def main() -> None:
|
||||
safeNoteHex = "".join([hex(x)[2:].zfill(2) for x in safeNoteArr])
|
||||
print("safe_note:", decryptEntryValue(plain_nonce, bytes.fromhex(safeNoteHex)))
|
||||
|
||||
client.transport.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -36,12 +36,14 @@ import click
|
||||
from bottle import post, request, response, run
|
||||
|
||||
import trezorlib.mapping
|
||||
import trezorlib.messages
|
||||
import trezorlib.models
|
||||
import trezorlib.transport
|
||||
import trezorlib.transport.session as transport_session
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.protobuf import format_message
|
||||
from trezorlib.transport.bridge import BridgeTransport
|
||||
from trezorlib.ui import TrezorClientUI
|
||||
from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
# ignore bridge. we are the bridge
|
||||
BridgeTransport.ENABLED = False
|
||||
@ -59,15 +61,18 @@ logging.basicConfig(
|
||||
LOG = logging.getLogger()
|
||||
|
||||
|
||||
class SilentUI(TrezorClientUI):
|
||||
def get_pin(self, _code: t.Any) -> str:
|
||||
return ""
|
||||
def pin_callback(
|
||||
session: transport_session.Session, request: trezorlib.messages.PinMatrixRequest
|
||||
) -> t.Any:
|
||||
return session.call_raw(trezorlib.messages.PinMatrixAck(pin=""))
|
||||
|
||||
def get_passphrase(self) -> str:
|
||||
return ""
|
||||
|
||||
def button_request(self, _br: t.Any) -> None:
|
||||
pass
|
||||
def passphrase_callback(
|
||||
session: transport_session.Session, request: trezorlib.messages.PassphraseRequest
|
||||
) -> t.Any:
|
||||
return session.call_raw(
|
||||
trezorlib.messages.PassphraseAck(passphrase="", on_device=False)
|
||||
)
|
||||
|
||||
|
||||
class Session:
|
||||
@ -102,10 +107,16 @@ class Transport:
|
||||
self.path = transport.get_path()
|
||||
self.session: Session | None = None
|
||||
self.transport = transport
|
||||
self.protocol = ProtocolV1Channel(transport, trezorlib.mapping.DEFAULT_MAPPING)
|
||||
|
||||
client = TrezorClient(transport, ui=SilentUI())
|
||||
transport.open()
|
||||
client = TrezorClient(transport)
|
||||
client.pin_callback = pin_callback
|
||||
client.passphrase_callback = passphrase_callback
|
||||
self.model = client.model
|
||||
client.end_session()
|
||||
|
||||
client.get_seedless_session().end()
|
||||
transport.close()
|
||||
|
||||
def acquire(self, sid: str) -> str:
|
||||
if self.session_id() != sid:
|
||||
@ -114,11 +125,11 @@ class Transport:
|
||||
self.session.release()
|
||||
|
||||
self.session = Session(self)
|
||||
self.transport.begin_session()
|
||||
self.transport.open()
|
||||
return self.session.id
|
||||
|
||||
def release(self) -> None:
|
||||
self.transport.end_session()
|
||||
self.transport.close()
|
||||
self.session = None
|
||||
|
||||
def session_id(self) -> str | None:
|
||||
@ -139,10 +150,10 @@ class Transport:
|
||||
}
|
||||
|
||||
def write(self, msg_id: int, data: bytes) -> None:
|
||||
self.transport.write(msg_id, data)
|
||||
self.protocol._write(msg_id, data)
|
||||
|
||||
def read(self) -> tuple[int, bytes]:
|
||||
return self.transport.read()
|
||||
return self.protocol._read()
|
||||
|
||||
@classmethod
|
||||
def find(cls, path: str) -> Transport | None:
|
||||
|
@ -7,14 +7,17 @@
|
||||
import io
|
||||
import sys
|
||||
|
||||
from trezorlib import misc, ui
|
||||
from trezorlib import misc
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport import get_transport
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
client = TrezorClient(get_transport(), ui=ui.ClickUI())
|
||||
transport = get_transport()
|
||||
transport.open()
|
||||
client = TrezorClient(transport)
|
||||
session = client.get_seedless_session()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return
|
||||
@ -25,10 +28,10 @@ def main() -> None:
|
||||
|
||||
with io.open(arg1, "wb") as f:
|
||||
for _ in range(0, arg2, step):
|
||||
entropy = misc.get_entropy(client, step)
|
||||
entropy = misc.get_entropy(session, step)
|
||||
f.write(entropy)
|
||||
|
||||
client.close()
|
||||
transport.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -27,26 +27,29 @@ from trezorlib.client import TrezorClient
|
||||
from trezorlib.misc import decrypt_keyvalue, encrypt_keyvalue
|
||||
from trezorlib.tools import parse_path
|
||||
from trezorlib.transport import get_transport
|
||||
from trezorlib.ui import ClickUI
|
||||
|
||||
BIP32_PATH = parse_path("10016h/0")
|
||||
|
||||
|
||||
def encrypt(type: str, domain: str, secret: str) -> str:
|
||||
transport = get_transport()
|
||||
client = TrezorClient(transport, ClickUI())
|
||||
transport.open()
|
||||
client = TrezorClient(transport)
|
||||
session = client.get_seedless_session()
|
||||
dom = type.upper() + ": " + domain
|
||||
enc = encrypt_keyvalue(client, BIP32_PATH, dom, secret.encode(), False, True)
|
||||
client.close()
|
||||
enc = encrypt_keyvalue(session, BIP32_PATH, dom, secret.encode(), False, True)
|
||||
transport.close()
|
||||
return enc.hex()
|
||||
|
||||
|
||||
def decrypt(type: str, domain: str, secret: bytes) -> bytes:
|
||||
transport = get_transport()
|
||||
client = TrezorClient(transport, ClickUI())
|
||||
transport.open()
|
||||
client = TrezorClient(transport)
|
||||
session = client.get_seedless_session()
|
||||
dom = type.upper() + ": " + domain
|
||||
dec = decrypt_keyvalue(client, BIP32_PATH, dom, secret, False, True)
|
||||
client.close()
|
||||
dec = decrypt_keyvalue(session, BIP32_PATH, dom, secret, False, True)
|
||||
transport.close()
|
||||
return dec
|
||||
|
||||
|
||||
|
@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str):
|
||||
if __name__ == "__main__":
|
||||
wirelink = get_device()
|
||||
client = Client(wirelink)
|
||||
client.open()
|
||||
session = client.get_seedless_session()
|
||||
|
||||
i = 0
|
||||
|
||||
@ -76,10 +76,12 @@ if __name__ == "__main__":
|
||||
|
||||
# change PIN
|
||||
new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10)))
|
||||
client.set_input_flow(pin_input_flow(client, last_pin, new_pin))
|
||||
session.set_input_flow(pin_input_flow(client, last_pin, new_pin))
|
||||
device.change_pin(client)
|
||||
client.set_input_flow(None)
|
||||
session.set_input_flow(None)
|
||||
last_pin = new_pin
|
||||
|
||||
print(f"iteration {i}")
|
||||
i = i + 1
|
||||
|
||||
wirelink.close()
|
||||
|
@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, device, exceptions, messages
|
||||
from trezorlib.client import PASSPHRASE_ON_DEVICE
|
||||
from trezorlib.debuglink import DebugLink, LayoutType
|
||||
from trezorlib.protobuf import MessageType
|
||||
from trezorlib.tools import parse_path
|
||||
@ -66,8 +67,8 @@ def _center_button(debug: DebugLink) -> Tuple[int, int]:
|
||||
|
||||
def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int):
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
device_handler.run(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore
|
||||
device_handler.client.get_seedless_session().lock()
|
||||
device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore
|
||||
|
||||
assert "PinKeyboard" in debug.read_layout().all_components()
|
||||
|
||||
@ -106,7 +107,7 @@ def test_autolock_interrupts_signing(device_handler: "BackgroundDeviceHandler"):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
device_handler.run(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore
|
||||
device_handler.run_with_session(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore
|
||||
|
||||
assert (
|
||||
"1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1"
|
||||
@ -144,6 +145,10 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
|
||||
set_autolock_delay(device_handler, 10_000)
|
||||
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
# Prepare session to use later
|
||||
session = device_handler.client.get_session()
|
||||
|
||||
# try to sign a transaction
|
||||
inp1 = messages.TxInputType(
|
||||
address_n=parse_path("86h/0h/0h/0/0"),
|
||||
@ -159,8 +164,8 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
device_handler.run(
|
||||
btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET
|
||||
device_handler.run_with_provided_session(
|
||||
session, btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -190,14 +195,14 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
|
||||
|
||||
def sleepy_filter(msg: MessageType) -> MessageType:
|
||||
time.sleep(10.1)
|
||||
device_handler.client.set_filter(messages.TxAck, None)
|
||||
session.set_filter(messages.TxAck, None)
|
||||
return msg
|
||||
|
||||
with device_handler.client:
|
||||
device_handler.client.set_filter(messages.TxAck, sleepy_filter)
|
||||
with session:
|
||||
session.set_filter(messages.TxAck, sleepy_filter)
|
||||
# confirm transaction
|
||||
if debug.layout_type is LayoutType.Bolt:
|
||||
debug.click(debug.screen_buttons.ok())
|
||||
debug.click(debug.screen_buttons.ok(), hold_ms=1000)
|
||||
elif debug.layout_type is LayoutType.Delizia:
|
||||
debug.click(debug.screen_buttons.tap_to_confirm())
|
||||
elif debug.layout_type is LayoutType.Caesar:
|
||||
@ -206,7 +211,6 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
|
||||
signatures, tx = device_handler.result()
|
||||
assert len(signatures) == 1
|
||||
assert tx
|
||||
|
||||
assert device_handler.features().unlocked is False
|
||||
|
||||
|
||||
@ -216,8 +220,9 @@ def test_autolock_passphrase_keyboard(device_handler: "BackgroundDeviceHandler")
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
# get address
|
||||
device_handler.run(common.get_test_address) # type: ignore
|
||||
session = device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE)
|
||||
|
||||
device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore
|
||||
assert "PassphraseKeyboard" in debug.read_layout().all_components()
|
||||
|
||||
if debug.layout_type is LayoutType.Caesar:
|
||||
@ -253,8 +258,8 @@ def test_autolock_interrupts_passphrase(device_handler: "BackgroundDeviceHandler
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
# get address
|
||||
device_handler.run(common.get_test_address) # type: ignore
|
||||
|
||||
session = device_handler.client.get_session(passphrase=PASSPHRASE_ON_DEVICE)
|
||||
device_handler.run_with_provided_session(session, common.get_test_address) # type: ignore
|
||||
assert "PassphraseKeyboard" in debug.read_layout().all_components()
|
||||
|
||||
if debug.layout_type is LayoutType.Caesar:
|
||||
@ -293,7 +298,7 @@ def test_dryrun_locks_at_number_of_words(device_handler: "BackgroundDeviceHandle
|
||||
set_autolock_delay(device_handler, 10_000)
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
|
||||
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
|
||||
|
||||
layout = unlock_dry_run(debug)
|
||||
assert TR.recovery__num_of_words in debug.read_layout().text_content()
|
||||
@ -326,7 +331,7 @@ def test_dryrun_locks_at_word_entry(device_handler: "BackgroundDeviceHandler"):
|
||||
set_autolock_delay(device_handler, 10_000)
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
|
||||
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
|
||||
|
||||
unlock_dry_run(debug)
|
||||
|
||||
@ -353,7 +358,7 @@ def test_dryrun_enter_word_slowly(device_handler: "BackgroundDeviceHandler"):
|
||||
set_autolock_delay(device_handler, 10_000)
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
|
||||
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
|
||||
|
||||
unlock_dry_run(debug)
|
||||
|
||||
@ -418,7 +423,11 @@ def test_autolock_does_not_interrupt_preauthorized(
|
||||
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
device_handler.run(
|
||||
# Prepare session to use later
|
||||
session = device_handler.client.get_session()
|
||||
|
||||
device_handler.run_with_provided_session(
|
||||
session,
|
||||
btc.authorize_coinjoin,
|
||||
coordinator="www.example.com",
|
||||
max_rounds=2,
|
||||
@ -532,14 +541,15 @@ def test_autolock_does_not_interrupt_preauthorized(
|
||||
|
||||
def sleepy_filter(msg: MessageType) -> MessageType:
|
||||
time.sleep(10.1)
|
||||
device_handler.client.set_filter(messages.SignTx, None)
|
||||
session.set_filter(messages.SignTx, None)
|
||||
return msg
|
||||
|
||||
with device_handler.client:
|
||||
with session:
|
||||
# Start DoPreauthorized flow when device is unlocked. Wait 10s before
|
||||
# delivering SignTx, by that time autolock timer should have fired.
|
||||
device_handler.client.set_filter(messages.SignTx, sleepy_filter)
|
||||
device_handler.run(
|
||||
session.set_filter(messages.SignTx, sleepy_filter)
|
||||
device_handler.run_with_provided_session(
|
||||
session,
|
||||
btc.sign_tx,
|
||||
"Testnet",
|
||||
inputs,
|
||||
|
@ -52,7 +52,9 @@ def test_backup_slip39_custom(
|
||||
|
||||
assert features.initialized is False
|
||||
|
||||
device_handler.run(
|
||||
session = device_handler.client.get_seedless_session()
|
||||
device_handler.run_with_provided_session(
|
||||
session,
|
||||
device.setup,
|
||||
strength=128,
|
||||
backup_type=messages.BackupType.Slip39_Basic,
|
||||
@ -71,7 +73,7 @@ def test_backup_slip39_custom(
|
||||
# retrieve the result to check that it's not a TrezorFailure exception
|
||||
device_handler.result()
|
||||
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.backup,
|
||||
group_threshold=group_threshold,
|
||||
groups=[(share_threshold, share_count)],
|
||||
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import models
|
||||
from trezorlib import messages, models
|
||||
from trezorlib.debuglink import LayoutType
|
||||
|
||||
from .. import common
|
||||
@ -34,6 +34,9 @@ PIN4 = "1234"
|
||||
@pytest.mark.setup_client(pin=PIN4)
|
||||
def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
|
||||
debug = device_handler.debuglink()
|
||||
session = device_handler.client.get_seedless_session()
|
||||
session.call(messages.LockDevice())
|
||||
session.refresh_features()
|
||||
|
||||
short_duration = {
|
||||
models.T1B1: 500,
|
||||
@ -59,22 +62,25 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
|
||||
assert device_handler.features().unlocked is False
|
||||
|
||||
# unlock with message
|
||||
device_handler.run(common.get_test_address)
|
||||
device_handler.run_with_session(common.get_test_address)
|
||||
|
||||
assert "PinKeyboard" in debug.read_layout().all_components()
|
||||
debug.input("1234")
|
||||
assert device_handler.result()
|
||||
|
||||
session.refresh_features()
|
||||
assert device_handler.features().unlocked is True
|
||||
|
||||
# short touch
|
||||
hold(short_duration)
|
||||
|
||||
time.sleep(0.5) # so that the homescreen appears again (hacky)
|
||||
session.refresh_features()
|
||||
assert device_handler.features().unlocked is True
|
||||
|
||||
# lock
|
||||
hold(lock_duration)
|
||||
session.refresh_features()
|
||||
assert device_handler.features().unlocked is False
|
||||
|
||||
# unlock by touching
|
||||
@ -86,8 +92,10 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
|
||||
assert "PinKeyboard" in layout.all_components()
|
||||
debug.input("1234")
|
||||
|
||||
session.refresh_features()
|
||||
assert device_handler.features().unlocked is True
|
||||
|
||||
# lock
|
||||
hold(lock_duration)
|
||||
session.refresh_features()
|
||||
assert device_handler.features().unlocked is False
|
||||
|
@ -73,7 +73,7 @@ def prepare_passphrase_dialogue(
|
||||
device_handler: "BackgroundDeviceHandler", address: Optional[str] = None
|
||||
) -> Generator["DebugLink", None, None]:
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(get_test_address) # type: ignore
|
||||
device_handler.run_with_session(get_test_address) # type: ignore
|
||||
assert debug.read_layout().main_component() == "PassphraseKeyboard"
|
||||
|
||||
# Resetting the category as it could have been changed by previous tests
|
||||
|
@ -91,7 +91,7 @@ def prepare_passphrase_dialogue(
|
||||
device_handler: "BackgroundDeviceHandler", address: Optional[str] = None
|
||||
) -> Generator["DebugLink", None, None]:
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(get_test_address) # type: ignore
|
||||
device_handler.run_with_session(get_test_address) # type: ignore
|
||||
layout = debug.read_layout()
|
||||
assert "PassphraseKeyboard" in layout.all_components()
|
||||
assert layout.passphrase() == ""
|
||||
|
@ -90,17 +90,19 @@ def prepare(
|
||||
|
||||
tap = False
|
||||
|
||||
device_handler.client.get_seedless_session().lock()
|
||||
|
||||
# Setup according to the wanted situation
|
||||
if situation == Situation.PIN_INPUT:
|
||||
# Any action triggering the PIN dialogue
|
||||
device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
|
||||
device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
|
||||
tap = True
|
||||
if situation == Situation.PIN_INPUT_CANCEL:
|
||||
# Any action triggering the PIN dialogue
|
||||
device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
|
||||
device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
|
||||
elif situation == Situation.PIN_SETUP:
|
||||
# Set new PIN
|
||||
device_handler.run(device.change_pin) # type: ignore
|
||||
device_handler.run_with_session(device.change_pin) # type: ignore
|
||||
assert (
|
||||
TR.pin__turn_on in debug.read_layout().text_content()
|
||||
or TR.pin__info in debug.read_layout().text_content()
|
||||
@ -114,14 +116,14 @@ def prepare(
|
||||
go_next(debug)
|
||||
elif situation == Situation.PIN_CHANGE:
|
||||
# Change PIN
|
||||
device_handler.run(device.change_pin) # type: ignore
|
||||
device_handler.run_with_session(device.change_pin) # type: ignore
|
||||
_input_see_confirm(debug, old_pin)
|
||||
assert TR.pin__change in debug.read_layout().text_content()
|
||||
go_next(debug)
|
||||
_input_see_confirm(debug, old_pin)
|
||||
elif situation == Situation.WIPE_CODE_SETUP:
|
||||
# Set wipe code
|
||||
device_handler.run(device.change_wipe_code) # type: ignore
|
||||
device_handler.run_with_session(device.change_wipe_code) # type: ignore
|
||||
if old_pin:
|
||||
_input_see_confirm(debug, old_pin)
|
||||
assert TR.wipe_code__turn_on in debug.read_layout().text_content()
|
||||
|
@ -40,7 +40,7 @@ def prepare_recovery_and_evaluate(
|
||||
features = device_handler.features()
|
||||
debug = device_handler.debuglink()
|
||||
assert features.initialized is False
|
||||
device_handler.run(device.recover, pin_protection=False) # type: ignore
|
||||
device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore
|
||||
|
||||
yield debug
|
||||
|
||||
@ -58,7 +58,7 @@ def prepare_recovery_and_evaluate_cancel(
|
||||
features = device_handler.features()
|
||||
debug = device_handler.debuglink()
|
||||
assert features.initialized is False
|
||||
device_handler.run(device.recover, pin_protection=False) # type: ignore
|
||||
device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore
|
||||
|
||||
yield debug
|
||||
|
||||
@ -113,10 +113,11 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
# initiate and confirm the recovery
|
||||
device_handler.run(device.recover, type=messages.RecoveryType.DryRun)
|
||||
device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
|
||||
recovery.confirm_recovery(debug, title="recovery__title_dry_run")
|
||||
# select number of words
|
||||
recovery.select_number_of_words(debug, num_of_words=12)
|
||||
device_handler.client.transport.close()
|
||||
# abort the process running the recovery from host
|
||||
device_handler.kill_task()
|
||||
|
||||
@ -124,16 +125,20 @@ def test_recovery_cancel_issue4613(device_handler: "BackgroundDeviceHandler"):
|
||||
# from the host side.
|
||||
|
||||
# Reopen client and debuglink, closed by kill_task
|
||||
device_handler.client.open()
|
||||
device_handler.client.transport.open()
|
||||
debug = device_handler.debuglink()
|
||||
|
||||
# Ping the Trezor with an Initialize message (listed in DO_NOT_RESTART)
|
||||
try:
|
||||
features = device_handler.client.call(messages.Initialize())
|
||||
features = device_handler.client.get_seedless_session().call(
|
||||
messages.Initialize()
|
||||
)
|
||||
except exceptions.Cancelled:
|
||||
# due to a related problem, the first call in this situation will return
|
||||
# a Cancelled failure. This test does not care, we just retry.
|
||||
features = device_handler.client.call(messages.Initialize())
|
||||
features = device_handler.client.get_seedless_session().call(
|
||||
messages.Initialize()
|
||||
)
|
||||
|
||||
assert features.recovery_status == messages.RecoveryStatus.Recovery
|
||||
# Trezor is sitting in recovery_homescreen now, waiting for the user to select
|
||||
|
@ -40,7 +40,7 @@ def test_repeated_backup(
|
||||
|
||||
assert features.initialized is False
|
||||
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.setup,
|
||||
strength=128,
|
||||
backup_type=messages.BackupType.Slip39_Basic,
|
||||
@ -93,7 +93,7 @@ def test_repeated_backup(
|
||||
assert features.recovery_status == messages.RecoveryStatus.Nothing
|
||||
|
||||
# run recovery to unlock backup
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.recover,
|
||||
type=messages.RecoveryType.UnlockRepeatedBackup,
|
||||
)
|
||||
@ -160,7 +160,7 @@ def test_repeated_backup(
|
||||
assert features.recovery_status == messages.RecoveryStatus.Nothing
|
||||
|
||||
# try to unlock backup again...
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.recover,
|
||||
type=messages.RecoveryType.UnlockRepeatedBackup,
|
||||
)
|
||||
@ -200,7 +200,7 @@ def test_repeated_backup(
|
||||
assert features.recovery_status == messages.RecoveryStatus.Nothing
|
||||
|
||||
# try to unlock backup yet again...
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.recover,
|
||||
type=messages.RecoveryType.UnlockRepeatedBackup,
|
||||
)
|
||||
|
@ -39,7 +39,7 @@ def test_reset_bip39(device_handler: "BackgroundDeviceHandler"):
|
||||
|
||||
assert features.initialized is False
|
||||
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.setup,
|
||||
strength=128,
|
||||
backup_type=messages.BackupType.Bip39,
|
||||
|
@ -50,7 +50,7 @@ def test_reset_slip39_advanced(
|
||||
|
||||
assert features.initialized is False
|
||||
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.setup,
|
||||
backup_type=messages.BackupType.Slip39_Advanced,
|
||||
pin_protection=False,
|
||||
|
@ -46,7 +46,7 @@ def test_reset_slip39_basic(
|
||||
|
||||
assert features.initialized is False
|
||||
|
||||
device_handler.run(
|
||||
device_handler.run_with_session(
|
||||
device.setup,
|
||||
strength=128,
|
||||
backup_type=messages.BackupType.Slip39_Basic,
|
||||
|
@ -39,7 +39,7 @@ def prepare_tutorial_and_cancel_after_it(
|
||||
device_handler: "BackgroundDeviceHandler", cancelled: bool = False
|
||||
) -> Generator["DebugLink", None, None]:
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(device.show_device_tutorial)
|
||||
device_handler.run_with_session(device.show_device_tutorial)
|
||||
|
||||
yield debug
|
||||
|
||||
|
@ -35,7 +35,7 @@ pytestmark = [
|
||||
|
||||
def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"):
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(device.show_device_tutorial)
|
||||
device_handler.run_with_session(device.show_device_tutorial)
|
||||
|
||||
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
|
||||
debug.click(debug.screen_buttons.tap_to_confirm())
|
||||
@ -55,7 +55,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"):
|
||||
|
||||
def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"):
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(device.show_device_tutorial)
|
||||
device_handler.run_with_session(device.show_device_tutorial)
|
||||
|
||||
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
|
||||
debug.click(debug.screen_buttons.tap_to_confirm())
|
||||
@ -81,7 +81,7 @@ def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"):
|
||||
|
||||
def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"):
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(device.show_device_tutorial)
|
||||
device_handler.run_with_session(device.show_device_tutorial)
|
||||
|
||||
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
|
||||
debug.click(debug.screen_buttons.tap_to_confirm())
|
||||
@ -104,7 +104,7 @@ def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"):
|
||||
|
||||
def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"):
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(device.show_device_tutorial)
|
||||
device_handler.run_with_session(device.show_device_tutorial)
|
||||
|
||||
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
|
||||
debug.click(debug.screen_buttons.tap_to_confirm())
|
||||
@ -134,7 +134,7 @@ def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"):
|
||||
|
||||
def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"):
|
||||
debug = device_handler.debuglink()
|
||||
device_handler.run(device.show_device_tutorial)
|
||||
device_handler.run_with_session(device.show_device_tutorial)
|
||||
|
||||
assert debug.read_layout().title() == TR.tutorial__welcome_safe5
|
||||
debug.click(debug.screen_buttons.tap_to_confirm())
|
||||
|
@ -32,8 +32,8 @@ if TYPE_CHECKING:
|
||||
from _pytest.mark.structures import MarkDecorator
|
||||
|
||||
from trezorlib.debuglink import DebugLink
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.messages import ButtonRequest
|
||||
from trezorlib.transport.session import Session
|
||||
|
||||
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
|
||||
|
||||
@ -336,10 +336,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None:
|
||||
assert got >= expected
|
||||
|
||||
|
||||
def get_test_address(client: "Client") -> str:
|
||||
def get_test_address(session: "Session") -> str:
|
||||
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
|
||||
protected call, or to identify the root secret (seed+passphrase)"""
|
||||
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
|
||||
return btc.get_address(session, "Testnet", TEST_ADDRESS_N)
|
||||
|
||||
|
||||
def compact_size(n: int) -> bytes:
|
||||
@ -378,5 +378,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
|
||||
debug.swipe_up()
|
||||
|
||||
|
||||
def is_core(client: "Client") -> bool:
|
||||
return client.model is not models.T1B1
|
||||
def is_core(session: "Session") -> bool:
|
||||
return session.model is not models.T1B1
|
||||
|
@ -21,6 +21,7 @@ import os
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
import xdist
|
||||
@ -31,7 +32,8 @@ from trezorlib import debuglink, log, models
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.device import apply_settings
|
||||
from trezorlib.device import wipe as wipe_device
|
||||
from trezorlib.transport import enumerate_devices, get_transport, protocol
|
||||
from trezorlib.transport import enumerate_devices, get_transport
|
||||
from trezorlib.transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
# register rewrites before importing from local package
|
||||
# so that we see details of failed asserts from this module
|
||||
@ -49,6 +51,7 @@ if t.TYPE_CHECKING:
|
||||
from _pytest.terminal import TerminalReporter
|
||||
|
||||
from trezorlib._internal.emulator import Emulator
|
||||
from trezorlib.debuglink import SessionDebugWrapper
|
||||
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
@ -78,7 +81,7 @@ def core_emulator(request: pytest.FixtureRequest) -> t.Iterator[Emulator]:
|
||||
"""Fixture returning default core emulator with possibility of screen recording."""
|
||||
with EmulatorWrapper("core", main_args=_emulator_wrapper_main_args()) as emu:
|
||||
# Modifying emu.client to add screen recording (when --ui=test is used)
|
||||
with ui_tests.screen_recording(emu.client, request) as _:
|
||||
with ui_tests.screen_recording(emu.client, request, lambda: emu.client) as _:
|
||||
yield emu
|
||||
|
||||
|
||||
@ -127,7 +130,15 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _raw_client(request: pytest.FixtureRequest) -> Client:
|
||||
def _raw_client(request: pytest.FixtureRequest) -> t.Generator[Client, None, None]:
|
||||
client = _get_raw_client(request)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close_transport()
|
||||
|
||||
|
||||
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
|
||||
# In case tests run in parallel, each process has its own emulator/client.
|
||||
# Requesting the emulator fixture only if relevant.
|
||||
if request.session.config.getoption("control_emulators"):
|
||||
@ -137,7 +148,7 @@ def _raw_client(request: pytest.FixtureRequest) -> Client:
|
||||
interact = os.environ.get("INTERACT") == "1"
|
||||
if not interact:
|
||||
# prevent tests from getting stuck in case there is an USB packet loss
|
||||
protocol._DEFAULT_READ_TIMEOUT = 50.0
|
||||
ProtocolV1Channel._DEFAULT_READ_TIMEOUT = 50.0
|
||||
|
||||
path = os.environ.get("TREZOR_PATH")
|
||||
if path:
|
||||
@ -153,7 +164,7 @@ def _client_from_path(
|
||||
) -> Client:
|
||||
try:
|
||||
transport = get_transport(path)
|
||||
return Client(transport, auto_interact=not interact)
|
||||
return Client(transport, auto_interact=not interact, open_transport=True)
|
||||
except Exception as e:
|
||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||
raise RuntimeError(f"Failed to open debuglink for {path}") from e
|
||||
@ -162,10 +173,7 @@ def _client_from_path(
|
||||
def _find_client(request: pytest.FixtureRequest, interact: bool) -> Client:
|
||||
devices = enumerate_devices()
|
||||
for device in devices:
|
||||
try:
|
||||
return Client(device, auto_interact=not interact)
|
||||
except Exception:
|
||||
pass
|
||||
return Client(device, auto_interact=not interact, open_transport=True)
|
||||
|
||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||
raise RuntimeError("No debuggable device found")
|
||||
@ -240,7 +248,7 @@ class ModelsFilter:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(
|
||||
def _client_unlocked(
|
||||
request: pytest.FixtureRequest, _raw_client: Client
|
||||
) -> t.Generator[Client, None, None]:
|
||||
"""Client fixture.
|
||||
@ -280,14 +288,14 @@ def client(
|
||||
|
||||
test_ui = request.config.getoption("ui")
|
||||
|
||||
_raw_client.reset_debug_features()
|
||||
_raw_client.open()
|
||||
try:
|
||||
_raw_client.sync_responses()
|
||||
_raw_client.init_device()
|
||||
except Exception:
|
||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||
pytest.fail("Failed to communicate with Trezor")
|
||||
# _raw_client.reset_debug_features()
|
||||
if isinstance(_raw_client.protocol, ProtocolV1Channel):
|
||||
try:
|
||||
_raw_client.sync_responses()
|
||||
except Exception:
|
||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||
pytest.fail("Failed to communicate with Trezor")
|
||||
_raw_client._seedless_session = _raw_client.get_seedless_session(new_session=True)
|
||||
|
||||
# Resetting all the debug events to not be influenced by previous test
|
||||
_raw_client.debug.reset_debug_events()
|
||||
@ -300,13 +308,20 @@ def client(
|
||||
should_format = sd_marker.kwargs.get("formatted", True)
|
||||
_raw_client.debug.erase_sd_card(format=should_format)
|
||||
|
||||
wipe_device(_raw_client)
|
||||
if _raw_client.is_invalidated:
|
||||
_raw_client = _raw_client.get_new_client()
|
||||
session = _raw_client.get_seedless_session()
|
||||
wipe_device(session)
|
||||
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
|
||||
|
||||
if not _raw_client.features.bootloader_mode:
|
||||
_raw_client.refresh_features()
|
||||
|
||||
# Load language again, as it got erased in wipe
|
||||
if _raw_client.model is not models.T1B1:
|
||||
lang = request.session.config.getoption("lang") or "en"
|
||||
assert isinstance(lang, str)
|
||||
translations.set_language(_raw_client, lang)
|
||||
translations.set_language(_raw_client.get_seedless_session(), lang)
|
||||
|
||||
setup_params = dict(
|
||||
uninitialized=False,
|
||||
@ -324,32 +339,59 @@ def client(
|
||||
use_passphrase = setup_params["passphrase"] is True or isinstance(
|
||||
setup_params["passphrase"], str
|
||||
)
|
||||
|
||||
if not setup_params["uninitialized"]:
|
||||
session = _raw_client.get_seedless_session(new_session=True)
|
||||
debuglink.load_device(
|
||||
_raw_client,
|
||||
session,
|
||||
mnemonic=setup_params["mnemonic"], # type: ignore
|
||||
pin=setup_params["pin"], # type: ignore
|
||||
passphrase_protection=use_passphrase,
|
||||
label="test",
|
||||
needs_backup=setup_params["needs_backup"], # type: ignore
|
||||
no_backup=setup_params["no_backup"], # type: ignore
|
||||
_skip_init_device=True,
|
||||
_skip_init_device=False,
|
||||
)
|
||||
_raw_client._setup_pin = setup_params["pin"]
|
||||
|
||||
if request.node.get_closest_marker("experimental"):
|
||||
apply_settings(_raw_client, experimental_features=True)
|
||||
apply_settings(session, experimental_features=True)
|
||||
session.end()
|
||||
|
||||
if use_passphrase and isinstance(setup_params["passphrase"], str):
|
||||
_raw_client.use_passphrase(setup_params["passphrase"])
|
||||
yield _raw_client
|
||||
|
||||
_raw_client.lock(_refresh_features=False)
|
||||
_raw_client.init_device(new_session=True)
|
||||
|
||||
with ui_tests.screen_recording(_raw_client, request):
|
||||
yield _raw_client
|
||||
@pytest.fixture(scope="function")
|
||||
def client(
|
||||
request: pytest.FixtureRequest, _client_unlocked: Client
|
||||
) -> t.Generator[Client, None, None]:
|
||||
_client_unlocked.lock()
|
||||
with ui_tests.screen_recording(_client_unlocked, request):
|
||||
yield _client_unlocked
|
||||
|
||||
_raw_client.close()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def session(
|
||||
request: pytest.FixtureRequest, _client_unlocked: Client
|
||||
) -> t.Generator[SessionDebugWrapper, None, None]:
|
||||
if bool(request.node.get_closest_marker("uninitialized_session")):
|
||||
session = _client_unlocked.get_seedless_session()
|
||||
else:
|
||||
derive_cardano = bool(request.node.get_closest_marker("cardano"))
|
||||
passphrase = ""
|
||||
marker = request.node.get_closest_marker("setup_client")
|
||||
if marker and isinstance(marker.kwargs.get("passphrase"), str):
|
||||
passphrase = marker.kwargs["passphrase"]
|
||||
if _client_unlocked._setup_pin is not None:
|
||||
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
|
||||
session = _client_unlocked.get_session(
|
||||
derive_cardano=derive_cardano, passphrase=passphrase
|
||||
)
|
||||
|
||||
if _client_unlocked._setup_pin is not None:
|
||||
session.lock()
|
||||
with ui_tests.screen_recording(_client_unlocked, request):
|
||||
yield session
|
||||
# Calling session.end() is not needed since the device gets wiped later anyway.
|
||||
|
||||
|
||||
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
|
||||
@ -467,6 +509,10 @@ def pytest_configure(config: "Config") -> None:
|
||||
"markers",
|
||||
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"uninitialized_session: use uninitialized session instance",
|
||||
)
|
||||
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
|
||||
for line in f:
|
||||
config.addinivalue_line("markers", line.strip())
|
||||
|
@ -6,12 +6,12 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import typing_extensions as tx
|
||||
|
||||
from trezorlib.client import PASSPHRASE_ON_DEVICE
|
||||
from trezorlib.messages import DebugWaitType
|
||||
from trezorlib.transport import udp
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from trezorlib._internal.emulator import Emulator
|
||||
from trezorlib.debuglink import DebugLink
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.messages import Features
|
||||
|
||||
@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1
|
||||
|
||||
|
||||
class NullUI:
|
||||
@staticmethod
|
||||
def clear(*args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def button_request(code):
|
||||
pass
|
||||
@ -50,11 +54,29 @@ class BackgroundDeviceHandler:
|
||||
self.client = client
|
||||
self.client.ui = NullUI # type: ignore [NullUI is OK UI]
|
||||
self.client.watch_layout(True)
|
||||
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
|
||||
|
||||
def run(
|
||||
def run_with_session(
|
||||
self,
|
||||
function: t.Callable[tx.Concatenate["Client", P], t.Any],
|
||||
function: t.Callable[tx.Concatenate["Session", P], t.Any],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""Runs some function that interacts with a device.
|
||||
|
||||
Makes sure the UI is updated before returning.
|
||||
"""
|
||||
if self.task is not None:
|
||||
raise RuntimeError("Wait for previous task first")
|
||||
|
||||
# wait for the first UI change triggered by the task running in the background
|
||||
session = self.client.get_session()
|
||||
with self.debuglink().wait_for_layout_change():
|
||||
self.task = self._pool.submit(function, session, *args, **kwargs)
|
||||
|
||||
def run_with_provided_session(
|
||||
self,
|
||||
session,
|
||||
function: t.Callable[tx.Concatenate["Session", P], t.Any],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
@ -67,15 +89,13 @@ class BackgroundDeviceHandler:
|
||||
|
||||
# wait for the first UI change triggered by the task running in the background
|
||||
with self.debuglink().wait_for_layout_change():
|
||||
self.task = self._pool.submit(function, self.client, *args, **kwargs)
|
||||
self.task = self._pool.submit(function, session, *args, **kwargs)
|
||||
|
||||
def kill_task(self) -> None:
|
||||
if self.task is not None:
|
||||
# Force close the client, which should raise an exception in a client
|
||||
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
||||
# a close() method.
|
||||
while self.client.session_counter > 0:
|
||||
self.client.close()
|
||||
try:
|
||||
self.task.result(timeout=1)
|
||||
except Exception:
|
||||
@ -99,7 +119,7 @@ class BackgroundDeviceHandler:
|
||||
def features(self) -> "Features":
|
||||
if self.task is not None:
|
||||
raise RuntimeError("Cannot query features while task is running")
|
||||
self.client.init_device()
|
||||
self.client.refresh_features()
|
||||
return self.client.features
|
||||
|
||||
def debuglink(self) -> "DebugLink":
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib.binance import get_address
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from ...input_flows import InputFlowShowAddressQRCode
|
||||
@ -38,23 +38,23 @@ BINANCE_ADDRESS_TEST_VECTORS = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS)
|
||||
def test_binance_get_address(client: Client, path: str, expected_address: str):
|
||||
def test_binance_get_address(session: Session, path: str, expected_address: str):
|
||||
# data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50
|
||||
|
||||
address = get_address(client, parse_path(path), show_display=True)
|
||||
address = get_address(session, parse_path(path), show_display=True)
|
||||
assert address == expected_address
|
||||
|
||||
|
||||
@pytest.mark.parametrize("path, expected_address", BINANCE_ADDRESS_TEST_VECTORS)
|
||||
def test_binance_get_address_chunkify_details(
|
||||
client: Client, path: str, expected_address: str
|
||||
session: Session, path: str, expected_address: str
|
||||
):
|
||||
# data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50
|
||||
|
||||
with client:
|
||||
IF = InputFlowShowAddressQRCode(client)
|
||||
client.set_input_flow(IF.get())
|
||||
with session:
|
||||
IF = InputFlowShowAddressQRCode(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
address = get_address(
|
||||
client, parse_path(path), show_display=True, chunkify=True
|
||||
session, parse_path(path), show_display=True, chunkify=True
|
||||
)
|
||||
assert address == expected_address
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import binance
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from ...input_flows import InputFlowShowXpubQRCode
|
||||
@ -31,11 +31,11 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0")
|
||||
@pytest.mark.setup_client(
|
||||
mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin"
|
||||
)
|
||||
def test_binance_get_public_key(client: Client):
|
||||
with client:
|
||||
IF = InputFlowShowXpubQRCode(client)
|
||||
client.set_input_flow(IF.get())
|
||||
sig = binance.get_public_key(client, BINANCE_PATH, show_display=True)
|
||||
def test_binance_get_public_key(session: Session):
|
||||
with session:
|
||||
IF = InputFlowShowXpubQRCode(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
sig = binance.get_public_key(session, BINANCE_PATH, show_display=True)
|
||||
assert (
|
||||
sig.hex()
|
||||
== "029729a52e4e3c2b4a4e52aa74033eedaf8ba1df5ab6d1f518fd69e67bbd309b0e"
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import binance
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
BINANCE_TEST_VECTORS = [
|
||||
@ -110,10 +110,10 @@ BINANCE_TEST_VECTORS = [
|
||||
@pytest.mark.parametrize("message, expected_response", BINANCE_TEST_VECTORS)
|
||||
@pytest.mark.parametrize("chunkify", (True, False))
|
||||
def test_binance_sign_message(
|
||||
client: Client, chunkify: bool, message: dict, expected_response: dict
|
||||
session: Session, chunkify: bool, message: dict, expected_response: dict
|
||||
):
|
||||
response = binance.sign_tx(
|
||||
client, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify
|
||||
session, parse_path("m/44h/714h/0h/0/0"), message, chunkify=chunkify
|
||||
)
|
||||
|
||||
assert response.public_key.hex() == expected_response["public_key"]
|
||||
|
@ -4,6 +4,7 @@ from hashlib import sha256
|
||||
from ecdsa import SECP256k1, SigningKey
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.transport.session import Session
|
||||
|
||||
from ...common import compact_size
|
||||
|
||||
@ -27,7 +28,12 @@ def hash_bytes_prefixed(hasher, data):
|
||||
|
||||
|
||||
def make_payment_request(
|
||||
client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None
|
||||
session: Session,
|
||||
recipient_name,
|
||||
outputs,
|
||||
change_addresses=None,
|
||||
memos=None,
|
||||
nonce=None,
|
||||
):
|
||||
h_pr = sha256(b"SL\x00\x24")
|
||||
|
||||
@ -52,7 +58,7 @@ def make_payment_request(
|
||||
hash_bytes_prefixed(h_pr, memo.text.encode())
|
||||
elif isinstance(memo, RefundMemo):
|
||||
address_resp = btc.get_authenticated_address(
|
||||
client, "Testnet", memo.address_n
|
||||
session, "Testnet", memo.address_n
|
||||
)
|
||||
msg_memo = messages.RefundMemo(
|
||||
address=address_resp.address, mac=address_resp.mac
|
||||
@ -63,7 +69,7 @@ def make_payment_request(
|
||||
hash_bytes_prefixed(h_pr, address_resp.address.encode())
|
||||
elif isinstance(memo, CoinPurchaseMemo):
|
||||
address_resp = btc.get_authenticated_address(
|
||||
client, memo.coin_name, memo.address_n
|
||||
session, memo.coin_name, memo.address_n
|
||||
)
|
||||
msg_memo = messages.CoinPurchaseMemo(
|
||||
coin_type=memo.slip44,
|
||||
|
@ -19,6 +19,7 @@ import time
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, device, messages
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
from trezorlib.tools import parse_path
|
||||
@ -59,15 +60,15 @@ SLIP25_PATH = parse_path("m/10025h")
|
||||
|
||||
@pytest.mark.parametrize("chunkify", (True, False))
|
||||
@pytest.mark.setup_client(pin=PIN)
|
||||
def test_sign_tx(client: Client, chunkify: bool):
|
||||
def test_sign_tx(session: Session, chunkify: bool):
|
||||
# NOTE: FAKE input tx
|
||||
|
||||
assert session.features.unlocked is False
|
||||
commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big")
|
||||
|
||||
with client:
|
||||
client.use_pin_sequence([PIN])
|
||||
with session:
|
||||
session.client.use_pin_sequence([PIN])
|
||||
btc.authorize_coinjoin(
|
||||
client,
|
||||
session,
|
||||
coordinator="www.example.com",
|
||||
max_rounds=2,
|
||||
max_coordinator_fee_rate=500_000, # 0.5 %
|
||||
@ -77,14 +78,14 @@ def test_sign_tx(client: Client, chunkify: bool):
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
)
|
||||
|
||||
client.call(messages.LockDevice())
|
||||
session.call(messages.LockDevice())
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[messages.PreauthorizedRequest, messages.OwnershipProof]
|
||||
)
|
||||
btc.get_ownership_proof(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -93,12 +94,12 @@ def test_sign_tx(client: Client, chunkify: bool):
|
||||
preauthorized=True,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[messages.PreauthorizedRequest, messages.OwnershipProof]
|
||||
)
|
||||
btc.get_ownership_proof(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/5"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -206,8 +207,8 @@ def test_sign_tx(client: Client, chunkify: bool):
|
||||
no_fee_indices=[],
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
messages.PreauthorizedRequest(),
|
||||
request_input(0),
|
||||
@ -222,7 +223,7 @@ def test_sign_tx(client: Client, chunkify: bool):
|
||||
]
|
||||
)
|
||||
signatures, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
@ -243,7 +244,7 @@ def test_sign_tx(client: Client, chunkify: bool):
|
||||
|
||||
# Test for a second time.
|
||||
btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
@ -256,7 +257,7 @@ def test_sign_tx(client: Client, chunkify: bool):
|
||||
# Test for a third time, number of rounds should be exceeded.
|
||||
with pytest.raises(TrezorFailure, match="No preauthorized operation"):
|
||||
btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
@ -267,7 +268,7 @@ def test_sign_tx(client: Client, chunkify: bool):
|
||||
)
|
||||
|
||||
|
||||
def test_sign_tx_large(client: Client):
|
||||
def test_sign_tx_large(session: Session):
|
||||
# NOTE: FAKE input tx
|
||||
|
||||
commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big")
|
||||
@ -278,17 +279,16 @@ def test_sign_tx_large(client: Client):
|
||||
output_denom = 10_000 # sats
|
||||
max_expected_delay = 80 # seconds
|
||||
|
||||
with client:
|
||||
btc.authorize_coinjoin(
|
||||
client,
|
||||
coordinator="www.example.com",
|
||||
max_rounds=2,
|
||||
max_coordinator_fee_rate=500_000, # 0.5 %
|
||||
max_fee_per_kvbyte=3500,
|
||||
n=parse_path("m/10025h/1h/0h/1h"),
|
||||
coin_name="Testnet",
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
)
|
||||
btc.authorize_coinjoin(
|
||||
session,
|
||||
coordinator="www.example.com",
|
||||
max_rounds=2,
|
||||
max_coordinator_fee_rate=500_000, # 0.5 %
|
||||
max_fee_per_kvbyte=3500,
|
||||
n=parse_path("m/10025h/1h/0h/1h"),
|
||||
coin_name="Testnet",
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
)
|
||||
|
||||
# INPUTS.
|
||||
|
||||
@ -399,22 +399,21 @@ def test_sign_tx_large(client: Client):
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
with client:
|
||||
btc.sign_tx(
|
||||
client,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
prev_txes=TX_CACHE_TESTNET,
|
||||
coinjoin_request=coinjoin_req,
|
||||
preauthorized=True,
|
||||
serialize=False,
|
||||
)
|
||||
btc.sign_tx(
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
prev_txes=TX_CACHE_TESTNET,
|
||||
coinjoin_request=coinjoin_req,
|
||||
preauthorized=True,
|
||||
serialize=False,
|
||||
)
|
||||
delay = time.time() - start
|
||||
assert delay <= max_expected_delay
|
||||
|
||||
|
||||
def test_sign_tx_spend(client: Client):
|
||||
def test_sign_tx_spend(session: Session):
|
||||
# NOTE: FAKE input tx
|
||||
|
||||
inputs = [
|
||||
@ -446,15 +445,15 @@ def test_sign_tx_spend(client: Client):
|
||||
# Ensure that Trezor refuses to spend from CoinJoin without user authorization.
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
prev_txes=TX_CACHE_TESTNET,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
messages.ButtonRequest(code=B.Other),
|
||||
messages.UnlockedPathRequest,
|
||||
@ -462,7 +461,7 @@ def test_sign_tx_spend(client: Client):
|
||||
request_output(0),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
@ -472,7 +471,7 @@ def test_sign_tx_spend(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
@ -487,7 +486,7 @@ def test_sign_tx_spend(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_sign_tx_migration(client: Client):
|
||||
def test_sign_tx_migration(session: Session):
|
||||
inputs = [
|
||||
messages.TxInputType(
|
||||
address_n=parse_path("m/84h/1h/3h/0/12"),
|
||||
@ -520,15 +519,15 @@ def test_sign_tx_migration(client: Client):
|
||||
# Ensure that Trezor refuses to receive to CoinJoin path without the user first authorizing access to CoinJoin paths.
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
prev_txes=TX_CACHE_TESTNET,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
messages.ButtonRequest(code=B.Other),
|
||||
messages.UnlockedPathRequest,
|
||||
@ -536,7 +535,7 @@ def test_sign_tx_migration(client: Client):
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(TXHASH_2cc3c1),
|
||||
@ -558,7 +557,7 @@ def test_sign_tx_migration(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
inputs,
|
||||
outputs,
|
||||
@ -573,11 +572,11 @@ def test_sign_tx_migration(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_wrong_coordinator(client: Client):
|
||||
def test_wrong_coordinator(session: Session):
|
||||
# Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator.
|
||||
|
||||
btc.authorize_coinjoin(
|
||||
client,
|
||||
session,
|
||||
coordinator="www.example.com",
|
||||
max_rounds=10,
|
||||
max_coordinator_fee_rate=500_000, # 0.5 %
|
||||
@ -589,7 +588,7 @@ def test_wrong_coordinator(client: Client):
|
||||
|
||||
with pytest.raises(TrezorFailure, match="Unauthorized operation"):
|
||||
btc.get_ownership_proof(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -599,9 +598,9 @@ def test_wrong_coordinator(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_wrong_account_type(client: Client):
|
||||
def test_wrong_account_type(session: Session):
|
||||
params = {
|
||||
"client": client,
|
||||
"session": session,
|
||||
"coordinator": "www.example.com",
|
||||
"max_rounds": 10,
|
||||
"max_coordinator_fee_rate": 500_000, # 0.5 %
|
||||
@ -625,11 +624,11 @@ def test_wrong_account_type(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_cancel_authorization(client: Client):
|
||||
def test_cancel_authorization(session: Session):
|
||||
# Ensure that a preauthorized GetOwnershipProof fails if the commitment_data doesn't match the coordinator.
|
||||
|
||||
btc.authorize_coinjoin(
|
||||
client,
|
||||
session,
|
||||
coordinator="www.example.com",
|
||||
max_rounds=10,
|
||||
max_coordinator_fee_rate=500_000, # 0.5 %
|
||||
@ -639,11 +638,11 @@ def test_cancel_authorization(client: Client):
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
)
|
||||
|
||||
device.cancel_authorization(client)
|
||||
device.cancel_authorization(session)
|
||||
|
||||
with pytest.raises(TrezorFailure, match="No preauthorized operation"):
|
||||
btc.get_ownership_proof(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -653,35 +652,35 @@ def test_cancel_authorization(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_get_public_key(client: Client):
|
||||
def test_get_public_key(session: Session):
|
||||
ACCOUNT_PATH = parse_path("m/10025h/1h/0h/1h")
|
||||
EXPECTED_XPUB = "tpubDEMKm4M3S2Grx5DHTfbX9et5HQb9KhdjDCkUYdH9gvVofvPTE6yb2MH52P9uc4mx6eFohUmfN1f4hhHNK28GaZnWRXr3b8KkfFcySo1SmXU"
|
||||
|
||||
# Ensure that user cannot access SLIP-25 path without UnlockPath.
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
resp = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
ACCOUNT_PATH,
|
||||
coin_name="Testnet",
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
)
|
||||
|
||||
# Get unlock path MAC.
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
messages.ButtonRequest(code=B.Other),
|
||||
messages.UnlockedPathRequest,
|
||||
messages.Failure(code=messages.FailureType.ActionCancelled),
|
||||
]
|
||||
)
|
||||
unlock_path_mac = device.unlock_path(client, n=SLIP25_PATH)
|
||||
unlock_path_mac = device.unlock_path(session, n=SLIP25_PATH)
|
||||
|
||||
# Ensure that UnlockPath fails with invalid MAC.
|
||||
invalid_unlock_path_mac = bytes([unlock_path_mac[0] ^ 1]) + unlock_path_mac[1:]
|
||||
with pytest.raises(TrezorFailure, match="Invalid MAC"):
|
||||
resp = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
ACCOUNT_PATH,
|
||||
coin_name="Testnet",
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -690,15 +689,15 @@ def test_get_public_key(client: Client):
|
||||
)
|
||||
|
||||
# Ensure that user does not need to confirm access when path unlock is requested with MAC.
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
messages.UnlockedPathRequest,
|
||||
messages.PublicKey,
|
||||
]
|
||||
)
|
||||
resp = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
ACCOUNT_PATH,
|
||||
coin_name="Testnet",
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -708,11 +707,12 @@ def test_get_public_key(client: Client):
|
||||
assert resp.xpub == EXPECTED_XPUB
|
||||
|
||||
|
||||
def test_get_address(client: Client):
|
||||
def test_get_address(session: Session):
|
||||
|
||||
# Ensure that the SLIP-0025 external chain is inaccessible without user confirmation.
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/0/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -720,20 +720,20 @@ def test_get_address(client: Client):
|
||||
)
|
||||
|
||||
# Unlock CoinJoin path.
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
messages.ButtonRequest(code=B.Other),
|
||||
messages.UnlockedPathRequest,
|
||||
messages.Failure(code=messages.FailureType.ActionCancelled),
|
||||
]
|
||||
)
|
||||
unlock_path_mac = device.unlock_path(client, SLIP25_PATH)
|
||||
unlock_path_mac = device.unlock_path(session, SLIP25_PATH)
|
||||
|
||||
# Ensure that the SLIP-0025 external chain is accessible after user confirmation.
|
||||
for chunkify in (True, False):
|
||||
resp = btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/0/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -745,7 +745,7 @@ def test_get_address(client: Client):
|
||||
assert resp == "tb1pl3y9gf7xk2ryvmav5ar66ra0d2hk7lhh9mmusx3qvn0n09kmaghqh32ru7"
|
||||
|
||||
resp = btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/0/1"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -758,7 +758,7 @@ def test_get_address(client: Client):
|
||||
# Ensure that the SLIP-0025 internal chain is inaccessible even with user authorization.
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -769,7 +769,7 @@ def test_get_address(client: Client):
|
||||
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/1"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -781,7 +781,7 @@ def test_get_address(client: Client):
|
||||
# Ensure that another SLIP-0025 account is inaccessible with the same MAC.
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/1h/1h/0/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -793,8 +793,10 @@ def test_get_address(client: Client):
|
||||
|
||||
def test_multisession_authorization(client: Client):
|
||||
# Authorize CoinJoin with www.example1.com in session 1.
|
||||
session1 = client.get_session()
|
||||
|
||||
btc.authorize_coinjoin(
|
||||
client,
|
||||
session1,
|
||||
coordinator="www.example1.com",
|
||||
max_rounds=10,
|
||||
max_coordinator_fee_rate=500_000, # 0.5 %
|
||||
@ -805,12 +807,11 @@ def test_multisession_authorization(client: Client):
|
||||
)
|
||||
|
||||
# Open a second session.
|
||||
session_id1 = client.session_id
|
||||
client.init_device(new_session=True)
|
||||
session2 = client.get_session()
|
||||
|
||||
# Authorize CoinJoin with www.example2.com in session 2.
|
||||
btc.authorize_coinjoin(
|
||||
client,
|
||||
session2,
|
||||
coordinator="www.example2.com",
|
||||
max_rounds=10,
|
||||
max_coordinator_fee_rate=500_000, # 0.5 %
|
||||
@ -823,7 +824,7 @@ def test_multisession_authorization(client: Client):
|
||||
# Requesting a preauthorized ownership proof for www.example1.com should fail in session 2.
|
||||
with pytest.raises(TrezorFailure, match="Unauthorized operation"):
|
||||
ownership_proof, _ = btc.get_ownership_proof(
|
||||
client,
|
||||
session2,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -834,7 +835,7 @@ def test_multisession_authorization(client: Client):
|
||||
|
||||
# Requesting a preauthorized ownership proof for www.example2.com should succeed in session 2.
|
||||
ownership_proof, _ = btc.get_ownership_proof(
|
||||
client,
|
||||
session2,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -849,12 +850,10 @@ def test_multisession_authorization(client: Client):
|
||||
)
|
||||
|
||||
# Switch back to the first session.
|
||||
session_id2 = client.session_id
|
||||
client.init_device(session_id=session_id1)
|
||||
|
||||
session1 = client.resume_session(session1)
|
||||
# Requesting a preauthorized ownership proof for www.example1.com should succeed in session 1.
|
||||
ownership_proof, _ = btc.get_ownership_proof(
|
||||
client,
|
||||
session1,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -871,7 +870,7 @@ def test_multisession_authorization(client: Client):
|
||||
# Requesting a preauthorized ownership proof for www.example2.com should fail in session 1.
|
||||
with pytest.raises(TrezorFailure, match="Unauthorized operation"):
|
||||
ownership_proof, _ = btc.get_ownership_proof(
|
||||
client,
|
||||
session1,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -881,12 +880,12 @@ def test_multisession_authorization(client: Client):
|
||||
)
|
||||
|
||||
# Cancel the authorization in session 1.
|
||||
device.cancel_authorization(client)
|
||||
device.cancel_authorization(session1)
|
||||
|
||||
# Requesting a preauthorized ownership proof should fail now.
|
||||
with pytest.raises(TrezorFailure, match="No preauthorized operation"):
|
||||
ownership_proof, _ = btc.get_ownership_proof(
|
||||
client,
|
||||
session1,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
@ -896,11 +895,10 @@ def test_multisession_authorization(client: Client):
|
||||
)
|
||||
|
||||
# Switch to the second session.
|
||||
client.init_device(session_id=session_id2)
|
||||
|
||||
session2 = client.resume_session(session2)
|
||||
# Requesting a preauthorized ownership proof for www.example2.com should still succeed in session 2.
|
||||
ownership_proof, _ = btc.get_ownership_proof(
|
||||
client,
|
||||
session2,
|
||||
"Testnet",
|
||||
parse_path("m/10025h/1h/0h/1h/1/0"),
|
||||
script_type=messages.InputScriptType.SPENDTAPROOT,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
from trezorlib.tools import H_, parse_path
|
||||
|
||||
@ -53,7 +53,7 @@ FAKE_TXHASH_203416 = bytes.fromhex(
|
||||
pytestmark = pytest.mark.altcoin
|
||||
|
||||
|
||||
def test_send_bch_change(client: Client):
|
||||
def test_send_bch_change(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
address_n=parse_path("m/44h/145h/0h/0/0"),
|
||||
# bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv
|
||||
@ -72,14 +72,14 @@ def test_send_bch_change(client: Client):
|
||||
amount=73_452,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(TXHASH_bc37c2),
|
||||
@ -92,9 +92,9 @@ def test_send_bch_change(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
# raise Exception(hexlify(serialized_tx))
|
||||
assert_tx_matches(
|
||||
serialized_tx,
|
||||
hash_link="https://bch1.trezor.io/api/tx/502e8577b237b0152843a416f8f1ab0c63321b1be7a8cad7bf5c5c216fcf062c",
|
||||
@ -102,7 +102,7 @@ def test_send_bch_change(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_bch_nochange(client: Client):
|
||||
def test_send_bch_nochange(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
address_n=parse_path("m/44h/145h/0h/1/0"),
|
||||
# bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw
|
||||
@ -124,14 +124,14 @@ def test_send_bch_nochange(client: Client):
|
||||
amount=1_934_960,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(TXHASH_502e85),
|
||||
@ -150,7 +150,7 @@ def test_send_bch_nochange(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert_tx_matches(
|
||||
@ -160,7 +160,7 @@ def test_send_bch_nochange(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_bch_oldaddr(client: Client):
|
||||
def test_send_bch_oldaddr(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
address_n=parse_path("m/44h/145h/0h/1/0"),
|
||||
# bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw
|
||||
@ -182,14 +182,14 @@ def test_send_bch_oldaddr(client: Client):
|
||||
amount=1_934_960,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(TXHASH_502e85),
|
||||
@ -208,7 +208,7 @@ def test_send_bch_oldaddr(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert_tx_matches(
|
||||
@ -218,7 +218,7 @@ def test_send_bch_oldaddr(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_attack_change_input(client: Client):
|
||||
def test_attack_change_input(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -252,15 +252,15 @@ def test_attack_change_input(client: Client):
|
||||
|
||||
return msg
|
||||
|
||||
with client:
|
||||
client.set_filter(messages.TxAck, attack_processor)
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_filter(messages.TxAck, attack_processor)
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_bd32ff),
|
||||
@ -271,16 +271,16 @@ def test_attack_change_input(client: Client):
|
||||
]
|
||||
)
|
||||
with pytest.raises(TrezorFailure):
|
||||
btc.sign_tx(client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API)
|
||||
btc.sign_tx(session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API)
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_send_bch_multisig_wrongchange(client: Client):
|
||||
def test_send_bch_multisig_wrongchange(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
nodes = [
|
||||
btc.get_public_node(
|
||||
client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
|
||||
session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
|
||||
).node
|
||||
for i in range(1, 4)
|
||||
]
|
||||
@ -327,13 +327,13 @@ def test_send_bch_multisig_wrongchange(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOMULTISIG,
|
||||
amount=23_000,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_062fbd),
|
||||
@ -346,7 +346,7 @@ def test_send_bch_multisig_wrongchange(client: Client):
|
||||
]
|
||||
)
|
||||
(signatures1, serialized_tx) = btc.sign_tx(
|
||||
client, "Bcash", [inp1], [out1], prev_txes=TX_API
|
||||
session, "Bcash", [inp1], [out1], prev_txes=TX_API
|
||||
)
|
||||
assert (
|
||||
signatures1[0].hex()
|
||||
@ -359,12 +359,12 @@ def test_send_bch_multisig_wrongchange(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_send_bch_multisig_change(client: Client):
|
||||
def test_send_bch_multisig_change(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
nodes = [
|
||||
btc.get_public_node(
|
||||
client, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
|
||||
session, parse_path(f"m/48h/145h/{i}h/0h"), coin_name="Bcash"
|
||||
).node
|
||||
for i in range(1, 4)
|
||||
]
|
||||
@ -395,13 +395,13 @@ def test_send_bch_multisig_change(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOMULTISIG,
|
||||
amount=24_000,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
@ -415,7 +415,7 @@ def test_send_bch_multisig_change(client: Client):
|
||||
]
|
||||
)
|
||||
(signatures1, serialized_tx) = btc.sign_tx(
|
||||
client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -434,13 +434,13 @@ def test_send_bch_multisig_change(client: Client):
|
||||
)
|
||||
out2.address_n[2] = H_(1)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
@ -454,7 +454,7 @@ def test_send_bch_multisig_change(client: Client):
|
||||
]
|
||||
)
|
||||
(signatures1, serialized_tx) = btc.sign_tx(
|
||||
client, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bcash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -468,7 +468,7 @@ def test_send_bch_multisig_change(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.models("core")
|
||||
def test_send_bch_external_presigned(client: Client):
|
||||
def test_send_bch_external_presigned(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
# address_n=parse_path("44'/145'/0'/1/0"),
|
||||
# bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw
|
||||
@ -496,14 +496,14 @@ def test_send_bch_external_presigned(client: Client):
|
||||
amount=1_934_960,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(TXHASH_502e85),
|
||||
@ -522,7 +522,7 @@ def test_send_bch_external_presigned(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
session, "Bcash", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert_tx_matches(
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
from trezorlib.tools import H_, parse_path, tx_hash
|
||||
|
||||
@ -51,7 +51,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")]
|
||||
|
||||
|
||||
# All data taken from T1
|
||||
def test_send_bitcoin_gold_change(client: Client):
|
||||
def test_send_bitcoin_gold_change(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -71,14 +71,14 @@ def test_send_bitcoin_gold_change(client: Client):
|
||||
amount=1_252_382_934 - 1_896_050 - 1_000,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_6f0398),
|
||||
@ -92,7 +92,7 @@ def test_send_bitcoin_gold_change(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -101,7 +101,7 @@ def test_send_bitcoin_gold_change(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_bitcoin_gold_nochange(client: Client):
|
||||
def test_send_bitcoin_gold_nochange(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -124,14 +124,14 @@ def test_send_bitcoin_gold_nochange(client: Client):
|
||||
amount=1_252_382_934 + 38_448_607 - 1_000,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_6f0398),
|
||||
@ -150,7 +150,7 @@ def test_send_bitcoin_gold_nochange(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -159,7 +159,7 @@ def test_send_bitcoin_gold_nochange(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_attack_change_input(client: Client):
|
||||
def test_attack_change_input(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -193,15 +193,15 @@ def test_attack_change_input(client: Client):
|
||||
|
||||
return msg
|
||||
|
||||
with client:
|
||||
client.set_filter(messages.TxAck, attack_processor)
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_filter(messages.TxAck, attack_processor)
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_6f0398),
|
||||
@ -213,16 +213,16 @@ def test_attack_change_input(client: Client):
|
||||
]
|
||||
)
|
||||
with pytest.raises(TrezorFailure):
|
||||
btc.sign_tx(client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API)
|
||||
btc.sign_tx(session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API)
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_send_btg_multisig_change(client: Client):
|
||||
def test_send_btg_multisig_change(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
nodes = [
|
||||
btc.get_public_node(
|
||||
client, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold"
|
||||
session, parse_path(f"m/48h/156h/{i}h/0h"), coin_name="Bgold"
|
||||
).node
|
||||
for i in range(1, 4)
|
||||
]
|
||||
@ -254,13 +254,13 @@ def test_send_btg_multisig_change(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOMULTISIG,
|
||||
amount=1_252_382_934 - 24_000 - 1_000,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
@ -275,7 +275,7 @@ def test_send_btg_multisig_change(client: Client):
|
||||
]
|
||||
)
|
||||
signatures, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -293,13 +293,13 @@ def test_send_btg_multisig_change(client: Client):
|
||||
)
|
||||
out2.address_n[2] = H_(1)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
@ -314,7 +314,7 @@ def test_send_btg_multisig_change(client: Client):
|
||||
]
|
||||
)
|
||||
signatures, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -327,7 +327,7 @@ def test_send_btg_multisig_change(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_p2sh(client: Client):
|
||||
def test_send_p2sh(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -347,16 +347,16 @@ def test_send_p2sh(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
amount=1_252_382_934 - 11_000 - 12_300_000,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_db7239),
|
||||
@ -371,7 +371,7 @@ def test_send_p2sh(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -380,7 +380,7 @@ def test_send_p2sh(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_p2sh_witness_change(client: Client):
|
||||
def test_send_p2sh_witness_change(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -400,13 +400,13 @@ def test_send_p2sh_witness_change(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOP2SHWITNESS,
|
||||
amount=1_252_382_934 - 11_000 - 12_300_000,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
@ -422,7 +422,7 @@ def test_send_p2sh_witness_change(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Bgold", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -432,12 +432,12 @@ def test_send_p2sh_witness_change(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_send_multisig_1(client: Client):
|
||||
def test_send_multisig_1(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
nodes = [
|
||||
btc.get_public_node(
|
||||
client, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold"
|
||||
session, parse_path(f"m/49h/156h/{i}h"), coin_name="Bgold"
|
||||
).node
|
||||
for i in range(1, 4)
|
||||
]
|
||||
@ -460,13 +460,13 @@ def test_send_multisig_1(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_7f1f6b),
|
||||
@ -479,17 +479,17 @@ def test_send_multisig_1(client: Client):
|
||||
request_finished(),
|
||||
]
|
||||
)
|
||||
signatures, _ = btc.sign_tx(client, "Bgold", [inp1], [out1], prev_txes=TX_API)
|
||||
signatures, _ = btc.sign_tx(session, "Bgold", [inp1], [out1], prev_txes=TX_API)
|
||||
# store signature
|
||||
inp1.multisig.signatures[0] = signatures[0]
|
||||
# sign with third key
|
||||
inp1.address_n[2] = H_(3)
|
||||
client.set_expected_responses(
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_7f1f6b),
|
||||
@ -503,7 +503,7 @@ def test_send_multisig_1(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1], [out1], prev_txes=TX_API
|
||||
session, "Bgold", [inp1], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -512,7 +512,7 @@ def test_send_multisig_1(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_mixed_inputs(client: Client):
|
||||
def test_send_mixed_inputs(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
# First is non-segwit, second is segwit.
|
||||
|
||||
@ -537,9 +537,9 @@ def test_send_mixed_inputs(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
with client:
|
||||
with session:
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -549,7 +549,7 @@ def test_send_mixed_inputs(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.models("core")
|
||||
def test_send_btg_external_presigned(client: Client):
|
||||
def test_send_btg_external_presigned(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -577,14 +577,14 @@ def test_send_btg_external_presigned(client: Client):
|
||||
amount=1_252_382_934 + 58_456 - 1_000,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_6f0398),
|
||||
@ -603,7 +603,7 @@ def test_send_btg_external_presigned(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from ...common import is_core
|
||||
@ -43,7 +43,7 @@ TXHASH_15575a = bytes.fromhex(
|
||||
pytestmark = [pytest.mark.altcoin, pytest.mark.models("t1b1", "t2t1")]
|
||||
|
||||
|
||||
def test_send_dash(client: Client):
|
||||
def test_send_dash(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
address_n=parse_path("m/44h/5h/0h/0/0"),
|
||||
# dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH
|
||||
@ -57,13 +57,13 @@ def test_send_dash(client: Client):
|
||||
amount=999_999_000,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(inp1.prev_hash),
|
||||
@ -77,7 +77,9 @@ def test_send_dash(client: Client):
|
||||
request_finished(),
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(client, "Dash", [inp1], [out1], prev_txes=TX_API)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
session, "Dash", [inp1], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
serialized_tx.hex()
|
||||
@ -85,7 +87,7 @@ def test_send_dash(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_dash_dip2_input(client: Client):
|
||||
def test_send_dash_dip2_input(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
address_n=parse_path("m/44h/5h/0h/0/0"),
|
||||
# dash:XdTw4G5AWW4cogGd7ayybyBNDbuB45UpgH
|
||||
@ -104,14 +106,14 @@ def test_send_dash_dip2_input(client: Client):
|
||||
amount=95_000_000,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(inp1.prev_hash),
|
||||
@ -128,7 +130,7 @@ def test_send_dash_dip2_input(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Dash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
session, "Dash", [inp1], [out1, out2], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from ...common import is_core
|
||||
@ -57,7 +57,7 @@ pytestmark = [
|
||||
]
|
||||
|
||||
|
||||
def test_send_decred(client: Client):
|
||||
def test_send_decred(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -76,13 +76,13 @@ def test_send_decred(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.FeeOverThreshold),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
@ -95,7 +95,7 @@ def test_send_decred(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Decred Testnet", [inp1], [out1], prev_txes=TX_API
|
||||
session, "Decred Testnet", [inp1], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -105,7 +105,7 @@ def test_send_decred(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.models("core")
|
||||
def test_purchase_ticket_decred(client: Client):
|
||||
def test_purchase_ticket_decred(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -133,8 +133,8 @@ def test_purchase_ticket_decred(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_output(0),
|
||||
@ -153,7 +153,7 @@ def test_purchase_ticket_decred(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Decred Testnet",
|
||||
[inp1],
|
||||
[out1, out2, out3],
|
||||
@ -168,7 +168,7 @@ def test_purchase_ticket_decred(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.models("core")
|
||||
def test_spend_from_stake_generation_and_revocation_decred(client: Client):
|
||||
def test_spend_from_stake_generation_and_revocation_decred(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -197,14 +197,14 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_8b6890),
|
||||
@ -223,7 +223,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
session, "Decred Testnet", [inp1, inp2], [out1], prev_txes=TX_API
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -232,7 +232,7 @@ def test_spend_from_stake_generation_and_revocation_decred(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_send_decred_change(client: Client):
|
||||
def test_send_decred_change(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
inp1 = messages.TxInputType(
|
||||
@ -278,15 +278,15 @@ def test_send_decred_change(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_input(2),
|
||||
request_output(0),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
@ -311,7 +311,7 @@ def test_send_decred_change(client: Client):
|
||||
]
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Decred Testnet",
|
||||
[inp1, inp2, inp3],
|
||||
[out1, out2],
|
||||
@ -325,12 +325,12 @@ def test_send_decred_change(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_decred_multisig_change(client: Client):
|
||||
def test_decred_multisig_change(session: Session):
|
||||
# NOTE: fake input tx used
|
||||
|
||||
paths = [parse_path(f"m/48h/1h/{index}'/0'") for index in range(3)]
|
||||
nodes = [
|
||||
btc.get_public_node(client, address_n, coin_name="Decred Testnet").node
|
||||
btc.get_public_node(session, address_n, coin_name="Decred Testnet").node
|
||||
for address_n in paths
|
||||
]
|
||||
|
||||
@ -384,15 +384,15 @@ def test_decred_multisig_change(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
request_input(0),
|
||||
request_input(1),
|
||||
request_output(0),
|
||||
request_output(1),
|
||||
messages.ButtonRequest(code=B.ConfirmOutput),
|
||||
(is_core(client), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
(is_core(session), messages.ButtonRequest(code=B.ConfirmOutput)),
|
||||
messages.ButtonRequest(code=B.SignTx),
|
||||
request_input(0),
|
||||
request_meta(FAKE_TXHASH_9ac7d2),
|
||||
@ -410,7 +410,7 @@ def test_decred_multisig_change(client: Client):
|
||||
]
|
||||
)
|
||||
signature, serialized_tx = btc.sign_tx(
|
||||
client,
|
||||
session,
|
||||
"Decred Testnet",
|
||||
[inp1, inp2],
|
||||
[out1, out2],
|
||||
|
@ -18,7 +18,7 @@ import pytest
|
||||
|
||||
from trezorlib import btc, messages, models
|
||||
from trezorlib.cli import btc as btc_cli
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import H_
|
||||
|
||||
from ...input_flows import InputFlowShowXpubQRCode
|
||||
@ -165,14 +165,16 @@ def _address_n(purpose, coin, account, script_type):
|
||||
@pytest.mark.parametrize(
|
||||
"coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS
|
||||
)
|
||||
def test_descriptors(client: Client, coin, account, purpose, script_type, descriptors):
|
||||
with client:
|
||||
IF = InputFlowShowXpubQRCode(client)
|
||||
client.set_input_flow(IF.get())
|
||||
def test_descriptors(
|
||||
session: Session, coin, account, purpose, script_type, descriptors
|
||||
):
|
||||
with session:
|
||||
IF = InputFlowShowXpubQRCode(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
|
||||
address_n = _address_n(purpose, coin, account, script_type)
|
||||
res = btc.get_public_node(
|
||||
client,
|
||||
session,
|
||||
_address_n(purpose, coin, account, script_type),
|
||||
show_display=True,
|
||||
coin_name=coin,
|
||||
@ -187,13 +189,13 @@ def test_descriptors(client: Client, coin, account, purpose, script_type, descri
|
||||
"coin, account, purpose, script_type, descriptors", VECTORS_DESCRIPTORS
|
||||
)
|
||||
def test_descriptors_trezorlib(
|
||||
client: Client, coin, account, purpose, script_type, descriptors
|
||||
session: Session, coin, account, purpose, script_type, descriptors
|
||||
):
|
||||
with client:
|
||||
if client.model != models.T1B1:
|
||||
IF = InputFlowShowXpubQRCode(client)
|
||||
client.set_input_flow(IF.get())
|
||||
with session:
|
||||
if session.client.model != models.T1B1:
|
||||
IF = InputFlowShowXpubQRCode(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
res = btc_cli._get_descriptor(
|
||||
client, coin, account, purpose, script_type, show_display=True
|
||||
session, coin, account, purpose, script_type, show_display=True
|
||||
)
|
||||
assert res == descriptors
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from ...tx_cache import TxCache
|
||||
@ -30,7 +30,7 @@ TXHASH_8a34cc = bytes.fromhex(
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_spend_lelantus(client: Client):
|
||||
def test_spend_lelantus(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
# THgGLVqfzJcaxRVPWE5fd8YJ1GpVePq2Uk
|
||||
address_n=parse_path("m/44h/1h/0h/0/4"),
|
||||
@ -45,7 +45,7 @@ def test_spend_lelantus(client: Client):
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(
|
||||
client, "Firo Testnet", [inp1], [out1], prev_txes=TX_API
|
||||
session, "Firo Testnet", [inp1], [out1], prev_txes=TX_API
|
||||
)
|
||||
assert_tx_matches(
|
||||
serialized_tx,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
TXHASH_33043a = bytes.fromhex(
|
||||
@ -27,7 +27,7 @@ TXHASH_33043a = bytes.fromhex(
|
||||
pytestmark = pytest.mark.altcoin
|
||||
|
||||
|
||||
def test_send_p2tr(client: Client):
|
||||
def test_send_p2tr(session: Session):
|
||||
inp1 = messages.TxInputType(
|
||||
# fc1prr07akly3xjtmggue0p04vghr8pdcgxrye2s00sahptwjeawxrkq2rxzr7
|
||||
address_n=parse_path("m/86h/75h/0h/0/1"),
|
||||
@ -42,7 +42,7 @@ def test_send_p2tr(client: Client):
|
||||
amount=99_996_670_000,
|
||||
script_type=messages.OutputScriptType.PAYTOADDRESS,
|
||||
)
|
||||
_, serialized_tx = btc.sign_tx(client, "Fujicoin", [inp1], [out1])
|
||||
_, serialized_tx = btc.sign_tx(session, "Fujicoin", [inp1], [out1])
|
||||
# Transaction hex changed with fix #2085, all other details are the same as this tx:
|
||||
# https://explorer.fujicoin.org/tx/a1c6a81f5e8023b17e6e3e51e2596d5b5e1d4914ea13c0c31cef90b3c3edee86
|
||||
assert (
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, device, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
from trezorlib.messages import MultisigPubkeysOrder, SafetyCheckLevel
|
||||
from trezorlib.tools import parse_path
|
||||
@ -36,112 +36,112 @@ def getmultisig(chain, nr, xpubs):
|
||||
)
|
||||
|
||||
|
||||
def test_btc(client: Client):
|
||||
def test_btc(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/0"))
|
||||
btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"))
|
||||
== "1JAd7XCBzGudGpJQSDSfpmJhiygtLQWaGL"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/0/1"))
|
||||
btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/0/1"))
|
||||
== "1GWFxtwWmNVqotUPXLcKVL2mUKpshuJYo"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
|
||||
btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
|
||||
== "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_ltc(client: Client):
|
||||
def test_ltc(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/0"))
|
||||
btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/0"))
|
||||
== "LcubERmHD31PWup1fbozpKuiqjHZ4anxcL"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/0/1"))
|
||||
btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/0/1"))
|
||||
== "LVWBmHBkCGNjSPHucvL2PmnuRAJnucmRE6"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Litecoin", parse_path("m/44h/2h/0h/1/0"))
|
||||
btc.get_address(session, "Litecoin", parse_path("m/44h/2h/0h/1/0"))
|
||||
== "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h"
|
||||
)
|
||||
|
||||
|
||||
def test_tbtc(client: Client):
|
||||
def test_tbtc(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/0"))
|
||||
btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/0"))
|
||||
== "mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/0/1"))
|
||||
btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/0/1"))
|
||||
== "mopZWqZZyQc3F2Sy33cvDtJchSAMsnLi7b"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Testnet", parse_path("m/44h/1h/0h/1/0"))
|
||||
btc.get_address(session, "Testnet", parse_path("m/44h/1h/0h/1/0"))
|
||||
== "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_bch(client: Client):
|
||||
def test_bch(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/0"))
|
||||
btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/0"))
|
||||
== "bitcoincash:qr08q88p9etk89wgv05nwlrkm4l0urz4cyl36hh9sv"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/0/1"))
|
||||
btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/0/1"))
|
||||
== "bitcoincash:qr23ajjfd9wd73l87j642puf8cad20lfmqdgwvpat4"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Bcash", parse_path("m/44h/145h/0h/1/0"))
|
||||
btc.get_address(session, "Bcash", parse_path("m/44h/145h/0h/1/0"))
|
||||
== "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_grs(client: Client):
|
||||
def test_grs(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/0/0"))
|
||||
btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/0/0"))
|
||||
== "Fj62rBJi8LvbmWu2jzkaUX1NFXLEqDLoZM"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/0"))
|
||||
btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/0"))
|
||||
== "FmRaqvVBRrAp2Umfqx9V1ectZy8gw54QDN"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1"))
|
||||
btc.get_address(session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1"))
|
||||
== "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_tgrs(client: Client):
|
||||
def test_tgrs(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"))
|
||||
btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/0/0"))
|
||||
== "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0"))
|
||||
btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/0"))
|
||||
== "mm6kLYbGEL1tGe4ZA8xacfgRPdW1LMq8cN"
|
||||
)
|
||||
assert (
|
||||
btc.get_address(client, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1"))
|
||||
btc.get_address(session, "Groestlcoin Testnet", parse_path("m/44h/1h/0h/1/1"))
|
||||
== "mjXZwmEi1z1MzveZrKUAo4DBgbdq6ZhGD6"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_elements(client: Client):
|
||||
def test_elements(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Elements", parse_path("m/44h/1h/0h/0/0"))
|
||||
btc.get_address(session, "Elements", parse_path("m/44h/1h/0h/0/0"))
|
||||
== "2dpWh6jbhAowNsQ5agtFzi7j6nKscj6UnEr"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.models("core")
|
||||
def test_address_mac(client: Client):
|
||||
def test_address_mac(session: Session):
|
||||
resp = btc.get_authenticated_address(
|
||||
client, "Bitcoin", parse_path("m/44h/0h/0h/1/0")
|
||||
session, "Bitcoin", parse_path("m/44h/0h/0h/1/0")
|
||||
)
|
||||
assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE"
|
||||
assert (
|
||||
@ -150,7 +150,7 @@ def test_address_mac(client: Client):
|
||||
)
|
||||
|
||||
resp = btc.get_authenticated_address(
|
||||
client, "Testnet", parse_path("m/44h/1h/0h/1/0")
|
||||
session, "Testnet", parse_path("m/44h/1h/0h/1/0")
|
||||
)
|
||||
assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ"
|
||||
assert (
|
||||
@ -160,16 +160,16 @@ def test_address_mac(client: Client):
|
||||
|
||||
# Script type mismatch.
|
||||
resp = btc.get_authenticated_address(
|
||||
client, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False
|
||||
session, "Bitcoin", parse_path("m/84h/0h/0h/0/0"), show_display=False
|
||||
)
|
||||
assert resp.mac is None
|
||||
|
||||
|
||||
@pytest.mark.models("core")
|
||||
@pytest.mark.altcoin
|
||||
def test_altcoin_address_mac(client: Client):
|
||||
def test_altcoin_address_mac(session: Session):
|
||||
resp = btc.get_authenticated_address(
|
||||
client, "Litecoin", parse_path("m/44h/2h/0h/1/0")
|
||||
session, "Litecoin", parse_path("m/44h/2h/0h/1/0")
|
||||
)
|
||||
assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h"
|
||||
assert (
|
||||
@ -178,7 +178,7 @@ def test_altcoin_address_mac(client: Client):
|
||||
)
|
||||
|
||||
resp = btc.get_authenticated_address(
|
||||
client, "Bcash", parse_path("m/44h/145h/0h/1/0")
|
||||
session, "Bcash", parse_path("m/44h/145h/0h/1/0")
|
||||
)
|
||||
assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw"
|
||||
assert (
|
||||
@ -187,7 +187,7 @@ def test_altcoin_address_mac(client: Client):
|
||||
)
|
||||
|
||||
resp = btc.get_authenticated_address(
|
||||
client, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")
|
||||
session, "Groestlcoin", parse_path("m/44h/17h/0h/1/1")
|
||||
)
|
||||
assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di"
|
||||
assert (
|
||||
@ -197,9 +197,9 @@ def test_altcoin_address_mac(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_multisig_pubkeys_order(client: Client):
|
||||
xpub_internal = btc.get_public_node(client, parse_path("m/45h/0")).xpub
|
||||
xpub_external = btc.get_public_node(client, parse_path("m/45h/1")).xpub
|
||||
def test_multisig_pubkeys_order(session: Session):
|
||||
xpub_internal = btc.get_public_node(session, parse_path("m/45h/0")).xpub
|
||||
xpub_external = btc.get_public_node(session, parse_path("m/45h/1")).xpub
|
||||
|
||||
multisig_unsorted_1 = messages.MultisigRedeemScriptType(
|
||||
nodes=[bip32.deserialize(xpub) for xpub in [xpub_external, xpub_internal]],
|
||||
@ -238,45 +238,45 @@ def test_multisig_pubkeys_order(client: Client):
|
||||
|
||||
assert (
|
||||
btc.get_address(
|
||||
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1
|
||||
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_1
|
||||
)
|
||||
== address_unsorted_1
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2
|
||||
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_unsorted_2
|
||||
)
|
||||
== address_unsorted_2
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1
|
||||
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_1
|
||||
)
|
||||
== address_unsorted_2
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2
|
||||
session, "Bitcoin", parse_path("m/45h/0/0/0"), multisig=multisig_sorted_2
|
||||
)
|
||||
== address_unsorted_2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_multisig(client: Client):
|
||||
def test_multisig(session: Session):
|
||||
xpubs = []
|
||||
for n in range(1, 4):
|
||||
node = btc.get_public_node(client, parse_path(f"m/44h/0h/{n}h"))
|
||||
node = btc.get_public_node(session, parse_path(f"m/44h/0h/{n}h"))
|
||||
xpubs.append(node.xpub)
|
||||
|
||||
for nr in range(1, 4):
|
||||
with client:
|
||||
if is_core(client):
|
||||
IF = InputFlowConfirmAllWarnings(client)
|
||||
client.set_input_flow(IF.get())
|
||||
with session:
|
||||
if is_core(session):
|
||||
IF = InputFlowConfirmAllWarnings(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bitcoin",
|
||||
parse_path(f"m/44h/0h/{nr}h/0/0"),
|
||||
show_display=(nr == 1),
|
||||
@ -286,7 +286,7 @@ def test_multisig(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bitcoin",
|
||||
parse_path(f"m/44h/0h/{nr}h/1/0"),
|
||||
show_display=(nr == 1),
|
||||
@ -298,11 +298,11 @@ def test_multisig(client: Client):
|
||||
|
||||
@pytest.mark.multisig
|
||||
@pytest.mark.parametrize("show_display", (True, False))
|
||||
def test_multisig_missing(client: Client, show_display):
|
||||
def test_multisig_missing(session: Session, show_display):
|
||||
# Use account numbers 1, 2 and 3 to create a valid multisig,
|
||||
# but not containing the keys from account 0 used below.
|
||||
nodes = [
|
||||
btc.get_public_node(client, parse_path(f"m/44h/0h/{i}h")).node
|
||||
btc.get_public_node(session, parse_path(f"m/44h/0h/{i}h")).node
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
@ -321,12 +321,12 @@ def test_multisig_missing(client: Client, show_display):
|
||||
)
|
||||
|
||||
for multisig in (multisig1, multisig2):
|
||||
with client, pytest.raises(TrezorFailure):
|
||||
if is_core(client):
|
||||
IF = InputFlowConfirmAllWarnings(client)
|
||||
client.set_input_flow(IF.get())
|
||||
with pytest.raises(TrezorFailure), session:
|
||||
if is_core(session):
|
||||
IF = InputFlowConfirmAllWarnings(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bitcoin",
|
||||
parse_path("m/44h/0h/0h/0/0"),
|
||||
show_display=show_display,
|
||||
@ -336,22 +336,22 @@ def test_multisig_missing(client: Client, show_display):
|
||||
|
||||
@pytest.mark.altcoin
|
||||
@pytest.mark.multisig
|
||||
def test_bch_multisig(client: Client):
|
||||
def test_bch_multisig(session: Session):
|
||||
xpubs = []
|
||||
for n in range(1, 4):
|
||||
node = btc.get_public_node(
|
||||
client, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash"
|
||||
session, parse_path(f"m/44h/145h/{n}h"), coin_name="Bcash"
|
||||
)
|
||||
xpubs.append(node.xpub)
|
||||
|
||||
for nr in range(1, 4):
|
||||
with client:
|
||||
if is_core(client):
|
||||
IF = InputFlowConfirmAllWarnings(client)
|
||||
client.set_input_flow(IF.get())
|
||||
with session:
|
||||
if is_core(session):
|
||||
IF = InputFlowConfirmAllWarnings(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bcash",
|
||||
parse_path(f"m/44h/145h/{nr}h/0/0"),
|
||||
show_display=(nr == 1),
|
||||
@ -361,7 +361,7 @@ def test_bch_multisig(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bcash",
|
||||
parse_path(f"m/44h/145h/{nr}h/1/0"),
|
||||
show_display=(nr == 1),
|
||||
@ -371,43 +371,43 @@ def test_bch_multisig(client: Client):
|
||||
)
|
||||
|
||||
|
||||
def test_public_ckd(client: Client):
|
||||
node = btc.get_public_node(client, parse_path("m/44h/0h/0h")).node
|
||||
node_sub1 = btc.get_public_node(client, parse_path("m/44h/0h/0h/1/0")).node
|
||||
def test_public_ckd(session: Session):
|
||||
node = btc.get_public_node(session, parse_path("m/44h/0h/0h")).node
|
||||
node_sub1 = btc.get_public_node(session, parse_path("m/44h/0h/0h/1/0")).node
|
||||
node_sub2 = bip32.public_ckd(node, [1, 0])
|
||||
|
||||
assert node_sub1.chain_code == node_sub2.chain_code
|
||||
assert node_sub1.public_key == node_sub2.public_key
|
||||
|
||||
address1 = btc.get_address(client, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
|
||||
address1 = btc.get_address(session, "Bitcoin", parse_path("m/44h/0h/0h/1/0"))
|
||||
address2 = bip32.get_address(node_sub2, 0)
|
||||
|
||||
assert address2 == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE"
|
||||
assert address1 == address2
|
||||
|
||||
|
||||
def test_invalid_path(client: Client):
|
||||
def test_invalid_path(session: Session):
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
# slip44 id mismatch
|
||||
btc.get_address(
|
||||
client, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True
|
||||
session, "Bitcoin", parse_path("m/44h/111h/0h/0/0"), show_display=True
|
||||
)
|
||||
|
||||
|
||||
def test_unknown_path(client: Client):
|
||||
def test_unknown_path(session: Session):
|
||||
UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0")
|
||||
with client:
|
||||
client.set_expected_responses([messages.Failure])
|
||||
with session:
|
||||
session.set_expected_responses([messages.Failure])
|
||||
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
# account number is too high
|
||||
btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True)
|
||||
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True)
|
||||
|
||||
# disable safety checks
|
||||
device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily)
|
||||
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
|
||||
|
||||
with client:
|
||||
client.set_expected_responses(
|
||||
with session:
|
||||
session.set_expected_responses(
|
||||
[
|
||||
messages.ButtonRequest(
|
||||
code=messages.ButtonRequestType.UnknownDerivationPath
|
||||
@ -416,30 +416,30 @@ def test_unknown_path(client: Client):
|
||||
messages.Address,
|
||||
]
|
||||
)
|
||||
if is_core(client):
|
||||
IF = InputFlowConfirmAllWarnings(client)
|
||||
client.set_input_flow(IF.get())
|
||||
if is_core(session):
|
||||
IF = InputFlowConfirmAllWarnings(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
# try again with a warning
|
||||
btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True)
|
||||
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True)
|
||||
|
||||
with client:
|
||||
with session:
|
||||
# no warning is displayed when the call is silent
|
||||
client.set_expected_responses([messages.Address])
|
||||
btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=False)
|
||||
session.set_expected_responses([messages.Address])
|
||||
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False)
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_crw(client: Client):
|
||||
def test_crw(session: Session):
|
||||
assert (
|
||||
btc.get_address(client, "Crown", parse_path("m/44h/72h/0h/0/0"))
|
||||
btc.get_address(session, "Crown", parse_path("m/44h/72h/0h/0/0"))
|
||||
== "CRWYdvZM1yXMKQxeN3hRsAbwa7drfvTwys48"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_multisig_different_paths(client: Client):
|
||||
def test_multisig_different_paths(session: Session):
|
||||
nodes = [
|
||||
btc.get_public_node(client, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node
|
||||
btc.get_public_node(session, parse_path(f"m/45h/{i}"), coin_name="Bitcoin").node
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
@ -455,12 +455,12 @@ def test_multisig_different_paths(client: Client):
|
||||
with pytest.raises(
|
||||
Exception, match="Using different paths for different xpubs is not allowed"
|
||||
):
|
||||
with client:
|
||||
if is_core(client):
|
||||
IF = InputFlowConfirmAllWarnings(client)
|
||||
client.set_input_flow(IF.get())
|
||||
with session:
|
||||
if is_core(session):
|
||||
IF = InputFlowConfirmAllWarnings(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bitcoin",
|
||||
parse_path("m/45h/0/0/0"),
|
||||
show_display=True,
|
||||
@ -468,13 +468,13 @@ def test_multisig_different_paths(client: Client):
|
||||
script_type=messages.InputScriptType.SPENDMULTISIG,
|
||||
)
|
||||
|
||||
device.apply_settings(client, safety_checks=SafetyCheckLevel.PromptTemporarily)
|
||||
with client:
|
||||
if is_core(client):
|
||||
IF = InputFlowConfirmAllWarnings(client)
|
||||
client.set_input_flow(IF.get())
|
||||
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
|
||||
with session:
|
||||
if is_core(session):
|
||||
IF = InputFlowConfirmAllWarnings(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bitcoin",
|
||||
parse_path("m/45h/0/0/0"),
|
||||
show_display=True,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, messages
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
@ -25,10 +25,10 @@ from ...common import is_core
|
||||
from ...input_flows import InputFlowConfirmAllWarnings
|
||||
|
||||
|
||||
def test_show_segwit(client: Client):
|
||||
def test_show_segwit(session: Session):
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/49h/1h/0h/1/0"),
|
||||
True,
|
||||
@ -39,7 +39,7 @@ def test_show_segwit(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/49h/1h/0h/0/0"),
|
||||
False,
|
||||
@ -50,7 +50,7 @@ def test_show_segwit(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/44h/1h/0h/0/0"),
|
||||
False,
|
||||
@ -61,7 +61,7 @@ def test_show_segwit(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path("m/44h/1h/0h/0/0"),
|
||||
False,
|
||||
@ -73,14 +73,14 @@ def test_show_segwit(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.altcoin
|
||||
def test_show_segwit_altcoin(client: Client):
|
||||
with client:
|
||||
if is_core(client):
|
||||
IF = InputFlowConfirmAllWarnings(client)
|
||||
client.set_input_flow(IF.get())
|
||||
def test_show_segwit_altcoin(session: Session):
|
||||
with session:
|
||||
if is_core(session):
|
||||
IF = InputFlowConfirmAllWarnings(session.client)
|
||||
session.set_input_flow(IF.get())
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Groestlcoin Testnet",
|
||||
parse_path("m/49h/1h/0h/1/0"),
|
||||
True,
|
||||
@ -91,7 +91,7 @@ def test_show_segwit_altcoin(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Groestlcoin Testnet",
|
||||
parse_path("m/49h/1h/0h/0/0"),
|
||||
True,
|
||||
@ -102,7 +102,7 @@ def test_show_segwit_altcoin(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Groestlcoin Testnet",
|
||||
parse_path("m/44h/1h/0h/0/0"),
|
||||
True,
|
||||
@ -113,7 +113,7 @@ def test_show_segwit_altcoin(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Groestlcoin Testnet",
|
||||
parse_path("m/44h/1h/0h/0/0"),
|
||||
True,
|
||||
@ -124,7 +124,7 @@ def test_show_segwit_altcoin(client: Client):
|
||||
)
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Elements",
|
||||
parse_path("m/49h/1h/0h/0/0"),
|
||||
True,
|
||||
@ -136,10 +136,10 @@ def test_show_segwit_altcoin(client: Client):
|
||||
|
||||
|
||||
@pytest.mark.multisig
|
||||
def test_show_multisig_3(client: Client):
|
||||
def test_show_multisig_3(session: Session):
|
||||
nodes = [
|
||||
btc.get_public_node(
|
||||
client, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet"
|
||||
session, parse_path(f"m/49h/1h/{i}h"), coin_name="Testnet"
|
||||
).node
|
||||
for i in range(1, 4)
|
||||
]
|
||||
@ -155,7 +155,7 @@ def test_show_multisig_3(client: Client):
|
||||
for i in [1, 2, 3]:
|
||||
assert (
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Testnet",
|
||||
parse_path(f"m/49h/1h/{i}h/0/7"),
|
||||
False,
|
||||
@ -168,11 +168,11 @@ def test_show_multisig_3(client: Client):
|
||||
|
||||
@pytest.mark.multisig
|
||||
@pytest.mark.parametrize("show_display", (True, False))
|
||||
def test_multisig_missing(client: Client, show_display):
|
||||
def test_multisig_missing(session: Session, show_display):
|
||||
# Use account numbers 1, 2 and 3 to create a valid multisig,
|
||||
# but not containing the keys from account 0 used below.
|
||||
nodes = [
|
||||
btc.get_public_node(client, parse_path(f"m/49h/0h/{i}h")).node
|
||||
btc.get_public_node(session, parse_path(f"m/49h/0h/{i}h")).node
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
@ -193,7 +193,7 @@ def test_multisig_missing(client: Client, show_display):
|
||||
for multisig in (multisig1, multisig2):
|
||||
with pytest.raises(TrezorFailure):
|
||||
btc.get_address(
|
||||
client,
|
||||
session,
|
||||
"Bitcoin",
|
||||
parse_path("m/49h/0h/0h/0/0"),
|
||||
show_display=show_display,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user