diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index e94eb60fd3..5c9da6d64a 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -14,6 +14,8 @@ # You should have received a copy of the License along with this library. # If not, see . +from typing import Optional + import pytest from trezorlib import btc, messages @@ -143,30 +145,35 @@ def _responses( client: Client, INP1: messages.TxInputType, INP2: messages.TxInputType, - change: int = 0, - foreign: bool = False, + change_indices: Optional[list[int]] = None, + foreign_indices: Optional[list[int]] = None, ): + if change_indices is None: + change_indices = [] + if foreign_indices is None: + foreign_indices = [] + resp = [ request_input(0), request_input(1), request_output(0), ] - if change != 1: + if 1 in foreign_indices: + resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) + if 1 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) if is_core(client): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - elif foreign: - resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) resp.append(request_output(1)) - if change != 2: + if 2 in foreign_indices: + resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) + if 2 not in change_indices: resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) if is_core(client): resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) - elif foreign: - resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath)) resp += [ messages.ButtonRequest(code=B.SignTx), @@ -242,7 +249,7 @@ def test_external_internal(client: Client): with client: client.set_expected_responses( - _responses(client, INP1, INP2, change=2, foreign=True) + _responses(client, INP1, INP2, change_indices=[2], foreign_indices=[2]) ) if is_core(client): IF = InputFlowConfirmAllWarnings(client) @@ -278,7 +285,7 @@ def test_internal_external(client: Client): with client: client.set_expected_responses( - _responses(client, INP1, INP2, change=1, foreign=True) + _responses(client, INP1, INP2, change_indices=[1], foreign_indices=[1]) ) if is_core(client): IF = InputFlowConfirmAllWarnings(client) @@ -352,7 +359,9 @@ def test_multisig_change_match_first(client: Client): ) with client: - client.set_expected_responses(_responses(client, INP1, INP2, change=1)) + client.set_expected_responses( + _responses(client, INP1, INP2, change_indices=[1]) + ) _, serialized_tx = btc.sign_tx( client, "Bitcoin", @@ -391,7 +400,9 @@ def test_multisig_change_match_second(client: Client): ) with client: - client.set_expected_responses(_responses(client, INP1, INP2, change=2)) + client.set_expected_responses( + _responses(client, INP1, INP2, change_indices=[2]) + ) _, serialized_tx = btc.sign_tx( client, "Bitcoin", @@ -408,7 +419,7 @@ def test_multisig_change_match_second(client: Client): # inputs match, change mismatches (second tries to be change but isn't) -def test_multisig_mismatch_change(client: Client): +def test_multisig_mismatch_multisig_change(client: Client): multisig_out2 = messages.MultisigRedeemScriptType( nodes=[NODE_EXT1, NODE_INT, NODE_EXT3], address_n=[1, 0],