diff --git a/core/src/trezor/crypto/slip39.py b/core/src/trezor/crypto/slip39.py index a2c9d06b8..9c67d410e 100644 --- a/core/src/trezor/crypto/slip39.py +++ b/core/src/trezor/crypto/slip39.py @@ -209,13 +209,19 @@ def _create_digest(random_data, shared_secret): def _split_secret(threshold, share_count, shared_secret): - assert 0 < threshold <= share_count <= MAX_SHARE_COUNT - - # If the threshold is 1, then the digest of the shared secret is not used. - if threshold == 1: - return [(i, shared_secret) for i in range(share_count)] + if threshold < 1: + raise ValueError( + "The requested threshold ({}) must be a positive integer.".format( + threshold + ) + ) - random_share_count = threshold - 2 + if threshold > share_count: + raise ValueError( + "The requested threshold ({}) must not exceed the number of shares ({}).".format( + threshold, share_count + ) + ) if share_count > MAX_SHARE_COUNT: raise ValueError( @@ -224,6 +230,12 @@ def _split_secret(threshold, share_count, shared_secret): ) ) + # If the threshold is 1, then the digest of the shared secret is not used. + if threshold == 1: + return [(0, shared_secret)] + + random_share_count = threshold - 2 + shares = [(i, random.bytes(len(shared_secret))) for i in range(random_share_count)] random_part = random.bytes(len(shared_secret) - DIGEST_LENGTH_BYTES) @@ -241,16 +253,17 @@ def _split_secret(threshold, share_count, shared_secret): def _recover_secret(threshold, shares): - shared_secret = shamir.interpolate(shares, SECRET_INDEX) - # If the threshold is 1, then the digest of the shared secret is not used. - if threshold != 1: - digest_share = shamir.interpolate(shares, DIGEST_INDEX) - digest = digest_share[:DIGEST_LENGTH_BYTES] - random_part = digest_share[DIGEST_LENGTH_BYTES:] + if threshold == 1: + return shares[0][1] - if digest != _create_digest(random_part, shared_secret): - raise MnemonicError("Invalid digest of the shared secret.") + shared_secret = shamir.interpolate(shares, SECRET_INDEX) + digest_share = shamir.interpolate(shares, DIGEST_INDEX) + digest = digest_share[:DIGEST_LENGTH_BYTES] + random_part = digest_share[DIGEST_LENGTH_BYTES:] + + if digest != _create_digest(random_part, shared_secret): + raise MnemonicError("Invalid digest of the shared secret.") return shared_secret @@ -465,6 +478,11 @@ def generate_mnemonics( "The length of the master secret in bytes must be an even number." ) + if not all(32 <= c <= 126 for c in passphrase): + raise ValueError( + "The passphrase must contain only printable ASCII characters (code points 32-126)." + ) + if group_threshold > len(groups): raise ValueError( "The requested group threshold ({}) must not exceed the number of groups ({}).".format(