1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-27 08:38:07 +00:00

core/shamir: fix EMS vs MS

(cherry picked from commit cb94454618)
This commit is contained in:
Tomas Susanka 2019-08-12 12:27:02 +02:00
parent d4b1e256d6
commit 1b666804c0
5 changed files with 243 additions and 66 deletions

View File

@ -47,6 +47,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
# request external entropy and compute the master secret # request external entropy and compute the master secret
entropy_ack = await ctx.call(EntropyRequest(), EntropyAck) entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
ext_entropy = entropy_ack.entropy ext_entropy = entropy_ack.entropy
# For SLIP-39 this is the Encrypted Master Secret
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength) secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
if is_slip39_simple: if is_slip39_simple:
@ -76,7 +77,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
) )
if is_slip39_simple: if is_slip39_simple:
storage.device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, secret, # this is the EMS in SLIP-39 terminology
mnemonic.TYPE_SLIP39, mnemonic.TYPE_SLIP39,
needs_backup=msg.skip_backup, needs_backup=msg.skip_backup,
no_backup=msg.no_backup, no_backup=msg.no_backup,
@ -97,7 +98,9 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
return Success(message="Initialized") return Success(message="Initialized")
async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None: async def backup_slip39_wallet(
ctx: wire.Context, encrypted_master_secret: bytes
) -> None:
# get number of shares # get number of shares
await layout.slip39_show_checklist_set_shares(ctx) await layout.slip39_show_checklist_set_shares(ctx)
shares_count = await layout.slip39_prompt_number_of_shares(ctx) shares_count = await layout.slip39_prompt_number_of_shares(ctx)
@ -108,7 +111,11 @@ async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None:
# generate the mnemonics # generate the mnemonics
mnemonics = slip39.generate_single_group_mnemonics_from_data( mnemonics = slip39.generate_single_group_mnemonics_from_data(
secret, storage.device.get_slip39_identifier(), threshold, shares_count encrypted_master_secret,
storage.device.get_slip39_identifier(),
threshold,
shares_count,
storage.device.get_slip39_iteration_exponent(),
) )
# show and confirm individual shares # show and confirm individual shares

View File

