1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-27 08:38:07 +00:00

core/shamir: fix EMS vs MS

(cherry picked from commit cb94454618)
This commit is contained in:
Tomas Susanka 2019-08-12 12:27:02 +02:00
parent d4b1e256d6
commit 1b666804c0
5 changed files with 243 additions and 66 deletions

View File

@ -47,6 +47,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
# request external entropy and compute the master secret
entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
ext_entropy = entropy_ack.entropy
# For SLIP-39 this is the Encrypted Master Secret
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
if is_slip39_simple:
@ -76,7 +77,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
)
if is_slip39_simple:
storage.device.store_mnemonic_secret(
secret,
secret, # this is the EMS in SLIP-39 terminology
mnemonic.TYPE_SLIP39,
needs_backup=msg.skip_backup,
no_backup=msg.no_backup,
@ -97,7 +98,9 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
return Success(message="Initialized")
async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None:
async def backup_slip39_wallet(
ctx: wire.Context, encrypted_master_secret: bytes
) -> None:
# get number of shares
await layout.slip39_show_checklist_set_shares(ctx)
shares_count = await layout.slip39_prompt_number_of_shares(ctx)
@ -108,7 +111,11 @@ async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None:
# generate the mnemonics
mnemonics = slip39.generate_single_group_mnemonics_from_data(
secret, storage.device.get_slip39_identifier(), threshold, shares_count
encrypted_master_secret,
storage.device.get_slip39_identifier(),
threshold,
shares_count,
storage.device.get_slip39_iteration_exponent(),
)
# show and confirm individual shares

View File

@ -485,43 +485,34 @@ def generate_random_identifier() -> int:
def generate_single_group_mnemonics_from_data(
master_secret: bytes,
encrypted_master_secret: bytes,
identifier: int,
threshold: int,
count: int,
passphrase: bytes = b"",
iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> List[str]:
return generate_mnemonics_from_data(
master_secret,
identifier,
1,
[(threshold, count)],
passphrase,
iteration_exponent,
encrypted_master_secret, identifier, 1, [(threshold, count)], iteration_exponent
)[0]
def generate_mnemonics_from_data(
master_secret: bytes,
encrypted_master_secret: bytes,
identifier: int,
group_threshold: int,
groups: List[Tuple[int, int]],
passphrase: bytes = b"",
iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> List[List[str]]:
"""
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
:param master_secret: The master secret to split.
:type master_secret: Array of bytes.
Splits an encrypted master secret into mnemonic shares using Shamir's secret sharing scheme.
:param encrypted_master_secret: The encrypted master secret to split.
:type encrypted_master_secret: Array of bytes.
:param int identifier
:param int group_threshold: The number of groups required to reconstruct the master secret.
:param groups: A list of (member_threshold, member_count) pairs for each group, where member_count
is the number of shares to generate for the group and member_threshold is the number of members required to
reconstruct the group secret.
:type groups: List of pairs of integers.
:param passphrase: The passphrase used to encrypt the master secret.
:type passphrase: Array of bytes.
:param int iteration_exponent: The iteration exponent.
:return: List of mnemonics.
:rtype: List of byte arrays.
@ -529,21 +520,16 @@ def generate_mnemonics_from_data(
:rtype: int.
"""
if len(master_secret) * 8 < _MIN_STRENGTH_BITS:
if len(encrypted_master_secret) * 8 < _MIN_STRENGTH_BITS:
raise ValueError(
"The length of the master secret ({} bytes) must be at least {} bytes.".format(
len(master_secret), bits_to_bytes(_MIN_STRENGTH_BITS)
"The length of the encrypted master secret ({} bytes) must be at least {} bytes.".format(
len(encrypted_master_secret), bits_to_bytes(_MIN_STRENGTH_BITS)
)
)
if len(master_secret) % 2 != 0:
if len(encrypted_master_secret) % 2 != 0:
raise ValueError(
"The length of the master secret in bytes must be an even number."
)
if not all(32 <= c <= 126 for c in passphrase):
raise ValueError(
"The passphrase must contain only printable ASCII characters (code points 32-126)."
"The length of the encrypted master secret in bytes must be an even number."
)
if group_threshold > len(groups):
@ -561,10 +547,6 @@ def generate_mnemonics_from_data(
"Creating multiple member shares with member threshold 1 is not allowed. Use 1-of-1 member sharing instead."
)
encrypted_master_secret = _encrypt(
master_secret, passphrase, iteration_exponent, identifier
)
group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
mnemonics = [] # type: List[List[str]]

View File

@ -22,7 +22,7 @@ def combinations(iterable, r):
yield tuple(pool[i] for i in indices)
class TestCryptoSlip39(unittest.TestCase):
MS = b"ABCDEFGHIJKLMNOP"
EMS = b"ABCDEFGHIJKLMNOP"
def test_basic_sharing_random(self):
ms = random.bytes(32)
@ -34,39 +34,28 @@ class TestCryptoSlip39(unittest.TestCase):
def test_basic_sharing_fixed(self):
generated_identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, generated_identifier, 1, [(3, 5)])
mnemonics = slip39.generate_mnemonics_from_data(self.EMS, generated_identifier, 1, [(3, 5)])
mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[:3])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(ems, self.EMS)
self.assertEqual(generated_identifier, identifier)
self.assertEqual(slip39.combine_mnemonics(mnemonics[1:4])[2], ems)
with self.assertRaises(slip39.MnemonicError):
slip39.combine_mnemonics(mnemonics[1:3])
def test_passphrase(self):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR")
mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
def test_iteration_exponent(self):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR", 1)
mnemonics = slip39.generate_mnemonics_from_data(self.EMS, identifier, 1, [(3, 5)], 1)
mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(ems, self.EMS)
identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(self.MS, identifier, 1, [(3, 5)], b"TREZOR", 2)
mnemonics = slip39.generate_mnemonics_from_data(self.EMS, identifier, 1, [(3, 5)], 2)
mnemonics = mnemonics[0]
identifier, exponent, ems = slip39.combine_mnemonics(mnemonics[1:4])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b"TREZOR"), self.MS)
self.assertNotEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(ems, self.EMS)
def test_group_sharing(self):
@ -75,7 +64,7 @@ class TestCryptoSlip39(unittest.TestCase):
member_thresholds = (3, 2, 2, 1)
identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(
self.MS, identifier, group_threshold, list(zip(member_thresholds, group_sizes))
self.EMS, identifier, group_threshold, list(zip(member_thresholds, group_sizes))
)
# Test all valid combinations of mnemonics.
@ -85,12 +74,12 @@ class TestCryptoSlip39(unittest.TestCase):
mnemonic_subset = list(group1_subset + group2_subset)
random.shuffle(mnemonic_subset)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset)
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(ems, self.EMS)
# Minimal sets of mnemonics.
identifier, exponent, ems = slip39.combine_mnemonics([mnemonics[2][0], mnemonics[2][2], mnemonics[3][0]])
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(ems, self.EMS)
self.assertEqual(slip39.combine_mnemonics([mnemonics[2][3], mnemonics[3][0], mnemonics[2][4]])[2], ems)
# One complete group and one incomplete group out of two groups required.
@ -108,7 +97,7 @@ class TestCryptoSlip39(unittest.TestCase):
member_thresholds = (3, 2, 2, 1)
identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(
self.MS, identifier, group_threshold, list(zip(member_thresholds, group_sizes))
self.EMS, identifier, group_threshold, list(zip(member_thresholds, group_sizes))
)
# Test all valid combinations of mnemonics.
@ -117,14 +106,14 @@ class TestCryptoSlip39(unittest.TestCase):
mnemonic_subset = list(group_subset)
random.shuffle(mnemonic_subset)
identifier, exponent, ems = slip39.combine_mnemonics(mnemonic_subset)
self.assertEqual(slip39.decrypt(identifier, exponent, ems, b""), self.MS)
self.assertEqual(ems, self.EMS)
def test_all_groups_exist(self):
for group_threshold in (1, 2, 5):
identifier = slip39.generate_random_identifier()
mnemonics = slip39.generate_mnemonics_from_data(
self.MS, identifier, group_threshold, [(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)]
self.EMS, identifier, group_threshold, [(3, 5), (1, 1), (2, 3), (2, 5), (3, 5)]
)
self.assertEqual(len(mnemonics), 5)
self.assertEqual(len(sum(mnemonics, [])), 19)
@ -134,31 +123,31 @@ class TestCryptoSlip39(unittest.TestCase):
identifier = slip39.generate_random_identifier()
# Short master secret.
with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS[:14], identifier, 1, [(2, 3)])
slip39.generate_mnemonics_from_data(self.EMS[:14], identifier, 1, [(2, 3)])
# Odd length master secret.
with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS + b"X", identifier,1, [(2, 3)])
slip39.generate_mnemonics_from_data(self.EMS + b"X", identifier,1, [(2, 3)])
# Group threshold exceeds number of groups.
with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 3, [(3, 5), (2, 5)])
slip39.generate_mnemonics_from_data(self.EMS, identifier, 3, [(3, 5), (2, 5)])
# Invalid group threshold.
with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 0, [(3, 5), (2, 5)])
slip39.generate_mnemonics_from_data(self.EMS, identifier, 0, [(3, 5), (2, 5)])
# Member threshold exceeds number of members.
with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(3, 2), (2, 5)])
slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(3, 2), (2, 5)])
# Invalid member threshold.
with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(0, 2), (2, 5)])
slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(0, 2), (2, 5)])
# Group with multiple members and threshold 1.
with self.assertRaises(ValueError):
slip39.generate_mnemonics_from_data(self.MS, identifier, 2, [(3, 5), (1, 3), (2, 5)])
slip39.generate_mnemonics_from_data(self.EMS, identifier, 2, [(3, 5), (1, 3), (2, 5)])
def test_vectors(self):

