1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 12:28:09 +00:00

test(core/ui): fix bitcoin signtx for T3T1

This commit is contained in:
Martin Milata 2024-05-21 23:49:08 +02:00
parent 198693259d
commit 38b0361ded
13 changed files with 419 additions and 168 deletions

View File

@ -355,6 +355,12 @@ class LayoutContent(UnstructuredJSONReader):
choice_obj.get(choice, {}).get("content", "") for choice in choice_keys choice_obj.get(choice, {}).get("content", "") for choice in choice_keys
) )
def footer(self) -> str:
footer = self.find_unique_object_with_key_and_value("component", "Footer")
if not footer:
return ""
return footer.get("description", "") + " " + footer.get("instruction", "")
def multipage_content(layouts: List[LayoutContent]) -> str: def multipage_content(layouts: List[LayoutContent]) -> str:
"""Get overall content from multiple-page layout.""" """Get overall content from multiple-page layout."""
@ -804,6 +810,25 @@ class DebugUI:
Generator[None, messages.ButtonRequest, None], object, None Generator[None, messages.ButtonRequest, None], object, None
] = None ] = None
def _default_input_flow(self, br: messages.ButtonRequest) -> None:
if br.code == messages.ButtonRequestType.PinEntry:
self.debuglink.input(self.get_pin())
else:
# Paginating (going as further as possible) and pressing Yes
if br.pages is not None:
for _ in range(br.pages - 1):
self.debuglink.swipe_up(wait=True)
if self.debuglink.model is models.T3T1:
layout = self.debuglink.read_layout()
if "PromptScreen" in layout.all_components():
self.debuglink.press_yes()
elif "swipe up" in layout.footer().lower():
self.debuglink.swipe_up()
else:
self.debuglink.press_yes()
else:
self.debuglink.press_yes()
def button_request(self, br: messages.ButtonRequest) -> None: def button_request(self, br: messages.ButtonRequest) -> None:
self.debuglink.take_t1_screenshot_if_relevant() self.debuglink.take_t1_screenshot_if_relevant()
@ -814,14 +839,7 @@ class DebugUI:
# recording their screens that way (as well as # recording their screens that way (as well as
# possible swipes below). # possible swipes below).
self.debuglink.save_current_screen_if_relevant(wait=True) self.debuglink.save_current_screen_if_relevant(wait=True)
if br.code == messages.ButtonRequestType.PinEntry: self._default_input_flow(br)
self.debuglink.input(self.get_pin())
else:
# Paginating (going as further as possible) and pressing Yes
if br.pages is not None:
for _ in range(br.pages - 1):
self.debuglink.swipe_up(wait=True)
self.debuglink.press_yes()
elif self.input_flow is self.INPUT_FLOW_DONE: elif self.input_flow is self.INPUT_FLOW_DONE:
raise AssertionError("input flow ended prematurely") raise AssertionError("input flow ended prematurely")
else: else:

View File

@ -344,4 +344,4 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
def is_core(client: "Client") -> bool: def is_core(client: "Client") -> bool:
return client.model in (models.T2T1, models.T2B1, models.T3T1) return client.model is not models.T1B1

View File

