1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-27 23:02:02 +00:00

trezorlib: finalize BTC API changes

- drop set_tx_api method and its usage from trezorctl
- drop _prepare_sign_tx which is not used anymore
- adapt trezorctl to new signing API
- make trezorctl signing smarter, ahead of moving it elsewhere
This commit is contained in:
matejcik 2018-11-02 16:34:44 +01:00
parent 620e48e4d0
commit c269d67cde
3 changed files with 38 additions and 43 deletions

View File

@ -21,6 +21,7 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
import base64 import base64
import decimal
import json import json
import os import os
import sys import sys
@ -766,6 +767,7 @@ def get_public_node(connect, coin, address, curve, script_type, show_display):
def sign_tx(connect, coin): def sign_tx(connect, coin):
client = connect() client = connect()
if coin in coins.tx_api: if coin in coins.tx_api:
coin_data = coins.by_name[coin]
txapi = coins.tx_api[coin] txapi = coins.tx_api[coin]
else: else:
click.echo('Coin "%s" is not recognized.' % coin, err=True) click.echo('Coin "%s" is not recognized.' % coin, err=True)
@ -774,8 +776,6 @@ def sign_tx(connect, coin):
) )
sys.exit(1) sys.exit(1)
client.set_tx_api(txapi)
def default_script_type(address_n): def default_script_type(address_n):
script_type = "address" script_type = "address"
@ -791,6 +791,7 @@ def sign_tx(connect, coin):
return bytes.fromhex(txid), int(vout) return bytes.fromhex(txid), int(vout)
inputs = [] inputs = []
txes = {}
while True: while True:
click.echo() click.echo()
prev = click.prompt( prev = click.prompt(
@ -800,6 +801,14 @@ def sign_tx(connect, coin):
break break
prev_hash, prev_index = prev prev_hash, prev_index = prev
address_n = click.prompt("BIP-32 path to derive the key", type=tools.parse_path) address_n = click.prompt("BIP-32 path to derive the key", type=tools.parse_path)
try:
tx = txapi[prev_hash]
txes[prev_hash] = tx
amount = tx.bin_outputs[prev_index].amount
click.echo("Prefilling input amount: {}".format(amount))
except Exception as e:
print(e)
click.echo("Failed to fetch transation. This might bite you later.")
amount = click.prompt("Input amount (satoshis)", type=int, default=0) amount = click.prompt("Input amount (satoshis)", type=int, default=0)
sequence = click.prompt( sequence = click.prompt(
"Sequence Number to use (RBF opt-in enabled by default)", "Sequence Number to use (RBF opt-in enabled by default)",
@ -825,14 +834,14 @@ def sign_tx(connect, coin):
script_type=script_type, script_type=script_type,
sequence=sequence, sequence=sequence,
) )
if txapi.bip115: if coin_data["bip115"]:
prev_output = txapi.get_tx(prev_hash.hex()).bin_outputs[prev_index] prev_output = txapi.get_tx(prev_hash.hex()).bin_outputs[prev_index]
new_input.prev_block_hash_bip115 = prev_output.block_hash new_input.prev_block_hash_bip115 = prev_output.block_hash
new_input.prev_block_height_bip115 = prev_output.block_height new_input.prev_block_height_bip115 = prev_output.block_height
inputs.append(new_input) inputs.append(new_input)
if txapi.bip115: if coin_data["bip115"]:
current_block_height = txapi.current_height() current_block_height = txapi.current_height()
# Zencash recommendation for the better protection # Zencash recommendation for the better protection
block_height = current_block_height - 300 block_height = current_block_height - 300
@ -878,14 +887,14 @@ def sign_tx(connect, coin):
) )
) )
tx_version = click.prompt("Transaction version", type=int, default=2) signtx = proto.SignTx()
tx_locktime = click.prompt("Transaction locktime", type=int, default=0) signtx.version = click.prompt("Transaction version", type=int, default=2)
tx_timestamp = click.prompt( signtx.lock_time = click.prompt("Transaction locktime", type=int, default=0)
"Transaction timestamp (Capricoin)", type=int, default=None if coin == "Capricoin":
) signtx.timestamp = click.prompt("Transaction timestamp", type=int)
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, coin, inputs, outputs, tx_version, tx_locktime, timestamp=tx_timestamp client, coin, inputs, outputs, details=signtx, prev_txes=txes
) )
client.close() client.close()

View File

@ -72,7 +72,22 @@ def verify_message(client, coin_name, address, signature, message):
@session @session
def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None): def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None):
my_tx = messages.TransactionType(inputs=inputs, outputs=outputs) # set up a transactions dict
txes = {None: messages.TransactionType(inputs=inputs, outputs=outputs)}
# preload all relevant transactions ahead of time
for inp in inputs:
if inp.script_type not in (
messages.InputScriptType.SPENDP2SHWITNESS,
messages.InputScriptType.SPENDWITNESS,
messages.InputScriptType.EXTERNAL,
):
try:
prev_tx = prev_txes[inp.prev_hash]
except Exception as e:
raise ValueError("Could not retrieve prev_tx") from e
if not isinstance(prev_tx, messages.TransactionType):
raise ValueError("Invalid value for prev_tx") from None
txes[inp.prev_hash] = prev_tx
if details is None: if details is None:
signtx = messages.SignTx() signtx = messages.SignTx()
@ -120,10 +135,7 @@ def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None):
break break
# Device asked for one more information, let's process it. # Device asked for one more information, let's process it.
if not res.details.tx_hash: current_tx = txes[res.details.tx_hash]
current_tx = my_tx
else:
current_tx = prev_txes[res.details.tx_hash]
if res.request_type == R.TXMETA: if res.request_type == R.TXMETA:
msg = copy_tx_meta(current_tx) msg = copy_tx_meta(current_tx)

View File

@ -169,10 +169,9 @@ class ProtocolMixin(object):
super(ProtocolMixin, self).__init__(*args, **kwargs) super(ProtocolMixin, self).__init__(*args, **kwargs)
self.state = state self.state = state
self.init_device() self.init_device()
self.tx_api = None
def set_tx_api(self, tx_api): def set_tx_api(self, tx_api):
self.tx_api = tx_api warnings.warn("set_tx_api is deprecated, use new arguments to sign_tx")
def init_device(self): def init_device(self):
resp = self.call(proto.Initialize(state=self.state)) resp = self.call(proto.Initialize(state=self.state))
@ -211,31 +210,6 @@ class ProtocolMixin(object):
def get_device_id(self): def get_device_id(self):
return self.features.device_id return self.features.device_id
def _prepare_sign_tx(self, inputs, outputs):
tx = proto.TransactionType()
tx.inputs = inputs
tx.outputs = outputs
txes = {None: tx}
for inp in inputs:
if inp.prev_hash in txes:
continue
if inp.script_type in (
proto.InputScriptType.SPENDP2SHWITNESS,
proto.InputScriptType.SPENDWITNESS,
):
continue
if not self.tx_api:
raise RuntimeError("TX_API not defined")
prev_tx = self.tx_api.get_tx(inp.prev_hash.hex())
txes[inp.prev_hash] = prev_tx
return txes
@tools.expect(proto.Success, field="message") @tools.expect(proto.Success, field="message")
def clear_session(self): def clear_session(self):
return self.call(proto.ClearSession()) return self.call(proto.ClearSession())