diff --git a/src/apps/wallet/sign_tx/__init__.py b/src/apps/wallet/sign_tx/__init__.py index 65ea76446..daa6d1367 100644 --- a/src/apps/wallet/sign_tx/__init__.py +++ b/src/apps/wallet/sign_tx/__init__.py @@ -3,7 +3,7 @@ from trezor.messages.wire_types import TxAck from trezor.messages.TxRequest import TxRequest from trezor.messages.RequestType import TXFINISHED from apps.common import seed -from apps.wallet.sign_tx.helpers import UiConfirmOutput, UiConfirmTotal, UiConfirmFeeOverThreshold +from apps.wallet.sign_tx.helpers import UiConfirmOutput, UiConfirmTotal, UiConfirmFeeOverThreshold, UiConfirmForeignAddress @ui.layout @@ -41,6 +41,8 @@ async def sign_tx(ctx, msg): elif isinstance(req, UiConfirmFeeOverThreshold): res = await layout.confirm_feeoverthreshold(ctx, req.fee, req.coin) progress.report_init() + elif isinstance(req, UiConfirmForeignAddress): + res = await layout.confirm_foreign_address(ctx, req.address_n, req.coin) else: raise TypeError('Invalid signing instruction') return req diff --git a/src/apps/wallet/sign_tx/helpers.py b/src/apps/wallet/sign_tx/helpers.py index d9023ad60..81689331a 100644 --- a/src/apps/wallet/sign_tx/helpers.py +++ b/src/apps/wallet/sign_tx/helpers.py @@ -35,6 +35,13 @@ class UiConfirmFeeOverThreshold: self.coin = coin +class UiConfirmForeignAddress: + + def __init__(self, address_n: list, coin: CoinInfo): + self.address_n = address_n + self.coin = coin + + def confirm_output(output: TxOutputType, coin: CoinInfo): return (yield UiConfirmOutput(output, coin)) @@ -47,6 +54,10 @@ def confirm_feeoverthreshold(fee: int, coin: CoinInfo): return (yield UiConfirmFeeOverThreshold(fee, coin)) +def confirm_foreign_address(address_n: list, coin: CoinInfo): + return (yield UiConfirmForeignAddress(address_n, coin)) + + def request_tx_meta(tx_req: TxRequest, tx_hash: bytes=None): tx_req.request_type = TXMETA tx_req.details.tx_hash = tx_hash diff --git a/src/apps/wallet/sign_tx/layout.py b/src/apps/wallet/sign_tx/layout.py index b4e4e9bda..b7e12979e 100644 --- a/src/apps/wallet/sign_tx/layout.py +++ b/src/apps/wallet/sign_tx/layout.py @@ -4,6 +4,7 @@ from trezor.utils import chunks, format_amount from trezor.ui.text import Text from trezor.messages import ButtonRequestType from trezor.messages import OutputScriptType +from apps.common import coins from apps.common.confirm import confirm from apps.common.confirm import hold_to_confirm @@ -53,3 +54,12 @@ async def confirm_feeoverthreshold(ctx, fee, coin): 'Continue?', icon_color=ui.GREEN) return await confirm(ctx, content, ButtonRequestType.FeeOverThreshold) + + +async def confirm_foreign_address(ctx, address_n, coin): + content = Text('Confirm sending', ui.ICON_SEND, + 'Trying to spend', + 'coins from another chain.', + 'Continue?', icon_color=ui.RED) + + return await confirm(ctx, content, ButtonRequestType.SignTx) diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index bd4e2b180..f2cc0c689 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -86,6 +86,9 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): hash143.add_prevouts(txi) # all inputs are included (non-segwit as well) hash143.add_sequence(txi) + if not address_n_matches_coin(txi.address_n, coin): + await confirm_foreign_address(txi.address_n, coin) + if txi.multisig: multifp.add(txi.multisig) @@ -607,12 +610,19 @@ def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list: 'Transaction has changed during signing') -def node_derive(root: bip32.HDNode, address_n: list): +def node_derive(root: bip32.HDNode, address_n: list) -> bip32.HDNode: node = root.clone() node.derive_path(address_n) return node +def address_n_matches_coin(address_n: list, coin: CoinInfo) -> bool: + bip44 = const(44 | 0x80000000) + if len(address_n) < 2 or address_n[0] != bip44 or address_n[1] == coin.slip44 | 0x80000000: + return True # path is not BIP44 or matches the coin + return False # path is BIP44 and does not match the coin + + def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: sig = secp256k1.sign(node.private_key(), digest) sigder = der.encode_seq((sig[1:33], sig[33:65]))