@ -23,6 +23,8 @@ from trezorlib.messages import SafetyCheckLevel
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ... import bip32 from ... import bip32
from ...common import is_core
from ...input_flows import InputFlowConfirmAllWarnings
def getmultisig(chain, nr, xpubs): def getmultisig(chain, nr, xpubs):
@ -202,26 +204,30 @@ def test_multisig(client: Client):
xpubs.append(node.xpub) xpubs.append(node.xpub)
for nr in range(1, 4): for nr in range(1, 4):
assert ( with client:
btc.get_address( if is_core(client):
client, IF = InputFlowConfirmAllWarnings(client)
"Bitcoin", client.set_input_flow(IF.get())
parse_path(f"m/44h/0h/{nr}h/0/0"), assert (
show_display=(nr == 1), btc.get_address(
multisig=getmultisig(0, 0, xpubs=xpubs), client,
"Bitcoin",
parse_path(f"m/44h/0h/{nr}h/0/0"),
show_display=(nr == 1),
multisig=getmultisig(0, 0, xpubs=xpubs),
)
== "3Pdz86KtfJBuHLcSv4DysJo4aQfanTqCzG"
) )
== "3Pdz86KtfJBuHLcSv4DysJo4aQfanTqCzG" assert (
) btc.get_address(
assert ( client,
btc.get_address( "Bitcoin",
client, parse_path(f"m/44h/0h/{nr}h/1/0"),
"Bitcoin", show_display=(nr == 1),
parse_path(f"m/44h/0h/{nr}h/1/0"), multisig=getmultisig(1, 0, xpubs=xpubs),
show_display=(nr == 1), )
multisig=getmultisig(1, 0, xpubs=xpubs), == "36gP3KVx1ooStZ9quZDXbAF3GCr42b2zzd"
) )
== "36gP3KVx1ooStZ9quZDXbAF3GCr42b2zzd"
)
@pytest.mark.multisig @pytest.mark.multisig
@ -254,7 +260,10 @@ def test_multisig_missing(client: Client, show_display):
) )
for multisig in (multisig1, multisig2): for multisig in (multisig1, multisig2):
with pytest.raises(TrezorFailure): with client, pytest.raises(TrezorFailure):
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
btc.get_address( btc.get_address(
client, client,
"Bitcoin", "Bitcoin",
@ -275,26 +284,30 @@ def test_bch_multisig(client: Client):
xpubs.append(node.xpub) xpubs.append(node.xpub)
for nr in range(1, 4): for nr in range(1, 4):
assert ( with client:
btc.get_address( if is_core(client):
client, IF = InputFlowConfirmAllWarnings(client)
"Bcash", client.set_input_flow(IF.get())
parse_path(f"m/44h/145h/{nr}h/0/0"), assert (
show_display=(nr == 1), btc.get_address(
multisig=getmultisig(0, 0, xpubs=xpubs), client,
"Bcash",
parse_path(f"m/44h/145h/{nr}h/0/0"),
show_display=(nr == 1),
multisig=getmultisig(0, 0, xpubs=xpubs),
)
== "bitcoincash:pqguz4nqq64jhr5v3kvpq4dsjrkda75hwy86gq0qzw"
) )
== "bitcoincash:pqguz4nqq64jhr5v3kvpq4dsjrkda75hwy86gq0qzw" assert (
) btc.get_address(
assert ( client,
btc.get_address( "Bcash",
client, parse_path(f"m/44h/145h/{nr}h/1/0"),
"Bcash", show_display=(nr == 1),
parse_path(f"m/44h/145h/{nr}h/1/0"), multisig=getmultisig(1, 0, xpubs=xpubs),
show_display=(nr == 1), )
multisig=getmultisig(1, 0, xpubs=xpubs), == "bitcoincash:pp6kcpkhua7789g2vyj0qfkcux3yvje7euhyhltn0a"
) )
== "bitcoincash:pp6kcpkhua7789g2vyj0qfkcux3yvje7euhyhltn0a"
)
def test_public_ckd(client: Client): def test_public_ckd(client: Client):
@ -342,6 +355,9 @@ def test_unknown_path(client: Client):
messages.Address, messages.Address,
] ]
) )
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
# try again with a warning # try again with a warning
btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True) btc.get_address(client, "Bitcoin", UNKNOWN_PATH, show_display=True)

View File

