all: disallow most RecoveryDevice fields in dry-run (fixes #666)

pull/723/head
matejcik 5 years ago committed by matejcik
parent b6d46e93e1
commit 34913a328a

@ -22,6 +22,12 @@ if False:
from trezor.messages.RecoveryDevice import RecoveryDevice
# List of RecoveryDevice fields that can be set when doing dry-run recovery.
# All except `dry_run` are allowed for T1 compatibility, but their values are ignored.
# If set, `enforce_wordlist` must be True, because we do not support non-enforcing.
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
"""
Recover BIP39/SLIP39 seed into empty device.
@ -29,7 +35,7 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
User starts the process here using the RecoveryDevice msg and then they can unplug
the device anytime and continue without a computer.
"""
_check_state(msg)
_validate(msg)
if storage.recovery.is_in_progress():
return await recovery_process(ctx)
@ -43,27 +49,26 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
await show_pin_invalid(ctx)
raise wire.PinInvalid("PIN invalid")
# set up pin if requested
if msg.pin_protection:
if msg.dry_run:
raise wire.ProcessError("Can't setup PIN during dry_run recovery.")
newpin = await request_pin_confirm(ctx, allow_cancel=False)
config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
if msg.u2f_counter:
storage.device.set_u2f_counter(msg.u2f_counter)
storage.device.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection
)
if not msg.dry_run:
# set up pin if requested
if msg.pin_protection:
newpin = await request_pin_confirm(ctx, allow_cancel=False)
config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
if msg.u2f_counter is not None:
storage.device.set_u2f_counter(msg.u2f_counter)
storage.device.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection
)
storage.recovery.set_in_progress(True)
if msg.dry_run:
storage.recovery.set_dry_run(msg.dry_run)
storage.recovery.set_dry_run(bool(msg.dry_run))
workflow.replace_default(recovery_homescreen)
return await recovery_process(ctx)
def _check_state(msg: RecoveryDevice) -> None:
def _validate(msg: RecoveryDevice) -> None:
if not msg.dry_run and storage.is_initialized():
raise wire.UnexpectedMessage("Already initialized")
if msg.dry_run and not storage.is_initialized():
@ -74,6 +79,14 @@ def _check_state(msg: RecoveryDevice) -> None:
"Value enforce_wordlist must be True, Trezor Core enforces words automatically."
)
if msg.dry_run:
# check that only allowed fields are set
for key, value in msg.__dict__.items():
if key not in DRY_RUN_ALLOWED_FIELDS and value is not None:
raise wire.ProcessError(
"Forbidden field set in dry-run: {}".format(key)
)
async def _continue_dialog(ctx: wire.Context, msg: RecoveryDevice) -> None:
if not msg.dry_run:

