diff --git a/src/apps/common/signtx.py b/src/apps/common/signtx.py index d0521de92b..1c04cda136 100644 --- a/src/apps/common/signtx.py +++ b/src/apps/common/signtx.py @@ -13,7 +13,7 @@ from trezor.messages.TxRequest import TxRequest from trezor.messages.RequestType import TXINPUT, TXOUTPUT, TXMETA, TXFINISHED from trezor.messages.TxRequestSerializedType import TxRequestSerializedType from trezor.messages.TxRequestDetailsType import TxRequestDetailsType -from trezor.messages import OutputScriptType, InputScriptType +from trezor.messages import OutputScriptType, InputScriptType, FailureType # Machine instructions @@ -26,23 +26,25 @@ class SigningError(ValueError): class UiConfirmOutput: - def __init__(self, output: TxOutputType): + def __init__(self, output: TxOutputType, coin: CoinType): self.output = output + self.coin = coin class UiConfirmTotal: - def __init__(self, total_out: int, fee: int): + def __init__(self, total_out: int, fee: int, coin: CoinType): self.total_out = total_out self.fee = fee + self.coin = coin -def confirm_output(output: TxOutputType): - yield UiConfirmOutput(output) +def confirm_output(output: TxOutputType, coin: CoinType): + return (yield UiConfirmOutput(output, coin)) -def confirm_total(total_out: int, fee: int): - yield UiConfirmTotal(total_out, fee) +def confirm_total(total_out: int, fee: int, coin: CoinType): + return (yield UiConfirmTotal(total_out, fee, coin)) def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): @@ -125,16 +127,23 @@ async def sign_tx(tx: SignTx, root): txo = await request_tx_output(tx_req, o) if output_is_change(txo): if change_out != 0: - raise SigningError('Only one change output is valid') + raise SigningError(FailureType.Other, + 'Only one change output is valid') change_out = txo.amount txo_bin.amount = txo.amount txo_bin.script_pubkey = output_derive_script(txo, coin, root) write_tx_output(h_first, txo_bin) total_out += txo_bin.amount - await confirm_output(txo) + + if not output_is_change(txo) and not await confirm_output(txo, coin): + raise SigningError(FailureType.ActionCancelled, + 'Output cancelled') fee = total_in - total_out - await confirm_total(total_out, fee) + + if not await confirm_total(total_out, fee, coin): + raise SigningError(FailureType.ActionCancelled, + 'Total cancelled') # Phase 2 # - sign inputs @@ -187,7 +196,8 @@ async def sign_tx(tx: SignTx, root): # check the control digests if get_tx_hash(h_first, False) != get_tx_hash(h_second, False): - raise SigningError('Transaction has changed during signing') + raise SigningError(FailureType.Other, + 'Transaction has changed during signing') # compute the signature from the tx digest signature = ecdsa_sign(key_sign, get_tx_hash(h_sign, True)) @@ -262,7 +272,8 @@ async def get_prevtx_output_value(tx_req: TxRequest, prev_hash: bytes, prev_inde write_uint32(txh, tx_lock_time) if get_tx_hash(txh, True, True) != prev_hash: - raise SigningError('Encountered invalid prev_hash') + raise SigningError(FailureType.Other, + 'Encountered invalid prev_hash') return total_out @@ -285,29 +296,32 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: ra = output_paytoaddress_extract_raw_address(o, coin, root) return script_paytoaddress_new(ra[1:]) else: - raise SigningError('Invalid output script type') + raise SigningError(FailureType.SyntaxError, + 'Invalid output script type') return def output_paytoaddress_extract_raw_address(o: TxOutputType, coin: CoinType, root) -> bytes: - o_address_n = getattr(o, 'address_n', None) - o_address = getattr(o, 'address', None) # TODO: dont encode/decode more then necessary # TODO: detect correct address type - if o_address_n is not None: - n = node_derive(root, o_address_n) - raw_address = base58.decode_check(n.address(coin.address_type)) - elif o_address: - raw_address = base58.decode_check(o_address) - if raw_address[0] != coin.address_type: - raise SigningError('Invalid address type') - else: - raise SigningError('Missing address') - return raw_address + address_n = getattr(o, 'address_n', None) + if address_n is not None: + node = node_derive(root, address_n) + address = node.address(coin.address_type) + return base58.decode_check(address) + address = getattr(o, 'address', None) + if address: + raw = base58.decode_check(address) + if raw[0] != coin.address_type: + raise SigningError(FailureType.SyntaxError, + 'Invalid address type') + return raw + raise SigningError(FailureType.SyntaxError, + 'Missing address') -def output_is_change(output: TxOutputType): - address_n = getattr(output, 'address_n', None) +def output_is_change(o: TxOutputType): + address_n = getattr(o, 'address_n', None) return bool(address_n) @@ -320,7 +334,8 @@ def input_derive_script_pre_sign(i: TxInputType, pubkey: bytes) -> bytes: if i_script_type == InputScriptType.SPENDADDRESS: return script_paytoaddress_new(ecdsa_hash_pubkey(pubkey)) else: - raise SigningError('Unknown input script type') + raise SigningError(FailureType.SyntaxError, + 'Unknown input script type') def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: bytes) -> bytes: @@ -328,7 +343,8 @@ def input_derive_script_post_sign(i: TxInputType, pubkey: bytes, signature: byte if i_script_type == InputScriptType.SPENDADDRESS: return script_spendaddress_new(pubkey, signature) else: - raise SigningError('Unknown input script type') + raise SigningError(FailureType.SyntaxError, + 'Unknown input script type') def node_derive(root, address_n: list): diff --git a/src/apps/wallet/layout_sign_tx.py b/src/apps/wallet/layout_sign_tx.py index 76215fdcd8..a51f76cfcf 100644 --- a/src/apps/wallet/layout_sign_tx.py +++ b/src/apps/wallet/layout_sign_tx.py @@ -2,12 +2,34 @@ from trezor.utils import unimport from trezor import wire -async def confirm_output(output): - return True +def format_amount(amount, coin): + return '%s %s' % (amount / 1e8, coin.coin_shortcut) -async def confirm_total(total_out, fee): - return True +async def confirm_output(session_id, output, coin): + from trezor import ui + from trezor.ui.text import Text + from trezor.messages.ButtonRequestType import ConfirmOutput + from ..common.confirm import confirm + + content = Text('Confirm output', ui.ICON_RESET, + ui.BOLD, format_amount(output.amount, coin), + ui.NORMAL, 'to', + ui.MONO, output.address[0:17], + ui.MONO, output.address[17:]) + return await confirm(session_id, content, ConfirmOutput) + + +async def confirm_total(session_id, total_out, fee, coin): + from trezor import ui + from trezor.ui.text import Text + from trezor.messages.ButtonRequestType import SignTx + from ..common.confirm import confirm + + content = Text('Confirm transaction', ui.ICON_RESET, + 'Sending: %s' % format_amount(total_out, coin), + 'Fee: %s' % format_amount(fee, coin)) + return await confirm(session_id, content, SignTx) @unimport @@ -33,9 +55,10 @@ async def layout_sign_tx(message, session_id): break res = await wire.reply_message(session_id, req, TxAck) elif isinstance(req, signtx.UiConfirmOutput): - res = await confirm_output(req.output) + res = await confirm_output(session_id, req.output, req.coin) elif isinstance(req, signtx.UiConfirmTotal): - res = await confirm_total(req.total_out, req.fee) + res = await confirm_total(session_id, req.total_out, req.fee, req.coin) else: + print(req) raise ValueError('Invalid signing instruction') return req