diff --git a/core/src/trezor/crypto/slip39.py b/core/src/trezor/crypto/slip39.py index 3197f11285..6f23e5ada9 100644 --- a/core/src/trezor/crypto/slip39.py +++ b/core/src/trezor/crypto/slip39.py @@ -247,6 +247,69 @@ def split_ems( return mnemonics +def extend_mnemonics( + share_count: int, # The number of shares to create. + mnemonics: list[str], # A threshold set of the old mnemonics. +) -> list[str]: + """ + Extends a set of mnemonics to the desired share_count, while maintaining the threshold. This, + for example, allows extending a 2-of-2 backup to 2-of-3, where the first two shares remain the + same. It also allows reconstructing lost shares by providing any threshold number of shares and + requesting the original share_count. The current implementation is limited to Slip39_Basic, + i.e. single group. + + It is not possible to tell how many shares the user originally created, so if share_count is + less than the original number of shares, then this function will return the first share_count + shares. + """ + + if not mnemonics: + raise MnemonicError("The list of mnemonics is empty.") + + ( + identifier, + extendable, + iteration_exponent, + group_threshold, + group_count, + groups, + ) = _decode_mnemonics(mnemonics) + + if group_threshold != 1 or group_count != 1 or len(groups) != 1: + raise MnemonicError("Extending advanced backups is not supported.") + + threshold = groups[0][0] + shares = groups[0][1] + if len(shares) != threshold: + raise MnemonicError( + f"Wrong number of mnemonics. Expected {threshold} mnemonics, but {len(shares)} were provided." + ) + + if threshold == 1 and share_count > 1: + raise ValueError( + "Creating multiple member shares with member threshold 1 is not allowed. Use 1-of-1 member sharing instead." + ) + + shares = _extend_shares(share_count, list(shares)) + + mnemonics = [] + for index, value in shares: + mnemonics.append( + _encode_mnemonic( + identifier, + extendable, + iteration_exponent, + group_index=0, + group_threshold=1, + group_count=1, + member_index=index, + member_threshold=threshold, + value=value, + ) + ) + return mnemonics + + def recover_ems(mnemonics: list[str]) -> tuple[int, bool, int, bytes]: """ Combines mnemonic shares to obtain the encrypted master secret which was previously @@ -457,9 +520,7 @@ def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes: return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES] -def _split_secret( - threshold: int, share_count: int, shared_secret: bytes -) -> list[tuple[int, bytes]]: +def _check_parameters(threshold: int, share_count: int) -> None: if threshold < 1: raise ValueError( f"The requested threshold ({threshold}) must be a positive integer." @@ -475,6 +536,12 @@ def _split_secret( f"The requested number of shares ({share_count}) must not exceed {MAX_SHARE_COUNT}." ) + +def _split_secret( + threshold: int, share_count: int, shared_secret: bytes +) -> list[tuple[int, bytes]]: + _check_parameters(threshold, share_count) + # If the threshold is 1, then the digest of the shared secret is not used. if threshold == 1: return [(i, shared_secret) for i in range(share_count)] @@ -499,6 +566,13 @@ def _split_secret( return shares +def _extend_shares( + share_count: int, old_shares: list[tuple[int, bytes]] +) -> list[tuple[int, bytes]]: + _check_parameters(len(old_shares), share_count) + return [(i, shamir.interpolate(old_shares, i)) for i in range(share_count)] + + def _recover_secret(threshold: int, shares: list[tuple[int, bytes]]) -> bytes: # If the threshold is 1, then the digest of the shared secret is not used. if threshold == 1: diff --git a/core/tests/test_trezor.crypto.slip39.py b/core/tests/test_trezor.crypto.slip39.py index 7b99184ae7..26367957f8 100644 --- a/core/tests/test_trezor.crypto.slip39.py +++ b/core/tests/test_trezor.crypto.slip39.py @@ -37,6 +37,16 @@ class TestCryptoSlip39(unittest.TestCase): slip39.recover_ems(mnemonics[:3]), slip39.recover_ems(mnemonics[2:]) ) + def test_basic_sharing_extend(self): + identifier = slip39.generate_random_identifier() + for extendable in (False, True): + mnemonics = slip39.split_ems(1, [(2, 3)], identifier, extendable, 1, self.EMS) + mnemonics = mnemonics[0] + extended_mnemonics = slip39.extend_mnemonics(4, mnemonics[1:]) + self.assertEqual(mnemonics, extended_mnemonics[:3]) + for i in range(3): + self.assertEqual(slip39.recover_ems([extended_mnemonics[3], mnemonics[i]])[3], self.EMS) + def test_basic_sharing_fixed(self): for extendable in (False, True): generated_identifier = slip39.generate_random_identifier()