1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-06 06:42:33 +00:00

slip39: Remove ShamirMnemonic class. Use binary search to lookup words in wordlist.

This commit is contained in:
Andrew Kozlik 2019-04-14 20:49:14 +02:00
parent d0527997ee
commit cd08c6937b
2 changed files with 522 additions and 557 deletions

View File

@ -33,7 +33,6 @@ class MnemonicError(Exception):
pass pass
class ShamirMnemonic(object):
RADIX_BITS = 10 RADIX_BITS = 10
"""The length of the radix in bits.""" """The length of the radix in bits."""
@ -67,9 +66,7 @@ class ShamirMnemonic(object):
MIN_STRENGTH_BITS = 128 MIN_STRENGTH_BITS = 128
"""The minimum allowed entropy of the master secret.""" """The minimum allowed entropy of the master secret."""
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil( MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil(MIN_STRENGTH_BITS / 10)
MIN_STRENGTH_BITS / 10
)
"""The minimum allowed length of the mnemonic in words.""" """The minimum allowed length of the mnemonic in words."""
MIN_ITERATION_COUNT = 10000 MIN_ITERATION_COUNT = 10000
@ -84,20 +81,22 @@ class ShamirMnemonic(object):
DIGEST_INDEX = 254 DIGEST_INDEX = 254
"""The index of the share containing the digest of the shared secret.""" """The index of the share containing the digest of the shared secret."""
def __init__(self):
# Load the word list.
if len(wordlist) != self.RADIX: def word_index(word):
raise ConfigurationError( lo = 0
"The wordlist should contain {} words, but it contains {} words.".format( hi = len(wordlist)
self.RADIX, len(wordlist) while hi - lo > 1:
) mid = (hi + lo) // 2
) if wordlist[mid] > word:
hi = mid
else:
lo = mid
if not wordlist[lo].startswith(word):
raise MnemonicError('Invalid mnemonic word "{}".'.format(word))
return lo
self.word_index_map = {word: i for i, word in enumerate(wordlist)}
@classmethod def _rs1024_polymod(values):
def _rs1024_polymod(cls, values):
GEN = ( GEN = (
0xE0E040, 0xE0E040,
0x1C1C080, 0x1C1C080,
@ -118,102 +117,93 @@ class ShamirMnemonic(object):
chk ^= GEN[i] if ((b >> i) & 1) else 0 chk ^= GEN[i] if ((b >> i) & 1) else 0
return chk return chk
@classmethod
def rs1024_create_checksum(cls, data): def rs1024_create_checksum(data):
values = ( values = tuple(CUSTOMIZATION_STRING) + data + CHECKSUM_LENGTH_WORDS * (0,)
tuple(cls.CUSTOMIZATION_STRING) + data + cls.CHECKSUM_LENGTH_WORDS * (0,) polymod = _rs1024_polymod(values) ^ 1
)
polymod = cls._rs1024_polymod(values) ^ 1
return tuple( return tuple(
(polymod >> 10 * i) & 1023 (polymod >> 10 * i) & 1023 for i in reversed(range(CHECKSUM_LENGTH_WORDS))
for i in reversed(range(cls.CHECKSUM_LENGTH_WORDS))
) )
@classmethod
def rs1024_verify_checksum(cls, data):
return cls._rs1024_polymod(tuple(cls.CUSTOMIZATION_STRING) + data) == 1
@staticmethod def rs1024_verify_checksum(data):
return _rs1024_polymod(tuple(CUSTOMIZATION_STRING) + data) == 1
def xor(a, b): def xor(a, b):
return bytes(x ^ y for x, y in zip(a, b)) return bytes(x ^ y for x, y in zip(a, b))
@classmethod
def _int_from_indices(cls, indices): def _int_from_indices(indices):
"""Converts a list of base 1024 indices in big endian order to an integer value.""" """Converts a list of base 1024 indices in big endian order to an integer value."""
value = 0 value = 0
for index in indices: for index in indices:
value = (value << cls.RADIX_BITS) + index value = (value << RADIX_BITS) + index
return value return value
@staticmethod
def _int_to_indices(value, length, bits): def _int_to_indices(value, length, bits):
"""Converts an integer value to indices in big endian order.""" """Converts an integer value to indices in big endian order."""
mask = (1 << bits) - 1 mask = (1 << bits) - 1
return ((value >> (i * bits)) & mask for i in reversed(range(length))) return ((value >> (i * bits)) & mask for i in reversed(range(length)))
def mnemonic_from_indices(self, indices):
def mnemonic_from_indices(indices):
return " ".join(wordlist[i] for i in indices) return " ".join(wordlist[i] for i in indices)
def mnemonic_to_indices(self, mnemonic):
try:
return (self.word_index_map[word.lower()] for word in mnemonic.split())
except KeyError as key_error:
raise MnemonicError("Invalid mnemonic word {}.".format(key_error)) from None
@classmethod def mnemonic_to_indices(mnemonic):
def _round_function(cls, i, passphrase, e, salt, r): return (word_index(word.lower()) for word in mnemonic.split())
def _round_function(i, passphrase, e, salt, r):
"""The round function used internally by the Feistel cipher.""" """The round function used internally by the Feistel cipher."""
return pbkdf2( return pbkdf2(
pbkdf2.HMAC_SHA256, pbkdf2.HMAC_SHA256,
bytes([i]) + passphrase, bytes([i]) + passphrase,
salt + r, salt + r,
(cls.MIN_ITERATION_COUNT << e) // cls.ROUND_COUNT, (MIN_ITERATION_COUNT << e) // ROUND_COUNT,
).key()[: len(r)] ).key()[: len(r)]
@classmethod
def _get_salt(cls, identifier): def _get_salt(identifier):
return cls.CUSTOMIZATION_STRING + identifier.to_bytes( return CUSTOMIZATION_STRING + identifier.to_bytes(
math.ceil(cls.ID_LENGTH_BITS / 8), "big" math.ceil(ID_LENGTH_BITS / 8), "big"
) )
@classmethod
def _encrypt(cls, master_secret, passphrase, iteration_exponent, identifier): def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
l = master_secret[: len(master_secret) // 2] l = master_secret[: len(master_secret) // 2]
r = master_secret[len(master_secret) // 2 :] r = master_secret[len(master_secret) // 2 :]
salt = cls._get_salt(identifier) salt = _get_salt(identifier)
for i in range(cls.ROUND_COUNT): for i in range(ROUND_COUNT):
(l, r) = ( (l, r) = (
r, r,
cls.xor( xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
),
) )
return r + l return r + l
@classmethod
def _decrypt( def _decrypt(encrypted_master_secret, passphrase, iteration_exponent, identifier):
cls, encrypted_master_secret, passphrase, iteration_exponent, identifier
):
l = encrypted_master_secret[: len(encrypted_master_secret) // 2] l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :] r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
salt = cls._get_salt(identifier) salt = _get_salt(identifier)
for i in reversed(range(cls.ROUND_COUNT)): for i in reversed(range(ROUND_COUNT)):
(l, r) = ( (l, r) = (
r, r,
cls.xor( xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
),
) )
return r + l return r + l
@classmethod
def _create_digest(cls, random_data, shared_secret): def _create_digest(random_data, shared_secret):
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[ return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
: cls.DIGEST_LENGTH_BYTES :DIGEST_LENGTH_BYTES
] ]
def _split_secret(self, threshold, share_count, shared_secret):
assert 0 < threshold <= share_count <= self.MAX_SHARE_COUNT def _split_secret(threshold, share_count, shared_secret):
assert 0 < threshold <= share_count <= MAX_SHARE_COUNT
# If the threshold is 1, then the digest of the shared secret is not used. # If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1: if threshold == 1:
@ -221,23 +211,21 @@ class ShamirMnemonic(object):
random_share_count = threshold - 2 random_share_count = threshold - 2
if share_count > self.MAX_SHARE_COUNT: if share_count > MAX_SHARE_COUNT:
raise ValueError( raise ValueError(
"The requested number of shares ({}) must not exceed {}.".format( "The requested number of shares ({}) must not exceed {}.".format(
share_count, self.MAX_SHARE_COUNT share_count, MAX_SHARE_COUNT
) )
) )
shares = [ shares = [(i, random.bytes(len(shared_secret))) for i in range(random_share_count)]
(i, random.bytes(len(shared_secret))) for i in range(random_share_count)
]
random_part = random.bytes(len(shared_secret) - self.DIGEST_LENGTH_BYTES) random_part = random.bytes(len(shared_secret) - DIGEST_LENGTH_BYTES)
digest = self._create_digest(random_part, shared_secret) digest = _create_digest(random_part, shared_secret)
base_shares = shares + [ base_shares = shares + [
(self.DIGEST_INDEX, digest + random_part), (DIGEST_INDEX, digest + random_part),
(self.SECRET_INDEX, shared_secret), (SECRET_INDEX, shared_secret),
] ]
for i in range(random_share_count, share_count): for i in range(random_share_count, share_count):
@ -245,35 +233,32 @@ class ShamirMnemonic(object):
return shares return shares
def _recover_secret(self, threshold, shares):
shared_secret = shamir.interpolate(shares, self.SECRET_INDEX) def _recover_secret(threshold, shares):
shared_secret = shamir.interpolate(shares, SECRET_INDEX)
# If the threshold is 1, then the digest of the shared secret is not used. # If the threshold is 1, then the digest of the shared secret is not used.
if threshold != 1: if threshold != 1:
digest_share = shamir.interpolate(shares, self.DIGEST_INDEX) digest_share = shamir.interpolate(shares, DIGEST_INDEX)
digest = digest_share[: self.DIGEST_LENGTH_BYTES] digest = digest_share[:DIGEST_LENGTH_BYTES]
random_part = digest_share[self.DIGEST_LENGTH_BYTES :] random_part = digest_share[DIGEST_LENGTH_BYTES:]
if digest != self._create_digest(random_part, shared_secret): if digest != _create_digest(random_part, shared_secret):
raise MnemonicError("Invalid digest of the shared secret.") raise MnemonicError("Invalid digest of the shared secret.")
return shared_secret return shared_secret
@classmethod
def _group_prefix( def _group_prefix(
cls, identifier, iteration_exponent, group_index, group_threshold, group_count identifier, iteration_exponent, group_index, group_threshold, group_count
): ):
id_exp_int = (identifier << cls.ITERATION_EXP_LENGTH_BITS) + iteration_exponent id_exp_int = (identifier << ITERATION_EXP_LENGTH_BITS) + iteration_exponent
return tuple( return tuple(_int_to_indices(id_exp_int, ID_EXP_LENGTH_WORDS, RADIX_BITS)) + (
cls._int_to_indices(id_exp_int, cls.ID_EXP_LENGTH_WORDS, cls.RADIX_BITS) (group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2),
) + (
(group_index << 6)
+ ((group_threshold - 1) << 2)
+ ((group_count - 1) >> 2),
) )
def encode_mnemonic( def encode_mnemonic(
self,
identifier, identifier,
iteration_exponent, iteration_exponent,
group_index, group_index,
@ -299,77 +284,72 @@ class ShamirMnemonic(object):
""" """
# Convert the share value from bytes to wordlist indices. # Convert the share value from bytes to wordlist indices.
value_word_count = math.ceil(len(value) * 8 / self.RADIX_BITS) value_word_count = math.ceil(len(value) * 8 / RADIX_BITS)
value_int = int.from_bytes(value, "big") value_int = int.from_bytes(value, "big")
share_data = ( share_data = (
self._group_prefix( _group_prefix(
identifier, identifier, iteration_exponent, group_index, group_threshold, group_count
iteration_exponent,
group_index,
group_threshold,
group_count,
) )
+ ( + (
(((group_count - 1) & 3) << 8) (((group_count - 1) & 3) << 8)
+ (member_index << 4) + (member_index << 4)
+ (member_threshold - 1), + (member_threshold - 1),
) )
+ tuple(self._int_to_indices(value_int, value_word_count, self.RADIX_BITS)) + tuple(_int_to_indices(value_int, value_word_count, RADIX_BITS))
) )
checksum = self.rs1024_create_checksum(share_data) checksum = rs1024_create_checksum(share_data)
return self.mnemonic_from_indices(share_data + checksum) return mnemonic_from_indices(share_data + checksum)
def decode_mnemonic(self, mnemonic):
def decode_mnemonic(mnemonic):
"""Converts a share mnemonic to share data.""" """Converts a share mnemonic to share data."""
mnemonic_data = tuple(self.mnemonic_to_indices(mnemonic)) mnemonic_data = tuple(mnemonic_to_indices(mnemonic))
if len(mnemonic_data) < self.MIN_MNEMONIC_LENGTH_WORDS: if len(mnemonic_data) < MIN_MNEMONIC_LENGTH_WORDS:
raise MnemonicError( raise MnemonicError(
"Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format( "Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format(
self.MIN_MNEMONIC_LENGTH_WORDS MIN_MNEMONIC_LENGTH_WORDS
) )
) )
padding_len = (10 * (len(mnemonic_data) - self.METADATA_LENGTH_WORDS)) % 16 padding_len = (10 * (len(mnemonic_data) - METADATA_LENGTH_WORDS)) % 16
if padding_len > 8: if padding_len > 8:
raise MnemonicError("Invalid mnemonic length.") raise MnemonicError("Invalid mnemonic length.")
if not self.rs1024_verify_checksum(mnemonic_data): if not rs1024_verify_checksum(mnemonic_data):
raise MnemonicError( raise MnemonicError(
'Invalid mnemonic checksum for "{} ...".'.format( 'Invalid mnemonic checksum for "{} ...".'.format(
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2]) " ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
) )
) )
id_exp_int = self._int_from_indices(mnemonic_data[: self.ID_EXP_LENGTH_WORDS]) id_exp_int = _int_from_indices(mnemonic_data[:ID_EXP_LENGTH_WORDS])
identifier = id_exp_int >> self.ITERATION_EXP_LENGTH_BITS identifier = id_exp_int >> ITERATION_EXP_LENGTH_BITS
iteration_exponent = id_exp_int & ((1 << self.ITERATION_EXP_LENGTH_BITS) - 1) iteration_exponent = id_exp_int & ((1 << ITERATION_EXP_LENGTH_BITS) - 1)
tmp = self._int_from_indices( tmp = _int_from_indices(
mnemonic_data[self.ID_EXP_LENGTH_WORDS : self.ID_EXP_LENGTH_WORDS + 2] mnemonic_data[ID_EXP_LENGTH_WORDS : ID_EXP_LENGTH_WORDS + 2]
) )
group_index, group_threshold, group_count, member_index, member_threshold = self._int_to_indices( group_index, group_threshold, group_count, member_index, member_threshold = _int_to_indices(
tmp, 5, 4 tmp, 5, 4
) )
value_data = mnemonic_data[ value_data = mnemonic_data[ID_EXP_LENGTH_WORDS + 2 : -CHECKSUM_LENGTH_WORDS]
self.ID_EXP_LENGTH_WORDS + 2 : -self.CHECKSUM_LENGTH_WORDS
]
if group_count < group_threshold: if group_count < group_threshold:
raise MnemonicError( raise MnemonicError(
'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format( 'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format(
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2]) " ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
) )
) )
value_byte_count = (10 * len(value_data) - padding_len) // 8 value_byte_count = (10 * len(value_data) - padding_len) // 8
value_int = self._int_from_indices(value_data) value_int = _int_from_indices(value_data)
if value_data[0] >= 1 << (10 - padding_len): if value_data[0] >= 1 << (10 - padding_len):
raise MnemonicError( raise MnemonicError(
'Invalid mnemonic padding for "{} ...".'.format( 'Invalid mnemonic padding for "{} ...".'.format(
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2]) " ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
) )
) )
value = value_int.to_bytes(value_byte_count, "big") value = value_int.to_bytes(value_byte_count, "big")
@ -385,14 +365,15 @@ class ShamirMnemonic(object):
value, value,
) )
def _decode_mnemonics(self, mnemonics):
def _decode_mnemonics(mnemonics):
identifiers = set() identifiers = set()
iteration_exponents = set() iteration_exponents = set()
group_thresholds = set() group_thresholds = set()
group_counts = set() group_counts = set()
groups = {} # { group_index : [member_threshold, set_of_member_shares] } groups = {} # { group_index : [member_threshold, set_of_member_shares] }
for mnemonic in mnemonics: for mnemonic in mnemonics:
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = self.decode_mnemonic( identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic(
mnemonic mnemonic
) )
identifiers.add(identifier) identifiers.add(identifier)
@ -409,7 +390,7 @@ class ShamirMnemonic(object):
if len(identifiers) != 1 or len(iteration_exponents) != 1: if len(identifiers) != 1 or len(iteration_exponents) != 1:
raise MnemonicError( raise MnemonicError(
"Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format( "Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format(
self.ID_EXP_LENGTH_WORDS ID_EXP_LENGTH_WORDS
) )
) )
@ -437,21 +418,16 @@ class ShamirMnemonic(object):
groups, groups,
) )
def _generate_random_identifier(self):
def _generate_random_identifier():
"""Returns a randomly generated integer in the range 0, ... , 2**ID_LENGTH_BITS - 1.""" """Returns a randomly generated integer in the range 0, ... , 2**ID_LENGTH_BITS - 1."""
identifier = int.from_bytes( identifier = int.from_bytes(random.bytes(math.ceil(ID_LENGTH_BITS / 8)), "big")
random.bytes(math.ceil(self.ID_LENGTH_BITS / 8)), "big" return identifier & ((1 << ID_LENGTH_BITS) - 1)
)
return identifier & ((1 << self.ID_LENGTH_BITS) - 1)
def generate_mnemonics( def generate_mnemonics(
self, group_threshold, groups, master_secret, passphrase=b"", iteration_exponent=0
group_threshold,
groups,
master_secret,
passphrase=b"",
iteration_exponent=0,
): ):
""" """
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme. Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
@ -469,12 +445,12 @@ class ShamirMnemonic(object):
:rtype: List of byte arrays. :rtype: List of byte arrays.
""" """
identifier = self._generate_random_identifier() identifier = _generate_random_identifier()
if len(master_secret) * 8 < self.MIN_STRENGTH_BITS: if len(master_secret) * 8 < MIN_STRENGTH_BITS:
raise ValueError( raise ValueError(
"The length of the master secret ({} bytes) must be at least {} bytes.".format( "The length of the master secret ({} bytes) must be at least {} bytes.".format(
len(master_secret), math.ceil(self.MIN_STRENGTH_BITS / 8) len(master_secret), math.ceil(MIN_STRENGTH_BITS / 8)
) )
) )
@ -490,17 +466,15 @@ class ShamirMnemonic(object):
) )
) )
encrypted_master_secret = self._encrypt( encrypted_master_secret = _encrypt(
master_secret, passphrase, iteration_exponent, identifier master_secret, passphrase, iteration_exponent, identifier
) )
group_shares = self._split_secret( group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
group_threshold, len(groups), encrypted_master_secret
)
return [ return [
[ [
self.encode_mnemonic( encode_mnemonic(
identifier, identifier,
iteration_exponent, iteration_exponent,
group_index, group_index,
@ -510,7 +484,7 @@ class ShamirMnemonic(object):
member_threshold, member_threshold,
value, value,
) )
for member_index, value in self._split_secret( for member_index, value in _split_secret(
member_threshold, member_count, group_secret member_threshold, member_count, group_secret
) )
] ]
@ -519,13 +493,9 @@ class ShamirMnemonic(object):
) )
] ]
def generate_mnemonics_random( def generate_mnemonics_random(
self, group_threshold, groups, strength_bits=128, passphrase=b"", iteration_exponent=0
group_threshold,
groups,
strength_bits=128,
passphrase=b"",
iteration_exponent=0,
): ):
""" """
Generates a random master secret and splits it into mnemonic shares using Shamir's secret Generates a random master secret and splits it into mnemonic shares using Shamir's secret
@ -543,10 +513,10 @@ class ShamirMnemonic(object):
:rtype: List of byte arrays. :rtype: List of byte arrays.
""" """
if strength_bits < self.MIN_STRENGTH_BITS: if strength_bits < MIN_STRENGTH_BITS:
raise ValueError( raise ValueError(
"The requested strength of the master secret ({} bits) must be at least {} bits.".format( "The requested strength of the master secret ({} bits) must be at least {} bits.".format(
strength_bits, self.MIN_STRENGTH_BITS strength_bits, MIN_STRENGTH_BITS
) )
) )
@ -557,7 +527,7 @@ class ShamirMnemonic(object):
) )
) )
return self.generate_mnemonics( return generate_mnemonics(
group_threshold, group_threshold,
groups, groups,
random.bytes(strength_bits // 8), random.bytes(strength_bits // 8),
@ -565,7 +535,8 @@ class ShamirMnemonic(object):
iteration_exponent, iteration_exponent,
) )
def combine_mnemonics(self, mnemonics, passphrase=b""):
def combine_mnemonics(mnemonics, passphrase=b""):
""" """
Combines mnemonic shares to obtain the master secret which was previously split using Combines mnemonic shares to obtain the master secret which was previously split using
Shamir's secret sharing scheme. Shamir's secret sharing scheme.
@ -580,7 +551,7 @@ class ShamirMnemonic(object):
if not mnemonics: if not mnemonics:
raise MnemonicError("The list of mnemonics is empty.") raise MnemonicError("The list of mnemonics is empty.")
identifier, iteration_exponent, group_threshold, group_count, groups = self._decode_mnemonics( identifier, iteration_exponent, group_threshold, group_count, groups = _decode_mnemonics(
mnemonics mnemonics
) )
@ -602,26 +573,22 @@ class ShamirMnemonic(object):
if len(groups) < group_threshold: if len(groups) < group_threshold:
group_index, group = next(iter(bad_groups.items())) group_index, group = next(iter(bad_groups.items()))
prefix = self._group_prefix( prefix = _group_prefix(
identifier, identifier, iteration_exponent, group_index, group_threshold, group_count
iteration_exponent,
group_index,
group_threshold,
group_count,
) )
raise MnemonicError( raise MnemonicError(
'Insufficient number of mnemonics. At least {} mnemonics starting with "{} ..." are required.'.format( 'Insufficient number of mnemonics. At least {} mnemonics starting with "{} ..." are required.'.format(
group[0], self.mnemonic_from_indices(prefix) group[0], mnemonic_from_indices(prefix)
) )
) )
group_shares = [ group_shares = [
(group_index, self._recover_secret(group[0], list(group[1]))) (group_index, _recover_secret(group[0], list(group[1])))
for group_index, group in groups.items() for group_index, group in groups.items()
] ]
return self._decrypt( return _decrypt(
self._recover_secret(group_threshold, group_shares), _recover_secret(group_threshold, group_shares),
passphrase, passphrase,
iteration_exponent, iteration_exponent,
identifier, identifier,

View File

@ -6,68 +6,66 @@ from slip39_vectors import vectors
class TestCryptoSlip39(unittest.TestCase): class TestCryptoSlip39(unittest.TestCase):
MS = b"ABCDEFGHIJKLMNOP" MS = b"ABCDEFGHIJKLMNOP"
shamir = slip39.ShamirMnemonic()
def test_basic_sharing_random(self): def test_basic_sharing_random(self):
mnemonics = self.shamir.generate_mnemonics_random(1, [(3, 5)])[0] mnemonics = slip39.generate_mnemonics_random(1, [(3, 5)])[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.shamir.combine_mnemonics(mnemonics)) self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), slip39.combine_mnemonics(mnemonics))
def test_basic_sharing_fixed(self): def test_basic_sharing_fixed(self):
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS)[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics), self.MS)
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics[1:3]) slip39.combine_mnemonics(mnemonics[1:3])
def test_passphrase(self): def test_passphrase(self):
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
def test_iteration_exponent(self): def test_iteration_exponent(self):
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
def test_group_sharing(self): def test_group_sharing(self):
mnemonics = self.shamir.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS) mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
# All mnemonics. # All mnemonics.
self.assertEqual(self.shamir.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS) self.assertEqual(slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
# Minimal sets of mnemonics. # Minimal sets of mnemonics.
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS) self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS) self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
# Two complete groups and one incomplete group. # Two complete groups and one incomplete group.
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
# One complete group and one incomplete group out of two groups required. # One complete group and one incomplete group out of two groups required.
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]]) slip39.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
# One group of two required. # One group of two required.
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics[0][1:4]) slip39.combine_mnemonics(mnemonics[0][1:4])
def test_vectors(self): def test_vectors(self):
for mnemonics, secret in vectors: for mnemonics, secret in vectors:
if secret: if secret:
self.assertEqual(self.shamir.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret)) self.assertEqual(slip39.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
else: else:
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics, b"TREZOR") slip39.combine_mnemonics(mnemonics, b"TREZOR")
def test_invalid_rs1024_checksum(self): def test_invalid_rs1024_checksum(self):
@ -75,7 +73,7 @@ class TestCryptoSlip39(unittest.TestCase):
"artist away academic academic dismiss spill unkind pencil lair sugar usher elegant paces sweater firm gravity deal body chest sugar" "artist away academic academic dismiss spill unkind pencil lair sugar usher elegant paces sweater firm gravity deal body chest sugar"
] ]
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics) slip39.combine_mnemonics(mnemonics)
if __name__ == '__main__': if __name__ == '__main__':