@ -21,6 +21,9 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import is_core
from ...input_flows import InputFlowConfirmAllWarnings
def test_show_segwit(client: Client): def test_show_segwit(client: Client):
assert ( assert (
@ -71,61 +74,65 @@ def test_show_segwit(client: Client):
@pytest.mark.altcoin @pytest.mark.altcoin
def test_show_segwit_altcoin(client: Client): def test_show_segwit_altcoin(client: Client):
assert ( with client:
btc.get_address( if is_core(client):
client, IF = InputFlowConfirmAllWarnings(client)
"Groestlcoin Testnet", client.set_input_flow(IF.get())
parse_path("m/49h/1h/0h/1/0"), assert (
True, btc.get_address(
None, client,
script_type=messages.InputScriptType.SPENDP2SHWITNESS, "Groestlcoin Testnet",
parse_path("m/49h/1h/0h/1/0"),
True,
None,
script_type=messages.InputScriptType.SPENDP2SHWITNESS,
)
== "2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7"
) )
== "2N1LGaGg836mqSQqiuUBLfcyGBhyZYBtBZ7" assert (
) btc.get_address(
assert ( client,
btc.get_address( "Groestlcoin Testnet",
client, parse_path("m/49h/1h/0h/0/0"),
"Groestlcoin Testnet", True,
parse_path("m/49h/1h/0h/0/0"), None,
True, script_type=messages.InputScriptType.SPENDP2SHWITNESS,
None, )
script_type=messages.InputScriptType.SPENDP2SHWITNESS, == "2N4Q5FhU2497BryFfUgbqkAJE87aKDv3V3e"
) )
== "2N4Q5FhU2497BryFfUgbqkAJE87aKDv3V3e" assert (
) btc.get_address(
assert ( client,
btc.get_address( "Groestlcoin Testnet",
client, parse_path("m/44h/1h/0h/0/0"),
"Groestlcoin Testnet", True,
parse_path("m/44h/1h/0h/0/0"), None,
True, script_type=messages.InputScriptType.SPENDP2SHWITNESS,
None, )
script_type=messages.InputScriptType.SPENDP2SHWITNESS, == "2N6UeBoqYEEnybg4cReFYDammpsyDzLXvCT"
) )
== "2N6UeBoqYEEnybg4cReFYDammpsyDzLXvCT" assert (
) btc.get_address(
assert ( client,
btc.get_address( "Groestlcoin Testnet",
client, parse_path("m/44h/1h/0h/0/0"),
"Groestlcoin Testnet", True,
parse_path("m/44h/1h/0h/0/0"), None,
True, script_type=messages.InputScriptType.SPENDADDRESS,
None, )
script_type=messages.InputScriptType.SPENDADDRESS, == "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y"
) )
== "mvbu1Gdy8SUjTenqerxUaZyYjmvedc787y" assert (
) btc.get_address(
assert ( client,
btc.get_address( "Elements",
client, parse_path("m/49h/1h/0h/0/0"),
"Elements", True,
parse_path("m/49h/1h/0h/0/0"), None,
True, script_type=messages.InputScriptType.SPENDP2SHWITNESS,
None, )
script_type=messages.InputScriptType.SPENDP2SHWITNESS, == "XNW67ZQA9K3AuXPBWvJH4zN2y5QBDTwy2Z"
) )
== "XNW67ZQA9K3AuXPBWvJH4zN2y5QBDTwy2Z"
)
@pytest.mark.multisig @pytest.mark.multisig

View File

@ -20,7 +20,9 @@ from trezorlib import btc, messages, tools
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import Cancelled, TrezorFailure from trezorlib.exceptions import Cancelled, TrezorFailure
from ...common import is_core
from ...input_flows import ( from ...input_flows import (
InputFlowConfirmAllWarnings,
InputFlowShowAddressQRCode, InputFlowShowAddressQRCode,
InputFlowShowAddressQRCodeCancel, InputFlowShowAddressQRCodeCancel,
InputFlowShowMultisigXPUBs, InputFlowShowMultisigXPUBs,
@ -148,17 +150,21 @@ def test_show_multisig_3(client: Client):
) )
for i in [1, 2, 3]: for i in [1, 2, 3]:
assert ( with client:
btc.get_address( if is_core(client):
client, IF = InputFlowConfirmAllWarnings(client)
"Bitcoin", client.set_input_flow(IF.get())
tools.parse_path(f"m/45h/0/0/{i}"), assert (
show_display=True, btc.get_address(
multisig=multisig, client,
script_type=messages.InputScriptType.SPENDMULTISIG, "Bitcoin",
tools.parse_path(f"m/45h/0/0/{i}"),
show_display=True,
multisig=multisig,
script_type=messages.InputScriptType.SPENDMULTISIG,
)
== "35Q3tgZZfr9GhVpaqz7fbDK8WXV1V1KxfD"
) )
== "35Q3tgZZfr9GhVpaqz7fbDK8WXV1V1KxfD"
)
VECTORS_MULTISIG = ( # script_type, bip48_type, address, xpubs, ignore_xpub_magic VECTORS_MULTISIG = ( # script_type, bip48_type, address, xpubs, ignore_xpub_magic
@ -289,14 +295,18 @@ def test_show_multisig_15(client: Client):
) )
for i in range(15): for i in range(15):
assert ( with client:
btc.get_address( if is_core(client):
client, IF = InputFlowConfirmAllWarnings(client)
"Bitcoin", client.set_input_flow(IF.get())
tools.parse_path(f"m/45h/0/0/{i}"), assert (
show_display=True, btc.get_address(
multisig=multisig, client,
script_type=messages.InputScriptType.SPENDMULTISIG, "Bitcoin",
tools.parse_path(f"m/45h/0/0/{i}"),
show_display=True,
multisig=multisig,
script_type=messages.InputScriptType.SPENDMULTISIG,
)
== "3GG78bp1hA3mu9xv1vZLXiENmeabmi7WKQ"
) )
== "3GG78bp1hA3mu9xv1vZLXiENmeabmi7WKQ"
)

View File

@ -22,6 +22,7 @@ from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import MNEMONIC12, is_core from ...common import MNEMONIC12, is_core
from ...input_flows import InputFlowConfirmAllWarnings
from ...tx_cache import TxCache from ...tx_cache import TxCache
from .signtx import ( from .signtx import (
assert_tx_matches, assert_tx_matches,
@ -304,6 +305,9 @@ def test_attack_change_input(client: Client):
# Transaction can be signed without the attack processor # Transaction can be signed without the attack processor
with client: with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
client, client,
"Testnet", "Testnet",

View File

@ -22,6 +22,7 @@ from trezorlib.tools import H_, parse_path
from ... import bip32 from ... import bip32
from ...common import MNEMONIC12, is_core from ...common import MNEMONIC12, is_core
from ...input_flows import InputFlowConfirmAllWarnings
from ...tx_cache import TxCache from ...tx_cache import TxCache
from .signtx import request_finished, request_input, request_meta, request_output from .signtx import request_finished, request_input, request_meta, request_output
@ -243,6 +244,9 @@ def test_external_internal(client: Client):
client.set_expected_responses( client.set_expected_responses(
_responses(client, INP1, INP2, change=2, foreign=True) _responses(client, INP1, INP2, change=2, foreign=True)
) )
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, client,
"Bitcoin", "Bitcoin",
@ -276,6 +280,9 @@ def test_internal_external(client: Client):
client.set_expected_responses( client.set_expected_responses(
_responses(client, INP1, INP2, change=1, foreign=True) _responses(client, INP1, INP2, change=1, foreign=True)
) )
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, client,
"Bitcoin", "Bitcoin",

View File

@ -20,6 +20,8 @@ from trezorlib import btc, messages
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...common import is_core
from ...input_flows import InputFlowConfirmAllWarnings
from .signtx import forge_prevtx from .signtx import forge_prevtx
VECTORS = ( # path, script_types VECTORS = ( # path, script_types
@ -113,16 +115,20 @@ def test_getaddress(
script_types: list[messages.InputScriptType], script_types: list[messages.InputScriptType],
): ):
for script_type in script_types: for script_type in script_types:
res = btc.get_address( with client:
client, if is_core(client):
"Bitcoin", IF = InputFlowConfirmAllWarnings(client)
parse_path(path), client.set_input_flow(IF.get())
show_display=True, res = btc.get_address(
script_type=script_type, client,
chunkify=chunkify, "Bitcoin",
) parse_path(path),
show_display=True,
script_type=script_type,
chunkify=chunkify,
)
assert res assert res
@pytest.mark.parametrize("path, script_types", VECTORS) @pytest.mark.parametrize("path, script_types", VECTORS)
@ -130,15 +136,20 @@ def test_signmessage(
client: Client, path: str, script_types: list[messages.InputScriptType] client: Client, path: str, script_types: list[messages.InputScriptType]
): ):
for script_type in script_types: for script_type in script_types:
sig = btc.sign_message( with client:
client, if is_core(client):
coin_name="Bitcoin", IF = InputFlowConfirmAllWarnings(client)
n=parse_path(path), client.set_input_flow(IF.get())
script_type=script_type,
message="This is an example of a signed message.",
)
assert sig.signature sig = btc.sign_message(
client,
coin_name="Bitcoin",
n=parse_path(path),
script_type=script_type,
message="This is an example of a signed message.",
)
assert sig.signature
@pytest.mark.parametrize("path, script_types", VECTORS) @pytest.mark.parametrize("path, script_types", VECTORS)
@ -164,9 +175,13 @@ def test_signtx(
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
_, serialized_tx = btc.sign_tx( with client:
client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} if is_core(client):
) IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx(
client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}
)
assert serialized_tx.hex() assert serialized_tx.hex()
@ -187,14 +202,18 @@ def test_getaddress_multisig(
] ]
multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2)
address = btc.get_address( with client:
client, if is_core(client):
"Bitcoin", IF = InputFlowConfirmAllWarnings(client)
parse_path(paths[0]) + address_index, client.set_input_flow(IF.get())
show_display=True, address = btc.get_address(
multisig=multisig, client,
script_type=messages.InputScriptType.SPENDMULTISIG, "Bitcoin",
) parse_path(paths[0]) + address_index,
show_display=True,
multisig=multisig,
script_type=messages.InputScriptType.SPENDMULTISIG,
)
assert address assert address
@ -242,8 +261,12 @@ def test_signtx_multisig(client: Client, paths: list[str], address_index: list[i
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
sig, _ = btc.sign_tx( with client:
client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} if is_core(client):
) IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
sig, _ = btc.sign_tx(
client, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}
)
assert sig[0] assert sig[0]

