diff --git a/core/src/apps/management/recovery_device/__init__.py b/core/src/apps/management/recovery_device/__init__.py index 1eef93181..515140b8c 100644 --- a/core/src/apps/management/recovery_device/__init__.py +++ b/core/src/apps/management/recovery_device/__init__.py @@ -22,6 +22,12 @@ if False: from trezor.messages.RecoveryDevice import RecoveryDevice +# List of RecoveryDevice fields that can be set when doing dry-run recovery. +# All except `dry_run` are allowed for T1 compatibility, but their values are ignored. +# If set, `enforce_wordlist` must be True, because we do not support non-enforcing. +DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type") + + async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: """ Recover BIP39/SLIP39 seed into empty device. @@ -29,7 +35,7 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: User starts the process here using the RecoveryDevice msg and then they can unplug the device anytime and continue without a computer. """ - _check_state(msg) + _validate(msg) if storage.recovery.is_in_progress(): return await recovery_process(ctx) @@ -43,27 +49,26 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: await show_pin_invalid(ctx) raise wire.PinInvalid("PIN invalid") - # set up pin if requested - if msg.pin_protection: - if msg.dry_run: - raise wire.ProcessError("Can't setup PIN during dry_run recovery.") - newpin = await request_pin_confirm(ctx, allow_cancel=False) - config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None) - - if msg.u2f_counter: - storage.device.set_u2f_counter(msg.u2f_counter) - storage.device.load_settings( - label=msg.label, use_passphrase=msg.passphrase_protection - ) + if not msg.dry_run: + # set up pin if requested + if msg.pin_protection: + newpin = await request_pin_confirm(ctx, allow_cancel=False) + config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None) + + if msg.u2f_counter is not None: + storage.device.set_u2f_counter(msg.u2f_counter) + storage.device.load_settings( + label=msg.label, use_passphrase=msg.passphrase_protection + ) + storage.recovery.set_in_progress(True) - if msg.dry_run: - storage.recovery.set_dry_run(msg.dry_run) + storage.recovery.set_dry_run(bool(msg.dry_run)) workflow.replace_default(recovery_homescreen) return await recovery_process(ctx) -def _check_state(msg: RecoveryDevice) -> None: +def _validate(msg: RecoveryDevice) -> None: if not msg.dry_run and storage.is_initialized(): raise wire.UnexpectedMessage("Already initialized") if msg.dry_run and not storage.is_initialized(): @@ -74,6 +79,14 @@ def _check_state(msg: RecoveryDevice) -> None: "Value enforce_wordlist must be True, Trezor Core enforces words automatically." ) + if msg.dry_run: + # check that only allowed fields are set + for key, value in msg.__dict__.items(): + if key not in DRY_RUN_ALLOWED_FIELDS and value is not None: + raise wire.ProcessError( + "Forbidden field set in dry-run: {}".format(key) + ) + async def _continue_dialog(ctx: wire.Context, msg: RecoveryDevice) -> None: if not msg.dry_run: diff --git a/legacy/firmware/fsm_msg_common.h b/legacy/firmware/fsm_msg_common.h index ebe0116eb..d84b47abb 100644 --- a/legacy/firmware/fsm_msg_common.h +++ b/legacy/firmware/fsm_msg_common.h @@ -403,6 +403,12 @@ void fsm_msgRecoveryDevice(const RecoveryDevice *msg) { const bool dry_run = msg->has_dry_run ? msg->dry_run : false; if (!dry_run) { CHECK_NOT_INITIALIZED + } else { + CHECK_INITIALIZED + CHECK_PARAM(!msg->has_passphrase_protection && !msg->has_pin_protection && + !msg->has_language && !msg->has_label && + !msg->has_u2f_counter, + _("Forbidden field set in dry-run")) } CHECK_PARAM(!msg->has_word_count || msg->word_count == 12 || diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index 823d1c94b..51decd4e3 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -18,7 +18,7 @@ import os import time import warnings -from . import messages as proto +from . import messages from .exceptions import Cancelled from .tools import expect, session from .transport import enumerate_devices, get_transport @@ -44,7 +44,7 @@ class TrezorDevice: return get_transport(path, prefix_search=False) -@expect(proto.Success, field="message") +@expect(messages.Success, field="message") def apply_settings( client, label=None, @@ -55,7 +55,7 @@ def apply_settings( auto_lock_delay_ms=None, display_rotation=None, ): - settings = proto.ApplySettings() + settings = messages.ApplySettings() if label is not None: settings.label = label if language: @@ -76,30 +76,30 @@ def apply_settings( return out -@expect(proto.Success, field="message") +@expect(messages.Success, field="message") def apply_flags(client, flags): - out = client.call(proto.ApplyFlags(flags=flags)) + out = client.call(messages.ApplyFlags(flags=flags)) client.init_device() # Reload Features return out -@expect(proto.Success, field="message") +@expect(messages.Success, field="message") def change_pin(client, remove=False): - ret = client.call(proto.ChangePin(remove=remove)) + ret = client.call(messages.ChangePin(remove=remove)) client.init_device() # Re-read features return ret -@expect(proto.Success, field="message") +@expect(messages.Success, field="message") def sd_protect(client, operation): - ret = client.call(proto.SdProtect(operation=operation)) + ret = client.call(messages.SdProtect(operation=operation)) client.init_device() return ret -@expect(proto.Success, field="message") +@expect(messages.Success, field="message") def wipe(client): - ret = client.call(proto.WipeDevice()) + ret = client.call(messages.WipeDevice()) client.init_device() return ret @@ -112,7 +112,7 @@ def recover( label=None, language="english", input_callback=None, - type=proto.RecoveryDeviceType.ScrambledWords, + type=messages.RecoveryDeviceType.ScrambledWords, dry_run=False, u2f_counter=None, ): @@ -130,32 +130,32 @@ def recover( if u2f_counter is None: u2f_counter = int(time.time()) - res = client.call( - proto.RecoveryDevice( - word_count=word_count, - passphrase_protection=bool(passphrase_protection), - pin_protection=bool(pin_protection), - label=label, - language=language, - enforce_wordlist=True, - type=type, - dry_run=dry_run, - u2f_counter=u2f_counter, - ) + msg = messages.RecoveryDevice( + word_count=word_count, enforce_wordlist=True, type=type, dry_run=dry_run ) - while isinstance(res, proto.WordRequest): + if not dry_run: + # set additional parameters + msg.passphrase_protection = passphrase_protection + msg.pin_protection = pin_protection + msg.label = label + msg.language = language + msg.u2f_counter = u2f_counter + + res = client.call(msg) + + while isinstance(res, messages.WordRequest): try: inp = input_callback(res.type) - res = client.call(proto.WordAck(word=inp)) + res = client.call(messages.WordAck(word=inp)) except Cancelled: - res = client.call(proto.Cancel()) + res = client.call(messages.Cancel()) client.init_device() return res -@expect(proto.Success, field="message") +@expect(messages.Success, field="message") @session def reset( client, @@ -168,7 +168,7 @@ def reset( u2f_counter=0, skip_backup=False, no_backup=False, - backup_type=proto.BackupType.Bip39, + backup_type=messages.BackupType.Bip39, ): if client.features.initialized: raise RuntimeError( @@ -182,7 +182,7 @@ def reset( strength = 128 # Begin with device reset workflow - msg = proto.ResetDevice( + msg = messages.ResetDevice( display_random=bool(display_random), strength=strength, passphrase_protection=bool(passphrase_protection), @@ -196,17 +196,17 @@ def reset( ) resp = client.call(msg) - if not isinstance(resp, proto.EntropyRequest): + if not isinstance(resp, messages.EntropyRequest): raise RuntimeError("Invalid response, expected EntropyRequest") external_entropy = os.urandom(32) # LOG.debug("Computer generated entropy: " + external_entropy.hex()) - ret = client.call(proto.EntropyAck(entropy=external_entropy)) + ret = client.call(messages.EntropyAck(entropy=external_entropy)) client.init_device() return ret -@expect(proto.Success, field="message") +@expect(messages.Success, field="message") def backup(client): - ret = client.call(proto.BackupDevice()) + ret = client.call(messages.BackupDevice()) return ret diff --git a/tests/device_tests/test_msg_recoverydevice_bip39_dryrun.py b/tests/device_tests/test_msg_recoverydevice_bip39_dryrun.py index f651ac420..746b469b0 100644 --- a/tests/device_tests/test_msg_recoverydevice_bip39_dryrun.py +++ b/tests/device_tests/test_msg_recoverydevice_bip39_dryrun.py @@ -16,54 +16,190 @@ import pytest -from trezorlib import messages as proto +from trezorlib import device, exceptions, messages, protobuf +from .. import buttons from ..common import MNEMONIC12 +def do_recover_legacy(client, mnemonic, **kwargs): + def input_callback(_): + word, pos = client.debug.read_recovery_word() + if pos != 0: + word = mnemonic[pos - 1] + mnemonic[pos - 1] = None + assert word is not None + + return word + + ret = device.recover( + client, + dry_run=True, + word_count=len(mnemonic), + type=messages.RecoveryDeviceType.ScrambledWords, + input_callback=input_callback, + **kwargs + ) + # if the call succeeded, check that all words have been used + assert all(m is None for m in mnemonic) + return ret + + +def do_recover_core(client, mnemonic, **kwargs): + def input_flow(): + yield + layout = client.debug.wait_layout() + assert "check the recovery seed" in layout.text + client.debug.click(buttons.OK) + + yield + layout = client.debug.wait_layout() + assert "Select number of words" in layout.text + client.debug.click(buttons.OK) + + yield + layout = client.debug.wait_layout() + assert layout.text == "WordSelector" + # click the number + word_option_offset = 6 + word_options = (12, 18, 20, 24, 33) + index = word_option_offset + word_options.index(len(mnemonic)) + client.debug.click(buttons.grid34(index % 3, index // 3)) + + yield + layout = client.debug.wait_layout() + assert "Enter recovery seed" in layout.text + client.debug.click(buttons.OK) + + yield + for word in mnemonic: + client.debug.input(word) + + yield + client.debug.click(buttons.OK) + + with client: + client.set_input_flow(input_flow) + return device.recover(client, dry_run=True, **kwargs) + + +def do_recover(client, mnemonic): + if client.features.model == "1": + return do_recover_legacy(client, mnemonic) + else: + return do_recover_core(client, mnemonic) + + +@pytest.mark.setup_client(mnemonic=MNEMONIC12) +def test_dry_run(client): + ret = do_recover(client, MNEMONIC12.split(" ")) + assert isinstance(ret, messages.Success) + + +@pytest.mark.setup_client(mnemonic=MNEMONIC12) +def test_seed_mismatch(client): + with pytest.raises(exceptions.TrezorFailure) as exc: + do_recover(client, ["all"] * 12) + assert "does not match the one in the device" in exc.value.failure.message + + @pytest.mark.skip_t2 -class TestMsgRecoverydeviceDryrun: - def recovery_loop(self, client, mnemonic, result): - ret = client.call_raw( - proto.RecoveryDevice( - word_count=12, - passphrase_protection=False, - pin_protection=False, - label="label", - language="english", - enforce_wordlist=True, - dry_run=True, - ) - ) - - fakes = 0 - for _ in range(int(12 * 2)): - assert isinstance(ret, proto.WordRequest) - (word, pos) = client.debug.read_recovery_word() - - if pos != 0: - ret = client.call_raw(proto.WordAck(word=mnemonic[pos - 1])) - mnemonic[pos - 1] = None - else: - ret = client.call_raw(proto.WordAck(word=word)) - fakes += 1 - - print(mnemonic) - - assert isinstance(ret, proto.ButtonRequest) - client.debug.press_yes() - - ret = client.call_raw(proto.ButtonAck()) - assert isinstance(ret, result) - - def test_correct_notsame(self, client): - mnemonic = MNEMONIC12.split(" ") - self.recovery_loop(client, mnemonic, proto.Failure) - - def test_correct_same(self, client): - mnemonic = ["all"] * 12 - self.recovery_loop(client, mnemonic, proto.Success) - - def test_incorrect(self, client): - mnemonic = ["stick"] * 12 - self.recovery_loop(client, mnemonic, proto.Failure) +def test_invalid_seed_t1(client): + with pytest.raises(exceptions.TrezorFailure) as exc: + do_recover(client, ["stick"] * 12) + assert "Invalid seed" in exc.value.failure.message + + +@pytest.mark.skip_t1 +def test_invalid_seed_core(client): + def input_flow(): + yield + layout = client.debug.wait_layout() + assert "check the recovery seed" in layout.text + client.debug.click(buttons.OK) + + yield + layout = client.debug.wait_layout() + assert "Select number of words" in layout.text + client.debug.click(buttons.OK) + + yield + layout = client.debug.wait_layout() + assert layout.text == "WordSelector" + # select 12 words + client.debug.click(buttons.grid34(0, 2)) + + yield + layout = client.debug.wait_layout() + assert "Enter recovery seed" in layout.text + client.debug.click(buttons.OK) + + yield + for _ in range(12): + client.debug.input("stick") + + code = yield + layout = client.debug.wait_layout() + assert code == messages.ButtonRequestType.Warning + assert "invalid recovery seed" in layout.text + client.debug.click(buttons.OK) + + yield + # retry screen + layout = client.debug.wait_layout() + assert "Select number of words" in layout.text + client.debug.click(buttons.CANCEL) + + yield + layout = client.debug.wait_layout() + assert "abort" in layout.text + client.debug.click(buttons.OK) + + with client: + client.set_input_flow(input_flow) + with pytest.raises(exceptions.Cancelled): + return device.recover(client, dry_run=True) + + +@pytest.mark.setup_client(uninitialized=True) +def test_uninitialized(client): + with pytest.raises(exceptions.TrezorFailure) as exc: + do_recover(client, ["all"] * 12) + assert "not initialized" in exc.value.failure.message + + +DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type") + + +def _make_bad_params(): + """Generate a list of field names that must NOT be set on a dry-run message, + and default values of the appropriate type. + """ + for fname, ftype, _ in messages.RecoveryDevice.get_fields().values(): + if fname in DRY_RUN_ALLOWED_FIELDS: + continue + + if ftype is protobuf.UVarintType: + yield fname, 1 + elif ftype is protobuf.BoolType: + yield fname, True + elif ftype is protobuf.UnicodeType: + yield fname, "test" + else: + # Someone added a field to RecoveryDevice of a type that has no assigned + # default value. This test must be fixed. + raise RuntimeError("unknown field in RecoveryDevice") + + +@pytest.mark.parametrize("field_name, field_value", _make_bad_params()) +def test_bad_parameters(client, field_name, field_value): + msg = messages.RecoveryDevice( + dry_run=True, + word_count=12, + enforce_wordlist=True, + type=messages.RecoveryDeviceType.ScrambledWords, + ) + setattr(msg, field_name, field_value) + with pytest.raises(exceptions.TrezorFailure) as exc: + client.call(msg) + assert "Forbidden field set in dry-run" in exc.value.failure.message