mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 16:00:57 +00:00
signtx: add UI instructions, SigningError type
This commit is contained in:
parent
8109d8363c
commit
31e3aaa23b
@ -20,6 +20,31 @@ from trezor.messages import OutputScriptType, InputScriptType
|
||||
# ===
|
||||
|
||||
|
||||
class SigningError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class UiConfirmOutput:
|
||||
|
||||
def __init__(self, output: TxOutputType):
|
||||
self.output = output
|
||||
|
||||
|
||||
class UiConfirmTotal:
|
||||
|
||||
def __init__(self, total_out: int, fee: int):
|
||||
self.total_out = total_out
|
||||
self.fee = fee
|
||||
|
||||
|
||||
def confirm_output(output: TxOutputType):
|
||||
yield UiConfirmOutput(output)
|
||||
|
||||
|
||||
def confirm_total(total_out: int, fee: int):
|
||||
yield UiConfirmTotal(total_out, fee)
|
||||
|
||||
|
||||
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
|
||||
tx_req.request_type = TXMETA
|
||||
tx_req.details.tx_hash = tx_hash
|
||||
@ -100,17 +125,16 @@ async def sign_tx(tx: SignTx, root):
|
||||
txo = await request_tx_output(tx_req, o)
|
||||
if output_is_change(txo):
|
||||
if change_out != 0:
|
||||
raise ValueError('Only one change output is valid')
|
||||
raise SigningError('Only one change output is valid')
|
||||
change_out = txo.amount
|
||||
txo_bin.amount = txo.amount
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
||||
write_tx_output(h_first, txo_bin)
|
||||
total_out += txo_bin.amount
|
||||
# TODO: display output
|
||||
# TODO: confirm output
|
||||
await confirm_output(txo)
|
||||
|
||||
# TODO: check funds and tx fee
|
||||
# TODO: ask for confirmation
|
||||
fee = total_in - total_out
|
||||
await confirm_total(total_out, fee)
|
||||
|
||||
# Phase 2
|
||||
# - sign inputs
|
||||
@ -163,7 +187,7 @@ async def sign_tx(tx: SignTx, root):
|
||||
|
||||
# check the control digests
|
||||
if tx_hash_digest(h_first, False) != tx_hash_digest(h_second, False):
|
||||
raise ValueError('Transaction has changed during signing')
|
||||
raise SigningError('Transaction has changed during signing')
|
||||
|
||||
# compute the signature from the tx digest
|
||||
signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign, True))
|
||||
@ -239,7 +263,7 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
|
||||
|
||||
prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance
|
||||
if tx_hash_digest(txh, True) != prev_hash_rev:
|
||||
raise ValueError('Encountered invalid prev_hash')
|
||||
raise SigningError('Encountered invalid prev_hash')
|
||||
|
||||
return total_out
|
||||
|
||||
@ -260,7 +284,7 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
|
||||
ra = output_paytoaddress_extract_raw_address(o, coin, root)
|
||||
return script_paytoaddress_new(ra[1:])
|
||||
else:
|
||||
raise ValueError('Invalid output script type')
|
||||
raise SigningError('Invalid output script type')
|
||||
return
|
||||
|
||||
|
||||
@ -275,9 +299,9 @@ def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, roo
|
||||
elif o_address:
|
||||
raw_address = base58.decode_check(o_address)
|
||||
else:
|
||||
raise ValueError('Missing address')
|
||||
raise SigningError('Missing address')
|
||||
if raw_address[0] != coin.address_type:
|
||||
raise ValueError('Invalid address type')
|
||||
raise SigningError('Invalid address type')
|
||||
return raw_address
|
||||
|
||||
|
||||
@ -295,7 +319,7 @@ def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes:
|
||||
if i_script_type == InputScriptType.SPENDADDRESS:
|
||||
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
|
||||
else:
|
||||
raise ValueError('Unknown input script type')
|
||||
raise SigningError('Unknown input script type')
|
||||
|
||||
|
||||
def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
|
||||
@ -303,7 +327,7 @@ def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: byte
|
||||
if i_script_type == InputScriptType.SPENDADDRESS:
|
||||
return script_spendaddress_new(pubkey, signature)
|
||||
else:
|
||||
raise ValueError('Unknown input script type')
|
||||
raise SigningError('Unknown input script type')
|
||||
|
||||
|
||||
def node_derive(root, address_n: list):
|
||||
|
@ -12,7 +12,7 @@ from trezor.messages.TransactionType import TransactionType
|
||||
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
|
||||
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
|
||||
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
|
||||
from trezor.messages import OutputScriptType, InputScriptType
|
||||
from trezor.messages import OutputScriptType
|
||||
|
||||
from apps.common import signtx
|
||||
|
||||
@ -57,6 +57,10 @@ class TestSignTx(unittest.TestCase):
|
||||
TxAck(tx=TransactionType(bin_outputs=[pout1])),
|
||||
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
|
||||
TxAck(tx=TransactionType(outputs=[out1])),
|
||||
signtx.UiConfirmOutput(out1),
|
||||
True,
|
||||
signtx.UiConfirmTotal(380000, 10000),
|
||||
True,
|
||||
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
|
||||
# ButtonRequest(code=ButtonRequest_SignTx),
|
||||
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
|
||||
@ -80,10 +84,18 @@ class TestSignTx(unittest.TestCase):
|
||||
|
||||
signer = signtx.sign_tx(tx, root)
|
||||
for request, response in chunks(messages, 2):
|
||||
self.assertEqual(signer.send(request), response)
|
||||
self.assertEqualEx(signer.send(request), response)
|
||||
with self.assertRaises(StopIteration):
|
||||
signer.send(None)
|
||||
|
||||
def assertEqualEx(self, a, b):
|
||||
# hack to avoid adding __eq__ to signtx.Ui* classes
|
||||
if ((isinstance(a, signtx.UiConfirmOutput) and isinstance(b, signtx.UiConfirmOutput)) or
|
||||
(isinstance(a, signtx.UiConfirmTotal) and isinstance(b, signtx.UiConfirmTotal))):
|
||||
return self.assertEqual(a.__dict__, b.__dict__)
|
||||
else:
|
||||
return self.assertEqual(a, b)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user