1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-28 02:18:20 +00:00

refactor(test): refactor device test

This commit is contained in:
Ondřej Vejpustek 2024-11-12 14:32:26 +01:00
parent 78d522d650
commit f0ffebfc1d

View File

@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import Optional
import pytest import pytest
from trezorlib import btc, messages from trezorlib import btc, messages
@ -143,30 +145,35 @@ def _responses(
client: Client, client: Client,
INP1: messages.TxInputType, INP1: messages.TxInputType,
INP2: messages.TxInputType, INP2: messages.TxInputType,
change: int = 0, change_indices: Optional[list[int]] = None,
foreign: bool = False, foreign_indices: Optional[list[int]] = None,
): ):
if change_indices is None:
change_indices = []
if foreign_indices is None:
foreign_indices = []
resp = [ resp = [
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
] ]
if change != 1: if 1 in foreign_indices:
resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath))
if 1 not in change_indices:
resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(messages.ButtonRequest(code=B.ConfirmOutput))
if is_core(client): if is_core(client):
resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(messages.ButtonRequest(code=B.ConfirmOutput))
elif foreign:
resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath))
resp.append(request_output(1)) resp.append(request_output(1))
if change != 2: if 2 in foreign_indices:
resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath))
if 2 not in change_indices:
resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(messages.ButtonRequest(code=B.ConfirmOutput))
if is_core(client): if is_core(client):
resp.append(messages.ButtonRequest(code=B.ConfirmOutput)) resp.append(messages.ButtonRequest(code=B.ConfirmOutput))
elif foreign:
resp.append(messages.ButtonRequest(code=B.UnknownDerivationPath))
resp += [ resp += [
messages.ButtonRequest(code=B.SignTx), messages.ButtonRequest(code=B.SignTx),
@ -242,7 +249,7 @@ def test_external_internal(client: Client):
with client: with client:
client.set_expected_responses( client.set_expected_responses(
_responses(client, INP1, INP2, change=2, foreign=True) _responses(client, INP1, INP2, change_indices=[2], foreign_indices=[2])
) )
if is_core(client): if is_core(client):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(client)
@ -278,7 +285,7 @@ def test_internal_external(client: Client):
with client: with client:
client.set_expected_responses( client.set_expected_responses(
_responses(client, INP1, INP2, change=1, foreign=True) _responses(client, INP1, INP2, change_indices=[1], foreign_indices=[1])
) )
if is_core(client): if is_core(client):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(client)
@ -352,7 +359,9 @@ def test_multisig_change_match_first(client: Client):
) )
with client: with client:
client.set_expected_responses(_responses(client, INP1, INP2, change=1)) client.set_expected_responses(
_responses(client, INP1, INP2, change_indices=[1])
)
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, client,
"Bitcoin", "Bitcoin",
@ -391,7 +400,9 @@ def test_multisig_change_match_second(client: Client):
) )
with client: with client:
client.set_expected_responses(_responses(client, INP1, INP2, change=2)) client.set_expected_responses(
_responses(client, INP1, INP2, change_indices=[2])
)
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
client, client,
"Bitcoin", "Bitcoin",
@ -408,7 +419,7 @@ def test_multisig_change_match_second(client: Client):
# inputs match, change mismatches (second tries to be change but isn't) # inputs match, change mismatches (second tries to be change but isn't)
def test_multisig_mismatch_change(client: Client): def test_multisig_mismatch_multisig_change(client: Client):
multisig_out2 = messages.MultisigRedeemScriptType( multisig_out2 = messages.MultisigRedeemScriptType(
nodes=[NODE_EXT1, NODE_INT, NODE_EXT3], nodes=[NODE_EXT1, NODE_INT, NODE_EXT3],
address_n=[1, 0], address_n=[1, 0],