chore(core): decrease crypto size by 60 bytes

pull/2633/head
grdddj 2 years ago committed by matejcik
parent 8bb73ffebe
commit bd7513f2df

@ -2,8 +2,6 @@
# https://github.com/micropython/micropython-lib/blob/master/base64/base64.py # https://github.com/micropython/micropython-lib/blob/master/base64/base64.py
# #
from ubinascii import unhexlify
from ustruct import unpack
_b32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567" _b32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
@ -12,6 +10,8 @@ _b32rev = {ord(v): k for k, v in enumerate(_b32alphabet)}
def encode(s: bytes) -> str: def encode(s: bytes) -> str:
from ustruct import unpack
quanta, leftover = divmod(len(s), 5) quanta, leftover = divmod(len(s), 5)
# Pad the last quantum with zero bits if necessary # Pad the last quantum with zero bits if necessary
if leftover: if leftover:
@ -53,6 +53,8 @@ def encode(s: bytes) -> str:
def decode(s: str) -> bytes: def decode(s: str) -> bytes:
from ubinascii import unhexlify
data = s.encode() data = s.encode()
_, leftover = divmod(len(data), 8) _, leftover = divmod(len(data), 8)
if leftover: if leftover:

@ -20,9 +20,11 @@
"""Reference implementation for Bech32/Bech32m and segwit addresses.""" """Reference implementation for Bech32/Bech32m and segwit addresses."""
from micropython import const
from trezorcrypto import bech32
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from micropython import const bech32_decode = bech32.decode # reexported
if TYPE_CHECKING: if TYPE_CHECKING:
@ -67,7 +69,7 @@ def bech32_hrp_expand(hrp: str) -> list[int]:
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp] return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]
def bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[int]: def _bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[int]:
"""Compute the checksum values given HRP and data.""" """Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data values = bech32_hrp_expand(hrp) + data
const = _BECH32M_CONST if spec == Encoding.BECH32M else 1 const = _BECH32M_CONST if spec == Encoding.BECH32M else 1
@ -77,7 +79,7 @@ def bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[in
def bech32_encode(hrp: str, data: list[int], spec: Encoding) -> str: def bech32_encode(hrp: str, data: list[int], spec: Encoding) -> str:
"""Compute a Bech32 string given HRP and data values.""" """Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec) combined = data + _bech32_create_checksum(hrp, data, spec)
return hrp + "1" + "".join([CHARSET[d] for d in combined]) return hrp + "1" + "".join([CHARSET[d] for d in combined])

@ -50,7 +50,7 @@ def prefix_expand(prefix: str) -> list[int]:
return [ord(x) & 0x1F for x in prefix] + [0] return [ord(x) & 0x1F for x in prefix] + [0]
def calculate_checksum(prefix: str, payload: list[int]) -> list[int]: def _calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
poly = cashaddr_polymod(prefix_expand(prefix) + payload + [0, 0, 0, 0, 0, 0, 0, 0]) poly = cashaddr_polymod(prefix_expand(prefix) + payload + [0, 0, 0, 0, 0, 0, 0, 0])
out = [] out = []
for i in range(8): for i in range(8):
@ -58,18 +58,14 @@ def calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
return out return out
def verify_checksum(prefix: str, payload: list[int]) -> bool: def _b32decode(inputs: str) -> list[int]:
return cashaddr_polymod(prefix_expand(prefix) + payload) == 0
def b32decode(inputs: str) -> list[int]:
out = [] out = []
for letter in inputs: for letter in inputs:
out.append(CHARSET.find(letter)) out.append(CHARSET.find(letter))
return out return out
def b32encode(inputs: list[int]) -> str: def _b32encode(inputs: list[int]) -> str:
out = "" out = ""
for char_code in inputs: for char_code in inputs:
out += CHARSET[char_code] out += CHARSET[char_code]
@ -79,14 +75,18 @@ def b32encode(inputs: list[int]) -> str:
def encode(prefix: str, version: int, payload_bytes: bytes) -> str: def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
payload_bytes = bytes([version]) + payload_bytes payload_bytes = bytes([version]) + payload_bytes
payload = convertbits(payload_bytes, 8, 5) payload = convertbits(payload_bytes, 8, 5)
checksum = calculate_checksum(prefix, payload) checksum = _calculate_checksum(prefix, payload)
return prefix + ":" + b32encode(payload + checksum) return prefix + ":" + _b32encode(payload + checksum)
def decode(prefix: str, addr: str) -> tuple[int, bytes]: def decode(prefix: str, addr: str) -> tuple[int, bytes]:
addr = addr.lower() addr = addr.lower()
decoded = b32decode(addr) decoded = _b32decode(addr)
if not verify_checksum(prefix, decoded):
# verify_checksum
checksum_verified = cashaddr_polymod(prefix_expand(prefix) + decoded) == 0
if not checksum_verified:
raise ValueError("Bad cashaddr checksum") raise ValueError("Bad cashaddr checksum")
data = bytes(convertbits(decoded, 5, 8)) data = bytes(convertbits(decoded, 5, 8))
return data[0], data[1:-6] return data[0], data[1:-6]