View File

@ -155,11 +155,14 @@ class TestMsgResetDeviceT2(TrezorTest):
def validate_mnemonics(mnemonics, threshold, expected_secret):
# We expect these combinations to recreate the secret properly
for test_group in combinations(mnemonics, threshold):
secret = shamir.combine_mnemonics(test_group)
assert secret == expected_secret
# TODO: HOTFIX, we should fix this properly by modifying and unifying the python-shamir-mnemonic API
ms = shamir.combine_mnemonics(test_group)
identifier, iteration_exponent, _, _, _ = shamir._decode_mnemonics(test_group)
ems = shamir._encrypt(ms, b"", iteration_exponent, identifier)
assert ems == expected_secret
# We expect these combinations to raise MnemonicError
for test_group in combinations(mnemonics, threshold - 1):
with pytest.raises(
MnemonicError, match=r".*Expected {} mnemonics.*".format(threshold)
):
secret = shamir.combine_mnemonics(test_group)
shamir.combine_mnemonics(test_group)

View File

@ -0,0 +1,196 @@
import pytest
from trezorlib import btc, device, messages
from trezorlib.messages import ButtonRequestType as B, ResetDeviceBackupType
from trezorlib.tools import parse_path
@pytest.mark.skip_t1
def test_reset_recovery(client):
mnemonics = reset(client)
address_before = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
for share_subset in ((0, 1, 2), (4, 3, 2), (2, 1, 3)):
# TODO: change the above to itertools.combinations(mnemonics, 3)
device.wipe(client)
selected_mnemonics = [mnemonics[i] for i in share_subset]
recover(client, selected_mnemonics)
address_after = btc.get_address(client, "Bitcoin", parse_path("44'/0'/0'/0/0"))
assert address_before == address_after
def reset(client, strength=128):
all_mnemonics = []
def input_flow():
# Confirm Reset
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Backup your seed
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm warning
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# shares info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & Confirm number of shares
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# threshold info
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Set & confirm threshold value
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# Confirm show seeds
btn_code = yield
assert btn_code == B.ResetDevice
client.debug.press_yes()
# show & confirm shares
for h in range(5):
words = []
btn_code = yield
assert btn_code == B.Other
# mnemonic phrases
# 20 word over 6 pages for strength 128, 33 words over 9 pages for strength 256
for i in range(6):
words.extend(client.debug.read_reset_word().split())
if i < 5:
client.debug.swipe_down()
else:
# last page is confirmation
client.debug.press_yes()
# check share
for _ in range(3):
index = client.debug.read_reset_word_pos()
client.debug.input(words[index])
all_mnemonics.extend([" ".join(words)])
# Confirm continue to next share
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
# safety warning
btn_code = yield
assert btn_code == B.Success
client.debug.press_yes()
with client:
client.set_expected_responses(
[
messages.ButtonRequest(code=B.ResetDevice),
messages.EntropyRequest(),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.ResetDevice),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Other),
messages.ButtonRequest(code=B.Success),
messages.ButtonRequest(code=B.Success),
messages.Success(),
messages.Features(),
]
)
client.set_input_flow(input_flow)
# No PIN, no passphrase, don't display random
device.reset(
client,
display_random=False,
strength=strength,
passphrase_protection=False,
pin_protection=False,
label="test",
language="english",
backup_type=ResetDeviceBackupType.Slip39_Single_Group,
)
client.set_input_flow(None)
# Check if device is properly initialized
assert client.features.initialized is True
assert client.features.needs_backup is False
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
return all_mnemonics
def recover(client, shares):
debug = client.debug
def input_flow():
yield # Confirm Recovery
debug.press_yes()
# run recovery flow
yield from enter_all_shares(debug, shares)
with client:
client.set_input_flow(input_flow)
ret = device.recover(client, pin_protection=False, label="label")
client.set_input_flow(None)
# Workflow successfully ended
assert ret == messages.Success(message="Device recovered")
assert client.features.pin_protection is False
assert client.features.passphrase_protection is False
# TODO: let's merge this with test_msg_recoverydevice_shamir.py
def enter_all_shares(debug, shares):
word_count = len(shares[0].split(" "))
# Homescreen - proceed to word number selection
yield
debug.press_yes()
# Input word number
code = yield
assert code == B.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
debug.press_yes()
# Enter shares
for share in shares:
code = yield
assert code == B.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
# Homescreen - continue
# or Homescreen - confirm success
yield
debug.press_yes()