chore(core): decrease common size by 5200 bytes

pull/2633/head
grdddj 2 years ago committed by matejcik
parent 5e7cc8b692
commit 3711fd0f19

@ -1,9 +1,6 @@
from typing import TYPE_CHECKING
from trezor import utils, wire
from trezor.crypto import hashlib, hmac
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
from trezor import utils
if TYPE_CHECKING:
from apps.common.keychain import Keychain
@ -14,12 +11,18 @@ _ADDRESS_MAC_KEY_PATH = [b"SLIP-0024", b"Address MAC key"]
def check_address_mac(
address: str, mac: bytes, slip44: int, keychain: Keychain
) -> None:
from trezor import wire
from trezor.crypto import hashlib
expected_mac = get_address_mac(address, slip44, keychain)
if len(mac) != hashlib.sha256.digest_size or not utils.consteq(expected_mac, mac):
raise wire.DataError("Invalid address MAC.")
def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes:
from trezor.crypto import hmac
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
# k = Key(m/"SLIP-0024"/"Address MAC key")
node = keychain.derive_slip21(_ADDRESS_MAC_KEY_PATH)

@ -1,45 +1,53 @@
from typing import Iterable
import storage.cache
import storage.cache as storage_cache
from trezor import protobuf
from trezor.enums import MessageType
from trezor.utils import ensure
WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
}
APP_COMMON_AUTHORIZATION_DATA = (
storage_cache.APP_COMMON_AUTHORIZATION_DATA
) # global_import_cache
APP_COMMON_AUTHORIZATION_TYPE = (
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
) # global_import_cache
def is_set() -> bool:
return bool(storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE))
return bool(storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE))
def set(auth_message: protobuf.MessageType) -> None:
from trezor.utils import ensure
buffer = protobuf.dump_message_buffer(auth_message)
# only wire-level messages can be stored as authorization
# (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
storage.cache.set(
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
storage_cache.set(
APP_COMMON_AUTHORIZATION_TYPE,
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
)
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, buffer)
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return None
msg_wire_type = int.from_bytes(stored_auth_type, "big")
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, msg_wire_type)
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
if stored_auth_type is None:
return ()
@ -48,5 +56,5 @@ def get_wire_types() -> Iterable[int]:
def clear() -> None:
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)

