1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-14 17:31:04 +00:00

all: disallow most RecoveryDevice fields in dry-run (fixes #666)

This commit is contained in:
matejcik 2019-11-19 16:42:41 +01:00 committed by matejcik
parent b6d46e93e1
commit 34913a328a
4 changed files with 242 additions and 87 deletions

View File

@ -22,6 +22,12 @@ if False:
from trezor.messages.RecoveryDevice import RecoveryDevice from trezor.messages.RecoveryDevice import RecoveryDevice
# List of RecoveryDevice fields that can be set when doing dry-run recovery.
# All except `dry_run` are allowed for T1 compatibility, but their values are ignored.
# If set, `enforce_wordlist` must be True, because we do not support non-enforcing.
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
""" """
Recover BIP39/SLIP39 seed into empty device. Recover BIP39/SLIP39 seed into empty device.
@ -29,7 +35,7 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
User starts the process here using the RecoveryDevice msg and then they can unplug User starts the process here using the RecoveryDevice msg and then they can unplug
the device anytime and continue without a computer. the device anytime and continue without a computer.
""" """
_check_state(msg) _validate(msg)
if storage.recovery.is_in_progress(): if storage.recovery.is_in_progress():
return await recovery_process(ctx) return await recovery_process(ctx)
@ -43,27 +49,26 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
await show_pin_invalid(ctx) await show_pin_invalid(ctx)
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
if not msg.dry_run:
# set up pin if requested # set up pin if requested
if msg.pin_protection: if msg.pin_protection:
if msg.dry_run:
raise wire.ProcessError("Can't setup PIN during dry_run recovery.")
newpin = await request_pin_confirm(ctx, allow_cancel=False) newpin = await request_pin_confirm(ctx, allow_cancel=False)
config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None) config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
if msg.u2f_counter: if msg.u2f_counter is not None:
storage.device.set_u2f_counter(msg.u2f_counter) storage.device.set_u2f_counter(msg.u2f_counter)
storage.device.load_settings( storage.device.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection label=msg.label, use_passphrase=msg.passphrase_protection
) )
storage.recovery.set_in_progress(True) storage.recovery.set_in_progress(True)
if msg.dry_run: storage.recovery.set_dry_run(bool(msg.dry_run))
storage.recovery.set_dry_run(msg.dry_run)
workflow.replace_default(recovery_homescreen) workflow.replace_default(recovery_homescreen)
return await recovery_process(ctx) return await recovery_process(ctx)
def _check_state(msg: RecoveryDevice) -> None: def _validate(msg: RecoveryDevice) -> None:
if not msg.dry_run and storage.is_initialized(): if not msg.dry_run and storage.is_initialized():
raise wire.UnexpectedMessage("Already initialized") raise wire.UnexpectedMessage("Already initialized")
if msg.dry_run and not storage.is_initialized(): if msg.dry_run and not storage.is_initialized():
@ -74,6 +79,14 @@ def _check_state(msg: RecoveryDevice) -> None:
"Value enforce_wordlist must be True, Trezor Core enforces words automatically." "Value enforce_wordlist must be True, Trezor Core enforces words automatically."
) )
if msg.dry_run:
# check that only allowed fields are set
for key, value in msg.__dict__.items():
if key not in DRY_RUN_ALLOWED_FIELDS and value is not None:
raise wire.ProcessError(
"Forbidden field set in dry-run: {}".format(key)
)
async def _continue_dialog(ctx: wire.Context, msg: RecoveryDevice) -> None: async def _continue_dialog(ctx: wire.Context, msg: RecoveryDevice) -> None:
if not msg.dry_run: if not msg.dry_run:

View File

