slip39: Split decryption out from combine_mnemonics().

pull/85/head
Andrew Kozlik 5 years ago
parent 04dcfea901
commit 1003ed9083

@ -184,7 +184,7 @@ def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
return r + l return r + l
def _decrypt(encrypted_master_secret, passphrase, iteration_exponent, identifier): def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase):
l = encrypted_master_secret[: len(encrypted_master_secret) // 2] l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :] r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
salt = _get_salt(identifier) salt = _get_salt(identifier)
@ -536,16 +536,14 @@ def generate_mnemonics_random(
) )
def combine_mnemonics(mnemonics, passphrase=b""): def combine_mnemonics(mnemonics):
""" """
Combines mnemonic shares to obtain the master secret which was previously split using Combines mnemonic shares to obtain the master secret which was previously split using
Shamir's secret sharing scheme. Shamir's secret sharing scheme.
:param mnemonics: List of mnemonics. :param mnemonics: List of mnemonics.
:type mnemonics: List of byte arrays. :type mnemonics: List of byte arrays.
:param passphrase: The passphrase used to encrypt the master secret. :return: Identifier, iteration exponent, the encrypted master secret.
:type passphrase: Array of bytes. :rtype: Integer, integer, array of bytes.
:return: The master secret.
:rtype: Array of bytes.
""" """
if not mnemonics: if not mnemonics:
@ -587,9 +585,8 @@ def combine_mnemonics(mnemonics, passphrase=b""):
for group_index, group in groups.items() for group_index, group in groups.items()
] ]
return _decrypt( return (
_recover_secret(group_threshold, group_shares),
passphrase,
iteration_exponent,
identifier, identifier,
iteration_exponent,
_recover_secret(group_threshold, group_shares),
) )

@ -14,41 +14,46 @@ class TestCryptoSlip39(unittest.TestCase):
def test_basic_sharing_fixed(self): def test_basic_sharing_fixed(self):
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics), self.MS) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics)
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4])[2], ems)
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
slip39.combine_mnemonics(mnemonics[1:3]) slip39.combine_mnemonics(mnemonics[1:3])
def test_passphrase(self): def test_passphrase(self):
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
def test_iteration_exponent(self): def test_iteration_exponent(self):
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0] mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0]
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
def test_group_sharing(self): def test_group_sharing(self):
mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS) mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS)
# All mnemonics. # All mnemonics.
self.assertEqual(slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS) identifier, exponent, ems = slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
# Minimal sets of mnemonics. # Minimal sets of mnemonics.
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS) self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]])[2], ems)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS) self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[2], ems)
# Two complete groups and one incomplete group. # Two complete groups and one incomplete group.
self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2])[2], ems)
self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS) self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4])[2], ems)
# One complete group and one incomplete group out of two groups required. # One complete group and one incomplete group out of two groups required.
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
@ -62,10 +67,11 @@ class TestCryptoSlip39(unittest.TestCase):
def test_vectors(self): def test_vectors(self):
for mnemonics, secret in vectors: for mnemonics, secret in vectors:
if secret: if secret:
self.assertEqual(slip39.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret)) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics)
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), unhexlify(secret))
else: else:
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
slip39.combine_mnemonics(mnemonics, b"TREZOR") slip39.combine_mnemonics(mnemonics)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save