all: disallow most RecoveryDevice fields in dry-run (fixes #666)

pull/723/head
matejcik 5 years ago committed by matejcik
parent b6d46e93e1
commit 34913a328a

@ -22,6 +22,12 @@ if False:
from trezor.messages.RecoveryDevice import RecoveryDevice from trezor.messages.RecoveryDevice import RecoveryDevice
# List of RecoveryDevice fields that can be set when doing dry-run recovery.
# All except `dry_run` are allowed for T1 compatibility, but their values are ignored.
# If set, `enforce_wordlist` must be True, because we do not support non-enforcing.
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
""" """
Recover BIP39/SLIP39 seed into empty device. Recover BIP39/SLIP39 seed into empty device.
@ -29,7 +35,7 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
User starts the process here using the RecoveryDevice msg and then they can unplug User starts the process here using the RecoveryDevice msg and then they can unplug
the device anytime and continue without a computer. the device anytime and continue without a computer.
""" """
_check_state(msg) _validate(msg)
if storage.recovery.is_in_progress(): if storage.recovery.is_in_progress():
return await recovery_process(ctx) return await recovery_process(ctx)
@ -43,27 +49,26 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
await show_pin_invalid(ctx) await show_pin_invalid(ctx)
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
# set up pin if requested if not msg.dry_run:
if msg.pin_protection: # set up pin if requested
if msg.dry_run: if msg.pin_protection:
raise wire.ProcessError("Can't setup PIN during dry_run recovery.") newpin = await request_pin_confirm(ctx, allow_cancel=False)
newpin = await request_pin_confirm(ctx, allow_cancel=False) config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
if msg.u2f_counter is not None:
if msg.u2f_counter: storage.device.set_u2f_counter(msg.u2f_counter)
storage.device.set_u2f_counter(msg.u2f_counter) storage.device.load_settings(
storage.device.load_settings( label=msg.label, use_passphrase=msg.passphrase_protection
label=msg.label, use_passphrase=msg.passphrase_protection )
)
storage.recovery.set_in_progress(True) storage.recovery.set_in_progress(True)
if msg.dry_run: storage.recovery.set_dry_run(bool(msg.dry_run))
storage.recovery.set_dry_run(msg.dry_run)
workflow.replace_default(recovery_homescreen) workflow.replace_default(recovery_homescreen)
return await recovery_process(ctx) return await recovery_process(ctx)
def _check_state(msg: RecoveryDevice) -> None: def _validate(msg: RecoveryDevice) -> None:
if not msg.dry_run and storage.is_initialized(): if not msg.dry_run and storage.is_initialized():
raise wire.UnexpectedMessage("Already initialized") raise wire.UnexpectedMessage("Already initialized")
if msg.dry_run and not storage.is_initialized(): if msg.dry_run and not storage.is_initialized():
@ -74,6 +79,14 @@ def _check_state(msg: RecoveryDevice) -> None:
"Value enforce_wordlist must be True, Trezor Core enforces words automatically." "Value enforce_wordlist must be True, Trezor Core enforces words automatically."
) )
if msg.dry_run:
# check that only allowed fields are set
for key, value in msg.__dict__.items():
if key not in DRY_RUN_ALLOWED_FIELDS and value is not None:
raise wire.ProcessError(
"Forbidden field set in dry-run: {}".format(key)
)
async def _continue_dialog(ctx: wire.Context, msg: RecoveryDevice) -> None: async def _continue_dialog(ctx: wire.Context, msg: RecoveryDevice) -> None:
if not msg.dry_run: if not msg.dry_run:

