1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-25 11:39:02 +00:00

chore(core): decrease common size by 5200 bytes

This commit is contained in:
grdddj 2022-09-19 11:17:36 +02:00 committed by matejcik
parent 5e7cc8b692
commit 3711fd0f19
18 changed files with 1773 additions and 1764 deletions

View File

@ -1,9 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils, wire from trezor import utils
from trezor.crypto import hashlib, hmac
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -14,12 +11,18 @@ _ADDRESS_MAC_KEY_PATH = [b"SLIP-0024", b"Address MAC key"]
def check_address_mac( def check_address_mac(
address: str, mac: bytes, slip44: int, keychain: Keychain address: str, mac: bytes, slip44: int, keychain: Keychain
) -> None: ) -> None:
from trezor import wire
from trezor.crypto import hashlib
expected_mac = get_address_mac(address, slip44, keychain) expected_mac = get_address_mac(address, slip44, keychain)
if len(mac) != hashlib.sha256.digest_size or not utils.consteq(expected_mac, mac): if len(mac) != hashlib.sha256.digest_size or not utils.consteq(expected_mac, mac):
raise wire.DataError("Invalid address MAC.") raise wire.DataError("Invalid address MAC.")
def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes: def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes:
from trezor.crypto import hmac
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
# k = Key(m/"SLIP-0024"/"Address MAC key") # k = Key(m/"SLIP-0024"/"Address MAC key")
node = keychain.derive_slip21(_ADDRESS_MAC_KEY_PATH) node = keychain.derive_slip21(_ADDRESS_MAC_KEY_PATH)

View File

@ -1,45 +1,53 @@
from typing import Iterable from typing import Iterable
import storage.cache import storage.cache as storage_cache
from trezor import protobuf from trezor import protobuf
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.utils import ensure
WIRE_TYPES: dict[int, tuple[int, ...]] = { WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof), MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
} }
APP_COMMON_AUTHORIZATION_DATA = (
storage_cache.APP_COMMON_AUTHORIZATION_DATA
) # global_import_cache
APP_COMMON_AUTHORIZATION_TYPE = (
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
) # global_import_cache
def is_set() -> bool: def is_set() -> bool:
return bool(storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)) return bool(storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE))
def set(auth_message: protobuf.MessageType) -> None: def set(auth_message: protobuf.MessageType) -> None:
from trezor.utils import ensure
buffer = protobuf.dump_message_buffer(auth_message) buffer = protobuf.dump_message_buffer(auth_message)
# only wire-level messages can be stored as authorization # only wire-level messages can be stored as authorization
# (because only wire-level messages have wire_type, which we use as identifier) # (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None) ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
storage.cache.set( storage_cache.set(
storage.cache.APP_COMMON_AUTHORIZATION_TYPE, APP_COMMON_AUTHORIZATION_TYPE,
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"), auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
) )
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, buffer) storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None: def get() -> protobuf.MessageType | None:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE) stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type: if not stored_auth_type:
return None return None
msg_wire_type = int.from_bytes(stored_auth_type, "big") msg_wire_type = int.from_bytes(stored_auth_type, "big")
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"") buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, msg_wire_type) return protobuf.load_message_buffer(buffer, msg_wire_type)
def get_wire_types() -> Iterable[int]: def get_wire_types() -> Iterable[int]:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE) stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
if stored_auth_type is None: if stored_auth_type is None:
return () return ()
@ -48,5 +56,5 @@ def get_wire_types() -> Iterable[int]:
def clear() -> None: def clear() -> None:
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_TYPE) storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_DATA) storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)

View File

@ -2,16 +2,14 @@
Minimalistic CBOR implementation, supports only what we need in cardano. Minimalistic CBOR implementation, supports only what we need in cardano.
""" """
import ustruct as struct
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import log, utils from trezor import log
from . import readers
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Generic, Iterator, TypeVar from typing import Any, Generic, Iterator, TypeVar
from trezor.utils import BufferReader
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
@ -48,16 +46,18 @@ _CBOR_RAW_TAG = const(0x18)
def _header(typ: int, l: int) -> bytes: def _header(typ: int, l: int) -> bytes:
from ustruct import pack
if l < 24: if l < 24:
return struct.pack(">B", typ + l) return pack(">B", typ + l)
elif l < 2**8: elif l < 2**8:
return struct.pack(">BB", typ + 24, l) return pack(">BB", typ + 24, l)
elif l < 2**16: elif l < 2**16:
return struct.pack(">BH", typ + 25, l) return pack(">BH", typ + 25, l)
elif l < 2**32: elif l < 2**32:
return struct.pack(">BI", typ + 26, l) return pack(">BI", typ + 26, l)
elif l < 2**64: elif l < 2**64:
return struct.pack(">BQ", typ + 27, l) return pack(">BQ", typ + 27, l)
else: else:
raise NotImplementedError # Length not supported raise NotImplementedError # Length not supported
@ -117,7 +117,9 @@ def _cbor_encode(value: Value) -> Iterator[bytes]:
raise NotImplementedError raise NotImplementedError
def _read_length(r: utils.BufferReader, aux: int) -> int: def _read_length(r: BufferReader, aux: int) -> int:
from . import readers
if aux < _CBOR_UINT8_FOLLOWS: if aux < _CBOR_UINT8_FOLLOWS:
return aux return aux
elif aux == _CBOR_UINT8_FOLLOWS: elif aux == _CBOR_UINT8_FOLLOWS:
@ -132,7 +134,7 @@ def _read_length(r: utils.BufferReader, aux: int) -> int:
raise NotImplementedError # Length not supported raise NotImplementedError # Length not supported
def _cbor_decode(r: utils.BufferReader) -> Value: def _cbor_decode(r: BufferReader) -> Value:
fb = r.get() fb = r.get()
fb_type = fb & _CBOR_TYPE_MASK fb_type = fb & _CBOR_TYPE_MASK
fb_aux = fb & _CBOR_INFO_BITS fb_aux = fb & _CBOR_INFO_BITS
@ -220,6 +222,7 @@ class Tagged:
) )
# TODO: this seems to be unused - is checked against, but is never created???
class Raw: class Raw:
def __init__(self, value: Value): def __init__(self, value: Value):
self.value = value self.value = value
@ -272,7 +275,9 @@ def encode_streamed(value: Value) -> Iterator[bytes]:
def decode(cbor: bytes, offset: int = 0) -> Value: def decode(cbor: bytes, offset: int = 0) -> Value:
r = utils.BufferReader(cbor) from trezor.utils import BufferReader
r = BufferReader(cbor)
r.seek(offset) r.seek(offset)
res = _cbor_decode(r) res = _cbor_decode(r)
if r.remaining_count(): if r.remaining_count():

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,9 @@
# generated from coininfo.py.mako # generated from coininfo.py.mako
# (by running `make templates` in `core`) # (by running `make templates` in `core`)
# do not edit manually! # do not edit manually!
# NOTE: using positional arguments saves 4500 bytes of flash size
from typing import Any from typing import Any
from trezor import utils from trezor import utils
@ -142,7 +145,7 @@ def by_name(name: str) -> CoinInfo:
if name == ${black_repr(coin["coin_name"])}: if name == ${black_repr(coin["coin_name"])}:
return CoinInfo( return CoinInfo(
% for attr, func in ATTRIBUTES: % for attr, func in ATTRIBUTES:
${attr}=${func(coin[attr])}, ${func(coin[attr])}, # ${attr}
% endfor % endfor
) )
% endfor % endfor
@ -151,7 +154,7 @@ def by_name(name: str) -> CoinInfo:
if name == ${black_repr(coin["coin_name"])}: if name == ${black_repr(coin["coin_name"])}:
return CoinInfo( return CoinInfo(
% for attr, func in ATTRIBUTES: % for attr, func in ATTRIBUTES:
${attr}=${func(coin[attr])}, ${func(coin[attr])}, # ${attr}
% endfor % endfor
) )
% endfor % endfor

