mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-17 19:00:58 +00:00
272 lines
6.9 KiB
Python
272 lines
6.9 KiB
Python
import gc
|
|
import sys
|
|
from trezorutils import ( # noqa: F401
|
|
BITCOIN_ONLY,
|
|
EMULATOR,
|
|
GITREV,
|
|
MODEL,
|
|
VERSION_MAJOR,
|
|
VERSION_MINOR,
|
|
VERSION_PATCH,
|
|
consteq,
|
|
halt,
|
|
memcpy,
|
|
)
|
|
|
|
DISABLE_ANIMATION = 0
|
|
|
|
if __debug__:
|
|
if EMULATOR:
|
|
import uos
|
|
|
|
DISABLE_ANIMATION = int(uos.getenv("TREZOR_DISABLE_ANIMATION") or "0")
|
|
LOG_MEMORY = int(uos.getenv("TREZOR_LOG_MEMORY") or "0")
|
|
else:
|
|
LOG_MEMORY = 0
|
|
|
|
if False:
|
|
from typing import (
|
|
Any,
|
|
Iterable,
|
|
Iterator,
|
|
Optional,
|
|
Protocol,
|
|
Union,
|
|
TypeVar,
|
|
Sequence,
|
|
)
|
|
|
|
|
|
def unimport_begin() -> Iterable[str]:
|
|
return set(sys.modules)
|
|
|
|
|
|
def unimport_end(mods: Iterable[str]) -> None:
|
|
for mod in sys.modules:
|
|
if mod not in mods:
|
|
# remove reference from sys.modules
|
|
del sys.modules[mod]
|
|
# remove reference from the parent module
|
|
i = mod.rfind(".")
|
|
if i < 0:
|
|
continue
|
|
path = mod[:i]
|
|
name = mod[i + 1 :]
|
|
try:
|
|
delattr(sys.modules[path], name)
|
|
except KeyError:
|
|
# either path is not present in sys.modules, or module is not
|
|
# referenced from the parent package. both is fine.
|
|
pass
|
|
# collect removed modules
|
|
gc.collect()
|
|
|
|
|
|
def ensure(cond: bool, msg: str = None) -> None:
|
|
if not cond:
|
|
if msg is None:
|
|
raise AssertionError
|
|
else:
|
|
raise AssertionError(msg)
|
|
|
|
|
|
if False:
|
|
Chunkable = TypeVar("Chunkable", str, Sequence[Any])
|
|
|
|
|
|
def chunks(items: Chunkable, size: int) -> Iterator[Chunkable]:
|
|
for i in range(0, len(items), size):
|
|
yield items[i : i + size]
|
|
|
|
|
|
if False:
|
|
|
|
class HashContext(Protocol):
|
|
def update(self, buf: bytes) -> None:
|
|
...
|
|
|
|
def digest(self) -> bytes:
|
|
...
|
|
|
|
class Writer(Protocol):
|
|
def append(self, b: int) -> None:
|
|
...
|
|
|
|
def extend(self, buf: bytes) -> None:
|
|
...
|
|
|
|
|
|
class HashWriter:
|
|
def __init__(self, ctx: HashContext) -> None:
|
|
self.ctx = ctx
|
|
self.buf = bytearray(1) # used in append()
|
|
|
|
def append(self, b: int) -> None:
|
|
self.buf[0] = b
|
|
self.ctx.update(self.buf)
|
|
|
|
def extend(self, buf: bytes) -> None:
|
|
self.ctx.update(buf)
|
|
|
|
def write(self, buf: bytes) -> None: # alias for extend()
|
|
self.ctx.update(buf)
|
|
|
|
async def awrite(self, buf: bytes) -> int: # AsyncWriter interface
|
|
self.ctx.update(buf)
|
|
return len(buf)
|
|
|
|
def get_digest(self) -> bytes:
|
|
return self.ctx.digest()
|
|
|
|
|
|
if False:
|
|
BufferType = Union[bytearray, memoryview]
|
|
|
|
|
|
class BufferWriter:
|
|
"""Seekable and writeable view into a buffer."""
|
|
|
|
def __init__(self, buffer: BufferType) -> None:
|
|
self.buffer = buffer
|
|
self.offset = 0
|
|
|
|
def seek(self, offset: int) -> None:
|
|
"""Set current offset to `offset`.
|
|
|
|
If negative, set to zero. If longer than the buffer, set to end of buffer.
|
|
"""
|
|
offset = min(offset, len(self.buffer))
|
|
offset = max(offset, 0)
|
|
self.offset = offset
|
|
|
|
def write(self, src: bytes) -> int:
|
|
"""Write exactly `len(src)` bytes into buffer, or raise EOFError.
|
|
|
|
Returns number of bytes written.
|
|
"""
|
|
buffer = self.buffer
|
|
offset = self.offset
|
|
if len(src) > len(buffer) - offset:
|
|
raise EOFError
|
|
nwrite = memcpy(buffer, offset, src, 0)
|
|
self.offset += nwrite
|
|
return nwrite
|
|
|
|
|
|
class BufferReader:
|
|
"""Seekable and readable view into a buffer."""
|
|
|
|
def __init__(self, buffer: bytes) -> None:
|
|
self.buffer = buffer
|
|
self.offset = 0
|
|
|
|
def seek(self, offset: int) -> None:
|
|
"""Set current offset to `offset`.
|
|
|
|
If negative, set to zero. If longer than the buffer, set to end of buffer.
|
|
"""
|
|
offset = min(offset, len(self.buffer))
|
|
offset = max(offset, 0)
|
|
self.offset = offset
|
|
|
|
def readinto(self, dst: BufferType) -> int:
|
|
"""Read exactly `len(dst)` bytes into `dst`, or raise EOFError.
|
|
|
|
Returns number of bytes read.
|
|
"""
|
|
buffer = self.buffer
|
|
offset = self.offset
|
|
if len(dst) > len(buffer) - offset:
|
|
raise EOFError
|
|
nread = memcpy(dst, 0, buffer, offset)
|
|
self.offset += nread
|
|
return nread
|
|
|
|
def read(self, length: Optional[int] = None) -> bytes:
|
|
"""Read and return exactly `length` bytes, or raise EOFError.
|
|
|
|
If `length` is unspecified, reads all remaining data.
|
|
|
|
Note that this method makes a copy of the data. To avoid allocation, use
|
|
`readinto()`.
|
|
"""
|
|
if length is None:
|
|
ret = self.buffer[self.offset :]
|
|
self.offset = len(self.buffer)
|
|
elif length < 0:
|
|
raise ValueError
|
|
elif length <= self.remaining_count():
|
|
ret = self.buffer[self.offset : self.offset + length]
|
|
self.offset += length
|
|
else:
|
|
raise EOFError
|
|
return ret
|
|
|
|
def remaining_count(self) -> int:
|
|
"""Return the number of bytes remaining for reading."""
|
|
return len(self.buffer) - self.offset
|
|
|
|
def peek(self) -> int:
|
|
"""Peek the ordinal value of the next byte to be read."""
|
|
if self.offset >= len(self.buffer):
|
|
raise EOFError
|
|
return self.buffer[self.offset]
|
|
|
|
def get(self) -> int:
|
|
"""Read exactly one byte and return its ordinal value."""
|
|
if self.offset >= len(self.buffer):
|
|
raise EOFError
|
|
byte = self.buffer[self.offset]
|
|
self.offset += 1
|
|
return byte
|
|
|
|
|
|
def obj_eq(l: object, r: object) -> bool:
|
|
"""
|
|
Compares object contents, supports __slots__.
|
|
"""
|
|
if l.__class__ is not r.__class__:
|
|
return False
|
|
if not hasattr(l, "__slots__"):
|
|
return l.__dict__ == r.__dict__
|
|
if l.__slots__ is not r.__slots__:
|
|
return False
|
|
for slot in l.__slots__:
|
|
if getattr(l, slot, None) != getattr(r, slot, None):
|
|
return False
|
|
return True
|
|
|
|
|
|
def obj_repr(o: object) -> str:
|
|
"""
|
|
Returns a string representation of object, supports __slots__.
|
|
"""
|
|
if hasattr(o, "__slots__"):
|
|
d = {attr: getattr(o, attr, None) for attr in o.__slots__}
|
|
else:
|
|
d = o.__dict__
|
|
return "<%s: %s>" % (o.__class__.__name__, d)
|
|
|
|
|
|
def truncate_utf8(string: str, max_bytes: int) -> str:
|
|
"""Truncate the codepoints of a string so that its UTF-8 encoding is at most `max_bytes` in length."""
|
|
data = string.encode()
|
|
if len(data) <= max_bytes:
|
|
return string
|
|
|
|
# Find the starting position of the last codepoint in data[0 : max_bytes + 1].
|
|
i = max_bytes
|
|
while i >= 0 and data[i] & 0xC0 == 0x80:
|
|
i -= 1
|
|
|
|
return data[:i].decode()
|
|
|
|
|
|
def is_empty_iterator(i: Iterator) -> bool:
|
|
try:
|
|
next(i)
|
|
except StopIteration:
|
|
return True
|
|
else:
|
|
return False
|