1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 16:00:57 +00:00

signtx: add UI instructions, SigningError type

This commit is contained in:
Jan Pochyla 2016-11-10 15:31:07 +01:00
parent 8109d8363c
commit 31e3aaa23b
2 changed files with 50 additions and 14 deletions

View File

@ -20,6 +20,31 @@ from trezor.messages import OutputScriptType, InputScriptType
# ===
class SigningError(ValueError):
pass
class UiConfirmOutput:
def __init__(self, output: TxOutputType):
self.output = output
class UiConfirmTotal:
def __init__(self, total_out: int, fee: int):
self.total_out = total_out
self.fee = fee
def confirm_output(output: TxOutputType):
yield UiConfirmOutput(output)
def confirm_total(total_out: int, fee: int):
yield UiConfirmTotal(total_out, fee)
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash
@ -100,17 +125,16 @@ async def sign_tx(tx: SignTx, root):
txo = await request_tx_output(tx_req, o)
if output_is_change(txo):
if change_out != 0:
raise ValueError('Only one change output is valid')
raise SigningError('Only one change output is valid')
change_out = txo.amount
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
write_tx_output(h_first, txo_bin)
total_out += txo_bin.amount
# TODO: display output
# TODO: confirm output
await confirm_output(txo)
# TODO: check funds and tx fee
# TODO: ask for confirmation
fee = total_in - total_out
await confirm_total(total_out, fee)
# Phase 2
# - sign inputs
@ -163,7 +187,7 @@ async def sign_tx(tx: SignTx, root):
# check the control digests
if tx_hash_digest(h_first, False) != tx_hash_digest(h_second, False):
raise ValueError('Transaction has changed during signing')
raise SigningError('Transaction has changed during signing')
# compute the signature from the tx digest
signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign, True))
@ -239,7 +263,7 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance
if tx_hash_digest(txh, True) != prev_hash_rev:
raise ValueError('Encountered invalid prev_hash')
raise SigningError('Encountered invalid prev_hash')
return total_out
@ -260,7 +284,7 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
ra = output_paytoaddress_extract_raw_address(o, coin, root)
return script_paytoaddress_new(ra[1:])
else:
raise ValueError('Invalid output script type')
raise SigningError('Invalid output script type')
return
@ -275,9 +299,9 @@ def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, roo
elif o_address:
raw_address = base58.decode_check(o_address)
else:
raise ValueError('Missing address')
raise SigningError('Missing address')
if raw_address[0] != coin.address_type:
raise ValueError('Invalid address type')
raise SigningError('Invalid address type')
return raw_address
@ -295,7 +319,7 @@ def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes:
if i_script_type == InputScriptType.SPENDADDRESS:
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
else:
raise ValueError('Unknown input script type')
raise SigningError('Unknown input script type')
def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
@ -303,7 +327,7 @@ def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: byte
if i_script_type == InputScriptType.SPENDADDRESS:
return script_spendaddress_new(pubkey, signature)
else:
raise ValueError('Unknown input script type')
raise SigningError('Unknown input script type')
def node_derive(root, address_n: list):

View File

@ -12,7 +12,7 @@ from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType, InputScriptType
from trezor.messages import OutputScriptType
from apps.common import signtx
@ -57,6 +57,10 @@ class TestSignTx(unittest.TestCase):
TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])),
signtx.UiConfirmOutput(out1),
True,
signtx.UiConfirmTotal(380000, 10000),
True,
# ButtonRequest(code=ButtonRequest_ConfirmOutput),
# ButtonRequest(code=ButtonRequest_SignTx),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
@ -80,10 +84,18 @@ class TestSignTx(unittest.TestCase):
signer = signtx.sign_tx(tx, root)
for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response)
self.assertEqualEx(signer.send(request), response)
with self.assertRaises(StopIteration):
signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signtx.Ui* classes
if ((isinstance(a, signtx.UiConfirmOutput) and isinstance(b, signtx.UiConfirmOutput)) or
(isinstance(a, signtx.UiConfirmTotal) and isinstance(b, signtx.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__':
unittest.main()