View File

@ -24,7 +24,12 @@ from trezorlib.debuglink import message_filters
from trezorlib.exceptions import Cancelled from trezorlib.exceptions import Cancelled
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...input_flows import InputFlowSignMessageInfo, InputFlowSignMessagePagination from ...common import is_core
from ...input_flows import (
InputFlowConfirmAllWarnings,
InputFlowSignMessageInfo,
InputFlowSignMessagePagination,
)
S = messages.InputScriptType S = messages.InputScriptType
@ -414,6 +419,9 @@ def test_signmessage_path_warning(client: Client):
messages.MessageSignature, messages.MessageSignature,
] ]
) )
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
btc.sign_message( btc.sign_message(
client, client,
coin_name="Bitcoin", coin_name="Bitcoin",

View File

@ -21,6 +21,8 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import H_, parse_path from trezorlib.tools import H_, parse_path
from ...common import is_core
from ...input_flows import InputFlowConfirmAllWarnings
from .signtx import forge_prevtx, request_input from .signtx import forge_prevtx, request_input
B = messages.ButtonRequestType B = messages.ButtonRequestType
@ -78,7 +80,12 @@ def test_invalid_path_prompt(client: Client):
client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily client, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
) )
btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
btc.sign_tx(client, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES)
# Bcash does have strong replay protection using SIGHASH_FORKID, # Bcash does have strong replay protection using SIGHASH_FORKID,
@ -99,7 +106,12 @@ def test_invalid_path_pass_forkid(client: Client):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) with client:
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
btc.sign_tx(client, "Bcash", [inp1], [out1], prev_txes=PREV_TXES)
def test_attack_path_segwit(client: Client): def test_attack_path_segwit(client: Client):

