mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-04-25 11:39:02 +00:00
chore(core): decrease common size by 5200 bytes
This commit is contained in:
parent
5e7cc8b692
commit
3711fd0f19
@ -1,9 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import utils, wire
|
from trezor import utils
|
||||||
from trezor.crypto import hashlib, hmac
|
|
||||||
|
|
||||||
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from apps.common.keychain import Keychain
|
from apps.common.keychain import Keychain
|
||||||
@ -14,12 +11,18 @@ _ADDRESS_MAC_KEY_PATH = [b"SLIP-0024", b"Address MAC key"]
|
|||||||
def check_address_mac(
|
def check_address_mac(
|
||||||
address: str, mac: bytes, slip44: int, keychain: Keychain
|
address: str, mac: bytes, slip44: int, keychain: Keychain
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from trezor import wire
|
||||||
|
from trezor.crypto import hashlib
|
||||||
|
|
||||||
expected_mac = get_address_mac(address, slip44, keychain)
|
expected_mac = get_address_mac(address, slip44, keychain)
|
||||||
if len(mac) != hashlib.sha256.digest_size or not utils.consteq(expected_mac, mac):
|
if len(mac) != hashlib.sha256.digest_size or not utils.consteq(expected_mac, mac):
|
||||||
raise wire.DataError("Invalid address MAC.")
|
raise wire.DataError("Invalid address MAC.")
|
||||||
|
|
||||||
|
|
||||||
def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes:
|
def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes:
|
||||||
|
from trezor.crypto import hmac
|
||||||
|
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
|
||||||
|
|
||||||
# k = Key(m/"SLIP-0024"/"Address MAC key")
|
# k = Key(m/"SLIP-0024"/"Address MAC key")
|
||||||
node = keychain.derive_slip21(_ADDRESS_MAC_KEY_PATH)
|
node = keychain.derive_slip21(_ADDRESS_MAC_KEY_PATH)
|
||||||
|
|
||||||
|
@ -1,45 +1,53 @@
|
|||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
import storage.cache
|
import storage.cache as storage_cache
|
||||||
from trezor import protobuf
|
from trezor import protobuf
|
||||||
from trezor.enums import MessageType
|
from trezor.enums import MessageType
|
||||||
from trezor.utils import ensure
|
|
||||||
|
|
||||||
WIRE_TYPES: dict[int, tuple[int, ...]] = {
|
WIRE_TYPES: dict[int, tuple[int, ...]] = {
|
||||||
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
|
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
APP_COMMON_AUTHORIZATION_DATA = (
|
||||||
|
storage_cache.APP_COMMON_AUTHORIZATION_DATA
|
||||||
|
) # global_import_cache
|
||||||
|
APP_COMMON_AUTHORIZATION_TYPE = (
|
||||||
|
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
|
||||||
|
) # global_import_cache
|
||||||
|
|
||||||
|
|
||||||
def is_set() -> bool:
|
def is_set() -> bool:
|
||||||
return bool(storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE))
|
return bool(storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE))
|
||||||
|
|
||||||
|
|
||||||
def set(auth_message: protobuf.MessageType) -> None:
|
def set(auth_message: protobuf.MessageType) -> None:
|
||||||
|
from trezor.utils import ensure
|
||||||
|
|
||||||
buffer = protobuf.dump_message_buffer(auth_message)
|
buffer = protobuf.dump_message_buffer(auth_message)
|
||||||
|
|
||||||
# only wire-level messages can be stored as authorization
|
# only wire-level messages can be stored as authorization
|
||||||
# (because only wire-level messages have wire_type, which we use as identifier)
|
# (because only wire-level messages have wire_type, which we use as identifier)
|
||||||
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
|
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
|
||||||
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
|
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
|
||||||
storage.cache.set(
|
storage_cache.set(
|
||||||
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
|
APP_COMMON_AUTHORIZATION_TYPE,
|
||||||
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
|
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
|
||||||
)
|
)
|
||||||
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, buffer)
|
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
|
||||||
|
|
||||||
|
|
||||||
def get() -> protobuf.MessageType | None:
|
def get() -> protobuf.MessageType | None:
|
||||||
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
|
||||||
if not stored_auth_type:
|
if not stored_auth_type:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
||||||
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")
|
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
|
||||||
return protobuf.load_message_buffer(buffer, msg_wire_type)
|
return protobuf.load_message_buffer(buffer, msg_wire_type)
|
||||||
|
|
||||||
|
|
||||||
def get_wire_types() -> Iterable[int]:
|
def get_wire_types() -> Iterable[int]:
|
||||||
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
|
||||||
if stored_auth_type is None:
|
if stored_auth_type is None:
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
@ -48,5 +56,5 @@ def get_wire_types() -> Iterable[int]:
|
|||||||
|
|
||||||
|
|
||||||
def clear() -> None:
|
def clear() -> None:
|
||||||
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
|
||||||
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
|
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)
|
||||||
|
@ -2,16 +2,14 @@
|
|||||||
Minimalistic CBOR implementation, supports only what we need in cardano.
|
Minimalistic CBOR implementation, supports only what we need in cardano.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ustruct as struct
|
|
||||||
from micropython import const
|
from micropython import const
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import log, utils
|
from trezor import log
|
||||||
|
|
||||||
from . import readers
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any, Generic, Iterator, TypeVar
|
from typing import Any, Generic, Iterator, TypeVar
|
||||||
|
from trezor.utils import BufferReader
|
||||||
|
|
||||||
K = TypeVar("K")
|
K = TypeVar("K")
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
@ -48,16 +46,18 @@ _CBOR_RAW_TAG = const(0x18)
|
|||||||
|
|
||||||
|
|
||||||
def _header(typ: int, l: int) -> bytes:
|
def _header(typ: int, l: int) -> bytes:
|
||||||
|
from ustruct import pack
|
||||||
|
|
||||||
if l < 24:
|
if l < 24:
|
||||||
return struct.pack(">B", typ + l)
|
return pack(">B", typ + l)
|
||||||
elif l < 2**8:
|
elif l < 2**8:
|
||||||
return struct.pack(">BB", typ + 24, l)
|
return pack(">BB", typ + 24, l)
|
||||||
elif l < 2**16:
|
elif l < 2**16:
|
||||||
return struct.pack(">BH", typ + 25, l)
|
return pack(">BH", typ + 25, l)
|
||||||
elif l < 2**32:
|
elif l < 2**32:
|
||||||
return struct.pack(">BI", typ + 26, l)
|
return pack(">BI", typ + 26, l)
|
||||||
elif l < 2**64:
|
elif l < 2**64:
|
||||||
return struct.pack(">BQ", typ + 27, l)
|
return pack(">BQ", typ + 27, l)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError # Length not supported
|
raise NotImplementedError # Length not supported
|
||||||
|
|
||||||
@ -117,7 +117,9 @@ def _cbor_encode(value: Value) -> Iterator[bytes]:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def _read_length(r: utils.BufferReader, aux: int) -> int:
|
def _read_length(r: BufferReader, aux: int) -> int:
|
||||||
|
from . import readers
|
||||||
|
|
||||||
if aux < _CBOR_UINT8_FOLLOWS:
|
if aux < _CBOR_UINT8_FOLLOWS:
|
||||||
return aux
|
return aux
|
||||||
elif aux == _CBOR_UINT8_FOLLOWS:
|
elif aux == _CBOR_UINT8_FOLLOWS:
|
||||||
@ -132,7 +134,7 @@ def _read_length(r: utils.BufferReader, aux: int) -> int:
|
|||||||
raise NotImplementedError # Length not supported
|
raise NotImplementedError # Length not supported
|
||||||
|
|
||||||
|
|
||||||
def _cbor_decode(r: utils.BufferReader) -> Value:
|
def _cbor_decode(r: BufferReader) -> Value:
|
||||||
fb = r.get()
|
fb = r.get()
|
||||||
fb_type = fb & _CBOR_TYPE_MASK
|
fb_type = fb & _CBOR_TYPE_MASK
|
||||||
fb_aux = fb & _CBOR_INFO_BITS
|
fb_aux = fb & _CBOR_INFO_BITS
|
||||||
@ -220,6 +222,7 @@ class Tagged:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this seems to be unused - is checked against, but is never created???
|
||||||
class Raw:
|
class Raw:
|
||||||
def __init__(self, value: Value):
|
def __init__(self, value: Value):
|
||||||
self.value = value
|
self.value = value
|
||||||
@ -272,7 +275,9 @@ def encode_streamed(value: Value) -> Iterator[bytes]:
|
|||||||
|
|
||||||
|
|
||||||
def decode(cbor: bytes, offset: int = 0) -> Value:
|
def decode(cbor: bytes, offset: int = 0) -> Value:
|
||||||
r = utils.BufferReader(cbor)
|
from trezor.utils import BufferReader
|
||||||
|
|
||||||
|
r = BufferReader(cbor)
|
||||||
r.seek(offset)
|
r.seek(offset)
|
||||||
res = _cbor_decode(r)
|
res = _cbor_decode(r)
|
||||||
if r.remaining_count():
|
if r.remaining_count():
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,9 @@
|
|||||||
# generated from coininfo.py.mako
|
# generated from coininfo.py.mako
|
||||||
# (by running `make templates` in `core`)
|
# (by running `make templates` in `core`)
|
||||||
# do not edit manually!
|
# do not edit manually!
|
||||||
|
|
||||||
|
# NOTE: using positional arguments saves 4500 bytes of flash size
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from trezor import utils
|
from trezor import utils
|
||||||
@ -142,7 +145,7 @@ def by_name(name: str) -> CoinInfo:
|
|||||||
if name == ${black_repr(coin["coin_name"])}:
|
if name == ${black_repr(coin["coin_name"])}:
|
||||||
return CoinInfo(
|
return CoinInfo(
|
||||||
% for attr, func in ATTRIBUTES:
|
% for attr, func in ATTRIBUTES:
|
||||||
${attr}=${func(coin[attr])},
|
${func(coin[attr])}, # ${attr}
|
||||||
% endfor
|
% endfor
|
||||||
)
|
)
|
||||||
% endfor
|
% endfor
|
||||||
@ -151,7 +154,7 @@ def by_name(name: str) -> CoinInfo:
|
|||||||
if name == ${black_repr(coin["coin_name"])}:
|
if name == ${black_repr(coin["coin_name"])}:
|
||||||
return CoinInfo(
|
return CoinInfo(
|
||||||
% for attr, func in ATTRIBUTES:
|
% for attr, func in ATTRIBUTES:
|
||||||
${attr}=${func(coin[attr])},
|
${func(coin[attr])}, # ${attr}
|
||||||
% endfor
|
% endfor
|
||||||
)
|
)
|
||||||
% endfor
|
% endfor
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import sys
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import wire
|
|
||||||
from trezor.crypto import bip32
|
from trezor.crypto import bip32
|
||||||
|
from trezor.wire import DataError
|
||||||
|
|
||||||
from . import paths, safety_checks
|
from . import paths, safety_checks
|
||||||
from .seed import Slip21Node, get_seed
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -18,6 +16,8 @@ if TYPE_CHECKING:
|
|||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from trezor.protobuf import MessageType
|
from trezor.protobuf import MessageType
|
||||||
|
from trezor.wire import Context
|
||||||
|
from .seed import Slip21Node
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -36,15 +36,15 @@ if TYPE_CHECKING:
|
|||||||
MsgIn = TypeVar("MsgIn", bound=MessageType)
|
MsgIn = TypeVar("MsgIn", bound=MessageType)
|
||||||
MsgOut = TypeVar("MsgOut", bound=MessageType)
|
MsgOut = TypeVar("MsgOut", bound=MessageType)
|
||||||
|
|
||||||
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
|
Handler = Callable[[Context, MsgIn], Awaitable[MsgOut]]
|
||||||
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
|
HandlerWithKeychain = Callable[[Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
|
||||||
|
|
||||||
class Deletable(Protocol):
|
class Deletable(Protocol):
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
FORBIDDEN_KEY_PATH = wire.DataError("Forbidden key path")
|
FORBIDDEN_KEY_PATH = DataError("Forbidden key path")
|
||||||
|
|
||||||
|
|
||||||
class LRUCache:
|
class LRUCache:
|
||||||
@ -54,13 +54,15 @@ class LRUCache:
|
|||||||
self.cache: dict[Any, Deletable] = {}
|
self.cache: dict[Any, Deletable] = {}
|
||||||
|
|
||||||
def insert(self, key: Any, value: Deletable) -> None:
|
def insert(self, key: Any, value: Deletable) -> None:
|
||||||
if key in self.cache_keys:
|
cache_keys = self.cache_keys # local_cache_attribute
|
||||||
self.cache_keys.remove(key)
|
|
||||||
self.cache_keys.insert(0, key)
|
if key in cache_keys:
|
||||||
|
cache_keys.remove(key)
|
||||||
|
cache_keys.insert(0, key)
|
||||||
self.cache[key] = value
|
self.cache[key] = value
|
||||||
|
|
||||||
if len(self.cache_keys) > self.size:
|
if len(cache_keys) > self.size:
|
||||||
dropped_key = self.cache_keys.pop()
|
dropped_key = cache_keys.pop()
|
||||||
self.cache[dropped_key].__del__()
|
self.cache[dropped_key].__del__()
|
||||||
del self.cache[dropped_key]
|
del self.cache[dropped_key]
|
||||||
|
|
||||||
@ -103,7 +105,7 @@ class Keychain:
|
|||||||
|
|
||||||
def verify_path(self, path: paths.Bip32Path) -> None:
|
def verify_path(self, path: paths.Bip32Path) -> None:
|
||||||
if "ed25519" in self.curve and not paths.path_is_hardened(path):
|
if "ed25519" in self.curve and not paths.path_is_hardened(path):
|
||||||
raise wire.DataError("Non-hardened paths unsupported on Ed25519")
|
raise DataError("Non-hardened paths unsupported on Ed25519")
|
||||||
|
|
||||||
if not safety_checks.is_strict():
|
if not safety_checks.is_strict():
|
||||||
return
|
return
|
||||||
@ -137,8 +139,8 @@ class Keychain:
|
|||||||
if self._root_fingerprint is None:
|
if self._root_fingerprint is None:
|
||||||
# derive m/0' to obtain root_fingerprint
|
# derive m/0' to obtain root_fingerprint
|
||||||
n = self._derive_with_cache(
|
n = self._derive_with_cache(
|
||||||
prefix_len=0,
|
0,
|
||||||
path=[0 | paths.HARDENED],
|
[0 | paths.HARDENED],
|
||||||
new_root=lambda: bip32.from_seed(self.seed, self.curve),
|
new_root=lambda: bip32.from_seed(self.seed, self.curve),
|
||||||
)
|
)
|
||||||
self._root_fingerprint = n.fingerprint()
|
self._root_fingerprint = n.fingerprint()
|
||||||
@ -147,20 +149,22 @@ class Keychain:
|
|||||||
def derive(self, path: paths.Bip32Path) -> bip32.HDNode:
|
def derive(self, path: paths.Bip32Path) -> bip32.HDNode:
|
||||||
self.verify_path(path)
|
self.verify_path(path)
|
||||||
return self._derive_with_cache(
|
return self._derive_with_cache(
|
||||||
prefix_len=3,
|
3,
|
||||||
path=path,
|
path,
|
||||||
new_root=lambda: bip32.from_seed(self.seed, self.curve),
|
new_root=lambda: bip32.from_seed(self.seed, self.curve),
|
||||||
)
|
)
|
||||||
|
|
||||||
def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node:
|
def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node:
|
||||||
|
from .seed import Slip21Node
|
||||||
|
|
||||||
if safety_checks.is_strict() and not any(
|
if safety_checks.is_strict() and not any(
|
||||||
ns == path[: len(ns)] for ns in self.slip21_namespaces
|
ns == path[: len(ns)] for ns in self.slip21_namespaces
|
||||||
):
|
):
|
||||||
raise FORBIDDEN_KEY_PATH
|
raise FORBIDDEN_KEY_PATH
|
||||||
|
|
||||||
return self._derive_with_cache(
|
return self._derive_with_cache(
|
||||||
prefix_len=1,
|
1,
|
||||||
path=path,
|
path,
|
||||||
new_root=lambda: Slip21Node(seed=self.seed),
|
new_root=lambda: Slip21Node(seed=self.seed),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -172,11 +176,13 @@ class Keychain:
|
|||||||
|
|
||||||
|
|
||||||
async def get_keychain(
|
async def get_keychain(
|
||||||
ctx: wire.Context,
|
ctx: Context,
|
||||||
curve: str,
|
curve: str,
|
||||||
schemas: Iterable[paths.PathSchemaType],
|
schemas: Iterable[paths.PathSchemaType],
|
||||||
slip21_namespaces: Iterable[paths.Slip21Path] = (),
|
slip21_namespaces: Iterable[paths.Slip21Path] = (),
|
||||||
) -> Keychain:
|
) -> Keychain:
|
||||||
|
from .seed import get_seed
|
||||||
|
|
||||||
seed = await get_seed(ctx)
|
seed = await get_seed(ctx)
|
||||||
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
|
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
|
||||||
return keychain
|
return keychain
|
||||||
@ -191,18 +197,15 @@ def with_slip44_keychain(
|
|||||||
if not patterns:
|
if not patterns:
|
||||||
raise ValueError # specify a pattern
|
raise ValueError # specify a pattern
|
||||||
|
|
||||||
if allow_testnet:
|
slip_44_ids = (slip44_id, 1) if allow_testnet else slip44_id
|
||||||
slip44_ids: int | tuple[int, int] = (slip44_id, 1)
|
|
||||||
else:
|
|
||||||
slip44_ids = slip44_id
|
|
||||||
|
|
||||||
schemas = []
|
schemas = []
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids))
|
schemas.append(paths.PathSchema.parse(pattern, slip_44_ids))
|
||||||
schemas = [s.copy() for s in schemas]
|
schemas = [s.copy() for s in schemas]
|
||||||
|
|
||||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
|
||||||
keychain = await get_keychain(ctx, curve, schemas)
|
keychain = await get_keychain(ctx, curve, schemas)
|
||||||
with keychain:
|
with keychain:
|
||||||
return await func(ctx, msg, keychain)
|
return await func(ctx, msg, keychain)
|
||||||
@ -215,6 +218,8 @@ def with_slip44_keychain(
|
|||||||
def auto_keychain(
|
def auto_keychain(
|
||||||
modname: str, allow_testnet: bool = True
|
modname: str, allow_testnet: bool = True
|
||||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||||
|
import sys
|
||||||
|
|
||||||
rdot = modname.rfind(".")
|
rdot = modname.rfind(".")
|
||||||
parent_modname = modname[:rdot]
|
parent_modname = modname[:rdot]
|
||||||
parent_module = sys.modules[parent_modname]
|
parent_module = sys.modules[parent_modname]
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import storage.device
|
from typing import TYPE_CHECKING
|
||||||
from trezor import ui, utils, workflow
|
|
||||||
from trezor.enums import BackupType
|
import storage.device as storage_device
|
||||||
|
from trezor import utils
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezor.enums import BackupType
|
||||||
|
|
||||||
|
|
||||||
def get() -> tuple[bytes | None, BackupType]:
|
def get() -> tuple[bytes | None, BackupType]:
|
||||||
@ -8,11 +12,11 @@ def get() -> tuple[bytes | None, BackupType]:
|
|||||||
|
|
||||||
|
|
||||||
def get_secret() -> bytes | None:
|
def get_secret() -> bytes | None:
|
||||||
return storage.device.get_mnemonic_secret()
|
return storage_device.get_mnemonic_secret()
|
||||||
|
|
||||||
|
|
||||||
def get_type() -> BackupType:
|
def get_type() -> BackupType:
|
||||||
return storage.device.get_backup_type()
|
return storage_device.get_backup_type()
|
||||||
|
|
||||||
|
|
||||||
def is_bip39() -> bool:
|
def is_bip39() -> bool:
|
||||||
@ -20,6 +24,8 @@ def is_bip39() -> bool:
|
|||||||
If False then SLIP-39 (either Basic or Advanced).
|
If False then SLIP-39 (either Basic or Advanced).
|
||||||
Other invalid values are checked directly in storage.
|
Other invalid values are checked directly in storage.
|
||||||
"""
|
"""
|
||||||
|
from trezor.enums import BackupType
|
||||||
|
|
||||||
return get_type() == BackupType.Bip39
|
return get_type() == BackupType.Bip39
|
||||||
|
|
||||||
|
|
||||||
@ -41,8 +47,8 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
|
|||||||
else: # SLIP-39
|
else: # SLIP-39
|
||||||
from trezor.crypto import slip39
|
from trezor.crypto import slip39
|
||||||
|
|
||||||
identifier = storage.device.get_slip39_identifier()
|
identifier = storage_device.get_slip39_identifier()
|
||||||
iteration_exponent = storage.device.get_slip39_iteration_exponent()
|
iteration_exponent = storage_device.get_slip39_iteration_exponent()
|
||||||
if identifier is None or iteration_exponent is None:
|
if identifier is None or iteration_exponent is None:
|
||||||
# Identifier or exponent expected but not found
|
# Identifier or exponent expected but not found
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
@ -84,6 +90,7 @@ if not utils.BITCOIN_ONLY:
|
|||||||
|
|
||||||
|
|
||||||
def _start_progress() -> None:
|
def _start_progress() -> None:
|
||||||
|
from trezor import workflow
|
||||||
from trezor.ui.layouts import draw_simple_text
|
from trezor.ui.layouts import draw_simple_text
|
||||||
|
|
||||||
# Because we are drawing to the screen manually, without a layout, we
|
# Because we are drawing to the screen manually, without a layout, we
|
||||||
@ -93,6 +100,8 @@ def _start_progress() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _render_progress(progress: int, total: int) -> None:
|
def _render_progress(progress: int, total: int) -> None:
|
||||||
|
from trezor import ui
|
||||||
|
|
||||||
p = 1000 * progress // total
|
p = 1000 * progress // total
|
||||||
ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
|
ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
|
||||||
ui.refresh()
|
ui.refresh()
|
||||||
|
@ -1,66 +1,72 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import storage.device
|
import storage.device as storage_device
|
||||||
from trezor import wire, workflow
|
from trezor.wire import DataError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezor.wire import Context
|
||||||
|
|
||||||
_MAX_PASSPHRASE_LEN = const(50)
|
_MAX_PASSPHRASE_LEN = const(50)
|
||||||
|
|
||||||
|
|
||||||
def is_enabled() -> bool:
|
def is_enabled() -> bool:
|
||||||
return storage.device.is_passphrase_enabled()
|
return storage_device.is_passphrase_enabled()
|
||||||
|
|
||||||
|
|
||||||
async def get(ctx: wire.Context) -> str:
|
async def get(ctx: Context) -> str:
|
||||||
if is_enabled():
|
from trezor import workflow
|
||||||
return await _request_from_user(ctx)
|
|
||||||
else:
|
if not is_enabled():
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def _request_from_user(ctx: wire.Context) -> str:
|
|
||||||
workflow.close_others() # request exclusive UI access
|
|
||||||
if storage.device.get_passphrase_always_on_device():
|
|
||||||
from trezor.ui.layouts import request_passphrase_on_device
|
|
||||||
|
|
||||||
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
|
|
||||||
else:
|
else:
|
||||||
passphrase = await _request_on_host(ctx)
|
workflow.close_others() # request exclusive UI access
|
||||||
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
|
if storage_device.get_passphrase_always_on_device():
|
||||||
raise wire.DataError(
|
from trezor.ui.layouts import request_passphrase_on_device
|
||||||
f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
|
|
||||||
)
|
|
||||||
|
|
||||||
return passphrase
|
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
|
||||||
|
else:
|
||||||
|
passphrase = await _request_on_host(ctx)
|
||||||
|
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
|
||||||
|
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
|
||||||
|
|
||||||
|
return passphrase
|
||||||
|
|
||||||
|
|
||||||
async def _request_on_host(ctx: wire.Context) -> str:
|
async def _request_on_host(ctx: Context) -> str:
|
||||||
from trezor.messages import PassphraseAck, PassphraseRequest
|
from trezor.messages import PassphraseAck, PassphraseRequest
|
||||||
|
from trezor.ui.layouts import draw_simple_text
|
||||||
|
|
||||||
_entry_dialog()
|
# _entry_dialog
|
||||||
|
draw_simple_text(
|
||||||
|
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
|
||||||
|
)
|
||||||
|
|
||||||
request = PassphraseRequest()
|
request = PassphraseRequest()
|
||||||
ack = await ctx.call(request, PassphraseAck)
|
ack = await ctx.call(request, PassphraseAck)
|
||||||
|
passphrase = ack.passphrase # local_cache_attribute
|
||||||
|
|
||||||
if ack.on_device:
|
if ack.on_device:
|
||||||
from trezor.ui.layouts import request_passphrase_on_device
|
from trezor.ui.layouts import request_passphrase_on_device
|
||||||
|
|
||||||
if ack.passphrase is not None:
|
if passphrase is not None:
|
||||||
raise wire.DataError("Passphrase provided when it should not be")
|
raise DataError("Passphrase provided when it should not be")
|
||||||
return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
|
return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
|
||||||
|
|
||||||
if ack.passphrase is None:
|
if passphrase is None:
|
||||||
raise wire.DataError(
|
raise DataError(
|
||||||
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
|
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
|
||||||
)
|
)
|
||||||
|
|
||||||
# non-empty passphrase
|
# non-empty passphrase
|
||||||
if ack.passphrase:
|
if passphrase:
|
||||||
from trezor import ui
|
from trezor import ui
|
||||||
from trezor.ui.layouts import confirm_action, confirm_blob
|
from trezor.ui.layouts import confirm_action, confirm_blob
|
||||||
|
|
||||||
await confirm_action(
|
await confirm_action(
|
||||||
ctx,
|
ctx,
|
||||||
"passphrase_host1",
|
"passphrase_host1",
|
||||||
title="Hidden wallet",
|
"Hidden wallet",
|
||||||
description="Access hidden wallet?\n\nNext screen will show\nthe passphrase!",
|
description="Access hidden wallet?\n\nNext screen will show\nthe passphrase!",
|
||||||
icon=ui.ICON_CONFIG,
|
icon=ui.ICON_CONFIG,
|
||||||
)
|
)
|
||||||
@ -68,19 +74,11 @@ async def _request_on_host(ctx: wire.Context) -> str:
|
|||||||
await confirm_blob(
|
await confirm_blob(
|
||||||
ctx,
|
ctx,
|
||||||
"passphrase_host2",
|
"passphrase_host2",
|
||||||
title="Hidden wallet",
|
"Hidden wallet",
|
||||||
description="Use this passphrase?\n",
|
passphrase,
|
||||||
data=ack.passphrase,
|
"Use this passphrase?\n",
|
||||||
icon=ui.ICON_CONFIG,
|
icon=ui.ICON_CONFIG,
|
||||||
icon_color=ui.ORANGE_ICON,
|
icon_color=ui.ORANGE_ICON,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ack.passphrase
|
return passphrase
|
||||||
|
|
||||||
|
|
||||||
def _entry_dialog() -> None:
|
|
||||||
from trezor.ui.layouts import draw_simple_text
|
|
||||||
|
|
||||||
draw_simple_text(
|
|
||||||
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
|
|
||||||
)
|
|
||||||
|
@ -197,23 +197,24 @@ class PathSchema:
|
|||||||
|
|
||||||
# optionally replace a keyword
|
# optionally replace a keyword
|
||||||
component = cls.REPLACEMENTS.get(component, component)
|
component = cls.REPLACEMENTS.get(component, component)
|
||||||
|
append = schema.append # local_cache_attribute
|
||||||
|
|
||||||
if "-" in component:
|
if "-" in component:
|
||||||
# parse as a range
|
# parse as a range
|
||||||
a, b = [parse(s) for s in component.split("-", 1)]
|
a, b = [parse(s) for s in component.split("-", 1)]
|
||||||
schema.append(Interval(a, b))
|
append(Interval(a, b))
|
||||||
|
|
||||||
elif "," in component:
|
elif "," in component:
|
||||||
# parse as a list of values
|
# parse as a list of values
|
||||||
schema.append(set(parse(s) for s in component.split(",")))
|
append(set(parse(s) for s in component.split(",")))
|
||||||
|
|
||||||
elif component == "coin_type":
|
elif component == "coin_type":
|
||||||
# substitute SLIP-44 ids
|
# substitute SLIP-44 ids
|
||||||
schema.append(set(parse(s) for s in slip44_id))
|
append(set(parse(s) for s in slip44_id))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# plain constant
|
# plain constant
|
||||||
schema.append((parse(component),))
|
append((parse(component),))
|
||||||
|
|
||||||
return cls(schema, trailing_components, compact=True)
|
return cls(schema, trailing_components, compact=True)
|
||||||
|
|
||||||
@ -258,18 +259,19 @@ class PathSchema:
|
|||||||
path. If the restriction results in a never-matching schema, then False
|
path. If the restriction results in a never-matching schema, then False
|
||||||
is returned.
|
is returned.
|
||||||
"""
|
"""
|
||||||
|
schema = self.schema # local_cache_attribute
|
||||||
|
|
||||||
for i, value in enumerate(path):
|
for i, value in enumerate(path):
|
||||||
if i < len(self.schema):
|
if i < len(schema):
|
||||||
# Ensure that the path is a prefix of the schema.
|
# Ensure that the path is a prefix of the schema.
|
||||||
if value not in self.schema[i]:
|
if value not in schema[i]:
|
||||||
self.set_never_matching()
|
self.set_never_matching()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Restrict the schema component if there are multiple choices.
|
# Restrict the schema component if there are multiple choices.
|
||||||
component = self.schema[i]
|
component = schema[i]
|
||||||
if not isinstance(component, tuple) or len(component) != 1:
|
if not isinstance(component, tuple) or len(component) != 1:
|
||||||
self.schema[i] = (value,)
|
schema[i] = (value,)
|
||||||
else:
|
else:
|
||||||
# The path is longer than the schema. We need to restrict the
|
# The path is longer than the schema. We need to restrict the
|
||||||
# trailing components.
|
# trailing components.
|
||||||
@ -278,7 +280,7 @@ class PathSchema:
|
|||||||
self.set_never_matching()
|
self.set_never_matching()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.schema.append((value,))
|
schema.append((value,))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -286,6 +288,7 @@ class PathSchema:
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
components = ["m"]
|
components = ["m"]
|
||||||
|
append = components.append # local_cache_attribute
|
||||||
|
|
||||||
def unharden(item: int) -> int:
|
def unharden(item: int) -> int:
|
||||||
return item ^ (item & HARDENED)
|
return item ^ (item & HARDENED)
|
||||||
@ -294,7 +297,7 @@ class PathSchema:
|
|||||||
if isinstance(component, Interval):
|
if isinstance(component, Interval):
|
||||||
a, b = component.min, component.max
|
a, b = component.min, component.max
|
||||||
prime = "'" if a & HARDENED else ""
|
prime = "'" if a & HARDENED else ""
|
||||||
components.append(f"[{unharden(a)}-{unharden(b)}]{prime}")
|
append(f"[{unharden(a)}-{unharden(b)}]{prime}")
|
||||||
else:
|
else:
|
||||||
# typechecker thinks component is a Contanier but we're using it
|
# typechecker thinks component is a Contanier but we're using it
|
||||||
# as a Collection.
|
# as a Collection.
|
||||||
@ -307,15 +310,15 @@ class PathSchema:
|
|||||||
component_str = "[" + component_str + "]"
|
component_str = "[" + component_str + "]"
|
||||||
if next(iter(collection)) & HARDENED:
|
if next(iter(collection)) & HARDENED:
|
||||||
component_str += "'"
|
component_str += "'"
|
||||||
components.append(component_str)
|
append(component_str)
|
||||||
|
|
||||||
if self.trailing_components:
|
if self.trailing_components:
|
||||||
for key, val in self.WILDCARD_RANGES.items():
|
for key, val in self.WILDCARD_RANGES.items():
|
||||||
if self.trailing_components is val:
|
if self.trailing_components is val:
|
||||||
components.append(key)
|
append(key)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
components.append("???")
|
append("???")
|
||||||
|
|
||||||
return "<schema:" + "/".join(components) + ">"
|
return "<schema:" + "/".join(components) + ">"
|
||||||
|
|
||||||
@ -362,7 +365,7 @@ def path_is_hardened(address_n: Bip32Path) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def address_n_to_str(address_n: Iterable[int]) -> str:
|
def address_n_to_str(address_n: Iterable[int]) -> str:
|
||||||
def path_item(i: int) -> str:
|
def _path_item(i: int) -> str:
|
||||||
if i & HARDENED:
|
if i & HARDENED:
|
||||||
return str(i ^ HARDENED) + "'"
|
return str(i ^ HARDENED) + "'"
|
||||||
else:
|
else:
|
||||||
@ -371,4 +374,4 @@ def address_n_to_str(address_n: Iterable[int]) -> str:
|
|||||||
if not address_n:
|
if not address_n:
|
||||||
return "m"
|
return "m"
|
||||||
|
|
||||||
return "m/" + "/".join(path_item(i) for i in address_n)
|
return "m/" + "/".join(_path_item(i) for i in address_n)
|
||||||
|
@ -1,27 +1,32 @@
|
|||||||
from trezor.utils import BufferReader
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezor.utils import BufferReader
|
||||||
|
|
||||||
|
|
||||||
def read_compact_size(r: BufferReader) -> int:
|
def read_compact_size(r: BufferReader) -> int:
|
||||||
prefix = r.get()
|
get = r.get # local_cache_attribute
|
||||||
|
|
||||||
|
prefix = get()
|
||||||
if prefix < 253:
|
if prefix < 253:
|
||||||
n = prefix
|
n = prefix
|
||||||
elif prefix == 253:
|
elif prefix == 253:
|
||||||
n = r.get()
|
n = get()
|
||||||
n += r.get() << 8
|
n += get() << 8
|
||||||
elif prefix == 254:
|
elif prefix == 254:
|
||||||
n = r.get()
|
n = get()
|
||||||
n += r.get() << 8
|
n += get() << 8
|
||||||
n += r.get() << 16
|
n += get() << 16
|
||||||
n += r.get() << 24
|
n += get() << 24
|
||||||
elif prefix == 255:
|
elif prefix == 255:
|
||||||
n = r.get()
|
n = get()
|
||||||
n += r.get() << 8
|
n += get() << 8
|
||||||
n += r.get() << 16
|
n += get() << 16
|
||||||
n += r.get() << 24
|
n += get() << 24
|
||||||
n += r.get() << 32
|
n += get() << 32
|
||||||
n += r.get() << 40
|
n += get() << 40
|
||||||
n += r.get() << 48
|
n += get() << 48
|
||||||
n += r.get() << 56
|
n += get() << 56
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
return n
|
return n
|
||||||
|
@ -1,20 +1,25 @@
|
|||||||
import utime
|
import utime
|
||||||
from typing import Any, NoReturn
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import storage.cache
|
import storage.cache as storage_cache
|
||||||
import storage.sd_salt
|
|
||||||
from trezor import config, wire
|
from trezor import config, wire
|
||||||
|
|
||||||
from .sdcard import SdCardUnavailable, request_sd_salt
|
from .sdcard import request_sd_salt
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any, NoReturn
|
||||||
|
from trezor.wire import Context, GenericContext
|
||||||
|
|
||||||
|
|
||||||
def can_lock_device() -> bool:
|
def can_lock_device() -> bool:
|
||||||
"""Return True if the device has a PIN set or SD-protect enabled."""
|
"""Return True if the device has a PIN set or SD-protect enabled."""
|
||||||
|
import storage.sd_salt
|
||||||
|
|
||||||
return config.has_pin() or storage.sd_salt.is_enabled()
|
return config.has_pin() or storage.sd_salt.is_enabled()
|
||||||
|
|
||||||
|
|
||||||
async def request_pin(
|
async def request_pin(
|
||||||
ctx: wire.GenericContext,
|
ctx: GenericContext,
|
||||||
prompt: str = "Enter your PIN",
|
prompt: str = "Enter your PIN",
|
||||||
attempts_remaining: int | None = None,
|
attempts_remaining: int | None = None,
|
||||||
allow_cancel: bool = True,
|
allow_cancel: bool = True,
|
||||||
@ -24,26 +29,26 @@ async def request_pin(
|
|||||||
return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel)
|
return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel)
|
||||||
|
|
||||||
|
|
||||||
async def request_pin_confirm(ctx: wire.Context, *args: Any, **kwargs: Any) -> str:
|
async def request_pin_confirm(ctx: Context, *args: Any, **kwargs: Any) -> str:
|
||||||
while True:
|
while True:
|
||||||
pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs)
|
pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs)
|
||||||
pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs)
|
pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs)
|
||||||
if pin1 == pin2:
|
if pin1 == pin2:
|
||||||
return pin1
|
return pin1
|
||||||
await pin_mismatch()
|
await _pin_mismatch()
|
||||||
|
|
||||||
|
|
||||||
async def pin_mismatch() -> None:
|
async def _pin_mismatch() -> None:
|
||||||
from trezor.ui.layouts import show_popup
|
from trezor.ui.layouts import show_popup
|
||||||
|
|
||||||
await show_popup(
|
await show_popup(
|
||||||
title="PIN mismatch",
|
"PIN mismatch",
|
||||||
description="The PINs you entered\ndo not match.\n\nPlease try again.",
|
"The PINs you entered\ndo not match.\n\nPlease try again.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def request_pin_and_sd_salt(
|
async def request_pin_and_sd_salt(
|
||||||
ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
|
ctx: Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
|
||||||
) -> tuple[str, bytearray | None]:
|
) -> tuple[str, bytearray | None]:
|
||||||
if config.has_pin():
|
if config.has_pin():
|
||||||
pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel)
|
pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel)
|
||||||
@ -58,21 +63,23 @@ async def request_pin_and_sd_salt(
|
|||||||
|
|
||||||
def _set_last_unlock_time() -> None:
|
def _set_last_unlock_time() -> None:
|
||||||
now = utime.ticks_ms()
|
now = utime.ticks_ms()
|
||||||
storage.cache.set_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
|
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
|
||||||
|
|
||||||
|
|
||||||
def _get_last_unlock_time() -> int:
|
|
||||||
return storage.cache.get_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK) or 0
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_user_pin(
|
async def verify_user_pin(
|
||||||
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
|
ctx: GenericContext = wire.DUMMY_CONTEXT,
|
||||||
prompt: str = "Enter your PIN",
|
prompt: str = "Enter your PIN",
|
||||||
allow_cancel: bool = True,
|
allow_cancel: bool = True,
|
||||||
retry: bool = True,
|
retry: bool = True,
|
||||||
cache_time_ms: int = 0,
|
cache_time_ms: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
last_unlock = _get_last_unlock_time()
|
from .sdcard import SdCardUnavailable
|
||||||
|
|
||||||
|
# _get_last_unlock_time
|
||||||
|
last_unlock = int.from_bytes(
|
||||||
|
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cache_time_ms
|
cache_time_ms
|
||||||
and last_unlock
|
and last_unlock
|
||||||
@ -112,28 +119,28 @@ async def verify_user_pin(
|
|||||||
raise wire.PinInvalid
|
raise wire.PinInvalid
|
||||||
|
|
||||||
|
|
||||||
async def error_pin_invalid(ctx: wire.Context) -> NoReturn:
|
async def error_pin_invalid(ctx: Context) -> NoReturn:
|
||||||
from trezor.ui.layouts import show_error_and_raise
|
from trezor.ui.layouts import show_error_and_raise
|
||||||
|
|
||||||
await show_error_and_raise(
|
await show_error_and_raise(
|
||||||
ctx,
|
ctx,
|
||||||
"warning_wrong_pin",
|
"warning_wrong_pin",
|
||||||
header="Wrong PIN",
|
"The PIN you entered is invalid.",
|
||||||
content="The PIN you entered is invalid.",
|
"Wrong PIN", # header
|
||||||
red=True,
|
red=True,
|
||||||
exc=wire.PinInvalid,
|
exc=wire.PinInvalid,
|
||||||
)
|
)
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|
||||||
async def error_pin_matches_wipe_code(ctx: wire.Context) -> NoReturn:
|
async def error_pin_matches_wipe_code(ctx: Context) -> NoReturn:
|
||||||
from trezor.ui.layouts import show_error_and_raise
|
from trezor.ui.layouts import show_error_and_raise
|
||||||
|
|
||||||
await show_error_and_raise(
|
await show_error_and_raise(
|
||||||
ctx,
|
ctx,
|
||||||
"warning_invalid_new_pin",
|
"warning_invalid_new_pin",
|
||||||
header="Invalid PIN",
|
"The new PIN must be different from your\nwipe code.",
|
||||||
content="The new PIN must be different from your\nwipe code.",
|
"Invalid PIN", # header
|
||||||
red=True,
|
red=True,
|
||||||
exc=wire.PinInvalid,
|
exc=wire.PinInvalid,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import storage.cache
|
import storage.cache as storage_cache
|
||||||
import storage.device
|
import storage.device as storage_device
|
||||||
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||||
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
|
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
|
||||||
from trezor.enums import SafetyCheckLevel
|
from trezor.enums import SafetyCheckLevel
|
||||||
@ -9,11 +9,11 @@ def read_setting() -> SafetyCheckLevel:
|
|||||||
"""
|
"""
|
||||||
Returns the effective safety check level.
|
Returns the effective safety check level.
|
||||||
"""
|
"""
|
||||||
temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||||
if temporary_safety_check_level:
|
if temporary_safety_check_level:
|
||||||
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
|
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
|
||||||
else:
|
else:
|
||||||
stored = storage.device.safety_check_level()
|
stored = storage_device.safety_check_level()
|
||||||
if stored == SAFETY_CHECK_LEVEL_STRICT:
|
if stored == SAFETY_CHECK_LEVEL_STRICT:
|
||||||
return SafetyCheckLevel.Strict
|
return SafetyCheckLevel.Strict
|
||||||
elif stored == SAFETY_CHECK_LEVEL_PROMPT:
|
elif stored == SAFETY_CHECK_LEVEL_PROMPT:
|
||||||
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
|
|||||||
Changes the safety level settings.
|
Changes the safety level settings.
|
||||||
"""
|
"""
|
||||||
if level == SafetyCheckLevel.Strict:
|
if level == SafetyCheckLevel.Strict:
|
||||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||||
elif level == SafetyCheckLevel.PromptAlways:
|
elif level == SafetyCheckLevel.PromptAlways:
|
||||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
||||||
elif level == SafetyCheckLevel.PromptTemporarily:
|
elif level == SafetyCheckLevel.PromptTemporarily:
|
||||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||||
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
|
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown SafetyCheckLevel")
|
raise ValueError("Unknown SafetyCheckLevel")
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import storage.sd_salt
|
|
||||||
from storage.sd_salt import SD_CARD_HOT_SWAPPABLE
|
from storage.sd_salt import SD_CARD_HOT_SWAPPABLE
|
||||||
from trezor import io, sdcard, ui, wire
|
from trezor import io, ui, wire
|
||||||
from trezor.ui.layouts import confirm_action, show_error_and_raise
|
from trezor.ui.layouts import confirm_action, show_error_and_raise
|
||||||
|
|
||||||
|
|
||||||
@ -14,8 +13,8 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
|
|||||||
ctx,
|
ctx,
|
||||||
"warning_wrong_sd",
|
"warning_wrong_sd",
|
||||||
"SD card protection",
|
"SD card protection",
|
||||||
action="Wrong SD card.",
|
"Wrong SD card.",
|
||||||
description="Please insert the correct SD card for this device.",
|
"Please insert the correct SD card for this device.",
|
||||||
verb="Retry",
|
verb="Retry",
|
||||||
verb_cancel="Abort",
|
verb_cancel="Abort",
|
||||||
icon=ui.ICON_WRONG,
|
icon=ui.ICON_WRONG,
|
||||||
@ -26,9 +25,9 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
|
|||||||
await show_error_and_raise(
|
await show_error_and_raise(
|
||||||
ctx,
|
ctx,
|
||||||
"warning_wrong_sd",
|
"warning_wrong_sd",
|
||||||
header="SD card protection",
|
"Please unplug the\ndevice and insert the correct SD card.",
|
||||||
subheader="Wrong SD card.",
|
"SD card protection",
|
||||||
content="Please unplug the\ndevice and insert the correct SD card.",
|
"Wrong SD card.",
|
||||||
exc=SdCardUnavailable("Wrong SD card."),
|
exc=SdCardUnavailable("Wrong SD card."),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,8 +38,8 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
|
|||||||
ctx,
|
ctx,
|
||||||
"warning_no_sd",
|
"warning_no_sd",
|
||||||
"SD card protection",
|
"SD card protection",
|
||||||
action="SD card required.",
|
"SD card required.",
|
||||||
description="Please insert your SD card.",
|
"Please insert your SD card.",
|
||||||
verb="Retry",
|
verb="Retry",
|
||||||
verb_cancel="Abort",
|
verb_cancel="Abort",
|
||||||
icon=ui.ICON_WRONG,
|
icon=ui.ICON_WRONG,
|
||||||
@ -51,9 +50,9 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
|
|||||||
await show_error_and_raise(
|
await show_error_and_raise(
|
||||||
ctx,
|
ctx,
|
||||||
"warning_no_sd",
|
"warning_no_sd",
|
||||||
header="SD card protection",
|
"Please unplug the\ndevice and insert your SD card.",
|
||||||
subheader="SD card required.",
|
"SD card protection",
|
||||||
content="Please unplug the\ndevice and insert your SD card.",
|
"SD card required.",
|
||||||
exc=SdCardUnavailable("SD card required."),
|
exc=SdCardUnavailable("SD card required."),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,8 +63,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
|
|||||||
ctx,
|
ctx,
|
||||||
"warning_format_sd",
|
"warning_format_sd",
|
||||||
"SD card error",
|
"SD card error",
|
||||||
action="Unknown filesystem.",
|
"Unknown filesystem.",
|
||||||
description="Use a different card or format the SD card to the FAT32 filesystem.",
|
"Use a different card or format the SD card to the FAT32 filesystem.",
|
||||||
icon=ui.ICON_WRONG,
|
icon=ui.ICON_WRONG,
|
||||||
icon_color=ui.RED,
|
icon_color=ui.RED,
|
||||||
verb="Format",
|
verb="Format",
|
||||||
@ -79,8 +78,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
|
|||||||
ctx,
|
ctx,
|
||||||
"confirm_format_sd",
|
"confirm_format_sd",
|
||||||
"Format SD card",
|
"Format SD card",
|
||||||
action="All data on the SD card will be lost.",
|
"All data on the SD card will be lost.",
|
||||||
description="Do you really want to format the SD card?",
|
"Do you really want to format the SD card?",
|
||||||
reverse=True,
|
reverse=True,
|
||||||
verb="Format SD card",
|
verb="Format SD card",
|
||||||
icon=ui.ICON_WIPE,
|
icon=ui.ICON_WIPE,
|
||||||
@ -99,8 +98,8 @@ async def confirm_retry_sd(
|
|||||||
ctx,
|
ctx,
|
||||||
"warning_sd_retry",
|
"warning_sd_retry",
|
||||||
"SD card problem",
|
"SD card problem",
|
||||||
action=None,
|
None,
|
||||||
description="There was a problem accessing the SD card.",
|
"There was a problem accessing the SD card.",
|
||||||
icon=ui.ICON_WRONG,
|
icon=ui.ICON_WRONG,
|
||||||
icon_color=ui.RED,
|
icon_color=ui.RED,
|
||||||
verb="Retry",
|
verb="Retry",
|
||||||
@ -121,18 +120,20 @@ async def ensure_sdcard(
|
|||||||
filesystem, and allows the user to format the card if a filesystem cannot be
|
filesystem, and allows the user to format the card if a filesystem cannot be
|
||||||
mounted.
|
mounted.
|
||||||
"""
|
"""
|
||||||
|
from trezor import sdcard
|
||||||
|
|
||||||
while not sdcard.is_present():
|
while not sdcard.is_present():
|
||||||
await _confirm_retry_insert_card(ctx)
|
await _confirm_retry_insert_card(ctx)
|
||||||
|
|
||||||
if not ensure_filesystem:
|
if not ensure_filesystem:
|
||||||
return
|
return
|
||||||
|
fatfs = io.fatfs # local_cache_attribute
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
with sdcard.filesystem(mounted=False):
|
with sdcard.filesystem(mounted=False):
|
||||||
io.fatfs.mount()
|
fatfs.mount()
|
||||||
except io.fatfs.NoFilesystem:
|
except fatfs.NoFilesystem:
|
||||||
# card not formatted. proceed out of the except clause
|
# card not formatted. proceed out of the except clause
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -143,9 +144,9 @@ async def ensure_sdcard(
|
|||||||
|
|
||||||
# Proceed to formatting. Failure is caught by the outside OSError handler
|
# Proceed to formatting. Failure is caught by the outside OSError handler
|
||||||
with sdcard.filesystem(mounted=False):
|
with sdcard.filesystem(mounted=False):
|
||||||
io.fatfs.mkfs()
|
fatfs.mkfs()
|
||||||
io.fatfs.mount()
|
fatfs.mount()
|
||||||
io.fatfs.setlabel("TREZOR")
|
fatfs.setlabel("TREZOR")
|
||||||
|
|
||||||
# format and mount succeeded
|
# format and mount succeeded
|
||||||
return
|
return
|
||||||
@ -158,14 +159,16 @@ async def ensure_sdcard(
|
|||||||
async def request_sd_salt(
|
async def request_sd_salt(
|
||||||
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
|
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
|
||||||
) -> bytearray | None:
|
) -> bytearray | None:
|
||||||
if not storage.sd_salt.is_enabled():
|
import storage.sd_salt as storage_sd_salt
|
||||||
|
|
||||||
|
if not storage_sd_salt.is_enabled():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
await ensure_sdcard(ctx, ensure_filesystem=False)
|
await ensure_sdcard(ctx, ensure_filesystem=False)
|
||||||
try:
|
try:
|
||||||
return storage.sd_salt.load_sd_salt()
|
return storage_sd_salt.load_sd_salt()
|
||||||
except (storage.sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
|
except (storage_sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
|
||||||
await _confirm_retry_wrong_card(ctx)
|
await _confirm_retry_wrong_card(ctx)
|
||||||
except OSError:
|
except OSError:
|
||||||
# Generic problem with loading the SD salt (hardware problem, or we could
|
# Generic problem with loading the SD salt (hardware problem, or we could
|
||||||
|
@ -1,14 +1,17 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from storage import cache, device
|
import storage.cache as storage_cache
|
||||||
from trezor import utils, wire
|
import storage.device as storage_device
|
||||||
from trezor.crypto import bip32, hmac
|
from trezor import utils
|
||||||
|
from trezor.crypto import hmac
|
||||||
|
|
||||||
from . import mnemonic
|
from . import mnemonic
|
||||||
from .passphrase import get as get_passphrase
|
from .passphrase import get as get_passphrase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .paths import Bip32Path, Slip21Path
|
from .paths import Bip32Path, Slip21Path
|
||||||
|
from trezor.wire import Context
|
||||||
|
from trezor.crypto import bip32
|
||||||
|
|
||||||
|
|
||||||
class Slip21Node:
|
class Slip21Node:
|
||||||
@ -47,14 +50,16 @@ if not utils.BITCOIN_ONLY:
|
|||||||
# We want to derive both the normal seed and the Cardano seed together, AND
|
# We want to derive both the normal seed and the Cardano seed together, AND
|
||||||
# expose a method for Cardano to do the same
|
# expose a method for Cardano to do the same
|
||||||
|
|
||||||
async def derive_and_store_roots(ctx: wire.Context) -> None:
|
async def derive_and_store_roots(ctx: Context) -> None:
|
||||||
if not device.is_initialized():
|
from trezor import wire
|
||||||
|
|
||||||
|
if not storage_device.is_initialized():
|
||||||
raise wire.NotInitialized("Device is not initialized")
|
raise wire.NotInitialized("Device is not initialized")
|
||||||
|
|
||||||
need_seed = not cache.is_set(cache.APP_COMMON_SEED)
|
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
|
||||||
need_cardano_secret = cache.get(
|
need_cardano_secret = storage_cache.get(
|
||||||
cache.APP_COMMON_DERIVE_CARDANO
|
storage_cache.APP_COMMON_DERIVE_CARDANO
|
||||||
) and not cache.is_set(cache.APP_CARDANO_ICARUS_SECRET)
|
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
|
||||||
|
|
||||||
if not need_seed and not need_cardano_secret:
|
if not need_seed and not need_cardano_secret:
|
||||||
return
|
return
|
||||||
@ -63,17 +68,17 @@ if not utils.BITCOIN_ONLY:
|
|||||||
|
|
||||||
if need_seed:
|
if need_seed:
|
||||||
common_seed = mnemonic.get_seed(passphrase)
|
common_seed = mnemonic.get_seed(passphrase)
|
||||||
cache.set(cache.APP_COMMON_SEED, common_seed)
|
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
|
||||||
|
|
||||||
if need_cardano_secret:
|
if need_cardano_secret:
|
||||||
from apps.cardano.seed import derive_and_store_secrets
|
from apps.cardano.seed import derive_and_store_secrets
|
||||||
|
|
||||||
derive_and_store_secrets(passphrase)
|
derive_and_store_secrets(passphrase)
|
||||||
|
|
||||||
@cache.stored_async(cache.APP_COMMON_SEED)
|
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
|
||||||
async def get_seed(ctx: wire.Context) -> bytes:
|
async def get_seed(ctx: Context) -> bytes:
|
||||||
await derive_and_store_roots(ctx)
|
await derive_and_store_roots(ctx)
|
||||||
common_seed = cache.get(cache.APP_COMMON_SEED)
|
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
|
||||||
assert common_seed is not None
|
assert common_seed is not None
|
||||||
return common_seed
|
return common_seed
|
||||||
|
|
||||||
@ -81,15 +86,15 @@ else:
|
|||||||
# === Bitcoin-only variant ===
|
# === Bitcoin-only variant ===
|
||||||
# We use the simple version of `get_seed` that never needs to derive anything else.
|
# We use the simple version of `get_seed` that never needs to derive anything else.
|
||||||
|
|
||||||
@cache.stored_async(cache.APP_COMMON_SEED)
|
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
|
||||||
async def get_seed(ctx: wire.Context) -> bytes:
|
async def get_seed(ctx: Context) -> bytes:
|
||||||
passphrase = await get_passphrase(ctx)
|
passphrase = await get_passphrase(ctx)
|
||||||
return mnemonic.get_seed(passphrase)
|
return mnemonic.get_seed(passphrase)
|
||||||
|
|
||||||
|
|
||||||
@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
||||||
def _get_seed_without_passphrase() -> bytes:
|
def _get_seed_without_passphrase() -> bytes:
|
||||||
if not device.is_initialized():
|
if not storage_device.is_initialized():
|
||||||
raise Exception("Device is not initialized")
|
raise Exception("Device is not initialized")
|
||||||
return mnemonic.get_seed(progress_bar=False)
|
return mnemonic.get_seed(progress_bar=False)
|
||||||
|
|
||||||
@ -97,6 +102,8 @@ def _get_seed_without_passphrase() -> bytes:
|
|||||||
def derive_node_without_passphrase(
|
def derive_node_without_passphrase(
|
||||||
path: Bip32Path, curve_name: str = "secp256k1"
|
path: Bip32Path, curve_name: str = "secp256k1"
|
||||||
) -> bip32.HDNode:
|
) -> bip32.HDNode:
|
||||||
|
from trezor.crypto import bip32
|
||||||
|
|
||||||
seed = _get_seed_without_passphrase()
|
seed = _get_seed_without_passphrase()
|
||||||
node = bip32.from_seed(seed, curve_name)
|
node = bip32.from_seed(seed, curve_name)
|
||||||
node.derive_path(path)
|
node.derive_path(path)
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from ubinascii import hexlify
|
|
||||||
|
|
||||||
from trezor import utils, wire
|
|
||||||
from trezor.crypto.hashlib import blake256, sha256
|
|
||||||
|
|
||||||
from apps.common.writers import write_compact_size
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from apps.common.coininfo import CoinInfo
|
from apps.common.coininfo import CoinInfo
|
||||||
|
|
||||||
|
|
||||||
def message_digest(coin: CoinInfo, message: bytes) -> bytes:
|
def message_digest(coin: CoinInfo, message: bytes) -> bytes:
|
||||||
|
from trezor import utils, wire
|
||||||
|
from trezor.crypto.hashlib import blake256, sha256
|
||||||
|
|
||||||
|
from apps.common.writers import write_compact_size
|
||||||
|
|
||||||
if not utils.BITCOIN_ONLY and coin.decred:
|
if not utils.BITCOIN_ONLY and coin.decred:
|
||||||
h = utils.HashWriter(blake256())
|
h = utils.HashWriter(blake256())
|
||||||
else:
|
else:
|
||||||
@ -28,6 +27,8 @@ def message_digest(coin: CoinInfo, message: bytes) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
def decode_message(message: bytes) -> str:
|
def decode_message(message: bytes) -> str:
|
||||||
|
from ubinascii import hexlify
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return bytes(message).decode()
|
return bytes(message).decode()
|
||||||
except UnicodeError:
|
except UnicodeError:
|
||||||
|
@ -6,61 +6,38 @@ if TYPE_CHECKING:
|
|||||||
from trezor.utils import Writer
|
from trezor.utils import Writer
|
||||||
|
|
||||||
|
|
||||||
|
def _write_uint(w: Writer, n: int, bits: int, bigendian: bool) -> int:
|
||||||
|
ensure(0 <= n <= 2**bits - 1, "overflow")
|
||||||
|
shifts = range(0, bits, 8)
|
||||||
|
if bigendian:
|
||||||
|
shifts = reversed(shifts)
|
||||||
|
for num in shifts:
|
||||||
|
w.append((n >> num) & 0xFF)
|
||||||
|
return bits // 8
|
||||||
|
|
||||||
|
|
||||||
def write_uint8(w: Writer, n: int) -> int:
|
def write_uint8(w: Writer, n: int) -> int:
|
||||||
ensure(0 <= n <= 0xFF)
|
return _write_uint(w, n, 8, False)
|
||||||
w.append(n)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
def write_uint16_le(w: Writer, n: int) -> int:
|
def write_uint16_le(w: Writer, n: int) -> int:
|
||||||
ensure(0 <= n <= 0xFFFF)
|
return _write_uint(w, n, 16, False)
|
||||||
w.append(n & 0xFF)
|
|
||||||
w.append((n >> 8) & 0xFF)
|
|
||||||
return 2
|
|
||||||
|
|
||||||
|
|
||||||
def write_uint32_le(w: Writer, n: int) -> int:
|
def write_uint32_le(w: Writer, n: int) -> int:
|
||||||
ensure(0 <= n <= 0xFFFF_FFFF)
|
return _write_uint(w, n, 32, False)
|
||||||
w.append(n & 0xFF)
|
|
||||||
w.append((n >> 8) & 0xFF)
|
|
||||||
w.append((n >> 16) & 0xFF)
|
|
||||||
w.append((n >> 24) & 0xFF)
|
|
||||||
return 4
|
|
||||||
|
|
||||||
|
|
||||||
def write_uint32_be(w: Writer, n: int) -> int:
|
def write_uint32_be(w: Writer, n: int) -> int:
|
||||||
ensure(0 <= n <= 0xFFFF_FFFF)
|
return _write_uint(w, n, 32, True)
|
||||||
w.append((n >> 24) & 0xFF)
|
|
||||||
w.append((n >> 16) & 0xFF)
|
|
||||||
w.append((n >> 8) & 0xFF)
|
|
||||||
w.append(n & 0xFF)
|
|
||||||
return 4
|
|
||||||
|
|
||||||
|
|
||||||
def write_uint64_le(w: Writer, n: int) -> int:
|
def write_uint64_le(w: Writer, n: int) -> int:
|
||||||
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF)
|
return _write_uint(w, n, 64, False)
|
||||||
w.append(n & 0xFF)
|
|
||||||
w.append((n >> 8) & 0xFF)
|
|
||||||
w.append((n >> 16) & 0xFF)
|
|
||||||
w.append((n >> 24) & 0xFF)
|
|
||||||
w.append((n >> 32) & 0xFF)
|
|
||||||
w.append((n >> 40) & 0xFF)
|
|
||||||
w.append((n >> 48) & 0xFF)
|
|
||||||
w.append((n >> 56) & 0xFF)
|
|
||||||
return 8
|
|
||||||
|
|
||||||
|
|
||||||
def write_uint64_be(w: Writer, n: int) -> int:
|
def write_uint64_be(w: Writer, n: int) -> int:
|
||||||
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF)
|
return _write_uint(w, n, 64, True)
|
||||||
w.append((n >> 56) & 0xFF)
|
|
||||||
w.append((n >> 48) & 0xFF)
|
|
||||||
w.append((n >> 40) & 0xFF)
|
|
||||||
w.append((n >> 32) & 0xFF)
|
|
||||||
w.append((n >> 24) & 0xFF)
|
|
||||||
w.append((n >> 16) & 0xFF)
|
|
||||||
w.append((n >> 8) & 0xFF)
|
|
||||||
w.append(n & 0xFF)
|
|
||||||
return 8
|
|
||||||
|
|
||||||
|
|
||||||
def write_bytes_unchecked(w: Writer, b: bytes | memoryview) -> int:
|
def write_bytes_unchecked(w: Writer, b: bytes | memoryview) -> int:
|
||||||
@ -82,16 +59,18 @@ def write_bytes_reversed(w: Writer, b: bytes, length: int) -> int:
|
|||||||
|
|
||||||
def write_compact_size(w: Writer, n: int) -> None:
|
def write_compact_size(w: Writer, n: int) -> None:
|
||||||
ensure(0 <= n <= 0xFFFF_FFFF)
|
ensure(0 <= n <= 0xFFFF_FFFF)
|
||||||
|
append = w.append # local_cache_attribute
|
||||||
|
|
||||||
if n < 253:
|
if n < 253:
|
||||||
w.append(n & 0xFF)
|
append(n & 0xFF)
|
||||||
elif n < 0x1_0000:
|
elif n < 0x1_0000:
|
||||||
w.append(253)
|
append(253)
|
||||||
write_uint16_le(w, n)
|
write_uint16_le(w, n)
|
||||||
elif n < 0x1_0000_0000:
|
elif n < 0x1_0000_0000:
|
||||||
w.append(254)
|
append(254)
|
||||||
write_uint32_le(w, n)
|
write_uint32_le(w, n)
|
||||||
else:
|
else:
|
||||||
w.append(255)
|
append(255)
|
||||||
write_uint64_le(w, n)
|
write_uint64_le(w, n)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from common import *
|
from common import *
|
||||||
|
|
||||||
from apps.common.cbor import (
|
from apps.common.cbor import (
|
||||||
@ -12,43 +10,8 @@ from apps.common.cbor import (
|
|||||||
decode,
|
decode,
|
||||||
encode,
|
encode,
|
||||||
encode_streamed,
|
encode_streamed,
|
||||||
utils
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: moved into tests not to occupy flash space
|
|
||||||
# in firmware binary, when it is not used in production
|
|
||||||
def encode_chunked(value: "Value", max_chunk_size: int) -> "Iterator[bytes]":
|
|
||||||
"""
|
|
||||||
Returns the encoded value as an iterable of chunks of a given size,
|
|
||||||
removing the need to reserve a continuous chunk of memory for the
|
|
||||||
full serialized representation of the value.
|
|
||||||
"""
|
|
||||||
if max_chunk_size <= 0:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
chunks = encode_streamed(value)
|
|
||||||
|
|
||||||
chunk_buffer = utils.empty_bytearray(max_chunk_size)
|
|
||||||
try:
|
|
||||||
current_chunk_view = utils.BufferReader(next(chunks))
|
|
||||||
while True:
|
|
||||||
num_bytes_to_write = min(
|
|
||||||
current_chunk_view.remaining_count(),
|
|
||||||
max_chunk_size - len(chunk_buffer),
|
|
||||||
)
|
|
||||||
chunk_buffer.extend(current_chunk_view.read(num_bytes_to_write))
|
|
||||||
|
|
||||||
if len(chunk_buffer) >= max_chunk_size:
|
|
||||||
yield chunk_buffer
|
|
||||||
chunk_buffer[:] = bytes()
|
|
||||||
|
|
||||||
if current_chunk_view.remaining_count() == 0:
|
|
||||||
current_chunk_view = utils.BufferReader(next(chunks))
|
|
||||||
except StopIteration:
|
|
||||||
if len(chunk_buffer) > 0:
|
|
||||||
yield chunk_buffer
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestCardanoCbor(unittest.TestCase):
|
class TestCardanoCbor(unittest.TestCase):
|
||||||
def test_create_array_header(self):
|
def test_create_array_header(self):
|
||||||
@ -211,43 +174,6 @@ class TestCardanoCbor(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(b''.join(encoded_streamed), encoded)
|
self.assertEqual(b''.join(encoded_streamed), encoded)
|
||||||
|
|
||||||
def test_encode_chunked(self):
|
|
||||||
large_dict = {i: i for i in range(100)}
|
|
||||||
encoded = encode(large_dict)
|
|
||||||
|
|
||||||
encoded_len = len(encoded)
|
|
||||||
assert encoded_len == 354
|
|
||||||
|
|
||||||
arbitrary_encoded_len_factor = 59
|
|
||||||
arbitrary_power_of_two = 64
|
|
||||||
larger_than_encoded_len = encoded_len + 1
|
|
||||||
|
|
||||||
for max_chunk_size in [
|
|
||||||
1,
|
|
||||||
10,
|
|
||||||
arbitrary_encoded_len_factor,
|
|
||||||
arbitrary_power_of_two,
|
|
||||||
encoded_len,
|
|
||||||
larger_than_encoded_len
|
|
||||||
]:
|
|
||||||
encoded_chunks = [
|
|
||||||
bytes(chunk) for chunk in encode_chunked(large_dict, max_chunk_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
expected_number_of_chunks = math.ceil(len(encoded) / max_chunk_size)
|
|
||||||
self.assertEqual(len(encoded_chunks), expected_number_of_chunks)
|
|
||||||
|
|
||||||
# all chunks except the last should be of chunk_size
|
|
||||||
for i in range(len(encoded_chunks) - 1):
|
|
||||||
self.assertEqual(len(encoded_chunks[i]), max_chunk_size)
|
|
||||||
|
|
||||||
# last chunk should contain the remaining bytes or the whole chunk
|
|
||||||
remaining_bytes = len(encoded) % max_chunk_size
|
|
||||||
expected_last_chunk_size = remaining_bytes if remaining_bytes > 0 else max_chunk_size
|
|
||||||
self.assertEqual(len(encoded_chunks[-1]), expected_last_chunk_size)
|
|
||||||
|
|
||||||
self.assertEqual(b''.join(encoded_chunks), encoded)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
44
core/tests/test_apps.common.writers.py
Normal file
44
core/tests/test_apps.common.writers.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from common import *
|
||||||
|
|
||||||
|
import apps.common.writers as writers
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeed(unittest.TestCase):
|
||||||
|
def test_write_uint8(self):
|
||||||
|
buf = bytearray()
|
||||||
|
writers.write_uint8(buf, 0x12)
|
||||||
|
self.assertEqual(buf, b"\x12")
|
||||||
|
|
||||||
|
def test_write_uint16_le(self):
|
||||||
|
buf = bytearray()
|
||||||
|
writers.write_uint16_le(buf, 0x1234)
|
||||||
|
self.assertEqual(buf, b"\x34\x12")
|
||||||
|
|
||||||
|
def test_write_uint16_le_overflow(self):
|
||||||
|
buf = bytearray()
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
writers.write_uint16_le(buf, 0x12345678)
|
||||||
|
|
||||||
|
def test_write_uint32_le(self):
|
||||||
|
buf = bytearray()
|
||||||
|
writers.write_uint32_le(buf, 0x12345678)
|
||||||
|
self.assertEqual(buf, b"\x78\x56\x34\x12")
|
||||||
|
|
||||||
|
def test_write_uint64_le(self):
|
||||||
|
buf = bytearray()
|
||||||
|
writers.write_uint64_le(buf, 0x1234567890abcdef)
|
||||||
|
self.assertEqual(buf, b"\xef\xcd\xab\x90\x78\x56\x34\x12")
|
||||||
|
|
||||||
|
def test_write_uint32_be(self):
|
||||||
|
buf = bytearray()
|
||||||
|
writers.write_uint32_be(buf, 0x12345678)
|
||||||
|
self.assertEqual(buf, b"\x12\x34\x56\x78")
|
||||||
|
|
||||||
|
def test_write_uint64_be(self):
|
||||||
|
buf = bytearray()
|
||||||
|
writers.write_uint64_be(buf, 0x1234567890abcdef)
|
||||||
|
self.assertEqual(buf, b"\x12\x34\x56\x78\x90\xab\xcd\xef")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user