View File

@ -1,11 +1,9 @@
import sys
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.wire import DataError
from . import paths, safety_checks from . import paths, safety_checks
from .seed import Slip21Node, get_seed
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import ( from typing import (
@ -18,6 +16,8 @@ if TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
from trezor.wire import Context
from .seed import Slip21Node
T = TypeVar("T") T = TypeVar("T")
@ -36,15 +36,15 @@ if TYPE_CHECKING:
MsgIn = TypeVar("MsgIn", bound=MessageType) MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType) MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]] Handler = Callable[[Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]] HandlerWithKeychain = Callable[[Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
class Deletable(Protocol): class Deletable(Protocol):
def __del__(self) -> None: def __del__(self) -> None:
... ...
FORBIDDEN_KEY_PATH = wire.DataError("Forbidden key path") FORBIDDEN_KEY_PATH = DataError("Forbidden key path")
class LRUCache: class LRUCache:
@ -54,13 +54,15 @@ class LRUCache:
self.cache: dict[Any, Deletable] = {} self.cache: dict[Any, Deletable] = {}
def insert(self, key: Any, value: Deletable) -> None: def insert(self, key: Any, value: Deletable) -> None:
if key in self.cache_keys: cache_keys = self.cache_keys # local_cache_attribute
self.cache_keys.remove(key)
self.cache_keys.insert(0, key) if key in cache_keys:
cache_keys.remove(key)
cache_keys.insert(0, key)
self.cache[key] = value self.cache[key] = value
if len(self.cache_keys) > self.size: if len(cache_keys) > self.size:
dropped_key = self.cache_keys.pop() dropped_key = cache_keys.pop()
self.cache[dropped_key].__del__() self.cache[dropped_key].__del__()
del self.cache[dropped_key] del self.cache[dropped_key]
@ -103,7 +105,7 @@ class Keychain:
def verify_path(self, path: paths.Bip32Path) -> None: def verify_path(self, path: paths.Bip32Path) -> None:
if "ed25519" in self.curve and not paths.path_is_hardened(path): if "ed25519" in self.curve and not paths.path_is_hardened(path):
raise wire.DataError("Non-hardened paths unsupported on Ed25519") raise DataError("Non-hardened paths unsupported on Ed25519")
if not safety_checks.is_strict(): if not safety_checks.is_strict():
return return
@ -137,8 +139,8 @@ class Keychain:
if self._root_fingerprint is None: if self._root_fingerprint is None:
# derive m/0' to obtain root_fingerprint # derive m/0' to obtain root_fingerprint
n = self._derive_with_cache( n = self._derive_with_cache(
prefix_len=0, 0,
path=[0 | paths.HARDENED], [0 | paths.HARDENED],
new_root=lambda: bip32.from_seed(self.seed, self.curve), new_root=lambda: bip32.from_seed(self.seed, self.curve),
) )
self._root_fingerprint = n.fingerprint() self._root_fingerprint = n.fingerprint()
@ -147,20 +149,22 @@ class Keychain:
def derive(self, path: paths.Bip32Path) -> bip32.HDNode: def derive(self, path: paths.Bip32Path) -> bip32.HDNode:
self.verify_path(path) self.verify_path(path)
return self._derive_with_cache( return self._derive_with_cache(
prefix_len=3, 3,
path=path, path,
new_root=lambda: bip32.from_seed(self.seed, self.curve), new_root=lambda: bip32.from_seed(self.seed, self.curve),
) )
def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node: def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node:
from .seed import Slip21Node
if safety_checks.is_strict() and not any( if safety_checks.is_strict() and not any(
ns == path[: len(ns)] for ns in self.slip21_namespaces ns == path[: len(ns)] for ns in self.slip21_namespaces
): ):
raise FORBIDDEN_KEY_PATH raise FORBIDDEN_KEY_PATH
return self._derive_with_cache( return self._derive_with_cache(
prefix_len=1, 1,
path=path, path,
new_root=lambda: Slip21Node(seed=self.seed), new_root=lambda: Slip21Node(seed=self.seed),
) )
@ -172,11 +176,13 @@ class Keychain:
async def get_keychain( async def get_keychain(
ctx: wire.Context, ctx: Context,
curve: str, curve: str,
schemas: Iterable[paths.PathSchemaType], schemas: Iterable[paths.PathSchemaType],
slip21_namespaces: Iterable[paths.Slip21Path] = (), slip21_namespaces: Iterable[paths.Slip21Path] = (),
) -> Keychain: ) -> Keychain:
from .seed import get_seed
seed = await get_seed(ctx) seed = await get_seed(ctx)
keychain = Keychain(seed, curve, schemas, slip21_namespaces) keychain = Keychain(seed, curve, schemas, slip21_namespaces)
return keychain return keychain
@ -191,18 +197,15 @@ def with_slip44_keychain(
if not patterns: if not patterns:
raise ValueError # specify a pattern raise ValueError # specify a pattern
if allow_testnet: slip_44_ids = (slip44_id, 1) if allow_testnet else slip44_id
slip44_ids: int | tuple[int, int] = (slip44_id, 1)
else:
slip44_ids = slip44_id
schemas = [] schemas = []
for pattern in patterns: for pattern in patterns:
schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids)) schemas.append(paths.PathSchema.parse(pattern, slip_44_ids))
schemas = [s.copy() for s in schemas] schemas = [s.copy() for s in schemas]
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, curve, schemas) keychain = await get_keychain(ctx, curve, schemas)
with keychain: with keychain:
return await func(ctx, msg, keychain) return await func(ctx, msg, keychain)
@ -215,6 +218,8 @@ def with_slip44_keychain(
def auto_keychain( def auto_keychain(
modname: str, allow_testnet: bool = True modname: str, allow_testnet: bool = True
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: ) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
import sys
rdot = modname.rfind(".") rdot = modname.rfind(".")
parent_modname = modname[:rdot] parent_modname = modname[:rdot]
parent_module = sys.modules[parent_modname] parent_module = sys.modules[parent_modname]

View File

@ -1,6 +1,10 @@
import storage.device from typing import TYPE_CHECKING
from trezor import ui, utils, workflow
from trezor.enums import BackupType import storage.device as storage_device
from trezor import utils
if TYPE_CHECKING:
from trezor.enums import BackupType
def get() -> tuple[bytes | None, BackupType]: def get() -> tuple[bytes | None, BackupType]:
@ -8,11 +12,11 @@ def get() -> tuple[bytes | None, BackupType]:
def get_secret() -> bytes | None: def get_secret() -> bytes | None:
return storage.device.get_mnemonic_secret() return storage_device.get_mnemonic_secret()
def get_type() -> BackupType: def get_type() -> BackupType:
return storage.device.get_backup_type() return storage_device.get_backup_type()
def is_bip39() -> bool: def is_bip39() -> bool:
@ -20,6 +24,8 @@ def is_bip39() -> bool:
If False then SLIP-39 (either Basic or Advanced). If False then SLIP-39 (either Basic or Advanced).
Other invalid values are checked directly in storage. Other invalid values are checked directly in storage.
""" """
from trezor.enums import BackupType
return get_type() == BackupType.Bip39 return get_type() == BackupType.Bip39
@ -41,8 +47,8 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
else: # SLIP-39 else: # SLIP-39
from trezor.crypto import slip39 from trezor.crypto import slip39
identifier = storage.device.get_slip39_identifier() identifier = storage_device.get_slip39_identifier()
iteration_exponent = storage.device.get_slip39_iteration_exponent() iteration_exponent = storage_device.get_slip39_iteration_exponent()
if identifier is None or iteration_exponent is None: if identifier is None or iteration_exponent is None:
# Identifier or exponent expected but not found # Identifier or exponent expected but not found
raise RuntimeError raise RuntimeError
@ -84,6 +90,7 @@ if not utils.BITCOIN_ONLY:
def _start_progress() -> None: def _start_progress() -> None:
from trezor import workflow
from trezor.ui.layouts import draw_simple_text from trezor.ui.layouts import draw_simple_text
# Because we are drawing to the screen manually, without a layout, we # Because we are drawing to the screen manually, without a layout, we
@ -93,6 +100,8 @@ def _start_progress() -> None:
def _render_progress(progress: int, total: int) -> None: def _render_progress(progress: int, total: int) -> None:
from trezor import ui
p = 1000 * progress // total p = 1000 * progress // total
ui.display.loader(p, False, 18, ui.WHITE, ui.BG) ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
ui.refresh() ui.refresh()

View File

@ -1,66 +1,72 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING
import storage.device import storage.device as storage_device
from trezor import wire, workflow from trezor.wire import DataError
if TYPE_CHECKING:
from trezor.wire import Context
_MAX_PASSPHRASE_LEN = const(50) _MAX_PASSPHRASE_LEN = const(50)
def is_enabled() -> bool: def is_enabled() -> bool:
return storage.device.is_passphrase_enabled() return storage_device.is_passphrase_enabled()
async def get(ctx: wire.Context) -> str: async def get(ctx: Context) -> str:
if is_enabled(): from trezor import workflow
return await _request_from_user(ctx)
else: if not is_enabled():
return "" return ""
async def _request_from_user(ctx: wire.Context) -> str:
workflow.close_others() # request exclusive UI access
if storage.device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
else: else:
passphrase = await _request_on_host(ctx) workflow.close_others() # request exclusive UI access
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN: if storage_device.get_passphrase_always_on_device():
raise wire.DataError( from trezor.ui.layouts import request_passphrase_on_device
f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
)
return passphrase passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
else:
passphrase = await _request_on_host(ctx)
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _request_on_host(ctx: wire.Context) -> str: async def _request_on_host(ctx: Context) -> str:
from trezor.messages import PassphraseAck, PassphraseRequest from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import draw_simple_text
_entry_dialog() # _entry_dialog
draw_simple_text(
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
)
request = PassphraseRequest() request = PassphraseRequest()
ack = await ctx.call(request, PassphraseAck) ack = await ctx.call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute
if ack.on_device: if ack.on_device:
from trezor.ui.layouts import request_passphrase_on_device from trezor.ui.layouts import request_passphrase_on_device
if ack.passphrase is not None: if passphrase is not None:
raise wire.DataError("Passphrase provided when it should not be") raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN) return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
if ack.passphrase is None: if passphrase is None:
raise wire.DataError( raise DataError(
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase." "Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
) )
# non-empty passphrase # non-empty passphrase
if ack.passphrase: if passphrase:
from trezor import ui from trezor import ui
from trezor.ui.layouts import confirm_action, confirm_blob from trezor.ui.layouts import confirm_action, confirm_blob
await confirm_action( await confirm_action(
ctx, ctx,
"passphrase_host1", "passphrase_host1",
title="Hidden wallet", "Hidden wallet",
description="Access hidden wallet?\n\nNext screen will show\nthe passphrase!", description="Access hidden wallet?\n\nNext screen will show\nthe passphrase!",
icon=ui.ICON_CONFIG, icon=ui.ICON_CONFIG,
) )
@ -68,19 +74,11 @@ async def _request_on_host(ctx: wire.Context) -> str:
await confirm_blob( await confirm_blob(
ctx, ctx,
"passphrase_host2", "passphrase_host2",
title="Hidden wallet", "Hidden wallet",
description="Use this passphrase?\n", passphrase,
data=ack.passphrase, "Use this passphrase?\n",
icon=ui.ICON_CONFIG, icon=ui.ICON_CONFIG,
icon_color=ui.ORANGE_ICON, icon_color=ui.ORANGE_ICON,
) )
return ack.passphrase return passphrase
def _entry_dialog() -> None:
from trezor.ui.layouts import draw_simple_text
draw_simple_text(
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
)

View File

@ -197,23 +197,24 @@ class PathSchema:
# optionally replace a keyword # optionally replace a keyword
component = cls.REPLACEMENTS.get(component, component) component = cls.REPLACEMENTS.get(component, component)
append = schema.append # local_cache_attribute
if "-" in component: if "-" in component:
# parse as a range # parse as a range
a, b = [parse(s) for s in component.split("-", 1)] a, b = [parse(s) for s in component.split("-", 1)]
schema.append(Interval(a, b)) append(Interval(a, b))
elif "," in component: elif "," in component:
# parse as a list of values # parse as a list of values
schema.append(set(parse(s) for s in component.split(","))) append(set(parse(s) for s in component.split(",")))
elif component == "coin_type": elif component == "coin_type":
# substitute SLIP-44 ids # substitute SLIP-44 ids
schema.append(set(parse(s) for s in slip44_id)) append(set(parse(s) for s in slip44_id))
else: else:
# plain constant # plain constant
schema.append((parse(component),)) append((parse(component),))
return cls(schema, trailing_components, compact=True) return cls(schema, trailing_components, compact=True)
@ -258,18 +259,19 @@ class PathSchema:
path. If the restriction results in a never-matching schema, then False path. If the restriction results in a never-matching schema, then False
is returned. is returned.
""" """
schema = self.schema # local_cache_attribute
for i, value in enumerate(path): for i, value in enumerate(path):
if i < len(self.schema): if i < len(schema):
# Ensure that the path is a prefix of the schema. # Ensure that the path is a prefix of the schema.
if value not in self.schema[i]: if value not in schema[i]:
self.set_never_matching() self.set_never_matching()
return False return False
# Restrict the schema component if there are multiple choices. # Restrict the schema component if there are multiple choices.
component = self.schema[i] component = schema[i]
if not isinstance(component, tuple) or len(component) != 1: if not isinstance(component, tuple) or len(component) != 1:
self.schema[i] = (value,) schema[i] = (value,)
else: else:
# The path is longer than the schema. We need to restrict the # The path is longer than the schema. We need to restrict the
# trailing components. # trailing components.
@ -278,7 +280,7 @@ class PathSchema:
self.set_never_matching() self.set_never_matching()
return False return False
self.schema.append((value,)) schema.append((value,))
return True return True
@ -286,6 +288,7 @@ class PathSchema:
def __repr__(self) -> str: def __repr__(self) -> str:
components = ["m"] components = ["m"]
append = components.append # local_cache_attribute
def unharden(item: int) -> int: def unharden(item: int) -> int:
return item ^ (item & HARDENED) return item ^ (item & HARDENED)
@ -294,7 +297,7 @@ class PathSchema:
if isinstance(component, Interval): if isinstance(component, Interval):
a, b = component.min, component.max a, b = component.min, component.max
prime = "'" if a & HARDENED else "" prime = "'" if a & HARDENED else ""
components.append(f"[{unharden(a)}-{unharden(b)}]{prime}") append(f"[{unharden(a)}-{unharden(b)}]{prime}")
else: else:
# typechecker thinks component is a Contanier but we're using it # typechecker thinks component is a Contanier but we're using it
# as a Collection. # as a Collection.
@ -307,15 +310,15 @@ class PathSchema:
component_str = "[" + component_str + "]" component_str = "[" + component_str + "]"
if next(iter(collection)) & HARDENED: if next(iter(collection)) & HARDENED:
component_str += "'" component_str += "'"
components.append(component_str) append(component_str)
if self.trailing_components: if self.trailing_components:
for key, val in self.WILDCARD_RANGES.items(): for key, val in self.WILDCARD_RANGES.items():
if self.trailing_components is val: if self.trailing_components is val:
components.append(key) append(key)
break break
else: else:
components.append("???") append("???")
return "<schema:" + "/".join(components) + ">" return "<schema:" + "/".join(components) + ">"
@ -362,7 +365,7 @@ def path_is_hardened(address_n: Bip32Path) -> bool:
def address_n_to_str(address_n: Iterable[int]) -> str: def address_n_to_str(address_n: Iterable[int]) -> str:
def path_item(i: int) -> str: def _path_item(i: int) -> str:
if i & HARDENED: if i & HARDENED:
return str(i ^ HARDENED) + "'" return str(i ^ HARDENED) + "'"
else: else:
@ -371,4 +374,4 @@ def address_n_to_str(address_n: Iterable[int]) -> str:
if not address_n: if not address_n:
return "m" return "m"
return "m/" + "/".join(path_item(i) for i in address_n) return "m/" + "/".join(_path_item(i) for i in address_n)

View File

@ -1,27 +1,32 @@
from trezor.utils import BufferReader from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.utils import BufferReader
def read_compact_size(r: BufferReader) -> int: def read_compact_size(r: BufferReader) -> int:
prefix = r.get() get = r.get # local_cache_attribute
prefix = get()
if prefix < 253: if prefix < 253:
n = prefix n = prefix
elif prefix == 253: elif prefix == 253:
n = r.get() n = get()
n += r.get() << 8 n += get() << 8
elif prefix == 254: elif prefix == 254:
n = r.get() n = get()
n += r.get() << 8 n += get() << 8
n += r.get() << 16 n += get() << 16
n += r.get() << 24 n += get() << 24
elif prefix == 255: elif prefix == 255:
n = r.get() n = get()
n += r.get() << 8 n += get() << 8
n += r.get() << 16 n += get() << 16
n += r.get() << 24 n += get() << 24
n += r.get() << 32 n += get() << 32
n += r.get() << 40 n += get() << 40
n += r.get() << 48 n += get() << 48
n += r.get() << 56 n += get() << 56
else: else:
raise ValueError raise ValueError
return n return n

View File

@ -1,20 +1,25 @@
import utime import utime
from typing import Any, NoReturn from typing import TYPE_CHECKING
import storage.cache import storage.cache as storage_cache
import storage.sd_salt
from trezor import config, wire from trezor import config, wire
from .sdcard import SdCardUnavailable, request_sd_salt from .sdcard import request_sd_salt
if TYPE_CHECKING:
from typing import Any, NoReturn
from trezor.wire import Context, GenericContext
def can_lock_device() -> bool: def can_lock_device() -> bool:
"""Return True if the device has a PIN set or SD-protect enabled.""" """Return True if the device has a PIN set or SD-protect enabled."""
import storage.sd_salt
return config.has_pin() or storage.sd_salt.is_enabled() return config.has_pin() or storage.sd_salt.is_enabled()
async def request_pin( async def request_pin(
ctx: wire.GenericContext, ctx: GenericContext,
prompt: str = "Enter your PIN", prompt: str = "Enter your PIN",
attempts_remaining: int | None = None, attempts_remaining: int | None = None,
allow_cancel: bool = True, allow_cancel: bool = True,
@ -24,26 +29,26 @@ async def request_pin(
return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel) return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel)
async def request_pin_confirm(ctx: wire.Context, *args: Any, **kwargs: Any) -> str: async def request_pin_confirm(ctx: Context, *args: Any, **kwargs: Any) -> str:
while True: while True:
pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs) pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs)
pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs) pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs)
if pin1 == pin2: if pin1 == pin2:
return pin1 return pin1
await pin_mismatch() await _pin_mismatch()
async def pin_mismatch() -> None: async def _pin_mismatch() -> None:
from trezor.ui.layouts import show_popup from trezor.ui.layouts import show_popup
await show_popup( await show_popup(
title="PIN mismatch", "PIN mismatch",
description="The PINs you entered\ndo not match.\n\nPlease try again.", "The PINs you entered\ndo not match.\n\nPlease try again.",
) )
async def request_pin_and_sd_salt( async def request_pin_and_sd_salt(
ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True ctx: Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
) -> tuple[str, bytearray | None]: ) -> tuple[str, bytearray | None]:
if config.has_pin(): if config.has_pin():
pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel) pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel)
@ -58,21 +63,23 @@ async def request_pin_and_sd_salt(
def _set_last_unlock_time() -> None: def _set_last_unlock_time() -> None:
now = utime.ticks_ms() now = utime.ticks_ms()
storage.cache.set_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
def _get_last_unlock_time() -> int:
return storage.cache.get_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK) or 0
async def verify_user_pin( async def verify_user_pin(
ctx: wire.GenericContext = wire.DUMMY_CONTEXT, ctx: GenericContext = wire.DUMMY_CONTEXT,
prompt: str = "Enter your PIN", prompt: str = "Enter your PIN",
allow_cancel: bool = True, allow_cancel: bool = True,
retry: bool = True, retry: bool = True,
cache_time_ms: int = 0, cache_time_ms: int = 0,
) -> None: ) -> None:
last_unlock = _get_last_unlock_time() from .sdcard import SdCardUnavailable
# _get_last_unlock_time
last_unlock = int.from_bytes(
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
)
if ( if (
cache_time_ms cache_time_ms
and last_unlock and last_unlock
@ -112,28 +119,28 @@ async def verify_user_pin(
raise wire.PinInvalid raise wire.PinInvalid
async def error_pin_invalid(ctx: wire.Context) -> NoReturn: async def error_pin_invalid(ctx: Context) -> NoReturn:
from trezor.ui.layouts import show_error_and_raise from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise( await show_error_and_raise(
ctx, ctx,
"warning_wrong_pin", "warning_wrong_pin",
header="Wrong PIN", "The PIN you entered is invalid.",
content="The PIN you entered is invalid.", "Wrong PIN", # header
red=True, red=True,
exc=wire.PinInvalid, exc=wire.PinInvalid,
) )
assert False assert False
async def error_pin_matches_wipe_code(ctx: wire.Context) -> NoReturn: async def error_pin_matches_wipe_code(ctx: Context) -> NoReturn:
from trezor.ui.layouts import show_error_and_raise from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise( await show_error_and_raise(
ctx, ctx,
"warning_invalid_new_pin", "warning_invalid_new_pin",
header="Invalid PIN", "The new PIN must be different from your\nwipe code.",
content="The new PIN must be different from your\nwipe code.", "Invalid PIN", # header
red=True, red=True,
exc=wire.PinInvalid, exc=wire.PinInvalid,
) )

View File

@ -1,5 +1,5 @@
import storage.cache import storage.cache as storage_cache
import storage.device import storage.device as storage_device
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
@ -9,11 +9,11 @@ def read_setting() -> SafetyCheckLevel:
""" """
Returns the effective safety check level. Returns the effective safety check level.
""" """
temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level: if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum] return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
else: else:
stored = storage.device.safety_check_level() stored = storage_device.safety_check_level()
if stored == SAFETY_CHECK_LEVEL_STRICT: if stored == SAFETY_CHECK_LEVEL_STRICT:
return SafetyCheckLevel.Strict return SafetyCheckLevel.Strict
elif stored == SAFETY_CHECK_LEVEL_PROMPT: elif stored == SAFETY_CHECK_LEVEL_PROMPT:
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
Changes the safety level settings. Changes the safety level settings.
""" """
if level == SafetyCheckLevel.Strict: if level == SafetyCheckLevel.Strict:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways: elif level == SafetyCheckLevel.PromptAlways:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily: elif level == SafetyCheckLevel.PromptTemporarily:
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else: else:
raise ValueError("Unknown SafetyCheckLevel") raise ValueError("Unknown SafetyCheckLevel")

View File

@ -1,6 +1,5 @@
import storage.sd_salt
from storage.sd_salt import SD_CARD_HOT_SWAPPABLE from storage.sd_salt import SD_CARD_HOT_SWAPPABLE
from trezor import io, sdcard, ui, wire from trezor import io, ui, wire
from trezor.ui.layouts import confirm_action, show_error_and_raise from trezor.ui.layouts import confirm_action, show_error_and_raise
@ -14,8 +13,8 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
ctx, ctx,
"warning_wrong_sd", "warning_wrong_sd",
"SD card protection", "SD card protection",
action="Wrong SD card.", "Wrong SD card.",
description="Please insert the correct SD card for this device.", "Please insert the correct SD card for this device.",
verb="Retry", verb="Retry",
verb_cancel="Abort", verb_cancel="Abort",
icon=ui.ICON_WRONG, icon=ui.ICON_WRONG,
@ -26,9 +25,9 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
await show_error_and_raise( await show_error_and_raise(
ctx, ctx,
"warning_wrong_sd", "warning_wrong_sd",
header="SD card protection", "Please unplug the\ndevice and insert the correct SD card.",
subheader="Wrong SD card.", "SD card protection",
content="Please unplug the\ndevice and insert the correct SD card.", "Wrong SD card.",
exc=SdCardUnavailable("Wrong SD card."), exc=SdCardUnavailable("Wrong SD card."),
) )
@ -39,8 +38,8 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
ctx, ctx,
"warning_no_sd", "warning_no_sd",
"SD card protection", "SD card protection",
action="SD card required.", "SD card required.",
description="Please insert your SD card.", "Please insert your SD card.",
verb="Retry", verb="Retry",
verb_cancel="Abort", verb_cancel="Abort",
icon=ui.ICON_WRONG, icon=ui.ICON_WRONG,
@ -51,9 +50,9 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
await show_error_and_raise( await show_error_and_raise(
ctx, ctx,
"warning_no_sd", "warning_no_sd",
header="SD card protection", "Please unplug the\ndevice and insert your SD card.",
subheader="SD card required.", "SD card protection",
content="Please unplug the\ndevice and insert your SD card.", "SD card required.",
exc=SdCardUnavailable("SD card required."), exc=SdCardUnavailable("SD card required."),
) )
@ -64,8 +63,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
ctx, ctx,
"warning_format_sd", "warning_format_sd",
"SD card error", "SD card error",
action="Unknown filesystem.", "Unknown filesystem.",
description="Use a different card or format the SD card to the FAT32 filesystem.", "Use a different card or format the SD card to the FAT32 filesystem.",
icon=ui.ICON_WRONG, icon=ui.ICON_WRONG,
icon_color=ui.RED, icon_color=ui.RED,
verb="Format", verb="Format",
@ -79,8 +78,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
ctx, ctx,
"confirm_format_sd", "confirm_format_sd",
"Format SD card", "Format SD card",
action="All data on the SD card will be lost.", "All data on the SD card will be lost.",
description="Do you really want to format the SD card?", "Do you really want to format the SD card?",
reverse=True, reverse=True,
verb="Format SD card", verb="Format SD card",
icon=ui.ICON_WIPE, icon=ui.ICON_WIPE,
@ -99,8 +98,8 @@ async def confirm_retry_sd(
ctx, ctx,
"warning_sd_retry", "warning_sd_retry",
"SD card problem", "SD card problem",
action=None, None,
description="There was a problem accessing the SD card.", "There was a problem accessing the SD card.",
icon=ui.ICON_WRONG, icon=ui.ICON_WRONG,
icon_color=ui.RED, icon_color=ui.RED,
verb="Retry", verb="Retry",
@ -121,18 +120,20 @@ async def ensure_sdcard(
filesystem, and allows the user to format the card if a filesystem cannot be filesystem, and allows the user to format the card if a filesystem cannot be
mounted. mounted.
""" """
from trezor import sdcard
while not sdcard.is_present(): while not sdcard.is_present():
await _confirm_retry_insert_card(ctx) await _confirm_retry_insert_card(ctx)
if not ensure_filesystem: if not ensure_filesystem:
return return
fatfs = io.fatfs # local_cache_attribute
while True: while True:
try: try:
try: try:
with sdcard.filesystem(mounted=False): with sdcard.filesystem(mounted=False):
io.fatfs.mount() fatfs.mount()
except io.fatfs.NoFilesystem: except fatfs.NoFilesystem:
# card not formatted. proceed out of the except clause # card not formatted. proceed out of the except clause
pass pass
else: else:
@ -143,9 +144,9 @@ async def ensure_sdcard(
# Proceed to formatting. Failure is caught by the outside OSError handler # Proceed to formatting. Failure is caught by the outside OSError handler
with sdcard.filesystem(mounted=False): with sdcard.filesystem(mounted=False):
io.fatfs.mkfs() fatfs.mkfs()
io.fatfs.mount() fatfs.mount()
io.fatfs.setlabel("TREZOR") fatfs.setlabel("TREZOR")
# format and mount succeeded # format and mount succeeded
return return
@ -158,14 +159,16 @@ async def ensure_sdcard(
async def request_sd_salt( async def request_sd_salt(
ctx: wire.GenericContext = wire.DUMMY_CONTEXT, ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
) -> bytearray | None: ) -> bytearray | None:
if not storage.sd_salt.is_enabled(): import storage.sd_salt as storage_sd_salt
if not storage_sd_salt.is_enabled():
return None return None
while True: while True:
await ensure_sdcard(ctx, ensure_filesystem=False) await ensure_sdcard(ctx, ensure_filesystem=False)
try: try:
return storage.sd_salt.load_sd_salt() return storage_sd_salt.load_sd_salt()
except (storage.sd_salt.WrongSdCard, io.fatfs.NoFilesystem): except (storage_sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
await _confirm_retry_wrong_card(ctx) await _confirm_retry_wrong_card(ctx)
except OSError: except OSError:
# Generic problem with loading the SD salt (hardware problem, or we could # Generic problem with loading the SD salt (hardware problem, or we could

View File

@ -1,14 +1,17 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage import cache, device import storage.cache as storage_cache
from trezor import utils, wire import storage.device as storage_device
from trezor.crypto import bip32, hmac from trezor import utils
from trezor.crypto import hmac
from . import mnemonic from . import mnemonic
from .passphrase import get as get_passphrase from .passphrase import get as get_passphrase
if TYPE_CHECKING: if TYPE_CHECKING:
from .paths import Bip32Path, Slip21Path from .paths import Bip32Path, Slip21Path
from trezor.wire import Context
from trezor.crypto import bip32
class Slip21Node: class Slip21Node:
@ -47,14 +50,16 @@ if not utils.BITCOIN_ONLY:
# We want to derive both the normal seed and the Cardano seed together, AND # We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same # expose a method for Cardano to do the same
async def derive_and_store_roots(ctx: wire.Context) -> None: async def derive_and_store_roots(ctx: Context) -> None:
if not device.is_initialized(): from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
need_seed = not cache.is_set(cache.APP_COMMON_SEED) need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
need_cardano_secret = cache.get( need_cardano_secret = storage_cache.get(
cache.APP_COMMON_DERIVE_CARDANO storage_cache.APP_COMMON_DERIVE_CARDANO
) and not cache.is_set(cache.APP_CARDANO_ICARUS_SECRET) ) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret: if not need_seed and not need_cardano_secret:
return return
@ -63,17 +68,17 @@ if not utils.BITCOIN_ONLY:
if need_seed: if need_seed:
common_seed = mnemonic.get_seed(passphrase) common_seed = mnemonic.get_seed(passphrase)
cache.set(cache.APP_COMMON_SEED, common_seed) storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
if need_cardano_secret: if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase) derive_and_store_secrets(passphrase)
@cache.stored_async(cache.APP_COMMON_SEED) @storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: wire.Context) -> bytes: async def get_seed(ctx: Context) -> bytes:
await derive_and_store_roots(ctx) await derive_and_store_roots(ctx)
common_seed = cache.get(cache.APP_COMMON_SEED) common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None assert common_seed is not None
return common_seed return common_seed
@ -81,15 +86,15 @@ else:
# === Bitcoin-only variant === # === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else. # We use the simple version of `get_seed` that never needs to derive anything else.
@cache.stored_async(cache.APP_COMMON_SEED) @storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: wire.Context) -> bytes: async def get_seed(ctx: Context) -> bytes:
passphrase = await get_passphrase(ctx) passphrase = await get_passphrase(ctx)
return mnemonic.get_seed(passphrase) return mnemonic.get_seed(passphrase)
@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) @storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes: def _get_seed_without_passphrase() -> bytes:
if not device.is_initialized(): if not storage_device.is_initialized():
raise Exception("Device is not initialized") raise Exception("Device is not initialized")
return mnemonic.get_seed(progress_bar=False) return mnemonic.get_seed(progress_bar=False)
@ -97,6 +102,8 @@ def _get_seed_without_passphrase() -> bytes:
def derive_node_without_passphrase( def derive_node_without_passphrase(
path: Bip32Path, curve_name: str = "secp256k1" path: Bip32Path, curve_name: str = "secp256k1"
) -> bip32.HDNode: ) -> bip32.HDNode:
from trezor.crypto import bip32
seed = _get_seed_without_passphrase() seed = _get_seed_without_passphrase()
node = bip32.from_seed(seed, curve_name) node = bip32.from_seed(seed, curve_name)
node.derive_path(path) node.derive_path(path)

View File

@ -1,16 +1,15 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import utils, wire
from trezor.crypto.hashlib import blake256, sha256
from apps.common.writers import write_compact_size
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
def message_digest(coin: CoinInfo, message: bytes) -> bytes: def message_digest(coin: CoinInfo, message: bytes) -> bytes:
from trezor import utils, wire
from trezor.crypto.hashlib import blake256, sha256
from apps.common.writers import write_compact_size
if not utils.BITCOIN_ONLY and coin.decred: if not utils.BITCOIN_ONLY and coin.decred:
h = utils.HashWriter(blake256()) h = utils.HashWriter(blake256())
else: else:
@ -28,6 +27,8 @@ def message_digest(coin: CoinInfo, message: bytes) -> bytes:
def decode_message(message: bytes) -> str: def decode_message(message: bytes) -> str:
from ubinascii import hexlify
try: try:
return bytes(message).decode() return bytes(message).decode()
except UnicodeError: except UnicodeError:

View File

@ -6,61 +6,38 @@ if TYPE_CHECKING:
from trezor.utils import Writer from trezor.utils import Writer
def _write_uint(w: Writer, n: int, bits: int, bigendian: bool) -> int:
ensure(0 <= n <= 2**bits - 1, "overflow")
shifts = range(0, bits, 8)
if bigendian:
shifts = reversed(shifts)
for num in shifts:
w.append((n >> num) & 0xFF)
return bits // 8
def write_uint8(w: Writer, n: int) -> int: def write_uint8(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFF) return _write_uint(w, n, 8, False)
w.append(n)
return 1
def write_uint16_le(w: Writer, n: int) -> int: def write_uint16_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF) return _write_uint(w, n, 16, False)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
return 2
def write_uint32_le(w: Writer, n: int) -> int: def write_uint32_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF) return _write_uint(w, n, 32, False)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 24) & 0xFF)
return 4
def write_uint32_be(w: Writer, n: int) -> int: def write_uint32_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF) return _write_uint(w, n, 32, True)
w.append((n >> 24) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 8) & 0xFF)
w.append(n & 0xFF)
return 4
def write_uint64_le(w: Writer, n: int) -> int: def write_uint64_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF) return _write_uint(w, n, 64, False)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 24) & 0xFF)
w.append((n >> 32) & 0xFF)
w.append((n >> 40) & 0xFF)
w.append((n >> 48) & 0xFF)
w.append((n >> 56) & 0xFF)
return 8
def write_uint64_be(w: Writer, n: int) -> int: def write_uint64_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF) return _write_uint(w, n, 64, True)
w.append((n >> 56) & 0xFF)
w.append((n >> 48) & 0xFF)
w.append((n >> 40) & 0xFF)
w.append((n >> 32) & 0xFF)
w.append((n >> 24) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 8) & 0xFF)
w.append(n & 0xFF)
return 8
def write_bytes_unchecked(w: Writer, b: bytes | memoryview) -> int: def write_bytes_unchecked(w: Writer, b: bytes | memoryview) -> int:
@ -82,16 +59,18 @@ def write_bytes_reversed(w: Writer, b: bytes, length: int) -> int:
def write_compact_size(w: Writer, n: int) -> None: def write_compact_size(w: Writer, n: int) -> None:
ensure(0 <= n <= 0xFFFF_FFFF) ensure(0 <= n <= 0xFFFF_FFFF)
append = w.append # local_cache_attribute
if n < 253: if n < 253:
w.append(n & 0xFF) append(n & 0xFF)
elif n < 0x1_0000: elif n < 0x1_0000:
w.append(253) append(253)
write_uint16_le(w, n) write_uint16_le(w, n)
elif n < 0x1_0000_0000: elif n < 0x1_0000_0000:
w.append(254) append(254)
write_uint32_le(w, n) write_uint32_le(w, n)
else: else:
w.append(255) append(255)
write_uint64_le(w, n) write_uint64_le(w, n)

