feat!(python): modify btc.sign_tx api to accept kwargs

Because we can't pass SignTx anymore because it has required fields and
the caller is not supposed to fill out those.

Instead you can send arbitrary kwargs that match signtx fields.

BREAKING CHANGE: argument `details: SignTx` is no longer accepted.
pull/1279/head
matejcik 4 years ago committed by matejcik
parent 244b264b47
commit 5ddf1dfafb

@ -15,40 +15,45 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from decimal import Decimal from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
from . import exceptions, messages from . import exceptions, messages
from .tools import expect, normalize_nfc, session from .tools import expect, normalize_nfc, session
if TYPE_CHECKING:
from .client import TrezorClient
def from_json(json_dict): def from_json(json_dict):
def make_input(vin): def make_input(vin):
i = messages.TxInputType()
if "coinbase" in vin: if "coinbase" in vin:
i.prev_hash = b"\0" * 32 return messages.TxInputType(
i.prev_index = 0xFFFFFFFF # signed int -1 prev_hash=b"\0" * 32,
i.script_sig = bytes.fromhex(vin["coinbase"]) prev_index=0xFFFFFFFF, # signed int -1
i.sequence = vin["sequence"] script_sig=bytes.fromhex(vin["coinbase"]),
sequence=vin["sequence"],
)
else: else:
i.prev_hash = bytes.fromhex(vin["txid"]) return messages.TxInputType(
i.prev_index = vin["vout"] prev_hash=bytes.fromhex(vin["txid"]),
i.script_sig = bytes.fromhex(vin["scriptSig"]["hex"]) prev_index=vin["vout"],
i.sequence = vin["sequence"] script_sig=bytes.fromhex(vin["scriptSig"]["hex"]),
sequence=vin["sequence"],
return i )
def make_bin_output(vout): def make_bin_output(vout):
o = messages.TxOutputBinType() return messages.TxOutputBinType(
o.amount = int(Decimal(vout["value"]) * (10 ** 8)) amount=int(Decimal(vout["value"]) * (10 ** 8)),
o.script_pubkey = bytes.fromhex(vout["scriptPubKey"]["hex"]) script_pubkey=bytes.fromhex(vout["scriptPubKey"]["hex"]),
return o )
t = messages.TransactionType() return messages.TransactionType(
t.version = json_dict["version"] version=json_dict["version"],
t.lock_time = json_dict.get("locktime") lock_time=json_dict.get("locktime"),
t.inputs = [make_input(vin) for vin in json_dict["vin"]] inputs=[make_input(vin) for vin in json_dict["vin"]],
t.bin_outputs = [make_bin_output(vout) for vout in json_dict["vout"]] bin_outputs=[make_bin_output(vout) for vout in json_dict["vout"]],
return t )
@expect(messages.PublicKey) @expect(messages.PublicKey)
@ -173,24 +178,31 @@ def verify_message(client, coin_name, address, signature, message):
@session @session
def sign_tx( def sign_tx(
client, client: "TrezorClient",
coin_name, coin_name: str,
inputs, inputs: Sequence[messages.TxInputType],
outputs, outputs: Sequence[messages.TxOutputType],
details=None, prev_txes: Dict[bytes, messages.TransactionType],
prev_txes=None, preauthorized: bool = False,
preauthorized=False, **kwargs: Any,
): ) -> Tuple[Sequence[bytes], bytes]:
this_tx = messages.TransactionType(inputs=inputs, outputs=outputs) """Sign a Bitcoin-like transaction.
if details is None: Returns a list of signatures (one for each provided input) and the
signtx = messages.SignTx() network-serialized transaction.
else:
signtx = details In addition to the required arguments, it is possible to specify additional
transaction properties (version, lock time, expiry...). Each additional argument
signtx.coin_name = coin_name must correspond to a field in the `SignTx` data type. Note that some fields
signtx.inputs_count = len(inputs) (`inputs_count`, `outputs_count`, `coin_name`) will be inferred from the arguments
signtx.outputs_count = len(outputs) and cannot be overriden by kwargs.
"""
signtx = messages.SignTx(
coin_name=coin_name, inputs_count=len(inputs), outputs_count=len(outputs),
)
for name, value in kwargs.items():
if hasattr(signtx, name):
setattr(signtx, name, value)
if preauthorized: if preauthorized:
res = client.call(messages.DoPreauthorized()) res = client.call(messages.DoPreauthorized())
@ -203,7 +215,7 @@ def sign_tx(
signatures = [None] * len(inputs) signatures = [None] * len(inputs)
serialized_tx = b"" serialized_tx = b""
def copy_tx_meta(tx): def copy_tx_meta(tx: messages.TransactionType) -> messages.TransactionType:
tx_copy = messages.TransactionType(**tx) tx_copy = messages.TransactionType(**tx)
# clear fields # clear fields
tx_copy.inputs_cnt = len(tx.inputs) tx_copy.inputs_cnt = len(tx.inputs)
@ -215,6 +227,15 @@ def sign_tx(
tx_copy.extra_data = None tx_copy.extra_data = None
return tx_copy return tx_copy
this_tx = messages.TransactionType(
inputs=inputs,
outputs=outputs,
inputs_cnt=len(inputs),
outputs_cnt=len(outputs),
# pick either kw-provided or default value from the SignTx request
version=signtx.version,
)
R = messages.RequestType R = messages.RequestType
while isinstance(res, messages.TxRequest): while isinstance(res, messages.TxRequest):
# If there's some part of signed transaction, let's add it # If there's some part of signed transaction, let's add it

Loading…
Cancel
Save