diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index 7dcbaa238..b15241693 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -15,40 +15,45 @@ # If not, see . from decimal import Decimal +from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple from . import exceptions, messages from .tools import expect, normalize_nfc, session +if TYPE_CHECKING: + from .client import TrezorClient + def from_json(json_dict): def make_input(vin): - i = messages.TxInputType() if "coinbase" in vin: - i.prev_hash = b"\0" * 32 - i.prev_index = 0xFFFFFFFF # signed int -1 - i.script_sig = bytes.fromhex(vin["coinbase"]) - i.sequence = vin["sequence"] + return messages.TxInputType( + prev_hash=b"\0" * 32, + prev_index=0xFFFFFFFF, # signed int -1 + script_sig=bytes.fromhex(vin["coinbase"]), + sequence=vin["sequence"], + ) else: - i.prev_hash = bytes.fromhex(vin["txid"]) - i.prev_index = vin["vout"] - i.script_sig = bytes.fromhex(vin["scriptSig"]["hex"]) - i.sequence = vin["sequence"] - - return i + return messages.TxInputType( + prev_hash=bytes.fromhex(vin["txid"]), + prev_index=vin["vout"], + script_sig=bytes.fromhex(vin["scriptSig"]["hex"]), + sequence=vin["sequence"], + ) def make_bin_output(vout): - o = messages.TxOutputBinType() - o.amount = int(Decimal(vout["value"]) * (10 ** 8)) - o.script_pubkey = bytes.fromhex(vout["scriptPubKey"]["hex"]) - return o + return messages.TxOutputBinType( + amount=int(Decimal(vout["value"]) * (10 ** 8)), + script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]), + ) - t = messages.TransactionType() - t.version = json_dict["version"] - t.lock_time = json_dict.get("locktime") - t.inputs = [make_input(vin) for vin in json_dict["vin"]] - t.bin_outputs = [make_bin_output(vout) for vout in json_dict["vout"]] - return t + return messages.TransactionType( + version=json_dict["version"], + lock_time=json_dict.get("locktime"), + inputs=[make_input(vin) for vin in json_dict["vin"]], + bin_outputs=[make_bin_output(vout) for vout in json_dict["vout"]], + ) @expect(messages.PublicKey) @@ -173,24 +178,31 @@ def verify_message(client, coin_name, address, signature, message): @session def sign_tx( - client, - coin_name, - inputs, - outputs, - details=None, - prev_txes=None, - preauthorized=False, -): - this_tx = messages.TransactionType(inputs=inputs, outputs=outputs) - - if details is None: - signtx = messages.SignTx() - else: - signtx = details - - signtx.coin_name = coin_name - signtx.inputs_count = len(inputs) - signtx.outputs_count = len(outputs) + client: "TrezorClient", + coin_name: str, + inputs: Sequence[messages.TxInputType], + outputs: Sequence[messages.TxOutputType], + prev_txes: Dict[bytes, messages.TransactionType], + preauthorized: bool = False, + **kwargs: Any, +) -> Tuple[Sequence[bytes], bytes]: + """Sign a Bitcoin-like transaction. + + Returns a list of signatures (one for each provided input) and the + network-serialized transaction. + + In addition to the required arguments, it is possible to specify additional + transaction properties (version, lock time, expiry...). Each additional argument + must correspond to a field in the `SignTx` data type. Note that some fields + (`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments + and cannot be overriden by kwargs. + """ + signtx = messages.SignTx( + coin_name=coin_name, inputs_count=len(inputs), outputs_count=len(outputs), + ) + for name, value in kwargs.items(): + if hasattr(signtx, name): + setattr(signtx, name, value) if preauthorized: res = client.call(messages.DoPreauthorized()) @@ -203,7 +215,7 @@ def sign_tx( signatures = [None] * len(inputs) serialized_tx = b"" - def copy_tx_meta(tx): + def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: tx_copy = messages.TransactionType(**tx) # clear fields tx_copy.inputs_cnt = len(tx.inputs) @@ -215,6 +227,15 @@ def sign_tx( tx_copy.extra_data = None return tx_copy + this_tx = messages.TransactionType( + inputs=inputs, + outputs=outputs, + inputs_cnt=len(inputs), + outputs_cnt=len(outputs), + # pick either kw-provided or default value from the SignTx request + version=signtx.version, + ) + R = messages.RequestType while isinstance(res, messages.TxRequest): # If there's some part of signed transaction, let's add it