diff --git a/trezorctl b/trezorctl index d864e6c0b5..6e8a39a470 100755 --- a/trezorctl +++ b/trezorctl @@ -21,6 +21,7 @@ # along with this library. If not, see . import base64 +import decimal import json import os import sys @@ -766,6 +767,7 @@ def get_public_node(connect, coin, address, curve, script_type, show_display): def sign_tx(connect, coin): client = connect() if coin in coins.tx_api: + coin_data = coins.by_name[coin] txapi = coins.tx_api[coin] else: click.echo('Coin "%s" is not recognized.' % coin, err=True) @@ -774,8 +776,6 @@ def sign_tx(connect, coin): ) sys.exit(1) - client.set_tx_api(txapi) - def default_script_type(address_n): script_type = "address" @@ -791,6 +791,7 @@ def sign_tx(connect, coin): return bytes.fromhex(txid), int(vout) inputs = [] + txes = {} while True: click.echo() prev = click.prompt( @@ -800,7 +801,15 @@ def sign_tx(connect, coin): break prev_hash, prev_index = prev address_n = click.prompt("BIP-32 path to derive the key", type=tools.parse_path) - amount = click.prompt("Input amount (satoshis)", type=int, default=0) + try: + tx = txapi[prev_hash] + txes[prev_hash] = tx + amount = tx.bin_outputs[prev_index].amount + click.echo("Prefilling input amount: {}".format(amount)) + except Exception as e: + print(e) + click.echo("Failed to fetch transation. This might bite you later.") + amount = click.prompt("Input amount (satoshis)", type=int, default=0) sequence = click.prompt( "Sequence Number to use (RBF opt-in enabled by default)", type=int, @@ -825,14 +834,14 @@ def sign_tx(connect, coin): script_type=script_type, sequence=sequence, ) - if txapi.bip115: + if coin_data["bip115"]: prev_output = txapi.get_tx(prev_hash.hex()).bin_outputs[prev_index] new_input.prev_block_hash_bip115 = prev_output.block_hash new_input.prev_block_height_bip115 = prev_output.block_height inputs.append(new_input) - if txapi.bip115: + if coin_data["bip115"]: current_block_height = txapi.current_height() # Zencash recommendation for the better protection block_height = current_block_height - 300 @@ -878,14 +887,14 @@ def sign_tx(connect, coin): ) ) - tx_version = click.prompt("Transaction version", type=int, default=2) - tx_locktime = click.prompt("Transaction locktime", type=int, default=0) - tx_timestamp = click.prompt( - "Transaction timestamp (Capricoin)", type=int, default=None - ) + signtx = proto.SignTx() + signtx.version = click.prompt("Transaction version", type=int, default=2) + signtx.lock_time = click.prompt("Transaction locktime", type=int, default=0) + if coin == "Capricoin": + signtx.timestamp = click.prompt("Transaction timestamp", type=int) _, serialized_tx = btc.sign_tx( - client, coin, inputs, outputs, tx_version, tx_locktime, timestamp=tx_timestamp + client, coin, inputs, outputs, details=signtx, prev_txes=txes ) client.close() diff --git a/trezorlib/btc.py b/trezorlib/btc.py index d2e2270a02..df4840d5d0 100644 --- a/trezorlib/btc.py +++ b/trezorlib/btc.py @@ -72,7 +72,22 @@ def verify_message(client, coin_name, address, signature, message): @session def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None): - my_tx = messages.TransactionType(inputs=inputs, outputs=outputs) + # set up a transactions dict + txes = {None: messages.TransactionType(inputs=inputs, outputs=outputs)} + # preload all relevant transactions ahead of time + for inp in inputs: + if inp.script_type not in ( + messages.InputScriptType.SPENDP2SHWITNESS, + messages.InputScriptType.SPENDWITNESS, + messages.InputScriptType.EXTERNAL, + ): + try: + prev_tx = prev_txes[inp.prev_hash] + except Exception as e: + raise ValueError("Could not retrieve prev_tx") from e + if not isinstance(prev_tx, messages.TransactionType): + raise ValueError("Invalid value for prev_tx") from None + txes[inp.prev_hash] = prev_tx if details is None: signtx = messages.SignTx() @@ -120,10 +135,7 @@ def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None): break # Device asked for one more information, let's process it. - if not res.details.tx_hash: - current_tx = my_tx - else: - current_tx = prev_txes[res.details.tx_hash] + current_tx = txes[res.details.tx_hash] if res.request_type == R.TXMETA: msg = copy_tx_meta(current_tx) diff --git a/trezorlib/client.py b/trezorlib/client.py index 9ea1b4bf64..42d76abbc3 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -169,10 +169,9 @@ class ProtocolMixin(object): super(ProtocolMixin, self).__init__(*args, **kwargs) self.state = state self.init_device() - self.tx_api = None def set_tx_api(self, tx_api): - self.tx_api = tx_api + warnings.warn("set_tx_api is deprecated, use new arguments to sign_tx") def init_device(self): resp = self.call(proto.Initialize(state=self.state)) @@ -211,31 +210,6 @@ class ProtocolMixin(object): def get_device_id(self): return self.features.device_id - def _prepare_sign_tx(self, inputs, outputs): - tx = proto.TransactionType() - tx.inputs = inputs - tx.outputs = outputs - - txes = {None: tx} - - for inp in inputs: - if inp.prev_hash in txes: - continue - - if inp.script_type in ( - proto.InputScriptType.SPENDP2SHWITNESS, - proto.InputScriptType.SPENDWITNESS, - ): - continue - - if not self.tx_api: - raise RuntimeError("TX_API not defined") - - prev_tx = self.tx_api.get_tx(inp.prev_hash.hex()) - txes[inp.prev_hash] = prev_tx - - return txes - @tools.expect(proto.Success, field="message") def clear_session(self): return self.call(proto.ClearSession())