fixup! feat(core): add ability to request backups with any number of groups/shares.

Ioan Bizău 3 weeks ago
parent a279d9bcbc
commit 39e10c48c0

@ -1,9 +1,14 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.enums import BackupType
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import BackupDevice, Success from trezor.messages import BackupDevice, Success
BAK_T_BIP39 = BackupType.Bip39 # global_import_cache
async def backup_device(msg: BackupDevice) -> Success: async def backup_device(msg: BackupDevice) -> Success:
import storage.device as storage_device import storage.device as storage_device
from trezor import wire from trezor import wire
@ -11,7 +16,7 @@ async def backup_device(msg: BackupDevice) -> Success:
from apps.common import mnemonic from apps.common import mnemonic
from .reset_device import backup_seed, layout from .reset_device import backup_seed, backup_slip39_custom, layout
if not storage_device.is_initialized(): if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
@ -22,15 +27,26 @@ async def backup_device(msg: BackupDevice) -> Success:
if mnemonic_secret is None: if mnemonic_secret is None:
raise RuntimeError raise RuntimeError
group_threshold = msg.group_threshold
groups = [(g.member_threshold, g.member_count) for g in msg.groups]
if group_threshold is not None:
if group_threshold < 1:
raise wire.DataError("group_threshold must be a positive integer")
if len(groups) < group_threshold:
raise wire.DataError("Not enough groups provided for group_threshold")
if backup_type == BAK_T_BIP39:
raise wire.ProcessError("Expected SLIP39 backup")
elif len(groups) > 0:
raise wire.DataError("group_threshold is missing")
storage_device.set_unfinished_backup(True) storage_device.set_unfinished_backup(True)
storage_device.set_backed_up() storage_device.set_backed_up()
await backup_seed( if group_threshold is not None:
backup_type, await backup_slip39_custom(mnemonic_secret, group_threshold, groups)
mnemonic_secret, else:
msg.group_threshold, await backup_seed(backup_type, mnemonic_secret)
[(g.member_threshold, g.member_count) for g in msg.groups],
)
storage_device.set_unfinished_backup(False) storage_device.set_unfinished_backup(False)

