1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-17 10:51:00 +00:00

chore(core): decrease crypto size by 60 bytes

This commit is contained in:
grdddj 2022-09-21 11:16:38 +02:00 committed by matejcik
parent 8bb73ffebe
commit bd7513f2df
7 changed files with 45 additions and 78 deletions

View File

@ -2,8 +2,6 @@
# https://github.com/micropython/micropython-lib/blob/master/base64/base64.py
#
from ubinascii import unhexlify
from ustruct import unpack
_b32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
@ -12,6 +10,8 @@ _b32rev = {ord(v): k for k, v in enumerate(_b32alphabet)}
def encode(s: bytes) -> str:
from ustruct import unpack
quanta, leftover = divmod(len(s), 5)
# Pad the last quantum with zero bits if necessary
if leftover:
@ -53,6 +53,8 @@ def encode(s: bytes) -> str:
def decode(s: str) -> bytes:
from ubinascii import unhexlify
data = s.encode()
_, leftover = divmod(len(data), 8)
if leftover:

View File

@ -20,9 +20,11 @@
"""Reference implementation for Bech32/Bech32m and segwit addresses."""
from micropython import const
from trezorcrypto import bech32
from typing import TYPE_CHECKING
from micropython import const
bech32_decode = bech32.decode # reexported
if TYPE_CHECKING:
@ -67,7 +69,7 @@ def bech32_hrp_expand(hrp: str) -> list[int]:
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]
def bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[int]:
def _bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[int]:
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
const = _BECH32M_CONST if spec == Encoding.BECH32M else 1
@ -77,7 +79,7 @@ def bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[in
def bech32_encode(hrp: str, data: list[int], spec: Encoding) -> str:
"""Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec)
combined = data + _bech32_create_checksum(hrp, data, spec)
return hrp + "1" + "".join([CHARSET[d] for d in combined])

View File

@ -50,7 +50,7 @@ def prefix_expand(prefix: str) -> list[int]:
return [ord(x) & 0x1F for x in prefix] + [0]
def calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
def _calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
poly = cashaddr_polymod(prefix_expand(prefix) + payload + [0, 0, 0, 0, 0, 0, 0, 0])
out = []
for i in range(8):
@ -58,18 +58,14 @@ def calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
return out
def verify_checksum(prefix: str, payload: list[int]) -> bool:
return cashaddr_polymod(prefix_expand(prefix) + payload) == 0
def b32decode(inputs: str) -> list[int]:
def _b32decode(inputs: str) -> list[int]:
out = []
for letter in inputs:
out.append(CHARSET.find(letter))
return out
def b32encode(inputs: list[int]) -> str:
def _b32encode(inputs: list[int]) -> str:
out = ""
for char_code in inputs:
out += CHARSET[char_code]
@ -79,14 +75,18 @@ def b32encode(inputs: list[int]) -> str:
def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
payload_bytes = bytes([version]) + payload_bytes
payload = convertbits(payload_bytes, 8, 5)
checksum = calculate_checksum(prefix, payload)
return prefix + ":" + b32encode(payload + checksum)
checksum = _calculate_checksum(prefix, payload)
return prefix + ":" + _b32encode(payload + checksum)
def decode(prefix: str, addr: str) -> tuple[int, bytes]:
addr = addr.lower()
decoded = b32decode(addr)
if not verify_checksum(prefix, decoded):
decoded = _b32decode(addr)
# verify_checksum
checksum_verified = cashaddr_polymod(prefix_expand(prefix) + decoded) == 0
if not checksum_verified:
raise ValueError("Bad cashaddr checksum")
data = bytes(convertbits(decoded, 5, 8))
return data[0], data[1:-6]

View File

@ -1,10 +1,9 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor.utils import BufferReader, empty_bytearray
if TYPE_CHECKING:
from trezor.utils import Writer
from trezor.utils import BufferReader
# Maximum length of a DER-encoded secp256k1 or secp256p1 signature.
_MAX_DER_SIGNATURE_LENGTH = const(72)
@ -41,7 +40,7 @@ def read_length(r: BufferReader) -> int:
return n
def write_int(w: Writer, number: bytes) -> None:
def _write_int(w: Writer, number: bytes) -> None:
i = 0
while i < len(number) and number[i] == 0:
i += 1
@ -57,7 +56,9 @@ def write_int(w: Writer, number: bytes) -> None:
w.extend(memoryview(number)[i:])
def read_int(r: BufferReader) -> memoryview:
def _read_int(r: BufferReader) -> memoryview:
peek = r.peek # local_cache_attribute
if r.get() != 0x02:
raise ValueError
@ -65,32 +66,36 @@ def read_int(r: BufferReader) -> memoryview:
if n == 0:
raise ValueError
if r.peek() & 0x80:
if peek() & 0x80:
raise ValueError # negative integer
if r.peek() == 0x00 and n > 1:
if peek() == 0x00 and n > 1:
r.get()
n -= 1
if r.peek() & 0x80 == 0x00:
if peek() & 0x80 == 0x00:
raise ValueError # excessive zero-padding
if r.peek() == 0x00:
if peek() == 0x00:
raise ValueError # excessive zero-padding
return r.read_memoryview(n)
def encode_seq(seq: tuple[bytes, ...]) -> bytes:
from trezor.utils import empty_bytearray
# Preallocate space for a signature, which is all that this function ever encodes.
buffer = empty_bytearray(_MAX_DER_SIGNATURE_LENGTH)
buffer.append(0x30)
for i in seq:
write_int(buffer, i)
_write_int(buffer, i)
buffer[1:1] = encode_length(len(buffer) - 1)
return buffer
def decode_seq(data: memoryview) -> list[memoryview]:
from trezor.utils import BufferReader
r = BufferReader(data)
if r.get() != 0x30:
@ -100,7 +105,7 @@ def decode_seq(data: memoryview) -> list[memoryview]:
seq = []
end = r.offset + n
while r.offset < end:
i = read_int(r)
i = _read_int(r)
seq.append(i)
if r.offset != end or r.remaining_count():

View File

@ -80,12 +80,12 @@ def length(item: RLPItem) -> int:
return header_length(item_length, data) + item_length
def write_string(w: Writer, string: bytes) -> None:
def _write_string(w: Writer, string: bytes) -> None:
write_header(w, len(string), STRING_HEADER_BYTE, string)
w.extend(string)
def write_list(w: Writer, lst: RLPList) -> None:
def _write_list(w: Writer, lst: RLPList) -> None:
payload_length = sum(length(item) for item in lst)
write_header(w, payload_length, LIST_HEADER_BYTE)
for item in lst:
@ -94,10 +94,10 @@ def write_list(w: Writer, lst: RLPList) -> None:
def write(w: Writer, item: RLPItem) -> None:
if isinstance(item, int):
write_string(w, int_to_bytes(item))
_write_string(w, int_to_bytes(item))
elif isinstance(item, (bytes, bytearray)):
write_string(w, item)
_write_string(w, item)
elif isinstance(item, list):
write_list(w, item)
_write_list(w, item)
else:
raise TypeError

View File

@ -34,7 +34,7 @@ from micropython import const
from trezorcrypto import shamir, slip39
from typing import TYPE_CHECKING
from trezor.crypto import hmac, pbkdf2, random
from trezor.crypto import random
from trezor.errors import MnemonicError
if TYPE_CHECKING:
@ -416,6 +416,8 @@ def _rs1024_verify_checksum(data: Indices) -> bool:
def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes:
"""The round function used internally by the Feistel cipher."""
from trezor.crypto import pbkdf2
return pbkdf2(
pbkdf2.HMAC_SHA256,
bytes([i]) + passphrase,
@ -431,6 +433,8 @@ def _get_salt(identifier: int) -> bytes:
def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
from trezor.crypto import hmac
return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES]

View File

@ -3,38 +3,6 @@ from trezor.crypto import slip39, random
from slip39_vectors import vectors
# NOTE: moved into tests not to occupy flash space
# in firmware binary, when it is not used in production
def _rs1024_error_index(data: tuple[int, ...]) -> int | None:
"""
Returns the index where an error possibly occurred.
"""
GEN = (
0x91F_9F87,
0x122F_1F07,
0x244E_1E07,
0x81C_1C07,
0x1028_1C0E,
0x2040_1C1C,
0x10_3838,
0x20_7070,
0x40_E0E0,
0x81_C1C0,
)
chk = slip39._rs1024_polymod(tuple(slip39._CUSTOMIZATION_STRING) + data) ^ 1
if chk == 0:
return None
for i in reversed(range(len(data))):
b = chk & 0x3FF
chk >>= 10
if chk == 0:
return i
for j in range(10):
chk ^= GEN[j] if ((b >> j) & 1) else 0
return None
def combinations(iterable, r):
# Taken from https://docs.python.org/3.7/library/itertools.html#itertools.combinations
pool = tuple(iterable)
@ -186,19 +154,5 @@ class TestCryptoSlip39(unittest.TestCase):
slip39.recover_ems(mnemonics)
def test_error_location(self):
mnemonics = [
"duckling enlarge academic academic agency result length solution fridge kidney coal piece deal husband erode duke ajar critical decision keyboard",
"theory painting academic academic armed sweater year military elder discuss acne wildlife boring employer fused large satoshi bundle carbon diagnose anatomy hamster leaves tracks paces beyond phantom capital marvel lips brave detect luck",
]
for mnemonic in mnemonics:
data = tuple(slip39._mnemonic_to_indices(mnemonic))
self.assertEqual(_rs1024_error_index(data), None)
for i in range(len(data)):
for _ in range(50):
error_data = error_data = data[:i] + (data[i] ^ (random.uniform(1023) + 1), ) + data[i + 1:]
self.assertEqual(_rs1024_error_index(error_data), i)
if __name__ == '__main__':
unittest.main()