@ -403,6 +403,12 @@ void fsm_msgRecoveryDevice(const RecoveryDevice *msg) {
const bool dry_run = msg->has_dry_run ? msg->dry_run : false; const bool dry_run = msg->has_dry_run ? msg->dry_run : false;
if (!dry_run) { if (!dry_run) {
CHECK_NOT_INITIALIZED CHECK_NOT_INITIALIZED
} else {
CHECK_INITIALIZED
CHECK_PARAM(!msg->has_passphrase_protection && !msg->has_pin_protection &&
!msg->has_language && !msg->has_label &&
!msg->has_u2f_counter,
_("Forbidden field set in dry-run"))
} }
CHECK_PARAM(!msg->has_word_count || msg->word_count == 12 || CHECK_PARAM(!msg->has_word_count || msg->word_count == 12 ||

View File

@ -18,7 +18,7 @@ import os
import time import time
import warnings import warnings
from . import messages as proto from . import messages
from .exceptions import Cancelled from .exceptions import Cancelled
from .tools import expect, session from .tools import expect, session
from .transport import enumerate_devices, get_transport from .transport import enumerate_devices, get_transport
@ -44,7 +44,7 @@ class TrezorDevice:
return get_transport(path, prefix_search=False) return get_transport(path, prefix_search=False)
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def apply_settings( def apply_settings(
client, client,
label=None, label=None,
@ -55,7 +55,7 @@ def apply_settings(
auto_lock_delay_ms=None, auto_lock_delay_ms=None,
display_rotation=None, display_rotation=None,
): ):
settings = proto.ApplySettings() settings = messages.ApplySettings()
if label is not None: if label is not None:
settings.label = label settings.label = label
if language: if language:
@ -76,30 +76,30 @@ def apply_settings(
return out return out
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def apply_flags(client, flags): def apply_flags(client, flags):
out = client.call(proto.ApplyFlags(flags=flags)) out = client.call(messages.ApplyFlags(flags=flags))
client.init_device() # Reload Features client.init_device() # Reload Features
return out return out
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def change_pin(client, remove=False): def change_pin(client, remove=False):
ret = client.call(proto.ChangePin(remove=remove)) ret = client.call(messages.ChangePin(remove=remove))
client.init_device() # Re-read features client.init_device() # Re-read features
return ret return ret
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def sd_protect(client, operation): def sd_protect(client, operation):
ret = client.call(proto.SdProtect(operation=operation)) ret = client.call(messages.SdProtect(operation=operation))
client.init_device() client.init_device()
return ret return ret
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def wipe(client): def wipe(client):
ret = client.call(proto.WipeDevice()) ret = client.call(messages.WipeDevice())
client.init_device() client.init_device()
return ret return ret
@ -112,7 +112,7 @@ def recover(
label=None, label=None,
language="english", language="english",
input_callback=None, input_callback=None,
type=proto.RecoveryDeviceType.ScrambledWords, type=messages.RecoveryDeviceType.ScrambledWords,
dry_run=False, dry_run=False,
u2f_counter=None, u2f_counter=None,
): ):
@ -130,32 +130,32 @@ def recover(
if u2f_counter is None: if u2f_counter is None:
u2f_counter = int(time.time()) u2f_counter = int(time.time())
res = client.call( msg = messages.RecoveryDevice(
proto.RecoveryDevice( word_count=word_count, enforce_wordlist=True, type=type, dry_run=dry_run
word_count=word_count,
passphrase_protection=bool(passphrase_protection),
pin_protection=bool(pin_protection),
label=label,
language=language,
enforce_wordlist=True,
type=type,
dry_run=dry_run,
u2f_counter=u2f_counter,
)
) )
while isinstance(res, proto.WordRequest): if not dry_run:
# set additional parameters
msg.passphrase_protection = passphrase_protection
msg.pin_protection = pin_protection
msg.label = label
msg.language = language
msg.u2f_counter = u2f_counter
res = client.call(msg)
while isinstance(res, messages.WordRequest):
try: try:
inp = input_callback(res.type) inp = input_callback(res.type)
res = client.call(proto.WordAck(word=inp)) res = client.call(messages.WordAck(word=inp))
except Cancelled: except Cancelled:
res = client.call(proto.Cancel()) res = client.call(messages.Cancel())
client.init_device() client.init_device()
return res return res
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
@session @session
def reset( def reset(
client, client,
@ -168,7 +168,7 @@ def reset(
u2f_counter=0, u2f_counter=0,
skip_backup=False, skip_backup=False,
no_backup=False, no_backup=False,
backup_type=proto.BackupType.Bip39, backup_type=messages.BackupType.Bip39,
): ):
if client.features.initialized: if client.features.initialized:
raise RuntimeError( raise RuntimeError(
@ -182,7 +182,7 @@ def reset(
strength = 128 strength = 128
# Begin with device reset workflow # Begin with device reset workflow
msg = proto.ResetDevice( msg = messages.ResetDevice(
display_random=bool(display_random), display_random=bool(display_random),
strength=strength, strength=strength,
passphrase_protection=bool(passphrase_protection), passphrase_protection=bool(passphrase_protection),
@ -196,17 +196,17 @@ def reset(
) )
resp = client.call(msg) resp = client.call(msg)
if not isinstance(resp, proto.EntropyRequest): if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest") raise RuntimeError("Invalid response, expected EntropyRequest")
external_entropy = os.urandom(32) external_entropy = os.urandom(32)
# LOG.debug("Computer generated entropy: " + external_entropy.hex()) # LOG.debug("Computer generated entropy: " + external_entropy.hex())
ret = client.call(proto.EntropyAck(entropy=external_entropy)) ret = client.call(messages.EntropyAck(entropy=external_entropy))
client.init_device() client.init_device()
return ret return ret
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def backup(client): def backup(client):
ret = client.call(proto.BackupDevice()) ret = client.call(messages.BackupDevice())
return ret return ret

View File

@ -16,54 +16,190 @@
import pytest import pytest
from trezorlib import messages as proto from trezorlib import device, exceptions, messages, protobuf
from .. import buttons
from ..common import MNEMONIC12 from ..common import MNEMONIC12
@pytest.mark.skip_t2 def do_recover_legacy(client, mnemonic, **kwargs):
class TestMsgRecoverydeviceDryrun: def input_callback(_):
def recovery_loop(self, client, mnemonic, result): word, pos = client.debug.read_recovery_word()
ret = client.call_raw(
proto.RecoveryDevice(
word_count=12,
passphrase_protection=False,
pin_protection=False,
label="label",
language="english",
enforce_wordlist=True,
dry_run=True,
)
)
fakes = 0
for _ in range(int(12 * 2)):
assert isinstance(ret, proto.WordRequest)
(word, pos) = client.debug.read_recovery_word()
if pos != 0: if pos != 0:
ret = client.call_raw(proto.WordAck(word=mnemonic[pos - 1])) word = mnemonic[pos - 1]
mnemonic[pos - 1] = None mnemonic[pos - 1] = None
assert word is not None
return word
ret = device.recover(
client,
dry_run=True,
word_count=len(mnemonic),
type=messages.RecoveryDeviceType.ScrambledWords,
input_callback=input_callback,
**kwargs
)
# if the call succeeded, check that all words have been used
assert all(m is None for m in mnemonic)
return ret
def do_recover_core(client, mnemonic, **kwargs):
def input_flow():
yield
layout = client.debug.wait_layout()
assert "check the recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert "Select number of words" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert layout.text == "WordSelector"
# click the number
word_option_offset = 6
word_options = (12, 18, 20, 24, 33)
index = word_option_offset + word_options.index(len(mnemonic))
client.debug.click(buttons.grid34(index % 3, index // 3))
yield
layout = client.debug.wait_layout()
assert "Enter recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
for word in mnemonic:
client.debug.input(word)
yield
client.debug.click(buttons.OK)
with client:
client.set_input_flow(input_flow)
return device.recover(client, dry_run=True, **kwargs)
def do_recover(client, mnemonic):
if client.features.model == "1":
return do_recover_legacy(client, mnemonic)
else: else:
ret = client.call_raw(proto.WordAck(word=word)) return do_recover_core(client, mnemonic)
fakes += 1
print(mnemonic)
assert isinstance(ret, proto.ButtonRequest) @pytest.mark.setup_client(mnemonic=MNEMONIC12)
client.debug.press_yes() def test_dry_run(client):
ret = do_recover(client, MNEMONIC12.split(" "))
assert isinstance(ret, messages.Success)
ret = client.call_raw(proto.ButtonAck())
assert isinstance(ret, result)
def test_correct_notsame(self, client): @pytest.mark.setup_client(mnemonic=MNEMONIC12)
mnemonic = MNEMONIC12.split(" ") def test_seed_mismatch(client):
self.recovery_loop(client, mnemonic, proto.Failure) with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["all"] * 12)
assert "does not match the one in the device" in exc.value.failure.message
def test_correct_same(self, client):
mnemonic = ["all"] * 12
self.recovery_loop(client, mnemonic, proto.Success)
def test_incorrect(self, client): @pytest.mark.skip_t2
mnemonic = ["stick"] * 12 def test_invalid_seed_t1(client):
self.recovery_loop(client, mnemonic, proto.Failure) with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["stick"] * 12)
assert "Invalid seed" in exc.value.failure.message
@pytest.mark.skip_t1
def test_invalid_seed_core(client):
def input_flow():
yield
layout = client.debug.wait_layout()
assert "check the recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert "Select number of words" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert layout.text == "WordSelector"
# select 12 words
client.debug.click(buttons.grid34(0, 2))
yield
layout = client.debug.wait_layout()
assert "Enter recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
for _ in range(12):
client.debug.input("stick")
code = yield
layout = client.debug.wait_layout()
assert code == messages.ButtonRequestType.Warning
assert "invalid recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
# retry screen
layout = client.debug.wait_layout()
assert "Select number of words" in layout.text
client.debug.click(buttons.CANCEL)
yield
layout = client.debug.wait_layout()
assert "abort" in layout.text
client.debug.click(buttons.OK)
with client:
client.set_input_flow(input_flow)
with pytest.raises(exceptions.Cancelled):
return device.recover(client, dry_run=True)
@pytest.mark.setup_client(uninitialized=True)
def test_uninitialized(client):
with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["all"] * 12)
assert "not initialized" in exc.value.failure.message
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
def _make_bad_params():
"""Generate a list of field names that must NOT be set on a dry-run message,
and default values of the appropriate type.
"""
for fname, ftype, _ in messages.RecoveryDevice.get_fields().values():
if fname in DRY_RUN_ALLOWED_FIELDS:
continue
if ftype is protobuf.UVarintType:
yield fname, 1
elif ftype is protobuf.BoolType:
yield fname, True
elif ftype is protobuf.UnicodeType:
yield fname, "test"
else:
# Someone added a field to RecoveryDevice of a type that has no assigned
# default value. This test must be fixed.
raise RuntimeError("unknown field in RecoveryDevice")
@pytest.mark.parametrize("field_name, field_value", _make_bad_params())
def test_bad_parameters(client, field_name, field_value):
msg = messages.RecoveryDevice(
dry_run=True,
word_count=12,
enforce_wordlist=True,
type=messages.RecoveryDeviceType.ScrambledWords,
)
setattr(msg, field_name, field_value)
with pytest.raises(exceptions.TrezorFailure) as exc:
client.call(msg)
assert "Forbidden field set in dry-run" in exc.value.failure.message