View File

@ -8,6 +8,8 @@ from trezorlib import btc, messages, models, tools
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from ...common import is_core
from ...input_flows import InputFlowConfirmAllWarnings
from .signtx import forge_prevtx from .signtx import forge_prevtx
# address at seed "all all all..." path m/44h/0h/0h/0/0 # address at seed "all all all..." path m/44h/0h/0h/0/0
@ -130,6 +132,9 @@ def test_invalid_prev_hash_attack(client: Client, prev_hash):
with client, pytest.raises(TrezorFailure) as e: with client, pytest.raises(TrezorFailure) as e:
client.set_filter(messages.TxAck, attack_filter) client.set_filter(messages.TxAck, attack_filter)
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) btc.sign_tx(client, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES)
# check that injection was performed # check that injection was performed
@ -163,6 +168,9 @@ def test_invalid_prev_hash_in_prevtx(client: Client, prev_hash):
tx_hash = hash_tx(serialize_tx(prev_tx)) tx_hash = hash_tx(serialize_tx(prev_tx))
inp0.prev_hash = tx_hash inp0.prev_hash = tx_hash
with pytest.raises(TrezorFailure) as e: with client, pytest.raises(TrezorFailure) as e:
if client.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) btc.sign_tx(client, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx})
_check_error_message(prev_hash, client.model, e.value.message) _check_error_message(prev_hash, client.model, e.value.message)

