mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-06 06:42:33 +00:00
slip39: Remove ShamirMnemonic class. Use binary search to lookup words in wordlist.
This commit is contained in:
parent
d0527997ee
commit
cd08c6937b
@ -33,7 +33,6 @@ class MnemonicError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ShamirMnemonic(object):
|
|
||||||
RADIX_BITS = 10
|
RADIX_BITS = 10
|
||||||
"""The length of the radix in bits."""
|
"""The length of the radix in bits."""
|
||||||
|
|
||||||
@ -67,9 +66,7 @@ class ShamirMnemonic(object):
|
|||||||
MIN_STRENGTH_BITS = 128
|
MIN_STRENGTH_BITS = 128
|
||||||
"""The minimum allowed entropy of the master secret."""
|
"""The minimum allowed entropy of the master secret."""
|
||||||
|
|
||||||
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil(
|
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil(MIN_STRENGTH_BITS / 10)
|
||||||
MIN_STRENGTH_BITS / 10
|
|
||||||
)
|
|
||||||
"""The minimum allowed length of the mnemonic in words."""
|
"""The minimum allowed length of the mnemonic in words."""
|
||||||
|
|
||||||
MIN_ITERATION_COUNT = 10000
|
MIN_ITERATION_COUNT = 10000
|
||||||
@ -84,20 +81,22 @@ class ShamirMnemonic(object):
|
|||||||
DIGEST_INDEX = 254
|
DIGEST_INDEX = 254
|
||||||
"""The index of the share containing the digest of the shared secret."""
|
"""The index of the share containing the digest of the shared secret."""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Load the word list.
|
|
||||||
|
|
||||||
if len(wordlist) != self.RADIX:
|
def word_index(word):
|
||||||
raise ConfigurationError(
|
lo = 0
|
||||||
"The wordlist should contain {} words, but it contains {} words.".format(
|
hi = len(wordlist)
|
||||||
self.RADIX, len(wordlist)
|
while hi - lo > 1:
|
||||||
)
|
mid = (hi + lo) // 2
|
||||||
)
|
if wordlist[mid] > word:
|
||||||
|
hi = mid
|
||||||
|
else:
|
||||||
|
lo = mid
|
||||||
|
if not wordlist[lo].startswith(word):
|
||||||
|
raise MnemonicError('Invalid mnemonic word "{}".'.format(word))
|
||||||
|
return lo
|
||||||
|
|
||||||
self.word_index_map = {word: i for i, word in enumerate(wordlist)}
|
|
||||||
|
|
||||||
@classmethod
|
def _rs1024_polymod(values):
|
||||||
def _rs1024_polymod(cls, values):
|
|
||||||
GEN = (
|
GEN = (
|
||||||
0xE0E040,
|
0xE0E040,
|
||||||
0x1C1C080,
|
0x1C1C080,
|
||||||
@ -118,102 +117,93 @@ class ShamirMnemonic(object):
|
|||||||
chk ^= GEN[i] if ((b >> i) & 1) else 0
|
chk ^= GEN[i] if ((b >> i) & 1) else 0
|
||||||
return chk
|
return chk
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def rs1024_create_checksum(cls, data):
|
def rs1024_create_checksum(data):
|
||||||
values = (
|
values = tuple(CUSTOMIZATION_STRING) + data + CHECKSUM_LENGTH_WORDS * (0,)
|
||||||
tuple(cls.CUSTOMIZATION_STRING) + data + cls.CHECKSUM_LENGTH_WORDS * (0,)
|
polymod = _rs1024_polymod(values) ^ 1
|
||||||
)
|
|
||||||
polymod = cls._rs1024_polymod(values) ^ 1
|
|
||||||
return tuple(
|
return tuple(
|
||||||
(polymod >> 10 * i) & 1023
|
(polymod >> 10 * i) & 1023 for i in reversed(range(CHECKSUM_LENGTH_WORDS))
|
||||||
for i in reversed(range(cls.CHECKSUM_LENGTH_WORDS))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def rs1024_verify_checksum(cls, data):
|
|
||||||
return cls._rs1024_polymod(tuple(cls.CUSTOMIZATION_STRING) + data) == 1
|
|
||||||
|
|
||||||
@staticmethod
|
def rs1024_verify_checksum(data):
|
||||||
|
return _rs1024_polymod(tuple(CUSTOMIZATION_STRING) + data) == 1
|
||||||
|
|
||||||
|
|
||||||
def xor(a, b):
|
def xor(a, b):
|
||||||
return bytes(x ^ y for x, y in zip(a, b))
|
return bytes(x ^ y for x, y in zip(a, b))
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _int_from_indices(cls, indices):
|
def _int_from_indices(indices):
|
||||||
"""Converts a list of base 1024 indices in big endian order to an integer value."""
|
"""Converts a list of base 1024 indices in big endian order to an integer value."""
|
||||||
value = 0
|
value = 0
|
||||||
for index in indices:
|
for index in indices:
|
||||||
value = (value << cls.RADIX_BITS) + index
|
value = (value << RADIX_BITS) + index
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _int_to_indices(value, length, bits):
|
def _int_to_indices(value, length, bits):
|
||||||
"""Converts an integer value to indices in big endian order."""
|
"""Converts an integer value to indices in big endian order."""
|
||||||
mask = (1 << bits) - 1
|
mask = (1 << bits) - 1
|
||||||
return ((value >> (i * bits)) & mask for i in reversed(range(length)))
|
return ((value >> (i * bits)) & mask for i in reversed(range(length)))
|
||||||
|
|
||||||
def mnemonic_from_indices(self, indices):
|
|
||||||
|
def mnemonic_from_indices(indices):
|
||||||
return " ".join(wordlist[i] for i in indices)
|
return " ".join(wordlist[i] for i in indices)
|
||||||
|
|
||||||
def mnemonic_to_indices(self, mnemonic):
|
|
||||||
try:
|
|
||||||
return (self.word_index_map[word.lower()] for word in mnemonic.split())
|
|
||||||
except KeyError as key_error:
|
|
||||||
raise MnemonicError("Invalid mnemonic word {}.".format(key_error)) from None
|
|
||||||
|
|
||||||
@classmethod
|
def mnemonic_to_indices(mnemonic):
|
||||||
def _round_function(cls, i, passphrase, e, salt, r):
|
return (word_index(word.lower()) for word in mnemonic.split())
|
||||||
|
|
||||||
|
|
||||||
|
def _round_function(i, passphrase, e, salt, r):
|
||||||
"""The round function used internally by the Feistel cipher."""
|
"""The round function used internally by the Feistel cipher."""
|
||||||
return pbkdf2(
|
return pbkdf2(
|
||||||
pbkdf2.HMAC_SHA256,
|
pbkdf2.HMAC_SHA256,
|
||||||
bytes([i]) + passphrase,
|
bytes([i]) + passphrase,
|
||||||
salt + r,
|
salt + r,
|
||||||
(cls.MIN_ITERATION_COUNT << e) // cls.ROUND_COUNT,
|
(MIN_ITERATION_COUNT << e) // ROUND_COUNT,
|
||||||
).key()[: len(r)]
|
).key()[: len(r)]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_salt(cls, identifier):
|
def _get_salt(identifier):
|
||||||
return cls.CUSTOMIZATION_STRING + identifier.to_bytes(
|
return CUSTOMIZATION_STRING + identifier.to_bytes(
|
||||||
math.ceil(cls.ID_LENGTH_BITS / 8), "big"
|
math.ceil(ID_LENGTH_BITS / 8), "big"
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _encrypt(cls, master_secret, passphrase, iteration_exponent, identifier):
|
def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
|
||||||
l = master_secret[: len(master_secret) // 2]
|
l = master_secret[: len(master_secret) // 2]
|
||||||
r = master_secret[len(master_secret) // 2 :]
|
r = master_secret[len(master_secret) // 2 :]
|
||||||
salt = cls._get_salt(identifier)
|
salt = _get_salt(identifier)
|
||||||
for i in range(cls.ROUND_COUNT):
|
for i in range(ROUND_COUNT):
|
||||||
(l, r) = (
|
(l, r) = (
|
||||||
r,
|
r,
|
||||||
cls.xor(
|
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
|
||||||
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return r + l
|
return r + l
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _decrypt(
|
def _decrypt(encrypted_master_secret, passphrase, iteration_exponent, identifier):
|
||||||
cls, encrypted_master_secret, passphrase, iteration_exponent, identifier
|
|
||||||
):
|
|
||||||
l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
|
l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
|
||||||
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
|
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
|
||||||
salt = cls._get_salt(identifier)
|
salt = _get_salt(identifier)
|
||||||
for i in reversed(range(cls.ROUND_COUNT)):
|
for i in reversed(range(ROUND_COUNT)):
|
||||||
(l, r) = (
|
(l, r) = (
|
||||||
r,
|
r,
|
||||||
cls.xor(
|
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
|
||||||
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return r + l
|
return r + l
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _create_digest(cls, random_data, shared_secret):
|
def _create_digest(random_data, shared_secret):
|
||||||
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
|
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
|
||||||
: cls.DIGEST_LENGTH_BYTES
|
:DIGEST_LENGTH_BYTES
|
||||||
]
|
]
|
||||||
|
|
||||||
def _split_secret(self, threshold, share_count, shared_secret):
|
|
||||||
assert 0 < threshold <= share_count <= self.MAX_SHARE_COUNT
|
def _split_secret(threshold, share_count, shared_secret):
|
||||||
|
assert 0 < threshold <= share_count <= MAX_SHARE_COUNT
|
||||||
|
|
||||||
# If the threshold is 1, then the digest of the shared secret is not used.
|
# If the threshold is 1, then the digest of the shared secret is not used.
|
||||||
if threshold == 1:
|
if threshold == 1:
|
||||||
@ -221,23 +211,21 @@ class ShamirMnemonic(object):
|
|||||||
|
|
||||||
random_share_count = threshold - 2
|
random_share_count = threshold - 2
|
||||||
|
|
||||||
if share_count > self.MAX_SHARE_COUNT:
|
if share_count > MAX_SHARE_COUNT:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The requested number of shares ({}) must not exceed {}.".format(
|
"The requested number of shares ({}) must not exceed {}.".format(
|
||||||
share_count, self.MAX_SHARE_COUNT
|
share_count, MAX_SHARE_COUNT
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
shares = [
|
shares = [(i, random.bytes(len(shared_secret))) for i in range(random_share_count)]
|
||||||
(i, random.bytes(len(shared_secret))) for i in range(random_share_count)
|
|
||||||
]
|
|
||||||
|
|
||||||
random_part = random.bytes(len(shared_secret) - self.DIGEST_LENGTH_BYTES)
|
random_part = random.bytes(len(shared_secret) - DIGEST_LENGTH_BYTES)
|
||||||
digest = self._create_digest(random_part, shared_secret)
|
digest = _create_digest(random_part, shared_secret)
|
||||||
|
|
||||||
base_shares = shares + [
|
base_shares = shares + [
|
||||||
(self.DIGEST_INDEX, digest + random_part),
|
(DIGEST_INDEX, digest + random_part),
|
||||||
(self.SECRET_INDEX, shared_secret),
|
(SECRET_INDEX, shared_secret),
|
||||||
]
|
]
|
||||||
|
|
||||||
for i in range(random_share_count, share_count):
|
for i in range(random_share_count, share_count):
|
||||||
@ -245,35 +233,32 @@ class ShamirMnemonic(object):
|
|||||||
|
|
||||||
return shares
|
return shares
|
||||||
|
|
||||||
def _recover_secret(self, threshold, shares):
|
|
||||||
shared_secret = shamir.interpolate(shares, self.SECRET_INDEX)
|
def _recover_secret(threshold, shares):
|
||||||
|
shared_secret = shamir.interpolate(shares, SECRET_INDEX)
|
||||||
|
|
||||||
# If the threshold is 1, then the digest of the shared secret is not used.
|
# If the threshold is 1, then the digest of the shared secret is not used.
|
||||||
if threshold != 1:
|
if threshold != 1:
|
||||||
digest_share = shamir.interpolate(shares, self.DIGEST_INDEX)
|
digest_share = shamir.interpolate(shares, DIGEST_INDEX)
|
||||||
digest = digest_share[: self.DIGEST_LENGTH_BYTES]
|
digest = digest_share[:DIGEST_LENGTH_BYTES]
|
||||||
random_part = digest_share[self.DIGEST_LENGTH_BYTES :]
|
random_part = digest_share[DIGEST_LENGTH_BYTES:]
|
||||||
|
|
||||||
if digest != self._create_digest(random_part, shared_secret):
|
if digest != _create_digest(random_part, shared_secret):
|
||||||
raise MnemonicError("Invalid digest of the shared secret.")
|
raise MnemonicError("Invalid digest of the shared secret.")
|
||||||
|
|
||||||
return shared_secret
|
return shared_secret
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _group_prefix(
|
def _group_prefix(
|
||||||
cls, identifier, iteration_exponent, group_index, group_threshold, group_count
|
identifier, iteration_exponent, group_index, group_threshold, group_count
|
||||||
):
|
):
|
||||||
id_exp_int = (identifier << cls.ITERATION_EXP_LENGTH_BITS) + iteration_exponent
|
id_exp_int = (identifier << ITERATION_EXP_LENGTH_BITS) + iteration_exponent
|
||||||
return tuple(
|
return tuple(_int_to_indices(id_exp_int, ID_EXP_LENGTH_WORDS, RADIX_BITS)) + (
|
||||||
cls._int_to_indices(id_exp_int, cls.ID_EXP_LENGTH_WORDS, cls.RADIX_BITS)
|
(group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2),
|
||||||
) + (
|
|
||||||
(group_index << 6)
|
|
||||||
+ ((group_threshold - 1) << 2)
|
|
||||||
+ ((group_count - 1) >> 2),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def encode_mnemonic(
|
def encode_mnemonic(
|
||||||
self,
|
|
||||||
identifier,
|
identifier,
|
||||||
iteration_exponent,
|
iteration_exponent,
|
||||||
group_index,
|
group_index,
|
||||||
@ -299,77 +284,72 @@ class ShamirMnemonic(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Convert the share value from bytes to wordlist indices.
|
# Convert the share value from bytes to wordlist indices.
|
||||||
value_word_count = math.ceil(len(value) * 8 / self.RADIX_BITS)
|
value_word_count = math.ceil(len(value) * 8 / RADIX_BITS)
|
||||||
value_int = int.from_bytes(value, "big")
|
value_int = int.from_bytes(value, "big")
|
||||||
|
|
||||||
share_data = (
|
share_data = (
|
||||||
self._group_prefix(
|
_group_prefix(
|
||||||
identifier,
|
identifier, iteration_exponent, group_index, group_threshold, group_count
|
||||||
iteration_exponent,
|
|
||||||
group_index,
|
|
||||||
group_threshold,
|
|
||||||
group_count,
|
|
||||||
)
|
)
|
||||||
+ (
|
+ (
|
||||||
(((group_count - 1) & 3) << 8)
|
(((group_count - 1) & 3) << 8)
|
||||||
+ (member_index << 4)
|
+ (member_index << 4)
|
||||||
+ (member_threshold - 1),
|
+ (member_threshold - 1),
|
||||||
)
|
)
|
||||||
+ tuple(self._int_to_indices(value_int, value_word_count, self.RADIX_BITS))
|
+ tuple(_int_to_indices(value_int, value_word_count, RADIX_BITS))
|
||||||
)
|
)
|
||||||
checksum = self.rs1024_create_checksum(share_data)
|
checksum = rs1024_create_checksum(share_data)
|
||||||
|
|
||||||
return self.mnemonic_from_indices(share_data + checksum)
|
return mnemonic_from_indices(share_data + checksum)
|
||||||
|
|
||||||
def decode_mnemonic(self, mnemonic):
|
|
||||||
|
def decode_mnemonic(mnemonic):
|
||||||
"""Converts a share mnemonic to share data."""
|
"""Converts a share mnemonic to share data."""
|
||||||
|
|
||||||
mnemonic_data = tuple(self.mnemonic_to_indices(mnemonic))
|
mnemonic_data = tuple(mnemonic_to_indices(mnemonic))
|
||||||
|
|
||||||
if len(mnemonic_data) < self.MIN_MNEMONIC_LENGTH_WORDS:
|
if len(mnemonic_data) < MIN_MNEMONIC_LENGTH_WORDS:
|
||||||
raise MnemonicError(
|
raise MnemonicError(
|
||||||
"Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format(
|
"Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format(
|
||||||
self.MIN_MNEMONIC_LENGTH_WORDS
|
MIN_MNEMONIC_LENGTH_WORDS
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
padding_len = (10 * (len(mnemonic_data) - self.METADATA_LENGTH_WORDS)) % 16
|
padding_len = (10 * (len(mnemonic_data) - METADATA_LENGTH_WORDS)) % 16
|
||||||
if padding_len > 8:
|
if padding_len > 8:
|
||||||
raise MnemonicError("Invalid mnemonic length.")
|
raise MnemonicError("Invalid mnemonic length.")
|
||||||
|
|
||||||
if not self.rs1024_verify_checksum(mnemonic_data):
|
if not rs1024_verify_checksum(mnemonic_data):
|
||||||
raise MnemonicError(
|
raise MnemonicError(
|
||||||
'Invalid mnemonic checksum for "{} ...".'.format(
|
'Invalid mnemonic checksum for "{} ...".'.format(
|
||||||
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
|
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
id_exp_int = self._int_from_indices(mnemonic_data[: self.ID_EXP_LENGTH_WORDS])
|
id_exp_int = _int_from_indices(mnemonic_data[:ID_EXP_LENGTH_WORDS])
|
||||||
identifier = id_exp_int >> self.ITERATION_EXP_LENGTH_BITS
|
identifier = id_exp_int >> ITERATION_EXP_LENGTH_BITS
|
||||||
iteration_exponent = id_exp_int & ((1 << self.ITERATION_EXP_LENGTH_BITS) - 1)
|
iteration_exponent = id_exp_int & ((1 << ITERATION_EXP_LENGTH_BITS) - 1)
|
||||||
tmp = self._int_from_indices(
|
tmp = _int_from_indices(
|
||||||
mnemonic_data[self.ID_EXP_LENGTH_WORDS : self.ID_EXP_LENGTH_WORDS + 2]
|
mnemonic_data[ID_EXP_LENGTH_WORDS : ID_EXP_LENGTH_WORDS + 2]
|
||||||
)
|
)
|
||||||
group_index, group_threshold, group_count, member_index, member_threshold = self._int_to_indices(
|
group_index, group_threshold, group_count, member_index, member_threshold = _int_to_indices(
|
||||||
tmp, 5, 4
|
tmp, 5, 4
|
||||||
)
|
)
|
||||||
value_data = mnemonic_data[
|
value_data = mnemonic_data[ID_EXP_LENGTH_WORDS + 2 : -CHECKSUM_LENGTH_WORDS]
|
||||||
self.ID_EXP_LENGTH_WORDS + 2 : -self.CHECKSUM_LENGTH_WORDS
|
|
||||||
]
|
|
||||||
|
|
||||||
if group_count < group_threshold:
|
if group_count < group_threshold:
|
||||||
raise MnemonicError(
|
raise MnemonicError(
|
||||||
'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format(
|
'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format(
|
||||||
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
|
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
value_byte_count = (10 * len(value_data) - padding_len) // 8
|
value_byte_count = (10 * len(value_data) - padding_len) // 8
|
||||||
value_int = self._int_from_indices(value_data)
|
value_int = _int_from_indices(value_data)
|
||||||
if value_data[0] >= 1 << (10 - padding_len):
|
if value_data[0] >= 1 << (10 - padding_len):
|
||||||
raise MnemonicError(
|
raise MnemonicError(
|
||||||
'Invalid mnemonic padding for "{} ...".'.format(
|
'Invalid mnemonic padding for "{} ...".'.format(
|
||||||
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
|
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
value = value_int.to_bytes(value_byte_count, "big")
|
value = value_int.to_bytes(value_byte_count, "big")
|
||||||
@ -385,14 +365,15 @@ class ShamirMnemonic(object):
|
|||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _decode_mnemonics(self, mnemonics):
|
|
||||||
|
def _decode_mnemonics(mnemonics):
|
||||||
identifiers = set()
|
identifiers = set()
|
||||||
iteration_exponents = set()
|
iteration_exponents = set()
|
||||||
group_thresholds = set()
|
group_thresholds = set()
|
||||||
group_counts = set()
|
group_counts = set()
|
||||||
groups = {} # { group_index : [member_threshold, set_of_member_shares] }
|
groups = {} # { group_index : [member_threshold, set_of_member_shares] }
|
||||||
for mnemonic in mnemonics:
|
for mnemonic in mnemonics:
|
||||||
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = self.decode_mnemonic(
|
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic(
|
||||||
mnemonic
|
mnemonic
|
||||||
)
|
)
|
||||||
identifiers.add(identifier)
|
identifiers.add(identifier)
|
||||||
@ -409,7 +390,7 @@ class ShamirMnemonic(object):
|
|||||||
if len(identifiers) != 1 or len(iteration_exponents) != 1:
|
if len(identifiers) != 1 or len(iteration_exponents) != 1:
|
||||||
raise MnemonicError(
|
raise MnemonicError(
|
||||||
"Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format(
|
"Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format(
|
||||||
self.ID_EXP_LENGTH_WORDS
|
ID_EXP_LENGTH_WORDS
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -437,21 +418,16 @@ class ShamirMnemonic(object):
|
|||||||
groups,
|
groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_random_identifier(self):
|
|
||||||
|
def _generate_random_identifier():
|
||||||
"""Returns a randomly generated integer in the range 0, ... , 2**ID_LENGTH_BITS - 1."""
|
"""Returns a randomly generated integer in the range 0, ... , 2**ID_LENGTH_BITS - 1."""
|
||||||
|
|
||||||
identifier = int.from_bytes(
|
identifier = int.from_bytes(random.bytes(math.ceil(ID_LENGTH_BITS / 8)), "big")
|
||||||
random.bytes(math.ceil(self.ID_LENGTH_BITS / 8)), "big"
|
return identifier & ((1 << ID_LENGTH_BITS) - 1)
|
||||||
)
|
|
||||||
return identifier & ((1 << self.ID_LENGTH_BITS) - 1)
|
|
||||||
|
|
||||||
def generate_mnemonics(
|
def generate_mnemonics(
|
||||||
self,
|
group_threshold, groups, master_secret, passphrase=b"", iteration_exponent=0
|
||||||
group_threshold,
|
|
||||||
groups,
|
|
||||||
master_secret,
|
|
||||||
passphrase=b"",
|
|
||||||
iteration_exponent=0,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
|
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
|
||||||
@ -469,12 +445,12 @@ class ShamirMnemonic(object):
|
|||||||
:rtype: List of byte arrays.
|
:rtype: List of byte arrays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
identifier = self._generate_random_identifier()
|
identifier = _generate_random_identifier()
|
||||||
|
|
||||||
if len(master_secret) * 8 < self.MIN_STRENGTH_BITS:
|
if len(master_secret) * 8 < MIN_STRENGTH_BITS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The length of the master secret ({} bytes) must be at least {} bytes.".format(
|
"The length of the master secret ({} bytes) must be at least {} bytes.".format(
|
||||||
len(master_secret), math.ceil(self.MIN_STRENGTH_BITS / 8)
|
len(master_secret), math.ceil(MIN_STRENGTH_BITS / 8)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -490,17 +466,15 @@ class ShamirMnemonic(object):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
encrypted_master_secret = self._encrypt(
|
encrypted_master_secret = _encrypt(
|
||||||
master_secret, passphrase, iteration_exponent, identifier
|
master_secret, passphrase, iteration_exponent, identifier
|
||||||
)
|
)
|
||||||
|
|
||||||
group_shares = self._split_secret(
|
group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
|
||||||
group_threshold, len(groups), encrypted_master_secret
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
[
|
[
|
||||||
self.encode_mnemonic(
|
encode_mnemonic(
|
||||||
identifier,
|
identifier,
|
||||||
iteration_exponent,
|
iteration_exponent,
|
||||||
group_index,
|
group_index,
|
||||||
@ -510,7 +484,7 @@ class ShamirMnemonic(object):
|
|||||||
member_threshold,
|
member_threshold,
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
for member_index, value in self._split_secret(
|
for member_index, value in _split_secret(
|
||||||
member_threshold, member_count, group_secret
|
member_threshold, member_count, group_secret
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -519,13 +493,9 @@ class ShamirMnemonic(object):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def generate_mnemonics_random(
|
def generate_mnemonics_random(
|
||||||
self,
|
group_threshold, groups, strength_bits=128, passphrase=b"", iteration_exponent=0
|
||||||
group_threshold,
|
|
||||||
groups,
|
|
||||||
strength_bits=128,
|
|
||||||
passphrase=b"",
|
|
||||||
iteration_exponent=0,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates a random master secret and splits it into mnemonic shares using Shamir's secret
|
Generates a random master secret and splits it into mnemonic shares using Shamir's secret
|
||||||
@ -543,10 +513,10 @@ class ShamirMnemonic(object):
|
|||||||
:rtype: List of byte arrays.
|
:rtype: List of byte arrays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if strength_bits < self.MIN_STRENGTH_BITS:
|
if strength_bits < MIN_STRENGTH_BITS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The requested strength of the master secret ({} bits) must be at least {} bits.".format(
|
"The requested strength of the master secret ({} bits) must be at least {} bits.".format(
|
||||||
strength_bits, self.MIN_STRENGTH_BITS
|
strength_bits, MIN_STRENGTH_BITS
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -557,7 +527,7 @@ class ShamirMnemonic(object):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.generate_mnemonics(
|
return generate_mnemonics(
|
||||||
group_threshold,
|
group_threshold,
|
||||||
groups,
|
groups,
|
||||||
random.bytes(strength_bits // 8),
|
random.bytes(strength_bits // 8),
|
||||||
@ -565,7 +535,8 @@ class ShamirMnemonic(object):
|
|||||||
iteration_exponent,
|
iteration_exponent,
|
||||||
)
|
)
|
||||||
|
|
||||||
def combine_mnemonics(self, mnemonics, passphrase=b""):
|
|
||||||
|
def combine_mnemonics(mnemonics, passphrase=b""):
|
||||||
"""
|
"""
|
||||||
Combines mnemonic shares to obtain the master secret which was previously split using
|
Combines mnemonic shares to obtain the master secret which was previously split using
|
||||||
Shamir's secret sharing scheme.
|
Shamir's secret sharing scheme.
|
||||||
@ -580,7 +551,7 @@ class ShamirMnemonic(object):
|
|||||||
if not mnemonics:
|
if not mnemonics:
|
||||||
raise MnemonicError("The list of mnemonics is empty.")
|
raise MnemonicError("The list of mnemonics is empty.")
|
||||||
|
|
||||||
identifier, iteration_exponent, group_threshold, group_count, groups = self._decode_mnemonics(
|
identifier, iteration_exponent, group_threshold, group_count, groups = _decode_mnemonics(
|
||||||
mnemonics
|
mnemonics
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -602,26 +573,22 @@ class ShamirMnemonic(object):
|
|||||||
|
|
||||||
if len(groups) < group_threshold:
|
if len(groups) < group_threshold:
|
||||||
group_index, group = next(iter(bad_groups.items()))
|
group_index, group = next(iter(bad_groups.items()))
|
||||||
prefix = self._group_prefix(
|
prefix = _group_prefix(
|
||||||
identifier,
|
identifier, iteration_exponent, group_index, group_threshold, group_count
|
||||||
iteration_exponent,
|
|
||||||
group_index,
|
|
||||||
group_threshold,
|
|
||||||
group_count,
|
|
||||||
)
|
)
|
||||||
raise MnemonicError(
|
raise MnemonicError(
|
||||||
'Insufficient number of mnemonics. At least {} mnemonics starting with "{} ..." are required.'.format(
|
'Insufficient number of mnemonics. At least {} mnemonics starting with "{} ..." are required.'.format(
|
||||||
group[0], self.mnemonic_from_indices(prefix)
|
group[0], mnemonic_from_indices(prefix)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
group_shares = [
|
group_shares = [
|
||||||
(group_index, self._recover_secret(group[0], list(group[1])))
|
(group_index, _recover_secret(group[0], list(group[1])))
|
||||||
for group_index, group in groups.items()
|
for group_index, group in groups.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
return self._decrypt(
|
return _decrypt(
|
||||||
self._recover_secret(group_threshold, group_shares),
|
_recover_secret(group_threshold, group_shares),
|
||||||
passphrase,
|
passphrase,
|
||||||
iteration_exponent,
|
iteration_exponent,
|
||||||
identifier,
|
identifier,
|
||||||
|
@ -6,68 +6,66 @@ from slip39_vectors import vectors
|
|||||||
|
|
||||||
class TestCryptoSlip39(unittest.TestCase):
|
class TestCryptoSlip39(unittest.TestCase):
|
||||||
MS = b"ABCDEFGHIJKLMNOP"
|
MS = b"ABCDEFGHIJKLMNOP"
|
||||||
shamir = slip39.ShamirMnemonic()
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_sharing_random(self):
|
def test_basic_sharing_random(self):
|
||||||
mnemonics = self.shamir.generate_mnemonics_random(1, [(3, 5)])[0]
|
mnemonics = slip39.generate_mnemonics_random(1, [(3, 5)])[0]
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.shamir.combine_mnemonics(mnemonics))
|
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), slip39.combine_mnemonics(mnemonics))
|
||||||
|
|
||||||
|
|
||||||
def test_basic_sharing_fixed(self):
|
def test_basic_sharing_fixed(self):
|
||||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS)[0]
|
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0]
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics), self.MS)
|
self.assertEqual(slip39.combine_mnemonics(mnemonics), self.MS)
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||||
with self.assertRaises(slip39.MnemonicError):
|
with self.assertRaises(slip39.MnemonicError):
|
||||||
self.shamir.combine_mnemonics(mnemonics[1:3])
|
slip39.combine_mnemonics(mnemonics[1:3])
|
||||||
|
|
||||||
|
|
||||||
def test_passphrase(self):
|
def test_passphrase(self):
|
||||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
|
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||||
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||||
|
|
||||||
|
|
||||||
def test_iteration_exponent(self):
|
def test_iteration_exponent(self):
|
||||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
|
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||||
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||||
|
|
||||||
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
|
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
|
||||||
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
|
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
|
||||||
|
|
||||||
|
|
||||||
def test_group_sharing(self):
|
def test_group_sharing(self):
|
||||||
mnemonics = self.shamir.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
|
mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
|
||||||
|
|
||||||
# All mnemonics.
|
# All mnemonics.
|
||||||
self.assertEqual(self.shamir.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
|
self.assertEqual(slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
|
||||||
|
|
||||||
# Minimal sets of mnemonics.
|
# Minimal sets of mnemonics.
|
||||||
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
|
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
|
||||||
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
|
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
|
||||||
|
|
||||||
# Two complete groups and one incomplete group.
|
# Two complete groups and one incomplete group.
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
|
self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
|
self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
|
||||||
|
|
||||||
# One complete group and one incomplete group out of two groups required.
|
# One complete group and one incomplete group out of two groups required.
|
||||||
with self.assertRaises(slip39.MnemonicError):
|
with self.assertRaises(slip39.MnemonicError):
|
||||||
self.shamir.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
|
slip39.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
|
||||||
|
|
||||||
# One group of two required.
|
# One group of two required.
|
||||||
with self.assertRaises(slip39.MnemonicError):
|
with self.assertRaises(slip39.MnemonicError):
|
||||||
self.shamir.combine_mnemonics(mnemonics[0][1:4])
|
slip39.combine_mnemonics(mnemonics[0][1:4])
|
||||||
|
|
||||||
|
|
||||||
def test_vectors(self):
|
def test_vectors(self):
|
||||||
for mnemonics, secret in vectors:
|
for mnemonics, secret in vectors:
|
||||||
if secret:
|
if secret:
|
||||||
self.assertEqual(self.shamir.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
|
self.assertEqual(slip39.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
|
||||||
else:
|
else:
|
||||||
with self.assertRaises(slip39.MnemonicError):
|
with self.assertRaises(slip39.MnemonicError):
|
||||||
self.shamir.combine_mnemonics(mnemonics, b"TREZOR")
|
slip39.combine_mnemonics(mnemonics, b"TREZOR")
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_rs1024_checksum(self):
|
def test_invalid_rs1024_checksum(self):
|
||||||
@ -75,7 +73,7 @@ class TestCryptoSlip39(unittest.TestCase):
|
|||||||
"artist away academic academic dismiss spill unkind pencil lair sugar usher elegant paces sweater firm gravity deal body chest sugar"
|
"artist away academic academic dismiss spill unkind pencil lair sugar usher elegant paces sweater firm gravity deal body chest sugar"
|
||||||
]
|
]
|
||||||
with self.assertRaises(slip39.MnemonicError):
|
with self.assertRaises(slip39.MnemonicError):
|
||||||
self.shamir.combine_mnemonics(mnemonics)
|
slip39.combine_mnemonics(mnemonics)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user