1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-16 11:28:14 +00:00

slip39: Split decryption out from combine_mnemonics().

This commit is contained in:
Andrew Kozlik 2019-04-15 18:13:29 +02:00
parent 04dcfea901
commit 1003ed9083
2 changed files with 28 additions and 25 deletions

View File

@ -184,7 +184,7 @@ def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
return r + l
def _decrypt(encrypted_master_secret, passphrase, iteration_exponent, identifier):
def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase):
l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
salt = _get_salt(identifier)
@ -536,16 +536,14 @@ def generate_mnemonics_random(
)
def combine_mnemonics(mnemonics, passphrase=b""):
def combine_mnemonics(mnemonics):
"""
Combines mnemonic shares to obtain the master secret which was previously split using
Shamir's secret sharing scheme.
:param mnemonics: List of mnemonics.
:type mnemonics: List of byte arrays.
:param passphrase: The passphrase used to encrypt the master secret.
:type passphrase: Array of bytes.
:return: The master secret.
:rtype: Array of bytes.
:return: Identifier, iteration exponent, the encrypted master secret.
:rtype: Integer, integer, array of bytes.
"""
if not mnemonics:
@ -587,9 +585,8 @@ def combine_mnemonics(mnemonics, passphrase=b""):
for group_index, group in groups.items()
]
return _decrypt(
_recover_secret(group_threshold, group_shares),
passphrase,
iteration_exponent,
return (
identifier,
iteration_exponent,
_recover_secret(group_threshold, group_shares),
)

View File

@ -14,41 +14,46 @@ class TestCryptoSlip39(unittest.TestCase):
def test_basic_sharing_fixed(self):
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics)
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4])[2], ems)
with self.assertRaises(slip39.MnemonicError):
slip39.combine_mnemonics(mnemonics[1:3])
def test_passphrase(self):
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
def test_iteration_exponent(self):
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS)
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
def test_group_sharing(self):
mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
# All mnemonics.
self.assertEqual(slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS)
identifier, exponent, ems = slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
# Minimal sets of mnemonics.
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]])[2], ems)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[2], ems)
# Two complete groups and one incomplete group.
self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2])[2], ems)
self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4])[2], ems)
# One complete group and one incomplete group out of two groups required.
with self.assertRaises(slip39.MnemonicError):
@ -62,10 +67,11 @@ class TestCryptoSlip39(unittest.TestCase):
def test_vectors(self):
for mnemonics, secret in vectors:
if secret:
self.assertEqual(slip39.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret))
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics)
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), unhexlify(secret))
else:
with self.assertRaises(slip39.MnemonicError):
slip39.combine_mnemonics(mnemonics, b"TREZOR")
slip39.combine_mnemonics(mnemonics)
if __name__ == '__main__':