View File

@ -22,6 +22,7 @@ from trezorlib.tools import H_, parse_path
from ...bip32 import deserialize from ...bip32 import deserialize
from ...common import is_core from ...common import is_core
from ...input_flows import InputFlowConfirmAllWarnings
from ...tx_cache import TxCache from ...tx_cache import TxCache
from .signtx import ( from .signtx import (
assert_tx_matches, assert_tx_matches,
@ -612,6 +613,9 @@ def test_send_multisig_3_change(client: Client):
with client: with client:
client.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -624,6 +628,9 @@ def test_send_multisig_3_change(client: Client):
with client: with client:
client.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -698,6 +705,9 @@ def test_send_multisig_4_change(client: Client):
with client: with client:
client.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -710,6 +720,9 @@ def test_send_multisig_4_change(client: Client):
with client: with client:
client.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(client):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET client, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )

View File

@ -716,21 +716,20 @@ class InputFlowPaymentRequestDetails(InputFlowBase):
yield # confirm first output yield # confirm first output
assert self.outputs[0].address[:16] in self.text_content() # type: ignore assert self.outputs[0].address[:16] in self.text_content() # type: ignore
self.debug.press_yes() self.debug.swipe_up()
yield # confirm first output yield # confirm first output
self.debug.wait_layout() self.debug.wait_layout()
self.debug.press_yes() self.debug.swipe_up()
yield # confirm second output yield # confirm second output
assert self.outputs[1].address[:16] in self.text_content() # type: ignore assert self.outputs[1].address[:16] in self.text_content() # type: ignore
self.debug.press_yes() self.debug.swipe_up()
yield # confirm second output yield # confirm second output
self.debug.wait_layout() self.debug.wait_layout()
self.debug.press_yes() self.debug.swipe_up()
yield # confirm transaction yield # confirm transaction
self.debug.press_yes() self.debug.swipe_up()
yield # confirm transaction
self.debug.press_yes() self.debug.press_yes()
@ -747,7 +746,7 @@ class InputFlowSignTxHighFee(InputFlowBase):
self.finished = True self.finished = True
def input_flow_common(self) -> BRGeneratorType: def input_flow_tt(self) -> BRGeneratorType:
screens = [ screens = [
B.ConfirmOutput, B.ConfirmOutput,
B.ConfirmOutput, B.ConfirmOutput,
@ -756,6 +755,31 @@ class InputFlowSignTxHighFee(InputFlowBase):
] ]
yield from self.go_through_all_screens(screens) yield from self.go_through_all_screens(screens)
def input_flow_tr(self) -> BRGeneratorType:
screens = [
B.ConfirmOutput,
B.ConfirmOutput,
B.FeeOverThreshold,
B.SignTx,
]
yield from self.go_through_all_screens(screens)
def input_flow_t3t1(self) -> BRGeneratorType:
screens = [
B.ConfirmOutput,
B.ConfirmOutput,
B.FeeOverThreshold,
B.SignTx,
]
for expected in screens:
br = yield
assert br.code == expected
self.debug.swipe_up()
if br.code == B.SignTx:
self.debug.press_yes()
self.finished = True
def sign_tx_go_to_info(client: Client) -> Generator[None, None, str]: def sign_tx_go_to_info(client: Client) -> Generator[None, None, str]:
yield # confirm output yield # confirm output
@ -777,6 +801,43 @@ def sign_tx_go_to_info(client: Client) -> Generator[None, None, str]:
return content return content
def sign_tx_go_to_info_t3t1(
client: Client, multi_account: bool = False
) -> Generator[None, None, str]:
yield # confirm output
client.debug.wait_layout()
client.debug.swipe_up()
yield # confirm output
client.debug.wait_layout()
client.debug.swipe_up()
if multi_account:
yield
client.debug.wait_layout()
client.debug.swipe_up()
yield # confirm transaction
client.debug.wait_layout()
client.debug.click(buttons.CORNER_BUTTON)
client.debug.synchronize_at("VerticalMenu")
client.debug.click(buttons.VERTICAL_MENU[0])
layout = client.debug.wait_layout()
content = layout.text_content()
client.debug.click(buttons.CORNER_BUTTON)
client.debug.synchronize_at("VerticalMenu")
client.debug.click(buttons.VERTICAL_MENU[1])
layout = client.debug.wait_layout()
content += " " + layout.text_content()
client.debug.click(buttons.CORNER_BUTTON)
client.debug.click(buttons.CORNER_BUTTON, wait=True)
return content
def sign_tx_go_to_info_tr( def sign_tx_go_to_info_tr(
client: Client, client: Client,
) -> Generator[None, None, str]: ) -> Generator[None, None, str]:
@ -829,8 +890,9 @@ class InputFlowSignTxInformation(InputFlowBase):
self.debug.press_yes() self.debug.press_yes()
def input_flow_t3t1(self) -> BRGeneratorType: def input_flow_t3t1(self) -> BRGeneratorType:
content = yield from sign_tx_go_to_info(self.client) content = yield from sign_tx_go_to_info_t3t1(self.client)
self.assert_content(content, "confirm_total__sending_from_account") self.assert_content(content, "confirm_total__sending_from_account")
self.debug.swipe_up()
self.debug.press_yes() self.debug.press_yes()
@ -863,12 +925,9 @@ class InputFlowSignTxInformationMixed(InputFlowBase):
self.debug.press_yes() self.debug.press_yes()
def input_flow_t3t1(self) -> BRGeneratorType: def input_flow_t3t1(self) -> BRGeneratorType:
# multiple accounts warning content = yield from sign_tx_go_to_info_t3t1(self.client, multi_account=True)
yield
self.debug.press_yes()
content = yield from sign_tx_go_to_info(self.client)
self.assert_content(content, "confirm_total__sending_from_account") self.assert_content(content, "confirm_total__sending_from_account")
self.debug.swipe_up()
self.debug.press_yes() self.debug.press_yes()
@ -885,8 +944,11 @@ class InputFlowSignTxInformationCancel(InputFlowBase):
self.debug.press_left() self.debug.press_left()
def input_flow_t3t1(self) -> BRGeneratorType: def input_flow_t3t1(self) -> BRGeneratorType:
yield from sign_tx_go_to_info(self.client) yield from sign_tx_go_to_info_t3t1(self.client)
self.debug.press_no() self.debug.click(buttons.CORNER_BUTTON)
self.debug.click(buttons.VERTICAL_MENU[2])
self.debug.synchronize_at("PromptScreen")
self.debug.click(buttons.TAP_TO_CONFIRM)
class InputFlowSignTxInformationReplacement(InputFlowBase): class InputFlowSignTxInformationReplacement(InputFlowBase):
@ -984,6 +1046,30 @@ def lock_time_input_flow_tr(
debug.press_yes() debug.press_yes()
def lock_time_input_flow_t3t1(
debug: DebugLink,
layout_assert_func: Callable[[DebugLink, messages.ButtonRequest], None],
double_confirm: bool = False,
) -> BRGeneratorType:
yield # confirm output
debug.wait_layout()
debug.swipe_up()
yield # confirm output
debug.wait_layout()
debug.swipe_up()
br = yield # confirm locktime
layout_assert_func(debug, br)
debug.press_yes()
yield # confirm transaction
debug.swipe_up()
debug.press_yes()
if double_confirm:
yield # confirm transaction
debug.press_yes()
class InputFlowLockTimeBlockHeight(InputFlowBase): class InputFlowLockTimeBlockHeight(InputFlowBase):
def __init__(self, client: Client, block_height: str): def __init__(self, client: Client, block_height: str):
super().__init__(client) super().__init__(client)
@ -1003,7 +1089,7 @@ class InputFlowLockTimeBlockHeight(InputFlowBase):
yield from lock_time_input_flow_tr(self.debug, self.assert_func) yield from lock_time_input_flow_tr(self.debug, self.assert_func)
def input_flow_t3t1(self) -> BRGeneratorType: def input_flow_t3t1(self) -> BRGeneratorType:
yield from lock_time_input_flow_tt( yield from lock_time_input_flow_t3t1(
self.debug, self.assert_func, double_confirm=True self.debug, self.assert_func, double_confirm=True
) )
@ -1025,7 +1111,7 @@ class InputFlowLockTimeDatetime(InputFlowBase):
yield from lock_time_input_flow_tr(self.debug, self.assert_func) yield from lock_time_input_flow_tr(self.debug, self.assert_func)
def input_flow_t3t1(self) -> BRGeneratorType: def input_flow_t3t1(self) -> BRGeneratorType:
yield from lock_time_input_flow_tt(self.debug, self.assert_func) yield from lock_time_input_flow_t3t1(self.debug, self.assert_func)
class InputFlowEIP712ShowMore(InputFlowBase): class InputFlowEIP712ShowMore(InputFlowBase):
@ -2048,3 +2134,42 @@ class InputFlowResetSkipBackup(InputFlowBase):
yield # Confirm skip backup yield # Confirm skip backup
TR.assert_in(self.text_content(), "backup__want_to_skip") TR.assert_in(self.text_content(), "backup__want_to_skip")
self.debug.press_no() self.debug.press_no()
class InputFlowConfirmAllWarnings(InputFlowBase):
def __init__(self, client: Client):
super().__init__(client)
def input_flow_tt(self) -> BRGeneratorType:
br = yield
while True:
# wait for homescreen to go away
self.debug.wait_layout()
self.client.ui._default_input_flow(br)
br = yield
def input_flow_tr(self) -> BRGeneratorType:
return self.input_flow_tt()
def input_flow_t3t1(self) -> BRGeneratorType:
br = yield
while True:
# wait for homescreen to go away
# probably won't be needed after https://github.com/trezor/trezor-firmware/pull/3686
self.debug.wait_layout()
# Paginating (going as further as possible) and pressing Yes
if br.pages is not None:
for _ in range(br.pages - 1):
self.debug.swipe_up(wait=True)
layout = self.debug.read_layout()
text = layout.text_content().lower()
# hi priority warning
if ("wrong derivation path" in text) or ("to a multisig" in text):
self.debug.click(buttons.CORNER_BUTTON, wait=True)
self.debug.synchronize_at("VerticalMenu")
self.debug.click(buttons.VERTICAL_MENU[1])
elif "swipe up" in layout.footer().lower():
self.debug.swipe_up()
else:
self.debug.press_yes()
br = yield