mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-25 14:50:57 +00:00
btc: refactor and cleanup sign_tx api & flow
This commit is contained in:
parent
ea675f1e58
commit
e5e0759dc8
198
trezorlib/btc.py
198
trezorlib/btc.py
@ -1,18 +1,18 @@
|
||||
from . import messages as proto
|
||||
from . import messages, tools
|
||||
from .tools import CallException, expect, normalize_nfc, session
|
||||
|
||||
|
||||
@expect(proto.PublicKey)
|
||||
@expect(messages.PublicKey)
|
||||
def get_public_node(
|
||||
client,
|
||||
n,
|
||||
ecdsa_curve_name=None,
|
||||
show_display=False,
|
||||
coin_name=None,
|
||||
script_type=proto.InputScriptType.SPENDADDRESS,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
):
|
||||
return client.call(
|
||||
proto.GetPublicKey(
|
||||
messages.GetPublicKey(
|
||||
address_n=n,
|
||||
ecdsa_curve_name=ecdsa_curve_name,
|
||||
show_display=show_display,
|
||||
@ -22,18 +22,17 @@ def get_public_node(
|
||||
)
|
||||
|
||||
|
||||
@expect(proto.Address, field="address")
|
||||
@expect(messages.Address, field="address")
|
||||
def get_address(
|
||||
client,
|
||||
coin_name,
|
||||
n,
|
||||
show_display=False,
|
||||
multisig=None,
|
||||
script_type=proto.InputScriptType.SPENDADDRESS,
|
||||
script_type=messages.InputScriptType.SPENDADDRESS,
|
||||
):
|
||||
if multisig:
|
||||
return client.call(
|
||||
proto.GetAddress(
|
||||
messages.GetAddress(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
show_display=show_display,
|
||||
@ -41,24 +40,15 @@ def get_address(
|
||||
script_type=script_type,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return client.call(
|
||||
proto.GetAddress(
|
||||
address_n=n,
|
||||
coin_name=coin_name,
|
||||
show_display=show_display,
|
||||
script_type=script_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@expect(proto.MessageSignature)
|
||||
@expect(messages.MessageSignature)
|
||||
def sign_message(
|
||||
client, coin_name, n, message, script_type=proto.InputScriptType.SPENDADDRESS
|
||||
client, coin_name, n, message, script_type=messages.InputScriptType.SPENDADDRESS
|
||||
):
|
||||
message = normalize_nfc(message)
|
||||
return client.call(
|
||||
proto.SignMessage(
|
||||
messages.SignMessage(
|
||||
coin_name=coin_name, address_n=n, message=message, script_type=script_type
|
||||
)
|
||||
)
|
||||
@ -68,7 +58,7 @@ def verify_message(client, coin_name, address, signature, message):
|
||||
message = normalize_nfc(message)
|
||||
try:
|
||||
resp = client.call(
|
||||
proto.VerifyMessage(
|
||||
messages.VerifyMessage(
|
||||
address=address,
|
||||
signature=signature,
|
||||
message=message,
|
||||
@ -77,147 +67,95 @@ def verify_message(client, coin_name, address, signature, message):
|
||||
)
|
||||
except CallException as e:
|
||||
resp = e
|
||||
return isinstance(resp, proto.Success)
|
||||
return isinstance(resp, messages.Success)
|
||||
|
||||
|
||||
@session
|
||||
def sign_tx(
|
||||
client,
|
||||
coin_name,
|
||||
inputs,
|
||||
outputs,
|
||||
version=None,
|
||||
lock_time=None,
|
||||
expiry=None,
|
||||
overwintered=None,
|
||||
version_group_id=None,
|
||||
debug_processor=None,
|
||||
timestamp=None,
|
||||
):
|
||||
# start = time.time()
|
||||
txes = client._prepare_sign_tx(inputs, outputs)
|
||||
def sign_tx(client, coin_name, inputs, outputs, details=None, prev_txes=None):
|
||||
my_tx = messages.TransactionType(inputs=inputs, outputs=outputs)
|
||||
|
||||
# Prepare and send initial message
|
||||
tx = proto.SignTx()
|
||||
tx.inputs_count = len(inputs)
|
||||
tx.outputs_count = len(outputs)
|
||||
tx.coin_name = coin_name
|
||||
if version is not None:
|
||||
tx.version = version
|
||||
if lock_time is not None:
|
||||
tx.lock_time = lock_time
|
||||
if expiry is not None:
|
||||
tx.expiry = expiry
|
||||
if overwintered is not None:
|
||||
tx.overwintered = overwintered
|
||||
if version_group_id is not None:
|
||||
tx.version_group_id = version_group_id
|
||||
if timestamp is not None:
|
||||
tx.timestamp = timestamp
|
||||
res = client.call(tx)
|
||||
if details is None:
|
||||
signtx = messages.SignTx()
|
||||
else:
|
||||
signtx = details
|
||||
|
||||
signtx.coin_name = coin_name
|
||||
signtx.inputs_count = len(inputs)
|
||||
signtx.outputs_count = len(outputs)
|
||||
|
||||
res = client.call(signtx)
|
||||
|
||||
# Prepare structure for signatures
|
||||
signatures = [None] * len(inputs)
|
||||
serialized_tx = b""
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
counter += 1
|
||||
|
||||
if isinstance(res, proto.Failure):
|
||||
raise CallException("Signing failed")
|
||||
|
||||
if not isinstance(res, proto.TxRequest):
|
||||
raise CallException("Unexpected message")
|
||||
def copy_tx_meta(tx):
|
||||
tx_copy = messages.TransactionType()
|
||||
tx_copy.CopyFrom(tx)
|
||||
# clear fields
|
||||
tx_copy.inputs_cnt = len(tx.inputs)
|
||||
tx_copy.inputs = []
|
||||
tx_copy.outputs_cnt = len(tx.bin_outputs or tx.outputs)
|
||||
tx_copy.outputs = []
|
||||
tx_copy.bin_outputs = []
|
||||
tx_copy.extra_data_len = len(tx.extra_data or b"")
|
||||
tx_copy.extra_data = None
|
||||
return tx_copy
|
||||
|
||||
R = messages.RequestType
|
||||
while isinstance(res, messages.TxRequest):
|
||||
# If there's some part of signed transaction, let's add it
|
||||
if res.serialized and res.serialized.serialized_tx:
|
||||
# log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx))
|
||||
if res.serialized:
|
||||
if res.serialized.serialized_tx:
|
||||
serialized_tx += res.serialized.serialized_tx
|
||||
|
||||
if res.serialized and res.serialized.signature_index is not None:
|
||||
if signatures[res.serialized.signature_index] is not None:
|
||||
raise ValueError(
|
||||
"Signature for index %d already filled"
|
||||
% res.serialized.signature_index
|
||||
)
|
||||
signatures[res.serialized.signature_index] = res.serialized.signature
|
||||
if res.serialized.signature_index is not None:
|
||||
idx = res.serialized.signature_index
|
||||
sig = res.serialized.signature
|
||||
if signatures[idx] is not None:
|
||||
raise ValueError("Signature for index %d already filled" % idx)
|
||||
signatures[idx] = sig
|
||||
|
||||
if res.request_type == proto.RequestType.TXFINISHED:
|
||||
# Device didn't ask for more information, finish workflow
|
||||
if res.request_type == R.TXFINISHED:
|
||||
break
|
||||
|
||||
# Device asked for one more information, let's process it.
|
||||
if not res.details.tx_hash:
|
||||
current_tx = txes[None]
|
||||
current_tx = my_tx
|
||||
else:
|
||||
current_tx = txes[bytes(res.details.tx_hash)]
|
||||
current_tx = prev_txes[res.details.tx_hash]
|
||||
|
||||
if res.request_type == proto.RequestType.TXMETA:
|
||||
msg = proto.TransactionType()
|
||||
msg.version = current_tx.version
|
||||
msg.lock_time = current_tx.lock_time
|
||||
msg.inputs_cnt = len(current_tx.inputs)
|
||||
msg.timestamp = current_tx.timestamp
|
||||
if res.details.tx_hash:
|
||||
msg.outputs_cnt = len(current_tx.bin_outputs)
|
||||
else:
|
||||
msg.outputs_cnt = len(current_tx.outputs)
|
||||
msg.extra_data_len = (
|
||||
len(current_tx.extra_data) if current_tx.extra_data else 0
|
||||
)
|
||||
res = client.call(proto.TxAck(tx=msg))
|
||||
continue
|
||||
if res.request_type == R.TXMETA:
|
||||
msg = copy_tx_meta(current_tx)
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
elif res.request_type == proto.RequestType.TXINPUT:
|
||||
msg = proto.TransactionType()
|
||||
elif res.request_type == R.TXINPUT:
|
||||
msg = messages.TransactionType()
|
||||
msg.inputs = [current_tx.inputs[res.details.request_index]]
|
||||
if debug_processor is not None:
|
||||
# msg needs to be deep copied so when it's modified
|
||||
# the other messages stay intact
|
||||
from copy import deepcopy
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
msg = deepcopy(msg)
|
||||
# If debug_processor function is provided,
|
||||
# pass thru it the request and prepared response.
|
||||
# This is useful for tests, see test_msg_signtx
|
||||
msg = debug_processor(res, msg)
|
||||
|
||||
res = client.call(proto.TxAck(tx=msg))
|
||||
continue
|
||||
|
||||
elif res.request_type == proto.RequestType.TXOUTPUT:
|
||||
msg = proto.TransactionType()
|
||||
elif res.request_type == R.TXOUTPUT:
|
||||
msg = messages.TransactionType()
|
||||
if res.details.tx_hash:
|
||||
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
|
||||
else:
|
||||
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
||||
|
||||
if debug_processor is not None:
|
||||
# msg needs to be deep copied so when it's modified
|
||||
# the other messages stay intact
|
||||
from copy import deepcopy
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
msg = deepcopy(msg)
|
||||
# If debug_processor function is provided,
|
||||
# pass thru it the request and prepared response.
|
||||
# This is useful for tests, see test_msg_signtx
|
||||
msg = debug_processor(res, msg)
|
||||
|
||||
res = client.call(proto.TxAck(tx=msg))
|
||||
continue
|
||||
|
||||
elif res.request_type == proto.RequestType.TXEXTRADATA:
|
||||
elif res.request_type == R.TXEXTRADATA:
|
||||
o, l = res.details.extra_data_offset, res.details.extra_data_len
|
||||
msg = proto.TransactionType()
|
||||
msg = messages.TransactionType()
|
||||
msg.extra_data = current_tx.extra_data[o : o + l]
|
||||
res = client.call(proto.TxAck(tx=msg))
|
||||
continue
|
||||
res = client.call(messages.TxAck(tx=msg))
|
||||
|
||||
if isinstance(res, messages.Failure):
|
||||
raise CallException("Signing failed")
|
||||
|
||||
if not isinstance(res, messages.TxRequest):
|
||||
raise CallException("Unexpected message")
|
||||
|
||||
if None in signatures:
|
||||
raise RuntimeError("Some signatures are missing!")
|
||||
|
||||
# log("SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" %
|
||||
# (time.time() - start, counter, len(serialized_tx)))
|
||||
|
||||
return (signatures, serialized_tx)
|
||||
return signatures, serialized_tx
|
||||
|
Loading…
Reference in New Issue
Block a user