From 1b666804c079b20e0e9681c7fff7c29b0954f98d Mon Sep 17 00:00:00 2001 From: Tomas Susanka Date: Mon, 12 Aug 2019 12:27:02 +0200 Subject: [PATCH] core/shamir: fix EMS vs MS (cherry picked from commit cb9445461835e861e3e6254b3ad1089c3c61895c) --- core/src/apps/management/reset_device.py | 13 +- core/src/trezor/crypto/slip39.py | 40 +--- core/tests/test_trezor.crypto.slip39.py | 51 ++--- .../test_msg_resetdevice_shamir.py | 9 +- .../test_shamir_reset_recovery.py | 196 ++++++++++++++++++ 5 files changed, 243 insertions(+), 66 deletions(-) create mode 100644 tests/device_tests/test_shamir_reset_recovery.py diff --git a/core/src/apps/management/reset_device.py b/core/src/apps/management/reset_device.py index 081a6bcd2..2c2cc6440 100644 --- a/core/src/apps/management/reset_device.py +++ b/core/src/apps/management/reset_device.py @@ -47,6 +47,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: # request external entropy and compute the master secret entropy_ack = await ctx.call(EntropyRequest(), EntropyAck) ext_entropy = entropy_ack.entropy + # For SLIP-39 this is the Encrypted Master Secret secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength) if is_slip39_simple: @@ -76,7 +77,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: ) if is_slip39_simple: storage.device.store_mnemonic_secret( - secret, + secret, # this is the EMS in SLIP-39 terminology mnemonic.TYPE_SLIP39, needs_backup=msg.skip_backup, no_backup=msg.no_backup, @@ -97,7 +98,9 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: return Success(message="Initialized") -async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None: +async def backup_slip39_wallet( + ctx: wire.Context, encrypted_master_secret: bytes +) -> None: # get number of shares await layout.slip39_show_checklist_set_shares(ctx) shares_count = await layout.slip39_prompt_number_of_shares(ctx) @@ -108,7 +111,11 @@ async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None: # generate the mnemonics mnemonics = slip39.generate_single_group_mnemonics_from_data( - secret, storage.device.get_slip39_identifier(), threshold, shares_count + encrypted_master_secret, + storage.device.get_slip39_identifier(), + threshold, + shares_count, + storage.device.get_slip39_iteration_exponent(), ) # show and confirm individual shares diff --git a/core/src/trezor/crypto/slip39.py b/core/src/trezor/crypto/slip39.py index 5fe73232e..b3eb215ab 100644 --- a/core/src/trezor/crypto/slip39.py +++ b/core/src/trezor/crypto/slip39.py @@ -485,43 +485,34 @@ def generate_random_identifier() -> int: def generate_single_group_mnemonics_from_data( - master_secret: bytes, + encrypted_master_secret: bytes, identifier: int, threshold: int, count: int, - passphrase: bytes = b"", iteration_exponent: int = DEFAULT_ITERATION_EXPONENT, ) -> List[str]: return generate_mnemonics_from_data( - master_secret, - identifier, - 1, - [(threshold, count)], - passphrase, - iteration_exponent, + encrypted_master_secret, identifier, 1, [(threshold, count)], iteration_exponent )[0] def generate_mnemonics_from_data( - master_secret: bytes, + encrypted_master_secret: bytes, identifier: int, group_threshold: int, groups: List[Tuple[int, int]], - passphrase: bytes = b"", iteration_exponent: int = DEFAULT_ITERATION_EXPONENT, ) -> List[List[str]]: """ - Splits a master secret into mnemonic shares using Shamir's secret sharing scheme. - :param master_secret: The master secret to split. - :type master_secret: Array of bytes. + Splits an encrypted master secret into mnemonic shares using Shamir's secret sharing scheme. + :param encrypted_master_secret: The encrypted master secret to split. + :type encrypted_master_secret: Array of bytes. :param int identifier :param int group_threshold: The number of groups required to reconstruct the master secret. :param groups: A list of (member_threshold, member_count) pairs for each group, where member_count is the number of shares to generate for the group and member_threshold is the number of members required to reconstruct the group secret. :type groups: List of pairs of integers. - :param passphrase: The passphrase used to encrypt the master secret. - :type passphrase: Array of bytes. :param int iteration_exponent: The iteration exponent. :return: List of mnemonics. :rtype: List of byte arrays. @@ -529,21 +520,16 @@ def generate_mnemonics_from_data( :rtype: int. """ - if len(master_secret) * 8 < _MIN_STRENGTH_BITS: + if len(encrypted_master_secret) * 8 < _MIN_STRENGTH_BITS: raise ValueError( - "The length of the master secret ({} bytes) must be at least {} bytes.".format( - len(master_secret), bits_to_bytes(_MIN_STRENGTH_BITS) + "The length of the encrypted master secret ({} bytes) must be at least {} bytes.".format( + len(encrypted_master_secret), bits_to_bytes(_MIN_STRENGTH_BITS) ) ) - if len(master_secret) % 2 != 0: + if len(encrypted_master_secret) % 2 != 0: raise ValueError( - "The length of the master secret in bytes must be an even number." - ) - - if not all(32 <= c <= 126 for c in passphrase): - raise ValueError( - "The passphrase must contain only printable ASCII characters (code points 32-126)." + "The length of the encrypted master secret in bytes must be an even number." ) if group_threshold > len(groups): @@ -561,10 +547,6 @@ def generate_mnemonics_from_data( "Creating multiple member shares with member threshold 1 is not allowed. Use 1-of-1 member sharing instead." ) - encrypted_master_secret = _encrypt( - master_secret, passphrase, iteration_exponent, identifier - ) - group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret) mnemonics = [] # type: List[List[str]] diff --git a/core/tests/test_trezor.crypto.slip39.py b/core/tests/test_trezor.crypto.slip39.py index 59c43719c..b4e40a4d9 100644 --- a/core/tests/test_trezor.crypto.slip39.py +++ b/core/tests/test_trezor.crypto.slip39.py @@ -22,7 +22,7 @@ def combinations(iterable, r): yield tuple(pool[i] for i in indices) class TestCryptoSlip39(unittest.TestCase): - MS = b"ABCDEFGHIJKLMNOP" + EMS = b"ABCDEFGHIJKLMNOP" def test_basic_sharing_random(self): ms = random.bytes(32) @@ -34,39 +34,28 @@ class TestCryptoSlip39(unittest.TestCase): def test_basic_sharing_fixed(self): generated_identifier = slip39.generate_random_identifier() - mnemonics = slip39.generate_mnemonics_from_data(self.MS, generated_identifier, 1, [(3, 5)]) + mnemonics = slip39.generate_mnemonics_from_data(self.EMS, generated_identifier, 1, [(3, 5)]) mnemonics = mnemonics[0] identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[:3]) - self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) + self.assertEqual(ems, self.EMS) self.assertEqual(generated_identifier, identifier) self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4])[2], ems) with self.assertRaises(slip39.MnemonicError): slip39.combine_mnemonics(mnemonics[1:3]) - def test_passphrase(self): - identifier = slip39.generate_random_identifier() - mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR") - mnemonics = mnemonics[0] - identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) - self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) - self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) - - def test_iteration_exponent(self): identifier = slip39.generate_random_identifier() - mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR", 1) + mnemonics = slip39.generate_mnemonics_from_data(self.EMS, identifier, 1, [(3, 5)], 1) mnemonics = mnemonics[0] identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) - self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) - self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) + self.assertEqual(ems, self.EMS) identifier = slip39.generate_random_identifier() - mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR", 2) + mnemonics = slip39.generate_mnemonics_from_data(self.EMS, identifier, 1, [(3, 5)], 2) mnemonics = mnemonics[0] identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) - self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) - self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) + self.assertEqual(ems, self.EMS) def test_group_sharing(self): @@ -75,7 +64,7 @@ class TestCryptoSlip39(unittest.TestCase): member_thresholds = (3, 2, 2, 1) identifier = slip39.generate_random_identifier() mnemonics = slip39.generate_mnemonics_from_data( - self.MS, identifier, group_threshold, list(zip(member_thresholds, group_sizes)) + self.EMS, identifier, group_threshold, list(zip(member_thresholds, group_sizes)) ) # Test all valid combinations of mnemonics. @@ -85,12 +74,12 @@ class TestCryptoSlip39(unittest.TestCase): mnemonic_subset = list(group1_subset + group2_subset) random.shuffle(mnemonic_subset) identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset) - self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) + self.assertEqual(ems, self.EMS) # Minimal sets of mnemonics. identifier, exponent, ems = slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]) - self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) + self.assertEqual(ems, self.EMS) self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[2], ems) # One complete group and one incomplete group out of two groups required. @@ -108,7 +97,7 @@ class TestCryptoSlip39(unittest.TestCase): member_thresholds = (3, 2, 2, 1) identifier = slip39.generate_random_identifier() mnemonics = slip39.generate_mnemonics_from_data( - self.MS, identifier, group_threshold, list(zip(member_thresholds, group_sizes)) + self.EMS, identifier, group_threshold, list(zip(member_thresholds, group_sizes)) ) # Test all valid combinations of mnemonics. @@ -117,14 +106,14 @@ class TestCryptoSlip39(unittest.TestCase): mnemonic_subset = list(group_subset) random.shuffle(mnemonic_subset) identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset) - self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) + self.assertEqual(ems, self.EMS) def test_all_groups_exist(self): for group_threshold in (1, 2, 5): identifier = slip39.generate_random_identifier() mnemonics = slip39.generate_mnemonics_from_data( - self.MS, identifier, group_threshold, [(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)] + self.EMS, identifier, group_threshold, [(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)] ) self.assertEqual(len(mnemonics), 5) self.assertEqual(len(sum(mnemonics, [])), 19) @@ -134,31 +123,31 @@ class TestCryptoSlip39(unittest.TestCase): identifier = slip39.generate_random_identifier() # Short master secret. with self.assertRaises(ValueError): - slip39.generate_mnemonics_from_data(self.MS[:14], identifier, 1, [(2, 3)]) + slip39.generate_mnemonics_from_data(self.EMS[:14], identifier, 1, [(2, 3)]) # Odd length master secret. with self.assertRaises(ValueError): - slip39.generate_mnemonics_from_data(self.MS + b"X", identifier,1, [(2, 3)]) + slip39.generate_mnemonics_from_data(self.EMS + b"X", identifier,1, [(2, 3)]) # Group threshold exceeds number of groups. with self.assertRaises(ValueError): - slip39.generate_mnemonics_from_data(self.MS, identifier, 3, [(3, 5), (2, 5)]) + slip39.generate_mnemonics_from_data(self.EMS, identifier, 3, [(3, 5), (2, 5)]) # Invalid group threshold. with self.assertRaises(ValueError): - slip39.generate_mnemonics_from_data(self.MS, identifier, 0, [(3, 5), (2, 5)]) + slip39.generate_mnemonics_from_data(self.EMS, identifier, 0, [(3, 5), (2, 5)]) # Member threshold exceeds number of members. with self.assertRaises(ValueError): - slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(3, 2), (2, 5)]) + slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(3, 2), (2, 5)]) # Invalid member threshold. with self.assertRaises(ValueError): - slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(0, 2), (2, 5)]) + slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(0, 2), (2, 5)]) # Group with multiple members and threshold 1. with self.assertRaises(ValueError): - slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(3, 5), (1, 3), (2, 5)]) + slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(3, 5), (1, 3), (2, 5)]) def test_vectors(self): diff --git a/tests/device_tests/test_msg_resetdevice_shamir.py b/tests/device_tests/test_msg_resetdevice_shamir.py index 11481e7ec..017c31df0 100644 --- a/tests/device_tests/test_msg_resetdevice_shamir.py +++ b/tests/device_tests/test_msg_resetdevice_shamir.py @@ -155,11 +155,14 @@ class TestMsgResetDeviceT2(TrezorTest): def validate_mnemonics(mnemonics, threshold, expected_secret): # We expect these combinations to recreate the secret properly for test_group in combinations(mnemonics, threshold): - secret = shamir.combine_mnemonics(test_group) - assert secret == expected_secret + # TODO: HOTFIX, we should fix this properly by modifying and unifying the python-shamir-mnemonic API + ms = shamir.combine_mnemonics(test_group) + identifier, iteration_exponent, _, _, _ = shamir._decode_mnemonics(test_group) + ems = shamir._encrypt(ms, b"", iteration_exponent, identifier) + assert ems == expected_secret # We expect these combinations to raise MnemonicError for test_group in combinations(mnemonics, threshold - 1): with pytest.raises( MnemonicError, match=r".*Expected {} mnemonics.*".format(threshold) ): - secret = shamir.combine_mnemonics(test_group) + shamir.combine_mnemonics(test_group) diff --git a/tests/device_tests/test_shamir_reset_recovery.py b/tests/device_tests/test_shamir_reset_recovery.py new file mode 100644 index 000000000..3c326a003 --- /dev/null +++ b/tests/device_tests/test_shamir_reset_recovery.py @@ -0,0 +1,196 @@ +import pytest + +from trezorlib import btc, device, messages +from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType +from trezorlib.tools import parse_path + + +@pytest.mark.skip_t1 +def test_reset_recovery(client): + mnemonics = reset(client) + address_before = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0")) + + for share_subset in ((0, 1, 2), (4, 3, 2), (2, 1, 3)): + # TODO: change the above to itertools.combinations(mnemonics, 3) + device.wipe(client) + selected_mnemonics = [mnemonics[i] for i in share_subset] + recover(client, selected_mnemonics) + address_after = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0")) + assert address_before == address_after + + +def reset(client, strength=128): + all_mnemonics = [] + + def input_flow(): + # Confirm Reset + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Backup your seed + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Confirm warning + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # shares info + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Set & Confirm number of shares + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # threshold info + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Set & confirm threshold value + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # Confirm show seeds + btn_code = yield + assert btn_code == B.ResetDevice + client.debug.press_yes() + + # show & confirm shares + for h in range(5): + words = [] + btn_code = yield + assert btn_code == B.Other + + # mnemonic phrases + # 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256 + for i in range(6): + words.extend(client.debug.read_reset_word().split()) + if i < 5: + client.debug.swipe_down() + else: + # last page is confirmation + client.debug.press_yes() + + # check share + for _ in range(3): + index = client.debug.read_reset_word_pos() + client.debug.input(words[index]) + + all_mnemonics.extend([" ".join(words)]) + + # Confirm continue to next share + btn_code = yield + assert btn_code == B.Success + client.debug.press_yes() + + # safety warning + btn_code = yield + assert btn_code == B.Success + client.debug.press_yes() + + with client: + client.set_expected_responses( + [ + messages.ButtonRequest(code=B.ResetDevice), + messages.EntropyRequest(), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.ResetDevice), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Other), + messages.ButtonRequest(code=B.Success), + messages.ButtonRequest(code=B.Success), + messages.Success(), + messages.Features(), + ] + ) + client.set_input_flow(input_flow) + + # No PIN, no passphrase, don't display random + device.reset( + client, + display_random=False, + strength=strength, + passphrase_protection=False, + pin_protection=False, + label="test", + language="english", + backup_type=ResetDeviceBackupType.Slip39_Single_Group, + ) + client.set_input_flow(None) + + # Check if device is properly initialized + assert client.features.initialized is True + assert client.features.needs_backup is False + assert client.features.pin_protection is False + assert client.features.passphrase_protection is False + + return all_mnemonics + + +def recover(client, shares): + debug = client.debug + + def input_flow(): + yield # Confirm Recovery + debug.press_yes() + # run recovery flow + yield from enter_all_shares(debug, shares) + + with client: + client.set_input_flow(input_flow) + ret = device.recover(client, pin_protection=False, label="label") + + client.set_input_flow(None) + + # Workflow successfully ended + assert ret == messages.Success(message="Device recovered") + assert client.features.pin_protection is False + assert client.features.passphrase_protection is False + + +# TODO: let's merge this with test_msg_recoverydevice_shamir.py +def enter_all_shares(debug, shares): + word_count = len(shares[0].split(" ")) + + # Homescreen - proceed to word number selection + yield + debug.press_yes() + # Input word number + code = yield + assert code == B.MnemonicWordCount + debug.input(str(word_count)) + # Homescreen - proceed to share entry + yield + debug.press_yes() + # Enter shares + for share in shares: + code = yield + assert code == B.MnemonicInput + # Enter mnemonic words + for word in share.split(" "): + debug.input(word) + + # Homescreen - continue + # or Homescreen - confirm success + yield + debug.press_yes()