View File

@ -1,5 +1,3 @@
import math
from common import * from common import *
from apps.common.cbor import ( from apps.common.cbor import (
@ -12,43 +10,8 @@ from apps.common.cbor import (
decode, decode,
encode, encode,
encode_streamed, encode_streamed,
utils
) )
# NOTE: moved into tests not to occupy flash space
# in firmware binary, when it is not used in production
def encode_chunked(value: "Value", max_chunk_size: int) -> "Iterator[bytes]":
"""
Returns the encoded value as an iterable of chunks of a given size,
removing the need to reserve a continuous chunk of memory for the
full serialized representation of the value.
"""
if max_chunk_size <= 0:
raise ValueError
chunks = encode_streamed(value)
chunk_buffer = utils.empty_bytearray(max_chunk_size)
try:
current_chunk_view = utils.BufferReader(next(chunks))
while True:
num_bytes_to_write = min(
current_chunk_view.remaining_count(),
max_chunk_size - len(chunk_buffer),
)
chunk_buffer.extend(current_chunk_view.read(num_bytes_to_write))
if len(chunk_buffer) >= max_chunk_size:
yield chunk_buffer
chunk_buffer[:] = bytes()
if current_chunk_view.remaining_count() == 0:
current_chunk_view = utils.BufferReader(next(chunks))
except StopIteration:
if len(chunk_buffer) > 0:
yield chunk_buffer
class TestCardanoCbor(unittest.TestCase): class TestCardanoCbor(unittest.TestCase):
def test_create_array_header(self): def test_create_array_header(self):
@ -211,43 +174,6 @@ class TestCardanoCbor(unittest.TestCase):
self.assertEqual(b''.join(encoded_streamed), encoded) self.assertEqual(b''.join(encoded_streamed), encoded)
def test_encode_chunked(self):
large_dict = {i: i for i in range(100)}
encoded = encode(large_dict)
encoded_len = len(encoded)
assert encoded_len == 354
arbitrary_encoded_len_factor = 59
arbitrary_power_of_two = 64
larger_than_encoded_len = encoded_len + 1
for max_chunk_size in [
1,
10,
arbitrary_encoded_len_factor,
arbitrary_power_of_two,
encoded_len,
larger_than_encoded_len
]:
encoded_chunks = [
bytes(chunk) for chunk in encode_chunked(large_dict, max_chunk_size)
]
expected_number_of_chunks = math.ceil(len(encoded) / max_chunk_size)
self.assertEqual(len(encoded_chunks), expected_number_of_chunks)
# all chunks except the last should be of chunk_size
for i in range(len(encoded_chunks) - 1):
self.assertEqual(len(encoded_chunks[i]), max_chunk_size)
# last chunk should contain the remaining bytes or the whole chunk
remaining_bytes = len(encoded) % max_chunk_size
expected_last_chunk_size = remaining_bytes if remaining_bytes > 0 else max_chunk_size
self.assertEqual(len(encoded_chunks[-1]), expected_last_chunk_size)
self.assertEqual(b''.join(encoded_chunks), encoded)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -0,0 +1,44 @@
from common import *
import apps.common.writers as writers
class TestSeed(unittest.TestCase):
def test_write_uint8(self):
buf = bytearray()
writers.write_uint8(buf, 0x12)
self.assertEqual(buf, b"\x12")
def test_write_uint16_le(self):
buf = bytearray()
writers.write_uint16_le(buf, 0x1234)
self.assertEqual(buf, b"\x34\x12")
def test_write_uint16_le_overflow(self):
buf = bytearray()
with self.assertRaises(AssertionError):
writers.write_uint16_le(buf, 0x12345678)
def test_write_uint32_le(self):
buf = bytearray()
writers.write_uint32_le(buf, 0x12345678)
self.assertEqual(buf, b"\x78\x56\x34\x12")
def test_write_uint64_le(self):
buf = bytearray()
writers.write_uint64_le(buf, 0x1234567890abcdef)
self.assertEqual(buf, b"\xef\xcd\xab\x90\x78\x56\x34\x12")
def test_write_uint32_be(self):
buf = bytearray()
writers.write_uint32_be(buf, 0x12345678)
self.assertEqual(buf, b"\x12\x34\x56\x78")
def test_write_uint64_be(self):
buf = bytearray()
writers.write_uint64_be(buf, 0x1234567890abcdef)
self.assertEqual(buf, b"\x12\x34\x56\x78\x90\xab\xcd\xef")
if __name__ == "__main__":
unittest.main()