@ -485,43 +485,34 @@ def generate_random_identifier() -> int:
def generate_single_group_mnemonics_from_data( def generate_single_group_mnemonics_from_data(
master_secret: bytes, encrypted_master_secret: bytes,
identifier: int, identifier: int,
threshold: int, threshold: int,
count: int, count: int,
passphrase: bytes = b"",
iteration_exponent: int = DEFAULT_ITERATION_EXPONENT, iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> List[str]: ) -> List[str]:
return generate_mnemonics_from_data( return generate_mnemonics_from_data(
master_secret, encrypted_master_secret, identifier, 1, [(threshold, count)], iteration_exponent
identifier,
1,
[(threshold, count)],
passphrase,
iteration_exponent,
)[0] )[0]
def generate_mnemonics_from_data( def generate_mnemonics_from_data(
master_secret: bytes, encrypted_master_secret: bytes,
identifier: int, identifier: int,
group_threshold: int, group_threshold: int,
groups: List[Tuple[int, int]], groups: List[Tuple[int, int]],
passphrase: bytes = b"",
iteration_exponent: int = DEFAULT_ITERATION_EXPONENT, iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> List[List[str]]: ) -> List[List[str]]:
""" """
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme. Splits an encrypted master secret into mnemonic shares using Shamir's secret sharing scheme.
:param master_secret: The master secret to split. :param encrypted_master_secret: The encrypted master secret to split.
:type master_secret: Array of bytes. :type encrypted_master_secret: Array of bytes.
:param int identifier :param int identifier
:param int group_threshold: The number of groups required to reconstruct the master secret. :param int group_threshold: The number of groups required to reconstruct the master secret.
:param groups: A list of (member_threshold, member_count) pairs for each group, where member_count :param groups: A list of (member_threshold, member_count) pairs for each group, where member_count
is the number of shares to generate for the group and member_threshold is the number of members required to is the number of shares to generate for the group and member_threshold is the number of members required to
reconstruct the group secret. reconstruct the group secret.
:type groups: List of pairs of integers. :type groups: List of pairs of integers.
:param passphrase: The passphrase used to encrypt the master secret.
:type passphrase: Array of bytes.
:param int iteration_exponent: The iteration exponent. :param int iteration_exponent: The iteration exponent.
:return: List of mnemonics. :return: List of mnemonics.
:rtype: List of byte arrays. :rtype: List of byte arrays.
@ -529,21 +520,16 @@ def generate_mnemonics_from_data(
:rtype: int. :rtype: int.
""" """
if len(master_secret) * 8 < _MIN_STRENGTH_BITS: if len(encrypted_master_secret) * 8 < _MIN_STRENGTH_BITS:
raise ValueError( raise ValueError(
"The length of the master secret ({} bytes) must be at least {} bytes.".format( "The length of the encrypted master secret ({} bytes) must be at least {} bytes.".format(
len(master_secret), bits_to_bytes(_MIN_STRENGTH_BITS) len(encrypted_master_secret), bits_to_bytes(_MIN_STRENGTH_BITS)
) )
) )
if len(master_secret) % 2 != 0: if len(encrypted_master_secret) % 2 != 0:
raise ValueError( raise ValueError(
"The length of the master secret in bytes must be an even number." "The length of the encrypted master secret in bytes must be an even number."
)
if not all(32 <= c <= 126 for c in passphrase):
raise ValueError(
"The passphrase must contain only printable ASCII characters (code points 32-126)."
) )
if group_threshold > len(groups): if group_threshold > len(groups):
@ -561,10 +547,6 @@ def generate_mnemonics_from_data(
"Creating multiple member shares with member threshold 1 is not allowed. Use 1-of-1 member sharing instead." "Creating multiple member shares with member threshold 1 is not allowed. Use 1-of-1 member sharing instead."
) )
encrypted_master_secret = _encrypt(
master_secret, passphrase, iteration_exponent, identifier
)
group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret) group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
mnemonics = [] # type: List[List[str]] mnemonics = [] # type: List[List[str]]

View File

@ -22,7 +22,7 @@ def combinations(iterable, r):
yield tuple(pool[i] for i in indices) yield tuple(pool[i] for i in indices)
class TestCryptoSlip39(unittest.TestCase): class TestCryptoSlip39(unittest.TestCase):
MS = b"ABCDEFGHIJKLMNOP" EMS = b"ABCDEFGHIJKLMNOP"
def test_basic_sharing_random(self): def test_basic_sharing_random(self):
ms = random.bytes(32) ms = random.bytes(32)
@ -34,39 +34,28 @@ class TestCryptoSlip39(unittest.TestCase):
def test_basic_sharing_fixed(self): def test_basic_sharing_fixed(self):
generated_identifier = slip39.generate_random_identifier() generated_identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, generated_identifier, 1, [(3, 5)]) mnemonics = slip39.generate_mnemonics_from_data(self.EMS, generated_identifier, 1, [(3, 5)])
mnemonics = mnemonics[0] mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[:3]) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[:3])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) self.assertEqual(ems, self.EMS)
self.assertEqual(generated_identifier, identifier) self.assertEqual(generated_identifier, identifier)
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4])[2], ems) self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4])[2], ems)
with self.assertRaises(slip39.MnemonicError): with self.assertRaises(slip39.MnemonicError):
slip39.combine_mnemonics(mnemonics[1:3]) slip39.combine_mnemonics(mnemonics[1:3])
def test_passphrase(self):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR")
mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
def test_iteration_exponent(self): def test_iteration_exponent(self):
identifier = slip39.generate_random_identifier() identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR", 1) mnemonics = slip39.generate_mnemonics_from_data(self.EMS, identifier, 1, [(3, 5)], 1)
mnemonics = mnemonics[0] mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) self.assertEqual(ems, self.EMS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
identifier = slip39.generate_random_identifier() identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR", 2) mnemonics = slip39.generate_mnemonics_from_data(self.EMS, identifier, 1, [(3, 5)], 2)
mnemonics = mnemonics[0] mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4]) identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS) self.assertEqual(ems, self.EMS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
def test_group_sharing(self): def test_group_sharing(self):
@ -75,7 +64,7 @@ class TestCryptoSlip39(unittest.TestCase):
member_thresholds = (3, 2, 2, 1) member_thresholds = (3, 2, 2, 1)
identifier = slip39.generate_random_identifier() identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data( mnemonics = slip39.generate_mnemonics_from_data(
self.MS, identifier, group_threshold, list(zip(member_thresholds, group_sizes)) self.EMS, identifier, group_threshold, list(zip(member_thresholds, group_sizes))
) )
# Test all valid combinations of mnemonics. # Test all valid combinations of mnemonics.
@ -85,12 +74,12 @@ class TestCryptoSlip39(unittest.TestCase):
mnemonic_subset = list(group1_subset + group2_subset) mnemonic_subset = list(group1_subset + group2_subset)
random.shuffle(mnemonic_subset) random.shuffle(mnemonic_subset)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset) identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset)
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) self.assertEqual(ems, self.EMS)
# Minimal sets of mnemonics. # Minimal sets of mnemonics.
identifier, exponent, ems = slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]]) identifier, exponent, ems = slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) self.assertEqual(ems, self.EMS)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[2], ems) self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[2], ems)
# One complete group and one incomplete group out of two groups required. # One complete group and one incomplete group out of two groups required.
@ -108,7 +97,7 @@ class TestCryptoSlip39(unittest.TestCase):
member_thresholds = (3, 2, 2, 1) member_thresholds = (3, 2, 2, 1)
identifier = slip39.generate_random_identifier() identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data( mnemonics = slip39.generate_mnemonics_from_data(
self.MS, identifier, group_threshold, list(zip(member_thresholds, group_sizes)) self.EMS, identifier, group_threshold, list(zip(member_thresholds, group_sizes))
) )
# Test all valid combinations of mnemonics. # Test all valid combinations of mnemonics.
@ -117,14 +106,14 @@ class TestCryptoSlip39(unittest.TestCase):
mnemonic_subset = list(group_subset) mnemonic_subset = list(group_subset)
random.shuffle(mnemonic_subset) random.shuffle(mnemonic_subset)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset) identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset)
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS) self.assertEqual(ems, self.EMS)
def test_all_groups_exist(self): def test_all_groups_exist(self):
for group_threshold in (1, 2, 5): for group_threshold in (1, 2, 5):
identifier = slip39.generate_random_identifier() identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data( mnemonics = slip39.generate_mnemonics_from_data(
self.MS, identifier, group_threshold, [(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)] self.EMS, identifier, group_threshold, [(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)]
) )
self.assertEqual(len(mnemonics), 5) self.assertEqual(len(mnemonics), 5)
self.assertEqual(len(sum(mnemonics, [])), 19) self.assertEqual(len(sum(mnemonics, [])), 19)
@ -134,31 +123,31 @@ class TestCryptoSlip39(unittest.TestCase):
identifier = slip39.generate_random_identifier() identifier = slip39.generate_random_identifier()
# Short master secret. # Short master secret.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS[:14], identifier, 1, [(2, 3)]) slip39.generate_mnemonics_from_data(self.EMS[:14], identifier, 1, [(2, 3)])
# Odd length master secret. # Odd length master secret.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS + b"X", identifier,1, [(2, 3)]) slip39.generate_mnemonics_from_data(self.EMS + b"X", identifier,1, [(2, 3)])
# Group threshold exceeds number of groups. # Group threshold exceeds number of groups.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 3, [(3, 5), (2, 5)]) slip39.generate_mnemonics_from_data(self.EMS, identifier, 3, [(3, 5), (2, 5)])
# Invalid group threshold. # Invalid group threshold.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 0, [(3, 5), (2, 5)]) slip39.generate_mnemonics_from_data(self.EMS, identifier, 0, [(3, 5), (2, 5)])
# Member threshold exceeds number of members. # Member threshold exceeds number of members.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(3, 2), (2, 5)]) slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(3, 2), (2, 5)])
# Invalid member threshold. # Invalid member threshold.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(0, 2), (2, 5)]) slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(0, 2), (2, 5)])
# Group with multiple members and threshold 1. # Group with multiple members and threshold 1.
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(3, 5), (1, 3), (2, 5)]) slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(3, 5), (1, 3), (2, 5)])
def test_vectors(self): def test_vectors(self):