@ -1,10 +1,9 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.utils import BufferReader, empty_bytearray
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.utils import Writer from trezor.utils import Writer
from trezor.utils import BufferReader
# Maximum length of a DER-encoded secp256k1 or secp256p1 signature. # Maximum length of a DER-encoded secp256k1 or secp256p1 signature.
_MAX_DER_SIGNATURE_LENGTH = const(72) _MAX_DER_SIGNATURE_LENGTH = const(72)
@ -41,7 +40,7 @@ def read_length(r: BufferReader) -> int:
return n return n
def write_int(w: Writer, number: bytes) -> None: def _write_int(w: Writer, number: bytes) -> None:
i = 0 i = 0
while i < len(number) and number[i] == 0: while i < len(number) and number[i] == 0:
i += 1 i += 1
@ -57,7 +56,9 @@ def write_int(w: Writer, number: bytes) -> None:
w.extend(memoryview(number)[i:]) w.extend(memoryview(number)[i:])
def read_int(r: BufferReader) -> memoryview: def _read_int(r: BufferReader) -> memoryview:
peek = r.peek # local_cache_attribute
if r.get() != 0x02: if r.get() != 0x02:
raise ValueError raise ValueError
@ -65,32 +66,36 @@ def read_int(r: BufferReader) -> memoryview:
if n == 0: if n == 0:
raise ValueError raise ValueError
if r.peek() & 0x80: if peek() & 0x80:
raise ValueError # negative integer raise ValueError # negative integer
if r.peek() == 0x00 and n > 1: if peek() == 0x00 and n > 1:
r.get() r.get()
n -= 1 n -= 1
if r.peek() & 0x80 == 0x00: if peek() & 0x80 == 0x00:
raise ValueError # excessive zero-padding raise ValueError # excessive zero-padding
if r.peek() == 0x00: if peek() == 0x00:
raise ValueError # excessive zero-padding raise ValueError # excessive zero-padding
return r.read_memoryview(n) return r.read_memoryview(n)
def encode_seq(seq: tuple[bytes, ...]) -> bytes: def encode_seq(seq: tuple[bytes, ...]) -> bytes:
from trezor.utils import empty_bytearray
# Preallocate space for a signature, which is all that this function ever encodes. # Preallocate space for a signature, which is all that this function ever encodes.
buffer = empty_bytearray(_MAX_DER_SIGNATURE_LENGTH) buffer = empty_bytearray(_MAX_DER_SIGNATURE_LENGTH)
buffer.append(0x30) buffer.append(0x30)
for i in seq: for i in seq:
write_int(buffer, i) _write_int(buffer, i)
buffer[1:1] = encode_length(len(buffer) - 1) buffer[1:1] = encode_length(len(buffer) - 1)
return buffer return buffer
def decode_seq(data: memoryview) -> list[memoryview]: def decode_seq(data: memoryview) -> list[memoryview]:
from trezor.utils import BufferReader
r = BufferReader(data) r = BufferReader(data)
if r.get() != 0x30: if r.get() != 0x30:
@ -100,7 +105,7 @@ def decode_seq(data: memoryview) -> list[memoryview]:
seq = [] seq = []
end = r.offset + n end = r.offset + n
while r.offset < end: while r.offset < end:
i = read_int(r) i = _read_int(r)
seq.append(i) seq.append(i)
if r.offset != end or r.remaining_count(): if r.offset != end or r.remaining_count():