@ -2,16 +2,14 @@
Minimalistic CBOR implementation, supports only what we need in cardano.
"""
import ustruct as struct
from micropython import const
from typing import TYPE_CHECKING
from trezor import log, utils
from . import readers
from trezor import log
if TYPE_CHECKING:
from typing import Any, Generic, Iterator, TypeVar
from trezor.utils import BufferReader
K = TypeVar("K")
V = TypeVar("V")
@ -48,16 +46,18 @@ _CBOR_RAW_TAG = const(0x18)
def _header(typ: int, l: int) -> bytes:
from ustruct import pack
if l < 24:
return struct.pack(">B", typ + l)
return pack(">B", typ + l)
elif l < 2**8:
return struct.pack(">BB", typ + 24, l)
return pack(">BB", typ + 24, l)
elif l < 2**16:
return struct.pack(">BH", typ + 25, l)
return pack(">BH", typ + 25, l)
elif l < 2**32:
return struct.pack(">BI", typ + 26, l)
return pack(">BI", typ + 26, l)
elif l < 2**64:
return struct.pack(">BQ", typ + 27, l)
return pack(">BQ", typ + 27, l)
else:
raise NotImplementedError # Length not supported
@ -117,7 +117,9 @@ def _cbor_encode(value: Value) -> Iterator[bytes]:
raise NotImplementedError
def _read_length(r: utils.BufferReader, aux: int) -> int:
def _read_length(r: BufferReader, aux: int) -> int:
from . import readers
if aux < _CBOR_UINT8_FOLLOWS:
return aux
elif aux == _CBOR_UINT8_FOLLOWS:
@ -132,7 +134,7 @@ def _read_length(r: utils.BufferReader, aux: int) -> int:
raise NotImplementedError # Length not supported
def _cbor_decode(r: utils.BufferReader) -> Value:
def _cbor_decode(r: BufferReader) -> Value:
fb = r.get()
fb_type = fb & _CBOR_TYPE_MASK
fb_aux = fb & _CBOR_INFO_BITS
@ -220,6 +222,7 @@ class Tagged:
)
# TODO: this seems to be unused - is checked against, but is never created???
class Raw:
def __init__(self, value: Value):
self.value = value
@ -272,7 +275,9 @@ def encode_streamed(value: Value) -> Iterator[bytes]:
def decode(cbor: bytes, offset: int = 0) -> Value:
r = utils.BufferReader(cbor)
from trezor.utils import BufferReader
r = BufferReader(cbor)
r.seek(offset)
res = _cbor_decode(r)
if r.remaining_count():

File diff suppressed because it is too large Load Diff

@ -1,6 +1,9 @@
# generated from coininfo.py.mako
# (by running `make templates` in `core`)
# do not edit manually!
# NOTE: using positional arguments saves 4500 bytes of flash size
from typing import Any
from trezor import utils
@ -142,7 +145,7 @@ def by_name(name: str) -> CoinInfo:
if name == ${black_repr(coin["coin_name"])}:
return CoinInfo(
% for attr, func in ATTRIBUTES:
${attr}=${func(coin[attr])},
${func(coin[attr])}, # ${attr}
% endfor
)
% endfor
@ -151,7 +154,7 @@ def by_name(name: str) -> CoinInfo:
if name == ${black_repr(coin["coin_name"])}:
return CoinInfo(
% for attr, func in ATTRIBUTES:
${attr}=${func(coin[attr])},
${func(coin[attr])}, # ${attr}
% endfor
)
% endfor

@ -1,11 +1,9 @@
import sys
from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto import bip32
from trezor.wire import DataError
from . import paths, safety_checks
from .seed import Slip21Node, get_seed
if TYPE_CHECKING:
from typing import (
@ -18,6 +16,8 @@ if TYPE_CHECKING:
from typing_extensions import Protocol
from trezor.protobuf import MessageType
from trezor.wire import Context
from .seed import Slip21Node
T = TypeVar("T")
@ -36,15 +36,15 @@ if TYPE_CHECKING:
MsgIn = TypeVar("MsgIn", bound=MessageType)
MsgOut = TypeVar("MsgOut", bound=MessageType)
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
Handler = Callable[[Context, MsgIn], Awaitable[MsgOut]]
HandlerWithKeychain = Callable[[Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
class Deletable(Protocol):
def __del__(self) -> None:
...
FORBIDDEN_KEY_PATH = wire.DataError("Forbidden key path")
FORBIDDEN_KEY_PATH = DataError("Forbidden key path")
class LRUCache:
@ -54,13 +54,15 @@ class LRUCache:
self.cache: dict[Any, Deletable] = {}
def insert(self, key: Any, value: Deletable) -> None:
if key in self.cache_keys:
self.cache_keys.remove(key)
self.cache_keys.insert(0, key)
cache_keys = self.cache_keys # local_cache_attribute
if key in cache_keys:
cache_keys.remove(key)
cache_keys.insert(0, key)
self.cache[key] = value
if len(self.cache_keys) > self.size:
dropped_key = self.cache_keys.pop()
if len(cache_keys) > self.size:
dropped_key = cache_keys.pop()
self.cache[dropped_key].__del__()
del self.cache[dropped_key]
@ -103,7 +105,7 @@ class Keychain:
def verify_path(self, path: paths.Bip32Path) -> None:
if "ed25519" in self.curve and not paths.path_is_hardened(path):
raise wire.DataError("Non-hardened paths unsupported on Ed25519")
raise DataError("Non-hardened paths unsupported on Ed25519")
if not safety_checks.is_strict():
return
@ -137,8 +139,8 @@ class Keychain:
if self._root_fingerprint is None:
# derive m/0' to obtain root_fingerprint
n = self._derive_with_cache(
prefix_len=0,
path=[0 | paths.HARDENED],
0,
[0 | paths.HARDENED],
new_root=lambda: bip32.from_seed(self.seed, self.curve),
)
self._root_fingerprint = n.fingerprint()
@ -147,20 +149,22 @@ class Keychain:
def derive(self, path: paths.Bip32Path) -> bip32.HDNode:
self.verify_path(path)
return self._derive_with_cache(
prefix_len=3,
path=path,
3,
path,
new_root=lambda: bip32.from_seed(self.seed, self.curve),
)
def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node:
from .seed import Slip21Node
if safety_checks.is_strict() and not any(
ns == path[: len(ns)] for ns in self.slip21_namespaces
):
raise FORBIDDEN_KEY_PATH
return self._derive_with_cache(
prefix_len=1,
path=path,
1,
path,
new_root=lambda: Slip21Node(seed=self.seed),
)
@ -172,11 +176,13 @@ class Keychain:
async def get_keychain(
ctx: wire.Context,
ctx: Context,
curve: str,
schemas: Iterable[paths.PathSchemaType],
slip21_namespaces: Iterable[paths.Slip21Path] = (),
) -> Keychain:
from .seed import get_seed
seed = await get_seed(ctx)
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
return keychain
@ -191,18 +197,15 @@ def with_slip44_keychain(
if not patterns:
raise ValueError # specify a pattern
if allow_testnet:
slip44_ids: int | tuple[int, int] = (slip44_id, 1)
else:
slip44_ids = slip44_id
slip_44_ids = (slip44_id, 1) if allow_testnet else slip44_id
schemas = []
for pattern in patterns:
schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids))
schemas.append(paths.PathSchema.parse(pattern, slip_44_ids))
schemas = [s.copy() for s in schemas]
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, curve, schemas)
with keychain:
return await func(ctx, msg, keychain)
@ -215,6 +218,8 @@ def with_slip44_keychain(
def auto_keychain(
modname: str, allow_testnet: bool = True
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
import sys
rdot = modname.rfind(".")
parent_modname = modname[:rdot]
parent_module = sys.modules[parent_modname]

@ -1,6 +1,10 @@
import storage.device
from trezor import ui, utils, workflow
from trezor.enums import BackupType
from typing import TYPE_CHECKING
import storage.device as storage_device
from trezor import utils
if TYPE_CHECKING:
from trezor.enums import BackupType
def get() -> tuple[bytes | None, BackupType]:
@ -8,11 +12,11 @@ def get() -> tuple[bytes | None, BackupType]:
def get_secret() -> bytes | None:
return storage.device.get_mnemonic_secret()
return storage_device.get_mnemonic_secret()
def get_type() -> BackupType:
return storage.device.get_backup_type()
return storage_device.get_backup_type()
def is_bip39() -> bool:
@ -20,6 +24,8 @@ def is_bip39() -> bool:
If False then SLIP-39 (either Basic or Advanced).
Other invalid values are checked directly in storage.
"""
from trezor.enums import BackupType
return get_type() == BackupType.Bip39
@ -41,8 +47,8 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
else: # SLIP-39
from trezor.crypto import slip39
identifier = storage.device.get_slip39_identifier()
iteration_exponent = storage.device.get_slip39_iteration_exponent()
identifier = storage_device.get_slip39_identifier()
iteration_exponent = storage_device.get_slip39_iteration_exponent()
if identifier is None or iteration_exponent is None:
# Identifier or exponent expected but not found
raise RuntimeError
@ -84,6 +90,7 @@ if not utils.BITCOIN_ONLY:
def _start_progress() -> None:
from trezor import workflow
from trezor.ui.layouts import draw_simple_text
# Because we are drawing to the screen manually, without a layout, we
@ -93,6 +100,8 @@ def _start_progress() -> None:
def _render_progress(progress: int, total: int) -> None:
from trezor import ui
p = 1000 * progress // total
ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
ui.refresh()