View File

@ -155,11 +155,14 @@ class TestMsgResetDeviceT2(TrezorTest):
def validate_mnemonics(mnemonics, threshold, expected_secret): def validate_mnemonics(mnemonics, threshold, expected_secret):
# We expect these combinations to recreate the secret properly # We expect these combinations to recreate the secret properly
for test_group in combinations(mnemonics, threshold): for test_group in combinations(mnemonics, threshold):
secret = shamir.combine_mnemonics(test_group) # TODO: HOTFIX, we should fix this properly by modifying and unifying the python-shamir-mnemonic API
assert secret == expected_secret ms = shamir.combine_mnemonics(test_group)
identifier, iteration_exponent, _, _, _ = shamir._decode_mnemonics(test_group)
ems = shamir._encrypt(ms, b"", iteration_exponent, identifier)
assert ems == expected_secret
# We expect these combinations to raise MnemonicError # We expect these combinations to raise MnemonicError
for test_group in combinations(mnemonics, threshold - 1): for test_group in combinations(mnemonics, threshold - 1):
with pytest.raises( with pytest.raises(
MnemonicError, match=r".*Expected {} mnemonics.*".format(threshold) MnemonicError, match=r".*Expected {} mnemonics.*".format(threshold)
): ):
secret = shamir.combine_mnemonics(test_group) shamir.combine_mnemonics(test_group)

