mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-28 00:01:31 +00:00
all: disallow most RecoveryDevice fields in dry-run (fixes #666)
This commit is contained in:
parent
b6d46e93e1
commit
34913a328a
@ -22,6 +22,12 @@ if False:
|
||||
from trezor.messages.RecoveryDevice import RecoveryDevice
|
||||
|
||||
|
||||
# List of RecoveryDevice fields that can be set when doing dry-run recovery.
|
||||
# All except `dry_run` are allowed for T1 compatibility, but their values are ignored.
|
||||
# If set, `enforce_wordlist` must be True, because we do not support non-enforcing.
|
||||
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
|
||||
|
||||
|
||||
async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
|
||||
"""
|
||||
Recover BIP39/SLIP39 seed into empty device.
|
||||
@ -29,7 +35,7 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
|
||||
User starts the process here using the RecoveryDevice msg and then they can unplug
|
||||
the device anytime and continue without a computer.
|
||||
"""
|
||||
_check_state(msg)
|
||||
_validate(msg)
|
||||
|
||||
if storage.recovery.is_in_progress():
|
||||
return await recovery_process(ctx)
|
||||
@ -43,27 +49,26 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
|
||||
await show_pin_invalid(ctx)
|
||||
raise wire.PinInvalid("PIN invalid")
|
||||
|
||||
if not msg.dry_run:
|
||||
# set up pin if requested
|
||||
if msg.pin_protection:
|
||||
if msg.dry_run:
|
||||
raise wire.ProcessError("Can't setup PIN during dry_run recovery.")
|
||||
newpin = await request_pin_confirm(ctx, allow_cancel=False)
|
||||
config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
|
||||
|
||||
if msg.u2f_counter:
|
||||
if msg.u2f_counter is not None:
|
||||
storage.device.set_u2f_counter(msg.u2f_counter)
|
||||
storage.device.load_settings(
|
||||
label=msg.label, use_passphrase=msg.passphrase_protection
|
||||
)
|
||||
|
||||
storage.recovery.set_in_progress(True)
|
||||
if msg.dry_run:
|
||||
storage.recovery.set_dry_run(msg.dry_run)
|
||||
storage.recovery.set_dry_run(bool(msg.dry_run))
|
||||
|
||||
workflow.replace_default(recovery_homescreen)
|
||||
return await recovery_process(ctx)
|
||||
|
||||
|
||||
def _check_state(msg: RecoveryDevice) -> None:
|
||||
def _validate(msg: RecoveryDevice) -> None:
|
||||
if not msg.dry_run and storage.is_initialized():
|
||||
raise wire.UnexpectedMessage("Already initialized")
|
||||
if msg.dry_run and not storage.is_initialized():
|
||||
@ -74,6 +79,14 @@ def _check_state(msg: RecoveryDevice) -> None:
|
||||
"Value enforce_wordlist must be True, Trezor Core enforces words automatically."
|
||||
)
|
||||
|
||||
if msg.dry_run:
|
||||
# check that only allowed fields are set
|
||||
for key, value in msg.__dict__.items():
|
||||
if key not in DRY_RUN_ALLOWED_FIELDS and value is not None:
|
||||
raise wire.ProcessError(
|
||||
"Forbidden field set in dry-run: {}".format(key)
|
||||
)
|
||||
|
||||
|
||||
async def _continue_dialog(ctx: wire.Context, msg: RecoveryDevice) -> None:
|
||||
if not msg.dry_run:
|
||||
|
@ -403,6 +403,12 @@ void fsm_msgRecoveryDevice(const RecoveryDevice *msg) {
|
||||
const bool dry_run = msg->has_dry_run ? msg->dry_run : false;
|
||||
if (!dry_run) {
|
||||
CHECK_NOT_INITIALIZED
|
||||
} else {
|
||||
CHECK_INITIALIZED
|
||||
CHECK_PARAM(!msg->has_passphrase_protection && !msg->has_pin_protection &&
|
||||
!msg->has_language && !msg->has_label &&
|
||||
!msg->has_u2f_counter,
|
||||
_("Forbidden field set in dry-run"))
|
||||
}
|
||||
|
||||
CHECK_PARAM(!msg->has_word_count || msg->word_count == 12 ||
|
||||
|
@ -18,7 +18,7 @@ import os
|
||||
import time
|
||||
import warnings
|
||||
|
||||
from . import messages as proto
|
||||
from . import messages
|
||||
from .exceptions import Cancelled
|
||||
from .tools import expect, session
|
||||
from .transport import enumerate_devices, get_transport
|
||||
@ -44,7 +44,7 @@ class TrezorDevice:
|
||||
return get_transport(path, prefix_search=False)
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
@expect(messages.Success, field="message")
|
||||
def apply_settings(
|
||||
client,
|
||||
label=None,
|
||||
@ -55,7 +55,7 @@ def apply_settings(
|
||||
auto_lock_delay_ms=None,
|
||||
display_rotation=None,
|
||||
):
|
||||
settings = proto.ApplySettings()
|
||||
settings = messages.ApplySettings()
|
||||
if label is not None:
|
||||
settings.label = label
|
||||
if language:
|
||||
@ -76,30 +76,30 @@ def apply_settings(
|
||||
return out
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
@expect(messages.Success, field="message")
|
||||
def apply_flags(client, flags):
|
||||
out = client.call(proto.ApplyFlags(flags=flags))
|
||||
out = client.call(messages.ApplyFlags(flags=flags))
|
||||
client.init_device() # Reload Features
|
||||
return out
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
@expect(messages.Success, field="message")
|
||||
def change_pin(client, remove=False):
|
||||
ret = client.call(proto.ChangePin(remove=remove))
|
||||
ret = client.call(messages.ChangePin(remove=remove))
|
||||
client.init_device() # Re-read features
|
||||
return ret
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
@expect(messages.Success, field="message")
|
||||
def sd_protect(client, operation):
|
||||
ret = client.call(proto.SdProtect(operation=operation))
|
||||
ret = client.call(messages.SdProtect(operation=operation))
|
||||
client.init_device()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
@expect(messages.Success, field="message")
|
||||
def wipe(client):
|
||||
ret = client.call(proto.WipeDevice())
|
||||
ret = client.call(messages.WipeDevice())
|
||||
client.init_device()
|
||||
return ret
|
||||
|
||||
@ -112,7 +112,7 @@ def recover(
|
||||
label=None,
|
||||
language="english",
|
||||
input_callback=None,
|
||||
type=proto.RecoveryDeviceType.ScrambledWords,
|
||||
type=messages.RecoveryDeviceType.ScrambledWords,
|
||||
dry_run=False,
|
||||
u2f_counter=None,
|
||||
):
|
||||
@ -130,32 +130,32 @@ def recover(
|
||||
if u2f_counter is None:
|
||||
u2f_counter = int(time.time())
|
||||
|
||||
res = client.call(
|
||||
proto.RecoveryDevice(
|
||||
word_count=word_count,
|
||||
passphrase_protection=bool(passphrase_protection),
|
||||
pin_protection=bool(pin_protection),
|
||||
label=label,
|
||||
language=language,
|
||||
enforce_wordlist=True,
|
||||
type=type,
|
||||
dry_run=dry_run,
|
||||
u2f_counter=u2f_counter,
|
||||
)
|
||||
msg = messages.RecoveryDevice(
|
||||
word_count=word_count, enforce_wordlist=True, type=type, dry_run=dry_run
|
||||
)
|
||||
|
||||
while isinstance(res, proto.WordRequest):
|
||||
if not dry_run:
|
||||
# set additional parameters
|
||||
msg.passphrase_protection = passphrase_protection
|
||||
msg.pin_protection = pin_protection
|
||||
msg.label = label
|
||||
msg.language = language
|
||||
msg.u2f_counter = u2f_counter
|
||||
|
||||
res = client.call(msg)
|
||||
|
||||
while isinstance(res, messages.WordRequest):
|
||||
try:
|
||||
inp = input_callback(res.type)
|
||||
res = client.call(proto.WordAck(word=inp))
|
||||
res = client.call(messages.WordAck(word=inp))
|
||||
except Cancelled:
|
||||
res = client.call(proto.Cancel())
|
||||
res = client.call(messages.Cancel())
|
||||
|
||||
client.init_device()
|
||||
return res
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
@expect(messages.Success, field="message")
|
||||
@session
|
||||
def reset(
|
||||
client,
|
||||
@ -168,7 +168,7 @@ def reset(
|
||||
u2f_counter=0,
|
||||
skip_backup=False,
|
||||
no_backup=False,
|
||||
backup_type=proto.BackupType.Bip39,
|
||||
backup_type=messages.BackupType.Bip39,
|
||||
):
|
||||
if client.features.initialized:
|
||||
raise RuntimeError(
|
||||
@ -182,7 +182,7 @@ def reset(
|
||||
strength = 128
|
||||
|
||||
# Begin with device reset workflow
|
||||
msg = proto.ResetDevice(
|
||||
msg = messages.ResetDevice(
|
||||
display_random=bool(display_random),
|
||||
strength=strength,
|
||||
passphrase_protection=bool(passphrase_protection),
|
||||
@ -196,17 +196,17 @@ def reset(
|
||||
)
|
||||
|
||||
resp = client.call(msg)
|
||||
if not isinstance(resp, proto.EntropyRequest):
|
||||
if not isinstance(resp, messages.EntropyRequest):
|
||||
raise RuntimeError("Invalid response, expected EntropyRequest")
|
||||
|
||||
external_entropy = os.urandom(32)
|
||||
# LOG.debug("Computer generated entropy: " + external_entropy.hex())
|
||||
ret = client.call(proto.EntropyAck(entropy=external_entropy))
|
||||
ret = client.call(messages.EntropyAck(entropy=external_entropy))
|
||||
client.init_device()
|
||||
return ret
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
@expect(messages.Success, field="message")
|
||||
def backup(client):
|
||||
ret = client.call(proto.BackupDevice())
|
||||
ret = client.call(messages.BackupDevice())
|
||||
return ret
|
||||
|
@ -16,54 +16,190 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import messages as proto
|
||||
from trezorlib import device, exceptions, messages, protobuf
|
||||
|
||||
from .. import buttons
|
||||
from ..common import MNEMONIC12
|
||||
|
||||
|
||||
@pytest.mark.skip_t2
|
||||
class TestMsgRecoverydeviceDryrun:
|
||||
def recovery_loop(self, client, mnemonic, result):
|
||||
ret = client.call_raw(
|
||||
proto.RecoveryDevice(
|
||||
word_count=12,
|
||||
passphrase_protection=False,
|
||||
pin_protection=False,
|
||||
label="label",
|
||||
language="english",
|
||||
enforce_wordlist=True,
|
||||
dry_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
fakes = 0
|
||||
for _ in range(int(12 * 2)):
|
||||
assert isinstance(ret, proto.WordRequest)
|
||||
(word, pos) = client.debug.read_recovery_word()
|
||||
|
||||
def do_recover_legacy(client, mnemonic, **kwargs):
|
||||
def input_callback(_):
|
||||
word, pos = client.debug.read_recovery_word()
|
||||
if pos != 0:
|
||||
ret = client.call_raw(proto.WordAck(word=mnemonic[pos - 1]))
|
||||
word = mnemonic[pos - 1]
|
||||
mnemonic[pos - 1] = None
|
||||
assert word is not None
|
||||
|
||||
return word
|
||||
|
||||
ret = device.recover(
|
||||
client,
|
||||
dry_run=True,
|
||||
word_count=len(mnemonic),
|
||||
type=messages.RecoveryDeviceType.ScrambledWords,
|
||||
input_callback=input_callback,
|
||||
**kwargs
|
||||
)
|
||||
# if the call succeeded, check that all words have been used
|
||||
assert all(m is None for m in mnemonic)
|
||||
return ret
|
||||
|
||||
|
||||
def do_recover_core(client, mnemonic, **kwargs):
|
||||
def input_flow():
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert "check the recovery seed" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert "Select number of words" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert layout.text == "WordSelector"
|
||||
# click the number
|
||||
word_option_offset = 6
|
||||
word_options = (12, 18, 20, 24, 33)
|
||||
index = word_option_offset + word_options.index(len(mnemonic))
|
||||
client.debug.click(buttons.grid34(index % 3, index // 3))
|
||||
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert "Enter recovery seed" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
yield
|
||||
for word in mnemonic:
|
||||
client.debug.input(word)
|
||||
|
||||
yield
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
with client:
|
||||
client.set_input_flow(input_flow)
|
||||
return device.recover(client, dry_run=True, **kwargs)
|
||||
|
||||
|
||||
def do_recover(client, mnemonic):
|
||||
if client.features.model == "1":
|
||||
return do_recover_legacy(client, mnemonic)
|
||||
else:
|
||||
ret = client.call_raw(proto.WordAck(word=word))
|
||||
fakes += 1
|
||||
return do_recover_core(client, mnemonic)
|
||||
|
||||
print(mnemonic)
|
||||
|
||||
assert isinstance(ret, proto.ButtonRequest)
|
||||
client.debug.press_yes()
|
||||
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
|
||||
def test_dry_run(client):
|
||||
ret = do_recover(client, MNEMONIC12.split(" "))
|
||||
assert isinstance(ret, messages.Success)
|
||||
|
||||
ret = client.call_raw(proto.ButtonAck())
|
||||
assert isinstance(ret, result)
|
||||
|
||||
def test_correct_notsame(self, client):
|
||||
mnemonic = MNEMONIC12.split(" ")
|
||||
self.recovery_loop(client, mnemonic, proto.Failure)
|
||||
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
|
||||
def test_seed_mismatch(client):
|
||||
with pytest.raises(exceptions.TrezorFailure) as exc:
|
||||
do_recover(client, ["all"] * 12)
|
||||
assert "does not match the one in the device" in exc.value.failure.message
|
||||
|
||||
def test_correct_same(self, client):
|
||||
mnemonic = ["all"] * 12
|
||||
self.recovery_loop(client, mnemonic, proto.Success)
|
||||
|
||||
def test_incorrect(self, client):
|
||||
mnemonic = ["stick"] * 12
|
||||
self.recovery_loop(client, mnemonic, proto.Failure)
|
||||
@pytest.mark.skip_t2
|
||||
def test_invalid_seed_t1(client):
|
||||
with pytest.raises(exceptions.TrezorFailure) as exc:
|
||||
do_recover(client, ["stick"] * 12)
|
||||
assert "Invalid seed" in exc.value.failure.message
|
||||
|
||||
|
||||
@pytest.mark.skip_t1
|
||||
def test_invalid_seed_core(client):
|
||||
def input_flow():
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert "check the recovery seed" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert "Select number of words" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert layout.text == "WordSelector"
|
||||
# select 12 words
|
||||
client.debug.click(buttons.grid34(0, 2))
|
||||
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert "Enter recovery seed" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
yield
|
||||
for _ in range(12):
|
||||
client.debug.input("stick")
|
||||
|
||||
code = yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert code == messages.ButtonRequestType.Warning
|
||||
assert "invalid recovery seed" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
yield
|
||||
# retry screen
|
||||
layout = client.debug.wait_layout()
|
||||
assert "Select number of words" in layout.text
|
||||
client.debug.click(buttons.CANCEL)
|
||||
|
||||
yield
|
||||
layout = client.debug.wait_layout()
|
||||
assert "abort" in layout.text
|
||||
client.debug.click(buttons.OK)
|
||||
|
||||
with client:
|
||||
client.set_input_flow(input_flow)
|
||||
with pytest.raises(exceptions.Cancelled):
|
||||
return device.recover(client, dry_run=True)
|
||||
|
||||
|
||||
@pytest.mark.setup_client(uninitialized=True)
|
||||
def test_uninitialized(client):
|
||||
with pytest.raises(exceptions.TrezorFailure) as exc:
|
||||
do_recover(client, ["all"] * 12)
|
||||
assert "not initialized" in exc.value.failure.message
|
||||
|
||||
|
||||
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
|
||||
|
||||
|
||||
def _make_bad_params():
|
||||
"""Generate a list of field names that must NOT be set on a dry-run message,
|
||||
and default values of the appropriate type.
|
||||
"""
|
||||
for fname, ftype, _ in messages.RecoveryDevice.get_fields().values():
|
||||
if fname in DRY_RUN_ALLOWED_FIELDS:
|
||||
continue
|
||||
|
||||
if ftype is protobuf.UVarintType:
|
||||
yield fname, 1
|
||||
elif ftype is protobuf.BoolType:
|
||||
yield fname, True
|
||||
elif ftype is protobuf.UnicodeType:
|
||||
yield fname, "test"
|
||||
else:
|
||||
# Someone added a field to RecoveryDevice of a type that has no assigned
|
||||
# default value. This test must be fixed.
|
||||
raise RuntimeError("unknown field in RecoveryDevice")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("field_name, field_value", _make_bad_params())
|
||||
def test_bad_parameters(client, field_name, field_value):
|
||||
msg = messages.RecoveryDevice(
|
||||
dry_run=True,
|
||||
word_count=12,
|
||||
enforce_wordlist=True,
|
||||
type=messages.RecoveryDeviceType.ScrambledWords,
|
||||
)
|
||||
setattr(msg, field_name, field_value)
|
||||
with pytest.raises(exceptions.TrezorFailure) as exc:
|
||||
client.call(msg)
|
||||
assert "Forbidden field set in dry-run" in exc.value.failure.message
|
||||
|
Loading…
Reference in New Issue
Block a user