@ -1,66 +1,72 @@
from micropython import const
from typing import TYPE_CHECKING
import storage.device
from trezor import wire, workflow
import storage.device as storage_device
from trezor.wire import DataError
if TYPE_CHECKING:
from trezor.wire import Context
_MAX_PASSPHRASE_LEN = const(50)
def is_enabled() -> bool:
return storage.device.is_passphrase_enabled()
async def get(ctx: wire.Context) -> str:
if is_enabled():
return await _request_from_user(ctx)
else:
return ""
return storage_device.is_passphrase_enabled()
async def _request_from_user(ctx: wire.Context) -> str:
workflow.close_others() # request exclusive UI access
if storage.device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device
async def get(ctx: Context) -> str:
from trezor import workflow
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
if not is_enabled():
return ""
else:
passphrase = await _request_on_host(ctx)
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise wire.DataError(
f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
)
workflow.close_others() # request exclusive UI access
if storage_device.get_passphrase_always_on_device():
from trezor.ui.layouts import request_passphrase_on_device
return passphrase
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
else:
passphrase = await _request_on_host(ctx)
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
return passphrase
async def _request_on_host(ctx: wire.Context) -> str:
async def _request_on_host(ctx: Context) -> str:
from trezor.messages import PassphraseAck, PassphraseRequest
from trezor.ui.layouts import draw_simple_text
_entry_dialog()
# _entry_dialog
draw_simple_text(
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
)
request = PassphraseRequest()
ack = await ctx.call(request, PassphraseAck)
passphrase = ack.passphrase # local_cache_attribute
if ack.on_device:
from trezor.ui.layouts import request_passphrase_on_device
if ack.passphrase is not None:
raise wire.DataError("Passphrase provided when it should not be")
if passphrase is not None:
raise DataError("Passphrase provided when it should not be")
return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
if ack.passphrase is None:
raise wire.DataError(
if passphrase is None:
raise DataError(
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
)
# non-empty passphrase
if ack.passphrase:
if passphrase:
from trezor import ui
from trezor.ui.layouts import confirm_action, confirm_blob
await confirm_action(
ctx,
"passphrase_host1",
title="Hidden wallet",
"Hidden wallet",
description="Access hidden wallet?\n\nNext screen will show\nthe passphrase!",
icon=ui.ICON_CONFIG,
)
@ -68,19 +74,11 @@ async def _request_on_host(ctx: wire.Context) -> str:
await confirm_blob(
ctx,
"passphrase_host2",
title="Hidden wallet",
description="Use this passphrase?\n",
data=ack.passphrase,
"Hidden wallet",
passphrase,
"Use this passphrase?\n",
icon=ui.ICON_CONFIG,
icon_color=ui.ORANGE_ICON,
)
return ack.passphrase
def _entry_dialog() -> None:
from trezor.ui.layouts import draw_simple_text
draw_simple_text(
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
)
return passphrase

@ -197,23 +197,24 @@ class PathSchema:
# optionally replace a keyword
component = cls.REPLACEMENTS.get(component, component)
append = schema.append # local_cache_attribute
if "-" in component:
# parse as a range
a, b = [parse(s) for s in component.split("-", 1)]
schema.append(Interval(a, b))
append(Interval(a, b))
elif "," in component:
# parse as a list of values
schema.append(set(parse(s) for s in component.split(",")))
append(set(parse(s) for s in component.split(",")))
elif component == "coin_type":
# substitute SLIP-44 ids
schema.append(set(parse(s) for s in slip44_id))
append(set(parse(s) for s in slip44_id))
else:
# plain constant
schema.append((parse(component),))
append((parse(component),))
return cls(schema, trailing_components, compact=True)
@ -258,18 +259,19 @@ class PathSchema:
path. If the restriction results in a never-matching schema, then False
is returned.
"""
schema = self.schema # local_cache_attribute
for i, value in enumerate(path):
if i < len(self.schema):
if i < len(schema):
# Ensure that the path is a prefix of the schema.
if value not in self.schema[i]:
if value not in schema[i]:
self.set_never_matching()
return False
# Restrict the schema component if there are multiple choices.
component = self.schema[i]
component = schema[i]
if not isinstance(component, tuple) or len(component) != 1:
self.schema[i] = (value,)
schema[i] = (value,)
else:
# The path is longer than the schema. We need to restrict the
# trailing components.
@ -278,7 +280,7 @@ class PathSchema:
self.set_never_matching()
return False
self.schema.append((value,))
schema.append((value,))
return True
@ -286,6 +288,7 @@ class PathSchema:
def __repr__(self) -> str:
components = ["m"]
append = components.append # local_cache_attribute
def unharden(item: int) -> int:
return item ^ (item & HARDENED)
@ -294,7 +297,7 @@ class PathSchema:
if isinstance(component, Interval):
a, b = component.min, component.max
prime = "'" if a & HARDENED else ""
components.append(f"[{unharden(a)}-{unharden(b)}]{prime}")
append(f"[{unharden(a)}-{unharden(b)}]{prime}")
else:
# typechecker thinks component is a Contanier but we're using it
# as a Collection.
@ -307,15 +310,15 @@ class PathSchema:
component_str = "[" + component_str + "]"
if next(iter(collection)) & HARDENED:
component_str += "'"
components.append(component_str)
append(component_str)
if self.trailing_components:
for key, val in self.WILDCARD_RANGES.items():
if self.trailing_components is val:
components.append(key)
append(key)
break
else:
components.append("???")
append("???")
return "<schema:" + "/".join(components) + ">"
@ -362,7 +365,7 @@ def path_is_hardened(address_n: Bip32Path) -> bool:
def address_n_to_str(address_n: Iterable[int]) -> str:
def path_item(i: int) -> str:
def _path_item(i: int) -> str:
if i & HARDENED:
return str(i ^ HARDENED) + "'"
else:
@ -371,4 +374,4 @@ def address_n_to_str(address_n: Iterable[int]) -> str:
if not address_n:
return "m"
return "m/" + "/".join(path_item(i) for i in address_n)
return "m/" + "/".join(_path_item(i) for i in address_n)

@ -1,27 +1,32 @@
from trezor.utils import BufferReader
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.utils import BufferReader
def read_compact_size(r: BufferReader) -> int:
prefix = r.get()
get = r.get # local_cache_attribute
prefix = get()
if prefix < 253:
n = prefix
elif prefix == 253:
n = r.get()
n += r.get() << 8
n = get()
n += get() << 8
elif prefix == 254:
n = r.get()
n += r.get() << 8
n += r.get() << 16
n += r.get() << 24
n = get()
n += get() << 8
n += get() << 16
n += get() << 24
elif prefix == 255:
n = r.get()
n += r.get() << 8
n += r.get() << 16
n += r.get() << 24
n += r.get() << 32
n += r.get() << 40
n += r.get() << 48
n += r.get() << 56
n = get()
n += get() << 8
n += get() << 16
n += get() << 24
n += get() << 32
n += get() << 40
n += get() << 48
n += get() << 56
else:
raise ValueError
return n

@ -1,20 +1,25 @@
import utime
from typing import Any, NoReturn
from typing import TYPE_CHECKING
import storage.cache
import storage.sd_salt
import storage.cache as storage_cache
from trezor import config, wire
from .sdcard import SdCardUnavailable, request_sd_salt
from .sdcard import request_sd_salt
if TYPE_CHECKING:
from typing import Any, NoReturn
from trezor.wire import Context, GenericContext
def can_lock_device() -> bool:
"""Return True if the device has a PIN set or SD-protect enabled."""
import storage.sd_salt
return config.has_pin() or storage.sd_salt.is_enabled()
async def request_pin(
ctx: wire.GenericContext,
ctx: GenericContext,
prompt: str = "Enter your PIN",
attempts_remaining: int | None = None,
allow_cancel: bool = True,
@ -24,26 +29,26 @@ async def request_pin(
return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel)
async def request_pin_confirm(ctx: wire.Context, *args: Any, **kwargs: Any) -> str:
async def request_pin_confirm(ctx: Context, *args: Any, **kwargs: Any) -> str:
while True:
pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs)
pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs)
if pin1 == pin2:
return pin1
await pin_mismatch()
await _pin_mismatch()
async def pin_mismatch() -> None:
async def _pin_mismatch() -> None:
from trezor.ui.layouts import show_popup
await show_popup(
title="PIN mismatch",
description="The PINs you entered\ndo not match.\n\nPlease try again.",
"PIN mismatch",
"The PINs you entered\ndo not match.\n\nPlease try again.",
)
async def request_pin_and_sd_salt(
ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
ctx: Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
) -> tuple[str, bytearray | None]:
if config.has_pin():
pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel)
@ -58,21 +63,23 @@ async def request_pin_and_sd_salt(
def _set_last_unlock_time() -> None:
now = utime.ticks_ms()
storage.cache.set_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
def _get_last_unlock_time() -> int:
return storage.cache.get_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK) or 0
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
async def verify_user_pin(
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
ctx: GenericContext = wire.DUMMY_CONTEXT,
prompt: str = "Enter your PIN",
allow_cancel: bool = True,
retry: bool = True,
cache_time_ms: int = 0,
) -> None:
last_unlock = _get_last_unlock_time()
from .sdcard import SdCardUnavailable
# _get_last_unlock_time
last_unlock = int.from_bytes(
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
)
if (
cache_time_ms
and last_unlock
@ -112,28 +119,28 @@ async def verify_user_pin(
raise wire.PinInvalid
async def error_pin_invalid(ctx: wire.Context) -> NoReturn:
async def error_pin_invalid(ctx: Context) -> NoReturn:
from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise(
ctx,
"warning_wrong_pin",
header="Wrong PIN",
content="The PIN you entered is invalid.",
"The PIN you entered is invalid.",
"Wrong PIN", # header
red=True,
exc=wire.PinInvalid,
)
assert False
async def error_pin_matches_wipe_code(ctx: wire.Context) -> NoReturn:
async def error_pin_matches_wipe_code(ctx: Context) -> NoReturn:
from trezor.ui.layouts import show_error_and_raise
await show_error_and_raise(
ctx,
"warning_invalid_new_pin",
header="Invalid PIN",
content="The new PIN must be different from your\nwipe code.",
"The new PIN must be different from your\nwipe code.",
"Invalid PIN", # header
red=True,
exc=wire.PinInvalid,
)

@ -1,5 +1,5 @@
import storage.cache
import storage.device
import storage.cache as storage_cache
import storage.device as storage_device
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
from trezor.enums import SafetyCheckLevel
@ -9,11 +9,11 @@ def read_setting() -> SafetyCheckLevel:
"""
Returns the effective safety check level.
"""
temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
else:
stored = storage.device.safety_check_level()
stored = storage_device.safety_check_level()
if stored == SAFETY_CHECK_LEVEL_STRICT:
return SafetyCheckLevel.Strict
elif stored == SAFETY_CHECK_LEVEL_PROMPT:
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
Changes the safety level settings.
"""
if level == SafetyCheckLevel.Strict:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily:
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else:
raise ValueError("Unknown SafetyCheckLevel")

@ -1,6 +1,5 @@
import storage.sd_salt
from storage.sd_salt import SD_CARD_HOT_SWAPPABLE
from trezor import io, sdcard, ui, wire
from trezor import io, ui, wire
from trezor.ui.layouts import confirm_action, show_error_and_raise
@ -14,8 +13,8 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
ctx,
"warning_wrong_sd",
"SD card protection",
action="Wrong SD card.",
description="Please insert the correct SD card for this device.",
"Wrong SD card.",
"Please insert the correct SD card for this device.",
verb="Retry",
verb_cancel="Abort",
icon=ui.ICON_WRONG,
@ -26,9 +25,9 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
await show_error_and_raise(
ctx,
"warning_wrong_sd",
header="SD card protection",
subheader="Wrong SD card.",
content="Please unplug the\ndevice and insert the correct SD card.",
"Please unplug the\ndevice and insert the correct SD card.",
"SD card protection",
"Wrong SD card.",
exc=SdCardUnavailable("Wrong SD card."),
)
@ -39,8 +38,8 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
ctx,
"warning_no_sd",
"SD card protection",
action="SD card required.",
description="Please insert your SD card.",
"SD card required.",
"Please insert your SD card.",
verb="Retry",
verb_cancel="Abort",
icon=ui.ICON_WRONG,
@ -51,9 +50,9 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
await show_error_and_raise(
ctx,
"warning_no_sd",
header="SD card protection",
subheader="SD card required.",
content="Please unplug the\ndevice and insert your SD card.",
"Please unplug the\ndevice and insert your SD card.",
"SD card protection",
"SD card required.",
exc=SdCardUnavailable("SD card required."),
)
@ -64,8 +63,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
ctx,
"warning_format_sd",
"SD card error",
action="Unknown filesystem.",
description="Use a different card or format the SD card to the FAT32 filesystem.",
"Unknown filesystem.",
"Use a different card or format the SD card to the FAT32 filesystem.",
icon=ui.ICON_WRONG,
icon_color=ui.RED,
verb="Format",
@ -79,8 +78,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
ctx,
"confirm_format_sd",
"Format SD card",
action="All data on the SD card will be lost.",
description="Do you really want to format the SD card?",
"All data on the SD card will be lost.",
"Do you really want to format the SD card?",
reverse=True,
verb="Format SD card",
icon=ui.ICON_WIPE,
@ -99,8 +98,8 @@ async def confirm_retry_sd(
ctx,
"warning_sd_retry",
"SD card problem",
action=None,
description="There was a problem accessing the SD card.",
None,
"There was a problem accessing the SD card.",
icon=ui.ICON_WRONG,
icon_color=ui.RED,
verb="Retry",
@ -121,18 +120,20 @@ async def ensure_sdcard(
filesystem, and allows the user to format the card if a filesystem cannot be
mounted.
"""
from trezor import sdcard
while not sdcard.is_present():
await _confirm_retry_insert_card(ctx)
if not ensure_filesystem:
return
fatfs = io.fatfs # local_cache_attribute
while True:
try:
try:
with sdcard.filesystem(mounted=False):
io.fatfs.mount()
except io.fatfs.NoFilesystem:
fatfs.mount()
except fatfs.NoFilesystem:
# card not formatted. proceed out of the except clause
pass
else:
@ -143,9 +144,9 @@ async def ensure_sdcard(
# Proceed to formatting. Failure is caught by the outside OSError handler
with sdcard.filesystem(mounted=False):
io.fatfs.mkfs()
io.fatfs.mount()
io.fatfs.setlabel("TREZOR")
fatfs.mkfs()
fatfs.mount()
fatfs.setlabel("TREZOR")
# format and mount succeeded
return
@ -158,14 +159,16 @@ async def ensure_sdcard(
async def request_sd_salt(
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
) -> bytearray | None:
if not storage.sd_salt.is_enabled():
import storage.sd_salt as storage_sd_salt
if not storage_sd_salt.is_enabled():
return None
while True:
await ensure_sdcard(ctx, ensure_filesystem=False)
try:
return storage.sd_salt.load_sd_salt()
except (storage.sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
return storage_sd_salt.load_sd_salt()
except (storage_sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
await _confirm_retry_wrong_card(ctx)
except OSError:
# Generic problem with loading the SD salt (hardware problem, or we could

@ -1,14 +1,17 @@
from typing import TYPE_CHECKING
from storage import cache, device
from trezor import utils, wire
from trezor.crypto import bip32, hmac
import storage.cache as storage_cache
import storage.device as storage_device
from trezor import utils
from trezor.crypto import hmac
from . import mnemonic
from .passphrase import get as get_passphrase
if TYPE_CHECKING:
from .paths import Bip32Path, Slip21Path
from trezor.wire import Context
from trezor.crypto import bip32
class Slip21Node:
@ -47,14 +50,16 @@ if not utils.BITCOIN_ONLY:
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def derive_and_store_roots(ctx: wire.Context) -> None:
if not device.is_initialized():
async def derive_and_store_roots(ctx: Context) -> None:
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
need_seed = not cache.is_set(cache.APP_COMMON_SEED)
need_cardano_secret = cache.get(
cache.APP_COMMON_DERIVE_CARDANO
) and not cache.is_set(cache.APP_CARDANO_ICARUS_SECRET)
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
need_cardano_secret = storage_cache.get(
storage_cache.APP_COMMON_DERIVE_CARDANO
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
@ -63,17 +68,17 @@ if not utils.BITCOIN_ONLY:
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
cache.set(cache.APP_COMMON_SEED, common_seed)
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase)
@cache.stored_async(cache.APP_COMMON_SEED)
async def get_seed(ctx: wire.Context) -> bytes:
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: Context) -> bytes:
await derive_and_store_roots(ctx)
common_seed = cache.get(cache.APP_COMMON_SEED)
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None
return common_seed
@ -81,15 +86,15 @@ else:
# === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else.
@cache.stored_async(cache.APP_COMMON_SEED)
async def get_seed(ctx: wire.Context) -> bytes:
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed(ctx: Context) -> bytes:
passphrase = await get_passphrase(ctx)
return mnemonic.get_seed(passphrase)
@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not device.is_initialized():
if not storage_device.is_initialized():
raise Exception("Device is not initialized")
return mnemonic.get_seed(progress_bar=False)
@ -97,6 +102,8 @@ def _get_seed_without_passphrase() -> bytes:
def derive_node_without_passphrase(
path: Bip32Path, curve_name: str = "secp256k1"
) -> bip32.HDNode:
from trezor.crypto import bip32
seed = _get_seed_without_passphrase()
node = bip32.from_seed(seed, curve_name)
node.derive_path(path)

@ -1,16 +1,15 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import utils, wire
from trezor.crypto.hashlib import blake256, sha256
from apps.common.writers import write_compact_size
if TYPE_CHECKING:
from apps.common.coininfo import CoinInfo
def message_digest(coin: CoinInfo, message: bytes) -> bytes:
from trezor import utils, wire
from trezor.crypto.hashlib import blake256, sha256
from apps.common.writers import write_compact_size
if not utils.BITCOIN_ONLY and coin.decred:
h = utils.HashWriter(blake256())
else:
@ -28,6 +27,8 @@ def message_digest(coin: CoinInfo, message: bytes) -> bytes:
def decode_message(message: bytes) -> str:
from ubinascii import hexlify
try:
return bytes(message).decode()
except UnicodeError:

@ -6,61 +6,38 @@ if TYPE_CHECKING:
from trezor.utils import Writer
def _write_uint(w: Writer, n: int, bits: int, bigendian: bool) -> int:
ensure(0 <= n <= 2**bits - 1, "overflow")
shifts = range(0, bits, 8)
if bigendian:
shifts = reversed(shifts)
for num in shifts:
w.append((n >> num) & 0xFF)
return bits // 8
def write_uint8(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFF)
w.append(n)
return 1
return _write_uint(w, n, 8, False)
def write_uint16_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
return 2
return _write_uint(w, n, 16, False)
def write_uint32_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 24) & 0xFF)
return 4
return _write_uint(w, n, 32, False)
def write_uint32_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF)
w.append((n >> 24) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 8) & 0xFF)
w.append(n & 0xFF)
return 4
return _write_uint(w, n, 32, True)
def write_uint64_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 24) & 0xFF)
w.append((n >> 32) & 0xFF)
w.append((n >> 40) & 0xFF)
w.append((n >> 48) & 0xFF)
w.append((n >> 56) & 0xFF)
return 8
return _write_uint(w, n, 64, False)
def write_uint64_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF)
w.append((n >> 56) & 0xFF)
w.append((n >> 48) & 0xFF)
w.append((n >> 40) & 0xFF)
w.append((n >> 32) & 0xFF)
w.append((n >> 24) & 0xFF)
w.append((n >> 16) & 0xFF)
w.append((n >> 8) & 0xFF)
w.append(n & 0xFF)
return 8
return _write_uint(w, n, 64, True)
def write_bytes_unchecked(w: Writer, b: bytes | memoryview) -> int:
@ -82,16 +59,18 @@ def write_bytes_reversed(w: Writer, b: bytes, length: int) -> int:
def write_compact_size(w: Writer, n: int) -> None:
ensure(0 <= n <= 0xFFFF_FFFF)
append = w.append # local_cache_attribute
if n < 253:
w.append(n & 0xFF)
append(n & 0xFF)
elif n < 0x1_0000:
w.append(253)
append(253)
write_uint16_le(w, n)
elif n < 0x1_0000_0000:
w.append(254)
append(254)
write_uint32_le(w, n)
else:
w.append(255)
append(255)
write_uint64_le(w, n)

@ -1,5 +1,3 @@
import math
from common import *
from apps.common.cbor import (
@ -12,43 +10,8 @@ from apps.common.cbor import (
decode,
encode,
encode_streamed,
utils
)
# NOTE: moved into tests not to occupy flash space
# in firmware binary, when it is not used in production
def encode_chunked(value: "Value", max_chunk_size: int) -> "Iterator[bytes]":
"""
Returns the encoded value as an iterable of chunks of a given size,
removing the need to reserve a continuous chunk of memory for the
full serialized representation of the value.
"""
if max_chunk_size <= 0:
raise ValueError
chunks = encode_streamed(value)
chunk_buffer = utils.empty_bytearray(max_chunk_size)
try:
current_chunk_view = utils.BufferReader(next(chunks))
while True:
num_bytes_to_write = min(
current_chunk_view.remaining_count(),
max_chunk_size - len(chunk_buffer),
)
chunk_buffer.extend(current_chunk_view.read(num_bytes_to_write))
if len(chunk_buffer) >= max_chunk_size:
yield chunk_buffer
chunk_buffer[:] = bytes()
if current_chunk_view.remaining_count() == 0:
current_chunk_view = utils.BufferReader(next(chunks))
except StopIteration:
if len(chunk_buffer) > 0:
yield chunk_buffer
class TestCardanoCbor(unittest.TestCase):
def test_create_array_header(self):
@ -211,43 +174,6 @@ class TestCardanoCbor(unittest.TestCase):
self.assertEqual(b''.join(encoded_streamed), encoded)
def test_encode_chunked(self):
large_dict = {i: i for i in range(100)}
encoded = encode(large_dict)
encoded_len = len(encoded)
assert encoded_len == 354
arbitrary_encoded_len_factor = 59
arbitrary_power_of_two = 64
larger_than_encoded_len = encoded_len + 1
for max_chunk_size in [
1,
10,
arbitrary_encoded_len_factor,
arbitrary_power_of_two,
encoded_len,
larger_than_encoded_len
]:
encoded_chunks = [
bytes(chunk) for chunk in encode_chunked(large_dict, max_chunk_size)
]
expected_number_of_chunks = math.ceil(len(encoded) / max_chunk_size)
self.assertEqual(len(encoded_chunks), expected_number_of_chunks)
# all chunks except the last should be of chunk_size
for i in range(len(encoded_chunks) - 1):
self.assertEqual(len(encoded_chunks[i]), max_chunk_size)
# last chunk should contain the remaining bytes or the whole chunk
remaining_bytes = len(encoded) % max_chunk_size
expected_last_chunk_size = remaining_bytes if remaining_bytes > 0 else max_chunk_size
self.assertEqual(len(encoded_chunks[-1]), expected_last_chunk_size)
self.assertEqual(b''.join(encoded_chunks), encoded)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,44 @@
from common import *
import apps.common.writers as writers
class TestSeed(unittest.TestCase):
def test_write_uint8(self):
buf = bytearray()
writers.write_uint8(buf, 0x12)
self.assertEqual(buf, b"\x12")
def test_write_uint16_le(self):
buf = bytearray()
writers.write_uint16_le(buf, 0x1234)
self.assertEqual(buf, b"\x34\x12")
def test_write_uint16_le_overflow(self):
buf = bytearray()
with self.assertRaises(AssertionError):
writers.write_uint16_le(buf, 0x12345678)
def test_write_uint32_le(self):
buf = bytearray()
writers.write_uint32_le(buf, 0x12345678)
self.assertEqual(buf, b"\x78\x56\x34\x12")
def test_write_uint64_le(self):
buf = bytearray()
writers.write_uint64_le(buf, 0x1234567890abcdef)
self.assertEqual(buf, b"\xef\xcd\xab\x90\x78\x56\x34\x12")
def test_write_uint32_be(self):
buf = bytearray()
writers.write_uint32_be(buf, 0x12345678)
self.assertEqual(buf, b"\x12\x34\x56\x78")
def test_write_uint64_be(self):
buf = bytearray()
writers.write_uint64_be(buf, 0x1234567890abcdef)
self.assertEqual(buf, b"\x12\x34\x56\x78\x90\xab\xcd\xef")
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save