@ -124,7 +124,7 @@ async def _backup_slip39_basic(encrypted_master_secret: bytes) -> None:
share_threshold = await layout.slip39_prompt_threshold(share_count) share_threshold = await layout.slip39_prompt_threshold(share_count)
mnemonics = _get_slip39_mnemonics( mnemonics = _get_slip39_mnemonics(
encrypted_master_secret, group_threshold, [(share_threshold, share_count)] encrypted_master_secret, group_threshold, ((share_threshold, share_count),)
) )
# show and confirm individual shares # show and confirm individual shares
@ -155,11 +155,11 @@ async def _backup_slip39_advanced(encrypted_master_secret: bytes) -> None:
await layout.slip39_advanced_show_and_confirm_shares(mnemonics) await layout.slip39_advanced_show_and_confirm_shares(mnemonics)
async def _backup_slip39_custom( async def backup_slip39_custom(
encrypted_master_secret: bytes, encrypted_master_secret: bytes,
group_threshold: int, group_threshold: int,
groups: Sequence[tuple[int, int]], groups: Sequence[tuple[int, int]],
): ) -> None:
mnemonics = _get_slip39_mnemonics(encrypted_master_secret, group_threshold, groups) mnemonics = _get_slip39_mnemonics(encrypted_master_secret, group_threshold, groups)
# show and confirm individual shares # show and confirm individual shares
@ -241,20 +241,8 @@ def _compute_secret_from_entropy(
return secret return secret
async def backup_seed( async def backup_seed(backup_type: BackupType, mnemonic_secret: bytes) -> None:
backup_type: BackupType, if backup_type == BAK_T_SLIP39_BASIC:
mnemonic_secret: bytes,
group_threshold: int | None = None,
groups: Sequence[tuple[int, int]] = (),
) -> None:
# Either both should be defined or both should be missing: group_threshold, groups
assert (group_threshold is None) == (len(groups) == 0)
assert backup_type != BAK_T_BIP39 or group_threshold is None
if group_threshold is not None:
await _backup_slip39_custom(mnemonic_secret, group_threshold, groups)
elif backup_type == BAK_T_SLIP39_BASIC:
await _backup_slip39_basic(mnemonic_secret) await _backup_slip39_basic(mnemonic_secret)
elif backup_type == BAK_T_SLIP39_ADVANCED: elif backup_type == BAK_T_SLIP39_ADVANCED:
await _backup_slip39_advanced(mnemonic_secret) await _backup_slip39_advanced(mnemonic_secret)

@ -34,11 +34,13 @@ pytestmark = [pytest.mark.skip_t1b1]
"group_threshold, share_threshold, share_count", "group_threshold, share_threshold, share_count",
[ [
pytest.param(1, 1, 1, id="1of1"), pytest.param(1, 1, 1, id="1of1"),
pytest.param(1, 2, 3, id="2of3"),
pytest.param(1, 5, 5, id="5of5"),
], ],
) )
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@WITH_MOCK_URANDOM @WITH_MOCK_URANDOM
def test_backup_slip39_single( def test_backup_slip39_custom(
device_handler: "BackgroundDeviceHandler", device_handler: "BackgroundDeviceHandler",
group_threshold: int, group_threshold: int,
share_threshold: int, share_threshold: int,
@ -76,6 +78,10 @@ def test_backup_slip39_single(
# confirm backup warning # confirm backup warning
reset.confirm_read(debug, middle_r=True) reset.confirm_read(debug, middle_r=True)
if share_count > 1:
# confirm shamir warning
reset.confirm_read(debug, middle_r=True)
all_words: list[str] = [] all_words: list[str] = []
for _ in range(share_count): for _ in range(share_count):
# read words # read words
@ -98,7 +104,7 @@ def test_backup_slip39_single(
secret = generate_entropy(128, internal_entropy, EXTERNAL_ENTROPY) secret = generate_entropy(128, internal_entropy, EXTERNAL_ENTROPY)
# validate that all combinations will result in the correct master secret # validate that all combinations will result in the correct master secret
reset.validate_mnemonics(all_words, secret) reset.validate_mnemonics(all_words[:share_threshold], secret)
assert device_handler.result() == "Seed successfully backed up" assert device_handler.result() == "Seed successfully backed up"
features = device_handler.features() features = device_handler.features()

@ -58,6 +58,8 @@ MNEMONIC_SLIP39_ADVANCED_33 = [
"wildlife deal beard romp alcohol space mild usual clothes union nuclear testify course research heat listen task location thank hospital slice smell failure fawn helpful priest ambition average recover lecture process dough stadium", "wildlife deal beard romp alcohol space mild usual clothes union nuclear testify course research heat listen task location thank hospital slice smell failure fawn helpful priest ambition average recover lecture process dough stadium",
"wildlife deal acrobat romp anxiety axis starting require metric flexible geology game drove editor edge screw helpful have huge holy making pitch unknown carve holiday numb glasses survive already tenant adapt goat fangs", "wildlife deal acrobat romp anxiety axis starting require metric flexible geology game drove editor edge screw helpful have huge holy making pitch unknown carve holiday numb glasses survive already tenant adapt goat fangs",
] ]
MNEMONIC_SLIP39_CUSTOM_20_1of1 = ["tolerate flexible academic academic average dwarf square home promise aspect temple cluster roster forward hand unfair tenant emperor ceramic element forget perfect knit adapt review usual formal receiver typical pleasure duke yield party"]
MNEMONIC_SLIP39_CUSTOM_20_SECRET = "3439316237393562383066633231636364663436366330666263393863386663"
# External entropy mocked as received from trezorlib. # External entropy mocked as received from trezorlib.
EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2 EXTERNAL_ENTROPY = b"zlutoucky kun upel divoke ody" * 2
# fmt: on # fmt: on

@ -26,6 +26,8 @@ from ..common import (
MNEMONIC12, MNEMONIC12,
MNEMONIC_SLIP39_ADVANCED_20, MNEMONIC_SLIP39_ADVANCED_20,
MNEMONIC_SLIP39_BASIC_20_3of6, MNEMONIC_SLIP39_BASIC_20_3of6,
MNEMONIC_SLIP39_CUSTOM_20_1of1,
MNEMONIC_SLIP39_CUSTOM_20_SECRET,
) )
from ..input_flows import ( from ..input_flows import (
InputFlowBip39Backup, InputFlowBip39Backup,
@ -113,7 +115,7 @@ def test_backup_slip39_advanced(client: Client, click_info: bool):
@pytest.mark.skip_t1b1 @pytest.mark.skip_t1b1
@pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_ADVANCED_20) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_CUSTOM_20_1of1[0])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"share_threshold,share_count", "share_threshold,share_count",
[(1, 1), (2, 2), (3, 5)], [(1, 1), (2, 2), (3, 5)],
@ -134,13 +136,9 @@ def test_backup_slip39_custom(client: Client, share_threshold, share_count):
assert client.features.needs_backup is False assert client.features.needs_backup is False
assert client.features.unfinished_backup is False assert client.features.unfinished_backup is False
assert client.features.no_backup is False assert client.features.no_backup is False
assert client.features.backup_type is messages.BackupType.Slip39_Advanced
expected_ms = shamir.combine_mnemonics(MNEMONIC_SLIP39_ADVANCED_20) assert len(IF.mnemonics) == share_count
actual_ms = shamir.combine_mnemonics( assert shamir.combine_mnemonics(IF.mnemonics[-share_threshold:]).hex() == MNEMONIC_SLIP39_CUSTOM_20_SECRET
IF.mnemonics[:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13]
)
assert expected_ms == actual_ms
# we only test this with bip39 because the code path is always the same # we only test this with bip39 because the code path is always the same

Loading…
Cancel
Save