@ -403,6 +403,12 @@ void fsm_msgRecoveryDevice(const RecoveryDevice *msg) {
const bool dry_run = msg->has_dry_run ? msg->dry_run : false; const bool dry_run = msg->has_dry_run ? msg->dry_run : false;
if (!dry_run) { if (!dry_run) {
CHECK_NOT_INITIALIZED CHECK_NOT_INITIALIZED
} else {
CHECK_INITIALIZED
CHECK_PARAM(!msg->has_passphrase_protection && !msg->has_pin_protection &&
!msg->has_language && !msg->has_label &&
!msg->has_u2f_counter,
_("Forbidden field set in dry-run"))
} }
CHECK_PARAM(!msg->has_word_count || msg->word_count == 12 || CHECK_PARAM(!msg->has_word_count || msg->word_count == 12 ||

@ -18,7 +18,7 @@ import os
import time import time
import warnings import warnings
from . import messages as proto from . import messages
from .exceptions import Cancelled from .exceptions import Cancelled
from .tools import expect, session from .tools import expect, session
from .transport import enumerate_devices, get_transport from .transport import enumerate_devices, get_transport
@ -44,7 +44,7 @@ class TrezorDevice:
return get_transport(path, prefix_search=False) return get_transport(path, prefix_search=False)
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def apply_settings( def apply_settings(
client, client,
label=None, label=None,
@ -55,7 +55,7 @@ def apply_settings(
auto_lock_delay_ms=None, auto_lock_delay_ms=None,
display_rotation=None, display_rotation=None,
): ):
settings = proto.ApplySettings() settings = messages.ApplySettings()
if label is not None: if label is not None:
settings.label = label settings.label = label
if language: if language:
@ -76,30 +76,30 @@ def apply_settings(
return out return out
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def apply_flags(client, flags): def apply_flags(client, flags):
out = client.call(proto.ApplyFlags(flags=flags)) out = client.call(messages.ApplyFlags(flags=flags))
client.init_device() # Reload Features client.init_device() # Reload Features
return out return out
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def change_pin(client, remove=False): def change_pin(client, remove=False):
ret = client.call(proto.ChangePin(remove=remove)) ret = client.call(messages.ChangePin(remove=remove))
client.init_device() # Re-read features client.init_device() # Re-read features
return ret return ret
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def sd_protect(client, operation): def sd_protect(client, operation):
ret = client.call(proto.SdProtect(operation=operation)) ret = client.call(messages.SdProtect(operation=operation))
client.init_device() client.init_device()
return ret return ret
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def wipe(client): def wipe(client):
ret = client.call(proto.WipeDevice()) ret = client.call(messages.WipeDevice())
client.init_device() client.init_device()
return ret return ret
@ -112,7 +112,7 @@ def recover(
label=None, label=None,
language="english", language="english",
input_callback=None, input_callback=None,
type=proto.RecoveryDeviceType.ScrambledWords, type=messages.RecoveryDeviceType.ScrambledWords,
dry_run=False, dry_run=False,
u2f_counter=None, u2f_counter=None,
): ):
@ -130,32 +130,32 @@ def recover(
if u2f_counter is None: if u2f_counter is None:
u2f_counter = int(time.time()) u2f_counter = int(time.time())
res = client.call( msg = messages.RecoveryDevice(
proto.RecoveryDevice( word_count=word_count, enforce_wordlist=True, type=type, dry_run=dry_run
word_count=word_count,
passphrase_protection=bool(passphrase_protection),
pin_protection=bool(pin_protection),
label=label,
language=language,
enforce_wordlist=True,
type=type,
dry_run=dry_run,
u2f_counter=u2f_counter,
)
) )
while isinstance(res, proto.WordRequest): if not dry_run:
# set additional parameters
msg.passphrase_protection = passphrase_protection
msg.pin_protection = pin_protection
msg.label = label
msg.language = language
msg.u2f_counter = u2f_counter
res = client.call(msg)
while isinstance(res, messages.WordRequest):
try: try:
inp = input_callback(res.type) inp = input_callback(res.type)
res = client.call(proto.WordAck(word=inp)) res = client.call(messages.WordAck(word=inp))
except Cancelled: except Cancelled:
res = client.call(proto.Cancel()) res = client.call(messages.Cancel())
client.init_device() client.init_device()
return res return res
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
@session @session
def reset( def reset(
client, client,
@ -168,7 +168,7 @@ def reset(
u2f_counter=0, u2f_counter=0,
skip_backup=False, skip_backup=False,
no_backup=False, no_backup=False,
backup_type=proto.BackupType.Bip39, backup_type=messages.BackupType.Bip39,
): ):
if client.features.initialized: if client.features.initialized:
raise RuntimeError( raise RuntimeError(
@ -182,7 +182,7 @@ def reset(
strength = 128 strength = 128
# Begin with device reset workflow # Begin with device reset workflow
msg = proto.ResetDevice( msg = messages.ResetDevice(
display_random=bool(display_random), display_random=bool(display_random),
strength=strength, strength=strength,
passphrase_protection=bool(passphrase_protection), passphrase_protection=bool(passphrase_protection),
@ -196,17 +196,17 @@ def reset(
) )
resp = client.call(msg) resp = client.call(msg)
if not isinstance(resp, proto.EntropyRequest): if not isinstance(resp, messages.EntropyRequest):
raise RuntimeError("Invalid response, expected EntropyRequest") raise RuntimeError("Invalid response, expected EntropyRequest")
external_entropy = os.urandom(32) external_entropy = os.urandom(32)
# LOG.debug("Computer generated entropy: " + external_entropy.hex()) # LOG.debug("Computer generated entropy: " + external_entropy.hex())
ret = client.call(proto.EntropyAck(entropy=external_entropy)) ret = client.call(messages.EntropyAck(entropy=external_entropy))
client.init_device() client.init_device()
return ret return ret
@expect(proto.Success, field="message") @expect(messages.Success, field="message")
def backup(client): def backup(client):
ret = client.call(proto.BackupDevice()) ret = client.call(messages.BackupDevice())
return ret return ret

@ -16,54 +16,190 @@
import pytest import pytest
from trezorlib import messages as proto from trezorlib import device, exceptions, messages, protobuf
from .. import buttons
from ..common import MNEMONIC12 from ..common import MNEMONIC12
def do_recover_legacy(client, mnemonic, **kwargs):
def input_callback(_):
word, pos = client.debug.read_recovery_word()
if pos != 0:
word = mnemonic[pos - 1]
mnemonic[pos - 1] = None
assert word is not None
return word
ret = device.recover(
client,
dry_run=True,
word_count=len(mnemonic),
type=messages.RecoveryDeviceType.ScrambledWords,
input_callback=input_callback,
**kwargs
)
# if the call succeeded, check that all words have been used
assert all(m is None for m in mnemonic)
return ret
def do_recover_core(client, mnemonic, **kwargs):
def input_flow():
yield
layout = client.debug.wait_layout()
assert "check the recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert "Select number of words" in layout.text
client.debug.click(buttons.OK)
yield
layout = client.debug.wait_layout()
assert layout.text == "WordSelector"
# click the number
word_option_offset = 6
word_options = (12, 18, 20, 24, 33)
index = word_option_offset + word_options.index(len(mnemonic))
client.debug.click(buttons.grid34(index % 3, index // 3))
yield
layout = client.debug.wait_layout()
assert "Enter recovery seed" in layout.text
client.debug.click(buttons.OK)
yield
for word in mnemonic:
client.debug.input(word)
yield
client.debug.click(buttons.OK)
with client:
client.set_input_flow(input_flow)
return device.recover(client, dry_run=True, **kwargs)
def do_recover(client, mnemonic):
if client.features.model == "1":
return do_recover_legacy(client, mnemonic)
else:
return do_recover_core(client, mnemonic)
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_dry_run(client):
ret = do_recover(client, MNEMONIC12.split(" "))
assert isinstance(ret, messages.Success)
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_seed_mismatch(client):
with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["all"] * 12)
assert "does not match the one in the device" in exc.value.failure.message
@pytest.mark.skip_t2 @pytest.mark.skip_t2
class TestMsgRecoverydeviceDryrun: def test_invalid_seed_t1(client):
def recovery_loop(self, client, mnemonic, result): with pytest.raises(exceptions.TrezorFailure) as exc:
ret = client.call_raw( do_recover(client, ["stick"] * 12)
proto.RecoveryDevice( assert "Invalid seed" in exc.value.failure.message
word_count=12,
passphrase_protection=False,
pin_protection=False, @pytest.mark.skip_t1
label="label", def test_invalid_seed_core(client):
language="english", def input_flow():
enforce_wordlist=True, yield
dry_run=True, layout = client.debug.wait_layout()
) assert "check the recovery seed" in layout.text
) client.debug.click(buttons.OK)
fakes = 0 yield
for _ in range(int(12 * 2)): layout = client.debug.wait_layout()
assert isinstance(ret, proto.WordRequest) assert "Select number of words" in layout.text
(word, pos) = client.debug.read_recovery_word() client.debug.click(buttons.OK)
if pos != 0: yield
ret = client.call_raw(proto.WordAck(word=mnemonic[pos - 1])) layout = client.debug.wait_layout()
mnemonic[pos - 1] = None assert layout.text == "WordSelector"
else: # select 12 words
ret = client.call_raw(proto.WordAck(word=word)) client.debug.click(buttons.grid34(0, 2))
fakes += 1
yield
print(mnemonic) layout = client.debug.wait_layout()
assert "Enter recovery seed" in layout.text
assert isinstance(ret, proto.ButtonRequest) client.debug.click(buttons.OK)
client.debug.press_yes()
yield
ret = client.call_raw(proto.ButtonAck()) for _ in range(12):
assert isinstance(ret, result) client.debug.input("stick")
def test_correct_notsame(self, client): code = yield
mnemonic = MNEMONIC12.split(" ") layout = client.debug.wait_layout()
self.recovery_loop(client, mnemonic, proto.Failure) assert code == messages.ButtonRequestType.Warning
assert "invalid recovery seed" in layout.text
def test_correct_same(self, client): client.debug.click(buttons.OK)
mnemonic = ["all"] * 12
self.recovery_loop(client, mnemonic, proto.Success) yield
# retry screen
def test_incorrect(self, client): layout = client.debug.wait_layout()
mnemonic = ["stick"] * 12 assert "Select number of words" in layout.text
self.recovery_loop(client, mnemonic, proto.Failure) client.debug.click(buttons.CANCEL)
yield
layout = client.debug.wait_layout()
assert "abort" in layout.text
client.debug.click(buttons.OK)
with client:
client.set_input_flow(input_flow)
with pytest.raises(exceptions.Cancelled):
return device.recover(client, dry_run=True)
@pytest.mark.setup_client(uninitialized=True)
def test_uninitialized(client):
with pytest.raises(exceptions.TrezorFailure) as exc:
do_recover(client, ["all"] * 12)
assert "not initialized" in exc.value.failure.message
DRY_RUN_ALLOWED_FIELDS = ("dry_run", "word_count", "enforce_wordlist", "type")
def _make_bad_params():
"""Generate a list of field names that must NOT be set on a dry-run message,
and default values of the appropriate type.
"""
for fname, ftype, _ in messages.RecoveryDevice.get_fields().values():
if fname in DRY_RUN_ALLOWED_FIELDS:
continue
if ftype is protobuf.UVarintType:
yield fname, 1
elif ftype is protobuf.BoolType:
yield fname, True
elif ftype is protobuf.UnicodeType:
yield fname, "test"
else:
# Someone added a field to RecoveryDevice of a type that has no assigned
# default value. This test must be fixed.
raise RuntimeError("unknown field in RecoveryDevice")
@pytest.mark.parametrize("field_name, field_value", _make_bad_params())
def test_bad_parameters(client, field_name, field_value):
msg = messages.RecoveryDevice(
dry_run=True,
word_count=12,
enforce_wordlist=True,
type=messages.RecoveryDeviceType.ScrambledWords,
)
setattr(msg, field_name, field_value)
with pytest.raises(exceptions.TrezorFailure) as exc:
client.call(msg)
assert "Forbidden field set in dry-run" in exc.value.failure.message

Loading…
Cancel
Save