mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-16 17:42:02 +00:00
slip39: Remove ShamirMnemonic class. Use binary search to lookup words in wordlist.
This commit is contained in:
parent
d0527997ee
commit
cd08c6937b
@ -33,7 +33,6 @@ class MnemonicError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ShamirMnemonic(object):
|
||||
RADIX_BITS = 10
|
||||
"""The length of the radix in bits."""
|
||||
|
||||
@ -67,9 +66,7 @@ class ShamirMnemonic(object):
|
||||
MIN_STRENGTH_BITS = 128
|
||||
"""The minimum allowed entropy of the master secret."""
|
||||
|
||||
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil(
|
||||
MIN_STRENGTH_BITS / 10
|
||||
)
|
||||
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil(MIN_STRENGTH_BITS / 10)
|
||||
"""The minimum allowed length of the mnemonic in words."""
|
||||
|
||||
MIN_ITERATION_COUNT = 10000
|
||||
@ -84,20 +81,22 @@ class ShamirMnemonic(object):
|
||||
DIGEST_INDEX = 254
|
||||
"""The index of the share containing the digest of the shared secret."""
|
||||
|
||||
def __init__(self):
|
||||
# Load the word list.
|
||||
|
||||
if len(wordlist) != self.RADIX:
|
||||
raise ConfigurationError(
|
||||
"The wordlist should contain {} words, but it contains {} words.".format(
|
||||
self.RADIX, len(wordlist)
|
||||
)
|
||||
)
|
||||
def word_index(word):
|
||||
lo = 0
|
||||
hi = len(wordlist)
|
||||
while hi - lo > 1:
|
||||
mid = (hi + lo) // 2
|
||||
if wordlist[mid] > word:
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid
|
||||
if not wordlist[lo].startswith(word):
|
||||
raise MnemonicError('Invalid mnemonic word "{}".'.format(word))
|
||||
return lo
|
||||
|
||||
self.word_index_map = {word: i for i, word in enumerate(wordlist)}
|
||||
|
||||
@classmethod
|
||||
def _rs1024_polymod(cls, values):
|
||||
def _rs1024_polymod(values):
|
||||
GEN = (
|
||||
0xE0E040,
|
||||
0x1C1C080,
|
||||
@ -118,102 +117,93 @@ class ShamirMnemonic(object):
|
||||
chk ^= GEN[i] if ((b >> i) & 1) else 0
|
||||
return chk
|
||||
|
||||
@classmethod
|
||||
def rs1024_create_checksum(cls, data):
|
||||
values = (
|
||||
tuple(cls.CUSTOMIZATION_STRING) + data + cls.CHECKSUM_LENGTH_WORDS * (0,)
|
||||
)
|
||||
polymod = cls._rs1024_polymod(values) ^ 1
|
||||
|
||||
def rs1024_create_checksum(data):
|
||||
values = tuple(CUSTOMIZATION_STRING) + data + CHECKSUM_LENGTH_WORDS * (0,)
|
||||
polymod = _rs1024_polymod(values) ^ 1
|
||||
return tuple(
|
||||
(polymod >> 10 * i) & 1023
|
||||
for i in reversed(range(cls.CHECKSUM_LENGTH_WORDS))
|
||||
(polymod >> 10 * i) & 1023 for i in reversed(range(CHECKSUM_LENGTH_WORDS))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def rs1024_verify_checksum(cls, data):
|
||||
return cls._rs1024_polymod(tuple(cls.CUSTOMIZATION_STRING) + data) == 1
|
||||
|
||||
@staticmethod
|
||||
def rs1024_verify_checksum(data):
|
||||
return _rs1024_polymod(tuple(CUSTOMIZATION_STRING) + data) == 1
|
||||
|
||||
|
||||
def xor(a, b):
|
||||
return bytes(x ^ y for x, y in zip(a, b))
|
||||
|
||||
@classmethod
|
||||
def _int_from_indices(cls, indices):
|
||||
|
||||
def _int_from_indices(indices):
|
||||
"""Converts a list of base 1024 indices in big endian order to an integer value."""
|
||||
value = 0
|
||||
for index in indices:
|
||||
value = (value << cls.RADIX_BITS) + index
|
||||
value = (value << RADIX_BITS) + index
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
|
||||
def _int_to_indices(value, length, bits):
|
||||
"""Converts an integer value to indices in big endian order."""
|
||||
mask = (1 << bits) - 1
|
||||
return ((value >> (i * bits)) & mask for i in reversed(range(length)))
|
||||
|
||||
def mnemonic_from_indices(self, indices):
|
||||
|
||||
def mnemonic_from_indices(indices):
|
||||
return " ".join(wordlist[i] for i in indices)
|
||||
|
||||
def mnemonic_to_indices(self, mnemonic):
|
||||
try:
|
||||
return (self.word_index_map[word.lower()] for word in mnemonic.split())
|
||||
except KeyError as key_error:
|
||||
raise MnemonicError("Invalid mnemonic word {}.".format(key_error)) from None
|
||||
|
||||
@classmethod
|
||||
def _round_function(cls, i, passphrase, e, salt, r):
|
||||
def mnemonic_to_indices(mnemonic):
|
||||
return (word_index(word.lower()) for word in mnemonic.split())
|
||||
|
||||
|
||||
def _round_function(i, passphrase, e, salt, r):
|
||||
"""The round function used internally by the Feistel cipher."""
|
||||
return pbkdf2(
|
||||
pbkdf2.HMAC_SHA256,
|
||||
bytes([i]) + passphrase,
|
||||
salt + r,
|
||||
(cls.MIN_ITERATION_COUNT << e) // cls.ROUND_COUNT,
|
||||
(MIN_ITERATION_COUNT << e) // ROUND_COUNT,
|
||||
).key()[: len(r)]
|
||||
|
||||
@classmethod
|
||||
def _get_salt(cls, identifier):
|
||||
return cls.CUSTOMIZATION_STRING + identifier.to_bytes(
|
||||
math.ceil(cls.ID_LENGTH_BITS / 8), "big"
|
||||
|
||||
def _get_salt(identifier):
|
||||
return CUSTOMIZATION_STRING + identifier.to_bytes(
|
||||
math.ceil(ID_LENGTH_BITS / 8), "big"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _encrypt(cls, master_secret, passphrase, iteration_exponent, identifier):
|
||||
|
||||
def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
|
||||
l = master_secret[: len(master_secret) // 2]
|
||||
r = master_secret[len(master_secret) // 2 :]
|
||||
salt = cls._get_salt(identifier)
|
||||
for i in range(cls.ROUND_COUNT):
|
||||
salt = _get_salt(identifier)
|
||||
for i in range(ROUND_COUNT):
|
||||
(l, r) = (
|
||||
r,
|
||||
cls.xor(
|
||||
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
|
||||
),
|
||||
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
|
||||
)
|
||||
return r + l
|
||||
|
||||
@classmethod
|
||||
def _decrypt(
|
||||
cls, encrypted_master_secret, passphrase, iteration_exponent, identifier
|
||||
):
|
||||
|
||||
def _decrypt(encrypted_master_secret, passphrase, iteration_exponent, identifier):
|
||||
l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
|
||||
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
|
||||
salt = cls._get_salt(identifier)
|
||||
for i in reversed(range(cls.ROUND_COUNT)):
|
||||
salt = _get_salt(identifier)
|
||||
for i in reversed(range(ROUND_COUNT)):
|
||||
(l, r) = (
|
||||
r,
|
||||
cls.xor(
|
||||
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
|
||||
),
|
||||
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
|
||||
)
|
||||
return r + l
|
||||
|
||||
@classmethod
|
||||
def _create_digest(cls, random_data, shared_secret):
|
||||
|
||||
def _create_digest(random_data, shared_secret):
|
||||
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
|
||||
: cls.DIGEST_LENGTH_BYTES
|
||||
:DIGEST_LENGTH_BYTES
|
||||
]
|
||||
|
||||
def _split_secret(self, threshold, share_count, shared_secret):
|
||||
assert 0 < threshold <= share_count <= self.MAX_SHARE_COUNT
|
||||
|
||||
def _split_secret(threshold, share_count, shared_secret):
|
||||
assert 0 < threshold <= share_count <= MAX_SHARE_COUNT
|
||||
|
||||
# If the threshold is 1, then the digest of the shared secret is not used.
|
||||
if threshold == 1:
|
||||
@ -221,23 +211,21 @@ class ShamirMnemonic(object):
|
||||
|
||||
random_share_count = threshold - 2
|
||||
|
||||
if share_count > self.MAX_SHARE_COUNT:
|
||||
if share_count > MAX_SHARE_COUNT:
|
||||
raise ValueError(
|
||||
"The requested number of shares ({}) must not exceed {}.".format(
|
||||
share_count, self.MAX_SHARE_COUNT
|
||||
share_count, MAX_SHARE_COUNT
|
||||
)
|
||||
)
|
||||
|
||||
shares = [
|
||||
(i, random.bytes(len(shared_secret))) for i in range(random_share_count)
|
||||
]
|
||||
shares = [(i, random.bytes(len(shared_secret))) for i in range(random_share_count)]
|
||||
|
||||
random_part = random.bytes(len(shared_secret) - self.DIGEST_LENGTH_BYTES)
|
||||
digest = self._create_digest(random_part, shared_secret)
|
||||
random_part = random.bytes(len(shared_secret) - DIGEST_LENGTH_BYTES)
|
||||
digest = _create_digest(random_part, shared_secret)
|
||||
|
||||
base_shares = shares + [
|
||||
(self.DIGEST_INDEX, digest + random_part),
|
||||
(self.SECRET_INDEX, shared_secret),
|
||||
(DIGEST_INDEX, digest + random_part),
|
||||
(SECRET_INDEX, shared_secret),
|
||||
]
|
||||
|
||||
for i in range(random_share_count, share_count):
|
||||
@ -245,35 +233,32 @@ class ShamirMnemonic(object):
|
||||
|
||||
return shares
|
||||
|
||||
def _recover_secret(self, threshold, shares):
|
||||
shared_secret = shamir.interpolate(shares, self.SECRET_INDEX)
|
||||
|
||||
def _recover_secret(threshold, shares):
|
||||
shared_secret = shamir.interpolate(shares, SECRET_INDEX)
|
||||
|
||||
# If the threshold is 1, then the digest of the shared secret is not used.
|
||||
if threshold != 1:
|
||||
digest_share = shamir.interpolate(shares, self.DIGEST_INDEX)
|
||||
digest = digest_share[: self.DIGEST_LENGTH_BYTES]
|
||||
random_part = digest_share[self.DIGEST_LENGTH_BYTES :]
|
||||
digest_share = shamir.interpolate(shares, DIGEST_INDEX)
|
||||
digest = digest_share[:DIGEST_LENGTH_BYTES]
|
||||
random_part = digest_share[DIGEST_LENGTH_BYTES:]
|
||||
|
||||
if digest != self._create_digest(random_part, shared_secret):
|
||||
if digest != _create_digest(random_part, shared_secret):
|
||||
raise MnemonicError("Invalid digest of the shared secret.")
|
||||
|
||||
return shared_secret
|
||||
|
||||
@classmethod
|
||||
|
||||
def _group_prefix(
|
||||
cls, identifier, iteration_exponent, group_index, group_threshold, group_count
|
||||
identifier, iteration_exponent, group_index, group_threshold, group_count
|
||||
):
|
||||
id_exp_int = (identifier << cls.ITERATION_EXP_LENGTH_BITS) + iteration_exponent
|
||||
return tuple(
|
||||
cls._int_to_indices(id_exp_int, cls.ID_EXP_LENGTH_WORDS, cls.RADIX_BITS)
|
||||
) + (
|
||||
(group_index << 6)
|
||||
+ ((group_threshold - 1) << 2)
|
||||
+ ((group_count - 1) >> 2),
|
||||
id_exp_int = (identifier << ITERATION_EXP_LENGTH_BITS) + iteration_exponent
|
||||
return tuple(_int_to_indices(id_exp_int, ID_EXP_LENGTH_WORDS, RADIX_BITS)) + (
|
||||
(group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2),
|
||||
)
|
||||
|
||||
|
||||
def encode_mnemonic(
|
||||
self,
|
||||
identifier,
|
||||
iteration_exponent,
|
||||
group_index,
|
||||
@ -299,77 +284,72 @@ class ShamirMnemonic(object):
|
||||
"""
|
||||
|
||||
# Convert the share value from bytes to wordlist indices.
|
||||
value_word_count = math.ceil(len(value) * 8 / self.RADIX_BITS)
|
||||
value_word_count = math.ceil(len(value) * 8 / RADIX_BITS)
|
||||
value_int = int.from_bytes(value, "big")
|
||||
|
||||
share_data = (
|
||||
self._group_prefix(
|
||||
identifier,
|
||||
iteration_exponent,
|
||||
group_index,
|
||||
group_threshold,
|
||||
group_count,
|
||||
_group_prefix(
|
||||
identifier, iteration_exponent, group_index, group_threshold, group_count
|
||||
)
|
||||
+ (
|
||||
(((group_count - 1) & 3) << 8)
|
||||
+ (member_index << 4)
|
||||
+ (member_threshold - 1),
|
||||
)
|
||||
+ tuple(self._int_to_indices(value_int, value_word_count, self.RADIX_BITS))
|
||||
+ tuple(_int_to_indices(value_int, value_word_count, RADIX_BITS))
|
||||
)
|
||||
checksum = self.rs1024_create_checksum(share_data)
|
||||
checksum = rs1024_create_checksum(share_data)
|
||||
|
||||
return self.mnemonic_from_indices(share_data + checksum)
|
||||
return mnemonic_from_indices(share_data + checksum)
|
||||
|
||||
def decode_mnemonic(self, mnemonic):
|
||||
|
||||
def decode_mnemonic(mnemonic):
|
||||
"""Converts a share mnemonic to share data."""
|
||||
|
||||
mnemonic_data = tuple(self.mnemonic_to_indices(mnemonic))
|
||||
mnemonic_data = tuple(mnemonic_to_indices(mnemonic))
|
||||
|
||||
if len(mnemonic_data) < self.MIN_MNEMONIC_LENGTH_WORDS:
|
||||
if len(mnemonic_data) < MIN_MNEMONIC_LENGTH_WORDS:
|
||||
raise MnemonicError(
|
||||
"Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format(
|
||||
self.MIN_MNEMONIC_LENGTH_WORDS
|
||||
MIN_MNEMONIC_LENGTH_WORDS
|
||||
)
|
||||
)
|
||||
|
||||
padding_len = (10 * (len(mnemonic_data) - self.METADATA_LENGTH_WORDS)) % 16
|
||||
padding_len = (10 * (len(mnemonic_data) - METADATA_LENGTH_WORDS)) % 16
|
||||
if padding_len > 8:
|
||||
raise MnemonicError("Invalid mnemonic length.")
|
||||
|
||||
if not self.rs1024_verify_checksum(mnemonic_data):
|
||||
if not rs1024_verify_checksum(mnemonic_data):
|
||||
raise MnemonicError(
|
||||
'Invalid mnemonic checksum for "{} ...".'.format(
|
||||
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
|
||||
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
|
||||
)
|
||||
)
|
||||
|
||||
id_exp_int = self._int_from_indices(mnemonic_data[: self.ID_EXP_LENGTH_WORDS])
|
||||
identifier = id_exp_int >> self.ITERATION_EXP_LENGTH_BITS
|
||||
iteration_exponent = id_exp_int & ((1 << self.ITERATION_EXP_LENGTH_BITS) - 1)
|
||||
tmp = self._int_from_indices(
|
||||
mnemonic_data[self.ID_EXP_LENGTH_WORDS : self.ID_EXP_LENGTH_WORDS + 2]
|
||||
id_exp_int = _int_from_indices(mnemonic_data[:ID_EXP_LENGTH_WORDS])
|
||||
identifier = id_exp_int >> ITERATION_EXP_LENGTH_BITS
|
||||
iteration_exponent = id_exp_int & ((1 << ITERATION_EXP_LENGTH_BITS) - 1)
|
||||
tmp = _int_from_indices(
|
||||
mnemonic_data[ID_EXP_LENGTH_WORDS : ID_EXP_LENGTH_WORDS + 2]
|
||||
)
|
||||
group_index, group_threshold, group_count, member_index, member_threshold = self._int_to_indices(
|
||||
group_index, group_threshold, group_count, member_index, member_threshold = _int_to_indices(
|
||||
tmp, 5, 4
|
||||
)
|
||||
value_data = mnemonic_data[
|
||||
self.ID_EXP_LENGTH_WORDS + 2 : -self.CHECKSUM_LENGTH_WORDS
|
||||
]
|
||||
value_data = mnemonic_data[ID_EXP_LENGTH_WORDS + 2 : -CHECKSUM_LENGTH_WORDS]
|
||||
|
||||
if group_count < group_threshold:
|
||||
raise MnemonicError(
|
||||
'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format(
|
||||
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
|
||||
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
|
||||
)
|
||||
)
|
||||
|
||||
value_byte_count = (10 * len(value_data) - padding_len) // 8
|
||||
value_int = self._int_from_indices(value_data)
|
||||
value_int = _int_from_indices(value_data)
|
||||
if value_data[0] >= 1 << (10 - padding_len):
|
||||
raise MnemonicError(
|
||||
'Invalid mnemonic padding for "{} ...".'.format(
|
||||
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
|
||||
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
|
||||
)
|
||||
)
|
||||
value = value_int.to_bytes(value_byte_count, "big")
|
||||
@ -385,14 +365,15 @@ class ShamirMnemonic(object):
|
||||
value,
|
||||
)
|
||||
|
||||
def _decode_mnemonics(self, mnemonics):
|
||||
|
||||
def _decode_mnemonics(mnemonics):
|
||||
identifiers = set()
|
||||
iteration_exponents = set()
|
||||
group_thresholds = set()
|
||||
group_counts = set()
|
||||
groups = {} # { group_index : [member_threshold, set_of_member_shares] }
|
||||
for mnemonic in mnemonics:
|
||||
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = self.decode_mnemonic(
|
||||
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic(
|
||||
mnemonic
|
||||
)
|
||||
identifiers.add(identifier)
|
||||
@ -409,7 +390,7 @@ class ShamirMnemonic(object):
|
||||
if len(identifiers) != 1 or len(iteration_exponents) != 1:
|
||||
raise MnemonicError(
|
||||
"Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format(
|
||||
self.ID_EXP_LENGTH_WORDS
|
||||
ID_EXP_LENGTH_WORDS
|
||||
)
|
||||
)
|
||||
|
||||
@ -437,21 +418,16 @@ class ShamirMnemonic(object):
|
||||
groups,
|
||||
)
|
||||
|
||||
def _generate_random_identifier(self):
|
||||
|
||||
def _generate_random_identifier():
|
||||
"""Returns a randomly generated integer in the range 0, ... , 2**ID_LENGTH_BITS - 1."""
|
||||
|
||||
identifier = int.from_bytes(
|
||||
random.bytes(math.ceil(self.ID_LENGTH_BITS / 8)), "big"
|
||||
)
|
||||
return identifier & ((1 << self.ID_LENGTH_BITS) - 1)
|
||||
identifier = int.from_bytes(random.bytes(math.ceil(ID_LENGTH_BITS / 8)), "big")
|
||||
return identifier & ((1 << ID_LENGTH_BITS) - 1)
|
||||
|
||||
|
||||
def generate_mnemonics(
|
||||
self,
|
||||
group_threshold,
|
||||
groups,
|
||||
master_secret,
|
||||
passphrase=b"",
|
||||
iteration_exponent=0,
|
||||
group_threshold, groups, master_secret, passphrase=b"", iteration_exponent=0
|
||||
):
|
||||
"""
|
||||
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
|
||||
@ -469,12 +445,12 @@ class ShamirMnemonic(object):
|
||||
:rtype: List of byte arrays.
|
||||
"""
|
||||
|
||||
identifier = self._generate_random_identifier()
|
||||
identifier = _generate_random_identifier()
|
||||
|
||||
if len(master_secret) * 8 < self.MIN_STRENGTH_BITS:
|
||||
if len(master_secret) * 8 < MIN_STRENGTH_BITS:
|
||||
raise ValueError(
|
||||
"The length of the master secret ({} bytes) must be at least {} bytes.".format(
|
||||
len(master_secret), math.ceil(self.MIN_STRENGTH_BITS / 8)
|
||||
len(master_secret), math.ceil(MIN_STRENGTH_BITS / 8)
|
||||
)
|
||||
)
|
||||
|
||||
@ -490,17 +466,15 @@ class ShamirMnemonic(object):
|
||||
)
|
||||
)
|
||||
|
||||
encrypted_master_secret = self._encrypt(
|
||||
encrypted_master_secret = _encrypt(
|
||||
master_secret, passphrase, iteration_exponent, identifier
|
||||
)
|
||||
|
||||
group_shares = self._split_secret(
|
||||
group_threshold, len(groups), encrypted_master_secret
|
||||
)
|
||||
group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
|
||||
|
||||
return [
|
||||
[
|
||||
self.encode_mnemonic(
|
||||
encode_mnemonic(
|
||||
identifier,
|
||||
iteration_exponent,
|
||||
group_index,
|
||||
@ -510,7 +484,7 @@ class ShamirMnemonic(object):
|
||||
member_threshold,
|
||||
value,
|
||||
)
|
||||
for member_index, value in self._split_secret(
|
||||
for member_index, value in _split_secret(
|
||||
member_threshold, member_count, group_secret
|
||||
)
|
||||
]
|
||||
@ -519,13 +493,9 @@ class ShamirMnemonic(object):
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def generate_mnemonics_random(
|
||||
self,
|
||||
group_threshold,
|
||||
groups,
|
||||
strength_bits=128,
|
||||
passphrase=b"",
|
||||
iteration_exponent=0,
|
||||
group_threshold, groups, strength_bits=128, passphrase=b"", iteration_exponent=0
|
||||
):
|
||||
"""
|
||||
Generates a random master secret and splits it into mnemonic shares using Shamir's secret
|
||||
@ -543,10 +513,10 @@ class ShamirMnemonic(object):
|
||||
:rtype: List of byte arrays.
|
||||
"""
|
||||
|
||||
if strength_bits < self.MIN_STRENGTH_BITS:
|
||||
if strength_bits < MIN_STRENGTH_BITS:
|
||||
raise ValueError(
|
||||
"The requested strength of the master secret ({} bits) must be at least {} bits.".format(
|
||||
strength_bits, self.MIN_STRENGTH_BITS
|
||||
strength_bits, MIN_STRENGTH_BITS
|
||||
)
|
||||
)
|
||||
|
||||
@ -557,7 +527,7 @@ class ShamirMnemonic(object):
|
||||
)
|
||||
)
|
||||
|
||||
return self.generate_mnemonics(
|
||||
return generate_mnemonics(
|
||||
group_threshold,
|
||||
groups,
|
||||
random.bytes(strength_bits // 8),
|
||||
@ -565,7 +535,8 @@ class ShamirMnemonic(object):
|
||||
iteration_exponent,
|
||||
)
|
||||
|
||||
def combine_mnemonics(self, mnemonics, passphrase=b""):
|
||||
|
||||
def combine_mnemonics(mnemonics, passphrase=b""):
|
||||
"""
|
||||
Combines mnemonic shares to obtain the master secret which was previously split using
|
||||
Shamir's secret sharing scheme.
|
||||
@ -580,7 +551,7 @@ class ShamirMnemonic(object):
|
||||
if not mnemonics:
|
||||
raise MnemonicError("The list of mnemonics is empty.")
|
||||
|
||||
identifier, iteration_exponent, group_threshold, group_count, groups = self._decode_mnemonics(
|
||||
identifier, iteration_exponent, group_threshold, group_count, groups = _decode_mnemonics(
|
||||
mnemonics
|
||||
)
|
||||
|
||||
@ -602,26 +573,22 @@ class ShamirMnemonic(object):
|
||||
|
||||
if len(groups) < group_threshold:
|
||||
group_index, group = next(iter(bad_groups.items()))
|
||||
prefix = self._group_prefix(
|
||||
identifier,
|
||||
iteration_exponent,
|
||||
group_index,
|
||||
group_threshold,
|
||||
group_count,
|
||||
prefix = _group_prefix(
|
||||
identifier, iteration_exponent, group_index, group_threshold, group_count
|
||||
)
|
||||
raise MnemonicError(
|
||||
'Insufficient number of mnemonics. At least {} mnemonics starting with "{} ..." are required.'.format(
|
||||
group[0], self.mnemonic_from_indices(prefix)
|
||||
group[0], mnemonic_from_indices(prefix)
|
||||
)
|
||||
)
|
||||
|
||||
group_shares = [
|
||||
(group_index, self._recover_secret(group[0], list(group[1])))
|
||||
(group_index, _recover_secret(group[0], list(group[1])))
|
||||
for group_index, group in groups.items()
|
||||
]
|
||||
|
||||
return self._decrypt(
|
||||
self._recover_secret(group_threshold, group_shares),
|
||||
return _decrypt(
|
||||
_recover_secret(group_threshold, group_shares),
|
||||
passphrase,
|
||||
iteration_exponent,
|
||||
identifier,
|
||||
|
@ -6,68 +6,66 @@ from slip39_vectors import vectors
|
||||
|
||||
class TestCryptoSlip39(unittest.TestCase):
|
||||
MS = b"ABCDEFGHIJKLMNOP"
|
||||
shamir = slip39.ShamirMnemonic()
|
||||
|
||||
|
||||
def test_basic_sharing_random(self):
|
||||
mnemonics = self.shamir.generate_mnemonics_random(1, [(3, 5)])[0]
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.shamir.combine_mnemonics(mnemonics))
|
||||
mnemonics = slip39.generate_mnemonics_random(1, [(3, 5)])[0]
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), slip39.combine_mnemonics(mnemonics))
|
||||
|
||||
|
||||
def test_basic_sharing_fixed(self):
|
||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS)[0]
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics), self.MS)
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0]
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics), self.MS)
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
with self.assertRaises(slip39.MnemonicError):
|
||||
self.shamir.combine_mnemonics(mnemonics[1:3])
|
||||
slip39.combine_mnemonics(mnemonics[1:3])
|
||||
|
||||
|
||||
def test_passphrase(self):
|
||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
|
||||
|
||||
def test_iteration_exponent(self):
|
||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
|
||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||
|
||||
|
||||
def test_group_sharing(self):
|
||||
mnemonics = self.shamir.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
|
||||
mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
|
||||
|
||||
# All mnemonics.
|
||||
self.assertEqual(self.shamir.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
|
||||
self.assertEqual(slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
|
||||
|
||||
# Minimal sets of mnemonics.
|
||||
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
|
||||
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
|
||||
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
|
||||
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
|
||||
|
||||
# Two complete groups and one incomplete group.
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
|
||||
|
||||
# One complete group and one incomplete group out of two groups required.
|
||||
with self.assertRaises(slip39.MnemonicError):
|
||||
self.shamir.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
|
||||
slip39.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
|
||||
|
||||
# One group of two required.
|
||||
with self.assertRaises(slip39.MnemonicError):
|
||||
self.shamir.combine_mnemonics(mnemonics[0][1:4])
|
||||
slip39.combine_mnemonics(mnemonics[0][1:4])
|
||||
|
||||
|
||||
def test_vectors(self):
|
||||
for mnemonics, secret in vectors:
|
||||
if secret:
|
||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
|
||||
self.assertEqual(slip39.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
|
||||
else:
|
||||
with self.assertRaises(slip39.MnemonicError):
|
||||
self.shamir.combine_mnemonics(mnemonics, b"TREZOR")
|
||||
slip39.combine_mnemonics(mnemonics, b"TREZOR")
|
||||
|
||||
|
||||
def test_invalid_rs1024_checksum(self):
|
||||
@ -75,7 +73,7 @@ class TestCryptoSlip39(unittest.TestCase):
|
||||
"artist away academic academic dismiss spill unkind pencil lair sugar usher elegant paces sweater firm gravity deal body chest sugar"
|
||||
]
|
||||
with self.assertRaises(slip39.MnemonicError):
|
||||
self.shamir.combine_mnemonics(mnemonics)
|
||||
slip39.combine_mnemonics(mnemonics)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user