mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-17 21:22:10 +00:00
chore(core): decrease common size by 5200 bytes
This commit is contained in:
parent
5e7cc8b692
commit
3711fd0f19
@ -1,9 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils, wire
|
||||
from trezor.crypto import hashlib, hmac
|
||||
|
||||
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.keychain import Keychain
|
||||
@ -14,12 +11,18 @@ _ADDRESS_MAC_KEY_PATH = [b"SLIP-0024", b"Address MAC key"]
|
||||
def check_address_mac(
|
||||
address: str, mac: bytes, slip44: int, keychain: Keychain
|
||||
) -> None:
|
||||
from trezor import wire
|
||||
from trezor.crypto import hashlib
|
||||
|
||||
expected_mac = get_address_mac(address, slip44, keychain)
|
||||
if len(mac) != hashlib.sha256.digest_size or not utils.consteq(expected_mac, mac):
|
||||
raise wire.DataError("Invalid address MAC.")
|
||||
|
||||
|
||||
def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes:
|
||||
from trezor.crypto import hmac
|
||||
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
|
||||
|
||||
# k = Key(m/"SLIP-0024"/"Address MAC key")
|
||||
node = keychain.derive_slip21(_ADDRESS_MAC_KEY_PATH)
|
||||
|
||||
|
@ -1,45 +1,53 @@
|
||||
from typing import Iterable
|
||||
|
||||
import storage.cache
|
||||
import storage.cache as storage_cache
|
||||
from trezor import protobuf
|
||||
from trezor.enums import MessageType
|
||||
from trezor.utils import ensure
|
||||
|
||||
WIRE_TYPES: dict[int, tuple[int, ...]] = {
|
||||
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
|
||||
}
|
||||
|
||||
APP_COMMON_AUTHORIZATION_DATA = (
|
||||
storage_cache.APP_COMMON_AUTHORIZATION_DATA
|
||||
) # global_import_cache
|
||||
APP_COMMON_AUTHORIZATION_TYPE = (
|
||||
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
|
||||
) # global_import_cache
|
||||
|
||||
|
||||
def is_set() -> bool:
|
||||
return bool(storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE))
|
||||
return bool(storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE))
|
||||
|
||||
|
||||
def set(auth_message: protobuf.MessageType) -> None:
|
||||
from trezor.utils import ensure
|
||||
|
||||
buffer = protobuf.dump_message_buffer(auth_message)
|
||||
|
||||
# only wire-level messages can be stored as authorization
|
||||
# (because only wire-level messages have wire_type, which we use as identifier)
|
||||
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
|
||||
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
|
||||
storage.cache.set(
|
||||
storage.cache.APP_COMMON_AUTHORIZATION_TYPE,
|
||||
storage_cache.set(
|
||||
APP_COMMON_AUTHORIZATION_TYPE,
|
||||
auth_message.MESSAGE_WIRE_TYPE.to_bytes(2, "big"),
|
||||
)
|
||||
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, buffer)
|
||||
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
|
||||
|
||||
|
||||
def get() -> protobuf.MessageType | None:
|
||||
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
||||
stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
if not stored_auth_type:
|
||||
return None
|
||||
|
||||
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
||||
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")
|
||||
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
|
||||
return protobuf.load_message_buffer(buffer, msg_wire_type)
|
||||
|
||||
|
||||
def get_wire_types() -> Iterable[int]:
|
||||
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
||||
stored_auth_type = storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
if stored_auth_type is None:
|
||||
return ()
|
||||
|
||||
@ -48,5 +56,5 @@ def get_wire_types() -> Iterable[int]:
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
||||
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
|
||||
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE)
|
||||
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA)
|
||||
|
@ -2,16 +2,14 @@
|
||||
Minimalistic CBOR implementation, supports only what we need in cardano.
|
||||
"""
|
||||
|
||||
import ustruct as struct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import log, utils
|
||||
|
||||
from . import readers
|
||||
from trezor import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Generic, Iterator, TypeVar
|
||||
from trezor.utils import BufferReader
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
@ -48,16 +46,18 @@ _CBOR_RAW_TAG = const(0x18)
|
||||
|
||||
|
||||
def _header(typ: int, l: int) -> bytes:
|
||||
from ustruct import pack
|
||||
|
||||
if l < 24:
|
||||
return struct.pack(">B", typ + l)
|
||||
return pack(">B", typ + l)
|
||||
elif l < 2**8:
|
||||
return struct.pack(">BB", typ + 24, l)
|
||||
return pack(">BB", typ + 24, l)
|
||||
elif l < 2**16:
|
||||
return struct.pack(">BH", typ + 25, l)
|
||||
return pack(">BH", typ + 25, l)
|
||||
elif l < 2**32:
|
||||
return struct.pack(">BI", typ + 26, l)
|
||||
return pack(">BI", typ + 26, l)
|
||||
elif l < 2**64:
|
||||
return struct.pack(">BQ", typ + 27, l)
|
||||
return pack(">BQ", typ + 27, l)
|
||||
else:
|
||||
raise NotImplementedError # Length not supported
|
||||
|
||||
@ -117,7 +117,9 @@ def _cbor_encode(value: Value) -> Iterator[bytes]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _read_length(r: utils.BufferReader, aux: int) -> int:
|
||||
def _read_length(r: BufferReader, aux: int) -> int:
|
||||
from . import readers
|
||||
|
||||
if aux < _CBOR_UINT8_FOLLOWS:
|
||||
return aux
|
||||
elif aux == _CBOR_UINT8_FOLLOWS:
|
||||
@ -132,7 +134,7 @@ def _read_length(r: utils.BufferReader, aux: int) -> int:
|
||||
raise NotImplementedError # Length not supported
|
||||
|
||||
|
||||
def _cbor_decode(r: utils.BufferReader) -> Value:
|
||||
def _cbor_decode(r: BufferReader) -> Value:
|
||||
fb = r.get()
|
||||
fb_type = fb & _CBOR_TYPE_MASK
|
||||
fb_aux = fb & _CBOR_INFO_BITS
|
||||
@ -220,6 +222,7 @@ class Tagged:
|
||||
)
|
||||
|
||||
|
||||
# TODO: this seems to be unused - is checked against, but is never created???
|
||||
class Raw:
|
||||
def __init__(self, value: Value):
|
||||
self.value = value
|
||||
@ -272,7 +275,9 @@ def encode_streamed(value: Value) -> Iterator[bytes]:
|
||||
|
||||
|
||||
def decode(cbor: bytes, offset: int = 0) -> Value:
|
||||
r = utils.BufferReader(cbor)
|
||||
from trezor.utils import BufferReader
|
||||
|
||||
r = BufferReader(cbor)
|
||||
r.seek(offset)
|
||||
res = _cbor_decode(r)
|
||||
if r.remaining_count():
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,9 @@
|
||||
# generated from coininfo.py.mako
|
||||
# (by running `make templates` in `core`)
|
||||
# do not edit manually!
|
||||
|
||||
# NOTE: using positional arguments saves 4500 bytes of flash size
|
||||
|
||||
from typing import Any
|
||||
|
||||
from trezor import utils
|
||||
@ -142,7 +145,7 @@ def by_name(name: str) -> CoinInfo:
|
||||
if name == ${black_repr(coin["coin_name"])}:
|
||||
return CoinInfo(
|
||||
% for attr, func in ATTRIBUTES:
|
||||
${attr}=${func(coin[attr])},
|
||||
${func(coin[attr])}, # ${attr}
|
||||
% endfor
|
||||
)
|
||||
% endfor
|
||||
@ -151,7 +154,7 @@ def by_name(name: str) -> CoinInfo:
|
||||
if name == ${black_repr(coin["coin_name"])}:
|
||||
return CoinInfo(
|
||||
% for attr, func in ATTRIBUTES:
|
||||
${attr}=${func(coin[attr])},
|
||||
${func(coin[attr])}, # ${attr}
|
||||
% endfor
|
||||
)
|
||||
% endfor
|
||||
|
@ -1,11 +1,9 @@
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import wire
|
||||
from trezor.crypto import bip32
|
||||
from trezor.wire import DataError
|
||||
|
||||
from . import paths, safety_checks
|
||||
from .seed import Slip21Node, get_seed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (
|
||||
@ -18,6 +16,8 @@ if TYPE_CHECKING:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from trezor.protobuf import MessageType
|
||||
from trezor.wire import Context
|
||||
from .seed import Slip21Node
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -36,15 +36,15 @@ if TYPE_CHECKING:
|
||||
MsgIn = TypeVar("MsgIn", bound=MessageType)
|
||||
MsgOut = TypeVar("MsgOut", bound=MessageType)
|
||||
|
||||
Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]]
|
||||
HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
|
||||
Handler = Callable[[Context, MsgIn], Awaitable[MsgOut]]
|
||||
HandlerWithKeychain = Callable[[Context, MsgIn, "Keychain"], Awaitable[MsgOut]]
|
||||
|
||||
class Deletable(Protocol):
|
||||
def __del__(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
FORBIDDEN_KEY_PATH = wire.DataError("Forbidden key path")
|
||||
FORBIDDEN_KEY_PATH = DataError("Forbidden key path")
|
||||
|
||||
|
||||
class LRUCache:
|
||||
@ -54,13 +54,15 @@ class LRUCache:
|
||||
self.cache: dict[Any, Deletable] = {}
|
||||
|
||||
def insert(self, key: Any, value: Deletable) -> None:
|
||||
if key in self.cache_keys:
|
||||
self.cache_keys.remove(key)
|
||||
self.cache_keys.insert(0, key)
|
||||
cache_keys = self.cache_keys # local_cache_attribute
|
||||
|
||||
if key in cache_keys:
|
||||
cache_keys.remove(key)
|
||||
cache_keys.insert(0, key)
|
||||
self.cache[key] = value
|
||||
|
||||
if len(self.cache_keys) > self.size:
|
||||
dropped_key = self.cache_keys.pop()
|
||||
if len(cache_keys) > self.size:
|
||||
dropped_key = cache_keys.pop()
|
||||
self.cache[dropped_key].__del__()
|
||||
del self.cache[dropped_key]
|
||||
|
||||
@ -103,7 +105,7 @@ class Keychain:
|
||||
|
||||
def verify_path(self, path: paths.Bip32Path) -> None:
|
||||
if "ed25519" in self.curve and not paths.path_is_hardened(path):
|
||||
raise wire.DataError("Non-hardened paths unsupported on Ed25519")
|
||||
raise DataError("Non-hardened paths unsupported on Ed25519")
|
||||
|
||||
if not safety_checks.is_strict():
|
||||
return
|
||||
@ -137,8 +139,8 @@ class Keychain:
|
||||
if self._root_fingerprint is None:
|
||||
# derive m/0' to obtain root_fingerprint
|
||||
n = self._derive_with_cache(
|
||||
prefix_len=0,
|
||||
path=[0 | paths.HARDENED],
|
||||
0,
|
||||
[0 | paths.HARDENED],
|
||||
new_root=lambda: bip32.from_seed(self.seed, self.curve),
|
||||
)
|
||||
self._root_fingerprint = n.fingerprint()
|
||||
@ -147,20 +149,22 @@ class Keychain:
|
||||
def derive(self, path: paths.Bip32Path) -> bip32.HDNode:
|
||||
self.verify_path(path)
|
||||
return self._derive_with_cache(
|
||||
prefix_len=3,
|
||||
path=path,
|
||||
3,
|
||||
path,
|
||||
new_root=lambda: bip32.from_seed(self.seed, self.curve),
|
||||
)
|
||||
|
||||
def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node:
|
||||
from .seed import Slip21Node
|
||||
|
||||
if safety_checks.is_strict() and not any(
|
||||
ns == path[: len(ns)] for ns in self.slip21_namespaces
|
||||
):
|
||||
raise FORBIDDEN_KEY_PATH
|
||||
|
||||
return self._derive_with_cache(
|
||||
prefix_len=1,
|
||||
path=path,
|
||||
1,
|
||||
path,
|
||||
new_root=lambda: Slip21Node(seed=self.seed),
|
||||
)
|
||||
|
||||
@ -172,11 +176,13 @@ class Keychain:
|
||||
|
||||
|
||||
async def get_keychain(
|
||||
ctx: wire.Context,
|
||||
ctx: Context,
|
||||
curve: str,
|
||||
schemas: Iterable[paths.PathSchemaType],
|
||||
slip21_namespaces: Iterable[paths.Slip21Path] = (),
|
||||
) -> Keychain:
|
||||
from .seed import get_seed
|
||||
|
||||
seed = await get_seed(ctx)
|
||||
keychain = Keychain(seed, curve, schemas, slip21_namespaces)
|
||||
return keychain
|
||||
@ -191,18 +197,15 @@ def with_slip44_keychain(
|
||||
if not patterns:
|
||||
raise ValueError # specify a pattern
|
||||
|
||||
if allow_testnet:
|
||||
slip44_ids: int | tuple[int, int] = (slip44_id, 1)
|
||||
else:
|
||||
slip44_ids = slip44_id
|
||||
slip_44_ids = (slip44_id, 1) if allow_testnet else slip44_id
|
||||
|
||||
schemas = []
|
||||
for pattern in patterns:
|
||||
schemas.append(paths.PathSchema.parse(pattern=pattern, slip44_id=slip44_ids))
|
||||
schemas.append(paths.PathSchema.parse(pattern, slip_44_ids))
|
||||
schemas = [s.copy() for s in schemas]
|
||||
|
||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
|
||||
keychain = await get_keychain(ctx, curve, schemas)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
@ -215,6 +218,8 @@ def with_slip44_keychain(
|
||||
def auto_keychain(
|
||||
modname: str, allow_testnet: bool = True
|
||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||
import sys
|
||||
|
||||
rdot = modname.rfind(".")
|
||||
parent_modname = modname[:rdot]
|
||||
parent_module = sys.modules[parent_modname]
|
||||
|
@ -1,6 +1,10 @@
|
||||
import storage.device
|
||||
from trezor import ui, utils, workflow
|
||||
from trezor.enums import BackupType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.device as storage_device
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.enums import BackupType
|
||||
|
||||
|
||||
def get() -> tuple[bytes | None, BackupType]:
|
||||
@ -8,11 +12,11 @@ def get() -> tuple[bytes | None, BackupType]:
|
||||
|
||||
|
||||
def get_secret() -> bytes | None:
|
||||
return storage.device.get_mnemonic_secret()
|
||||
return storage_device.get_mnemonic_secret()
|
||||
|
||||
|
||||
def get_type() -> BackupType:
|
||||
return storage.device.get_backup_type()
|
||||
return storage_device.get_backup_type()
|
||||
|
||||
|
||||
def is_bip39() -> bool:
|
||||
@ -20,6 +24,8 @@ def is_bip39() -> bool:
|
||||
If False then SLIP-39 (either Basic or Advanced).
|
||||
Other invalid values are checked directly in storage.
|
||||
"""
|
||||
from trezor.enums import BackupType
|
||||
|
||||
return get_type() == BackupType.Bip39
|
||||
|
||||
|
||||
@ -41,8 +47,8 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
|
||||
else: # SLIP-39
|
||||
from trezor.crypto import slip39
|
||||
|
||||
identifier = storage.device.get_slip39_identifier()
|
||||
iteration_exponent = storage.device.get_slip39_iteration_exponent()
|
||||
identifier = storage_device.get_slip39_identifier()
|
||||
iteration_exponent = storage_device.get_slip39_iteration_exponent()
|
||||
if identifier is None or iteration_exponent is None:
|
||||
# Identifier or exponent expected but not found
|
||||
raise RuntimeError
|
||||
@ -84,6 +90,7 @@ if not utils.BITCOIN_ONLY:
|
||||
|
||||
|
||||
def _start_progress() -> None:
|
||||
from trezor import workflow
|
||||
from trezor.ui.layouts import draw_simple_text
|
||||
|
||||
# Because we are drawing to the screen manually, without a layout, we
|
||||
@ -93,6 +100,8 @@ def _start_progress() -> None:
|
||||
|
||||
|
||||
def _render_progress(progress: int, total: int) -> None:
|
||||
from trezor import ui
|
||||
|
||||
p = 1000 * progress // total
|
||||
ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
|
||||
ui.refresh()
|
||||
|
@ -1,66 +1,72 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.device
|
||||
from trezor import wire, workflow
|
||||
import storage.device as storage_device
|
||||
from trezor.wire import DataError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.wire import Context
|
||||
|
||||
_MAX_PASSPHRASE_LEN = const(50)
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return storage.device.is_passphrase_enabled()
|
||||
return storage_device.is_passphrase_enabled()
|
||||
|
||||
|
||||
async def get(ctx: wire.Context) -> str:
|
||||
if is_enabled():
|
||||
return await _request_from_user(ctx)
|
||||
else:
|
||||
async def get(ctx: Context) -> str:
|
||||
from trezor import workflow
|
||||
|
||||
if not is_enabled():
|
||||
return ""
|
||||
|
||||
|
||||
async def _request_from_user(ctx: wire.Context) -> str:
|
||||
workflow.close_others() # request exclusive UI access
|
||||
if storage.device.get_passphrase_always_on_device():
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
|
||||
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
|
||||
else:
|
||||
passphrase = await _request_on_host(ctx)
|
||||
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
|
||||
raise wire.DataError(
|
||||
f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
|
||||
)
|
||||
workflow.close_others() # request exclusive UI access
|
||||
if storage_device.get_passphrase_always_on_device():
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
|
||||
return passphrase
|
||||
passphrase = await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
|
||||
else:
|
||||
passphrase = await _request_on_host(ctx)
|
||||
if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
|
||||
raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
|
||||
|
||||
return passphrase
|
||||
|
||||
|
||||
async def _request_on_host(ctx: wire.Context) -> str:
|
||||
async def _request_on_host(ctx: Context) -> str:
|
||||
from trezor.messages import PassphraseAck, PassphraseRequest
|
||||
from trezor.ui.layouts import draw_simple_text
|
||||
|
||||
_entry_dialog()
|
||||
# _entry_dialog
|
||||
draw_simple_text(
|
||||
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
|
||||
)
|
||||
|
||||
request = PassphraseRequest()
|
||||
ack = await ctx.call(request, PassphraseAck)
|
||||
passphrase = ack.passphrase # local_cache_attribute
|
||||
|
||||
if ack.on_device:
|
||||
from trezor.ui.layouts import request_passphrase_on_device
|
||||
|
||||
if ack.passphrase is not None:
|
||||
raise wire.DataError("Passphrase provided when it should not be")
|
||||
if passphrase is not None:
|
||||
raise DataError("Passphrase provided when it should not be")
|
||||
return await request_passphrase_on_device(ctx, _MAX_PASSPHRASE_LEN)
|
||||
|
||||
if ack.passphrase is None:
|
||||
raise wire.DataError(
|
||||
if passphrase is None:
|
||||
raise DataError(
|
||||
"Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
|
||||
)
|
||||
|
||||
# non-empty passphrase
|
||||
if ack.passphrase:
|
||||
if passphrase:
|
||||
from trezor import ui
|
||||
from trezor.ui.layouts import confirm_action, confirm_blob
|
||||
|
||||
await confirm_action(
|
||||
ctx,
|
||||
"passphrase_host1",
|
||||
title="Hidden wallet",
|
||||
"Hidden wallet",
|
||||
description="Access hidden wallet?\n\nNext screen will show\nthe passphrase!",
|
||||
icon=ui.ICON_CONFIG,
|
||||
)
|
||||
@ -68,19 +74,11 @@ async def _request_on_host(ctx: wire.Context) -> str:
|
||||
await confirm_blob(
|
||||
ctx,
|
||||
"passphrase_host2",
|
||||
title="Hidden wallet",
|
||||
description="Use this passphrase?\n",
|
||||
data=ack.passphrase,
|
||||
"Hidden wallet",
|
||||
passphrase,
|
||||
"Use this passphrase?\n",
|
||||
icon=ui.ICON_CONFIG,
|
||||
icon_color=ui.ORANGE_ICON,
|
||||
)
|
||||
|
||||
return ack.passphrase
|
||||
|
||||
|
||||
def _entry_dialog() -> None:
|
||||
from trezor.ui.layouts import draw_simple_text
|
||||
|
||||
draw_simple_text(
|
||||
"Passphrase entry", "Please type your\npassphrase on the\nconnected host."
|
||||
)
|
||||
return passphrase
|
||||
|
@ -197,23 +197,24 @@ class PathSchema:
|
||||
|
||||
# optionally replace a keyword
|
||||
component = cls.REPLACEMENTS.get(component, component)
|
||||
append = schema.append # local_cache_attribute
|
||||
|
||||
if "-" in component:
|
||||
# parse as a range
|
||||
a, b = [parse(s) for s in component.split("-", 1)]
|
||||
schema.append(Interval(a, b))
|
||||
append(Interval(a, b))
|
||||
|
||||
elif "," in component:
|
||||
# parse as a list of values
|
||||
schema.append(set(parse(s) for s in component.split(",")))
|
||||
append(set(parse(s) for s in component.split(",")))
|
||||
|
||||
elif component == "coin_type":
|
||||
# substitute SLIP-44 ids
|
||||
schema.append(set(parse(s) for s in slip44_id))
|
||||
append(set(parse(s) for s in slip44_id))
|
||||
|
||||
else:
|
||||
# plain constant
|
||||
schema.append((parse(component),))
|
||||
append((parse(component),))
|
||||
|
||||
return cls(schema, trailing_components, compact=True)
|
||||
|
||||
@ -258,18 +259,19 @@ class PathSchema:
|
||||
path. If the restriction results in a never-matching schema, then False
|
||||
is returned.
|
||||
"""
|
||||
schema = self.schema # local_cache_attribute
|
||||
|
||||
for i, value in enumerate(path):
|
||||
if i < len(self.schema):
|
||||
if i < len(schema):
|
||||
# Ensure that the path is a prefix of the schema.
|
||||
if value not in self.schema[i]:
|
||||
if value not in schema[i]:
|
||||
self.set_never_matching()
|
||||
return False
|
||||
|
||||
# Restrict the schema component if there are multiple choices.
|
||||
component = self.schema[i]
|
||||
component = schema[i]
|
||||
if not isinstance(component, tuple) or len(component) != 1:
|
||||
self.schema[i] = (value,)
|
||||
schema[i] = (value,)
|
||||
else:
|
||||
# The path is longer than the schema. We need to restrict the
|
||||
# trailing components.
|
||||
@ -278,7 +280,7 @@ class PathSchema:
|
||||
self.set_never_matching()
|
||||
return False
|
||||
|
||||
self.schema.append((value,))
|
||||
schema.append((value,))
|
||||
|
||||
return True
|
||||
|
||||
@ -286,6 +288,7 @@ class PathSchema:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
components = ["m"]
|
||||
append = components.append # local_cache_attribute
|
||||
|
||||
def unharden(item: int) -> int:
|
||||
return item ^ (item & HARDENED)
|
||||
@ -294,7 +297,7 @@ class PathSchema:
|
||||
if isinstance(component, Interval):
|
||||
a, b = component.min, component.max
|
||||
prime = "'" if a & HARDENED else ""
|
||||
components.append(f"[{unharden(a)}-{unharden(b)}]{prime}")
|
||||
append(f"[{unharden(a)}-{unharden(b)}]{prime}")
|
||||
else:
|
||||
# typechecker thinks component is a Contanier but we're using it
|
||||
# as a Collection.
|
||||
@ -307,15 +310,15 @@ class PathSchema:
|
||||
component_str = "[" + component_str + "]"
|
||||
if next(iter(collection)) & HARDENED:
|
||||
component_str += "'"
|
||||
components.append(component_str)
|
||||
append(component_str)
|
||||
|
||||
if self.trailing_components:
|
||||
for key, val in self.WILDCARD_RANGES.items():
|
||||
if self.trailing_components is val:
|
||||
components.append(key)
|
||||
append(key)
|
||||
break
|
||||
else:
|
||||
components.append("???")
|
||||
append("???")
|
||||
|
||||
return "<schema:" + "/".join(components) + ">"
|
||||
|
||||
@ -362,7 +365,7 @@ def path_is_hardened(address_n: Bip32Path) -> bool:
|
||||
|
||||
|
||||
def address_n_to_str(address_n: Iterable[int]) -> str:
|
||||
def path_item(i: int) -> str:
|
||||
def _path_item(i: int) -> str:
|
||||
if i & HARDENED:
|
||||
return str(i ^ HARDENED) + "'"
|
||||
else:
|
||||
@ -371,4 +374,4 @@ def address_n_to_str(address_n: Iterable[int]) -> str:
|
||||
if not address_n:
|
||||
return "m"
|
||||
|
||||
return "m/" + "/".join(path_item(i) for i in address_n)
|
||||
return "m/" + "/".join(_path_item(i) for i in address_n)
|
||||
|
@ -1,27 +1,32 @@
|
||||
from trezor.utils import BufferReader
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.utils import BufferReader
|
||||
|
||||
|
||||
def read_compact_size(r: BufferReader) -> int:
|
||||
prefix = r.get()
|
||||
get = r.get # local_cache_attribute
|
||||
|
||||
prefix = get()
|
||||
if prefix < 253:
|
||||
n = prefix
|
||||
elif prefix == 253:
|
||||
n = r.get()
|
||||
n += r.get() << 8
|
||||
n = get()
|
||||
n += get() << 8
|
||||
elif prefix == 254:
|
||||
n = r.get()
|
||||
n += r.get() << 8
|
||||
n += r.get() << 16
|
||||
n += r.get() << 24
|
||||
n = get()
|
||||
n += get() << 8
|
||||
n += get() << 16
|
||||
n += get() << 24
|
||||
elif prefix == 255:
|
||||
n = r.get()
|
||||
n += r.get() << 8
|
||||
n += r.get() << 16
|
||||
n += r.get() << 24
|
||||
n += r.get() << 32
|
||||
n += r.get() << 40
|
||||
n += r.get() << 48
|
||||
n += r.get() << 56
|
||||
n = get()
|
||||
n += get() << 8
|
||||
n += get() << 16
|
||||
n += get() << 24
|
||||
n += get() << 32
|
||||
n += get() << 40
|
||||
n += get() << 48
|
||||
n += get() << 56
|
||||
else:
|
||||
raise ValueError
|
||||
return n
|
||||
|
@ -1,20 +1,25 @@
|
||||
import utime
|
||||
from typing import Any, NoReturn
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.cache
|
||||
import storage.sd_salt
|
||||
import storage.cache as storage_cache
|
||||
from trezor import config, wire
|
||||
|
||||
from .sdcard import SdCardUnavailable, request_sd_salt
|
||||
from .sdcard import request_sd_salt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, NoReturn
|
||||
from trezor.wire import Context, GenericContext
|
||||
|
||||
|
||||
def can_lock_device() -> bool:
|
||||
"""Return True if the device has a PIN set or SD-protect enabled."""
|
||||
import storage.sd_salt
|
||||
|
||||
return config.has_pin() or storage.sd_salt.is_enabled()
|
||||
|
||||
|
||||
async def request_pin(
|
||||
ctx: wire.GenericContext,
|
||||
ctx: GenericContext,
|
||||
prompt: str = "Enter your PIN",
|
||||
attempts_remaining: int | None = None,
|
||||
allow_cancel: bool = True,
|
||||
@ -24,26 +29,26 @@ async def request_pin(
|
||||
return await request_pin_on_device(ctx, prompt, attempts_remaining, allow_cancel)
|
||||
|
||||
|
||||
async def request_pin_confirm(ctx: wire.Context, *args: Any, **kwargs: Any) -> str:
|
||||
async def request_pin_confirm(ctx: Context, *args: Any, **kwargs: Any) -> str:
|
||||
while True:
|
||||
pin1 = await request_pin(ctx, "Enter new PIN", *args, **kwargs)
|
||||
pin2 = await request_pin(ctx, "Re-enter new PIN", *args, **kwargs)
|
||||
if pin1 == pin2:
|
||||
return pin1
|
||||
await pin_mismatch()
|
||||
await _pin_mismatch()
|
||||
|
||||
|
||||
async def pin_mismatch() -> None:
|
||||
async def _pin_mismatch() -> None:
|
||||
from trezor.ui.layouts import show_popup
|
||||
|
||||
await show_popup(
|
||||
title="PIN mismatch",
|
||||
description="The PINs you entered\ndo not match.\n\nPlease try again.",
|
||||
"PIN mismatch",
|
||||
"The PINs you entered\ndo not match.\n\nPlease try again.",
|
||||
)
|
||||
|
||||
|
||||
async def request_pin_and_sd_salt(
|
||||
ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
|
||||
ctx: Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
|
||||
) -> tuple[str, bytearray | None]:
|
||||
if config.has_pin():
|
||||
pin = await request_pin(ctx, prompt, config.get_pin_rem(), allow_cancel)
|
||||
@ -58,21 +63,23 @@ async def request_pin_and_sd_salt(
|
||||
|
||||
def _set_last_unlock_time() -> None:
|
||||
now = utime.ticks_ms()
|
||||
storage.cache.set_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
|
||||
|
||||
|
||||
def _get_last_unlock_time() -> int:
|
||||
return storage.cache.get_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK) or 0
|
||||
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
|
||||
|
||||
|
||||
async def verify_user_pin(
|
||||
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
|
||||
ctx: GenericContext = wire.DUMMY_CONTEXT,
|
||||
prompt: str = "Enter your PIN",
|
||||
allow_cancel: bool = True,
|
||||
retry: bool = True,
|
||||
cache_time_ms: int = 0,
|
||||
) -> None:
|
||||
last_unlock = _get_last_unlock_time()
|
||||
from .sdcard import SdCardUnavailable
|
||||
|
||||
# _get_last_unlock_time
|
||||
last_unlock = int.from_bytes(
|
||||
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
|
||||
)
|
||||
|
||||
if (
|
||||
cache_time_ms
|
||||
and last_unlock
|
||||
@ -112,28 +119,28 @@ async def verify_user_pin(
|
||||
raise wire.PinInvalid
|
||||
|
||||
|
||||
async def error_pin_invalid(ctx: wire.Context) -> NoReturn:
|
||||
async def error_pin_invalid(ctx: Context) -> NoReturn:
|
||||
from trezor.ui.layouts import show_error_and_raise
|
||||
|
||||
await show_error_and_raise(
|
||||
ctx,
|
||||
"warning_wrong_pin",
|
||||
header="Wrong PIN",
|
||||
content="The PIN you entered is invalid.",
|
||||
"The PIN you entered is invalid.",
|
||||
"Wrong PIN", # header
|
||||
red=True,
|
||||
exc=wire.PinInvalid,
|
||||
)
|
||||
assert False
|
||||
|
||||
|
||||
async def error_pin_matches_wipe_code(ctx: wire.Context) -> NoReturn:
|
||||
async def error_pin_matches_wipe_code(ctx: Context) -> NoReturn:
|
||||
from trezor.ui.layouts import show_error_and_raise
|
||||
|
||||
await show_error_and_raise(
|
||||
ctx,
|
||||
"warning_invalid_new_pin",
|
||||
header="Invalid PIN",
|
||||
content="The new PIN must be different from your\nwipe code.",
|
||||
"The new PIN must be different from your\nwipe code.",
|
||||
"Invalid PIN", # header
|
||||
red=True,
|
||||
exc=wire.PinInvalid,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import storage.cache
|
||||
import storage.device
|
||||
import storage.cache as storage_cache
|
||||
import storage.device as storage_device
|
||||
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
|
||||
from trezor.enums import SafetyCheckLevel
|
||||
@ -9,11 +9,11 @@ def read_setting() -> SafetyCheckLevel:
|
||||
"""
|
||||
Returns the effective safety check level.
|
||||
"""
|
||||
temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
if temporary_safety_check_level:
|
||||
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
|
||||
else:
|
||||
stored = storage.device.safety_check_level()
|
||||
stored = storage_device.safety_check_level()
|
||||
if stored == SAFETY_CHECK_LEVEL_STRICT:
|
||||
return SafetyCheckLevel.Strict
|
||||
elif stored == SAFETY_CHECK_LEVEL_PROMPT:
|
||||
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
|
||||
Changes the safety level settings.
|
||||
"""
|
||||
if level == SafetyCheckLevel.Strict:
|
||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
elif level == SafetyCheckLevel.PromptAlways:
|
||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
||||
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
||||
elif level == SafetyCheckLevel.PromptTemporarily:
|
||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
|
||||
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
|
||||
else:
|
||||
raise ValueError("Unknown SafetyCheckLevel")
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import storage.sd_salt
|
||||
from storage.sd_salt import SD_CARD_HOT_SWAPPABLE
|
||||
from trezor import io, sdcard, ui, wire
|
||||
from trezor import io, ui, wire
|
||||
from trezor.ui.layouts import confirm_action, show_error_and_raise
|
||||
|
||||
|
||||
@ -14,8 +13,8 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
|
||||
ctx,
|
||||
"warning_wrong_sd",
|
||||
"SD card protection",
|
||||
action="Wrong SD card.",
|
||||
description="Please insert the correct SD card for this device.",
|
||||
"Wrong SD card.",
|
||||
"Please insert the correct SD card for this device.",
|
||||
verb="Retry",
|
||||
verb_cancel="Abort",
|
||||
icon=ui.ICON_WRONG,
|
||||
@ -26,9 +25,9 @@ async def _confirm_retry_wrong_card(ctx: wire.GenericContext) -> None:
|
||||
await show_error_and_raise(
|
||||
ctx,
|
||||
"warning_wrong_sd",
|
||||
header="SD card protection",
|
||||
subheader="Wrong SD card.",
|
||||
content="Please unplug the\ndevice and insert the correct SD card.",
|
||||
"Please unplug the\ndevice and insert the correct SD card.",
|
||||
"SD card protection",
|
||||
"Wrong SD card.",
|
||||
exc=SdCardUnavailable("Wrong SD card."),
|
||||
)
|
||||
|
||||
@ -39,8 +38,8 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
|
||||
ctx,
|
||||
"warning_no_sd",
|
||||
"SD card protection",
|
||||
action="SD card required.",
|
||||
description="Please insert your SD card.",
|
||||
"SD card required.",
|
||||
"Please insert your SD card.",
|
||||
verb="Retry",
|
||||
verb_cancel="Abort",
|
||||
icon=ui.ICON_WRONG,
|
||||
@ -51,9 +50,9 @@ async def _confirm_retry_insert_card(ctx: wire.GenericContext) -> None:
|
||||
await show_error_and_raise(
|
||||
ctx,
|
||||
"warning_no_sd",
|
||||
header="SD card protection",
|
||||
subheader="SD card required.",
|
||||
content="Please unplug the\ndevice and insert your SD card.",
|
||||
"Please unplug the\ndevice and insert your SD card.",
|
||||
"SD card protection",
|
||||
"SD card required.",
|
||||
exc=SdCardUnavailable("SD card required."),
|
||||
)
|
||||
|
||||
@ -64,8 +63,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
|
||||
ctx,
|
||||
"warning_format_sd",
|
||||
"SD card error",
|
||||
action="Unknown filesystem.",
|
||||
description="Use a different card or format the SD card to the FAT32 filesystem.",
|
||||
"Unknown filesystem.",
|
||||
"Use a different card or format the SD card to the FAT32 filesystem.",
|
||||
icon=ui.ICON_WRONG,
|
||||
icon_color=ui.RED,
|
||||
verb="Format",
|
||||
@ -79,8 +78,8 @@ async def _confirm_format_card(ctx: wire.GenericContext) -> None:
|
||||
ctx,
|
||||
"confirm_format_sd",
|
||||
"Format SD card",
|
||||
action="All data on the SD card will be lost.",
|
||||
description="Do you really want to format the SD card?",
|
||||
"All data on the SD card will be lost.",
|
||||
"Do you really want to format the SD card?",
|
||||
reverse=True,
|
||||
verb="Format SD card",
|
||||
icon=ui.ICON_WIPE,
|
||||
@ -99,8 +98,8 @@ async def confirm_retry_sd(
|
||||
ctx,
|
||||
"warning_sd_retry",
|
||||
"SD card problem",
|
||||
action=None,
|
||||
description="There was a problem accessing the SD card.",
|
||||
None,
|
||||
"There was a problem accessing the SD card.",
|
||||
icon=ui.ICON_WRONG,
|
||||
icon_color=ui.RED,
|
||||
verb="Retry",
|
||||
@ -121,18 +120,20 @@ async def ensure_sdcard(
|
||||
filesystem, and allows the user to format the card if a filesystem cannot be
|
||||
mounted.
|
||||
"""
|
||||
from trezor import sdcard
|
||||
|
||||
while not sdcard.is_present():
|
||||
await _confirm_retry_insert_card(ctx)
|
||||
|
||||
if not ensure_filesystem:
|
||||
return
|
||||
|
||||
fatfs = io.fatfs # local_cache_attribute
|
||||
while True:
|
||||
try:
|
||||
try:
|
||||
with sdcard.filesystem(mounted=False):
|
||||
io.fatfs.mount()
|
||||
except io.fatfs.NoFilesystem:
|
||||
fatfs.mount()
|
||||
except fatfs.NoFilesystem:
|
||||
# card not formatted. proceed out of the except clause
|
||||
pass
|
||||
else:
|
||||
@ -143,9 +144,9 @@ async def ensure_sdcard(
|
||||
|
||||
# Proceed to formatting. Failure is caught by the outside OSError handler
|
||||
with sdcard.filesystem(mounted=False):
|
||||
io.fatfs.mkfs()
|
||||
io.fatfs.mount()
|
||||
io.fatfs.setlabel("TREZOR")
|
||||
fatfs.mkfs()
|
||||
fatfs.mount()
|
||||
fatfs.setlabel("TREZOR")
|
||||
|
||||
# format and mount succeeded
|
||||
return
|
||||
@ -158,14 +159,16 @@ async def ensure_sdcard(
|
||||
async def request_sd_salt(
|
||||
ctx: wire.GenericContext = wire.DUMMY_CONTEXT,
|
||||
) -> bytearray | None:
|
||||
if not storage.sd_salt.is_enabled():
|
||||
import storage.sd_salt as storage_sd_salt
|
||||
|
||||
if not storage_sd_salt.is_enabled():
|
||||
return None
|
||||
|
||||
while True:
|
||||
await ensure_sdcard(ctx, ensure_filesystem=False)
|
||||
try:
|
||||
return storage.sd_salt.load_sd_salt()
|
||||
except (storage.sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
|
||||
return storage_sd_salt.load_sd_salt()
|
||||
except (storage_sd_salt.WrongSdCard, io.fatfs.NoFilesystem):
|
||||
await _confirm_retry_wrong_card(ctx)
|
||||
except OSError:
|
||||
# Generic problem with loading the SD salt (hardware problem, or we could
|
||||
|
@ -1,14 +1,17 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache, device
|
||||
from trezor import utils, wire
|
||||
from trezor.crypto import bip32, hmac
|
||||
import storage.cache as storage_cache
|
||||
import storage.device as storage_device
|
||||
from trezor import utils
|
||||
from trezor.crypto import hmac
|
||||
|
||||
from . import mnemonic
|
||||
from .passphrase import get as get_passphrase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .paths import Bip32Path, Slip21Path
|
||||
from trezor.wire import Context
|
||||
from trezor.crypto import bip32
|
||||
|
||||
|
||||
class Slip21Node:
|
||||
@ -47,14 +50,16 @@ if not utils.BITCOIN_ONLY:
|
||||
# We want to derive both the normal seed and the Cardano seed together, AND
|
||||
# expose a method for Cardano to do the same
|
||||
|
||||
async def derive_and_store_roots(ctx: wire.Context) -> None:
|
||||
if not device.is_initialized():
|
||||
async def derive_and_store_roots(ctx: Context) -> None:
|
||||
from trezor import wire
|
||||
|
||||
if not storage_device.is_initialized():
|
||||
raise wire.NotInitialized("Device is not initialized")
|
||||
|
||||
need_seed = not cache.is_set(cache.APP_COMMON_SEED)
|
||||
need_cardano_secret = cache.get(
|
||||
cache.APP_COMMON_DERIVE_CARDANO
|
||||
) and not cache.is_set(cache.APP_CARDANO_ICARUS_SECRET)
|
||||
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
|
||||
need_cardano_secret = storage_cache.get(
|
||||
storage_cache.APP_COMMON_DERIVE_CARDANO
|
||||
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
|
||||
|
||||
if not need_seed and not need_cardano_secret:
|
||||
return
|
||||
@ -63,17 +68,17 @@ if not utils.BITCOIN_ONLY:
|
||||
|
||||
if need_seed:
|
||||
common_seed = mnemonic.get_seed(passphrase)
|
||||
cache.set(cache.APP_COMMON_SEED, common_seed)
|
||||
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
|
||||
|
||||
if need_cardano_secret:
|
||||
from apps.cardano.seed import derive_and_store_secrets
|
||||
|
||||
derive_and_store_secrets(passphrase)
|
||||
|
||||
@cache.stored_async(cache.APP_COMMON_SEED)
|
||||
async def get_seed(ctx: wire.Context) -> bytes:
|
||||
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
|
||||
async def get_seed(ctx: Context) -> bytes:
|
||||
await derive_and_store_roots(ctx)
|
||||
common_seed = cache.get(cache.APP_COMMON_SEED)
|
||||
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
|
||||
assert common_seed is not None
|
||||
return common_seed
|
||||
|
||||
@ -81,15 +86,15 @@ else:
|
||||
# === Bitcoin-only variant ===
|
||||
# We use the simple version of `get_seed` that never needs to derive anything else.
|
||||
|
||||
@cache.stored_async(cache.APP_COMMON_SEED)
|
||||
async def get_seed(ctx: wire.Context) -> bytes:
|
||||
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
|
||||
async def get_seed(ctx: Context) -> bytes:
|
||||
passphrase = await get_passphrase(ctx)
|
||||
return mnemonic.get_seed(passphrase)
|
||||
|
||||
|
||||
@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
||||
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
|
||||
def _get_seed_without_passphrase() -> bytes:
|
||||
if not device.is_initialized():
|
||||
if not storage_device.is_initialized():
|
||||
raise Exception("Device is not initialized")
|
||||
return mnemonic.get_seed(progress_bar=False)
|
||||
|
||||
@ -97,6 +102,8 @@ def _get_seed_without_passphrase() -> bytes:
|
||||
def derive_node_without_passphrase(
|
||||
path: Bip32Path, curve_name: str = "secp256k1"
|
||||
) -> bip32.HDNode:
|
||||
from trezor.crypto import bip32
|
||||
|
||||
seed = _get_seed_without_passphrase()
|
||||
node = bip32.from_seed(seed, curve_name)
|
||||
node.derive_path(path)
|
||||
|
@ -1,16 +1,15 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
from trezor import utils, wire
|
||||
from trezor.crypto.hashlib import blake256, sha256
|
||||
|
||||
from apps.common.writers import write_compact_size
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.coininfo import CoinInfo
|
||||
|
||||
|
||||
def message_digest(coin: CoinInfo, message: bytes) -> bytes:
|
||||
from trezor import utils, wire
|
||||
from trezor.crypto.hashlib import blake256, sha256
|
||||
|
||||
from apps.common.writers import write_compact_size
|
||||
|
||||
if not utils.BITCOIN_ONLY and coin.decred:
|
||||
h = utils.HashWriter(blake256())
|
||||
else:
|
||||
@ -28,6 +27,8 @@ def message_digest(coin: CoinInfo, message: bytes) -> bytes:
|
||||
|
||||
|
||||
def decode_message(message: bytes) -> str:
|
||||
from ubinascii import hexlify
|
||||
|
||||
try:
|
||||
return bytes(message).decode()
|
||||
except UnicodeError:
|
||||
|
@ -6,61 +6,38 @@ if TYPE_CHECKING:
|
||||
from trezor.utils import Writer
|
||||
|
||||
|
||||
def _write_uint(w: Writer, n: int, bits: int, bigendian: bool) -> int:
|
||||
ensure(0 <= n <= 2**bits - 1, "overflow")
|
||||
shifts = range(0, bits, 8)
|
||||
if bigendian:
|
||||
shifts = reversed(shifts)
|
||||
for num in shifts:
|
||||
w.append((n >> num) & 0xFF)
|
||||
return bits // 8
|
||||
|
||||
|
||||
def write_uint8(w: Writer, n: int) -> int:
|
||||
ensure(0 <= n <= 0xFF)
|
||||
w.append(n)
|
||||
return 1
|
||||
return _write_uint(w, n, 8, False)
|
||||
|
||||
|
||||
def write_uint16_le(w: Writer, n: int) -> int:
|
||||
ensure(0 <= n <= 0xFFFF)
|
||||
w.append(n & 0xFF)
|
||||
w.append((n >> 8) & 0xFF)
|
||||
return 2
|
||||
return _write_uint(w, n, 16, False)
|
||||
|
||||
|
||||
def write_uint32_le(w: Writer, n: int) -> int:
|
||||
ensure(0 <= n <= 0xFFFF_FFFF)
|
||||
w.append(n & 0xFF)
|
||||
w.append((n >> 8) & 0xFF)
|
||||
w.append((n >> 16) & 0xFF)
|
||||
w.append((n >> 24) & 0xFF)
|
||||
return 4
|
||||
return _write_uint(w, n, 32, False)
|
||||
|
||||
|
||||
def write_uint32_be(w: Writer, n: int) -> int:
|
||||
ensure(0 <= n <= 0xFFFF_FFFF)
|
||||
w.append((n >> 24) & 0xFF)
|
||||
w.append((n >> 16) & 0xFF)
|
||||
w.append((n >> 8) & 0xFF)
|
||||
w.append(n & 0xFF)
|
||||
return 4
|
||||
return _write_uint(w, n, 32, True)
|
||||
|
||||
|
||||
def write_uint64_le(w: Writer, n: int) -> int:
|
||||
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF)
|
||||
w.append(n & 0xFF)
|
||||
w.append((n >> 8) & 0xFF)
|
||||
w.append((n >> 16) & 0xFF)
|
||||
w.append((n >> 24) & 0xFF)
|
||||
w.append((n >> 32) & 0xFF)
|
||||
w.append((n >> 40) & 0xFF)
|
||||
w.append((n >> 48) & 0xFF)
|
||||
w.append((n >> 56) & 0xFF)
|
||||
return 8
|
||||
return _write_uint(w, n, 64, False)
|
||||
|
||||
|
||||
def write_uint64_be(w: Writer, n: int) -> int:
|
||||
ensure(0 <= n <= 0xFFFF_FFFF_FFFF_FFFF)
|
||||
w.append((n >> 56) & 0xFF)
|
||||
w.append((n >> 48) & 0xFF)
|
||||
w.append((n >> 40) & 0xFF)
|
||||
w.append((n >> 32) & 0xFF)
|
||||
w.append((n >> 24) & 0xFF)
|
||||
w.append((n >> 16) & 0xFF)
|
||||
w.append((n >> 8) & 0xFF)
|
||||
w.append(n & 0xFF)
|
||||
return 8
|
||||
return _write_uint(w, n, 64, True)
|
||||
|
||||
|
||||
def write_bytes_unchecked(w: Writer, b: bytes | memoryview) -> int:
|
||||
@ -82,16 +59,18 @@ def write_bytes_reversed(w: Writer, b: bytes, length: int) -> int:
|
||||
|
||||
def write_compact_size(w: Writer, n: int) -> None:
|
||||
ensure(0 <= n <= 0xFFFF_FFFF)
|
||||
append = w.append # local_cache_attribute
|
||||
|
||||
if n < 253:
|
||||
w.append(n & 0xFF)
|
||||
append(n & 0xFF)
|
||||
elif n < 0x1_0000:
|
||||
w.append(253)
|
||||
append(253)
|
||||
write_uint16_le(w, n)
|
||||
elif n < 0x1_0000_0000:
|
||||
w.append(254)
|
||||
append(254)
|
||||
write_uint32_le(w, n)
|
||||
else:
|
||||
w.append(255)
|
||||
append(255)
|
||||
write_uint64_le(w, n)
|
||||
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
import math
|
||||
|
||||
from common import *
|
||||
|
||||
from apps.common.cbor import (
|
||||
@ -12,43 +10,8 @@ from apps.common.cbor import (
|
||||
decode,
|
||||
encode,
|
||||
encode_streamed,
|
||||
utils
|
||||
)
|
||||
|
||||
# NOTE: moved into tests not to occupy flash space
|
||||
# in firmware binary, when it is not used in production
|
||||
def encode_chunked(value: "Value", max_chunk_size: int) -> "Iterator[bytes]":
|
||||
"""
|
||||
Returns the encoded value as an iterable of chunks of a given size,
|
||||
removing the need to reserve a continuous chunk of memory for the
|
||||
full serialized representation of the value.
|
||||
"""
|
||||
if max_chunk_size <= 0:
|
||||
raise ValueError
|
||||
|
||||
chunks = encode_streamed(value)
|
||||
|
||||
chunk_buffer = utils.empty_bytearray(max_chunk_size)
|
||||
try:
|
||||
current_chunk_view = utils.BufferReader(next(chunks))
|
||||
while True:
|
||||
num_bytes_to_write = min(
|
||||
current_chunk_view.remaining_count(),
|
||||
max_chunk_size - len(chunk_buffer),
|
||||
)
|
||||
chunk_buffer.extend(current_chunk_view.read(num_bytes_to_write))
|
||||
|
||||
if len(chunk_buffer) >= max_chunk_size:
|
||||
yield chunk_buffer
|
||||
chunk_buffer[:] = bytes()
|
||||
|
||||
if current_chunk_view.remaining_count() == 0:
|
||||
current_chunk_view = utils.BufferReader(next(chunks))
|
||||
except StopIteration:
|
||||
if len(chunk_buffer) > 0:
|
||||
yield chunk_buffer
|
||||
|
||||
|
||||
|
||||
class TestCardanoCbor(unittest.TestCase):
|
||||
def test_create_array_header(self):
|
||||
@ -211,43 +174,6 @@ class TestCardanoCbor(unittest.TestCase):
|
||||
|
||||
self.assertEqual(b''.join(encoded_streamed), encoded)
|
||||
|
||||
def test_encode_chunked(self):
|
||||
large_dict = {i: i for i in range(100)}
|
||||
encoded = encode(large_dict)
|
||||
|
||||
encoded_len = len(encoded)
|
||||
assert encoded_len == 354
|
||||
|
||||
arbitrary_encoded_len_factor = 59
|
||||
arbitrary_power_of_two = 64
|
||||
larger_than_encoded_len = encoded_len + 1
|
||||
|
||||
for max_chunk_size in [
|
||||
1,
|
||||
10,
|
||||
arbitrary_encoded_len_factor,
|
||||
arbitrary_power_of_two,
|
||||
encoded_len,
|
||||
larger_than_encoded_len
|
||||
]:
|
||||
encoded_chunks = [
|
||||
bytes(chunk) for chunk in encode_chunked(large_dict, max_chunk_size)
|
||||
]
|
||||
|
||||
expected_number_of_chunks = math.ceil(len(encoded) / max_chunk_size)
|
||||
self.assertEqual(len(encoded_chunks), expected_number_of_chunks)
|
||||
|
||||
# all chunks except the last should be of chunk_size
|
||||
for i in range(len(encoded_chunks) - 1):
|
||||
self.assertEqual(len(encoded_chunks[i]), max_chunk_size)
|
||||
|
||||
# last chunk should contain the remaining bytes or the whole chunk
|
||||
remaining_bytes = len(encoded) % max_chunk_size
|
||||
expected_last_chunk_size = remaining_bytes if remaining_bytes > 0 else max_chunk_size
|
||||
self.assertEqual(len(encoded_chunks[-1]), expected_last_chunk_size)
|
||||
|
||||
self.assertEqual(b''.join(encoded_chunks), encoded)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
44
core/tests/test_apps.common.writers.py
Normal file
44
core/tests/test_apps.common.writers.py
Normal file
@ -0,0 +1,44 @@
|
||||
from common import *
|
||||
|
||||
import apps.common.writers as writers
|
||||
|
||||
|
||||
class TestSeed(unittest.TestCase):
|
||||
def test_write_uint8(self):
|
||||
buf = bytearray()
|
||||
writers.write_uint8(buf, 0x12)
|
||||
self.assertEqual(buf, b"\x12")
|
||||
|
||||
def test_write_uint16_le(self):
|
||||
buf = bytearray()
|
||||
writers.write_uint16_le(buf, 0x1234)
|
||||
self.assertEqual(buf, b"\x34\x12")
|
||||
|
||||
def test_write_uint16_le_overflow(self):
|
||||
buf = bytearray()
|
||||
with self.assertRaises(AssertionError):
|
||||
writers.write_uint16_le(buf, 0x12345678)
|
||||
|
||||
def test_write_uint32_le(self):
|
||||
buf = bytearray()
|
||||
writers.write_uint32_le(buf, 0x12345678)
|
||||
self.assertEqual(buf, b"\x78\x56\x34\x12")
|
||||
|
||||
def test_write_uint64_le(self):
|
||||
buf = bytearray()
|
||||
writers.write_uint64_le(buf, 0x1234567890abcdef)
|
||||
self.assertEqual(buf, b"\xef\xcd\xab\x90\x78\x56\x34\x12")
|
||||
|
||||
def test_write_uint32_be(self):
|
||||
buf = bytearray()
|
||||
writers.write_uint32_be(buf, 0x12345678)
|
||||
self.assertEqual(buf, b"\x12\x34\x56\x78")
|
||||
|
||||
def test_write_uint64_be(self):
|
||||
buf = bytearray()
|
||||
writers.write_uint64_be(buf, 0x1234567890abcdef)
|
||||
self.assertEqual(buf, b"\x12\x34\x56\x78\x90\xab\xcd\xef")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user