1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-16 17:42:02 +00:00

slip39: Remove ShamirMnemonic class. Use binary search to lookup words in wordlist.

This commit is contained in:
Andrew Kozlik 2019-04-14 20:49:14 +02:00
parent d0527997ee
commit cd08c6937b
2 changed files with 522 additions and 557 deletions

View File

@ -33,7 +33,6 @@ class MnemonicError(Exception):
pass
class ShamirMnemonic(object):
RADIX_BITS = 10
"""The length of the radix in bits."""
@ -67,9 +66,7 @@ class ShamirMnemonic(object):
MIN_STRENGTH_BITS = 128
"""The minimum allowed entropy of the master secret."""
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil(
MIN_STRENGTH_BITS / 10
)
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + math.ceil(MIN_STRENGTH_BITS / 10)
"""The minimum allowed length of the mnemonic in words."""
MIN_ITERATION_COUNT = 10000
@ -84,20 +81,22 @@ class ShamirMnemonic(object):
DIGEST_INDEX = 254
"""The index of the share containing the digest of the shared secret."""
def __init__(self):
# Load the word list.
if len(wordlist) != self.RADIX:
raise ConfigurationError(
"The wordlist should contain {} words, but it contains {} words.".format(
self.RADIX, len(wordlist)
)
)
def word_index(word):
lo = 0
hi = len(wordlist)
while hi - lo > 1:
mid = (hi + lo) // 2
if wordlist[mid] > word:
hi = mid
else:
lo = mid
if not wordlist[lo].startswith(word):
raise MnemonicError('Invalid mnemonic word "{}".'.format(word))
return lo
self.word_index_map = {word: i for i, word in enumerate(wordlist)}
@classmethod
def _rs1024_polymod(cls, values):
def _rs1024_polymod(values):
GEN = (
0xE0E040,
0x1C1C080,
@ -118,102 +117,93 @@ class ShamirMnemonic(object):
chk ^= GEN[i] if ((b >> i) & 1) else 0
return chk
@classmethod
def rs1024_create_checksum(cls, data):
values = (
tuple(cls.CUSTOMIZATION_STRING) + data + cls.CHECKSUM_LENGTH_WORDS * (0,)
)
polymod = cls._rs1024_polymod(values) ^ 1
def rs1024_create_checksum(data):
values = tuple(CUSTOMIZATION_STRING) + data + CHECKSUM_LENGTH_WORDS * (0,)
polymod = _rs1024_polymod(values) ^ 1
return tuple(
(polymod >> 10 * i) & 1023
for i in reversed(range(cls.CHECKSUM_LENGTH_WORDS))
(polymod >> 10 * i) & 1023 for i in reversed(range(CHECKSUM_LENGTH_WORDS))
)
@classmethod
def rs1024_verify_checksum(cls, data):
return cls._rs1024_polymod(tuple(cls.CUSTOMIZATION_STRING) + data) == 1
@staticmethod
def rs1024_verify_checksum(data):
return _rs1024_polymod(tuple(CUSTOMIZATION_STRING) + data) == 1
def xor(a, b):
return bytes(x ^ y for x, y in zip(a, b))
@classmethod
def _int_from_indices(cls, indices):
def _int_from_indices(indices):
"""Converts a list of base 1024 indices in big endian order to an integer value."""
value = 0
for index in indices:
value = (value << cls.RADIX_BITS) + index
value = (value << RADIX_BITS) + index
return value
@staticmethod
def _int_to_indices(value, length, bits):
"""Converts an integer value to indices in big endian order."""
mask = (1 << bits) - 1
return ((value >> (i * bits)) & mask for i in reversed(range(length)))
def mnemonic_from_indices(self, indices):
def mnemonic_from_indices(indices):
return " ".join(wordlist[i] for i in indices)
def mnemonic_to_indices(self, mnemonic):
try:
return (self.word_index_map[word.lower()] for word in mnemonic.split())
except KeyError as key_error:
raise MnemonicError("Invalid mnemonic word {}.".format(key_error)) from None
@classmethod
def _round_function(cls, i, passphrase, e, salt, r):
def mnemonic_to_indices(mnemonic):
return (word_index(word.lower()) for word in mnemonic.split())
def _round_function(i, passphrase, e, salt, r):
"""The round function used internally by the Feistel cipher."""
return pbkdf2(
pbkdf2.HMAC_SHA256,
bytes([i]) + passphrase,
salt + r,
(cls.MIN_ITERATION_COUNT << e) // cls.ROUND_COUNT,
(MIN_ITERATION_COUNT << e) // ROUND_COUNT,
).key()[: len(r)]
@classmethod
def _get_salt(cls, identifier):
return cls.CUSTOMIZATION_STRING + identifier.to_bytes(
math.ceil(cls.ID_LENGTH_BITS / 8), "big"
def _get_salt(identifier):
return CUSTOMIZATION_STRING + identifier.to_bytes(
math.ceil(ID_LENGTH_BITS / 8), "big"
)
@classmethod
def _encrypt(cls, master_secret, passphrase, iteration_exponent, identifier):
def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
l = master_secret[: len(master_secret) // 2]
r = master_secret[len(master_secret) // 2 :]
salt = cls._get_salt(identifier)
for i in range(cls.ROUND_COUNT):
salt = _get_salt(identifier)
for i in range(ROUND_COUNT):
(l, r) = (
r,
cls.xor(
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
),
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
)
return r + l
@classmethod
def _decrypt(
cls, encrypted_master_secret, passphrase, iteration_exponent, identifier
):
def _decrypt(encrypted_master_secret, passphrase, iteration_exponent, identifier):
l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
salt = cls._get_salt(identifier)
for i in reversed(range(cls.ROUND_COUNT)):
salt = _get_salt(identifier)
for i in reversed(range(ROUND_COUNT)):
(l, r) = (
r,
cls.xor(
l, cls._round_function(i, passphrase, iteration_exponent, salt, r)
),
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
)
return r + l
@classmethod
def _create_digest(cls, random_data, shared_secret):
def _create_digest(random_data, shared_secret):
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
: cls.DIGEST_LENGTH_BYTES
:DIGEST_LENGTH_BYTES
]
def _split_secret(self, threshold, share_count, shared_secret):
assert 0 < threshold <= share_count <= self.MAX_SHARE_COUNT
def _split_secret(threshold, share_count, shared_secret):
assert 0 < threshold <= share_count <= MAX_SHARE_COUNT
# If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1:
@ -221,23 +211,21 @@ class ShamirMnemonic(object):
random_share_count = threshold - 2
if share_count > self.MAX_SHARE_COUNT:
if share_count > MAX_SHARE_COUNT:
raise ValueError(
"The requested number of shares ({}) must not exceed {}.".format(
share_count, self.MAX_SHARE_COUNT
share_count, MAX_SHARE_COUNT
)
)
shares = [
(i, random.bytes(len(shared_secret))) for i in range(random_share_count)
]
shares = [(i, random.bytes(len(shared_secret))) for i in range(random_share_count)]
random_part = random.bytes(len(shared_secret) - self.DIGEST_LENGTH_BYTES)
digest = self._create_digest(random_part, shared_secret)
random_part = random.bytes(len(shared_secret) - DIGEST_LENGTH_BYTES)
digest = _create_digest(random_part, shared_secret)
base_shares = shares + [
(self.DIGEST_INDEX, digest + random_part),
(self.SECRET_INDEX, shared_secret),
(DIGEST_INDEX, digest + random_part),
(SECRET_INDEX, shared_secret),
]
for i in range(random_share_count, share_count):
@ -245,35 +233,32 @@ class ShamirMnemonic(object):
return shares
def _recover_secret(self, threshold, shares):
shared_secret = shamir.interpolate(shares, self.SECRET_INDEX)
def _recover_secret(threshold, shares):
shared_secret = shamir.interpolate(shares, SECRET_INDEX)
# If the threshold is 1, then the digest of the shared secret is not used.
if threshold != 1:
digest_share = shamir.interpolate(shares, self.DIGEST_INDEX)
digest = digest_share[: self.DIGEST_LENGTH_BYTES]
random_part = digest_share[self.DIGEST_LENGTH_BYTES :]
digest_share = shamir.interpolate(shares, DIGEST_INDEX)
digest = digest_share[:DIGEST_LENGTH_BYTES]
random_part = digest_share[DIGEST_LENGTH_BYTES:]
if digest != self._create_digest(random_part, shared_secret):
if digest != _create_digest(random_part, shared_secret):
raise MnemonicError("Invalid digest of the shared secret.")
return shared_secret
@classmethod
def _group_prefix(
cls, identifier, iteration_exponent, group_index, group_threshold, group_count
identifier, iteration_exponent, group_index, group_threshold, group_count
):
id_exp_int = (identifier << cls.ITERATION_EXP_LENGTH_BITS) + iteration_exponent
return tuple(
cls._int_to_indices(id_exp_int, cls.ID_EXP_LENGTH_WORDS, cls.RADIX_BITS)
) + (
(group_index << 6)
+ ((group_threshold - 1) << 2)
+ ((group_count - 1) >> 2),
id_exp_int = (identifier << ITERATION_EXP_LENGTH_BITS) + iteration_exponent
return tuple(_int_to_indices(id_exp_int, ID_EXP_LENGTH_WORDS, RADIX_BITS)) + (
(group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2),
)
def encode_mnemonic(
self,
identifier,
iteration_exponent,
group_index,
@ -299,77 +284,72 @@ class ShamirMnemonic(object):
"""
# Convert the share value from bytes to wordlist indices.
value_word_count = math.ceil(len(value) * 8 / self.RADIX_BITS)
value_word_count = math.ceil(len(value) * 8 / RADIX_BITS)
value_int = int.from_bytes(value, "big")
share_data = (
self._group_prefix(
identifier,
iteration_exponent,
group_index,
group_threshold,
group_count,
_group_prefix(
identifier, iteration_exponent, group_index, group_threshold, group_count
)
+ (
(((group_count - 1) & 3) << 8)
+ (member_index << 4)
+ (member_threshold - 1),
)
+ tuple(self._int_to_indices(value_int, value_word_count, self.RADIX_BITS))
+ tuple(_int_to_indices(value_int, value_word_count, RADIX_BITS))
)
checksum = self.rs1024_create_checksum(share_data)
checksum = rs1024_create_checksum(share_data)
return self.mnemonic_from_indices(share_data + checksum)
return mnemonic_from_indices(share_data + checksum)
def decode_mnemonic(self, mnemonic):
def decode_mnemonic(mnemonic):
"""Converts a share mnemonic to share data."""
mnemonic_data = tuple(self.mnemonic_to_indices(mnemonic))
mnemonic_data = tuple(mnemonic_to_indices(mnemonic))
if len(mnemonic_data) < self.MIN_MNEMONIC_LENGTH_WORDS:
if len(mnemonic_data) < MIN_MNEMONIC_LENGTH_WORDS:
raise MnemonicError(
"Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format(
self.MIN_MNEMONIC_LENGTH_WORDS
MIN_MNEMONIC_LENGTH_WORDS
)
)
padding_len = (10 * (len(mnemonic_data) - self.METADATA_LENGTH_WORDS)) % 16
padding_len = (10 * (len(mnemonic_data) - METADATA_LENGTH_WORDS)) % 16
if padding_len > 8:
raise MnemonicError("Invalid mnemonic length.")
if not self.rs1024_verify_checksum(mnemonic_data):
if not rs1024_verify_checksum(mnemonic_data):
raise MnemonicError(
'Invalid mnemonic checksum for "{} ...".'.format(
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
)
)
id_exp_int = self._int_from_indices(mnemonic_data[: self.ID_EXP_LENGTH_WORDS])
identifier = id_exp_int >> self.ITERATION_EXP_LENGTH_BITS
iteration_exponent = id_exp_int & ((1 << self.ITERATION_EXP_LENGTH_BITS) - 1)
tmp = self._int_from_indices(
mnemonic_data[self.ID_EXP_LENGTH_WORDS : self.ID_EXP_LENGTH_WORDS + 2]
id_exp_int = _int_from_indices(mnemonic_data[:ID_EXP_LENGTH_WORDS])
identifier = id_exp_int >> ITERATION_EXP_LENGTH_BITS
iteration_exponent = id_exp_int & ((1 << ITERATION_EXP_LENGTH_BITS) - 1)
tmp = _int_from_indices(
mnemonic_data[ID_EXP_LENGTH_WORDS : ID_EXP_LENGTH_WORDS + 2]
)
group_index, group_threshold, group_count, member_index, member_threshold = self._int_to_indices(
group_index, group_threshold, group_count, member_index, member_threshold = _int_to_indices(
tmp, 5, 4
)
value_data = mnemonic_data[
self.ID_EXP_LENGTH_WORDS + 2 : -self.CHECKSUM_LENGTH_WORDS
]
value_data = mnemonic_data[ID_EXP_LENGTH_WORDS + 2 : -CHECKSUM_LENGTH_WORDS]
if group_count < group_threshold:
raise MnemonicError(
'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format(
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
)
)
value_byte_count = (10 * len(value_data) - padding_len) // 8
value_int = self._int_from_indices(value_data)
value_int = _int_from_indices(value_data)
if value_data[0] >= 1 << (10 - padding_len):
raise MnemonicError(
'Invalid mnemonic padding for "{} ...".'.format(
" ".join(mnemonic.split()[: self.ID_EXP_LENGTH_WORDS + 2])
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2])
)
)
value = value_int.to_bytes(value_byte_count, "big")
@ -385,14 +365,15 @@ class ShamirMnemonic(object):
value,
)
def _decode_mnemonics(self, mnemonics):
def _decode_mnemonics(mnemonics):
identifiers = set()
iteration_exponents = set()
group_thresholds = set()
group_counts = set()
groups = {} # { group_index : [member_threshold, set_of_member_shares] }
for mnemonic in mnemonics:
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = self.decode_mnemonic(
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic(
mnemonic
)
identifiers.add(identifier)
@ -409,7 +390,7 @@ class ShamirMnemonic(object):
if len(identifiers) != 1 or len(iteration_exponents) != 1:
raise MnemonicError(
"Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format(
self.ID_EXP_LENGTH_WORDS
ID_EXP_LENGTH_WORDS
)
)
@ -437,21 +418,16 @@ class ShamirMnemonic(object):
groups,
)
def _generate_random_identifier(self):
def _generate_random_identifier():
"""Returns a randomly generated integer in the range 0, ... , 2**ID_LENGTH_BITS - 1."""
identifier = int.from_bytes(
random.bytes(math.ceil(self.ID_LENGTH_BITS / 8)), "big"
)
return identifier & ((1 << self.ID_LENGTH_BITS) - 1)
identifier = int.from_bytes(random.bytes(math.ceil(ID_LENGTH_BITS / 8)), "big")
return identifier & ((1 << ID_LENGTH_BITS) - 1)
def generate_mnemonics(
self,
group_threshold,
groups,
master_secret,
passphrase=b"",
iteration_exponent=0,
group_threshold, groups, master_secret, passphrase=b"", iteration_exponent=0
):
"""
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
@ -469,12 +445,12 @@ class ShamirMnemonic(object):
:rtype: List of byte arrays.
"""
identifier = self._generate_random_identifier()
identifier = _generate_random_identifier()
if len(master_secret) * 8 < self.MIN_STRENGTH_BITS:
if len(master_secret) * 8 < MIN_STRENGTH_BITS:
raise ValueError(
"The length of the master secret ({} bytes) must be at least {} bytes.".format(
len(master_secret), math.ceil(self.MIN_STRENGTH_BITS / 8)
len(master_secret), math.ceil(MIN_STRENGTH_BITS / 8)
)
)
@ -490,17 +466,15 @@ class ShamirMnemonic(object):
)
)
encrypted_master_secret = self._encrypt(
encrypted_master_secret = _encrypt(
master_secret, passphrase, iteration_exponent, identifier
)
group_shares = self._split_secret(
group_threshold, len(groups), encrypted_master_secret
)
group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
return [
[
self.encode_mnemonic(
encode_mnemonic(
identifier,
iteration_exponent,
group_index,
@ -510,7 +484,7 @@ class ShamirMnemonic(object):
member_threshold,
value,
)
for member_index, value in self._split_secret(
for member_index, value in _split_secret(
member_threshold, member_count, group_secret
)
]
@ -519,13 +493,9 @@ class ShamirMnemonic(object):
)
]
def generate_mnemonics_random(
self,
group_threshold,
groups,
strength_bits=128,
passphrase=b"",
iteration_exponent=0,
group_threshold, groups, strength_bits=128, passphrase=b"", iteration_exponent=0
):
"""
Generates a random master secret and splits it into mnemonic shares using Shamir's secret
@ -543,10 +513,10 @@ class ShamirMnemonic(object):
:rtype: List of byte arrays.
"""
if strength_bits < self.MIN_STRENGTH_BITS:
if strength_bits < MIN_STRENGTH_BITS:
raise ValueError(
"The requested strength of the master secret ({} bits) must be at least {} bits.".format(
strength_bits, self.MIN_STRENGTH_BITS
strength_bits, MIN_STRENGTH_BITS
)
)
@ -557,7 +527,7 @@ class ShamirMnemonic(object):
)
)
return self.generate_mnemonics(
return generate_mnemonics(
group_threshold,
groups,
random.bytes(strength_bits // 8),
@ -565,7 +535,8 @@ class ShamirMnemonic(object):
iteration_exponent,
)
def combine_mnemonics(self, mnemonics, passphrase=b""):
def combine_mnemonics(mnemonics, passphrase=b""):
"""
Combines mnemonic shares to obtain the master secret which was previously split using
Shamir's secret sharing scheme.
@ -580,7 +551,7 @@ class ShamirMnemonic(object):
if not mnemonics:
raise MnemonicError("The list of mnemonics is empty.")
identifier, iteration_exponent, group_threshold, group_count, groups = self._decode_mnemonics(
identifier, iteration_exponent, group_threshold, group_count, groups = _decode_mnemonics(
mnemonics
)
@ -602,26 +573,22 @@ class ShamirMnemonic(object):
if len(groups) < group_threshold:
group_index, group = next(iter(bad_groups.items()))
prefix = self._group_prefix(
identifier,
iteration_exponent,
group_index,
group_threshold,
group_count,
prefix = _group_prefix(
identifier, iteration_exponent, group_index, group_threshold, group_count
)
raise MnemonicError(
'Insufficient number of mnemonics. At least {} mnemonics starting with "{} ..." are required.'.format(
group[0], self.mnemonic_from_indices(prefix)
group[0], mnemonic_from_indices(prefix)
)
)
group_shares = [
(group_index, self._recover_secret(group[0], list(group[1])))
(group_index, _recover_secret(group[0], list(group[1])))
for group_index, group in groups.items()
]
return self._decrypt(
self._recover_secret(group_threshold, group_shares),
return _decrypt(
_recover_secret(group_threshold, group_shares),
passphrase,
iteration_exponent,
identifier,

View File

@ -6,68 +6,66 @@ from slip39_vectors import vectors
class TestCryptoSlip39(unittest.TestCase):
MS = b"ABCDEFGHIJKLMNOP"
shamir = slip39.ShamirMnemonic()
def test_basic_sharing_random(self):
mnemonics = self.shamir.generate_mnemonics_random(1, [(3, 5)])[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.shamir.combine_mnemonics(mnemonics))
mnemonics = slip39.generate_mnemonics_random(1, [(3, 5)])[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), slip39.combine_mnemonics(mnemonics))
def test_basic_sharing_fixed(self):
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS)[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics), self.MS)
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics[1:3])
slip39.combine_mnemonics(mnemonics[1:3])
def test_passphrase(self):
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
def test_iteration_exponent(self):
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
mnemonics = self.shamir.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(self.shamir.combine_mnemonics(mnemonics[1:4]), self.MS)
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
def test_group_sharing(self):
mnemonics = self.shamir.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
# All mnemonics.
self.assertEqual(self.shamir.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
self.assertEqual(slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
# Minimal sets of mnemonics.
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
self.assertEqual(self.shamir.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
# Two complete groups and one incomplete group.
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
self.assertEqual(self.shamir.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
# One complete group and one incomplete group out of two groups required.
with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
slip39.combine_mnemonics(mnemonics[0][2:] + [mnemonics[1][0]])
# One group of two required.
with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics[0][1:4])
slip39.combine_mnemonics(mnemonics[0][1:4])
def test_vectors(self):
for mnemonics, secret in vectors:
if secret:
self.assertEqual(self.shamir.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
self.assertEqual(slip39.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
else:
with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics, b"TREZOR")
slip39.combine_mnemonics(mnemonics, b"TREZOR")
def test_invalid_rs1024_checksum(self):
@ -75,7 +73,7 @@ class TestCryptoSlip39(unittest.TestCase):
"artist away academic academic dismiss spill unkind pencil lair sugar usher elegant paces sweater firm gravity deal body chest sugar"
]
with self.assertRaises(slip39.MnemonicError):
self.shamir.combine_mnemonics(mnemonics)
slip39.combine_mnemonics(mnemonics)
if __name__ == '__main__':