wip trezorlib add btc commands

M1nd3r/thp-improved
M1nd3r 4 days ago
parent 6a65d62353
commit 34d3dbdeb1

@ -23,7 +23,8 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict from typing_extensions import Protocol, TypedDict
from . import exceptions, messages from . import exceptions, messages
from .tools import expect, prepare_message_bytes, session from .tools import expect, prepare_message_bytes
from .transport.new.session import Session
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
@ -105,7 +106,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
@expect(messages.PublicKey) @expect(messages.PublicKey)
def get_public_node( def get_public_node(
client: "TrezorClient", session: "Session",
n: "Address", n: "Address",
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
show_display: bool = False, show_display: bool = False,
@ -116,13 +117,13 @@ def get_public_node(
unlock_path_mac: Optional[bytes] = None, unlock_path_mac: Optional[bytes] = None,
) -> "MessageType": ) -> "MessageType":
if unlock_path: if unlock_path:
res = client.call( res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
) )
if not isinstance(res, messages.UnlockedPathRequest): if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
return client.call( return session.call(
messages.GetPublicKey( messages.GetPublicKey(
address_n=n, address_n=n,
ecdsa_curve_name=ecdsa_curve_name, ecdsa_curve_name=ecdsa_curve_name,
@ -141,7 +142,7 @@ def get_address(*args: Any, **kwargs: Any):
@expect(messages.Address) @expect(messages.Address)
def get_authenticated_address( def get_authenticated_address(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
@ -153,13 +154,13 @@ def get_authenticated_address(
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
if unlock_path: if unlock_path:
res = client.call( res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
) )
if not isinstance(res, messages.UnlockedPathRequest): if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
return client.call( return session.call(
messages.GetAddress( messages.GetAddress(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
@ -172,6 +173,7 @@ def get_authenticated_address(
) )
# TODO this is used by tests only
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id( def get_ownership_id(
client: "TrezorClient", client: "TrezorClient",
@ -190,6 +192,7 @@ def get_ownership_id(
) )
# TODO this is used by tests only
def get_ownership_proof( def get_ownership_proof(
client: "TrezorClient", client: "TrezorClient",
coin_name: str, coin_name: str,
@ -226,7 +229,7 @@ def get_ownership_proof(
@expect(messages.MessageSignature) @expect(messages.MessageSignature)
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
message: AnyStr, message: AnyStr,
@ -234,7 +237,7 @@ def sign_message(
no_script_type: bool = False, no_script_type: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> "MessageType":
return client.call( return session.call(
messages.SignMessage( messages.SignMessage(
coin_name=coin_name, coin_name=coin_name,
address_n=n, address_n=n,
@ -247,7 +250,7 @@ def sign_message(
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
address: str, address: str,
signature: bytes, signature: bytes,
@ -255,7 +258,7 @@ def verify_message(
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
resp = client.call( resp = session.call(
messages.VerifyMessage( messages.VerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
@ -269,9 +272,9 @@ def verify_message(
return isinstance(resp, messages.Success) return isinstance(resp, messages.Success)
@session # @session
def sign_tx( def sign_tx(
client: "TrezorClient", session: "Session",
coin_name: str, coin_name: str,
inputs: Sequence[messages.TxInputType], inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType], outputs: Sequence[messages.TxOutputType],
@ -319,17 +322,17 @@ def sign_tx(
setattr(signtx, name, value) setattr(signtx, name, value)
if unlock_path: if unlock_path:
res = client.call( res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
) )
if not isinstance(res, messages.UnlockedPathRequest): if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
elif preauthorized: elif preauthorized:
res = client.call(messages.DoPreauthorized()) res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest): if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")
res = client.call(signtx) res = session.call(signtx)
# Prepare structure for signatures # Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs) signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -388,7 +391,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ: if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index] msg = payment_reqs[res.details.request_index]
res = client.call(msg) res = session.call(msg)
else: else:
msg = messages.TransactionType() msg = messages.TransactionType()
if res.request_type == R.TXMETA: if res.request_type == R.TXMETA:
@ -418,7 +421,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}." f"Unknown request type - {res.request_type}."
) )
res = client.call(messages.TxAck(tx=msg)) res = session.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest): if not isinstance(res, messages.TxRequest):
raise exceptions.TrezorException("Unexpected message") raise exceptions.TrezorException("Unexpected message")

@ -13,6 +13,7 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64 import base64
import json import json
@ -22,10 +23,10 @@ import click
import construct as c import construct as c
from .. import btc, messages, protobuf, tools from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client from . import ChoiceType, with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.new.session import Session
PURPOSE_BIP44 = 44 PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48 PURPOSE_BIP48 = 48
@ -168,15 +169,15 @@ def cli() -> None:
default=2, default=2,
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@with_client @with_session
def get_address( def get_address(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
script_type: Optional[messages.InputScriptType], script_type: messages.InputScriptType | None,
show_display: bool, show_display: bool,
multisig_xpub: List[str], multisig_xpub: List[str],
multisig_threshold: Optional[int], multisig_threshold: int | None,
multisig_suffix_length: int, multisig_suffix_length: int,
chunkify: bool, chunkify: bool,
) -> str: ) -> str:
@ -220,7 +221,7 @@ def get_address(
multisig = None multisig = None
return btc.get_address( return btc.get_address(
client, session,
coin, coin,
address_n, address_n,
show_display, show_display,
@ -237,9 +238,9 @@ def get_address(
@click.option("-e", "--curve") @click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_public_node( def get_public_node(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
curve: Optional[str], curve: Optional[str],
@ -251,7 +252,7 @@ def get_public_node(
if script_type is None: if script_type is None:
script_type = guess_script_type_from_path(address_n) script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node( result = btc.get_public_node(
client, session,
address_n, address_n,
ecdsa_curve_name=curve, ecdsa_curve_name=curve,
show_display=show_display, show_display=show_display,
@ -277,7 +278,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor( def _get_descriptor(
client: "TrezorClient", session: "Session",
coin: Optional[str], coin: Optional[str],
account: int, account: int,
purpose: Optional[int], purpose: Optional[int],
@ -311,7 +312,7 @@ def _get_descriptor(
n = tools.parse_path(path) n = tools.parse_path(path)
pub = btc.get_public_node( pub = btc.get_public_node(
client, session,
n, n,
show_display=show_display, show_display=show_display,
coin_name=coin, coin_name=coin,
@ -348,9 +349,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True) @click.option("-d", "--show-display", is_flag=True)
@with_client @with_session
def get_descriptor( def get_descriptor(
client: "TrezorClient", session: "Session",
coin: Optional[str], coin: Optional[str],
account: int, account: int,
account_type: Optional[int], account_type: Optional[int],
@ -360,7 +361,7 @@ def get_descriptor(
"""Get descriptor of given account.""" """Get descriptor of given account."""
try: try:
return _get_descriptor( return _get_descriptor(
client, coin, account, account_type, script_type, show_display session, coin, account, account_type, script_type, show_display
) )
except ValueError as e: except ValueError as e:
raise click.ClickException(str(e)) raise click.ClickException(str(e))
@ -375,8 +376,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File()) @click.argument("json_file", type=click.File())
@with_client @with_session
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction. """Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -401,7 +402,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
} }
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, session,
coin, coin,
inputs, inputs,
outputs, outputs,
@ -432,9 +433,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
) )
@click.option("-C", "--chunkify", is_flag=True) @click.option("-C", "--chunkify", is_flag=True)
@click.argument("message") @click.argument("message")
@with_client @with_session
def sign_message( def sign_message(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
message: str, message: str,
@ -447,7 +448,7 @@ def sign_message(
if script_type is None: if script_type is None:
script_type = guess_script_type_from_path(address_n) script_type = guess_script_type_from_path(address_n)
res = btc.sign_message( res = btc.sign_message(
client, session,
coin, coin,
address_n, address_n,
message, message,
@ -468,9 +469,9 @@ def sign_message(
@click.argument("address") @click.argument("address")
@click.argument("signature") @click.argument("signature")
@click.argument("message") @click.argument("message")
@with_client @with_session
def verify_message( def verify_message(
client: "TrezorClient", session: "Session",
coin: str, coin: str,
address: str, address: str,
signature: str, signature: str,
@ -480,7 +481,7 @@ def verify_message(
"""Verify message.""" """Verify message."""
signature_bytes = base64.b64decode(signature) signature_bytes = base64.b64decode(signature)
return btc.verify_message( return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify session, coin, address, signature_bytes, message, chunkify=chunkify
) )

@ -36,14 +36,22 @@ class NewTransport:
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str: ... def get_path(self) -> str:
raise NotImplementedError
def open(self) -> None: ... def open(self) -> None:
raise NotImplementedError
def close(self) -> None: ... def close(self) -> None:
raise NotImplementedError
def write_chunk(self, chunk: bytes) -> None: ... def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self) -> bytes: ... def read_chunk(self) -> bytes:
raise NotImplementedError
def find_debug(self: "T") -> "T":
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int] CHUNK_SIZE: t.ClassVar[int]

Loading…
Cancel
Save