@ -403,6 +403,12 @@ void fsm_msgRecoveryDevice(const RecoveryDevice *msg) {
const bool dry_run = msg->has_dry_run ? msg->dry_run : false;
if (!dry_run) {
CHECK_NOT_INITIALIZED
} else {
CHECK_INITIALIZED
CHECK_PARAM(!msg->has_passphrase_protection && !msg->has_pin_protection &&
!msg->has_language && !msg->has_label &&
!msg->has_u2f_counter,
_("Forbidden field set in dry-run"))
}
CHECK_PARAM(!msg->has_word_count || msg->word_count == 12 ||

@ -18,7 +18,7 @@ import os
import time
import warnings
from . import messages as proto
from . import messages
from .exceptions import Cancelled
from .tools import expect, session
from .transport import enumerate_devices, get_transport
@ -44,7 +44,7 @@ class TrezorDevice:
return get_transport(path, prefix_search=False)
@expect(proto.Success, field="message")
@expect(messages.Success, field="message")
def apply_settings(
client,
label=None,
@ -55,7 +55,7 @@ def apply_settings(
auto_lock_delay_ms=None,
display_rotation=None,
):
settings = proto.ApplySettings()
settings = messages.ApplySettings()
if label is not None:
settings.label = label
if language:
@ -76,30 +76,30 @@ def apply_settings(
return out
@expect(proto.Success, field="message")
@expect(messages.Success, field="message")
def apply_flags(client, flags):
out = client.call(proto.ApplyFlags(flags=flags))
out = client.call(messages.ApplyFlags(flags=flags))
client.init_device() # Reload Features
return out
@expect(proto.Success, field="message")
@expect(messages.Success, field="message")
def change_pin(client, remove=False):
ret = client.call(proto.ChangePin(remove=remove))
ret = client.call(messages.ChangePin(remove=remove))
client.init_device() # Re-read features
return ret
@expect(proto.Success, field="message")
@expect(messages.Success, field="message")
def sd_protect(client, operation):
ret = client.call(proto.SdProtect(operation=operation))
ret = client.call(messages.SdProtect(operation=operation))
client.init_device()
return ret
@expect(proto.Success, field="message")
@expect(messages.Success, field="message")
def wipe(client):
ret = client.call(proto.WipeDevice())
ret = client.call(messages.WipeDevice())
client.init_device()
return ret
@ -112,7 +112,7 @@ def recover(
label=None,
language="english",
input_callback=None,
type=proto.RecoveryDeviceType.ScrambledWords,
type=messages.RecoveryDeviceType.ScrambledWords,
dry_run=False,
u2f_counter=None,
):
@ -130,32 +130,32 @@ def recover(
if u2f_counter is None:
u2f_counter = int(time.time())
res = client.call(
proto.RecoveryDevice(
word_count=word_count,
passphrase_protection=bool(passphrase_protection),
pin_protection=bool(pin_protection),
label=label,
language=language,
enforce_wordlist=True,
type=type,
dry_run=dry_run,
u2f_counter=u2f_counter,
)
msg = messages.RecoveryDevice(
word_count=word_count, enforce_wordlist=True, type=type, dry_run=dry_run
)
while isinstance(res, proto.WordRequest):
if not dry_run:
# set additional parameters
msg.passphrase_protection = passphrase_protection
msg.pin_protection = pin_protection
msg.label = label
msg.language = language
msg.u2f_counter = u2f_counter
res = client.call(msg)
while isinstance(res, messages.WordRequest):
try:
inp = input_callback(res.type)
res = client.call(proto.WordAck(word=inp))
res = client.call(messages.WordAck(word=inp))
except Cancelled:
res = client.call(proto.Cancel())
res = client.call(messages.Cancel())
client.init_device()
return res
@expect(proto.Success, field="message")
@expect(messages.Success, field="message")
@session
def reset(
client,
@ -168,7 +168,7 @@ def reset(
u2f_counter=0,
skip_backup=False,
no_backup=False,
backup_type=proto.BackupType.Bip39,
backup_type=messages.BackupType.Bip39,
):
if client.features.initialized:
raise RuntimeError(
@ -182,7 +182,7 @@ def reset(
strength = 128
# Begin with device reset workflow
msg = proto.ResetDevice(
msg = messages.ResetDevice(
display_random=bool(display_random),
strength=strength,
passphrase_protection=bool(passphrase_protection),
@ -196,17 +196,17 @@ def reset(
)
resp = client.call(msg)
if not isinstance(resp, proto.EntropyRequest):
if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest")
external_entropy = os.urandom(32)
# LOG.debug("Computer generated entropy: " + external_entropy.hex())
ret = client.call(proto.EntropyAck(entropy=external_entropy))
ret = client.call(messages.EntropyAck(entropy=external_entropy))
client.init_device()
return ret
@expect(proto.Success, field="message")
@expect(messages.Success, field="message")
def backup(client):
ret = client.call(proto.BackupDevice())
ret = client.call(messages.BackupDevice())
return ret

@ -16,54 +16,190 @@
import pytest
from trezorlib import messages as proto
from trezorlib import device, exceptions, messages, protobuf
from .. import buttons
from ..common import MNEMONIC12
def do_recover_legacy(client, mnemonic, **kwargs):
def input_callback(_):
word, pos = client.debug.read_recovery_word()
if pos != 0:
word = mnemonic[pos - 1]
mnemonic[pos - 1] = None
assert word is not None
return word
ret = device.recover(
client,
dry_run=True,
word_count=len(mnemonic),
type=messages.RecoveryDeviceType.ScrambledWords,
input_callback=input_callback,
**kwargs
)
# if the call succeeded, check that all words have been used
assert all(m is None for m in mnemonic)
return ret
def do_recover_core(client, mnemonic, **kwargs):
def input_flow():
yield
layout = client.debug.wait_layout()
assert "check the recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert "Select number of words" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert layout.text == "WordSelector"
# click the number
word_option_offset = 6
word_options = (12, 18, 20, 24, 33)
index = word_option_offset + word_options.index(len(mnemonic))
client.debug.click(buttons.grid34(index % 3, index // 3))
yield
layout = client.debug.wait_layout()
assert "Enter recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
for word in mnemonic:
client.debug.input(word)
yield
client.debug.click(buttons.OK)
with client:
client.set_input_flow(input_flow)
return device.recover(client, dry_run=True, **kwargs)
def do_recover(client, mnemonic):
if client.features.model == "1":
return do_recover_legacy(client, mnemonic)
else:
return do_recover_core(client, mnemonic)
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_dry_run(client):
ret = do_recover(client, MNEMONIC12.split(" "))
assert isinstance(ret, messages.Success)
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_seed_mismatch(client):
with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["all"] * 12)
assert "does not match the one in the device" in exc.value.failure.message
@pytest.mark.skip_t2
class TestMsgRecoverydeviceDryrun:
def recovery_loop(self, client, mnemonic, result):
ret = client.call_raw(
proto.RecoveryDevice(
word_count=12,
passphrase_protection=False,
pin_protection=False,
label="label",
language="english",
enforce_wordlist=True,
dry_run=True,
)
)
fakes = 0
for _ in range(int(12 * 2)):
assert isinstance(ret, proto.WordRequest)
(word, pos) = client.debug.read_recovery_word()
if pos != 0:
ret = client.call_raw(proto.WordAck(word=mnemonic[pos - 1]))
mnemonic[pos - 1] = None
else:
ret = client.call_raw(proto.WordAck(word=word))
fakes += 1
print(mnemonic)
assert isinstance(ret, proto.ButtonRequest)
client.debug.press_yes()
ret = client.call_raw(proto.ButtonAck())
assert isinstance(ret, result)
def test_correct_notsame(self, client):
mnemonic = MNEMONIC12.split(" ")
self.recovery_loop(client, mnemonic, proto.Failure)
def test_correct_same(self, client):
mnemonic = ["all"] * 12
self.recovery_loop(client, mnemonic, proto.Success)
def test_incorrect(self, client):
mnemonic = ["stick"] * 12
self.recovery_loop(client, mnemonic, proto.Failure)
def test_invalid_seed_t1(client):
with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["stick"] * 12)
assert "Invalid seed" in exc.value.failure.message
@pytest.mark.skip_t1
def test_invalid_seed_core(client):
def input_flow():
yield
layout = client.debug.wait_layout()
assert "check the recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert "Select number of words" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert layout.text == "WordSelector"
# select 12 words
client.debug.click(buttons.grid34(0, 2))
yield
layout = client.debug.wait_layout()
assert "Enter recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
for _ in range(12):
client.debug.input("stick")
code = yield
layout = client.debug.wait_layout()
assert code == messages.ButtonRequestType.Warning
assert "invalid recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
# retry screen
layout = client.debug.wait_layout()
assert "Select number of words" in layout.text
client.debug.click(buttons.CANCEL)
yield
layout = client.debug.wait_layout()
assert "abort" in layout.text
client.debug.click(buttons.OK)
with client:
client.set_input_flow(input_flow)
with pytest.raises(exceptions.Cancelled):
return device.recover(client, dry_run=True)
@pytest.mark.setup_client(uninitialized=True)
def test_uninitialized(client):
with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["all"] * 12)
assert "not initialized" in exc.value.failure.message
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
def _make_bad_params():
"""Generate a list of field names that must NOT be set on a dry-run message,
and default values of the appropriate type.
"""
for fname, ftype, _ in messages.RecoveryDevice.get_fields().values():
if fname in DRY_RUN_ALLOWED_FIELDS:
continue
if ftype is protobuf.UVarintType:
yield fname, 1
elif ftype is protobuf.BoolType:
yield fname, True
elif ftype is protobuf.UnicodeType:
yield fname, "test"
else:
# Someone added a field to RecoveryDevice of a type that has no assigned
# default value. This test must be fixed.
raise RuntimeError("unknown field in RecoveryDevice")
@pytest.mark.parametrize("field_name, field_value", _make_bad_params())
def test_bad_parameters(client, field_name, field_value):
msg = messages.RecoveryDevice(
dry_run=True,
word_count=12,
enforce_wordlist=True,
type=messages.RecoveryDeviceType.ScrambledWords,
)
setattr(msg, field_name, field_value)
with pytest.raises(exceptions.TrezorFailure) as exc:
client.call(msg)
assert "Forbidden field set in dry-run" in exc.value.failure.message

Loading…
Cancel
Save