View File

@ -0,0 +1,196 @@
import pytest
from trezorlib import btc, device, messages
from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType
from trezorlib.tools import parse_path
@pytest.mark.skip_t1
def test_reset_recovery(client):
mnemonics = reset(client)
address_before = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
for share_subset in ((0, 1, 2), (4, 3, 2), (2, 1, 3)):
# TODO: change the above to itertools.combinations(mnemonics, 3)
device.wipe(client)
selected_mnemonics = [mnemonics[i] for i in share_subset]
recover(client, selected_mnemonics)
address_after = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
assert address_before == address_after
def reset(client, strength=128):
all_mnemonics = []
def input_flow():
# Confirm Reset
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Backup your seed
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm warning
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# shares info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & Confirm number of shares
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# threshold info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & confirm threshold value
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm show seeds
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# show & confirm shares
for h in range(5):
words = []
btn_code = yield
assert btn_code == B.Other
# mnemonic phrases
# 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256
for i in range(6):
words.extend(client.debug.read_reset_word().split())
if i < 5:
client.debug.swipe_down()
else:
# last page is confirmation
client.debug.press_yes()
# check share
for _ in range(3):
index = client.debug.read_reset_word_pos()
client.debug.input(words[index])
all_mnemonics.extend([" ".join(words)])
# Confirm continue to next share
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
# safety warning
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
with client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success(),
messages.Features(),
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random
device.reset(
client,
display_random=False,
strength=strength,
passphrase_protection=False,
pin_protection=False,
label="test",
language="english",
backup_type=ResetDeviceBackupType.Slip39_Single_Group,
)
client.set_input_flow(None)
# Check if device is properly initialized
assert client.features.initialized is True
assert client.features.needs_backup is False
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
return all_mnemonics
def recover(client, shares):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, shares)
with client:
client.set_input_flow(input_flow)
ret = device.recover(client, pin_protection=False, label="label")
client.set_input_flow(None)
# Workflow successfully ended
assert ret == messages.Success(message="Device recovered")
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
# TODO: let's merge this with test_msg_recoverydevice_shamir.py
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == B.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for share in shares:
code = yield
assert code == B.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()