From 38b0361ded7fa13d90f88c0f8d26c030864d5707 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Tue, 21 May 2024 23:49:08 +0200 Subject: [PATCH] test(core/ui): fix bitcoin signtx for T3T1 --- python/src/trezorlib/debuglink.py | 34 +++- tests/common.py | 2 +- tests/device_tests/bitcoin/test_getaddress.py | 90 ++++++---- .../bitcoin/test_getaddress_segwit.py | 107 ++++++------ .../bitcoin/test_getaddress_show.py | 50 +++--- tests/device_tests/bitcoin/test_multisig.py | 4 + .../bitcoin/test_multisig_change.py | 7 + .../bitcoin/test_nonstandard_paths.py | 85 ++++++---- .../device_tests/bitcoin/test_signmessage.py | 10 +- .../bitcoin/test_signtx_invalid_path.py | 16 +- .../bitcoin/test_signtx_prevhash.py | 10 +- .../bitcoin/test_signtx_segwit_native.py | 13 ++ tests/input_flows.py | 159 ++++++++++++++++-- 13 files changed, 419 insertions(+), 168 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 914ee5b664..f1fa328b8c 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -355,6 +355,12 @@ class LayoutContent(UnstructuredJSONReader): choice_obj.get(choice, {}).get("content", "") for choice in choice_keys ) + def footer(self) -> str: + footer = self.find_unique_object_with_key_and_value("component", "Footer") + if not footer: + return "" + return footer.get("description", "") + " " + footer.get("instruction", "") + def multipage_content(layouts: List[LayoutContent]) -> str: """Get overall content from multiple-page layout.""" @@ -804,6 +810,25 @@ class DebugUI: Generator[None, messages.ButtonRequest, None], object, None ] = None + def _default_input_flow(self, br: messages.ButtonRequest) -> None: + if br.code == messages.ButtonRequestType.PinEntry: + self.debuglink.input(self.get_pin()) + else: + # Paginating (going as further as possible) and pressing Yes + if br.pages is not None: + for _ in range(br.pages - 1): + self.debuglink.swipe_up(wait=True) + if self.debuglink.model is models.T3T1: + layout = self.debuglink.read_layout() + if "PromptScreen" in layout.all_components(): + self.debuglink.press_yes() + elif "swipe up" in layout.footer().lower(): + self.debuglink.swipe_up() + else: + self.debuglink.press_yes() + else: + self.debuglink.press_yes() + def button_request(self, br: messages.ButtonRequest) -> None: self.debuglink.take_t1_screenshot_if_relevant() @@ -814,14 +839,7 @@ class DebugUI: # recording their screens that way (as well as # possible swipes below). self.debuglink.save_current_screen_if_relevant(wait=True) - if br.code == messages.ButtonRequestType.PinEntry: - self.debuglink.input(self.get_pin()) - else: - # Paginating (going as further as possible) and pressing Yes - if br.pages is not None: - for _ in range(br.pages - 1): - self.debuglink.swipe_up(wait=True) - self.debuglink.press_yes() + self._default_input_flow(br) elif self.input_flow is self.INPUT_FLOW_DONE: raise AssertionError("input flow ended prematurely") else: diff --git a/tests/common.py b/tests/common.py index 35d6905c8a..889283bbab 100644 --- a/tests/common.py +++ b/tests/common.py @@ -344,4 +344,4 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None: def is_core(client: "Client") -> bool: - return client.model in (models.T2T1, models.T2B1, models.T3T1) + return client.model is not models.T1B1 diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 4191e7e8fb..5a53e15d29 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -23,6 +23,8 @@ from trezorlib.messages import SafetyCheckLevel from trezorlib.tools import parse_path from ... import bip32 +from ...common import is_core +from ...input_flows import InputFlowConfirmAllWarnings def getmultisig(chain, nr, xpubs): @@ -202,26 +204,30 @@ def test_multisig(client: Client): xpubs.append(node.xpub) for nr in range(1, 4): - assert ( - btc.get_address( - client, - "Bitcoin", - parse_path(f"m/44h/0h/{nr}h/0/0"), - show_display=(nr == 1), - multisig=getmultisig(0, 0, xpubs=xpubs), + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + assert ( + btc.get_address( + client, + "Bitcoin", + parse_path(f"m/44h/0h/{nr}h/0/0"), + show_display=(nr == 1), + multisig=getmultisig(0, 0, xpubs=xpubs), + ) + == "3Pdz86KtfJBuHLcSv4DysJo4aQfanTqCzG" ) - == "3Pdz86KtfJBuHLcSv4DysJo4aQfanTqCzG" - ) - assert ( - btc.get_address( - client, - "Bitcoin", - parse_path(f"m/44h/0h/{nr}h/1/0"), - show_display=(nr == 1), - multisig=getmultisig(1, 0, xpubs=xpubs), + assert ( + btc.get_address( + client, + "Bitcoin", + parse_path(f"m/44h/0h/{nr}h/1/0"), + show_display=(nr == 1), + multisig=getmultisig(1, 0, xpubs=xpubs), + ) + == "36gP3KVx1ooStZ9quZDXbAF3GCr42b2zzd" ) - == "36gP3KVx1ooStZ9quZDXbAF3GCr42b2zzd" - ) @pytest.mark.multisig @@ -254,7 +260,10 @@ def test_multisig_missing(client: Client, show_display): ) for multisig in (multisig1, multisig2): - with pytest.raises(TrezorFailure): + with client, pytest.raises(TrezorFailure): + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) btc.get_address( client, "Bitcoin", @@ -275,26 +284,30 @@ def test_bch_multisig(client: Client): xpubs.append(node.xpub) for nr in range(1, 4): - assert ( - btc.get_address( - client, - "Bcash", - parse_path(f"m/44h/145h/{nr}h/0/0"), - show_display=(nr == 1), - multisig=getmultisig(0, 0, xpubs=xpubs), + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + assert ( + btc.get_address( + client, + "Bcash", + parse_path(f"m/44h/145h/{nr}h/0/0"), + show_display=(nr == 1), + multisig=getmultisig(0, 0, xpubs=xpubs), + ) + == "bitcoincash:pqguz4nqq64jhr5v3kvpq4dsjrkda75hwy86gq0qzw" ) - == "bitcoincash:pqguz4nqq64jhr5v3kvpq4dsjrkda75hwy86gq0qzw" - ) - assert ( - btc.get_address( - client, - "Bcash", - parse_path(f"m/44h/145h/{nr}h/1/0"), - show_display=(nr == 1), - multisig=getmultisig(1, 0, xpubs=xpubs), + assert ( + btc.get_address( + client, + "Bcash", + parse_path(f"m/44h/145h/{nr}h/1/0"), + show_display=(nr == 1), + multisig=getmultisig(1, 0, xpubs=xpubs), + ) + == "bitcoincash:pp6kcpkhua7789g2vyj0qfkcux3yvje7euhyhltn0a" ) - == "bitcoincash:pp6kcpkhua7789g2vyj0qfkcux3yvje7euhyhltn0a" - ) def test_public_ckd(client: Client): @@ -342,6 +355,9 @@ def test_unknown_path(client: Client): messages.Address, ] ) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) # try again with a warning btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index 76e09a81f9..0958facda2 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -21,6 +21,9 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path +from ...common import is_core +from ...input_flows import InputFlowConfirmAllWarnings + def test_show_segwit(client: Client): assert ( @@ -71,61 +74,65 @@ def test_show_segwit(client: Client): @pytest.mark.altcoin def test_show_segwit_altcoin(client: Client): - assert ( - btc.get_address( - client, - "Groestlcoin Testnet", - parse_path("m/49h/1h/0h/1/0"), - True, - None, - script_type=messages.InputScriptType.SPENDP2SHWITNESS, + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + assert ( + btc.get_address( + client, + "Groestlcoin Testnet", + parse_path("m/49h/1h/0h/1/0"), + True, + None, + script_type=messages.InputScriptType.SPENDP2SHWITNESS, + ) + == "2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7" ) - == "2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7" - ) - assert ( - btc.get_address( - client, - "Groestlcoin Testnet", - parse_path("m/49h/1h/0h/0/0"), - True, - None, - script_type=messages.InputScriptType.SPENDP2SHWITNESS, + assert ( + btc.get_address( + client, + "Groestlcoin Testnet", + parse_path("m/49h/1h/0h/0/0"), + True, + None, + script_type=messages.InputScriptType.SPENDP2SHWITNESS, + ) + == "2N4Q5FhU2497BryFfUgbqkAJE87aKDv3V3e" ) - == "2N4Q5FhU2497BryFfUgbqkAJE87aKDv3V3e" - ) - assert ( - btc.get_address( - client, - "Groestlcoin Testnet", - parse_path("m/44h/1h/0h/0/0"), - True, - None, - script_type=messages.InputScriptType.SPENDP2SHWITNESS, + assert ( + btc.get_address( + client, + "Groestlcoin Testnet", + parse_path("m/44h/1h/0h/0/0"), + True, + None, + script_type=messages.InputScriptType.SPENDP2SHWITNESS, + ) + == "2N6UeBoqYEEnybg4cReFYDammpsyDzLXvCT" ) - == "2N6UeBoqYEEnybg4cReFYDammpsyDzLXvCT" - ) - assert ( - btc.get_address( - client, - "Groestlcoin Testnet", - parse_path("m/44h/1h/0h/0/0"), - True, - None, - script_type=messages.InputScriptType.SPENDADDRESS, + assert ( + btc.get_address( + client, + "Groestlcoin Testnet", + parse_path("m/44h/1h/0h/0/0"), + True, + None, + script_type=messages.InputScriptType.SPENDADDRESS, + ) + == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" ) - == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" - ) - assert ( - btc.get_address( - client, - "Elements", - parse_path("m/49h/1h/0h/0/0"), - True, - None, - script_type=messages.InputScriptType.SPENDP2SHWITNESS, + assert ( + btc.get_address( + client, + "Elements", + parse_path("m/49h/1h/0h/0/0"), + True, + None, + script_type=messages.InputScriptType.SPENDP2SHWITNESS, + ) + == "XNW67ZQA9K3AuXPBWvJH4zN2y5QBDTwy2Z" ) - == "XNW67ZQA9K3AuXPBWvJH4zN2y5QBDTwy2Z" - ) @pytest.mark.multisig diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 83d82d8adf..4760d485b4 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -20,7 +20,9 @@ from trezorlib import btc, messages, tools from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import Cancelled, TrezorFailure +from ...common import is_core from ...input_flows import ( + InputFlowConfirmAllWarnings, InputFlowShowAddressQRCode, InputFlowShowAddressQRCodeCancel, InputFlowShowMultisigXPUBs, @@ -148,17 +150,21 @@ def test_show_multisig_3(client: Client): ) for i in [1, 2, 3]: - assert ( - btc.get_address( - client, - "Bitcoin", - tools.parse_path(f"m/45h/0/0/{i}"), - show_display=True, - multisig=multisig, - script_type=messages.InputScriptType.SPENDMULTISIG, + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + assert ( + btc.get_address( + client, + "Bitcoin", + tools.parse_path(f"m/45h/0/0/{i}"), + show_display=True, + multisig=multisig, + script_type=messages.InputScriptType.SPENDMULTISIG, + ) + == "35Q3tgZZfr9GhVpaqz7fbDK8WXV1V1KxfD" ) - == "35Q3tgZZfr9GhVpaqz7fbDK8WXV1V1KxfD" - ) VECTORS_MULTISIG = ( # script_type, bip48_type, address, xpubs, ignore_xpub_magic @@ -289,14 +295,18 @@ def test_show_multisig_15(client: Client): ) for i in range(15): - assert ( - btc.get_address( - client, - "Bitcoin", - tools.parse_path(f"m/45h/0/0/{i}"), - show_display=True, - multisig=multisig, - script_type=messages.InputScriptType.SPENDMULTISIG, + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + assert ( + btc.get_address( + client, + "Bitcoin", + tools.parse_path(f"m/45h/0/0/{i}"), + show_display=True, + multisig=multisig, + script_type=messages.InputScriptType.SPENDMULTISIG, + ) + == "3GG78bp1hA3mu9xv1vZLXiENmeabmi7WKQ" ) - == "3GG78bp1hA3mu9xv1vZLXiENmeabmi7WKQ" - ) diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 303bedf87f..927720bfb2 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -22,6 +22,7 @@ from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path from ...common import MNEMONIC12, is_core +from ...input_flows import InputFlowConfirmAllWarnings from ...tx_cache import TxCache from .signtx import ( assert_tx_matches, @@ -304,6 +305,9 @@ def test_attack_change_input(client: Client): # Transaction can be signed without the attack processor with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) btc.sign_tx( client, "Testnet", diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index dbc095c544..e94eb60fd3 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -22,6 +22,7 @@ from trezorlib.tools import H_, parse_path from ... import bip32 from ...common import MNEMONIC12, is_core +from ...input_flows import InputFlowConfirmAllWarnings from ...tx_cache import TxCache from .signtx import request_finished, request_input, request_meta, request_output @@ -243,6 +244,9 @@ def test_external_internal(client: Client): client.set_expected_responses( _responses(client, INP1, INP2, change=2, foreign=True) ) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( client, "Bitcoin", @@ -276,6 +280,9 @@ def test_internal_external(client: Client): client.set_expected_responses( _responses(client, INP1, INP2, change=1, foreign=True) ) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( client, "Bitcoin", diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index 8aa0e038cf..96457da386 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -20,6 +20,8 @@ from trezorlib import btc, messages from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path +from ...common import is_core +from ...input_flows import InputFlowConfirmAllWarnings from .signtx import forge_prevtx VECTORS = ( # path, script_types @@ -113,16 +115,20 @@ def test_getaddress( script_types: list[messages.InputScriptType], ): for script_type in script_types: - res = btc.get_address( - client, - "Bitcoin", - parse_path(path), - show_display=True, - script_type=script_type, - chunkify=chunkify, - ) + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + res = btc.get_address( + client, + "Bitcoin", + parse_path(path), + show_display=True, + script_type=script_type, + chunkify=chunkify, + ) - assert res + assert res @pytest.mark.parametrize("path, script_types", VECTORS) @@ -130,15 +136,20 @@ def test_signmessage( client: Client, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - sig = btc.sign_message( - client, - coin_name="Bitcoin", - n=parse_path(path), - script_type=script_type, - message="This is an example of a signed message.", - ) + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) - assert sig.signature + sig = btc.sign_message( + client, + coin_name="Bitcoin", + n=parse_path(path), + script_type=script_type, + message="This is an example of a signed message.", + ) + + assert sig.signature @pytest.mark.parametrize("path, script_types", VECTORS) @@ -164,9 +175,13 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _, serialized_tx = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} - ) + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + _, serialized_tx = btc.sign_tx( + client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + ) assert serialized_tx.hex() @@ -187,14 +202,18 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - address = btc.get_address( - client, - "Bitcoin", - parse_path(paths[0]) + address_index, - show_display=True, - multisig=multisig, - script_type=messages.InputScriptType.SPENDMULTISIG, - ) + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + address = btc.get_address( + client, + "Bitcoin", + parse_path(paths[0]) + address_index, + show_display=True, + multisig=multisig, + script_type=messages.InputScriptType.SPENDMULTISIG, + ) assert address @@ -242,8 +261,12 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i script_type=messages.OutputScriptType.PAYTOADDRESS, ) - sig, _ = btc.sign_tx( - client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} - ) + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + sig, _ = btc.sign_tx( + client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} + ) assert sig[0] diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index cc8ee41a3c..b5d47b64f4 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -24,7 +24,12 @@ from trezorlib.debuglink import message_filters from trezorlib.exceptions import Cancelled from trezorlib.tools import parse_path -from ...input_flows import InputFlowSignMessageInfo, InputFlowSignMessagePagination +from ...common import is_core +from ...input_flows import ( + InputFlowConfirmAllWarnings, + InputFlowSignMessageInfo, + InputFlowSignMessagePagination, +) S = messages.InputScriptType @@ -414,6 +419,9 @@ def test_signmessage_path_warning(client: Client): messages.MessageSignature, ] ) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) btc.sign_message( client, coin_name="Bitcoin", diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 69bf4efe40..5ef4ba0389 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -21,6 +21,8 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import H_, parse_path +from ...common import is_core +from ...input_flows import InputFlowConfirmAllWarnings from .signtx import forge_prevtx, request_input B = messages.ButtonRequestType @@ -78,7 +80,12 @@ def test_invalid_path_prompt(client: Client): client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + + btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) # Bcash does have strong replay protection using SIGHASH_FORKID, @@ -99,7 +106,12 @@ def test_invalid_path_pass_forkid(client: Client): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) + with client: + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) + + btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) def test_attack_path_segwit(client: Client): diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index 09a5be83ef..307823a9f3 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -8,6 +8,8 @@ from trezorlib import btc, messages, models, tools from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure +from ...common import is_core +from ...input_flows import InputFlowConfirmAllWarnings from .signtx import forge_prevtx # address at seed "all all all..." path m/44h/0h/0h/0/0 @@ -130,6 +132,9 @@ def test_invalid_prev_hash_attack(client: Client, prev_hash): with client, pytest.raises(TrezorFailure) as e: client.set_filter(messages.TxAck, attack_filter) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed @@ -163,6 +168,9 @@ def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash): tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with pytest.raises(TrezorFailure) as e: + with client, pytest.raises(TrezorFailure) as e: + if client.model is not models.T1B1: + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) _check_error_message(prev_hash, client.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index ef350151c8..0c779c777e 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -22,6 +22,7 @@ from trezorlib.tools import H_, parse_path from ...bip32 import deserialize from ...common import is_core +from ...input_flows import InputFlowConfirmAllWarnings from ...tx_cache import TxCache from .signtx import ( assert_tx_matches, @@ -612,6 +613,9 @@ def test_send_multisig_3_change(client: Client): with client: client.set_expected_responses(expected_responses) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -624,6 +628,9 @@ def test_send_multisig_3_change(client: Client): with client: client.set_expected_responses(expected_responses) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -698,6 +705,9 @@ def test_send_multisig_4_change(client: Client): with client: client.set_expected_responses(expected_responses) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -710,6 +720,9 @@ def test_send_multisig_4_change(client: Client): with client: client.set_expected_responses(expected_responses) + if is_core(client): + IF = InputFlowConfirmAllWarnings(client) + client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) diff --git a/tests/input_flows.py b/tests/input_flows.py index c3abebdf82..ec40b2f350 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -716,21 +716,20 @@ class InputFlowPaymentRequestDetails(InputFlowBase): yield # confirm first output assert self.outputs[0].address[:16] in self.text_content() # type: ignore - self.debug.press_yes() + self.debug.swipe_up() yield # confirm first output self.debug.wait_layout() - self.debug.press_yes() + self.debug.swipe_up() yield # confirm second output assert self.outputs[1].address[:16] in self.text_content() # type: ignore - self.debug.press_yes() + self.debug.swipe_up() yield # confirm second output self.debug.wait_layout() - self.debug.press_yes() + self.debug.swipe_up() yield # confirm transaction - self.debug.press_yes() - yield # confirm transaction + self.debug.swipe_up() self.debug.press_yes() @@ -747,7 +746,7 @@ class InputFlowSignTxHighFee(InputFlowBase): self.finished = True - def input_flow_common(self) -> BRGeneratorType: + def input_flow_tt(self) -> BRGeneratorType: screens = [ B.ConfirmOutput, B.ConfirmOutput, @@ -756,6 +755,31 @@ class InputFlowSignTxHighFee(InputFlowBase): ] yield from self.go_through_all_screens(screens) + def input_flow_tr(self) -> BRGeneratorType: + screens = [ + B.ConfirmOutput, + B.ConfirmOutput, + B.FeeOverThreshold, + B.SignTx, + ] + yield from self.go_through_all_screens(screens) + + def input_flow_t3t1(self) -> BRGeneratorType: + screens = [ + B.ConfirmOutput, + B.ConfirmOutput, + B.FeeOverThreshold, + B.SignTx, + ] + for expected in screens: + br = yield + assert br.code == expected + self.debug.swipe_up() + if br.code == B.SignTx: + self.debug.press_yes() + + self.finished = True + def sign_tx_go_to_info(client: Client) -> Generator[None, None, str]: yield # confirm output @@ -777,6 +801,43 @@ def sign_tx_go_to_info(client: Client) -> Generator[None, None, str]: return content +def sign_tx_go_to_info_t3t1( + client: Client, multi_account: bool = False +) -> Generator[None, None, str]: + yield # confirm output + client.debug.wait_layout() + client.debug.swipe_up() + yield # confirm output + client.debug.wait_layout() + client.debug.swipe_up() + + if multi_account: + yield + client.debug.wait_layout() + client.debug.swipe_up() + + yield # confirm transaction + client.debug.wait_layout() + client.debug.click(buttons.CORNER_BUTTON) + client.debug.synchronize_at("VerticalMenu") + client.debug.click(buttons.VERTICAL_MENU[0]) + + layout = client.debug.wait_layout() + content = layout.text_content() + + client.debug.click(buttons.CORNER_BUTTON) + client.debug.synchronize_at("VerticalMenu") + client.debug.click(buttons.VERTICAL_MENU[1]) + + layout = client.debug.wait_layout() + content += " " + layout.text_content() + + client.debug.click(buttons.CORNER_BUTTON) + client.debug.click(buttons.CORNER_BUTTON, wait=True) + + return content + + def sign_tx_go_to_info_tr( client: Client, ) -> Generator[None, None, str]: @@ -829,8 +890,9 @@ class InputFlowSignTxInformation(InputFlowBase): self.debug.press_yes() def input_flow_t3t1(self) -> BRGeneratorType: - content = yield from sign_tx_go_to_info(self.client) + content = yield from sign_tx_go_to_info_t3t1(self.client) self.assert_content(content, "confirm_total__sending_from_account") + self.debug.swipe_up() self.debug.press_yes() @@ -863,12 +925,9 @@ class InputFlowSignTxInformationMixed(InputFlowBase): self.debug.press_yes() def input_flow_t3t1(self) -> BRGeneratorType: - # multiple accounts warning - yield - self.debug.press_yes() - - content = yield from sign_tx_go_to_info(self.client) + content = yield from sign_tx_go_to_info_t3t1(self.client, multi_account=True) self.assert_content(content, "confirm_total__sending_from_account") + self.debug.swipe_up() self.debug.press_yes() @@ -885,8 +944,11 @@ class InputFlowSignTxInformationCancel(InputFlowBase): self.debug.press_left() def input_flow_t3t1(self) -> BRGeneratorType: - yield from sign_tx_go_to_info(self.client) - self.debug.press_no() + yield from sign_tx_go_to_info_t3t1(self.client) + self.debug.click(buttons.CORNER_BUTTON) + self.debug.click(buttons.VERTICAL_MENU[2]) + self.debug.synchronize_at("PromptScreen") + self.debug.click(buttons.TAP_TO_CONFIRM) class InputFlowSignTxInformationReplacement(InputFlowBase): @@ -984,6 +1046,30 @@ def lock_time_input_flow_tr( debug.press_yes() +def lock_time_input_flow_t3t1( + debug: DebugLink, + layout_assert_func: Callable[[DebugLink, messages.ButtonRequest], None], + double_confirm: bool = False, +) -> BRGeneratorType: + yield # confirm output + debug.wait_layout() + debug.swipe_up() + yield # confirm output + debug.wait_layout() + debug.swipe_up() + + br = yield # confirm locktime + layout_assert_func(debug, br) + debug.press_yes() + + yield # confirm transaction + debug.swipe_up() + debug.press_yes() + if double_confirm: + yield # confirm transaction + debug.press_yes() + + class InputFlowLockTimeBlockHeight(InputFlowBase): def __init__(self, client: Client, block_height: str): super().__init__(client) @@ -1003,7 +1089,7 @@ class InputFlowLockTimeBlockHeight(InputFlowBase): yield from lock_time_input_flow_tr(self.debug, self.assert_func) def input_flow_t3t1(self) -> BRGeneratorType: - yield from lock_time_input_flow_tt( + yield from lock_time_input_flow_t3t1( self.debug, self.assert_func, double_confirm=True ) @@ -1025,7 +1111,7 @@ class InputFlowLockTimeDatetime(InputFlowBase): yield from lock_time_input_flow_tr(self.debug, self.assert_func) def input_flow_t3t1(self) -> BRGeneratorType: - yield from lock_time_input_flow_tt(self.debug, self.assert_func) + yield from lock_time_input_flow_t3t1(self.debug, self.assert_func) class InputFlowEIP712ShowMore(InputFlowBase): @@ -2048,3 +2134,42 @@ class InputFlowResetSkipBackup(InputFlowBase): yield # Confirm skip backup TR.assert_in(self.text_content(), "backup__want_to_skip") self.debug.press_no() + + +class InputFlowConfirmAllWarnings(InputFlowBase): + def __init__(self, client: Client): + super().__init__(client) + + def input_flow_tt(self) -> BRGeneratorType: + br = yield + while True: + # wait for homescreen to go away + self.debug.wait_layout() + self.client.ui._default_input_flow(br) + br = yield + + def input_flow_tr(self) -> BRGeneratorType: + return self.input_flow_tt() + + def input_flow_t3t1(self) -> BRGeneratorType: + br = yield + while True: + # wait for homescreen to go away + # probably won't be needed after https://github.com/trezor/trezor-firmware/pull/3686 + self.debug.wait_layout() + # Paginating (going as further as possible) and pressing Yes + if br.pages is not None: + for _ in range(br.pages - 1): + self.debug.swipe_up(wait=True) + layout = self.debug.read_layout() + text = layout.text_content().lower() + # hi priority warning + if ("wrong derivation path" in text) or ("to a multisig" in text): + self.debug.click(buttons.CORNER_BUTTON, wait=True) + self.debug.synchronize_at("VerticalMenu") + self.debug.click(buttons.VERTICAL_MENU[1]) + elif "swipe up" in layout.footer().lower(): + self.debug.swipe_up() + else: + self.debug.press_yes() + br = yield