1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-04 11:51:50 +00:00

btc: refactor and cleanup sign_tx api & flow

This commit is contained in:
matejcik 2018-10-30 14:56:57 +01:00
parent ea675f1e58
commit e5e0759dc8

View File

@ -1,18 +1,18 @@
from . import messages as proto from . import messages, tools
from .tools import CallException, expect, normalize_nfc, session from .tools import CallException, expect, normalize_nfc, session
@expect(proto.PublicKey) @expect(messages.PublicKey)
def get_public_node( def get_public_node(
client, client,
n, n,
ecdsa_curve_name=None, ecdsa_curve_name=None,
show_display=False, show_display=False,
coin_name=None, coin_name=None,
script_type=proto.InputScriptType.SPENDADDRESS, script_type=messages.InputScriptType.SPENDADDRESS,
): ):
return client.call( return client.call(
proto.GetPublicKey( messages.GetPublicKey(
address_n=n, address_n=n,
ecdsa_curve_name=ecdsa_curve_name, ecdsa_curve_name=ecdsa_curve_name,
show_display=show_display, show_display=show_display,
@ -22,18 +22,17 @@ def get_public_node(
) )
@expect(proto.Address, field="address") @expect(messages.Address, field="address")
def get_address( def get_address(
client, client,
coin_name, coin_name,
n, n,
show_display=False, show_display=False,
multisig=None, multisig=None,
script_type=proto.InputScriptType.SPENDADDRESS, script_type=messages.InputScriptType.SPENDADDRESS,
): ):
if multisig:
return client.call( return client.call(
proto.GetAddress( messages.GetAddress(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
show_display=show_display, show_display=show_display,
@ -41,24 +40,15 @@ def get_address(
script_type=script_type, script_type=script_type,
) )
) )
else:
return client.call(
proto.GetAddress(
address_n=n,
coin_name=coin_name,
show_display=show_display,
script_type=script_type,
)
)
@expect(proto.MessageSignature) @expect(messages.MessageSignature)
def sign_message( def sign_message(
client, coin_name, n, message, script_type=proto.InputScriptType.SPENDADDRESS client, coin_name, n, message, script_type=messages.InputScriptType.SPENDADDRESS
): ):
message = normalize_nfc(message) message = normalize_nfc(message)
return client.call( return client.call(
proto.SignMessage( messages.SignMessage(
coin_name=coin_name, address_n=n, message=message, script_type=script_type coin_name=coin_name, address_n=n, message=message, script_type=script_type
) )
) )
@ -68,7 +58,7 @@ def verify_message(client, coin_name, address, signature, message):
message = normalize_nfc(message) message = normalize_nfc(message)
try: try:
resp = client.call( resp = client.call(
proto.VerifyMessage( messages.VerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
message=message, message=message,
@ -77,147 +67,95 @@ def verify_message(client, coin_name, address, signature, message):
) )
except CallException as e: except CallException as e:
resp = e resp = e
return isinstance(resp, proto.Success) return isinstance(resp, messages.Success)
@session @session
def sign_tx( def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None):
client, my_tx = messages.TransactionType(inputs=inputs, outputs=outputs)
coin_name,
inputs,
outputs,
version=None,
lock_time=None,
expiry=None,
overwintered=None,
version_group_id=None,
debug_processor=None,
timestamp=None,
):
# start = time.time()
txes = client._prepare_sign_tx(inputs, outputs)
# Prepare and send initial message if details is None:
tx = proto.SignTx() signtx = messages.SignTx()
tx.inputs_count = len(inputs) else:
tx.outputs_count = len(outputs) signtx = details
tx.coin_name = coin_name
if version is not None: signtx.coin_name = coin_name
tx.version = version signtx.inputs_count = len(inputs)
if lock_time is not None: signtx.outputs_count = len(outputs)
tx.lock_time = lock_time
if expiry is not None: res = client.call(signtx)
tx.expiry = expiry
if overwintered is not None:
tx.overwintered = overwintered
if version_group_id is not None:
tx.version_group_id = version_group_id
if timestamp is not None:
tx.timestamp = timestamp
res = client.call(tx)
# Prepare structure for signatures # Prepare structure for signatures
signatures = [None] * len(inputs) signatures = [None] * len(inputs)
serialized_tx = b"" serialized_tx = b""
counter = 0 def copy_tx_meta(tx):
while True: tx_copy = messages.TransactionType()
counter += 1 tx_copy.CopyFrom(tx)
# clear fields
if isinstance(res, proto.Failure): tx_copy.inputs_cnt = len(tx.inputs)
raise CallException("Signing failed") tx_copy.inputs = []
tx_copy.outputs_cnt = len(tx.bin_outputs or tx.outputs)
if not isinstance(res, proto.TxRequest): tx_copy.outputs = []
raise CallException("Unexpected message") tx_copy.bin_outputs = []
tx_copy.extra_data_len = len(tx.extra_data or b"")
tx_copy.extra_data = None
return tx_copy
R = messages.RequestType
while isinstance(res, messages.TxRequest):
# If there's some part of signed transaction, let's add it # If there's some part of signed transaction, let's add it
if res.serialized and res.serialized.serialized_tx: if res.serialized:
# log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx)) if res.serialized.serialized_tx:
serialized_tx += res.serialized.serialized_tx serialized_tx += res.serialized.serialized_tx
if res.serialized and res.serialized.signature_index is not None: if res.serialized.signature_index is not None:
if signatures[res.serialized.signature_index] is not None: idx = res.serialized.signature_index
raise ValueError( sig = res.serialized.signature
"Signature for index %d already filled" if signatures[idx] is not None:
% res.serialized.signature_index raise ValueError("Signature for index %d already filled" % idx)
) signatures[idx] = sig
signatures[res.serialized.signature_index] = res.serialized.signature
if res.request_type == proto.RequestType.TXFINISHED: if res.request_type == R.TXFINISHED:
# Device didn't ask for more information, finish workflow
break break
# Device asked for one more information, let's process it. # Device asked for one more information, let's process it.
if not res.details.tx_hash: if not res.details.tx_hash:
current_tx = txes[None] current_tx = my_tx
else: else:
current_tx = txes[bytes(res.details.tx_hash)] current_tx = prev_txes[res.details.tx_hash]
if res.request_type == proto.RequestType.TXMETA: if res.request_type == R.TXMETA:
msg = proto.TransactionType() msg = copy_tx_meta(current_tx)
msg.version = current_tx.version res = client.call(messages.TxAck(tx=msg))
msg.lock_time = current_tx.lock_time
msg.inputs_cnt = len(current_tx.inputs)
msg.timestamp = current_tx.timestamp
if res.details.tx_hash:
msg.outputs_cnt = len(current_tx.bin_outputs)
else:
msg.outputs_cnt = len(current_tx.outputs)
msg.extra_data_len = (
len(current_tx.extra_data) if current_tx.extra_data else 0
)
res = client.call(proto.TxAck(tx=msg))
continue
elif res.request_type == proto.RequestType.TXINPUT: elif res.request_type == R.TXINPUT:
msg = proto.TransactionType() msg = messages.TransactionType()
msg.inputs = [current_tx.inputs[res.details.request_index]] msg.inputs = [current_tx.inputs[res.details.request_index]]
if debug_processor is not None: res = client.call(messages.TxAck(tx=msg))
# msg needs to be deep copied so when it's modified
# the other messages stay intact
from copy import deepcopy
msg = deepcopy(msg) elif res.request_type == R.TXOUTPUT:
# If debug_processor function is provided, msg = messages.TransactionType()
# pass thru it the request and prepared response.
# This is useful for tests, see test_msg_signtx
msg = debug_processor(res, msg)
res = client.call(proto.TxAck(tx=msg))
continue
elif res.request_type == proto.RequestType.TXOUTPUT:
msg = proto.TransactionType()
if res.details.tx_hash: if res.details.tx_hash:
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]] msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
else: else:
msg.outputs = [current_tx.outputs[res.details.request_index]] msg.outputs = [current_tx.outputs[res.details.request_index]]
if debug_processor is not None: res = client.call(messages.TxAck(tx=msg))
# msg needs to be deep copied so when it's modified
# the other messages stay intact
from copy import deepcopy
msg = deepcopy(msg) elif res.request_type == R.TXEXTRADATA:
# If debug_processor function is provided,
# pass thru it the request and prepared response.
# This is useful for tests, see test_msg_signtx
msg = debug_processor(res, msg)
res = client.call(proto.TxAck(tx=msg))
continue
elif res.request_type == proto.RequestType.TXEXTRADATA:
o, l = res.details.extra_data_offset, res.details.extra_data_len o, l = res.details.extra_data_offset, res.details.extra_data_len
msg = proto.TransactionType() msg = messages.TransactionType()
msg.extra_data = current_tx.extra_data[o : o + l] msg.extra_data = current_tx.extra_data[o : o + l]
res = client.call(proto.TxAck(tx=msg)) res = client.call(messages.TxAck(tx=msg))
continue
if isinstance(res, messages.Failure):
raise CallException("Signing failed")
if not isinstance(res, messages.TxRequest):
raise CallException("Unexpected message")
if None in signatures: if None in signatures:
raise RuntimeError("Some signatures are missing!") raise RuntimeError("Some signatures are missing!")
# log("SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % return signatures, serialized_tx
# (time.time() - start, counter, len(serialized_tx)))
return (signatures, serialized_tx)