@ -80,12 +80,12 @@ def length(item: RLPItem) -> int:
return header_length(item_length, data) + item_length return header_length(item_length, data) + item_length
def write_string(w: Writer, string: bytes) -> None: def _write_string(w: Writer, string: bytes) -> None:
write_header(w, len(string), STRING_HEADER_BYTE, string) write_header(w, len(string), STRING_HEADER_BYTE, string)
w.extend(string) w.extend(string)
def write_list(w: Writer, lst: RLPList) -> None: def _write_list(w: Writer, lst: RLPList) -> None:
payload_length = sum(length(item) for item in lst) payload_length = sum(length(item) for item in lst)
write_header(w, payload_length, LIST_HEADER_BYTE) write_header(w, payload_length, LIST_HEADER_BYTE)
for item in lst: for item in lst:
@ -94,10 +94,10 @@ def write_list(w: Writer, lst: RLPList) -> None:
def write(w: Writer, item: RLPItem) -> None: def write(w: Writer, item: RLPItem) -> None:
if isinstance(item, int): if isinstance(item, int):
write_string(w, int_to_bytes(item)) _write_string(w, int_to_bytes(item))
elif isinstance(item, (bytes, bytearray)): elif isinstance(item, (bytes, bytearray)):
write_string(w, item) _write_string(w, item)
elif isinstance(item, list): elif isinstance(item, list):
write_list(w, item) _write_list(w, item)
else: else:
raise TypeError raise TypeError

@ -34,7 +34,7 @@ from micropython import const
from trezorcrypto import shamir, slip39 from trezorcrypto import shamir, slip39
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.crypto import hmac, pbkdf2, random from trezor.crypto import random
from trezor.errors import MnemonicError from trezor.errors import MnemonicError
if TYPE_CHECKING: if TYPE_CHECKING:
@ -416,6 +416,8 @@ def _rs1024_verify_checksum(data: Indices) -> bool:
def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes: def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes:
"""The round function used internally by the Feistel cipher.""" """The round function used internally by the Feistel cipher."""
from trezor.crypto import pbkdf2
return pbkdf2( return pbkdf2(
pbkdf2.HMAC_SHA256, pbkdf2.HMAC_SHA256,
bytes([i]) + passphrase, bytes([i]) + passphrase,
@ -431,6 +433,8 @@ def _get_salt(identifier: int) -> bytes:
def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes: def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
from trezor.crypto import hmac
return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES] return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES]

@ -3,38 +3,6 @@ from trezor.crypto import slip39, random
from slip39_vectors import vectors from slip39_vectors import vectors
# NOTE: moved into tests not to occupy flash space
# in firmware binary, when it is not used in production
def _rs1024_error_index(data: tuple[int, ...]) -> int | None:
"""
Returns the index where an error possibly occurred.
"""
GEN = (
0x91F_9F87,
0x122F_1F07,
0x244E_1E07,
0x81C_1C07,
0x1028_1C0E,
0x2040_1C1C,
0x10_3838,
0x20_7070,
0x40_E0E0,
0x81_C1C0,
)
chk = slip39._rs1024_polymod(tuple(slip39._CUSTOMIZATION_STRING) + data) ^ 1
if chk == 0:
return None
for i in reversed(range(len(data))):
b = chk & 0x3FF
chk >>= 10
if chk == 0:
return i
for j in range(10):
chk ^= GEN[j] if ((b >> j) & 1) else 0
return None
def combinations(iterable, r): def combinations(iterable, r):
# Taken from https://docs.python.org/3.7/library/itertools.html#itertools.combinations # Taken from https://docs.python.org/3.7/library/itertools.html#itertools.combinations
pool = tuple(iterable) pool = tuple(iterable)
@ -186,19 +154,5 @@ class TestCryptoSlip39(unittest.TestCase):
slip39.recover_ems(mnemonics) slip39.recover_ems(mnemonics)
def test_error_location(self):
mnemonics = [
"duckling enlarge academic academic agency result length solution fridge kidney coal piece deal husband erode duke ajar critical decision keyboard",
"theory painting academic academic armed sweater year military elder discuss acne wildlife boring employer fused large satoshi bundle carbon diagnose anatomy hamster leaves tracks paces beyond phantom capital marvel lips brave detect luck",
]
for mnemonic in mnemonics:
data = tuple(slip39._mnemonic_to_indices(mnemonic))
self.assertEqual(_rs1024_error_index(data), None)
for i in range(len(data)):
for _ in range(50):
error_data = error_data = data[:i] + (data[i] ^ (random.uniform(1023) + 1), ) + data[i + 1:]
self.assertEqual(_rs1024_error_index(error_data), i)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save