mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-17 10:51:00 +00:00
382 lines
11 KiB
Python
382 lines
11 KiB
Python
import gc
|
|
import sys
|
|
from trezorutils import ( # noqa: F401
|
|
BITCOIN_ONLY,
|
|
EMULATOR,
|
|
INTERNAL_MODEL,
|
|
MODEL,
|
|
MODEL_FULL_NAME,
|
|
SCM_REVISION,
|
|
UI_LAYOUT,
|
|
USE_BACKLIGHT,
|
|
USE_OPTIGA,
|
|
USE_SD_CARD,
|
|
VERSION,
|
|
bootloader_locked,
|
|
check_firmware_header,
|
|
consteq,
|
|
firmware_hash,
|
|
firmware_vendor,
|
|
halt,
|
|
memcpy,
|
|
reboot_to_bootloader,
|
|
sd_hotswap_enabled,
|
|
unit_btconly,
|
|
unit_color,
|
|
)
|
|
from typing import TYPE_CHECKING
|
|
|
|
# Will get replaced by "True" / "False" in the build process
|
|
# However, needs to stay as an exported symbol for the unit tests
|
|
MODEL_IS_T2B1: bool = INTERNAL_MODEL == "T2B1"
|
|
|
|
DISABLE_ANIMATION = 0
|
|
|
|
if __debug__:
|
|
if EMULATOR:
|
|
import uos
|
|
|
|
DISABLE_ANIMATION = int(uos.getenv("TREZOR_DISABLE_ANIMATION") or "0")
|
|
LOG_MEMORY = int(uos.getenv("TREZOR_LOG_MEMORY") or "0")
|
|
else:
|
|
LOG_MEMORY = 0
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Any, Iterator, Protocol, Sequence, TypeVar
|
|
|
|
from trezor.protobuf import MessageType
|
|
|
|
|
|
def unimport_begin() -> set[str]:
|
|
return set(sys.modules)
|
|
|
|
|
|
def unimport_end(mods: set[str], collect: bool = True) -> None:
|
|
# static check that the size of sys.modules never grows above value of
|
|
# MICROPY_LOADED_MODULES_DICT_SIZE, so that the sys.modules dict is never
|
|
# reallocated at run-time
|
|
assert len(sys.modules) <= 160, "Please bump preallocated size in mpconfigport.h"
|
|
|
|
for mod in sys.modules: # pylint: disable=consider-using-dict-items
|
|
if mod not in mods:
|
|
# remove reference from sys.modules
|
|
del sys.modules[mod]
|
|
# remove reference from the parent module
|
|
i = mod.rfind(".")
|
|
if i < 0:
|
|
continue
|
|
path = mod[:i]
|
|
name = mod[i + 1 :]
|
|
try:
|
|
delattr(sys.modules[path], name)
|
|
except KeyError:
|
|
# either path is not present in sys.modules, or module is not
|
|
# referenced from the parent package. both is fine.
|
|
pass
|
|
# collect removed modules
|
|
if collect:
|
|
gc.collect()
|
|
|
|
|
|
class unimport:
|
|
def __init__(self) -> None:
|
|
self.mods: set[str] | None = None
|
|
|
|
def __enter__(self) -> None:
|
|
self.mods = unimport_begin()
|
|
|
|
def __exit__(self, _exc_type: Any, _exc_value: Any, _tb: Any) -> None:
|
|
assert self.mods is not None
|
|
unimport_end(self.mods, collect=False)
|
|
self.mods.clear()
|
|
self.mods = None
|
|
gc.collect()
|
|
|
|
|
|
def presize_module(modname: str, size: int) -> None:
|
|
"""Ensure the module's dict is preallocated to an expected size.
|
|
|
|
This is used in modules like `trezor`, whose dict size depends not only on the
|
|
symbols defined in the file itself, but also on the number of submodules that will
|
|
be inserted into the module's namespace.
|
|
"""
|
|
module = sys.modules[modname]
|
|
for i in range(size):
|
|
setattr(module, f"___PRESIZE_MODULE_{i}", None)
|
|
for i in range(size):
|
|
delattr(module, f"___PRESIZE_MODULE_{i}")
|
|
|
|
|
|
if __debug__:
|
|
|
|
def mem_dump(filename: str) -> None:
|
|
from micropython import mem_info
|
|
|
|
print(f"### sysmodules ({len(sys.modules)}):")
|
|
for mod in sys.modules:
|
|
print("*", mod)
|
|
if EMULATOR:
|
|
from trezorutils import meminfo
|
|
|
|
print("### dumping to", filename)
|
|
meminfo(filename)
|
|
mem_info()
|
|
else:
|
|
mem_info(True)
|
|
|
|
|
|
def ensure(cond: bool, msg: str | None = None) -> None:
|
|
if not cond:
|
|
if msg is None:
|
|
raise AssertionError
|
|
else:
|
|
raise AssertionError(msg)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
Chunkable = TypeVar("Chunkable", str, Sequence[Any])
|
|
|
|
|
|
def chunks(items: Chunkable, size: int) -> Iterator[Chunkable]:
|
|
for i in range(0, len(items), size):
|
|
yield items[i : i + size]
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
class HashContext(Protocol):
|
|
def update(self, __buf: bytes) -> None: ...
|
|
|
|
def digest(self) -> bytes: ...
|
|
|
|
class HashContextInitable(HashContext, Protocol):
|
|
def __init__( # pylint: disable=super-init-not-called
|
|
self, __data: bytes | None = None
|
|
) -> None: ...
|
|
|
|
class Writer(Protocol):
|
|
def append(self, __b: int) -> None: ...
|
|
|
|
def extend(self, __buf: bytes) -> None: ...
|
|
|
|
|
|
if False: # noqa
|
|
|
|
class DebugHashContextWrapper:
|
|
"""
|
|
Use this wrapper to debug hashing operations. When digest() is called,
|
|
it will log all of the data that was provided to update().
|
|
|
|
Example usage:
|
|
self.h_prevouts = HashWriter(DebugHashContextWrapper(sha256()))
|
|
"""
|
|
|
|
def __init__(self, ctx: HashContext) -> None:
|
|
self.ctx = ctx
|
|
self.data = ""
|
|
|
|
def update(self, data: bytes) -> None:
|
|
from ubinascii import hexlify
|
|
|
|
self.ctx.update(data)
|
|
self.data += hexlify(data).decode() + " "
|
|
|
|
def digest(self) -> bytes:
|
|
from ubinascii import hexlify
|
|
|
|
from trezor import log
|
|
|
|
digest = self.ctx.digest()
|
|
log.debug(
|
|
__name__,
|
|
"%s hash: %s, data: %s",
|
|
self.ctx.__class__.__name__,
|
|
hexlify(digest).decode(),
|
|
self.data,
|
|
)
|
|
return digest
|
|
|
|
|
|
class HashWriter:
|
|
def __init__(self, ctx: HashContext) -> None:
|
|
self.ctx = ctx
|
|
self.buf = bytearray(1) # used in append()
|
|
|
|
def append(self, b: int) -> None:
|
|
self.buf[0] = b
|
|
self.ctx.update(self.buf)
|
|
|
|
def extend(self, buf: bytes) -> None:
|
|
self.ctx.update(buf)
|
|
|
|
def write(self, buf: bytes) -> None: # alias for extend()
|
|
self.ctx.update(buf)
|
|
|
|
def get_digest(self) -> bytes:
|
|
return self.ctx.digest()
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
BufferType = bytearray | memoryview
|
|
|
|
|
|
class BufferReader:
|
|
"""Seekable and readable view into a buffer."""
|
|
|
|
def __init__(self, buffer: bytes | memoryview) -> None:
|
|
if isinstance(buffer, memoryview):
|
|
self.buffer = buffer
|
|
else:
|
|
self.buffer = memoryview(buffer)
|
|
self.offset = 0
|
|
|
|
def seek(self, offset: int) -> None:
|
|
"""Set current offset to `offset`.
|
|
|
|
If negative, set to zero. If longer than the buffer, set to end of buffer.
|
|
"""
|
|
offset = min(offset, len(self.buffer))
|
|
offset = max(offset, 0)
|
|
self.offset = offset
|
|
|
|
def readinto(self, dst: BufferType) -> int:
|
|
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
|
|
|
|
Returns number of bytes read.
|
|
"""
|
|
buffer = self.buffer
|
|
offset = self.offset
|
|
if len(dst) > len(buffer) - offset:
|
|
raise EOFError
|
|
nread = memcpy(dst, 0, buffer, offset)
|
|
self.offset += nread
|
|
return nread
|
|
|
|
def read(self, length: int | None = None) -> bytes:
|
|
"""Read and return exactly `length` bytes, or raise EOFError.
|
|
|
|
If `length` is unspecified, reads all remaining data.
|
|
|
|
Note that this method makes a copy of the data. To avoid allocation, use
|
|
`readinto()`. To avoid copying use `read_memoryview()`.
|
|
"""
|
|
return bytes(self.read_memoryview(length))
|
|
|
|
def read_memoryview(self, length: int | None = None) -> memoryview:
|
|
"""Read and return a memoryview of exactly `length` bytes, or raise
|
|
EOFError.
|
|
|
|
If `length` is unspecified, reads all remaining data.
|
|
"""
|
|
if length is None:
|
|
ret = self.buffer[self.offset :]
|
|
self.offset = len(self.buffer)
|
|
elif length < 0:
|
|
raise ValueError
|
|
elif length <= self.remaining_count():
|
|
ret = self.buffer[self.offset : self.offset + length]
|
|
self.offset += length
|
|
else:
|
|
raise EOFError
|
|
return ret
|
|
|
|
def remaining_count(self) -> int:
|
|
"""Return the number of bytes remaining for reading."""
|
|
return len(self.buffer) - self.offset
|
|
|
|
def peek(self) -> int:
|
|
"""Peek the ordinal value of the next byte to be read."""
|
|
if self.offset >= len(self.buffer):
|
|
raise EOFError
|
|
return self.buffer[self.offset]
|
|
|
|
def get(self) -> int:
|
|
"""Read exactly one byte and return its ordinal value."""
|
|
if self.offset >= len(self.buffer):
|
|
raise EOFError
|
|
byte = self.buffer[self.offset]
|
|
self.offset += 1
|
|
return byte
|
|
|
|
|
|
def obj_eq(self: Any, __o: Any) -> bool:
|
|
"""
|
|
Compares object contents.
|
|
"""
|
|
if self.__class__ is not __o.__class__:
|
|
return False
|
|
assert not hasattr(self, "__slots__")
|
|
return self.__dict__ == __o.__dict__
|
|
|
|
|
|
def obj_repr(self: Any) -> str:
|
|
"""
|
|
Returns a string representation of object.
|
|
"""
|
|
assert not hasattr(self, "__slots__")
|
|
return f"<{self.__class__.__name__}: {self.__dict__}>"
|
|
|
|
|
|
def truncate_utf8(string: str, max_bytes: int) -> str:
|
|
"""Truncate the codepoints of a string so that its UTF-8 encoding is at most `max_bytes` in length."""
|
|
data = string.encode()
|
|
if len(data) <= max_bytes:
|
|
return string
|
|
|
|
# Find the starting position of the last codepoint in data[0 : max_bytes + 1].
|
|
i = max_bytes
|
|
while i >= 0 and data[i] & 0xC0 == 0x80:
|
|
i -= 1
|
|
|
|
return data[:i].decode()
|
|
|
|
|
|
def is_empty_iterator(i: Iterator) -> bool:
|
|
try:
|
|
next(i)
|
|
except StopIteration:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def empty_bytearray(preallocate: int) -> bytearray:
|
|
"""
|
|
Returns bytearray that won't allocate for at least `preallocate` bytes.
|
|
Useful in case you want to avoid allocating too often.
|
|
"""
|
|
b = bytearray(preallocate)
|
|
b[:] = bytes()
|
|
return b
|
|
|
|
|
|
if __debug__:
|
|
|
|
def dump_protobuf_lines(msg: MessageType, line_start: str = "") -> Iterator[str]:
|
|
msg_dict = msg.__dict__
|
|
if not msg_dict:
|
|
yield line_start + msg.MESSAGE_NAME + " {}"
|
|
return
|
|
|
|
yield line_start + msg.MESSAGE_NAME + " {"
|
|
for key, val in msg_dict.items():
|
|
if type(val) is type(msg):
|
|
sublines = dump_protobuf_lines(val, line_start=key + ": ")
|
|
for subline in sublines:
|
|
yield " " + subline
|
|
elif val and isinstance(val, list) and type(val[0]) is type(msg):
|
|
# non-empty list of protobuf messages
|
|
yield f" {key}: ["
|
|
for subval in val:
|
|
sublines = dump_protobuf_lines(subval)
|
|
for subline in sublines:
|
|
yield " " + subline
|
|
yield " ]"
|
|
else:
|
|
yield f" {key}: {repr(val)}"
|
|
|
|
yield "}"
|
|
|
|
def dump_protobuf(msg: MessageType) -> str:
|
|
return "\n".join(dump_protobuf_lines(msg))
|