From 1003ed90830804738ea28833d786b9faf4630f42 Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Mon, 15 Apr 2019 18:13:29 +0200 Subject: [PATCH] slip39: Split decryption out from combine_mnemonics(). --- src/trezor/crypto/slip39.py | 17 ++++++-------- tests/test_trezor.crypto.slip39.py | 36 +++++++++++++++++------------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/trezor/crypto/slip39.py b/src/trezor/crypto/slip39.py index cb682bf91..af5b3e435 100644 --- a/src/trezor/crypto/slip39.py +++ b/src/trezor/crypto/slip39.py @@ -184,7 +184,7 @@ def _encrypt(master_secret, passphrase, iteration_exponent, identifier): return r + l -def _decrypt(encrypted_master_secret, passphrase, iteration_exponent, identifier): +def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase): l = encrypted_master_secret[: len(encrypted_master_secret) // 2] r = encrypted_master_secret[len(encrypted_master_secret) // 2 :] salt = _get_salt(identifier) @@ -536,16 +536,14 @@ def generate_mnemonics_random( ) -def combine_mnemonics(mnemonics, passphrase=b""): +def combine_mnemonics(mnemonics): """ Combines mnemonic shares to obtain the master secret which was previously split using Shamir's secret sharing scheme. :param mnemonics: List of mnemonics. :type mnemonics: List of byte arrays. - :param passphrase: The passphrase used to encrypt the master secret. - :type passphrase: Array of bytes. - :return: The master secret. - :rtype: Array of bytes. + :return: Identifier, iteration exponent, the encrypted master secret. + :rtype: Integer, integer, array of bytes. """ if not mnemonics: @@ -587,9 +585,8 @@ def combine_mnemonics(mnemonics, passphrase=b""): for group_index, group in groups.items() ] - return _decrypt( - _recover_secret(group_threshold, group_shares), - passphrase, - iteration_exponent, + return ( identifier, + iteration_exponent, + _recover_secret(group_threshold, group_shares), ) diff --git a/tests/test_trezor.crypto.slip39.py b/tests/test_trezor.crypto.slip39.py index 58fe42ab7..a00654e8c 100644 --- a/tests/test_trezor.crypto.slip39.py +++ b/tests/test_trezor.crypto.slip39.py @@ -14,41 +14,46 @@ class TestCryptoSlip39(unittest.TestCase): def test_basic_sharing_fixed(self): mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS)[0] - self.assertEqual(slip39.combine_mnemonics(mnemonics), self.MS) - self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) + identifier, exponent, ems = slip39.combine_mnemonics(mnemonics) + self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) + self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4])[2], ems) with self.assertRaises(slip39.MnemonicError): slip39.combine_mnemonics(mnemonics[1:3]) def test_passphrase(self): mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR")[0] - self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) - self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) + identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) + self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) + self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) def test_iteration_exponent(self): mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 1)[0] - self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) - self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) + identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) + self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) + self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) mnemonics = slip39.generate_mnemonics(1, [(3, 5)], self.MS, b"TREZOR", 2)[0] - self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4], b"TREZOR"), self.MS) - self.assertNotEqual(slip39.combine_mnemonics(mnemonics[1:4]), self.MS) + identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) + self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) + self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) def test_group_sharing(self): mnemonics = slip39.generate_mnemonics(2, [(3, 5), (2, 3), (2, 5), (1, 1)], self.MS) # All mnemonics. - self.assertEqual(slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]), self.MS) + identifier, exponent, ems = slip39.combine_mnemonics([mnemonic for group in mnemonics for mnemonic in group]) + self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) # Minimal sets of mnemonics. - self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]), self.MS) - self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]]), self.MS) + self.assertEqual(slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]])[2], ems) + self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[2], ems) # Two complete groups and one incomplete group. - self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2]), self.MS) - self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4]), self.MS) + self.assertEqual(slip39.combine_mnemonics(mnemonics[0] + [mnemonics[1][1]] + mnemonics[2])[2], ems) + self.assertEqual(slip39.combine_mnemonics(mnemonics[0][1:4] + mnemonics[1][1:3] + mnemonics[2][2:4])[2], ems) # One complete group and one incomplete group out of two groups required. with self.assertRaises(slip39.MnemonicError): @@ -62,10 +67,11 @@ class TestCryptoSlip39(unittest.TestCase): def test_vectors(self): for mnemonics, secret in vectors: if secret: - self.assertEqual(slip39.combine_mnemonics(mnemonics, b"TREZOR"), unhexlify(secret)) + identifier, exponent, ems = slip39.combine_mnemonics(mnemonics) + self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), unhexlify(secret)) else: with self.assertRaises(slip39.MnemonicError): - slip39.combine_mnemonics(mnemonics, b"TREZOR") + slip39.combine_mnemonics(mnemonics) if __name__ == '__main__':