From 34d3dbdeb13307f2d1d3dd1f0241dfd587a2393c Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 16 Sep 2024 14:36:40 +0200 Subject: [PATCH] wip trezorlib add btc commands --- python/src/trezorlib/btc.py | 39 ++++++++------- python/src/trezorlib/cli/btc.py | 49 ++++++++++--------- .../src/trezorlib/transport/new/transport.py | 18 +++++-- 3 files changed, 59 insertions(+), 47 deletions(-) diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index a71ead2ad..ae8e88904 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -23,7 +23,8 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import expect, prepare_message_bytes, session +from .tools import expect, prepare_message_bytes +from .transport.new.session import Session if TYPE_CHECKING: from .client import TrezorClient @@ -105,7 +106,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType: @expect(messages.PublicKey) def get_public_node( - client: "TrezorClient", + session: "Session", n: "Address", ecdsa_curve_name: Optional[str] = None, show_display: bool = False, @@ -116,13 +117,13 @@ def get_public_node( unlock_path_mac: Optional[bytes] = None, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetPublicKey( address_n=n, ecdsa_curve_name=ecdsa_curve_name, @@ -141,7 +142,7 @@ def get_address(*args: Any, **kwargs: Any): @expect(messages.Address) def get_authenticated_address( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", show_display: bool = False, @@ -153,13 +154,13 @@ def get_authenticated_address( chunkify: bool = False, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetAddress( address_n=n, coin_name=coin_name, @@ -172,6 +173,7 @@ def get_authenticated_address( ) +# TODO this is used by tests only @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) def get_ownership_id( client: "TrezorClient", @@ -190,6 +192,7 @@ def get_ownership_id( ) +# TODO this is used by tests only def get_ownership_proof( client: "TrezorClient", coin_name: str, @@ -226,7 +229,7 @@ def get_ownership_proof( @expect(messages.MessageSignature) def sign_message( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", message: AnyStr, @@ -234,7 +237,7 @@ def sign_message( no_script_type: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.SignMessage( coin_name=coin_name, address_n=n, @@ -247,7 +250,7 @@ def sign_message( def verify_message( - client: "TrezorClient", + session: "Session", coin_name: str, address: str, signature: bytes, @@ -255,7 +258,7 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - resp = client.call( + resp = session.call( messages.VerifyMessage( address=address, signature=signature, @@ -269,9 +272,9 @@ def verify_message( return isinstance(resp, messages.Success) -@session +# @session def sign_tx( - client: "TrezorClient", + session: "Session", coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], @@ -319,17 +322,17 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") elif preauthorized: - res = client.call(messages.DoPreauthorized()) + res = session.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): raise exceptions.TrezorException("Unexpected message") - res = client.call(signtx) + res = session.call(signtx) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -388,7 +391,7 @@ def sign_tx( if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg) + res = session.call(msg) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -418,7 +421,7 @@ def sign_tx( f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg)) + res = session.call(messages.TxAck(tx=msg)) if not isinstance(res, messages.TxRequest): raise exceptions.TrezorException("Unexpected message") diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index dde59a6bc..26a4a7077 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -13,6 +13,7 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import base64 import json @@ -22,10 +23,10 @@ import click import construct as c from .. import btc, messages, protobuf, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.new.session import Session PURPOSE_BIP44 = 44 PURPOSE_BIP48 = 48 @@ -168,15 +169,15 @@ def cli() -> None: default=2, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", coin: str, address: str, - script_type: Optional[messages.InputScriptType], + script_type: messages.InputScriptType | None, show_display: bool, multisig_xpub: List[str], - multisig_threshold: Optional[int], + multisig_threshold: int | None, multisig_suffix_length: int, chunkify: bool, ) -> str: @@ -220,7 +221,7 @@ def get_address( multisig = None return btc.get_address( - client, + session, coin, address_n, show_display, @@ -237,9 +238,9 @@ def get_address( @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_node( - client: "TrezorClient", + session: "Session", coin: str, address: str, curve: Optional[str], @@ -251,7 +252,7 @@ def get_public_node( if script_type is None: script_type = guess_script_type_from_path(address_n) result = btc.get_public_node( - client, + session, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -277,7 +278,7 @@ def _append_descriptor_checksum(desc: str) -> str: def _get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, purpose: Optional[int], @@ -311,7 +312,7 @@ def _get_descriptor( n = tools.parse_path(path) pub = btc.get_public_node( - client, + session, n, show_display=show_display, coin_name=coin, @@ -348,9 +349,9 @@ def _get_descriptor( @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, account_type: Optional[int], @@ -360,7 +361,7 @@ def get_descriptor( """Get descriptor of given account.""" try: return _get_descriptor( - client, coin, account, account_type, script_type, show_display + session, coin, account, account_type, script_type, show_display ) except ValueError as e: raise click.ClickException(str(e)) @@ -375,8 +376,8 @@ def get_descriptor( @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) @click.argument("json_file", type=click.File()) -@with_client -def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -401,7 +402,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: } _, serialized_tx = btc.sign_tx( - client, + session, coin, inputs, outputs, @@ -432,9 +433,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: ) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, message: str, @@ -447,7 +448,7 @@ def sign_message( if script_type is None: script_type = guess_script_type_from_path(address_n) res = btc.sign_message( - client, + session, coin, address_n, message, @@ -468,9 +469,9 @@ def sign_message( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, signature: str, @@ -480,7 +481,7 @@ def verify_message( """Verify message.""" signature_bytes = base64.b64decode(signature) return btc.verify_message( - client, coin, address, signature_bytes, message, chunkify=chunkify + session, coin, address, signature_bytes, message, chunkify=chunkify ) diff --git a/python/src/trezorlib/transport/new/transport.py b/python/src/trezorlib/transport/new/transport.py index 87c550227..6401a98db 100644 --- a/python/src/trezorlib/transport/new/transport.py +++ b/python/src/trezorlib/transport/new/transport.py @@ -36,14 +36,22 @@ class NewTransport: raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") - def get_path(self) -> str: ... + def get_path(self) -> str: + raise NotImplementedError - def open(self) -> None: ... + def open(self) -> None: + raise NotImplementedError - def close(self) -> None: ... + def close(self) -> None: + raise NotImplementedError - def write_chunk(self, chunk: bytes) -> None: ... + def write_chunk(self, chunk: bytes) -> None: + raise NotImplementedError - def read_chunk(self) -> bytes: ... + def read_chunk(self) -> bytes: + raise NotImplementedError + + def find_debug(self: "T") -> "T": + raise NotImplementedError CHUNK_SIZE: t.ClassVar[int]