1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00
trezor-firmware/core/tests/test_trezor.crypto.slip39.py

185 lines
7.4 KiB
Python

from common import * # isort:skip
from slip39_vectors import vectors
from trezor.crypto import random, slip39
def combinations(iterable, r):
# Taken from https://docs.python.org/3.7/library/itertools.html#itertools.combinations
pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = list(range(r))
yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i + 1, r):
indices[j] = indices[j - 1] + 1
yield tuple(pool[i] for i in indices)
class TestCryptoSlip39(unittest.TestCase):
EMS = b"ABCDEFGHIJKLMNOP"
def test_basic_sharing_random(self):
ems = random.bytes(32)
identifier = slip39.generate_random_identifier()
for extendable in (False, True):
mnemonics = slip39.split_ems(1, [(3, 5)], identifier, extendable, 1, ems)
mnemonics = mnemonics[0]
self.assertEqual(
slip39.recover_ems(mnemonics[:3]), slip39.recover_ems(mnemonics[2:])
)
def test_basic_sharing_fixed(self):
for extendable in (False, True):
generated_identifier = slip39.generate_random_identifier()
mnemonics = slip39.split_ems(1, [(3, 5)], generated_identifier, extendable, 1, self.EMS)
mnemonics = mnemonics[0]
identifier, _, _, ems = slip39.recover_ems(mnemonics[:3])
self.assertEqual(ems, self.EMS)
self.assertEqual(generated_identifier, identifier)
self.assertEqual(slip39.recover_ems(mnemonics[1:4])[3], ems)
with self.assertRaises(slip39.MnemonicError):
slip39.recover_ems(mnemonics[1:3])
def test_iteration_exponent(self):
for extendable in (False, True):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.split_ems(1, [(3, 5)], identifier, extendable, 1, self.EMS)
mnemonics = mnemonics[0]
identifier, extendable, exponent, ems = slip39.recover_ems(mnemonics[1:4])
self.assertEqual(ems, self.EMS)
identifier = slip39.generate_random_identifier()
mnemonics = slip39.split_ems(1, [(3, 5)], identifier, extendable, 2, self.EMS)
mnemonics = mnemonics[0]
identifier, extendable, exponent, ems = slip39.recover_ems(mnemonics[1:4])
self.assertEqual(ems, self.EMS)
def test_group_sharing(self):
group_threshold = 2
group_sizes = (5, 3, 5, 1)
member_thresholds = (3, 2, 2, 1)
for extendable in (False, True):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.split_ems(
group_threshold,
list(zip(member_thresholds, group_sizes)),
identifier,
extendable,
1,
self.EMS,
)
# Test all valid combinations of mnemonics.
for groups in combinations(zip(mnemonics, member_thresholds), group_threshold):
for group1_subset in combinations(groups[0][0], groups[0][1]):
for group2_subset in combinations(groups[1][0], groups[1][1]):
mnemonic_subset = list(group1_subset + group2_subset)
random.shuffle(mnemonic_subset)
identifier, _, _, ems = slip39.recover_ems(mnemonic_subset)
self.assertEqual(ems, self.EMS)
# Minimal sets of mnemonics.
identifier, _, _, ems = slip39.recover_ems(
[mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]
)
self.assertEqual(ems, self.EMS)
self.assertEqual(
slip39.recover_ems([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[3],
ems,
)
# One complete group and one incomplete group out of two groups required.
with self.assertRaises(slip39.MnemonicError):
slip39.recover_ems(mnemonics[0][2:] + [mnemonics[1][0]])
# One group of two required.
with self.assertRaises(slip39.MnemonicError):
slip39.recover_ems(mnemonics[0][1:4])
def test_group_sharing_threshold_1(self):
group_threshold = 1
group_sizes = (5, 3, 5, 1)
member_thresholds = (3, 2, 2, 1)
for extendable in (False, True):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.split_ems(
group_threshold,
list(zip(member_thresholds, group_sizes)),
identifier,
extendable,
1,
self.EMS,
)
# Test all valid combinations of mnemonics.
for group, threshold in zip(mnemonics, member_thresholds):
for group_subset in combinations(group, threshold):
mnemonic_subset = list(group_subset)
random.shuffle(mnemonic_subset)
identifier, _, _, ems = slip39.recover_ems(mnemonic_subset)
self.assertEqual(ems, self.EMS)
def test_all_groups_exist(self):
for extendable in (False, True):
for group_threshold in (1, 2, 5):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.split_ems(
group_threshold,
[(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)],
identifier,
extendable,
1,
self.EMS,
)
self.assertEqual(len(mnemonics), 5)
self.assertEqual(len(sum(mnemonics, [])), 19)
def test_invalid_sharing(self):
for extendable in (False, True):
identifier = slip39.generate_random_identifier()
# Group threshold exceeds number of groups.
with self.assertRaises(ValueError):
slip39.split_ems(3, [(3, 5), (2, 5)], identifier, extendable, 1, self.EMS)
# Invalid group threshold.
with self.assertRaises(ValueError):
slip39.split_ems(0, [(3, 5), (2, 5)], identifier, extendable, 1, self.EMS)
# Member threshold exceeds number of members.
with self.assertRaises(ValueError):
slip39.split_ems(2, [(3, 2), (2, 5)], identifier, extendable, 1, self.EMS)
# Invalid member threshold.
with self.assertRaises(ValueError):
slip39.split_ems(2, [(0, 2), (2, 5)], identifier, extendable, 1, self.EMS)
# Group with multiple members and threshold 1.
with self.assertRaises(ValueError):
slip39.split_ems(2, [(3, 5), (1, 3), (2, 5)], identifier, extendable, 1, self.EMS)
def test_vectors(self):
for mnemonics, secret in vectors:
if secret:
identifier, extendable, exponent, ems = slip39.recover_ems(mnemonics)
self.assertEqual(
slip39.decrypt(ems, b"TREZOR", exponent, identifier, extendable),
unhexlify(secret),
)
else:
with self.assertRaises(slip39.MnemonicError):
slip39.recover_ems(mnemonics)
if __name__ == "__main__":
unittest.main()