1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-28 17:18:29 +00:00

core: use const in trezor.crypto.slip39

This commit is contained in:
Jan Pochyla 2019-05-28 11:12:27 +02:00
parent b89d1db0e4
commit 09da577fd9

View File

@ -18,13 +18,14 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# #
from micropython import const
from trezor.crypto import hashlib, hmac, pbkdf2, random from trezor.crypto import hashlib, hmac, pbkdf2, random
from trezor.crypto.slip39_wordlist import wordlist from trezor.crypto.slip39_wordlist import wordlist
from trezorcrypto import shamir from trezorcrypto import shamir
_RADIX_BITS = const(10)
class MnemonicError(Exception): """The length of the radix in bits."""
pass
def bits_to_bytes(n): def bits_to_bytes(n):
@ -32,62 +33,63 @@ def bits_to_bytes(n):
def bits_to_words(n): def bits_to_words(n):
return (n + RADIX_BITS - 1) // RADIX_BITS return (n + _RADIX_BITS - 1) // _RADIX_BITS
RADIX_BITS = 10 _RADIX = 2 ** _RADIX_BITS
"""The length of the radix in bits."""
RADIX = 2 ** RADIX_BITS
"""The number of words in the wordlist.""" """The number of words in the wordlist."""
ID_LENGTH_BITS = 15 _ID_LENGTH_BITS = const(15)
"""The length of the random identifier in bits.""" """The length of the random identifier in bits."""
ITERATION_EXP_LENGTH_BITS = 5 _ITERATION_EXP_LENGTH_BITS = const(5)
"""The length of the iteration exponent in bits.""" """The length of the iteration exponent in bits."""
ID_EXP_LENGTH_WORDS = bits_to_words(ID_LENGTH_BITS + ITERATION_EXP_LENGTH_BITS) _ID_EXP_LENGTH_WORDS = bits_to_words(_ID_LENGTH_BITS + _ITERATION_EXP_LENGTH_BITS)
"""The length of the random identifier and iteration exponent in words.""" """The length of the random identifier and iteration exponent in words."""
MAX_SHARE_COUNT = 16 _MAX_SHARE_COUNT = const(16)
"""The maximum number of shares that can be created.""" """The maximum number of shares that can be created."""
CHECKSUM_LENGTH_WORDS = 3 _CHECKSUM_LENGTH_WORDS = const(3)
"""The length of the RS1024 checksum in words.""" """The length of the RS1024 checksum in words."""
DIGEST_LENGTH_BYTES = 4 _DIGEST_LENGTH_BYTES = const(4)
"""The length of the digest of the shared secret in bytes.""" """The length of the digest of the shared secret in bytes."""
CUSTOMIZATION_STRING = b"shamir" _CUSTOMIZATION_STRING = b"shamir"
"""The customization string used in the RS1024 checksum and in the PBKDF2 salt.""" """The customization string used in the RS1024 checksum and in the PBKDF2 salt."""
METADATA_LENGTH_WORDS = ID_EXP_LENGTH_WORDS + 2 + CHECKSUM_LENGTH_WORDS _METADATA_LENGTH_WORDS = _ID_EXP_LENGTH_WORDS + 2 + _CHECKSUM_LENGTH_WORDS
"""The length of the mnemonic in words without the share value.""" """The length of the mnemonic in words without the share value."""
MIN_STRENGTH_BITS = 128 _MIN_STRENGTH_BITS = const(128)
"""The minimum allowed entropy of the master secret.""" """The minimum allowed entropy of the master secret."""
MIN_MNEMONIC_LENGTH_WORDS = METADATA_LENGTH_WORDS + bits_to_words(MIN_STRENGTH_BITS) _MIN_MNEMONIC_LENGTH_WORDS = _METADATA_LENGTH_WORDS + bits_to_words(_MIN_STRENGTH_BITS)
"""The minimum allowed length of the mnemonic in words.""" """The minimum allowed length of the mnemonic in words."""
BASE_ITERATION_COUNT = 10000 _BASE_ITERATION_COUNT = const(10000)
"""The minimum number of iterations to use in PBKDF2.""" """The minimum number of iterations to use in PBKDF2."""
ROUND_COUNT = 4 _ROUND_COUNT = const(4)
"""The number of rounds to use in the Feistel cipher.""" """The number of rounds to use in the Feistel cipher."""
SECRET_INDEX = 255 _SECRET_INDEX = const(255)
"""The index of the share containing the shared secret.""" """The index of the share containing the shared secret."""
DIGEST_INDEX = 254 _DIGEST_INDEX = const(254)
"""The index of the share containing the digest of the shared secret.""" """The index of the share containing the digest of the shared secret."""
class MnemonicError(Exception):
pass
def word_index(word): def word_index(word):
word = word + " " * (8 - len(word)) word = word + " " * (8 - len(word))
lo = 0 lo = 0
hi = RADIX hi = _RADIX
while hi - lo > 1: while hi - lo > 1:
mid = (hi + lo) // 2 mid = (hi + lo) // 2
if wordlist[mid * 8 : mid * 8 + 8] > word: if wordlist[mid * 8 : mid * 8 + 8] > word:
@ -122,15 +124,15 @@ def _rs1024_polymod(values):
def rs1024_create_checksum(data): def rs1024_create_checksum(data):
values = tuple(CUSTOMIZATION_STRING) + data + CHECKSUM_LENGTH_WORDS * (0,) values = tuple(_CUSTOMIZATION_STRING) + data + _CHECKSUM_LENGTH_WORDS * (0,)
polymod = _rs1024_polymod(values) ^ 1 polymod = _rs1024_polymod(values) ^ 1
return tuple( return tuple(
(polymod >> 10 * i) & 1023 for i in reversed(range(CHECKSUM_LENGTH_WORDS)) (polymod >> 10 * i) & 1023 for i in reversed(range(_CHECKSUM_LENGTH_WORDS))
) )
def rs1024_verify_checksum(data): def rs1024_verify_checksum(data):
return _rs1024_polymod(tuple(CUSTOMIZATION_STRING) + data) == 1 return _rs1024_polymod(tuple(_CUSTOMIZATION_STRING) + data) == 1
def xor(a, b): def xor(a, b):
@ -141,7 +143,7 @@ def _int_from_indices(indices):
"""Converts a list of base 1024 indices in big endian order to an integer value.""" """Converts a list of base 1024 indices in big endian order to an integer value."""
value = 0 value = 0
for index in indices: for index in indices:
value = (value << RADIX_BITS) + index value = (value << _RADIX_BITS) + index
return value return value
@ -165,13 +167,13 @@ def _round_function(i, passphrase, e, salt, r):
pbkdf2.HMAC_SHA256, pbkdf2.HMAC_SHA256,
bytes([i]) + passphrase, bytes([i]) + passphrase,
salt + r, salt + r,
(BASE_ITERATION_COUNT << e) // ROUND_COUNT, (_BASE_ITERATION_COUNT << e) // _ROUND_COUNT,
).key()[: len(r)] ).key()[: len(r)]
def _get_salt(identifier): def _get_salt(identifier):
return CUSTOMIZATION_STRING + identifier.to_bytes( return _CUSTOMIZATION_STRING + identifier.to_bytes(
bits_to_bytes(ID_LENGTH_BITS), "big" bits_to_bytes(_ID_LENGTH_BITS), "big"
) )
@ -179,7 +181,7 @@ def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
l = master_secret[: len(master_secret) // 2] l = master_secret[: len(master_secret) // 2]
r = master_secret[len(master_secret) // 2 :] r = master_secret[len(master_secret) // 2 :]
salt = _get_salt(identifier) salt = _get_salt(identifier)
for i in range(ROUND_COUNT): for i in range(_ROUND_COUNT):
(l, r) = ( (l, r) = (
r, r,
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)), xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
@ -191,7 +193,7 @@ def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase)
l = encrypted_master_secret[: len(encrypted_master_secret) // 2] l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :] r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
salt = _get_salt(identifier) salt = _get_salt(identifier)
for i in reversed(range(ROUND_COUNT)): for i in reversed(range(_ROUND_COUNT)):
(l, r) = ( (l, r) = (
r, r,
xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)), xor(l, _round_function(i, passphrase, iteration_exponent, salt, r)),
@ -201,7 +203,7 @@ def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase)
def _create_digest(random_data, shared_secret): def _create_digest(random_data, shared_secret):
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[ return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
:DIGEST_LENGTH_BYTES :_DIGEST_LENGTH_BYTES
] ]
@ -218,10 +220,10 @@ def _split_secret(threshold, share_count, shared_secret):
) )
) )
if share_count > MAX_SHARE_COUNT: if share_count > _MAX_SHARE_COUNT:
raise ValueError( raise ValueError(
"The requested number of shares ({}) must not exceed {}.".format( "The requested number of shares ({}) must not exceed {}.".format(
share_count, MAX_SHARE_COUNT share_count, _MAX_SHARE_COUNT
) )
) )
@ -233,12 +235,12 @@ def _split_secret(threshold, share_count, shared_secret):
shares = [(i, random.bytes(len(shared_secret))) for i in range(random_share_count)] shares = [(i, random.bytes(len(shared_secret))) for i in range(random_share_count)]
random_part = random.bytes(len(shared_secret) - DIGEST_LENGTH_BYTES) random_part = random.bytes(len(shared_secret) - _DIGEST_LENGTH_BYTES)
digest = _create_digest(random_part, shared_secret) digest = _create_digest(random_part, shared_secret)
base_shares = shares + [ base_shares = shares + [
(DIGEST_INDEX, digest + random_part), (_DIGEST_INDEX, digest + random_part),
(SECRET_INDEX, shared_secret), (_SECRET_INDEX, shared_secret),
] ]
for i in range(random_share_count, share_count): for i in range(random_share_count, share_count):
@ -252,10 +254,10 @@ def _recover_secret(threshold, shares):
if threshold == 1: if threshold == 1:
return shares[0][1] return shares[0][1]
shared_secret = shamir.interpolate(shares, SECRET_INDEX) shared_secret = shamir.interpolate(shares, _SECRET_INDEX)
digest_share = shamir.interpolate(shares, DIGEST_INDEX) digest_share = shamir.interpolate(shares, _DIGEST_INDEX)
digest = digest_share[:DIGEST_LENGTH_BYTES] digest = digest_share[:_DIGEST_LENGTH_BYTES]
random_part = digest_share[DIGEST_LENGTH_BYTES:] random_part = digest_share[_DIGEST_LENGTH_BYTES:]
if digest != _create_digest(random_part, shared_secret): if digest != _create_digest(random_part, shared_secret):
raise MnemonicError("Invalid digest of the shared secret.") raise MnemonicError("Invalid digest of the shared secret.")
@ -266,8 +268,8 @@ def _recover_secret(threshold, shares):
def _group_prefix( def _group_prefix(
identifier, iteration_exponent, group_index, group_threshold, group_count identifier, iteration_exponent, group_index, group_threshold, group_count
): ):
id_exp_int = (identifier << ITERATION_EXP_LENGTH_BITS) + iteration_exponent id_exp_int = (identifier << _ITERATION_EXP_LENGTH_BITS) + iteration_exponent
return tuple(_int_to_indices(id_exp_int, ID_EXP_LENGTH_WORDS, RADIX_BITS)) + ( return tuple(_int_to_indices(id_exp_int, _ID_EXP_LENGTH_WORDS, _RADIX_BITS)) + (
(group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2), (group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2),
) )
@ -310,7 +312,7 @@ def encode_mnemonic(
+ (member_index << 4) + (member_index << 4)
+ (member_threshold - 1), + (member_threshold - 1),
) )
+ tuple(_int_to_indices(value_int, value_word_count, RADIX_BITS)) + tuple(_int_to_indices(value_int, value_word_count, _RADIX_BITS))
) )
checksum = rs1024_create_checksum(share_data) checksum = rs1024_create_checksum(share_data)
@ -322,48 +324,48 @@ def decode_mnemonic(mnemonic):
mnemonic_data = tuple(mnemonic_to_indices(mnemonic)) mnemonic_data = tuple(mnemonic_to_indices(mnemonic))
if len(mnemonic_data) < MIN_MNEMONIC_LENGTH_WORDS: if len(mnemonic_data) < _MIN_MNEMONIC_LENGTH_WORDS:
raise MnemonicError( raise MnemonicError(
"Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format( "Invalid mnemonic length. The length of each mnemonic must be at least {} words.".format(
MIN_MNEMONIC_LENGTH_WORDS _MIN_MNEMONIC_LENGTH_WORDS
) )
) )
padding_len = (RADIX_BITS * (len(mnemonic_data) - METADATA_LENGTH_WORDS)) % 16 padding_len = (_RADIX_BITS * (len(mnemonic_data) - _METADATA_LENGTH_WORDS)) % 16
if padding_len > 8: if padding_len > 8:
raise MnemonicError("Invalid mnemonic length.") raise MnemonicError("Invalid mnemonic length.")
if not rs1024_verify_checksum(mnemonic_data): if not rs1024_verify_checksum(mnemonic_data):
raise MnemonicError( raise MnemonicError(
'Invalid mnemonic checksum for "{} ...".'.format( 'Invalid mnemonic checksum for "{} ...".'.format(
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2]) " ".join(mnemonic.split()[: _ID_EXP_LENGTH_WORDS + 2])
) )
) )
id_exp_int = _int_from_indices(mnemonic_data[:ID_EXP_LENGTH_WORDS]) id_exp_int = _int_from_indices(mnemonic_data[:_ID_EXP_LENGTH_WORDS])
identifier = id_exp_int >> ITERATION_EXP_LENGTH_BITS identifier = id_exp_int >> _ITERATION_EXP_LENGTH_BITS
iteration_exponent = id_exp_int & ((1 << ITERATION_EXP_LENGTH_BITS) - 1) iteration_exponent = id_exp_int & ((1 << _ITERATION_EXP_LENGTH_BITS) - 1)
tmp = _int_from_indices( tmp = _int_from_indices(
mnemonic_data[ID_EXP_LENGTH_WORDS : ID_EXP_LENGTH_WORDS + 2] mnemonic_data[_ID_EXP_LENGTH_WORDS : _ID_EXP_LENGTH_WORDS + 2]
) )
group_index, group_threshold, group_count, member_index, member_threshold = _int_to_indices( group_index, group_threshold, group_count, member_index, member_threshold = _int_to_indices(
tmp, 5, 4 tmp, 5, 4
) )
value_data = mnemonic_data[ID_EXP_LENGTH_WORDS + 2 : -CHECKSUM_LENGTH_WORDS] value_data = mnemonic_data[_ID_EXP_LENGTH_WORDS + 2 : -_CHECKSUM_LENGTH_WORDS]
if group_count < group_threshold: if group_count < group_threshold:
raise MnemonicError( raise MnemonicError(
'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format( 'Invalid mnemonic "{} ...". Group threshold cannot be greater than group count.'.format(
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2]) " ".join(mnemonic.split()[: _ID_EXP_LENGTH_WORDS + 2])
) )
) )
value_byte_count = bits_to_bytes(RADIX_BITS * len(value_data) - padding_len) value_byte_count = bits_to_bytes(_RADIX_BITS * len(value_data) - padding_len)
value_int = _int_from_indices(value_data) value_int = _int_from_indices(value_data)
if value_data[0] >= 1 << (RADIX_BITS - padding_len): if value_data[0] >= 1 << (_RADIX_BITS - padding_len):
raise MnemonicError( raise MnemonicError(
'Invalid mnemonic padding for "{} ...".'.format( 'Invalid mnemonic padding for "{} ...".'.format(
" ".join(mnemonic.split()[: ID_EXP_LENGTH_WORDS + 2]) " ".join(mnemonic.split()[: _ID_EXP_LENGTH_WORDS + 2])
) )
) )
value = value_int.to_bytes(value_byte_count, "big") value = value_int.to_bytes(value_byte_count, "big")
@ -404,7 +406,7 @@ def _decode_mnemonics(mnemonics):
if len(identifiers) != 1 or len(iteration_exponents) != 1: if len(identifiers) != 1 or len(iteration_exponents) != 1:
raise MnemonicError( raise MnemonicError(
"Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format( "Invalid set of mnemonics. All mnemonics must begin with the same {} words.".format(
ID_EXP_LENGTH_WORDS _ID_EXP_LENGTH_WORDS
) )
) )
@ -434,10 +436,10 @@ def _decode_mnemonics(mnemonics):
def _generate_random_identifier(): def _generate_random_identifier():
"""Returns a randomly generated integer in the range 0, ... , 2**ID_LENGTH_BITS - 1.""" """Returns a randomly generated integer in the range 0, ... , 2**_ID_LENGTH_BITS - 1."""
identifier = int.from_bytes(random.bytes(bits_to_bytes(ID_LENGTH_BITS)), "big") identifier = int.from_bytes(random.bytes(bits_to_bytes(_ID_LENGTH_BITS)), "big")
return identifier & ((1 << ID_LENGTH_BITS) - 1) return identifier & ((1 << _ID_LENGTH_BITS) - 1)
def generate_mnemonics( def generate_mnemonics(
@ -461,10 +463,10 @@ def generate_mnemonics(
identifier = _generate_random_identifier() identifier = _generate_random_identifier()
if len(master_secret) * 8 < MIN_STRENGTH_BITS: if len(master_secret) * 8 < _MIN_STRENGTH_BITS:
raise ValueError( raise ValueError(
"The length of the master secret ({} bytes) must be at least {} bytes.".format( "The length of the master secret ({} bytes) must be at least {} bytes.".format(
len(master_secret), bits_to_bytes(MIN_STRENGTH_BITS) len(master_secret), bits_to_bytes(_MIN_STRENGTH_BITS)
) )
) )
@ -540,10 +542,10 @@ def generate_mnemonics_random(
:rtype: List of byte arrays. :rtype: List of byte arrays.
""" """
if strength_bits < MIN_STRENGTH_BITS: if strength_bits < _MIN_STRENGTH_BITS:
raise ValueError( raise ValueError(
"The requested strength of the master secret ({} bits) must be at least {} bits.".format( "The requested strength of the master secret ({} bits) must be at least {} bits.".format(
strength_bits, MIN_STRENGTH_BITS strength_bits, _MIN_STRENGTH_BITS
) )
) )