From 31e3aaa23b17fd6260255f1c517e9b60cdd8b6f1 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Thu, 10 Nov 2016 15:31:07 +0100 Subject: [PATCH] signtx: add UI instructions, SigningError type --- src/apps/common/signtx.py | 48 ++++++++++++++++++++++++-------- tests/test_apps_common.signtx.py | 16 +++++++++-- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/src/apps/common/signtx.py b/src/apps/common/signtx.py index 4b4bf1622d..1821dcdbd3 100644 --- a/src/apps/common/signtx.py +++ b/src/apps/common/signtx.py @@ -20,6 +20,31 @@ from trezor.messages import OutputScriptType, InputScriptType # === +class SigningError(ValueError): + pass + + +class UiConfirmOutput: + + def __init__(self, output: TxOutputType): + self.output = output + + +class UiConfirmTotal: + + def __init__(self, total_out: int, fee: int): + self.total_out = total_out + self.fee = fee + + +def confirm_output(output: TxOutputType): + yield UiConfirmOutput(output) + + +def confirm_total(total_out: int, fee: int): + yield UiConfirmTotal(total_out, fee) + + def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): tx_req.request_type = TXMETA tx_req.details.tx_hash = tx_hash @@ -100,17 +125,16 @@ async def sign_tx(tx: SignTx, root): txo = await request_tx_output(tx_req, o) if output_is_change(txo): if change_out != 0: - raise ValueError('Only one change output is valid') + raise SigningError('Only one change output is valid') change_out = txo.amount txo_bin.amount = txo.amount txo_bin.script_pubkey = output_derive_script(txo, coin, root) write_tx_output(h_first, txo_bin) total_out += txo_bin.amount - # TODO: display output - # TODO: confirm output + await confirm_output(txo) - # TODO: check funds and tx fee - # TODO: ask for confirmation + fee = total_in - total_out + await confirm_total(total_out, fee) # Phase 2 # - sign inputs @@ -163,7 +187,7 @@ async def sign_tx(tx: SignTx, root): # check the control digests if tx_hash_digest(h_first, False) != tx_hash_digest(h_second, False): - raise ValueError('Transaction has changed during signing') + raise SigningError('Transaction has changed during signing') # compute the signature from the tx digest signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign, True)) @@ -239,7 +263,7 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance if tx_hash_digest(txh, True) != prev_hash_rev: - raise ValueError('Encountered invalid prev_hash') + raise SigningError('Encountered invalid prev_hash') return total_out @@ -260,7 +284,7 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: ra = output_paytoaddress_extract_raw_address(o, coin, root) return script_paytoaddress_new(ra[1:]) else: - raise ValueError('Invalid output script type') + raise SigningError('Invalid output script type') return @@ -275,9 +299,9 @@ def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, roo elif o_address: raw_address = base58.decode_check(o_address) else: - raise ValueError('Missing address') + raise SigningError('Missing address') if raw_address[0] != coin.address_type: - raise ValueError('Invalid address type') + raise SigningError('Invalid address type') return raw_address @@ -295,7 +319,7 @@ def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes: if i_script_type == InputScriptType.SPENDADDRESS: return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey)) else: - raise ValueError('Unknown input script type') + raise SigningError('Unknown input script type') def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes: @@ -303,7 +327,7 @@ def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: byte if i_script_type == InputScriptType.SPENDADDRESS: return script_spendaddress_new(pubkey, signature) else: - raise ValueError('Unknown input script type') + raise SigningError('Unknown input script type') def node_derive(root, address_n: list): diff --git a/tests/test_apps_common.signtx.py b/tests/test_apps_common.signtx.py index 64f4bb3a86..1669131eb9 100644 --- a/tests/test_apps_common.signtx.py +++ b/tests/test_apps_common.signtx.py @@ -12,7 +12,7 @@ from trezor.messages.TransactionType import TransactionType from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType -from trezor.messages import OutputScriptType, InputScriptType +from trezor.messages import OutputScriptType from apps.common import signtx @@ -57,6 +57,10 @@ class TestSignTx(unittest.TestCase): TxAck(tx=TransactionType(bin_outputs=[pout1])), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxAck(tx=TransactionType(outputs=[out1])), + signtx.UiConfirmOutput(out1), + True, + signtx.UiConfirmTotal(380000, 10000), + True, # ButtonRequest(code=ButtonRequest_ConfirmOutput), # ButtonRequest(code=ButtonRequest_SignTx), TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), @@ -80,10 +84,18 @@ class TestSignTx(unittest.TestCase): signer = signtx.sign_tx(tx, root) for request, response in chunks(messages, 2): - self.assertEqual(signer.send(request), response) + self.assertEqualEx(signer.send(request), response) with self.assertRaises(StopIteration): signer.send(None) + def assertEqualEx(self, a, b): + # hack to avoid adding __eq__ to signtx.Ui* classes + if ((isinstance(a, signtx.UiConfirmOutput) and isinstance(b, signtx.UiConfirmOutput)) or + (isinstance(a, signtx.UiConfirmTotal) and isinstance(b, signtx.UiConfirmTotal))): + return self.assertEqual(a.__dict__, b.__dict__) + else: + return self.assertEqual(a, b) + if __name__ == '__main__': unittest.main()