From 5ddf1dfafb4c8199f0ac627e9b336fddbdccd4e6 Mon Sep 17 00:00:00 2001 From: matejcik Date: Tue, 15 Sep 2020 13:04:08 +0200 Subject: [PATCH] feat!(python): modify btc.sign_tx api to accept kwargs Because we can't pass SignTx anymore because it has required fields and the caller is not supposed to fill out those. Instead you can send arbitrary kwargs that match signtx fields. BREAKING CHANGE: argument `details: SignTx` is no longer accepted. --- python/src/trezorlib/btc.py | 101 ++++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index 7dcbaa238..b15241693 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -15,40 +15,45 @@ # If not, see . from decimal import Decimal +from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple from . import exceptions, messages from .tools import expect, normalize_nfc, session +if TYPE_CHECKING: + from .client import TrezorClient + def from_json(json_dict): def make_input(vin): - i = messages.TxInputType() if "coinbase" in vin: - i.prev_hash = b"\0" * 32 - i.prev_index = 0xFFFFFFFF # signed int -1 - i.script_sig = bytes.fromhex(vin["coinbase"]) - i.sequence = vin["sequence"] + return messages.TxInputType( + prev_hash=b"\0" * 32, + prev_index=0xFFFFFFFF, # signed int -1 + script_sig=bytes.fromhex(vin["coinbase"]), + sequence=vin["sequence"], + ) else: - i.prev_hash = bytes.fromhex(vin["txid"]) - i.prev_index = vin["vout"] - i.script_sig = bytes.fromhex(vin["scriptSig"]["hex"]) - i.sequence = vin["sequence"] - - return i + return messages.TxInputType( + prev_hash=bytes.fromhex(vin["txid"]), + prev_index=vin["vout"], + script_sig=bytes.fromhex(vin["scriptSig"]["hex"]), + sequence=vin["sequence"], + ) def make_bin_output(vout): - o = messages.TxOutputBinType() - o.amount = int(Decimal(vout["value"]) * (10 ** 8)) - o.script_pubkey = bytes.fromhex(vout["scriptPubKey"]["hex"]) - return o + return messages.TxOutputBinType( + amount=int(Decimal(vout["value"]) * (10 ** 8)), + script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]), + ) - t = messages.TransactionType() - t.version = json_dict["version"] - t.lock_time = json_dict.get("locktime") - t.inputs = [make_input(vin) for vin in json_dict["vin"]] - t.bin_outputs = [make_bin_output(vout) for vout in json_dict["vout"]] - return t + return messages.TransactionType( + version=json_dict["version"], + lock_time=json_dict.get("locktime"), + inputs=[make_input(vin) for vin in json_dict["vin"]], + bin_outputs=[make_bin_output(vout) for vout in json_dict["vout"]], + ) @expect(messages.PublicKey) @@ -173,24 +178,31 @@ def verify_message(client, coin_name, address, signature, message): @session def sign_tx( - client, - coin_name, - inputs, - outputs, - details=None, - prev_txes=None, - preauthorized=False, -): - this_tx = messages.TransactionType(inputs=inputs, outputs=outputs) - - if details is None: - signtx = messages.SignTx() - else: - signtx = details - - signtx.coin_name = coin_name - signtx.inputs_count = len(inputs) - signtx.outputs_count = len(outputs) + client: "TrezorClient", + coin_name: str, + inputs: Sequence[messages.TxInputType], + outputs: Sequence[messages.TxOutputType], + prev_txes: Dict[bytes, messages.TransactionType], + preauthorized: bool = False, + **kwargs: Any, +) -> Tuple[Sequence[bytes], bytes]: + """Sign a Bitcoin-like transaction. + + Returns a list of signatures (one for each provided input) and the + network-serialized transaction. + + In addition to the required arguments, it is possible to specify additional + transaction properties (version, lock time, expiry...). Each additional argument + must correspond to a field in the `SignTx` data type. Note that some fields + (`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments + and cannot be overriden by kwargs. + """ + signtx = messages.SignTx( + coin_name=coin_name, inputs_count=len(inputs), outputs_count=len(outputs), + ) + for name, value in kwargs.items(): + if hasattr(signtx, name): + setattr(signtx, name, value) if preauthorized: res = client.call(messages.DoPreauthorized()) @@ -203,7 +215,7 @@ def sign_tx( signatures = [None] * len(inputs) serialized_tx = b"" - def copy_tx_meta(tx): + def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType: tx_copy = messages.TransactionType(**tx) # clear fields tx_copy.inputs_cnt = len(tx.inputs) @@ -215,6 +227,15 @@ def sign_tx( tx_copy.extra_data = None return tx_copy + this_tx = messages.TransactionType( + inputs=inputs, + outputs=outputs, + inputs_cnt=len(inputs), + outputs_cnt=len(outputs), + # pick either kw-provided or default value from the SignTx request + version=signtx.version, + ) + R = messages.RequestType while isinstance(res, messages.TxRequest): # If there's some part of signed transaction, let's add it