mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-18 04:18:10 +00:00
chore(core): decrease crypto size by 60 bytes
This commit is contained in:
parent
8bb73ffebe
commit
bd7513f2df
@ -2,8 +2,6 @@
|
|||||||
# https://github.com/micropython/micropython-lib/blob/master/base64/base64.py
|
# https://github.com/micropython/micropython-lib/blob/master/base64/base64.py
|
||||||
#
|
#
|
||||||
|
|
||||||
from ubinascii import unhexlify
|
|
||||||
from ustruct import unpack
|
|
||||||
|
|
||||||
_b32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
|
_b32alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"
|
||||||
|
|
||||||
@ -12,6 +10,8 @@ _b32rev = {ord(v): k for k, v in enumerate(_b32alphabet)}
|
|||||||
|
|
||||||
|
|
||||||
def encode(s: bytes) -> str:
|
def encode(s: bytes) -> str:
|
||||||
|
from ustruct import unpack
|
||||||
|
|
||||||
quanta, leftover = divmod(len(s), 5)
|
quanta, leftover = divmod(len(s), 5)
|
||||||
# Pad the last quantum with zero bits if necessary
|
# Pad the last quantum with zero bits if necessary
|
||||||
if leftover:
|
if leftover:
|
||||||
@ -53,6 +53,8 @@ def encode(s: bytes) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def decode(s: str) -> bytes:
|
def decode(s: str) -> bytes:
|
||||||
|
from ubinascii import unhexlify
|
||||||
|
|
||||||
data = s.encode()
|
data = s.encode()
|
||||||
_, leftover = divmod(len(data), 8)
|
_, leftover = divmod(len(data), 8)
|
||||||
if leftover:
|
if leftover:
|
||||||
|
@ -20,9 +20,11 @@
|
|||||||
|
|
||||||
"""Reference implementation for Bech32/Bech32m and segwit addresses."""
|
"""Reference implementation for Bech32/Bech32m and segwit addresses."""
|
||||||
|
|
||||||
|
from micropython import const
|
||||||
|
from trezorcrypto import bech32
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from micropython import const
|
bech32_decode = bech32.decode # reexported
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -67,7 +69,7 @@ def bech32_hrp_expand(hrp: str) -> list[int]:
|
|||||||
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]
|
return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp]
|
||||||
|
|
||||||
|
|
||||||
def bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[int]:
|
def _bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[int]:
|
||||||
"""Compute the checksum values given HRP and data."""
|
"""Compute the checksum values given HRP and data."""
|
||||||
values = bech32_hrp_expand(hrp) + data
|
values = bech32_hrp_expand(hrp) + data
|
||||||
const = _BECH32M_CONST if spec == Encoding.BECH32M else 1
|
const = _BECH32M_CONST if spec == Encoding.BECH32M else 1
|
||||||
@ -77,7 +79,7 @@ def bech32_create_checksum(hrp: str, data: list[int], spec: Encoding) -> list[in
|
|||||||
|
|
||||||
def bech32_encode(hrp: str, data: list[int], spec: Encoding) -> str:
|
def bech32_encode(hrp: str, data: list[int], spec: Encoding) -> str:
|
||||||
"""Compute a Bech32 string given HRP and data values."""
|
"""Compute a Bech32 string given HRP and data values."""
|
||||||
combined = data + bech32_create_checksum(hrp, data, spec)
|
combined = data + _bech32_create_checksum(hrp, data, spec)
|
||||||
return hrp + "1" + "".join([CHARSET[d] for d in combined])
|
return hrp + "1" + "".join([CHARSET[d] for d in combined])
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def prefix_expand(prefix: str) -> list[int]:
|
|||||||
return [ord(x) & 0x1F for x in prefix] + [0]
|
return [ord(x) & 0x1F for x in prefix] + [0]
|
||||||
|
|
||||||
|
|
||||||
def calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
|
def _calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
|
||||||
poly = cashaddr_polymod(prefix_expand(prefix) + payload + [0, 0, 0, 0, 0, 0, 0, 0])
|
poly = cashaddr_polymod(prefix_expand(prefix) + payload + [0, 0, 0, 0, 0, 0, 0, 0])
|
||||||
out = []
|
out = []
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
@ -58,18 +58,14 @@ def calculate_checksum(prefix: str, payload: list[int]) -> list[int]:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def verify_checksum(prefix: str, payload: list[int]) -> bool:
|
def _b32decode(inputs: str) -> list[int]:
|
||||||
return cashaddr_polymod(prefix_expand(prefix) + payload) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def b32decode(inputs: str) -> list[int]:
|
|
||||||
out = []
|
out = []
|
||||||
for letter in inputs:
|
for letter in inputs:
|
||||||
out.append(CHARSET.find(letter))
|
out.append(CHARSET.find(letter))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def b32encode(inputs: list[int]) -> str:
|
def _b32encode(inputs: list[int]) -> str:
|
||||||
out = ""
|
out = ""
|
||||||
for char_code in inputs:
|
for char_code in inputs:
|
||||||
out += CHARSET[char_code]
|
out += CHARSET[char_code]
|
||||||
@ -79,14 +75,18 @@ def b32encode(inputs: list[int]) -> str:
|
|||||||
def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
|
def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
|
||||||
payload_bytes = bytes([version]) + payload_bytes
|
payload_bytes = bytes([version]) + payload_bytes
|
||||||
payload = convertbits(payload_bytes, 8, 5)
|
payload = convertbits(payload_bytes, 8, 5)
|
||||||
checksum = calculate_checksum(prefix, payload)
|
checksum = _calculate_checksum(prefix, payload)
|
||||||
return prefix + ":" + b32encode(payload + checksum)
|
return prefix + ":" + _b32encode(payload + checksum)
|
||||||
|
|
||||||
|
|
||||||
def decode(prefix: str, addr: str) -> tuple[int, bytes]:
|
def decode(prefix: str, addr: str) -> tuple[int, bytes]:
|
||||||
addr = addr.lower()
|
addr = addr.lower()
|
||||||
decoded = b32decode(addr)
|
decoded = _b32decode(addr)
|
||||||
if not verify_checksum(prefix, decoded):
|
|
||||||
|
# verify_checksum
|
||||||
|
checksum_verified = cashaddr_polymod(prefix_expand(prefix) + decoded) == 0
|
||||||
|
if not checksum_verified:
|
||||||
raise ValueError("Bad cashaddr checksum")
|
raise ValueError("Bad cashaddr checksum")
|
||||||
|
|
||||||
data = bytes(convertbits(decoded, 5, 8))
|
data = bytes(convertbits(decoded, 5, 8))
|
||||||
return data[0], data[1:-6]
|
return data[0], data[1:-6]
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor.utils import BufferReader, empty_bytearray
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezor.utils import Writer
|
from trezor.utils import Writer
|
||||||
|
from trezor.utils import BufferReader
|
||||||
|
|
||||||
# Maximum length of a DER-encoded secp256k1 or secp256p1 signature.
|
# Maximum length of a DER-encoded secp256k1 or secp256p1 signature.
|
||||||
_MAX_DER_SIGNATURE_LENGTH = const(72)
|
_MAX_DER_SIGNATURE_LENGTH = const(72)
|
||||||
@ -41,7 +40,7 @@ def read_length(r: BufferReader) -> int:
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
def write_int(w: Writer, number: bytes) -> None:
|
def _write_int(w: Writer, number: bytes) -> None:
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(number) and number[i] == 0:
|
while i < len(number) and number[i] == 0:
|
||||||
i += 1
|
i += 1
|
||||||
@ -57,7 +56,9 @@ def write_int(w: Writer, number: bytes) -> None:
|
|||||||
w.extend(memoryview(number)[i:])
|
w.extend(memoryview(number)[i:])
|
||||||
|
|
||||||
|
|
||||||
def read_int(r: BufferReader) -> memoryview:
|
def _read_int(r: BufferReader) -> memoryview:
|
||||||
|
peek = r.peek # local_cache_attribute
|
||||||
|
|
||||||
if r.get() != 0x02:
|
if r.get() != 0x02:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
@ -65,32 +66,36 @@ def read_int(r: BufferReader) -> memoryview:
|
|||||||
if n == 0:
|
if n == 0:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
if r.peek() & 0x80:
|
if peek() & 0x80:
|
||||||
raise ValueError # negative integer
|
raise ValueError # negative integer
|
||||||
|
|
||||||
if r.peek() == 0x00 and n > 1:
|
if peek() == 0x00 and n > 1:
|
||||||
r.get()
|
r.get()
|
||||||
n -= 1
|
n -= 1
|
||||||
if r.peek() & 0x80 == 0x00:
|
if peek() & 0x80 == 0x00:
|
||||||
raise ValueError # excessive zero-padding
|
raise ValueError # excessive zero-padding
|
||||||
|
|
||||||
if r.peek() == 0x00:
|
if peek() == 0x00:
|
||||||
raise ValueError # excessive zero-padding
|
raise ValueError # excessive zero-padding
|
||||||
|
|
||||||
return r.read_memoryview(n)
|
return r.read_memoryview(n)
|
||||||
|
|
||||||
|
|
||||||
def encode_seq(seq: tuple[bytes, ...]) -> bytes:
|
def encode_seq(seq: tuple[bytes, ...]) -> bytes:
|
||||||
|
from trezor.utils import empty_bytearray
|
||||||
|
|
||||||
# Preallocate space for a signature, which is all that this function ever encodes.
|
# Preallocate space for a signature, which is all that this function ever encodes.
|
||||||
buffer = empty_bytearray(_MAX_DER_SIGNATURE_LENGTH)
|
buffer = empty_bytearray(_MAX_DER_SIGNATURE_LENGTH)
|
||||||
buffer.append(0x30)
|
buffer.append(0x30)
|
||||||
for i in seq:
|
for i in seq:
|
||||||
write_int(buffer, i)
|
_write_int(buffer, i)
|
||||||
buffer[1:1] = encode_length(len(buffer) - 1)
|
buffer[1:1] = encode_length(len(buffer) - 1)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
def decode_seq(data: memoryview) -> list[memoryview]:
|
def decode_seq(data: memoryview) -> list[memoryview]:
|
||||||
|
from trezor.utils import BufferReader
|
||||||
|
|
||||||
r = BufferReader(data)
|
r = BufferReader(data)
|
||||||
|
|
||||||
if r.get() != 0x30:
|
if r.get() != 0x30:
|
||||||
@ -100,7 +105,7 @@ def decode_seq(data: memoryview) -> list[memoryview]:
|
|||||||
seq = []
|
seq = []
|
||||||
end = r.offset + n
|
end = r.offset + n
|
||||||
while r.offset < end:
|
while r.offset < end:
|
||||||
i = read_int(r)
|
i = _read_int(r)
|
||||||
seq.append(i)
|
seq.append(i)
|
||||||
|
|
||||||
if r.offset != end or r.remaining_count():
|
if r.offset != end or r.remaining_count():
|
||||||
|
@ -80,12 +80,12 @@ def length(item: RLPItem) -> int:
|
|||||||
return header_length(item_length, data) + item_length
|
return header_length(item_length, data) + item_length
|
||||||
|
|
||||||
|
|
||||||
def write_string(w: Writer, string: bytes) -> None:
|
def _write_string(w: Writer, string: bytes) -> None:
|
||||||
write_header(w, len(string), STRING_HEADER_BYTE, string)
|
write_header(w, len(string), STRING_HEADER_BYTE, string)
|
||||||
w.extend(string)
|
w.extend(string)
|
||||||
|
|
||||||
|
|
||||||
def write_list(w: Writer, lst: RLPList) -> None:
|
def _write_list(w: Writer, lst: RLPList) -> None:
|
||||||
payload_length = sum(length(item) for item in lst)
|
payload_length = sum(length(item) for item in lst)
|
||||||
write_header(w, payload_length, LIST_HEADER_BYTE)
|
write_header(w, payload_length, LIST_HEADER_BYTE)
|
||||||
for item in lst:
|
for item in lst:
|
||||||
@ -94,10 +94,10 @@ def write_list(w: Writer, lst: RLPList) -> None:
|
|||||||
|
|
||||||
def write(w: Writer, item: RLPItem) -> None:
|
def write(w: Writer, item: RLPItem) -> None:
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
write_string(w, int_to_bytes(item))
|
_write_string(w, int_to_bytes(item))
|
||||||
elif isinstance(item, (bytes, bytearray)):
|
elif isinstance(item, (bytes, bytearray)):
|
||||||
write_string(w, item)
|
_write_string(w, item)
|
||||||
elif isinstance(item, list):
|
elif isinstance(item, list):
|
||||||
write_list(w, item)
|
_write_list(w, item)
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
@ -34,7 +34,7 @@ from micropython import const
|
|||||||
from trezorcrypto import shamir, slip39
|
from trezorcrypto import shamir, slip39
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor.crypto import hmac, pbkdf2, random
|
from trezor.crypto import random
|
||||||
from trezor.errors import MnemonicError
|
from trezor.errors import MnemonicError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -416,6 +416,8 @@ def _rs1024_verify_checksum(data: Indices) -> bool:
|
|||||||
|
|
||||||
def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes:
|
def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes:
|
||||||
"""The round function used internally by the Feistel cipher."""
|
"""The round function used internally by the Feistel cipher."""
|
||||||
|
from trezor.crypto import pbkdf2
|
||||||
|
|
||||||
return pbkdf2(
|
return pbkdf2(
|
||||||
pbkdf2.HMAC_SHA256,
|
pbkdf2.HMAC_SHA256,
|
||||||
bytes([i]) + passphrase,
|
bytes([i]) + passphrase,
|
||||||
@ -431,6 +433,8 @@ def _get_salt(identifier: int) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
|
def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
|
||||||
|
from trezor.crypto import hmac
|
||||||
|
|
||||||
return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES]
|
return hmac(hmac.SHA256, random_data, shared_secret).digest()[:_DIGEST_LENGTH_BYTES]
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,38 +3,6 @@ from trezor.crypto import slip39, random
|
|||||||
from slip39_vectors import vectors
|
from slip39_vectors import vectors
|
||||||
|
|
||||||
|
|
||||||
# NOTE: moved into tests not to occupy flash space
|
|
||||||
# in firmware binary, when it is not used in production
|
|
||||||
def _rs1024_error_index(data: tuple[int, ...]) -> int | None:
|
|
||||||
"""
|
|
||||||
Returns the index where an error possibly occurred.
|
|
||||||
"""
|
|
||||||
GEN = (
|
|
||||||
0x91F_9F87,
|
|
||||||
0x122F_1F07,
|
|
||||||
0x244E_1E07,
|
|
||||||
0x81C_1C07,
|
|
||||||
0x1028_1C0E,
|
|
||||||
0x2040_1C1C,
|
|
||||||
0x10_3838,
|
|
||||||
0x20_7070,
|
|
||||||
0x40_E0E0,
|
|
||||||
0x81_C1C0,
|
|
||||||
)
|
|
||||||
chk = slip39._rs1024_polymod(tuple(slip39._CUSTOMIZATION_STRING) + data) ^ 1
|
|
||||||
if chk == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for i in reversed(range(len(data))):
|
|
||||||
b = chk & 0x3FF
|
|
||||||
chk >>= 10
|
|
||||||
if chk == 0:
|
|
||||||
return i
|
|
||||||
for j in range(10):
|
|
||||||
chk ^= GEN[j] if ((b >> j) & 1) else 0
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def combinations(iterable, r):
|
def combinations(iterable, r):
|
||||||
# Taken from https://docs.python.org/3.7/library/itertools.html#itertools.combinations
|
# Taken from https://docs.python.org/3.7/library/itertools.html#itertools.combinations
|
||||||
pool = tuple(iterable)
|
pool = tuple(iterable)
|
||||||
@ -186,19 +154,5 @@ class TestCryptoSlip39(unittest.TestCase):
|
|||||||
slip39.recover_ems(mnemonics)
|
slip39.recover_ems(mnemonics)
|
||||||
|
|
||||||
|
|
||||||
def test_error_location(self):
|
|
||||||
mnemonics = [
|
|
||||||
"duckling enlarge academic academic agency result length solution fridge kidney coal piece deal husband erode duke ajar critical decision keyboard",
|
|
||||||
"theory painting academic academic armed sweater year military elder discuss acne wildlife boring employer fused large satoshi bundle carbon diagnose anatomy hamster leaves tracks paces beyond phantom capital marvel lips brave detect luck",
|
|
||||||
]
|
|
||||||
for mnemonic in mnemonics:
|
|
||||||
data = tuple(slip39._mnemonic_to_indices(mnemonic))
|
|
||||||
self.assertEqual(_rs1024_error_index(data), None)
|
|
||||||
for i in range(len(data)):
|
|
||||||
for _ in range(50):
|
|
||||||
error_data = error_data = data[:i] + (data[i] ^ (random.uniform(1023) + 1), ) + data[i + 1:]
|
|
||||||
self.assertEqual(_rs1024_error_index(error_data), i)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user