1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-25 15:58:08 +00:00

signtx: add UI instructions, SigningError type

This commit is contained in:
Jan Pochyla 2016-11-10 15:31:07 +01:00
parent 8109d8363c
commit 31e3aaa23b
2 changed files with 50 additions and 14 deletions

View File

@ -20,6 +20,31 @@ from trezor.messages import OutputScriptType, InputScriptType
# === # ===
class SigningError(ValueError):
pass
class UiConfirmOutput:
def __init__(self, output: TxOutputType):
self.output = output
class UiConfirmTotal:
def __init__(self, total_out: int, fee: int):
self.total_out = total_out
self.fee = fee
def confirm_output(output: TxOutputType):
yield UiConfirmOutput(output)
def confirm_total(total_out: int, fee: int):
yield UiConfirmTotal(total_out, fee)
def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None):
tx_req.request_type = TXMETA tx_req.request_type = TXMETA
tx_req.details.tx_hash = tx_hash tx_req.details.tx_hash = tx_hash
@ -100,17 +125,16 @@ async def sign_tx(tx: SignTx, root):
txo = await request_tx_output(tx_req, o) txo = await request_tx_output(tx_req, o)
if output_is_change(txo): if output_is_change(txo):
if change_out != 0: if change_out != 0:
raise ValueError('Only one change output is valid') raise SigningError('Only one change output is valid')
change_out = txo.amount change_out = txo.amount
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
write_tx_output(h_first, txo_bin) write_tx_output(h_first, txo_bin)
total_out += txo_bin.amount total_out += txo_bin.amount
# TODO: display output await confirm_output(txo)
# TODO: confirm output
# TODO: check funds and tx fee fee = total_in - total_out
# TODO: ask for confirmation await confirm_total(total_out, fee)
# Phase 2 # Phase 2
# - sign inputs # - sign inputs
@ -163,7 +187,7 @@ async def sign_tx(tx: SignTx, root):
# check the control digests # check the control digests
if tx_hash_digest(h_first, False) != tx_hash_digest(h_second, False): if tx_hash_digest(h_first, False) != tx_hash_digest(h_second, False):
raise ValueError('Transaction has changed during signing') raise SigningError('Transaction has changed during signing')
# compute the signature from the tx digest # compute the signature from the tx digest
signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign, True)) signature = ecdsa_sign(key_sign, tx_hash_digest(h_sign, True))
@ -239,7 +263,7 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde
prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance prev_hash_rev = bytes(reversed(prev_hash)) # TODO: improve performance
if tx_hash_digest(txh, True) != prev_hash_rev: if tx_hash_digest(txh, True) != prev_hash_rev:
raise ValueError('Encountered invalid prev_hash') raise SigningError('Encountered invalid prev_hash')
return total_out return total_out
@ -260,7 +284,7 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
ra = output_paytoaddress_extract_raw_address(o, coin, root) ra = output_paytoaddress_extract_raw_address(o, coin, root)
return script_paytoaddress_new(ra[1:]) return script_paytoaddress_new(ra[1:])
else: else:
raise ValueError('Invalid output script type') raise SigningError('Invalid output script type')
return return
@ -275,9 +299,9 @@ def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, roo
elif o_address: elif o_address:
raw_address = base58.decode_check(o_address) raw_address = base58.decode_check(o_address)
else: else:
raise ValueError('Missing address') raise SigningError('Missing address')
if raw_address[0] != coin.address_type: if raw_address[0] != coin.address_type:
raise ValueError('Invalid address type') raise SigningError('Invalid address type')
return raw_address return raw_address
@ -295,7 +319,7 @@ def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes:
if i_script_type == InputScriptType.SPENDADDRESS: if i_script_type == InputScriptType.SPENDADDRESS:
return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey)) return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey))
else: else:
raise ValueError('Unknown input script type') raise SigningError('Unknown input script type')
def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes: def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes:
@ -303,7 +327,7 @@ def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: byte
if i_script_type == InputScriptType.SPENDADDRESS: if i_script_type == InputScriptType.SPENDADDRESS:
return script_spendaddress_new(pubkey, signature) return script_spendaddress_new(pubkey, signature)
else: else:
raise ValueError('Unknown input script type') raise SigningError('Unknown input script type')
def node_derive(root, address_n: list): def node_derive(root, address_n: list):

View File

@ -12,7 +12,7 @@ from trezor.messages.TransactionType import TransactionType
from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED
from trezor.messages.TxRequestDetailsType import TxRequestDetailsType from trezor.messages.TxRequestDetailsType import TxRequestDetailsType
from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages.TxRequestSerializedType import TxRequestSerializedType
from trezor.messages import OutputScriptType, InputScriptType from trezor.messages import OutputScriptType
from apps.common import signtx from apps.common import signtx
@ -57,6 +57,10 @@ class TestSignTx(unittest.TestCase):
TxAck(tx=TransactionType(bin_outputs=[pout1])), TxAck(tx=TransactionType(bin_outputs=[pout1])),
TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxRequest(request_type=TXOUTPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
TxAck(tx=TransactionType(outputs=[out1])), TxAck(tx=TransactionType(outputs=[out1])),
signtx.UiConfirmOutput(out1),
True,
signtx.UiConfirmTotal(380000, 10000),
True,
# ButtonRequest(code=ButtonRequest_ConfirmOutput), # ButtonRequest(code=ButtonRequest_ConfirmOutput),
# ButtonRequest(code=ButtonRequest_SignTx), # ButtonRequest(code=ButtonRequest_SignTx),
TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None), TxRequest(request_type=TXINPUT, details=TxRequestDetailsType(request_index=0, tx_hash=None), serialized=None),
@ -80,10 +84,18 @@ class TestSignTx(unittest.TestCase):
signer = signtx.sign_tx(tx, root) signer = signtx.sign_tx(tx, root)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
self.assertEqual(signer.send(request), response) self.assertEqualEx(signer.send(request), response)
with self.assertRaises(StopIteration): with self.assertRaises(StopIteration):
signer.send(None) signer.send(None)
def assertEqualEx(self, a, b):
# hack to avoid adding __eq__ to signtx.Ui* classes
if ((isinstance(a, signtx.UiConfirmOutput) and isinstance(b, signtx.UiConfirmOutput)) or
(isinstance(a, signtx.UiConfirmTotal) and isinstance(b, signtx.UiConfirmTotal))):
return self.assertEqual(a.__dict__, b.__dict__)
else:
return self.assertEqual(a, b)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()