1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 07:28:10 +00:00

feat(core): Implement extend_mnemonics() for SLIP-39.

[no changelog]
This commit is contained in:
Andrew Kozlik 2024-10-04 21:29:42 +02:00 committed by Andrew Kozlik
parent d71d9e9c34
commit 911fe2a526
2 changed files with 87 additions and 3 deletions

View File

@ -247,6 +247,69 @@ def split_ems(
return mnemonics
def extend_mnemonics(
share_count: int, # The number of shares to create.
mnemonics: list[str], # A threshold set of the old mnemonics.
) -> list[str]:
"""
Extends a set of mnemonics to the desired share_count, while maintaining the threshold. This,
for example, allows extending a 2-of-2 backup to 2-of-3, where the first two shares remain the
same. It also allows reconstructing lost shares by providing any threshold number of shares and
requesting the original share_count. The current implementation is limited to Slip39_Basic,
i.e. single group.
It is not possible to tell how many shares the user originally created, so if share_count is
less than the original number of shares, then this function will return the first share_count
shares.
"""
if not mnemonics:
raise MnemonicError("The list of mnemonics is empty.")
(
identifier,
extendable,
iteration_exponent,
group_threshold,
group_count,
groups,
) = _decode_mnemonics(mnemonics)
if group_threshold != 1 or group_count != 1 or len(groups) != 1:
raise MnemonicError("Extending advanced backups is not supported.")
threshold = groups[0][0]
shares = groups[0][1]
if len(shares) != threshold:
raise MnemonicError(
f"Wrong number of mnemonics. Expected {threshold} mnemonics, but {len(shares)} were provided."
)
if threshold == 1 and share_count > 1:
raise ValueError(
"Creating multiple member shares with member threshold 1 is not allowed. Use 1-of-1 member sharing instead."
)
shares = _extend_shares(share_count, list(shares))
mnemonics = []
for index, value in shares:
mnemonics.append(
_encode_mnemonic(
identifier,
extendable,
iteration_exponent,
group_index=0,
group_threshold=1,
group_count=1,
member_index=index,
member_threshold=threshold,
value=value,
)
)
return mnemonics
def recover_ems(mnemonics: list[str]) -> tuple[int, bool, int, bytes]:
"""
Combines mnemonic shares to obtain the encrypted master secret which was previously
@ -457,9 +520,7 @@ def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES]
def _split_secret(
threshold: int, share_count: int, shared_secret: bytes
) -> list[tuple[int, bytes]]:
def _check_parameters(threshold: int, share_count: int) -> None:
if threshold < 1:
raise ValueError(
f"The requested threshold ({threshold}) must be a positive integer."
@ -475,6 +536,12 @@ def _split_secret(
f"The requested number of shares ({share_count}) must not exceed {MAX_SHARE_COUNT}."
)
def _split_secret(
threshold: int, share_count: int, shared_secret: bytes
) -> list[tuple[int, bytes]]:
_check_parameters(threshold, share_count)
# If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1:
return [(i, shared_secret) for i in range(share_count)]
@ -499,6 +566,13 @@ def _split_secret(
return shares
def _extend_shares(
share_count: int, old_shares: list[tuple[int, bytes]]
) -> list[tuple[int, bytes]]:
_check_parameters(len(old_shares), share_count)
return [(i, shamir.interpolate(old_shares, i)) for i in range(share_count)]
def _recover_secret(threshold: int, shares: list[tuple[int, bytes]]) -> bytes:
# If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1:

View File

@ -37,6 +37,16 @@ class TestCryptoSlip39(unittest.TestCase):
slip39.recover_ems(mnemonics[:3]), slip39.recover_ems(mnemonics[2:])
)
def test_basic_sharing_extend(self):
identifier = slip39.generate_random_identifier()
for extendable in (False, True):
mnemonics = slip39.split_ems(1, [(2, 3)], identifier, extendable, 1, self.EMS)
mnemonics = mnemonics[0]
extended_mnemonics = slip39.extend_mnemonics(4, mnemonics[1:])
self.assertEqual(mnemonics, extended_mnemonics[:3])
for i in range(3):
self.assertEqual(slip39.recover_ems([extended_mnemonics[3], mnemonics[i]])[3], self.EMS)
def test_basic_sharing_fixed(self):
for extendable in (False, True):
generated_identifier = slip39.generate_random_identifier()