1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-10 01:30:19 +00:00

wip trezorlib add btc commands

This commit is contained in:
M1nd3r 2024-09-16 14:36:40 +02:00
parent 6a65d62353
commit 34d3dbdeb1
3 changed files with 59 additions and 47 deletions

View File

@ -23,7 +23,8 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict
from . import exceptions, messages
from .tools import expect, prepare_message_bytes, session
from .tools import expect, prepare_message_bytes
from .transport.new.session import Session
if TYPE_CHECKING:
from .client import TrezorClient
@ -105,7 +106,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
@expect(messages.PublicKey)
def get_public_node(
client: "TrezorClient",
session: "Session",
n: "Address",
ecdsa_curve_name: Optional[str] = None,
show_display: bool = False,
@ -116,13 +117,13 @@ def get_public_node(
unlock_path_mac: Optional[bytes] = None,
) -> "MessageType":
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
return session.call(
messages.GetPublicKey(
address_n=n,
ecdsa_curve_name=ecdsa_curve_name,
@ -141,7 +142,7 @@ def get_address(*args: Any, **kwargs: Any):
@expect(messages.Address)
def get_authenticated_address(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
show_display: bool = False,
@ -153,13 +154,13 @@ def get_authenticated_address(
chunkify: bool = False,
) -> "MessageType":
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
return session.call(
messages.GetAddress(
address_n=n,
coin_name=coin_name,
@ -172,6 +173,7 @@ def get_authenticated_address(
)
# TODO this is used by tests only
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id(
client: "TrezorClient",
@ -190,6 +192,7 @@ def get_ownership_id(
)
# TODO this is used by tests only
def get_ownership_proof(
client: "TrezorClient",
coin_name: str,
@ -226,7 +229,7 @@ def get_ownership_proof(
@expect(messages.MessageSignature)
def sign_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
n: "Address",
message: AnyStr,
@ -234,7 +237,7 @@ def sign_message(
no_script_type: bool = False,
chunkify: bool = False,
) -> "MessageType":
return client.call(
return session.call(
messages.SignMessage(
coin_name=coin_name,
address_n=n,
@ -247,7 +250,7 @@ def sign_message(
def verify_message(
client: "TrezorClient",
session: "Session",
coin_name: str,
address: str,
signature: bytes,
@ -255,7 +258,7 @@ def verify_message(
chunkify: bool = False,
) -> bool:
try:
resp = client.call(
resp = session.call(
messages.VerifyMessage(
address=address,
signature=signature,
@ -269,9 +272,9 @@ def verify_message(
return isinstance(resp, messages.Success)
@session
# @session
def sign_tx(
client: "TrezorClient",
session: "Session",
coin_name: str,
inputs: Sequence[messages.TxInputType],
outputs: Sequence[messages.TxOutputType],
@ -319,17 +322,17 @@ def sign_tx(
setattr(signtx, name, value)
if unlock_path:
res = client.call(
res = session.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
elif preauthorized:
res = client.call(messages.DoPreauthorized())
res = session.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(signtx)
res = session.call(signtx)
# Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -388,7 +391,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index]
res = client.call(msg)
res = session.call(msg)
else:
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
@ -418,7 +421,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}."
)
res = client.call(messages.TxAck(tx=msg))
res = session.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest):
raise exceptions.TrezorException("Unexpected message")

View File

@ -13,6 +13,7 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import base64
import json
@ -22,10 +23,10 @@ import click
import construct as c
from .. import btc, messages, protobuf, tools
from . import ChoiceType, with_client
from . import ChoiceType, with_session
if TYPE_CHECKING:
from ..client import TrezorClient
from ..transport.new.session import Session
PURPOSE_BIP44 = 44
PURPOSE_BIP48 = 48
@ -168,15 +169,15 @@ def cli() -> None:
default=2,
)
@click.option("-C", "--chunkify", is_flag=True)
@with_client
@with_session
def get_address(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
script_type: Optional[messages.InputScriptType],
script_type: messages.InputScriptType | None,
show_display: bool,
multisig_xpub: List[str],
multisig_threshold: Optional[int],
multisig_threshold: int | None,
multisig_suffix_length: int,
chunkify: bool,
) -> str:
@ -220,7 +221,7 @@ def get_address(
multisig = None
return btc.get_address(
client,
session,
coin,
address_n,
show_display,
@ -237,9 +238,9 @@ def get_address(
@click.option("-e", "--curve")
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_public_node(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
curve: Optional[str],
@ -251,7 +252,7 @@ def get_public_node(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
result = btc.get_public_node(
client,
session,
address_n,
ecdsa_curve_name=curve,
show_display=show_display,
@ -277,7 +278,7 @@ def _append_descriptor_checksum(desc: str) -> str:
def _get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
purpose: Optional[int],
@ -311,7 +312,7 @@ def _get_descriptor(
n = tools.parse_path(path)
pub = btc.get_public_node(
client,
session,
n,
show_display=show_display,
coin_name=coin,
@ -348,9 +349,9 @@ def _get_descriptor(
@click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE))
@click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS))
@click.option("-d", "--show-display", is_flag=True)
@with_client
@with_session
def get_descriptor(
client: "TrezorClient",
session: "Session",
coin: Optional[str],
account: int,
account_type: Optional[int],
@ -360,7 +361,7 @@ def get_descriptor(
"""Get descriptor of given account."""
try:
return _get_descriptor(
client, coin, account, account_type, script_type, show_display
session, coin, account, account_type, script_type, show_display
)
except ValueError as e:
raise click.ClickException(str(e))
@ -375,8 +376,8 @@ def get_descriptor(
@click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("json_file", type=click.File())
@with_client
def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
@with_session
def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None:
"""Sign transaction.
Transaction data must be provided in a JSON file. See `transaction-format.md` for
@ -401,7 +402,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
}
_, serialized_tx = btc.sign_tx(
client,
session,
coin,
inputs,
outputs,
@ -432,9 +433,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None:
)
@click.option("-C", "--chunkify", is_flag=True)
@click.argument("message")
@with_client
@with_session
def sign_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
message: str,
@ -447,7 +448,7 @@ def sign_message(
if script_type is None:
script_type = guess_script_type_from_path(address_n)
res = btc.sign_message(
client,
session,
coin,
address_n,
message,
@ -468,9 +469,9 @@ def sign_message(
@click.argument("address")
@click.argument("signature")
@click.argument("message")
@with_client
@with_session
def verify_message(
client: "TrezorClient",
session: "Session",
coin: str,
address: str,
signature: str,
@ -480,7 +481,7 @@ def verify_message(
"""Verify message."""
signature_bytes = base64.b64decode(signature)
return btc.verify_message(
client, coin, address, signature_bytes, message, chunkify=chunkify
session, coin, address, signature_bytes, message, chunkify=chunkify
)

View File

@ -36,14 +36,22 @@ class NewTransport:
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str: ...
def get_path(self) -> str:
raise NotImplementedError
def open(self) -> None: ...
def open(self) -> None:
raise NotImplementedError
def close(self) -> None: ...
def close(self) -> None:
raise NotImplementedError
def write_chunk(self, chunk: bytes) -> None: ...
def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self) -> bytes: ...
def read_chunk(self) -> bytes:
raise NotImplementedError
def find_debug(self: "T") -> "T":
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int]