core/typing: add annotations

pull/320/head
Jan Pochyla 5 years ago
parent 86e16bbf31
commit 758a1a2528

@ -72,7 +72,7 @@ pylint: ## run pylint on application sources and tests
pylint -E $(shell find src tests -name *.py) pylint -E $(shell find src tests -name *.py)
mypy: mypy:
mypy \ mypy --config-file ../setup.cfg \
src/main.py src/main.py
## code generation: ## code generation:

@ -7,7 +7,7 @@ CURVE = "ed25519"
SEED_NAMESPACE = [HARDENED | 44, HARDENED | 1815] SEED_NAMESPACE = [HARDENED | 44, HARDENED | 1815]
def boot(): def boot() -> None:
wire.add(MessageType.CardanoGetAddress, __name__, "get_address") wire.add(MessageType.CardanoGetAddress, __name__, "get_address")
wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key") wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.CardanoSignTx, __name__, "sign_tx") wire.add(MessageType.CardanoSignTx, __name__, "sign_tx")

@ -4,8 +4,8 @@ from trezor import log, wire
from trezor.crypto import base58, hashlib from trezor.crypto import base58, hashlib
from trezor.crypto.curve import ed25519 from trezor.crypto.curve import ed25519
from trezor.messages.CardanoSignedTx import CardanoSignedTx from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxAck import CardanoTxAck
from trezor.messages.CardanoTxRequest import CardanoTxRequest from trezor.messages.CardanoTxRequest import CardanoTxRequest
from trezor.messages.MessageType import CardanoTxAck
from apps.cardano import CURVE, seed from apps.cardano import CURVE, seed
from apps.cardano.address import ( from apps.cardano.address import (

@ -1,4 +1,9 @@
def length(address_type): if False:
from typing import Tuple
from apps.common.coininfo import CoinType
def length(address_type: int) -> int:
if address_type <= 0xFF: if address_type <= 0xFF:
return 1 return 1
if address_type <= 0xFFFF: if address_type <= 0xFFFF:
@ -9,21 +14,21 @@ def length(address_type):
return 4 return 4
def tobytes(address_type: int): def tobytes(address_type: int) -> bytes:
return address_type.to_bytes(length(address_type), "big") return address_type.to_bytes(length(address_type), "big")
def check(address_type, raw_address): def check(address_type: int, raw_address: bytes) -> bool:
return raw_address.startswith(tobytes(address_type)) return raw_address.startswith(tobytes(address_type))
def strip(address_type, raw_address): def strip(address_type: int, raw_address: bytes) -> bytes:
if not check(address_type, raw_address): if not check(address_type, raw_address):
raise ValueError("Invalid address") raise ValueError("Invalid address")
return raw_address[length(address_type) :] return raw_address[length(address_type) :]
def split(coin, raw_address): def split(coin: CoinType, raw_address: bytes) -> Tuple[bytes, bytes]:
for f in ( for f in (
"address_type", "address_type",
"address_type_p2sh", "address_type_p2sh",

@ -2,12 +2,15 @@ from trezor.crypto import hashlib, hmac, random
from apps.common import storage from apps.common import storage
_cached_seed = None if False:
_cached_passphrase = None from typing import Optional
_cached_passphrase_fprint = b"\x00\x00\x00\x00"
_cached_seed = None # type: Optional[bytes]
_cached_passphrase = None # type: Optional[str]
_cached_passphrase_fprint = b"\x00\x00\x00\x00" # type: bytes
def get_state(prev_state: bytes = None, passphrase: str = None) -> bytes:
def get_state(prev_state: bytes = None, passphrase: str = None) -> Optional[bytes]:
if prev_state is None: if prev_state is None:
salt = random.bytes(32) # generate a random salt if no state provided salt = random.bytes(32) # generate a random salt if no state provided
else: else:
@ -29,34 +32,34 @@ def _compute_state(salt: bytes, passphrase: str) -> bytes:
return salt + state return salt + state
def get_seed(): def get_seed() -> Optional[bytes]:
return _cached_seed return _cached_seed
def get_passphrase(): def get_passphrase() -> Optional[str]:
return _cached_passphrase return _cached_passphrase
def get_passphrase_fprint(): def get_passphrase_fprint() -> bytes:
return _cached_passphrase_fprint return _cached_passphrase_fprint
def has_passphrase(): def has_passphrase() -> bool:
return _cached_passphrase is not None return _cached_passphrase is not None
def set_seed(seed): def set_seed(seed: Optional[bytes]) -> None:
global _cached_seed global _cached_seed
_cached_seed = seed _cached_seed = seed
def set_passphrase(passphrase): def set_passphrase(passphrase: Optional[str]) -> None:
global _cached_passphrase, _cached_passphrase_fprint global _cached_passphrase, _cached_passphrase_fprint
_cached_passphrase = passphrase _cached_passphrase = passphrase
_cached_passphrase_fprint = _compute_state(b"FPRINT", passphrase or "")[:4] _cached_passphrase_fprint = _compute_state(b"FPRINT", passphrase or "")[:4]
def clear(keep_passphrase: bool = False): def clear(keep_passphrase: bool = False) -> None:
set_seed(None) set_seed(None)
if not keep_passphrase: if not keep_passphrase:
set_passphrase(None) set_passphrase(None)

@ -6,7 +6,11 @@ import ustruct as struct
from micropython import const from micropython import const
from trezor import log from trezor import log
from trezor.utils import ensure
if False:
from typing import Any, Dict, Iterable, List, Tuple
Value = Any
_CBOR_TYPE_MASK = const(0xE0) _CBOR_TYPE_MASK = const(0xE0)
_CBOR_INFO_BITS = const(0x1F) _CBOR_INFO_BITS = const(0x1F)
@ -32,7 +36,7 @@ _CBOR_BREAK = const(0x1F)
_CBOR_RAW_TAG = const(0x18) _CBOR_RAW_TAG = const(0x18)
def _header(typ, l: int): def _header(typ: int, l: int) -> bytes:
if l < 24: if l < 24:
return struct.pack(">B", typ + l) return struct.pack(">B", typ + l)
elif l < 2 ** 8: elif l < 2 ** 8:
@ -47,7 +51,7 @@ def _header(typ, l: int):
raise NotImplementedError("Length %d not suppported" % l) raise NotImplementedError("Length %d not suppported" % l)
def _cbor_encode(value): def _cbor_encode(value: Value) -> Iterable[bytes]:
if isinstance(value, int): if isinstance(value, int):
if value >= 0: if value >= 0:
yield _header(_CBOR_UNSIGNED_INT, value) yield _header(_CBOR_UNSIGNED_INT, value)
@ -95,7 +99,7 @@ def _cbor_encode(value):
raise NotImplementedError raise NotImplementedError
def _read_length(cbor, aux): def _read_length(cbor: bytes, aux: int) -> Tuple[int, bytes]:
if aux < _CBOR_UINT8_FOLLOWS: if aux < _CBOR_UINT8_FOLLOWS:
return (aux, cbor) return (aux, cbor)
elif aux == _CBOR_UINT8_FOLLOWS: elif aux == _CBOR_UINT8_FOLLOWS:
@ -124,7 +128,7 @@ def _read_length(cbor, aux):
raise NotImplementedError("Length %d not suppported" % aux) raise NotImplementedError("Length %d not suppported" % aux)
def _cbor_decode(cbor): def _cbor_decode(cbor: bytes) -> Tuple[Value, bytes]:
fb = cbor[0] fb = cbor[0]
data = b"" data = b""
fb_type = fb & _CBOR_TYPE_MASK fb_type = fb & _CBOR_TYPE_MASK
@ -158,7 +162,7 @@ def _cbor_decode(cbor):
res.append(item) res.append(item)
return (res, data) return (res, data)
elif fb_type == _CBOR_MAP: elif fb_type == _CBOR_MAP:
res = {} res = {} # type: Dict[Value, Value]
if fb_aux == _CBOR_VAR_FOLLOWS: if fb_aux == _CBOR_VAR_FOLLOWS:
data = cbor[1:] data = cbor[1:]
while True: while True:
@ -201,36 +205,41 @@ def _cbor_decode(cbor):
class Tagged: class Tagged:
def __init__(self, tag, value): def __init__(self, tag: int, value: Value) -> None:
self.tag = tag self.tag = tag
self.value = value self.value = value
def __eq__(self, other): def __eq__(self, other: object) -> bool:
return self.tag == other.tag and self.value == other.value return (
isinstance(other, Tagged)
and self.tag == other.tag
and self.value == other.value
)
class Raw: class Raw:
def __init__(self, value): def __init__(self, value: Value):
self.value = value self.value = value
class IndefiniteLengthArray: class IndefiniteLengthArray:
def __init__(self, array): def __init__(self, array: List[Value]) -> None:
ensure(isinstance(array, list))
self.array = array self.array = array
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if isinstance(other, IndefiniteLengthArray): if isinstance(other, IndefiniteLengthArray):
return self.array == other.array return self.array == other.array
else: elif isinstance(other, list):
return self.array == other return self.array == other
else:
return False
def encode(value): def encode(value: Value) -> bytes:
return b"".join(_cbor_encode(value)) return b"".join(_cbor_encode(value))
def decode(cbor: bytes): def decode(cbor: bytes) -> Value:
res, check = _cbor_decode(cbor) res, check = _cbor_decode(cbor)
if not (check == b""): if not (check == b""):
raise ValueError raise ValueError

@ -1,23 +1,30 @@
from trezor import wire from trezor import wire
from trezor.messages import ButtonRequestType, MessageType from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.ui.confirm import CONFIRMED, Confirm, HoldToConfirm from trezor.ui.confirm import CONFIRMED, Confirm, HoldToConfirm
if __debug__: if __debug__:
from apps.debug import confirm_signal from apps.debug import confirm_signal
if False:
from typing import Any
from trezor import ui
from trezor.ui.confirm import ButtonContent, ButtonStyleType
from trezor.ui.loader import LoaderStyleType
async def confirm( async def confirm(
ctx, ctx: wire.Context,
content, content: ui.Control,
code=ButtonRequestType.Other, code: int = ButtonRequestType.Other,
confirm=Confirm.DEFAULT_CONFIRM, confirm: ButtonContent = Confirm.DEFAULT_CONFIRM,
confirm_style=Confirm.DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE,
cancel=Confirm.DEFAULT_CANCEL, cancel: ButtonContent = Confirm.DEFAULT_CANCEL,
cancel_style=Confirm.DEFAULT_CANCEL_STYLE, cancel_style: ButtonStyleType = Confirm.DEFAULT_CANCEL_STYLE,
major_confirm=None, major_confirm: bool = False,
): ) -> bool:
await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck) await ctx.call(ButtonRequest(code=code), ButtonAck)
if content.__class__.__name__ == "Paginated": if content.__class__.__name__ == "Paginated":
content.pages[-1] = Confirm( content.pages[-1] = Confirm(
@ -41,14 +48,14 @@ async def confirm(
async def hold_to_confirm( async def hold_to_confirm(
ctx, ctx: wire.Context,
content, content: ui.Control,
code=ButtonRequestType.Other, code: int = ButtonRequestType.Other,
confirm=HoldToConfirm.DEFAULT_CONFIRM, confirm: ButtonContent = Confirm.DEFAULT_CONFIRM,
confirm_style=HoldToConfirm.DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE,
loader_style=HoldToConfirm.DEFAULT_LOADER_STYLE, loader_style: LoaderStyleType = HoldToConfirm.DEFAULT_LOADER_STYLE,
): ) -> bool:
await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck) await ctx.call(ButtonRequest(code=code), ButtonAck)
if content.__class__.__name__ == "Paginated": if content.__class__.__name__ == "Paginated":
content.pages[-1] = HoldToConfirm( content.pages[-1] = HoldToConfirm(
@ -64,13 +71,13 @@ async def hold_to_confirm(
return await ctx.wait(dialog) is CONFIRMED return await ctx.wait(dialog) is CONFIRMED
async def require_confirm(*args, **kwargs): async def require_confirm(*args: Any, **kwargs: Any) -> None:
confirmed = await confirm(*args, **kwargs) confirmed = await confirm(*args, **kwargs)
if not confirmed: if not confirmed:
raise wire.ActionCancelled("Cancelled") raise wire.ActionCancelled("Cancelled")
async def require_hold_to_confirm(*args, **kwargs): async def require_hold_to_confirm(*args: Any, **kwargs: Any) -> None:
confirmed = await hold_to_confirm(*args, **kwargs) confirmed = await hold_to_confirm(*args, **kwargs)
if not confirmed: if not confirmed:
raise wire.ActionCancelled("Cancelled") raise wire.ActionCancelled("Cancelled")

@ -12,10 +12,14 @@ from trezor.utils import chunks
from apps.common import HARDENED from apps.common import HARDENED
from apps.common.confirm import confirm, require_confirm from apps.common.confirm import confirm, require_confirm
if False:
from typing import Iterable
from trezor import wire
async def show_address( async def show_address(
ctx, address: str, desc: str = "Confirm address", network: str = None ctx: wire.Context, address: str, desc: str = "Confirm address", network: str = None
): ) -> bool:
text = Text(desc, ui.ICON_RECEIVE, ui.GREEN) text = Text(desc, ui.ICON_RECEIVE, ui.GREEN)
if network is not None: if network is not None:
text.normal("%s network" % network) text.normal("%s network" % network)
@ -30,7 +34,9 @@ async def show_address(
) )
async def show_qr(ctx, address: str, desc: str = "Confirm address"): async def show_qr(
ctx: wire.Context, address: str, desc: str = "Confirm address"
) -> bool:
QR_X = const(120) QR_X = const(120)
QR_Y = const(115) QR_Y = const(115)
QR_COEF = const(4) QR_COEF = const(4)
@ -47,19 +53,19 @@ async def show_qr(ctx, address: str, desc: str = "Confirm address"):
) )
async def show_pubkey(ctx, pubkey: bytes): async def show_pubkey(ctx: wire.Context, pubkey: bytes) -> None:
lines = chunks(hexlify(pubkey).decode(), 18) lines = chunks(hexlify(pubkey).decode(), 18)
text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN) text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN)
text.mono(*lines) text.mono(*lines)
return await require_confirm(ctx, text, ButtonRequestType.PublicKey) await require_confirm(ctx, text, ButtonRequestType.PublicKey)
def split_address(address: str): def split_address(address: str) -> Iterable[str]:
return chunks(address, 17) return chunks(address, 17)
def address_n_to_str(address_n: list) -> str: def address_n_to_str(address_n: list) -> str:
def path_item(i: int): def path_item(i: int) -> str:
if i & HARDENED: if i & HARDENED:
return str(i ^ HARDENED) + "'" return str(i ^ HARDENED) + "'"
else: else:

@ -8,13 +8,16 @@ from trezor.utils import consteq
from apps.common import storage from apps.common import storage
from apps.common.mnemonic import bip39, slip39 from apps.common.mnemonic import bip39, slip39
if False:
from typing import Any, Tuple
TYPE_BIP39 = const(0) TYPE_BIP39 = const(0)
TYPE_SLIP39 = const(1) TYPE_SLIP39 = const(1)
TYPES_WORD_COUNT = {12: bip39, 18: bip39, 24: bip39, 20: slip39, 33: slip39} TYPES_WORD_COUNT = {12: bip39, 18: bip39, 24: bip39, 20: slip39, 33: slip39}
def get() -> (bytes, int): def get() -> Tuple[bytes, int]:
mnemonic_secret = storage.device.get_mnemonic_secret() mnemonic_secret = storage.device.get_mnemonic_secret()
mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39 mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39
if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39): if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39):
@ -22,15 +25,16 @@ def get() -> (bytes, int):
return mnemonic_secret, mnemonic_type return mnemonic_secret, mnemonic_type
def get_seed(passphrase: str = "", progress_bar=True): def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
mnemonic_secret, mnemonic_type = get() mnemonic_secret, mnemonic_type = get()
if mnemonic_type == TYPE_BIP39: if mnemonic_type == TYPE_BIP39:
return bip39.get_seed(mnemonic_secret, passphrase, progress_bar) return bip39.get_seed(mnemonic_secret, passphrase, progress_bar)
elif mnemonic_type == TYPE_SLIP39: elif mnemonic_type == TYPE_SLIP39:
return slip39.get_seed(mnemonic_secret, passphrase, progress_bar) return slip39.get_seed(mnemonic_secret, passphrase, progress_bar)
raise ValueError("Unknown mnemonic type")
def dry_run(secret: bytes): def dry_run(secret: bytes) -> None:
digest_input = sha256(secret).digest() digest_input = sha256(secret).digest()
stored, _ = get() stored, _ = get()
digest_stored = sha256(stored).digest() digest_stored = sha256(stored).digest()
@ -42,11 +46,11 @@ def dry_run(secret: bytes):
) )
def module_from_words_count(count: int): def module_from_words_count(count: int) -> Any:
return TYPES_WORD_COUNT[count] return TYPES_WORD_COUNT[count]
def _start_progress(): def _start_progress() -> None:
workflow.closedefault() workflow.closedefault()
ui.backlight_fade(ui.BACKLIGHT_DIM) ui.backlight_fade(ui.BACKLIGHT_DIM)
ui.display.clear() ui.display.clear()
@ -55,11 +59,11 @@ def _start_progress():
ui.backlight_fade(ui.BACKLIGHT_NORMAL) ui.backlight_fade(ui.BACKLIGHT_NORMAL)
def _render_progress(progress: int, total: int): def _render_progress(progress: int, total: int) -> None:
p = 1000 * progress // total p = 1000 * progress // total
ui.display.loader(p, False, 18, ui.WHITE, ui.BG) ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
ui.display.refresh() ui.display.refresh()
def _stop_progress(): def _stop_progress() -> None:
pass pass

@ -3,7 +3,7 @@ from trezor.crypto import bip39
from apps.common import mnemonic, storage from apps.common import mnemonic, storage
def get_type(): def get_type() -> int:
return mnemonic.TYPE_BIP39 return mnemonic.TYPE_BIP39
@ -23,20 +23,19 @@ def process_all(mnemonics: list) -> bytes:
return mnemonics[0].encode() return mnemonics[0].encode()
def store(secret: bytes, needs_backup: bool, no_backup: bool): def store(secret: bytes, needs_backup: bool, no_backup: bool) -> None:
storage.device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, mnemonic.TYPE_BIP39, needs_backup, no_backup secret, mnemonic.TYPE_BIP39, needs_backup, no_backup
) )
def get_seed(secret: bytes, passphrase: str, progress_bar=True): def get_seed(secret: bytes, passphrase: str, progress_bar: bool = True) -> bytes:
if progress_bar: if progress_bar:
mnemonic._start_progress() mnemonic._start_progress()
seed = bip39.seed(secret.decode(), passphrase, mnemonic._render_progress) seed = bip39.seed(secret.decode(), passphrase, mnemonic._render_progress)
mnemonic._stop_progress() mnemonic._stop_progress()
else: else:
seed = bip39.seed(secret.decode(), passphrase) seed = bip39.seed(secret.decode(), passphrase)
return seed return seed
@ -44,5 +43,5 @@ def get_mnemonic_threshold(mnemonic: str) -> int:
return 1 return 1
def check(secret: bytes): def check(secret: bytes) -> bool:
return bip39.check(secret) return bip39.check(secret)

@ -2,6 +2,9 @@ from trezor.crypto import slip39
from apps.common import mnemonic, storage from apps.common import mnemonic, storage
if False:
from typing import Optional
def generate_from_secret(master_secret: bytes, count: int, threshold: int) -> list: def generate_from_secret(master_secret: bytes, count: int, threshold: int) -> list:
""" """
@ -12,11 +15,11 @@ def generate_from_secret(master_secret: bytes, count: int, threshold: int) -> li
) )
def get_type(): def get_type() -> int:
return mnemonic.TYPE_SLIP39 return mnemonic.TYPE_SLIP39
def process_single(mnemonic: str) -> bytes: def process_single(mnemonic: str) -> Optional[bytes]:
""" """
Receives single mnemonic and processes it. Returns what is then stored in storage or Receives single mnemonic and processes it. Returns what is then stored in storage or
None if more shares are needed. None if more shares are needed.
@ -72,14 +75,16 @@ def process_all(mnemonics: list) -> bytes:
return secret return secret
def store(secret: bytes, needs_backup: bool, no_backup: bool): def store(secret: bytes, needs_backup: bool, no_backup: bool) -> None:
storage.device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, mnemonic.TYPE_SLIP39, needs_backup, no_backup secret, mnemonic.TYPE_SLIP39, needs_backup, no_backup
) )
storage.slip39.delete_progress() storage.slip39.delete_progress()
def get_seed(encrypted_master_secret: bytes, passphrase: str, progress_bar=True): def get_seed(
encrypted_master_secret: bytes, passphrase: str, progress_bar: bool = True
) -> bytes:
if progress_bar: if progress_bar:
mnemonic._start_progress() mnemonic._start_progress()
identifier = storage.slip39.get_identifier() identifier = storage.slip39.get_identifier()

@ -7,20 +7,32 @@ from trezor.ui.text import Text
from apps.common import HARDENED from apps.common import HARDENED
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
if False:
from typing import Any, Callable, List
from trezor import wire
from apps.common import seed
async def validate_path(ctx, validate_func, keychain, path, curve, **kwargs):
async def validate_path(
ctx: wire.Context,
validate_func: Callable[..., bool],
keychain: seed.Keychain,
path: List[int],
curve: str,
**kwargs: Any,
) -> None:
keychain.validate_path(path, curve) keychain.validate_path(path, curve)
if not validate_func(path, **kwargs): if not validate_func(path, **kwargs):
await show_path_warning(ctx, path) await show_path_warning(ctx, path)
async def show_path_warning(ctx, path: list): async def show_path_warning(ctx: wire.Context, path: List[int]) -> None:
text = Text("Confirm path", ui.ICON_WRONG, ui.RED) text = Text("Confirm path", ui.ICON_WRONG, ui.RED)
text.normal("Path") text.normal("Path")
text.mono(*break_address_n_to_lines(path)) text.mono(*break_address_n_to_lines(path))
text.normal("is unknown.") text.normal("is unknown.")
text.normal("Are you sure?") text.normal("Are you sure?")
return await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath) await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool: def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool:
@ -53,7 +65,7 @@ def is_hardened(i: int) -> bool:
def break_address_n_to_lines(address_n: list) -> list: def break_address_n_to_lines(address_n: list) -> list:
def path_item(i: int): def path_item(i: int) -> str:
if i & HARDENED: if i & HARDENED:
return str(i ^ HARDENED) + "'" return str(i ^ HARDENED) + "'"
else: else:

@ -1,9 +1,12 @@
from micropython import const from micropython import const
from trezor import ui, wire from trezor import ui, wire
from trezor.messages import ButtonRequestType, MessageType, PassphraseSourceType from trezor.messages import ButtonRequestType, PassphraseSourceType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.PassphraseAck import PassphraseAck
from trezor.messages.PassphraseRequest import PassphraseRequest from trezor.messages.PassphraseRequest import PassphraseRequest
from trezor.messages.PassphraseStateAck import PassphraseStateAck
from trezor.messages.PassphraseStateRequest import PassphraseStateRequest from trezor.messages.PassphraseStateRequest import PassphraseStateRequest
from trezor.ui.passphrase import CANCELLED, PassphraseKeyboard, PassphraseSource from trezor.ui.passphrase import CANCELLED, PassphraseKeyboard, PassphraseSource
from trezor.ui.popup import Popup from trezor.ui.popup import Popup
@ -17,14 +20,14 @@ if __debug__:
_MAX_PASSPHRASE_LEN = const(50) _MAX_PASSPHRASE_LEN = const(50)
async def protect_by_passphrase(ctx) -> str: async def protect_by_passphrase(ctx: wire.Context) -> str:
if storage.device.has_passphrase(): if storage.device.has_passphrase():
return await request_passphrase(ctx) return await request_passphrase(ctx)
else: else:
return "" return ""
async def request_passphrase(ctx) -> str: async def request_passphrase(ctx: wire.Context) -> str:
source = storage.device.get_passphrase_source() source = storage.device.get_passphrase_source()
if source == PassphraseSourceType.ASK: if source == PassphraseSourceType.ASK:
source = await request_passphrase_source(ctx) source = await request_passphrase_source(ctx)
@ -36,9 +39,9 @@ async def request_passphrase(ctx) -> str:
return passphrase return passphrase
async def request_passphrase_source(ctx) -> int: async def request_passphrase_source(ctx: wire.Context) -> int:
req = ButtonRequest(code=ButtonRequestType.PassphraseType) req = ButtonRequest(code=ButtonRequestType.PassphraseType)
await ctx.call(req, MessageType.ButtonAck) await ctx.call(req, ButtonAck)
text = Text("Enter passphrase", ui.ICON_CONFIG) text = Text("Enter passphrase", ui.ICON_CONFIG)
text.normal("Where to enter your", "passphrase?") text.normal("Where to enter your", "passphrase?")
@ -47,14 +50,14 @@ async def request_passphrase_source(ctx) -> int:
return await ctx.wait(source) return await ctx.wait(source)
async def request_passphrase_ack(ctx, on_device: bool) -> str: async def request_passphrase_ack(ctx: wire.Context, on_device: bool) -> str:
if not on_device: if not on_device:
text = Text("Passphrase entry", ui.ICON_CONFIG) text = Text("Passphrase entry", ui.ICON_CONFIG)
text.normal("Please, type passphrase", "on connected host.") text.normal("Please, type passphrase", "on connected host.")
await Popup(text) await Popup(text)
req = PassphraseRequest(on_device=on_device) req = PassphraseRequest(on_device=on_device)
ack = await ctx.call(req, MessageType.PassphraseAck) ack = await ctx.call(req, PassphraseAck)
if on_device: if on_device:
if ack.passphrase is not None: if ack.passphrase is not None:
@ -74,6 +77,6 @@ async def request_passphrase_ack(ctx, on_device: bool) -> str:
state = cache.get_state(prev_state=ack.state, passphrase=passphrase) state = cache.get_state(prev_state=ack.state, passphrase=passphrase)
req = PassphraseStateRequest(state=state) req = PassphraseStateRequest(state=state)
ack = await ctx.call(req, MessageType.PassphraseStateAck, MessageType.Cancel) ack = await ctx.call(req, PassphraseStateAck)
return passphrase return passphrase

@ -4,7 +4,8 @@ from trezor.crypto import bip32
from apps.common import HARDENED, cache, mnemonic, storage from apps.common import HARDENED, cache, mnemonic, storage
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
allow = list if False:
from typing import List, Optional
class Keychain: class Keychain:
@ -16,16 +17,16 @@ class Keychain:
def __init__(self, seed: bytes, namespaces: list): def __init__(self, seed: bytes, namespaces: list):
self.seed = seed self.seed = seed
self.namespaces = namespaces self.namespaces = namespaces
self.roots = [None] * len(namespaces) self.roots = [None] * len(namespaces) # type: List[Optional[bip32.HDNode]]
def __del__(self): def __del__(self) -> None:
for root in self.roots: for root in self.roots:
if root is not None: if root is not None:
root.__del__() root.__del__()
del self.roots del self.roots
del self.seed del self.seed
def validate_path(self, checked_path: list, checked_curve: str): def validate_path(self, checked_path: list, checked_curve: str) -> None:
for curve, *path in self.namespaces: for curve, *path in self.namespaces:
if path == checked_path[: len(path)] and curve == checked_curve: if path == checked_path[: len(path)] and curve == checked_curve:
if "ed25519" in curve and not _path_hardened(checked_path): if "ed25519" in curve and not _path_hardened(checked_path):

@ -5,8 +5,12 @@ from trezor.utils import HashWriter
from apps.wallet.sign_tx.writers import write_varint from apps.wallet.sign_tx.writers import write_varint
if False:
from typing import List
from apps.common.coininfo import CoinType
def message_digest(coin, message):
def message_digest(coin: CoinType, message: bytes) -> bytes:
if coin.decred: if coin.decred:
h = HashWriter(blake256()) h = HashWriter(blake256())
else: else:
@ -21,7 +25,7 @@ def message_digest(coin, message):
return ret return ret
def split_message(message): def split_message(message: bytes) -> List[str]:
try: try:
m = bytes(message).decode() m = bytes(message).decode()
words = m.split(" ") words = m.split(" ")

@ -1,5 +1,8 @@
from trezor.utils import ensure from trezor.utils import ensure
if False:
from trezor.utils import Writer
def empty_bytearray(preallocate: int) -> bytearray: def empty_bytearray(preallocate: int) -> bytearray:
""" """
@ -11,27 +14,27 @@ def empty_bytearray(preallocate: int) -> bytearray:
return b return b
def write_uint8(w: bytearray, n: int) -> int: def write_uint8(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFF) ensure(0 <= n <= 0xFF)
w.append(n) w.append(n)
return 1 return 1
def write_uint16_le(w: bytearray, n: int) -> int: def write_uint16_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF) ensure(0 <= n <= 0xFFFF)
w.append(n & 0xFF) w.append(n & 0xFF)
w.append((n >> 8) & 0xFF) w.append((n >> 8) & 0xFF)
return 2 return 2
def write_uint16_be(w: bytearray, n: int): def write_uint16_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF) ensure(0 <= n <= 0xFFFF)
w.append((n >> 8) & 0xFF) w.append((n >> 8) & 0xFF)
w.append(n & 0xFF) w.append(n & 0xFF)
return 2 return 2
def write_uint32_le(w: bytearray, n: int) -> int: def write_uint32_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFF) ensure(0 <= n <= 0xFFFFFFFF)
w.append(n & 0xFF) w.append(n & 0xFF)
w.append((n >> 8) & 0xFF) w.append((n >> 8) & 0xFF)
@ -40,7 +43,7 @@ def write_uint32_le(w: bytearray, n: int) -> int:
return 4 return 4
def write_uint32_be(w: bytearray, n: int) -> int: def write_uint32_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFF) ensure(0 <= n <= 0xFFFFFFFF)
w.append((n >> 24) & 0xFF) w.append((n >> 24) & 0xFF)
w.append((n >> 16) & 0xFF) w.append((n >> 16) & 0xFF)
@ -49,7 +52,7 @@ def write_uint32_be(w: bytearray, n: int) -> int:
return 4 return 4
def write_uint64_le(w: bytearray, n: int) -> int: def write_uint64_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF) ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF)
w.append(n & 0xFF) w.append(n & 0xFF)
w.append((n >> 8) & 0xFF) w.append((n >> 8) & 0xFF)
@ -62,7 +65,7 @@ def write_uint64_le(w: bytearray, n: int) -> int:
return 8 return 8
def write_uint64_be(w: bytearray, n: int) -> int: def write_uint64_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF) ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF)
w.append((n >> 56) & 0xFF) w.append((n >> 56) & 0xFF)
w.append((n >> 48) & 0xFF) w.append((n >> 48) & 0xFF)
@ -75,11 +78,11 @@ def write_uint64_be(w: bytearray, n: int) -> int:
return 8 return 8
def write_bytes(w: bytearray, b: bytes) -> int: def write_bytes(w: Writer, b: bytes) -> int:
w.extend(b) w.extend(b)
return len(b) return len(b)
def write_bytes_reversed(w: bytearray, b: bytes) -> int: def write_bytes_reversed(w: Writer, b: bytes) -> int:
w.extend(bytes(reversed(b))) w.extend(bytes(reversed(b)))
return len(b) return len(b)

@ -8,15 +8,24 @@ if __debug__:
from trezor.messages import MessageType from trezor.messages import MessageType
from trezor.wire import register, protobuf_workflow from trezor.wire import register, protobuf_workflow
reset_internal_entropy = None if False:
reset_current_words = None from typing import List, Optional
reset_word_index = None from trezor import wire
from trezor.messages.DebugLinkDecision import DebugLinkDecision
from trezor.messages.DebugLinkGetState import DebugLinkGetState
from trezor.messages.DebugLinkState import DebugLinkState
reset_internal_entropy = None # type: Optional[bytes]
reset_current_words = None # type: Optional[List[str]]
reset_word_index = None # type: Optional[int]
confirm_signal = loop.signal() confirm_signal = loop.signal()
swipe_signal = loop.signal() swipe_signal = loop.signal()
input_signal = loop.signal() input_signal = loop.signal()
async def dispatch_DebugLinkDecision(ctx, msg): async def dispatch_DebugLinkDecision(
ctx: wire.Context, msg: DebugLinkDecision
) -> None:
from trezor.ui import confirm, swipe from trezor.ui import confirm, swipe
if msg.yes_no is not None: if msg.yes_no is not None:
@ -26,7 +35,9 @@ if __debug__:
if msg.input is not None: if msg.input is not None:
input_signal.send(msg.input) input_signal.send(msg.input)
async def dispatch_DebugLinkGetState(ctx, msg): async def dispatch_DebugLinkGetState(
ctx: wire.Context, msg: DebugLinkGetState
) -> DebugLinkState:
from trezor.messages.DebugLinkState import DebugLinkState from trezor.messages.DebugLinkState import DebugLinkState
from apps.common import storage, mnemonic from apps.common import storage, mnemonic
@ -39,7 +50,7 @@ if __debug__:
m.reset_word = " ".join(reset_current_words) m.reset_word = " ".join(reset_current_words)
return m return m
def boot(): def boot() -> None:
# wipe storage when debug build is used on real hardware # wipe storage when debug build is used on real hardware
if not utils.EMULATOR: if not utils.EMULATOR:
config.wipe() config.wipe()

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "secp256k1" CURVE = "secp256k1"
def boot(): def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 194]] ns = [[CURVE, HARDENED | 44, HARDENED | 194]]
wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key", ns)

@ -1,13 +1,19 @@
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.messages.EosTxActionAck import EosTxActionAck
from trezor.messages.EosTxActionRequest import EosTxActionRequest from trezor.messages.EosTxActionRequest import EosTxActionRequest
from trezor.messages.MessageType import EosTxActionAck
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.eos import helpers, writers from apps.eos import helpers, writers
from apps.eos.actions import layout from apps.eos.actions import layout
if False:
from trezor import wire
from trezor.utils import Writer
async def process_action(ctx, sha, action):
async def process_action(
ctx: wire.Context, sha: HashWriter, action: EosTxActionAck
) -> None:
name = helpers.eos_name_to_string(action.common.name) name = helpers.eos_name_to_string(action.common.name)
account = helpers.eos_name_to_string(action.common.account) account = helpers.eos_name_to_string(action.common.account)
@ -65,7 +71,9 @@ async def process_action(ctx, sha, action):
writers.write_bytes(sha, w) writers.write_bytes(sha, w)
async def process_unknown_action(ctx, w, action): async def process_unknown_action(
ctx: wire.Context, w: Writer, action: EosTxActionAck
) -> None:
checksum = HashWriter(sha256()) checksum = HashWriter(sha256())
writers.write_variant32(checksum, action.unknown.data_size) writers.write_variant32(checksum, action.unknown.data_size)
checksum.extend(action.unknown.data_chunk) checksum.extend(action.unknown.data_chunk)
@ -91,7 +99,7 @@ async def process_unknown_action(ctx, w, action):
await layout.confirm_action_unknown(ctx, action.common, checksum.get_digest()) await layout.confirm_action_unknown(ctx, action.common, checksum.get_digest())
def check_action(action, name, account): def check_action(action: EosTxActionAck, name: str, account: str) -> bool:
if account == "eosio": if account == "eosio":
if ( if (
(name == "buyram" and action.buy_ram is not None) (name == "buyram" and action.buy_ram is not None)

@ -2,22 +2,7 @@ from micropython import const
from ubinascii import hexlify from ubinascii import hexlify
from trezor import ui from trezor import ui
from trezor.messages import ( from trezor.messages import ButtonRequestType
ButtonRequestType,
EosActionBuyRam,
EosActionBuyRamBytes,
EosActionDelegate,
EosActionDeleteAuth,
EosActionLinkAuth,
EosActionNewAccount,
EosActionRefund,
EosActionSellRam,
EosActionTransfer,
EosActionUndelegate,
EosActionUnlinkAuth,
EosActionUpdateAuth,
EosActionVoteProducer,
)
from trezor.ui.scroll import Paginated from trezor.ui.scroll import Paginated
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.utils import chunks from trezor.utils import chunks
@ -26,6 +11,25 @@ from apps.eos import helpers
from apps.eos.get_public_key import _public_key_to_wif from apps.eos.get_public_key import _public_key_to_wif
from apps.eos.layout import require_confirm from apps.eos.layout import require_confirm
if False:
from typing import List
from trezor import wire
from trezor.messages.EosAuthorization import EosAuthorization
from trezor.messages.EosActionBuyRam import EosActionBuyRam
from trezor.messages.EosActionBuyRamBytes import EosActionBuyRamBytes
from trezor.messages.EosActionCommon import EosActionCommon
from trezor.messages.EosActionDelegate import EosActionDelegate
from trezor.messages.EosActionDeleteAuth import EosActionDeleteAuth
from trezor.messages.EosActionLinkAuth import EosActionLinkAuth
from trezor.messages.EosActionNewAccount import EosActionNewAccount
from trezor.messages.EosActionRefund import EosActionRefund
from trezor.messages.EosActionSellRam import EosActionSellRam
from trezor.messages.EosActionTransfer import EosActionTransfer
from trezor.messages.EosActionUndelegate import EosActionUndelegate
from trezor.messages.EosActionUnlinkAuth import EosActionUnlinkAuth
from trezor.messages.EosActionUpdateAuth import EosActionUpdateAuth
from trezor.messages.EosActionVoteProducer import EosActionVoteProducer
_LINE_LENGTH = const(17) _LINE_LENGTH = const(17)
_LINE_PLACEHOLDER = "{:<" + str(_LINE_LENGTH) + "}" _LINE_PLACEHOLDER = "{:<" + str(_LINE_LENGTH) + "}"
_FIRST_PAGE = const(0) _FIRST_PAGE = const(0)
@ -35,7 +39,9 @@ _FOUR_FIELDS_PER_PAGE = const(4)
_FIVE_FIELDS_PER_PAGE = const(5) _FIVE_FIELDS_PER_PAGE = const(5)
async def _require_confirm_paginated(ctx, header, fields, per_page): async def _require_confirm_paginated(
ctx: wire.Context, header: str, fields: List[str], per_page: int
) -> None:
pages = [] pages = []
for page in chunks(fields, per_page): for page in chunks(fields, per_page):
if header == "Arbitrary data": if header == "Arbitrary data":
@ -47,7 +53,7 @@ async def _require_confirm_paginated(ctx, header, fields, per_page):
await require_confirm(ctx, Paginated(pages), ButtonRequestType.ConfirmOutput) await require_confirm(ctx, Paginated(pages), ButtonRequestType.ConfirmOutput)
async def confirm_action_buyram(ctx, msg: EosActionBuyRam): async def confirm_action_buyram(ctx: wire.Context, msg: EosActionBuyRam) -> None:
text = "Buy RAM" text = "Buy RAM"
fields = [] fields = []
fields.append("Payer:") fields.append("Payer:")
@ -59,7 +65,9 @@ async def confirm_action_buyram(ctx, msg: EosActionBuyRam):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_buyrambytes(ctx, msg: EosActionBuyRamBytes): async def confirm_action_buyrambytes(
ctx: wire.Context, msg: EosActionBuyRamBytes
) -> None:
text = "Buy RAM" text = "Buy RAM"
fields = [] fields = []
fields.append("Payer:") fields.append("Payer:")
@ -71,7 +79,7 @@ async def confirm_action_buyrambytes(ctx, msg: EosActionBuyRamBytes):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_delegate(ctx, msg: EosActionDelegate): async def confirm_action_delegate(ctx: wire.Context, msg: EosActionDelegate) -> None:
text = "Delegate" text = "Delegate"
fields = [] fields = []
fields.append("Sender:") fields.append("Sender:")
@ -93,7 +101,7 @@ async def confirm_action_delegate(ctx, msg: EosActionDelegate):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_sellram(ctx, msg: EosActionSellRam): async def confirm_action_sellram(ctx: wire.Context, msg: EosActionSellRam) -> None:
text = "Sell RAM" text = "Sell RAM"
fields = [] fields = []
fields.append("Receiver:") fields.append("Receiver:")
@ -103,7 +111,9 @@ async def confirm_action_sellram(ctx, msg: EosActionSellRam):
await _require_confirm_paginated(ctx, text, fields, _TWO_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _TWO_FIELDS_PER_PAGE)
async def confirm_action_undelegate(ctx, msg: EosActionUndelegate): async def confirm_action_undelegate(
ctx: wire.Context, msg: EosActionUndelegate
) -> None:
text = "Undelegate" text = "Undelegate"
fields = [] fields = []
fields.append("Sender:") fields.append("Sender:")
@ -117,14 +127,16 @@ async def confirm_action_undelegate(ctx, msg: EosActionUndelegate):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_refund(ctx, msg: EosActionRefund): async def confirm_action_refund(ctx: wire.Context, msg: EosActionRefund) -> None:
text = Text("Refund", ui.ICON_CONFIRM, icon_color=ui.GREEN) text = Text("Refund", ui.ICON_CONFIRM, icon_color=ui.GREEN)
text.normal("Owner:") text.normal("Owner:")
text.normal(helpers.eos_name_to_string(msg.owner)) text.normal(helpers.eos_name_to_string(msg.owner))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_action_voteproducer(ctx, msg: EosActionVoteProducer): async def confirm_action_voteproducer(
ctx: wire.Context, msg: EosActionVoteProducer
) -> None:
if msg.proxy and not msg.producers: if msg.proxy and not msg.producers:
# PROXY # PROXY
text = Text("Vote for proxy", ui.ICON_CONFIRM, icon_color=ui.GREEN) text = Text("Vote for proxy", ui.ICON_CONFIRM, icon_color=ui.GREEN)
@ -151,7 +163,9 @@ async def confirm_action_voteproducer(ctx, msg: EosActionVoteProducer):
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_action_transfer(ctx, msg: EosActionTransfer, account: str): async def confirm_action_transfer(
ctx: wire.Context, msg: EosActionTransfer, account: str
) -> None:
text = "Transfer" text = "Transfer"
fields = [] fields = []
fields.append("From:") fields.append("From:")
@ -170,7 +184,9 @@ async def confirm_action_transfer(ctx, msg: EosActionTransfer, account: str):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_updateauth(ctx, msg: EosActionUpdateAuth): async def confirm_action_updateauth(
ctx: wire.Context, msg: EosActionUpdateAuth
) -> None:
text = "Update Auth" text = "Update Auth"
fields = [] fields = []
fields.append("Account:") fields.append("Account:")
@ -183,7 +199,9 @@ async def confirm_action_updateauth(ctx, msg: EosActionUpdateAuth):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_deleteauth(ctx, msg: EosActionDeleteAuth): async def confirm_action_deleteauth(
ctx: wire.Context, msg: EosActionDeleteAuth
) -> None:
text = Text("Delete auth", ui.ICON_CONFIRM, icon_color=ui.GREEN) text = Text("Delete auth", ui.ICON_CONFIRM, icon_color=ui.GREEN)
text.normal("Account:") text.normal("Account:")
text.normal(helpers.eos_name_to_string(msg.account)) text.normal(helpers.eos_name_to_string(msg.account))
@ -192,7 +210,7 @@ async def confirm_action_deleteauth(ctx, msg: EosActionDeleteAuth):
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_action_linkauth(ctx, msg: EosActionLinkAuth): async def confirm_action_linkauth(ctx: wire.Context, msg: EosActionLinkAuth) -> None:
text = "Link Auth" text = "Link Auth"
fields = [] fields = []
fields.append("Account:") fields.append("Account:")
@ -206,7 +224,9 @@ async def confirm_action_linkauth(ctx, msg: EosActionLinkAuth):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_unlinkauth(ctx, msg: EosActionUnlinkAuth): async def confirm_action_unlinkauth(
ctx: wire.Context, msg: EosActionUnlinkAuth
) -> None:
text = "Unlink Auth" text = "Unlink Auth"
fields = [] fields = []
fields.append("Account:") fields.append("Account:")
@ -218,7 +238,9 @@ async def confirm_action_unlinkauth(ctx, msg: EosActionUnlinkAuth):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_newaccount(ctx, msg: EosActionNewAccount): async def confirm_action_newaccount(
ctx: wire.Context, msg: EosActionNewAccount
) -> None:
text = "New Account" text = "New Account"
fields = [] fields = []
fields.append("Creator:") fields.append("Creator:")
@ -230,7 +252,9 @@ async def confirm_action_newaccount(ctx, msg: EosActionNewAccount):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_unknown(ctx, action, checksum): async def confirm_action_unknown(
ctx: wire.Context, action: EosActionCommon, checksum: bytes
) -> None:
text = "Arbitrary data" text = "Arbitrary data"
fields = [] fields = []
fields.append("Contract:") fields.append("Contract:")
@ -242,7 +266,7 @@ async def confirm_action_unknown(ctx, action, checksum):
await _require_confirm_paginated(ctx, text, fields, _FIVE_FIELDS_PER_PAGE) await _require_confirm_paginated(ctx, text, fields, _FIVE_FIELDS_PER_PAGE)
def authorization_fields(auth): def authorization_fields(auth: EosAuthorization) -> List[str]:
fields = [] fields = []
fields.append("Threshold:") fields.append("Threshold:")
@ -288,11 +312,9 @@ def authorization_fields(auth):
return fields return fields
def split_data(data): def split_data(data: str) -> List[str]:
temp_list = [] lines = []
len_left = len(data) while data:
while len_left > 0: lines.append("{} ".format(data[:_LINE_LENGTH]))
temp_list.append("{} ".format(data[:_LINE_LENGTH]))
data = data[_LINE_LENGTH:] data = data[_LINE_LENGTH:]
len_left = len(data) return lines
return temp_list

@ -8,6 +8,11 @@ from apps.eos import CURVE
from apps.eos.helpers import base58_encode, validate_full_path from apps.eos.helpers import base58_encode, validate_full_path
from apps.eos.layout import require_get_public_key from apps.eos.layout import require_get_public_key
if False:
from typing import Tuple
from trezor.crypto import bip32
from apps.common import seed
def _public_key_to_wif(pub_key: bytes) -> str: def _public_key_to_wif(pub_key: bytes) -> str:
if pub_key[0] == 0x04 and len(pub_key) == 65: if pub_key[0] == 0x04 and len(pub_key) == 65:
@ -20,14 +25,16 @@ def _public_key_to_wif(pub_key: bytes) -> str:
return base58_encode("EOS", "", compressed_pub_key) return base58_encode("EOS", "", compressed_pub_key)
def _get_public_key(node): def _get_public_key(node: bip32.HDNode) -> Tuple[str, bytes]:
seckey = node.private_key() seckey = node.private_key()
public_key = secp256k1.publickey(seckey, True) public_key = secp256k1.publickey(seckey, True)
wif = _public_key_to_wif(public_key) wif = _public_key_to_wif(public_key)
return wif, public_key return wif, public_key
async def get_public_key(ctx, msg: EosGetPublicKey, keychain): async def get_public_key(
ctx: wire.Context, msg: EosGetPublicKey, keychain: seed.Keychain
) -> EosPublicKey:
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

@ -1,5 +1,5 @@
from trezor.crypto import base58 from trezor.crypto import base58
from trezor.messages import EosAsset from trezor.messages.EosAsset import EosAsset
from apps.common import HARDENED from apps.common import HARDENED
@ -12,7 +12,7 @@ def base58_encode(prefix: str, sig_prefix: str, data: bytes) -> str:
return prefix + b58 return prefix + b58
def eos_name_to_string(value) -> str: def eos_name_to_string(value: int) -> str:
charmap = ".12345abcdefghijklmnopqrstuvwxyz" charmap = ".12345abcdefghijklmnopqrstuvwxyz"
tmp = value tmp = value
string = "" string = ""

@ -1,19 +1,19 @@
from trezor import ui from trezor import ui, wire
from trezor.messages import ButtonRequestType from trezor.messages import ButtonRequestType
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
async def require_get_public_key(ctx, public_key): async def require_get_public_key(ctx: wire.Context, public_key: str) -> None:
text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN) text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN)
text.normal(public_key) text.normal(public_key)
return await require_confirm(ctx, text, ButtonRequestType.PublicKey) await require_confirm(ctx, text, ButtonRequestType.PublicKey)
async def require_sign_tx(ctx, num_actions): async def require_sign_tx(ctx: wire.Context, num_actions: int) -> None:
text = Text("Sign transaction", ui.ICON_SEND, ui.GREEN) text = Text("Sign transaction", ui.ICON_SEND, ui.GREEN)
text.normal("You are about") text.normal("You are about")
text.normal("to sign {}".format(num_actions)) text.normal("to sign {}".format(num_actions))
text.normal("action(s).") text.normal("action(s).")
return await require_confirm(ctx, text, ButtonRequestType.SignTx) await require_confirm(ctx, text, ButtonRequestType.SignTx)

@ -3,8 +3,8 @@ from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.messages.EosSignedTx import EosSignedTx from trezor.messages.EosSignedTx import EosSignedTx
from trezor.messages.EosSignTx import EosSignTx from trezor.messages.EosSignTx import EosSignTx
from trezor.messages.EosTxActionAck import EosTxActionAck
from trezor.messages.EosTxActionRequest import EosTxActionRequest from trezor.messages.EosTxActionRequest import EosTxActionRequest
from trezor.messages.MessageType import EosTxActionAck
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
@ -13,8 +13,13 @@ from apps.eos.actions import process_action
from apps.eos.helpers import base58_encode, validate_full_path from apps.eos.helpers import base58_encode, validate_full_path
from apps.eos.layout import require_sign_tx from apps.eos.layout import require_sign_tx
if False:
from apps.common import seed
async def sign_tx(ctx, msg: EosSignTx, keychain):
async def sign_tx(
ctx: wire.Context, msg: EosSignTx, keychain: seed.Keychain
) -> EosSignedTx:
if msg.chain_id is None: if msg.chain_id is None:
raise wire.DataError("No chain id") raise wire.DataError("No chain id")
if msg.header is None: if msg.header is None:
@ -39,7 +44,7 @@ async def sign_tx(ctx, msg: EosSignTx, keychain):
return EosSignedTx(signature=base58_encode("SIG_", "K1", signature)) return EosSignedTx(signature=base58_encode("SIG_", "K1", signature))
async def _init(ctx, sha, msg): async def _init(ctx: wire.Context, sha: HashWriter, msg: EosSignTx) -> None:
writers.write_bytes(sha, msg.chain_id) writers.write_bytes(sha, msg.chain_id)
writers.write_header(sha, msg.header) writers.write_header(sha, msg.header)
writers.write_variant32(sha, 0) writers.write_variant32(sha, 0)
@ -48,7 +53,7 @@ async def _init(ctx, sha, msg):
await require_sign_tx(ctx, msg.num_actions) await require_sign_tx(ctx, msg.num_actions)
async def _actions(ctx, sha, num_actions: int): async def _actions(ctx: wire.Context, sha: HashWriter, num_actions: int) -> None:
for i in range(num_actions): for i in range(num_actions):
action = await ctx.call(EosTxActionRequest(), EosTxActionAck) action = await ctx.call(EosTxActionRequest(), EosTxActionAck)
await process_action(ctx, sha, action) await process_action(ctx, sha, action)

@ -1,23 +1,3 @@
from trezor.messages import (
EosActionBuyRam,
EosActionBuyRamBytes,
EosActionCommon,
EosActionDelegate,
EosActionDeleteAuth,
EosActionLinkAuth,
EosActionNewAccount,
EosActionRefund,
EosActionSellRam,
EosActionTransfer,
EosActionUndelegate,
EosActionUpdateAuth,
EosActionVoteProducer,
EosAsset,
EosAuthorization,
EosTxHeader,
)
from trezor.utils import HashWriter
from apps.common.writers import ( from apps.common.writers import (
write_bytes, write_bytes,
write_uint8, write_uint8,
@ -26,8 +6,27 @@ from apps.common.writers import (
write_uint64_le, write_uint64_le,
) )
if False:
def write_auth(w: bytearray, auth: EosAuthorization) -> int: from trezor.messages.EosActionBuyRam import EosActionBuyRam
from trezor.messages.EosActionBuyRamBytes import EosActionBuyRamBytes
from trezor.messages.EosActionCommon import EosActionCommon
from trezor.messages.EosActionDelegate import EosActionDelegate
from trezor.messages.EosActionDeleteAuth import EosActionDeleteAuth
from trezor.messages.EosActionLinkAuth import EosActionLinkAuth
from trezor.messages.EosActionNewAccount import EosActionNewAccount
from trezor.messages.EosActionRefund import EosActionRefund
from trezor.messages.EosActionSellRam import EosActionSellRam
from trezor.messages.EosActionTransfer import EosActionTransfer
from trezor.messages.EosActionUndelegate import EosActionUndelegate
from trezor.messages.EosActionUpdateAuth import EosActionUpdateAuth
from trezor.messages.EosActionVoteProducer import EosActionVoteProducer
from trezor.messages.EosAsset import EosAsset
from trezor.messages.EosAuthorization import EosAuthorization
from trezor.messages.EosTxHeader import EosTxHeader
from trezor.utils import Writer
def write_auth(w: Writer, auth: EosAuthorization) -> None:
write_uint32_le(w, auth.threshold) write_uint32_le(w, auth.threshold)
write_variant32(w, len(auth.keys)) write_variant32(w, len(auth.keys))
for key in auth.keys: for key in auth.keys:
@ -47,7 +46,7 @@ def write_auth(w: bytearray, auth: EosAuthorization) -> int:
write_uint16_le(w, wait.weight) write_uint16_le(w, wait.weight)
def write_header(hasher: HashWriter, header: EosTxHeader): def write_header(hasher: Writer, header: EosTxHeader) -> None:
write_uint32_le(hasher, header.expiration) write_uint32_le(hasher, header.expiration)
write_uint16_le(hasher, header.ref_block_num) write_uint16_le(hasher, header.ref_block_num)
write_uint32_le(hasher, header.ref_block_prefix) write_uint32_le(hasher, header.ref_block_prefix)
@ -56,7 +55,7 @@ def write_header(hasher: HashWriter, header: EosTxHeader):
write_variant32(hasher, header.delay_sec) write_variant32(hasher, header.delay_sec)
def write_action_transfer(w: bytearray, msg: EosActionTransfer): def write_action_transfer(w: Writer, msg: EosActionTransfer) -> None:
write_uint64_le(w, msg.sender) write_uint64_le(w, msg.sender)
write_uint64_le(w, msg.receiver) write_uint64_le(w, msg.receiver)
write_asset(w, msg.quantity) write_asset(w, msg.quantity)
@ -64,24 +63,24 @@ def write_action_transfer(w: bytearray, msg: EosActionTransfer):
write_bytes(w, msg.memo) write_bytes(w, msg.memo)
def write_action_buyram(w: bytearray, msg: EosActionBuyRam): def write_action_buyram(w: Writer, msg: EosActionBuyRam) -> None:
write_uint64_le(w, msg.payer) write_uint64_le(w, msg.payer)
write_uint64_le(w, msg.receiver) write_uint64_le(w, msg.receiver)
write_asset(w, msg.quantity) write_asset(w, msg.quantity)
def write_action_buyrambytes(w: bytearray, msg: EosActionBuyRamBytes): def write_action_buyrambytes(w: Writer, msg: EosActionBuyRamBytes) -> None:
write_uint64_le(w, msg.payer) write_uint64_le(w, msg.payer)
write_uint64_le(w, msg.receiver) write_uint64_le(w, msg.receiver)
write_uint32_le(w, msg.bytes) write_uint32_le(w, msg.bytes)
def write_action_sellram(w: bytearray, msg: EosActionSellRam): def write_action_sellram(w: Writer, msg: EosActionSellRam) -> None:
write_uint64_le(w, msg.account) write_uint64_le(w, msg.account)
write_uint64_le(w, msg.bytes) write_uint64_le(w, msg.bytes)
def write_action_delegate(w: bytearray, msg: EosActionDelegate): def write_action_delegate(w: Writer, msg: EosActionDelegate) -> None:
write_uint64_le(w, msg.sender) write_uint64_le(w, msg.sender)
write_uint64_le(w, msg.receiver) write_uint64_le(w, msg.receiver)
write_asset(w, msg.net_quantity) write_asset(w, msg.net_quantity)
@ -89,18 +88,18 @@ def write_action_delegate(w: bytearray, msg: EosActionDelegate):
write_uint8(w, 1 if msg.transfer else 0) write_uint8(w, 1 if msg.transfer else 0)
def write_action_undelegate(w: bytearray, msg: EosActionUndelegate): def write_action_undelegate(w: Writer, msg: EosActionUndelegate) -> None:
write_uint64_le(w, msg.sender) write_uint64_le(w, msg.sender)
write_uint64_le(w, msg.receiver) write_uint64_le(w, msg.receiver)
write_asset(w, msg.net_quantity) write_asset(w, msg.net_quantity)
write_asset(w, msg.cpu_quantity) write_asset(w, msg.cpu_quantity)
def write_action_refund(w: bytearray, msg: EosActionRefund): def write_action_refund(w: Writer, msg: EosActionRefund) -> None:
write_uint64_le(w, msg.owner) write_uint64_le(w, msg.owner)
def write_action_voteproducer(w: bytearray, msg: EosActionVoteProducer): def write_action_voteproducer(w: Writer, msg: EosActionVoteProducer) -> None:
write_uint64_le(w, msg.voter) write_uint64_le(w, msg.voter)
write_uint64_le(w, msg.proxy) write_uint64_le(w, msg.proxy)
write_variant32(w, len(msg.producers)) write_variant32(w, len(msg.producers))
@ -108,61 +107,59 @@ def write_action_voteproducer(w: bytearray, msg: EosActionVoteProducer):
write_uint64_le(w, producer) write_uint64_le(w, producer)
def write_action_updateauth(w: bytearray, msg: EosActionUpdateAuth): def write_action_updateauth(w: Writer, msg: EosActionUpdateAuth) -> None:
write_uint64_le(w, msg.account) write_uint64_le(w, msg.account)
write_uint64_le(w, msg.permission) write_uint64_le(w, msg.permission)
write_uint64_le(w, msg.parent) write_uint64_le(w, msg.parent)
write_auth(w, msg.auth) write_auth(w, msg.auth)
def write_action_deleteauth(w: bytearray, msg: EosActionDeleteAuth): def write_action_deleteauth(w: Writer, msg: EosActionDeleteAuth) -> None:
write_uint64_le(w, msg.account) write_uint64_le(w, msg.account)
write_uint64_le(w, msg.permission) write_uint64_le(w, msg.permission)
def write_action_linkauth(w: bytearray, msg: EosActionLinkAuth): def write_action_linkauth(w: Writer, msg: EosActionLinkAuth) -> None:
write_uint64_le(w, msg.account) write_uint64_le(w, msg.account)
write_uint64_le(w, msg.code) write_uint64_le(w, msg.code)
write_uint64_le(w, msg.type) write_uint64_le(w, msg.type)
write_uint64_le(w, msg.requirement) write_uint64_le(w, msg.requirement)
def write_action_unlinkauth(w: bytearray, msg: EosActionLinkAuth): def write_action_unlinkauth(w: Writer, msg: EosActionLinkAuth) -> None:
write_uint64_le(w, msg.account) write_uint64_le(w, msg.account)
write_uint64_le(w, msg.code) write_uint64_le(w, msg.code)
write_uint64_le(w, msg.type) write_uint64_le(w, msg.type)
def write_action_newaccount(w: bytearray, msg: EosActionNewAccount): def write_action_newaccount(w: Writer, msg: EosActionNewAccount) -> None:
write_uint64_le(w, msg.creator) write_uint64_le(w, msg.creator)
write_uint64_le(w, msg.name) write_uint64_le(w, msg.name)
write_auth(w, msg.owner) write_auth(w, msg.owner)
write_auth(w, msg.active) write_auth(w, msg.active)
def write_action_common(hasher: HashWriter, msg: EosActionCommon): def write_action_common(w: Writer, msg: EosActionCommon) -> None:
write_uint64_le(hasher, msg.account) write_uint64_le(w, msg.account)
write_uint64_le(hasher, msg.name) write_uint64_le(w, msg.name)
write_variant32(hasher, len(msg.authorization)) write_variant32(w, len(msg.authorization))
for authorization in msg.authorization: for authorization in msg.authorization:
write_uint64_le(hasher, authorization.actor) write_uint64_le(w, authorization.actor)
write_uint64_le(hasher, authorization.permission) write_uint64_le(w, authorization.permission)
def write_asset(w: bytearray, asset: EosAsset) -> int: def write_asset(w: Writer, asset: EosAsset) -> None:
write_uint64_le(w, asset.amount) write_uint64_le(w, asset.amount)
write_uint64_le(w, asset.symbol) write_uint64_le(w, asset.symbol)
def write_variant32(w: bytearray, value: int) -> int: def write_variant32(w: Writer, value: int) -> None:
variant = bytearray() variant = bytearray()
while True: while True:
b = value & 0x7F b = value & 0x7F
value >>= 7 value >>= 7
b |= (value > 0) << 7 b |= (value > 0) << 7
variant.append(b) variant.append(b)
if value == 0: if value == 0:
break break
write_bytes(w, bytes(variant)) write_bytes(w, bytes(variant))

@ -7,7 +7,7 @@ from apps.ethereum.networks import all_slip44_ids_hardened
CURVE = "secp256k1" CURVE = "secp256k1"
def boot(): def boot() -> None:
ns = [] ns = []
for i in all_slip44_ids_hardened(): for i in all_slip44_ids_hardened():
ns.append([CURVE, HARDENED | 44, i]) ns.append([CURVE, HARDENED | 44, i])

@ -3,8 +3,8 @@ from trezor.crypto import rlp
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
from trezor.messages.EthereumSignTx import EthereumSignTx from trezor.messages.EthereumSignTx import EthereumSignTx
from trezor.messages.EthereumTxAck import EthereumTxAck
from trezor.messages.EthereumTxRequest import EthereumTxRequest from trezor.messages.EthereumTxRequest import EthereumTxRequest
from trezor.messages.MessageType import EthereumTxAck
from trezor.utils import HashWriter from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths

@ -6,8 +6,16 @@ from trezor.wire import protobuf_workflow, register
from apps.common import cache, storage from apps.common import cache, storage
if False:
from typing import NoReturn
from trezor.messages.Initialize import Initialize
from trezor.messages.GetFeatures import GetFeatures
from trezor.messages.Cancel import Cancel
from trezor.messages.ClearSession import ClearSession
from trezor.messages.Ping import Ping
def get_features():
def get_features() -> Features:
f = Features() f = Features()
f.vendor = "trezor.io" f.vendor = "trezor.io"
f.language = "english" f.language = "english"
@ -30,7 +38,7 @@ def get_features():
return f return f
async def handle_Initialize(ctx, msg): async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features:
if msg.state is None or msg.state != cache.get_state(prev_state=bytes(msg.state)): if msg.state is None or msg.state != cache.get_state(prev_state=bytes(msg.state)):
cache.clear() cache.clear()
if msg.skip_passphrase: if msg.skip_passphrase:
@ -38,20 +46,20 @@ async def handle_Initialize(ctx, msg):
return get_features() return get_features()
async def handle_GetFeatures(ctx, msg): async def handle_GetFeatures(ctx: wire.Context, msg: GetFeatures) -> Features:
return get_features() return get_features()
async def handle_Cancel(ctx, msg): async def handle_Cancel(ctx: wire.Context, msg: Cancel) -> NoReturn:
raise wire.ActionCancelled("Cancelled") raise wire.ActionCancelled("Cancelled")
async def handle_ClearSession(ctx, msg): async def handle_ClearSession(ctx: wire.Context, msg: ClearSession) -> Success:
cache.clear(keep_passphrase=True) cache.clear(keep_passphrase=True)
return Success(message="Session cleared") return Success(message="Session cleared")
async def handle_Ping(ctx, msg): async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
if msg.button_protection: if msg.button_protection:
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from trezor.messages.ButtonRequestType import ProtectCall from trezor.messages.ButtonRequestType import ProtectCall
@ -65,7 +73,7 @@ async def handle_Ping(ctx, msg):
return Success(message=msg.message) return Success(message=msg.message)
def boot(): def boot() -> None:
register(MessageType.Initialize, protobuf_workflow, handle_Initialize) register(MessageType.Initialize, protobuf_workflow, handle_Initialize)
register(MessageType.GetFeatures, protobuf_workflow, handle_GetFeatures) register(MessageType.GetFeatures, protobuf_workflow, handle_GetFeatures)
register(MessageType.Cancel, protobuf_workflow, handle_Cancel) register(MessageType.Cancel, protobuf_workflow, handle_Cancel)

@ -3,7 +3,7 @@ from trezor import config, io, loop, res, ui
from apps.common import storage from apps.common import storage
async def homescreen(): async def homescreen() -> None:
# render homescreen in dimmed mode and fade back in # render homescreen in dimmed mode and fade back in
ui.backlight_fade(ui.BACKLIGHT_DIM) ui.backlight_fade(ui.BACKLIGHT_DIM)
display_homescreen() display_homescreen()
@ -15,7 +15,7 @@ async def homescreen():
await touch await touch
def display_homescreen(): def display_homescreen() -> None:
image = None image = None
if storage.slip39.is_in_progress(): if storage.slip39.is_in_progress():
label = "Waiting for other shares" label = "Waiting for other shares"
@ -44,13 +44,13 @@ def display_homescreen():
ui.display.text_center(ui.WIDTH // 2, 220, label, ui.BOLD, ui.FG, ui.BG) ui.display.text_center(ui.WIDTH // 2, 220, label, ui.BOLD, ui.FG, ui.BG)
def _warn(message: str): def _warn(message: str) -> None:
ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW) ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW)
ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.BLACK, ui.YELLOW) ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.BLACK, ui.YELLOW)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG) ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
def _err(message: str): def _err(message: str) -> None:
ui.display.bar(0, 0, ui.WIDTH, 30, ui.RED) ui.display.bar(0, 0, ui.WIDTH, 30, ui.RED)
ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.WHITE, ui.RED) ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.WHITE, ui.RED)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG) ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519" CURVE = "ed25519"
def boot(): def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 134]] ns = [[CURVE, HARDENED | 44, HARDENED | 134]]
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns) wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns)

@ -2,7 +2,7 @@ from trezor import wire
from trezor.messages import MessageType from trezor.messages import MessageType
def boot(): def boot() -> None:
# only enable LoadDevice in debug builds # only enable LoadDevice in debug builds
if __debug__: if __debug__:
wire.add(MessageType.LoadDevice, __name__, "load_device") wire.add(MessageType.LoadDevice, __name__, "load_device")

@ -1,5 +1,6 @@
from trezor import config, ui, wire from trezor import config, ui, wire
from trezor.messages import ButtonRequestType, MessageType from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
@ -74,9 +75,7 @@ async def request_pin_confirm(ctx, *args, **kwargs):
async def request_pin_ack(ctx, *args, **kwargs): async def request_pin_ack(ctx, *args, **kwargs):
try: try:
await ctx.call( await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck)
ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck
)
return await ctx.wait(request_pin(*args, **kwargs)) return await ctx.wait(request_pin(*args, **kwargs))
except PinCancelled: except PinCancelled:
raise wire.ActionCancelled("Cancelled") raise wire.ActionCancelled("Cancelled")

@ -1,8 +1,8 @@
from trezor import config, ui, wire from trezor import config, ui, wire
from trezor.crypto import slip39 from trezor.crypto import slip39
from trezor.messages import ButtonRequestType from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.MessageType import ButtonAck
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
from trezor.ui.info import InfoConfirm from trezor.ui.info import InfoConfirm
@ -20,8 +20,11 @@ from apps.management.change_pin import request_pin_ack, request_pin_confirm
if __debug__: if __debug__:
from apps.debug import confirm_signal, input_signal from apps.debug import confirm_signal, input_signal
if False:
from trezor.messages.RecoveryDevice import RecoveryDevice
async def recovery_device(ctx, msg):
async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
""" """
Recover BIP39/SLIP39 seed into empty device. Recover BIP39/SLIP39 seed into empty device.
@ -116,7 +119,7 @@ async def recovery_device(ctx, msg):
return Success(message="Device recovered") return Success(message="Device recovered")
async def request_wordcount(ctx, title: str) -> int: async def request_wordcount(ctx: wire.Context, title: str) -> int:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck)
text = Text(title, ui.ICON_RECOVERY) text = Text(title, ui.ICON_RECOVERY)
@ -131,7 +134,7 @@ async def request_wordcount(ctx, title: str) -> int:
return count return count
async def request_mnemonic(ctx, count: int, slip39: bool) -> str: async def request_mnemonic(ctx: wire.Context, count: int, slip39: bool) -> str:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck)
words = [] words = []
@ -149,7 +152,7 @@ async def request_mnemonic(ctx, count: int, slip39: bool) -> str:
return " ".join(words) return " ".join(words)
async def show_keyboard_info(ctx): async def show_keyboard_info(ctx: wire.Context) -> None:
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck)
info = InfoConfirm( info = InfoConfirm(
@ -174,7 +177,9 @@ async def show_success(ctx):
) )
async def show_remaining_slip39_mnemonics(ctx, title, remaining: int): async def show_remaining_slip39_mnemonics(
ctx: wire.Context, title: str, remaining: int
) -> None:
text = Text(title, ui.ICON_RECOVERY) text = Text(title, ui.ICON_RECOVERY)
text.bold("Good job!") text.bold("Good job!")
text.normal("Enter %s more recovery " % remaining) text.normal("Enter %s more recovery " % remaining)

@ -1,6 +1,6 @@
from trezor import config, wire from trezor import config, wire
from trezor.crypto import bip39, hashlib, random, slip39 from trezor.crypto import bip39, hashlib, random, slip39
from trezor.messages import MessageType from trezor.messages.EntropyAck import EntropyAck
from trezor.messages.EntropyRequest import EntropyRequest from trezor.messages.EntropyRequest import EntropyRequest
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
@ -12,8 +12,11 @@ from apps.management.common import layout
if __debug__: if __debug__:
from apps import debug from apps import debug
if False:
from trezor.messages.ResetDevice import ResetDevice
async def reset_device(ctx, msg):
async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
# validate parameters and device state # validate parameters and device state
_validate_reset_device(msg) _validate_reset_device(msg)
@ -34,7 +37,7 @@ async def reset_device(ctx, msg):
await layout.show_internal_entropy(ctx, int_entropy) await layout.show_internal_entropy(ctx, int_entropy)
# request external entropy and compute the master secret # request external entropy and compute the master secret
entropy_ack = await ctx.call(EntropyRequest(), MessageType.EntropyAck) entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
ext_entropy = entropy_ack.entropy ext_entropy = entropy_ack.entropy
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength) secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
@ -84,7 +87,7 @@ async def reset_device(ctx, msg):
return Success(message="Initialized") return Success(message="Initialized")
async def backup_slip39_wallet(ctx, secret: bytes): async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None:
# get number of shares # get number of shares
await layout.slip39_show_checklist_set_shares(ctx) await layout.slip39_show_checklist_set_shares(ctx)
shares_count = await layout.slip39_prompt_number_of_shares(ctx) shares_count = await layout.slip39_prompt_number_of_shares(ctx)
@ -101,12 +104,12 @@ async def backup_slip39_wallet(ctx, secret: bytes):
await layout.slip39_show_and_confirm_shares(ctx, mnemonics) await layout.slip39_show_and_confirm_shares(ctx, mnemonics)
async def backup_bip39_wallet(ctx, secret: bytes): async def backup_bip39_wallet(ctx: wire.Context, secret: bytes) -> None:
mnemonic = bip39.from_data(secret) mnemonic = bip39.from_data(secret)
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic) await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic)
def _validate_reset_device(msg): def _validate_reset_device(msg: ResetDevice) -> None:
if msg.strength not in (128, 256): if msg.strength not in (128, 256):
if msg.slip39: if msg.slip39:
raise wire.ProcessError("Invalid strength (has to be 128 or 256 bits)") raise wire.ProcessError("Invalid strength (has to be 128 or 256 bits)")

@ -7,7 +7,7 @@ CURVE = "ed25519"
_LIVE_REFRESH_TOKEN = None # live-refresh permission token _LIVE_REFRESH_TOKEN = None # live-refresh permission token
def boot(): def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 128]] ns = [[CURVE, HARDENED | 44, HARDENED | 128]]
wire.add(MessageType.MoneroGetAddress, __name__, "get_address", ns) wire.add(MessageType.MoneroGetAddress, __name__, "get_address", ns)
wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only", ns) wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only", ns)

@ -1,11 +1,14 @@
import gc import gc
from trezor import log, wire from trezor import log, wire
from trezor.messages import MessageType
from trezor.messages.MoneroExportedKeyImage import MoneroExportedKeyImage from trezor.messages.MoneroExportedKeyImage import MoneroExportedKeyImage
from trezor.messages.MoneroKeyImageExportInitAck import MoneroKeyImageExportInitAck from trezor.messages.MoneroKeyImageExportInitAck import MoneroKeyImageExportInitAck
from trezor.messages.MoneroKeyImageSyncFinalAck import MoneroKeyImageSyncFinalAck from trezor.messages.MoneroKeyImageSyncFinalAck import MoneroKeyImageSyncFinalAck
from trezor.messages.MoneroKeyImageSyncFinalRequest import (
MoneroKeyImageSyncFinalRequest,
)
from trezor.messages.MoneroKeyImageSyncStepAck import MoneroKeyImageSyncStepAck from trezor.messages.MoneroKeyImageSyncStepAck import MoneroKeyImageSyncStepAck
from trezor.messages.MoneroKeyImageSyncStepRequest import MoneroKeyImageSyncStepRequest
from apps.common import paths from apps.common import paths
from apps.monero import CURVE, misc from apps.monero import CURVE, misc
@ -18,19 +21,12 @@ async def key_image_sync(ctx, msg, keychain):
state = KeyImageSync() state = KeyImageSync()
res = await _init_step(state, ctx, msg, keychain) res = await _init_step(state, ctx, msg, keychain)
while True: while state.current_output + 1 < state.num_outputs:
msg = await ctx.call( msg = await ctx.call(res, MoneroKeyImageSyncStepRequest)
res, res = await _sync_step(state, ctx, msg)
MessageType.MoneroKeyImageSyncStepRequest,
MessageType.MoneroKeyImageSyncFinalRequest,
)
del res
if msg.MESSAGE_WIRE_TYPE == MessageType.MoneroKeyImageSyncStepRequest:
res = await _sync_step(state, ctx, msg)
else:
res = await _final_step(state, ctx)
break
gc.collect() gc.collect()
msg = await ctx.call(res, MoneroKeyImageSyncFinalRequest)
res = await _final_step(state, ctx)
return res return res

@ -1,5 +1,6 @@
from trezor import loop, ui, utils from trezor import loop, ui, utils
from trezor.messages import ButtonRequestType, MessageType from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.ButtonRequest import ButtonRequest
from trezor.ui.text import Text from trezor.ui.text import Text
@ -27,9 +28,7 @@ async def naive_pagination(
paginated = PaginatedWithButtons(pages, one_by_one=True) paginated = PaginatedWithButtons(pages, one_by_one=True)
while True: while True:
await ctx.call( await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck)
ButtonRequest(code=ButtonRequestType.SignTx), MessageType.ButtonAck
)
if __debug__: if __debug__:
result = await loop.spawn(paginated, confirm_signal) result = await loop.spawn(paginated, confirm_signal)
else: else:

@ -21,7 +21,7 @@ async def live_refresh(ctx, msg: MoneroLiveRefreshStartRequest, keychain):
res = await _init_step(state, ctx, msg, keychain) res = await _init_step(state, ctx, msg, keychain)
while True: while True:
msg = await ctx.call( msg = await ctx.call_any(
res, res,
MessageType.MoneroLiveRefreshStepRequest, MessageType.MoneroLiveRefreshStepRequest,
MessageType.MoneroLiveRefreshFinalRequest, MessageType.MoneroLiveRefreshFinalRequest,

@ -26,7 +26,7 @@ async def sign_tx(ctx, received_msg, keychain):
del (result_msg, received_msg) del (result_msg, received_msg)
utils.unimport_end(mods) utils.unimport_end(mods)
received_msg = await ctx.read(accept_msgs) received_msg = await ctx.read_any(accept_msgs)
utils.unimport_end(mods) utils.unimport_end(mods)
return result_msg return result_msg

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519-keccak" CURVE = "ed25519-keccak"
def boot(): def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 43], [CURVE, HARDENED | 44, HARDENED | 1]] ns = [[CURVE, HARDENED | 44, HARDENED | 43], [CURVE, HARDENED | 44, HARDENED | 1]]
wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns) wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns)
wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns) wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "secp256k1" CURVE = "secp256k1"
def boot(): def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 144]] ns = [[CURVE, HARDENED | 44, HARDENED | 144]]
wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns) wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns)
wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns) wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519" CURVE = "ed25519"
def boot(): def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 148]] ns = [[CURVE, HARDENED | 44, HARDENED | 148]]
wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns) wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns)
wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns) wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns)

@ -77,7 +77,7 @@ def _timebounds(w: bytearray, start: int, end: int):
async def _operations(ctx, w: bytearray, num_operations: int): async def _operations(ctx, w: bytearray, num_operations: int):
writers.write_uint32(w, num_operations) writers.write_uint32(w, num_operations)
for i in range(num_operations): for i in range(num_operations):
op = await ctx.call(StellarTxOpRequest(), *consts.op_wire_types) op = await ctx.call_any(StellarTxOpRequest(), *consts.op_wire_types)
await process_operation(ctx, w, op) await process_operation(ctx, w, op)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519" CURVE = "ed25519"
def boot(): def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 1729]] ns = [[CURVE, HARDENED | 44, HARDENED | 1729]]
wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns) wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns)
wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns) wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns)

@ -2,7 +2,7 @@ from trezor import wire
from trezor.messages import MessageType from trezor.messages import MessageType
def boot(): def boot() -> None:
ns = [ ns = [
["curve25519"], ["curve25519"],
["ed25519"], ["ed25519"],

@ -1,6 +1,6 @@
from trezor import utils, wire from trezor import utils, wire
from trezor.messages.MessageType import TxAck
from trezor.messages.RequestType import TXFINISHED from trezor.messages.RequestType import TXFINISHED
from trezor.messages.TxAck import TxAck
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
from apps.common import paths from apps.common import paths

@ -5,7 +5,7 @@ from apps.common import storage
from apps.common.request_pin import request_pin from apps.common.request_pin import request_pin
async def bootscreen(): async def bootscreen() -> None:
ui.display.orientation(storage.device.get_rotation()) ui.display.orientation(storage.device.get_rotation())
while True: while True:
try: try:
@ -27,7 +27,7 @@ async def bootscreen():
log.exception(__name__, e) log.exception(__name__, e)
async def lockscreen(): async def lockscreen() -> None:
label = storage.device.get_label() label = storage.device.get_label()
image = storage.device.get_homescreen() image = storage.device.get_homescreen()
if not label: if not label:

@ -1,32 +1,31 @@
''' """
Extremely minimal streaming codec for a subset of protobuf. Supports uint32, Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields. bytes, string, embedded message and repeated fields.
"""
For de-serializing (loading) protobuf types, object with `AsyncReader` from micropython import const
interface is required:
>>> class AsyncReader: if False:
>>> async def areadinto(self, buffer): from typing import Any, Dict, List, Type, TypeVar
>>> """ from typing_extensions import Protocol
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
>>> """
For serializing (dumping) protobuf types, object with `AsyncWriter` interface is class AsyncReader(Protocol):
required: async def areadinto(self, buf: bytearray) -> int:
"""
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
"""
>>> class AsyncWriter: class AsyncWriter(Protocol):
>>> async def awrite(self, buffer): async def awrite(self, buf: bytes) -> int:
>>> """ """
>>> Writes all bytes from `buffer`, or raises `EOFError`. Writes all bytes from `buf`, or raises `EOFError`.
>>> """ """
'''
from micropython import const
_UVARINT_BUFFER = bytearray(1) _UVARINT_BUFFER = bytearray(1)
async def load_uvarint(reader): async def load_uvarint(reader: AsyncReader) -> int:
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
result = 0 result = 0
shift = 0 shift = 0
@ -39,11 +38,11 @@ async def load_uvarint(reader):
return result return result
async def dump_uvarint(writer, n): async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
if n < 0: if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.") raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
shifted = True shifted = 1
while shifted: while shifted:
shifted = n >> 7 shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
@ -51,7 +50,7 @@ async def dump_uvarint(writer, n):
n = shifted n = shifted
def count_uvarint(n): def count_uvarint(n: int) -> int:
if n < 0: if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.") raise ValueError("Cannot dump signed value, convert it to unsigned first.")
if n <= 0x7F: if n <= 0x7F:
@ -95,14 +94,14 @@ def count_uvarint(n):
# So we have to branch on whether the number is negative. # So we have to branch on whether the number is negative.
def sint_to_uint(sint): def sint_to_uint(sint: int) -> int:
res = sint << 1 res = sint << 1
if sint < 0: if sint < 0:
res = ~res res = ~res
return res return res
def uint_to_sint(uint): def uint_to_sint(uint: int) -> int:
sign = uint & 1 sign = uint & 1
res = uint >> 1 res = uint >> 1
if sign: if sign:
@ -133,27 +132,31 @@ class UnicodeType:
class MessageType: class MessageType:
WIRE_TYPE = 2 WIRE_TYPE = 2
# Type id for the wire codec.
# Technically, not every protobuf message has this.
MESSAGE_WIRE_TYPE = -1
@classmethod @classmethod
def get_fields(cls): def get_fields(cls) -> Dict:
return {} return {}
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
for kw in kwargs: for kw in kwargs:
setattr(self, kw, kwargs[kw]) setattr(self, kw, kwargs[kw])
def __eq__(self, rhs): def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self): def __repr__(self) -> str:
return "<%s>" % self.__class__.__name__ return "<%s>" % self.__class__.__name__
class LimitedReader: class LimitedReader:
def __init__(self, reader, limit): def __init__(self, reader: AsyncReader, limit: int) -> None:
self.reader = reader self.reader = reader
self.limit = limit self.limit = limit
async def areadinto(self, buf): async def areadinto(self, buf: bytearray) -> int:
if self.limit < len(buf): if self.limit < len(buf):
raise EOFError raise EOFError
else: else:
@ -162,20 +165,15 @@ class LimitedReader:
return nread return nread
class CountingWriter:
def __init__(self):
self.size = 0
async def awrite(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = const(1) FLAG_REPEATED = const(1)
if False:
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
async def load_message(reader, msg_type): async def load_message(
reader: AsyncReader, msg_type: Type[LoadedMessageType]
) -> LoadedMessageType:
fields = msg_type.get_fields() fields = msg_type.get_fields()
msg = msg_type() msg = msg_type()
@ -239,7 +237,9 @@ async def load_message(reader, msg_type):
return msg return msg
async def dump_message(writer, msg, fields=None): async def dump_message(
writer: AsyncWriter, msg: MessageType, fields: Dict = None
) -> None:
repvalue = [0] repvalue = [0]
if fields is None: if fields is None:
@ -297,7 +297,7 @@ async def dump_message(writer, msg, fields=None):
raise TypeError raise TypeError
def count_message(msg, fields=None): def count_message(msg: MessageType, fields: Dict = None) -> int:
nbytes = 0 nbytes = 0
repvalue = [0] repvalue = [0]
@ -361,7 +361,7 @@ def count_message(msg, fields=None):
return nbytes return nbytes
def _count_bytes_list(svalue): def _count_bytes_list(svalue: List[bytes]) -> int:
res = 0 res = 0
for x in svalue: for x in svalue:
res += len(x) res += len(x)

@ -1,5 +1,3 @@
from gc import collect
from trezorcrypto import ( # noqa: F401 from trezorcrypto import ( # noqa: F401
aes, aes,
bip32, bip32,
@ -12,18 +10,3 @@ from trezorcrypto import ( # noqa: F401
random, random,
rfc6979, rfc6979,
) )
class SecureContext:
def __init__(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
for k in self.__dict__:
o = getattr(self, k)
if hasattr(o, "__del__"):
o.__del__()
collect()

@ -1,5 +1,23 @@
if False:
from typing import Protocol, Type
class HashContext(Protocol):
digest_size = -1 # type: int
block_size = -1 # type: int
def __init__(self, data: bytes = None) -> None:
...
def update(self, data: bytes) -> None:
...
def digest(self) -> bytes:
...
class Hmac: class Hmac:
def __init__(self, key, msg, digestmod): def __init__(self, key: bytes, msg: bytes, digestmod: Type[HashContext]):
self.digestmod = digestmod self.digestmod = digestmod
self.inner = digestmod() self.inner = digestmod()
self.digest_size = self.inner.digest_size self.digest_size = self.inner.digest_size
@ -28,7 +46,7 @@ class Hmac:
return outer.digest() return outer.digest()
def new(key, msg, digestmod) -> Hmac: def new(key: bytes, msg: bytes, digestmod: Type[HashContext]) -> Hmac:
""" """
Creates a HMAC context object. Creates a HMAC context object.
""" """

@ -23,6 +23,11 @@ from micropython import const
from trezor.crypto import hashlib, hmac, pbkdf2, random from trezor.crypto import hashlib, hmac, pbkdf2, random
from trezorcrypto import shamir, slip39 from trezorcrypto import shamir, slip39
if False:
from typing import Dict, Iterable, List, Optional, Set, Tuple
Indices = Tuple[int, ...]
KEYBOARD_FULL_MASK = const(0x1FF) KEYBOARD_FULL_MASK = const(0x1FF)
"""All buttons are allowed. 9-bit bitmap all set to 1.""" """All buttons are allowed. 9-bit bitmap all set to 1."""
@ -35,7 +40,7 @@ def compute_mask(prefix: str) -> int:
def button_sequence_to_word(prefix: str) -> str: def button_sequence_to_word(prefix: str) -> str:
if not prefix: if not prefix:
return KEYBOARD_FULL_MASK return ""
return slip39.button_sequence_to_word(int(prefix)) return slip39.button_sequence_to_word(int(prefix))
@ -43,11 +48,11 @@ _RADIX_BITS = const(10)
"""The length of the radix in bits.""" """The length of the radix in bits."""
def bits_to_bytes(n): def bits_to_bytes(n: int) -> int:
return (n + 7) // 8 return (n + 7) // 8
def bits_to_words(n): def bits_to_words(n: int) -> int:
return (n + _RADIX_BITS - 1) // _RADIX_BITS return (n + _RADIX_BITS - 1) // _RADIX_BITS
@ -103,7 +108,7 @@ class MnemonicError(Exception):
pass pass
def _rs1024_polymod(values): def _rs1024_polymod(values: Indices) -> int:
GEN = ( GEN = (
0xE0E040, 0xE0E040,
0x1C1C080, 0x1C1C080,
@ -125,7 +130,7 @@ def _rs1024_polymod(values):
return chk return chk
def rs1024_create_checksum(data): def rs1024_create_checksum(data: Indices) -> Indices:
values = tuple(_CUSTOMIZATION_STRING) + data + _CHECKSUM_LENGTH_WORDS * (0,) values = tuple(_CUSTOMIZATION_STRING) + data + _CHECKSUM_LENGTH_WORDS * (0,)
polymod = _rs1024_polymod(values) ^ 1 polymod = _rs1024_polymod(values) ^ 1
return tuple( return tuple(
@ -133,11 +138,11 @@ def rs1024_create_checksum(data):
) )
def rs1024_verify_checksum(data): def rs1024_verify_checksum(data: Indices) -> bool:
return _rs1024_polymod(tuple(_CUSTOMIZATION_STRING) + data) == 1 return _rs1024_polymod(tuple(_CUSTOMIZATION_STRING) + data) == 1
def rs1024_error_index(data): def rs1024_error_index(data: Indices) -> Optional[int]:
GEN = ( GEN = (
0x91F9F87, 0x91F9F87,
0x122F1F07, 0x122F1F07,
@ -164,11 +169,11 @@ def rs1024_error_index(data):
return None return None
def xor(a, b): def xor(a: bytes, b: bytes) -> bytes:
return bytes(x ^ y for x, y in zip(a, b)) return bytes(x ^ y for x, y in zip(a, b))
def _int_from_indices(indices): def _int_from_indices(indices: Indices) -> int:
"""Converts a list of base 1024 indices in big endian order to an integer value.""" """Converts a list of base 1024 indices in big endian order to an integer value."""
value = 0 value = 0
for index in indices: for index in indices:
@ -176,21 +181,21 @@ def _int_from_indices(indices):
return value return value
def _int_to_indices(value, length, bits): def _int_to_indices(value: int, length: int, bits: int) -> Iterable[int]:
"""Converts an integer value to indices in big endian order.""" """Converts an integer value to indices in big endian order."""
mask = (1 << bits) - 1 mask = (1 << bits) - 1
return ((value >> (i * bits)) & mask for i in reversed(range(length))) return ((value >> (i * bits)) & mask for i in reversed(range(length)))
def mnemonic_from_indices(indices): def mnemonic_from_indices(indices: Indices) -> str:
return " ".join(slip39.get_word(i) for i in indices) return " ".join(slip39.get_word(i) for i in indices)
def mnemonic_to_indices(mnemonic): def mnemonic_to_indices(mnemonic: str) -> Iterable[int]:
return (slip39.word_index(word.lower()) for word in mnemonic.split()) return (slip39.word_index(word.lower()) for word in mnemonic.split())
def _round_function(i, passphrase, e, salt, r): def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes:
"""The round function used internally by the Feistel cipher.""" """The round function used internally by the Feistel cipher."""
return pbkdf2( return pbkdf2(
pbkdf2.HMAC_SHA256, pbkdf2.HMAC_SHA256,
@ -200,13 +205,15 @@ def _round_function(i, passphrase, e, salt, r):
).key()[: len(r)] ).key()[: len(r)]
def _get_salt(identifier): def _get_salt(identifier: int) -> bytes:
return _CUSTOMIZATION_STRING + identifier.to_bytes( return _CUSTOMIZATION_STRING + identifier.to_bytes(
bits_to_bytes(_ID_LENGTH_BITS), "big" bits_to_bytes(_ID_LENGTH_BITS), "big"
) )
def _encrypt(master_secret, passphrase, iteration_exponent, identifier): def _encrypt(
master_secret: bytes, passphrase: bytes, iteration_exponent: int, identifier: int
) -> bytes:
l = master_secret[: len(master_secret) // 2] l = master_secret[: len(master_secret) // 2]
r = master_secret[len(master_secret) // 2 :] r = master_secret[len(master_secret) // 2 :]
salt = _get_salt(identifier) salt = _get_salt(identifier)
@ -218,7 +225,12 @@ def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
return r + l return r + l
def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase): def decrypt(
identifier: int,
iteration_exponent: int,
encrypted_master_secret: bytes,
passphrase: bytes,
) -> bytes:
l = encrypted_master_secret[: len(encrypted_master_secret) // 2] l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :] r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
salt = _get_salt(identifier) salt = _get_salt(identifier)
@ -230,13 +242,15 @@ def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase)
return r + l return r + l
def _create_digest(random_data, shared_secret): def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[ return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
:_DIGEST_LENGTH_BYTES :_DIGEST_LENGTH_BYTES
] ]
def _split_secret(threshold, share_count, shared_secret): def _split_secret(
threshold: int, share_count: int, shared_secret: bytes
) -> List[Tuple[int, bytes]]:
if threshold < 1: if threshold < 1:
raise ValueError( raise ValueError(
"The requested threshold ({}) must be a positive integer.".format(threshold) "The requested threshold ({}) must be a positive integer.".format(threshold)
@ -278,7 +292,7 @@ def _split_secret(threshold, share_count, shared_secret):
return shares return shares
def _recover_secret(threshold, shares): def _recover_secret(threshold: int, shares: List[Tuple[int, bytes]]) -> bytes:
# If the threshold is 1, then the digest of the shared secret is not used. # If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1: if threshold == 1:
return shares[0][1] return shares[0][1]
@ -295,8 +309,12 @@ def _recover_secret(threshold, shares):
def _group_prefix( def _group_prefix(
identifier, iteration_exponent, group_index, group_threshold, group_count identifier: int,
): iteration_exponent: int,
group_index: int,
group_threshold: int,
group_count: int,
) -> Indices:
id_exp_int = (identifier << _ITERATION_EXP_LENGTH_BITS) + iteration_exponent id_exp_int = (identifier << _ITERATION_EXP_LENGTH_BITS) + iteration_exponent
return tuple(_int_to_indices(id_exp_int, _ID_EXP_LENGTH_WORDS, _RADIX_BITS)) + ( return tuple(_int_to_indices(id_exp_int, _ID_EXP_LENGTH_WORDS, _RADIX_BITS)) + (
(group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2), (group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2),
@ -304,15 +322,15 @@ def _group_prefix(
def encode_mnemonic( def encode_mnemonic(
identifier, identifier: int,
iteration_exponent, iteration_exponent: int,
group_index, group_index: int,
group_threshold, group_threshold: int,
group_count, group_count: int,
member_index, member_index: int,
member_threshold, member_threshold: int,
value, value: bytes,
): ) -> str:
""" """
Converts share data to a share mnemonic. Converts share data to a share mnemonic.
:param int identifier: The random identifier. :param int identifier: The random identifier.
@ -348,7 +366,7 @@ def encode_mnemonic(
return mnemonic_from_indices(share_data + checksum) return mnemonic_from_indices(share_data + checksum)
def decode_mnemonic(mnemonic): def decode_mnemonic(mnemonic: str) -> Tuple[int, int, int, int, int, int, int, bytes]:
"""Converts a share mnemonic to share data.""" """Converts a share mnemonic to share data."""
mnemonic_data = tuple(mnemonic_to_indices(mnemonic)) mnemonic_data = tuple(mnemonic_to_indices(mnemonic))
@ -401,12 +419,20 @@ def decode_mnemonic(mnemonic):
) )
def _decode_mnemonics(mnemonics): if False:
MnemonicGroups = Dict[int, Tuple[int, Set[Tuple[int, bytes]]]]
def _decode_mnemonics(
mnemonics: List[str]
) -> Tuple[int, int, int, int, MnemonicGroups]:
identifiers = set() identifiers = set()
iteration_exponents = set() iteration_exponents = set()
group_thresholds = set() group_thresholds = set()
group_counts = set() group_counts = set()
groups = {} # { group_index : [member_threshold, set_of_member_shares] }
# { group_index : [member_threshold, set_of_member_shares] }
groups = {} # type: MnemonicGroups
for mnemonic in mnemonics: for mnemonic in mnemonics:
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic( identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic(
mnemonic mnemonic
@ -415,7 +441,7 @@ def _decode_mnemonics(mnemonics):
iteration_exponents.add(iteration_exponent) iteration_exponents.add(iteration_exponent)
group_thresholds.add(group_threshold) group_thresholds.add(group_threshold)
group_counts.add(group_count) group_counts.add(group_count)
group = groups.setdefault(group_index, [member_threshold, set()]) group = groups.setdefault(group_index, (member_threshold, set()))
if group[0] != member_threshold: if group[0] != member_threshold:
raise MnemonicError( raise MnemonicError(
"Invalid set of mnemonics. All mnemonics in a group must have the same member threshold." "Invalid set of mnemonics. All mnemonics in a group must have the same member threshold."
@ -462,13 +488,13 @@ def generate_random_identifier() -> int:
def generate_single_group_mnemonics_from_data( def generate_single_group_mnemonics_from_data(
master_secret, master_secret: bytes,
identifier, identifier: int,
threshold, threshold: int,
count, count: int,
passphrase=b"", passphrase: bytes = b"",
iteration_exponent=DEFAULT_ITERATION_EXPONENT, iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> list: ) -> List[str]:
return generate_mnemonics_from_data( return generate_mnemonics_from_data(
master_secret, master_secret,
identifier, identifier,
@ -480,13 +506,13 @@ def generate_single_group_mnemonics_from_data(
def generate_mnemonics_from_data( def generate_mnemonics_from_data(
master_secret, master_secret: bytes,
identifier, identifier: int,
group_threshold, group_threshold: int,
groups, groups: List[Tuple[int, int]],
passphrase=b"", passphrase: bytes = b"",
iteration_exponent=DEFAULT_ITERATION_EXPONENT, iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> list: ) -> List[List[str]]:
""" """
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme. Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
:param master_secret: The master secret to split. :param master_secret: The master secret to split.
@ -544,7 +570,7 @@ def generate_mnemonics_from_data(
group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret) group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
mnemonics = [] mnemonics = [] # type: List[List[str]]
for (member_threshold, member_count), (group_index, group_secret) in zip( for (member_threshold, member_count), (group_index, group_secret) in zip(
groups, group_shares groups, group_shares
): ):
@ -568,7 +594,7 @@ def generate_mnemonics_from_data(
return mnemonics return mnemonics
def combine_mnemonics(mnemonics): def combine_mnemonics(mnemonics: List[str]) -> Tuple[int, int, bytes]:
""" """
Combines mnemonic shares to obtain the master secret which was previously split using Combines mnemonic shares to obtain the master secret which was previously split using
Shamir's secret sharing scheme. Shamir's secret sharing scheme.

@ -2,6 +2,9 @@ import sys
import utime import utime
from micropython import const from micropython import const
if False:
from typing import Any
NOTSET = const(0) NOTSET = const(0)
DEBUG = const(10) DEBUG = const(10)
INFO = const(20) INFO = const(20)
@ -21,7 +24,7 @@ level = DEBUG
color = True color = True
def _log(name, mlevel, msg, *args): def _log(name: str, mlevel: int, msg: str, *args: Any) -> None:
if __debug__ and mlevel >= level: if __debug__ and mlevel >= level:
if color: if color:
fmt = ( fmt = (
@ -35,26 +38,26 @@ def _log(name, mlevel, msg, *args):
print(fmt % ((utime.ticks_us(), name, _leveldict[mlevel][0]) + args)) print(fmt % ((utime.ticks_us(), name, _leveldict[mlevel][0]) + args))
def debug(name, msg, *args): def debug(name: str, msg: str, *args: Any) -> None:
_log(name, DEBUG, msg, *args) _log(name, DEBUG, msg, *args)
def info(name, msg, *args): def info(name: str, msg: str, *args: Any) -> None:
_log(name, INFO, msg, *args) _log(name, INFO, msg, *args)
def warning(name, msg, *args): def warning(name: str, msg: str, *args: Any) -> None:
_log(name, WARNING, msg, *args) _log(name, WARNING, msg, *args)
def error(name, msg, *args): def error(name: str, msg: str, *args: Any) -> None:
_log(name, ERROR, msg, *args) _log(name, ERROR, msg, *args)
def exception(name, exc): def critical(name: str, msg: str, *args: Any) -> None:
_log(name, ERROR, "exception:") _log(name, CRITICAL, msg, *args)
sys.print_exception(exc)
def critical(name, msg, *args): def exception(name: str, exc: BaseException) -> None:
_log(name, CRITICAL, msg, *args) _log(name, ERROR, "exception:")
sys.print_exception(exc)

@ -13,12 +13,33 @@ from micropython import const
from trezor import io, log from trezor import io, log
after_step_hook = None # function, called after each task step if False:
from typing import (
_QUEUE_SIZE = const(64) # maximum number of scheduled tasks Any,
_queue = utimeq.utimeq(_QUEUE_SIZE) Awaitable,
_paused = {} Callable,
_finalizers = {} Coroutine,
Dict,
Generator,
List,
Optional,
Set,
)
Task = Coroutine
Finalizer = Callable[[Task, Any], None]
# function to call after every task step
after_step_hook = None # type: Optional[Callable[[], None]]
# tasks scheduled for execution in the future
_queue = utimeq.utimeq(64)
# tasks paused on I/O
_paused = {} # type: Dict[int, Set[Task]]
# functions to execute after a task is finished
_finalizers = {} # type: Dict[int, Finalizer]
if __debug__: if __debug__:
# for performance stats # for performance stats
@ -29,7 +50,9 @@ if __debug__:
log_delay_rb = array.array("i", [0] * log_delay_rb_len) log_delay_rb = array.array("i", [0] * log_delay_rb_len)
def schedule(task, value=None, deadline=None, finalizer=None): def schedule(
task: Task, value: Any = None, deadline: int = None, finalizer: Finalizer = None
) -> None:
""" """
Schedule task to be executed with `value` on given `deadline` (in Schedule task to be executed with `value` on given `deadline` (in
microseconds). Does not start the event loop itself, see `run`. microseconds). Does not start the event loop itself, see `run`.
@ -41,20 +64,20 @@ def schedule(task, value=None, deadline=None, finalizer=None):
_queue.push(deadline, task, value) _queue.push(deadline, task, value)
def pause(task, iface): def pause(task: Task, iface: int) -> None:
tasks = _paused.get(iface, None) tasks = _paused.get(iface, None)
if tasks is None: if tasks is None:
tasks = _paused[iface] = set() tasks = _paused[iface] = set()
tasks.add(task) tasks.add(task)
def finalize(task, value): def finalize(task: Task, value: Any) -> None:
fn = _finalizers.pop(id(task), None) fn = _finalizers.pop(id(task), None)
if fn is not None: if fn is not None:
fn(task, value) fn(task, value)
def close(task): def close(task: Task) -> None:
for iface in _paused: for iface in _paused:
_paused[iface].discard(task) _paused[iface].discard(task)
_queue.discard(task) _queue.discard(task)
@ -62,7 +85,7 @@ def close(task):
finalize(task, GeneratorExit()) finalize(task, GeneratorExit())
def run(): def run() -> None:
""" """
Loop forever, stepping through scheduled tasks and awaiting I/O events Loop forever, stepping through scheduled tasks and awaiting I/O events
inbetween. Use `schedule` first to add a coroutine to the task queue. inbetween. Use `schedule` first to add a coroutine to the task queue.
@ -98,13 +121,17 @@ def run():
# timeout occurred, run the first scheduled task # timeout occurred, run the first scheduled task
if _queue: if _queue:
_queue.pop(task_entry) _queue.pop(task_entry)
_step(task_entry[1], task_entry[2]) _step(task_entry[1], task_entry[2]) # type: ignore
# error: Argument 1 to "_step" has incompatible type "int"; expected "Coroutine[Any, Any, Any]"
# rationale: We use untyped lists here, because that is what the C API supports.
def _step(task, value): def _step(task: Task, value: Any) -> None:
try: try:
if isinstance(value, BaseException): if isinstance(value, BaseException):
result = task.throw(value) result = task.throw(value) # type: ignore
# error: Argument 1 to "throw" of "Coroutine" has incompatible type "Exception"; expected "Type[BaseException]"
# rationale: In micropython, generator.throw() accepts the exception object directly.
else: else:
result = task.send(value) result = task.send(value)
except StopIteration as e: # as e: except StopIteration as e: # as e:
@ -133,10 +160,16 @@ class Syscall:
scheduler, they do so through instances of a class derived from `Syscall`. scheduler, they do so through instances of a class derived from `Syscall`.
""" """
def __iter__(self): def __iter__(self) -> Task: # type: ignore
# support `yield from` or `await` on syscalls # support `yield from` or `await` on syscalls
return (yield self) return (yield self)
def __await__(self) -> Generator:
return self.__iter__() # type: ignore
def handle(self, task: Task) -> None:
pass
class sleep(Syscall): class sleep(Syscall):
""" """
@ -150,10 +183,10 @@ class sleep(Syscall):
>>> print('missed by %d us', utime.ticks_diff(utime.ticks_us(), planned)) >>> print('missed by %d us', utime.ticks_diff(utime.ticks_us(), planned))
""" """
def __init__(self, delay_us): def __init__(self, delay_us: int) -> None:
self.delay_us = delay_us self.delay_us = delay_us
def handle(self, task): def handle(self, task: Task) -> None:
deadline = utime.ticks_add(utime.ticks_us(), self.delay_us) deadline = utime.ticks_add(utime.ticks_us(), self.delay_us)
schedule(task, deadline, deadline) schedule(task, deadline, deadline)
@ -170,14 +203,14 @@ class wait(Syscall):
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event >>> event, x, y = await loop.wait(io.TOUCH) # await touch event
""" """
def __init__(self, msg_iface): def __init__(self, msg_iface: int) -> None:
self.msg_iface = msg_iface self.msg_iface = msg_iface
def handle(self, task): def handle(self, task: Task) -> None:
pause(task, self.msg_iface) pause(task, self.msg_iface)
_NO_VALUE = () _NO_VALUE = object()
class signal(Syscall): class signal(Syscall):
@ -196,28 +229,28 @@ class signal(Syscall):
>>> # prints in the next iteration of the event loop >>> # prints in the next iteration of the event loop
""" """
def __init__(self): def __init__(self) -> None:
self.reset() self.reset()
def reset(self): def reset(self) -> None:
self.value = _NO_VALUE self.value = _NO_VALUE
self.task = None self.task = None # type: Optional[Task]
def handle(self, task): def handle(self, task: Task) -> None:
self.task = task self.task = task
self._deliver() self._deliver()
def send(self, value): def send(self, value: Any) -> None:
self.value = value self.value = value
self._deliver() self._deliver()
def _deliver(self): def _deliver(self) -> None:
if self.task is not None and self.value is not _NO_VALUE: if self.task is not None and self.value is not _NO_VALUE:
schedule(self.task, self.value) schedule(self.task, self.value)
self.task = None self.task = None
self.value = _NO_VALUE self.value = _NO_VALUE
def __iter__(self): def __iter__(self) -> Task: # type: ignore
try: try:
return (yield self) return (yield self)
except: # noqa: E722 except: # noqa: E722
@ -253,14 +286,13 @@ class spawn(Syscall):
`spawn.__iter__` for explanation. Always use `await`. `spawn.__iter__` for explanation. Always use `await`.
""" """
def __init__(self, *children, exit_others=True): def __init__(self, *children: Awaitable, exit_others: bool = True) -> None:
self.children = children self.children = children
self.exit_others = exit_others self.exit_others = exit_others
self.scheduled = [] # list of scheduled tasks self.finished = [] # type: List[Awaitable] # children that finished
self.finished = [] # list of children that finished self.scheduled = [] # type: List[Task] # scheduled wrapper tasks
self.callback = None
def handle(self, task): def handle(self, task: Task) -> None:
finalizer = self._finish finalizer = self._finish
scheduled = self.scheduled scheduled = self.scheduled
finished = self.finished finished = self.finished
@ -273,16 +305,17 @@ class spawn(Syscall):
if isinstance(child, _type_gen): if isinstance(child, _type_gen):
child_task = child child_task = child
else: else:
child_task = iter(child) child_task = iter(child) # type: ignore
schedule(child_task, None, None, finalizer) schedule(child_task, None, None, finalizer) # type: ignore
scheduled.append(child_task) scheduled.append(child_task) # type: ignore
# TODO: document the types here
def exit(self, except_for=None): def exit(self, except_for: Task = None) -> None:
for task in self.scheduled: for task in self.scheduled:
if task != except_for: if task != except_for:
close(task) close(task)
def _finish(self, task, result): def _finish(self, task: Task, result: Any) -> None:
if not self.finished: if not self.finished:
for index, child_task in enumerate(self.scheduled): for index, child_task in enumerate(self.scheduled):
if child_task is task: if child_task is task:
@ -293,7 +326,7 @@ class spawn(Syscall):
self.exit(task) self.exit(task)
schedule(self.callback, result) schedule(self.callback, result)
def __iter__(self): def __iter__(self) -> Task: # type: ignore
try: try:
return (yield self) return (yield self)
except: # noqa: E722 except: # noqa: E722

@ -1,17 +1,17 @@
try: try:
from .resources import resdata from .resources import resdata
except ImportError: except ImportError:
resdata = None resdata = {}
def load(name): def load(name: str) -> bytes:
""" """
Loads resource of a given name as bytes. Loads resource of a given name as bytes.
""" """
return resdata[name] return resdata[name]
def gettext(message): def gettext(message: str) -> str:
""" """
Returns localized string. This function is aliased to _. Returns localized string. This function is aliased to _.
""" """

@ -5,12 +5,18 @@ from trezorui import Display
from trezor import io, loop, res, utils, workflow from trezor import io, loop, res, utils, workflow
if False:
from typing import Any, Generator, Iterable, Tuple, TypeVar
Pos = Tuple[int, int]
Area = Tuple[int, int, int, int]
display = Display() display = Display()
# in debug mode, display an indicator in top right corner # in debug mode, display an indicator in top right corner
if __debug__: if __debug__:
def debug_display_refresh(): def debug_display_refresh() -> None:
display.bar(Display.WIDTH - 8, 0, 8, 8, 0xF800) display.bar(Display.WIDTH - 8, 0, 8, 8, 0xF800)
display.refresh() display.refresh()
if utils.SAVE_SCREEN: if utils.SAVE_SCREEN:
@ -59,25 +65,25 @@ from trezor.ui import style # isort:skip
from trezor.ui.style import * # isort:skip # noqa: F401,F403 from trezor.ui.style import * # isort:skip # noqa: F401,F403
def pulse(delay: int): def pulse(delay: int) -> float:
# normalize sin from interval -1:1 to 0:1 # normalize sin from interval -1:1 to 0:1
return 0.5 + 0.5 * math.sin(utime.ticks_us() / delay) return 0.5 + 0.5 * math.sin(utime.ticks_us() / delay)
async def click() -> tuple: async def click() -> Pos:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
while True: while True:
ev, *pos = yield touch ev, *pos = await touch
if ev == io.TOUCH_START: if ev == io.TOUCH_START:
break break
while True: while True:
ev, *pos = yield touch ev, *pos = await touch
if ev == io.TOUCH_END: if ev == io.TOUCH_END:
break break
return pos return pos # type: ignore
def backlight_fade(val: int, delay: int = 14000, step: int = 15): def backlight_fade(val: int, delay: int = 14000, step: int = 15) -> None:
if __debug__: if __debug__:
if utils.DISABLE_FADE: if utils.DISABLE_FADE:
display.backlight(val) display.backlight(val)
@ -96,7 +102,7 @@ def header(
fg: int = style.FG, fg: int = style.FG,
bg: int = style.BG, bg: int = style.BG,
ifg: int = style.GREEN, ifg: int = style.GREEN,
): ) -> None:
if icon is not None: if icon is not None:
display.icon(14, 15, res.load(icon), ifg, bg) display.icon(14, 15, res.load(icon), ifg, bg)
display.text(44, 35, title, BOLD, fg, bg) display.text(44, 35, title, BOLD, fg, bg)
@ -113,7 +119,7 @@ def grid(
cells_x: int = 1, cells_x: int = 1,
cells_y: int = 1, cells_y: int = 1,
spacing: int = 0, spacing: int = 0,
): ) -> Area:
w = (end_x - start_x) // n_x w = (end_x - start_x) // n_x
h = (end_y - start_y) // n_y h = (end_y - start_y) // n_y
x = (i % n_x) * w x = (i % n_x) * w
@ -121,7 +127,7 @@ def grid(
return (x + start_x, y + start_y, (w - spacing) * cells_x, (h - spacing) * cells_y) return (x + start_x, y + start_y, (w - spacing) * cells_x, (h - spacing) * cells_y)
def in_area(area: tuple, x: int, y: int) -> bool: def in_area(area: Area, x: int, y: int) -> bool:
ax, ay, aw, ah = area ax, ay, aw, ah = area
return ax <= x <= ax + aw and ay <= y <= ay + ah return ax <= x <= ax + aw and ay <= y <= ay + ah
@ -132,7 +138,7 @@ REPAINT = const(-256)
class Control: class Control:
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
if event is RENDER: if event is RENDER:
self.on_render() self.on_render()
elif event is io.TOUCH_START: elif event is io.TOUCH_START:
@ -144,16 +150,16 @@ class Control:
elif event is REPAINT: elif event is REPAINT:
self.repaint = True self.repaint = True
def on_render(self): def on_render(self) -> None:
pass pass
def on_touch_start(self, x, y): def on_touch_start(self, x: int, y: int) -> None:
pass pass
def on_touch_move(self, x, y): def on_touch_move(self, x: int, y: int) -> None:
pass pass
def on_touch_end(self, x, y): def on_touch_end(self, x: int, y: int) -> None:
pass pass
@ -164,8 +170,12 @@ class LayoutCancelled(Exception):
pass pass
if False:
ResultValue = TypeVar("ResultValue")
class Result(Exception): class Result(Exception):
def __init__(self, value): def __init__(self, value: ResultValue) -> None:
self.value = value self.value = value
@ -173,7 +183,7 @@ class Layout(Control):
""" """
""" """
async def __iter__(self): async def __iter__(self) -> ResultValue:
value = None value = None
try: try:
if workflow.layout_signal.task is not None: if workflow.layout_signal.task is not None:
@ -188,17 +198,20 @@ class Layout(Control):
workflow.onlayoutclose(self) workflow.onlayoutclose(self)
return value return value
def create_tasks(self): def __await__(self) -> Generator[Any, Any, ResultValue]:
return self.__iter__() # type: ignore
def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering() return self.handle_input(), self.handle_rendering()
def handle_input(self): def handle_input(self) -> loop.Task: # type: ignore
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
while True: while True:
event, x, y = yield touch event, x, y = yield touch
self.dispatch(event, x, y) self.dispatch(event, x, y)
self.dispatch(RENDER, 0, 0) self.dispatch(RENDER, 0, 0)
def handle_rendering(self): def handle_rendering(self) -> loop.Task: # type: ignore
backlight_fade(style.BACKLIGHT_DIM) backlight_fade(style.BACKLIGHT_DIM)
display.clear() display.clear()
self.dispatch(RENDER, 0, 0) self.dispatch(RENDER, 0, 0)

@ -3,6 +3,9 @@ from micropython import const
from trezor import ui from trezor import ui
from trezor.ui import display, in_area from trezor.ui import display, in_area
if False:
from typing import Type, Union
class ButtonDefault: class ButtonDefault:
class normal: class normal:
@ -12,14 +15,14 @@ class ButtonDefault:
border_color = ui.BG border_color = ui.BG
radius = ui.RADIUS radius = ui.RADIUS
class active: class active(normal):
bg_color = ui.FG bg_color = ui.FG
fg_color = ui.BLACKISH fg_color = ui.BLACKISH
text_style = ui.BOLD text_style = ui.BOLD
border_color = ui.FG border_color = ui.FG
radius = ui.RADIUS radius = ui.RADIUS
class disabled: class disabled(normal):
bg_color = ui.BG bg_color = ui.BG
fg_color = ui.GREY fg_color = ui.GREY
text_style = ui.NORMAL text_style = ui.NORMAL
@ -38,7 +41,7 @@ class ButtonMono(ButtonDefault):
text_style = ui.MONO text_style = ui.MONO
class ButtonMonoDark: class ButtonMonoDark(ButtonDefault):
class normal: class normal:
bg_color = ui.DARK_BLACK bg_color = ui.DARK_BLACK
fg_color = ui.DARK_WHITE fg_color = ui.DARK_WHITE
@ -46,14 +49,14 @@ class ButtonMonoDark:
border_color = ui.BG border_color = ui.BG
radius = ui.RADIUS radius = ui.RADIUS
class active: class active(normal):
bg_color = ui.FG bg_color = ui.FG
fg_color = ui.DARK_BLACK fg_color = ui.DARK_BLACK
text_style = ui.MONO text_style = ui.MONO
border_color = ui.FG border_color = ui.FG
radius = ui.RADIUS radius = ui.RADIUS
class disabled: class disabled(normal):
bg_color = ui.DARK_BLACK bg_color = ui.DARK_BLACK
fg_color = ui.GREY fg_color = ui.GREY
text_style = ui.MONO text_style = ui.MONO
@ -98,6 +101,12 @@ class ButtonMonoConfirm(ButtonDefault):
text_style = ui.MONO text_style = ui.MONO
if False:
ButtonContent = Union[str, bytes]
ButtonStyleType = Type[ButtonDefault]
ButtonStyleStateType = Type[ButtonDefault.normal]
# button states # button states
_INITIAL = const(0) _INITIAL = const(0)
_PRESSED = const(1) _PRESSED = const(1)
@ -110,39 +119,53 @@ _BORDER = const(4) # border size in pixels
class Button(ui.Control): class Button(ui.Control):
def __init__(self, area, content, style=ButtonDefault): def __init__(
self,
area: ui.Area,
content: ButtonContent,
style: ButtonStyleType = ButtonDefault,
) -> None:
if isinstance(content, str):
self.text = content
self.icon = b""
elif isinstance(content, bytes):
self.icon = content
self.text = ""
else:
raise TypeError
self.area = area self.area = area
self.content = content
self.normal_style = style.normal self.normal_style = style.normal
self.active_style = style.active self.active_style = style.active
self.disabled_style = style.disabled self.disabled_style = style.disabled
self.state = _INITIAL self.state = _INITIAL
self.repaint = True self.repaint = True
def enable(self): def enable(self) -> None:
if self.state is not _INITIAL: if self.state is not _INITIAL:
self.state = _INITIAL self.state = _INITIAL
self.repaint = True self.repaint = True
def disable(self): def disable(self) -> None:
if self.state is not _DISABLED: if self.state is not _DISABLED:
self.state = _DISABLED self.state = _DISABLED
self.repaint = True self.repaint = True
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
if self.state is _DISABLED: if self.state is _INITIAL or self.state is _RELEASED:
s = self.normal_style
elif self.state is _DISABLED:
s = self.disabled_style s = self.disabled_style
elif self.state is _PRESSED: elif self.state is _PRESSED:
s = self.active_style s = self.active_style
else:
s = self.normal_style
ax, ay, aw, ah = self.area ax, ay, aw, ah = self.area
self.render_background(s, ax, ay, aw, ah) self.render_background(s, ax, ay, aw, ah)
self.render_content(s, ax, ay, aw, ah) self.render_content(s, ax, ay, aw, ah)
self.repaint = False self.repaint = False
def render_background(self, s, ax, ay, aw, ah): def render_background(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
radius = s.radius radius = s.radius
bg_color = s.bg_color bg_color = s.bg_color
border_color = s.border_color border_color = s.border_color
@ -162,16 +185,21 @@ class Button(ui.Control):
radius, radius,
) )
def render_content(self, s, ax, ay, aw, ah): def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
tx = ax + aw // 2 tx = ax + aw // 2
ty = ay + ah // 2 + 8 ty = ay + ah // 2 + 8
t = self.content t = self.text
if isinstance(t, str): if t:
display.text_center(tx, ty, t, s.text_style, s.fg_color, s.bg_color) display.text_center(tx, ty, t, s.text_style, s.fg_color, s.bg_color)
elif isinstance(t, bytes): return
display.icon(tx - _ICON // 2, ty - _ICON, t, s.fg_color, s.bg_color) i = self.icon
if i:
display.icon(tx - _ICON // 2, ty - _ICON, i, s.fg_color, s.bg_color)
return
def on_touch_start(self, x, y): def on_touch_start(self, x: int, y: int) -> None:
if self.state is _DISABLED: if self.state is _DISABLED:
return return
if in_area(self.area, x, y): if in_area(self.area, x, y):
@ -179,7 +207,7 @@ class Button(ui.Control):
self.repaint = True self.repaint = True
self.on_press_start() self.on_press_start()
def on_touch_move(self, x, y): def on_touch_move(self, x: int, y: int) -> None:
if self.state is _DISABLED: if self.state is _DISABLED:
return return
if in_area(self.area, x, y): if in_area(self.area, x, y):
@ -193,7 +221,7 @@ class Button(ui.Control):
self.repaint = True self.repaint = True
self.on_press_end() self.on_press_end()
def on_touch_end(self, x, y): def on_touch_end(self, x: int, y: int) -> None:
state = self.state state = self.state
if state is not _INITIAL and state is not _DISABLED: if state is not _INITIAL and state is not _DISABLED:
self.state = _INITIAL self.state = _INITIAL
@ -203,11 +231,11 @@ class Button(ui.Control):
self.on_press_end() self.on_press_end()
self.on_click() self.on_click()
def on_press_start(self): def on_press_start(self) -> None:
pass pass
def on_press_end(self): def on_press_end(self) -> None:
pass pass
def on_click(self): def on_click(self) -> None:
pass pass

@ -3,32 +3,37 @@ from micropython import const
from trezor import res, ui from trezor import res, ui
from trezor.ui.text import TEXT_HEADER_HEIGHT, TEXT_LINE_HEIGHT from trezor.ui.text import TEXT_HEADER_HEIGHT, TEXT_LINE_HEIGHT
if False:
from typing import Iterable, List, Union
ChecklistItem = Union[str, Iterable[str]]
_CHECKLIST_MAX_LINES = const(5) _CHECKLIST_MAX_LINES = const(5)
_CHECKLIST_OFFSET_X = const(24) _CHECKLIST_OFFSET_X = const(24)
_CHECKLIST_OFFSET_X_ICON = const(0) _CHECKLIST_OFFSET_X_ICON = const(0)
class Checklist(ui.Control): class Checklist(ui.Control):
def __init__(self, title, icon): def __init__(self, title: str, icon: str) -> None:
self.title = title self.title = title
self.icon = icon self.icon = icon
self.items = [] self.items = [] # type: List[ChecklistItem]
self.active = 0 self.active = 0
self.repaint = True self.repaint = True
def add(self, choice): def add(self, item: ChecklistItem) -> None:
self.items.append(choice) self.items.append(item)
def select(self, active): def select(self, active: int) -> None:
self.active = active self.active = active
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
ui.header(self.title, self.icon) ui.header(self.title, self.icon)
self.render_items() self.render_items()
self.repaint = False self.repaint = False
def render_items(self): def render_items(self) -> None:
offset_x = _CHECKLIST_OFFSET_X offset_x = _CHECKLIST_OFFSET_X
offset_y = TEXT_HEADER_HEIGHT + TEXT_LINE_HEIGHT offset_y = TEXT_HEADER_HEIGHT + TEXT_LINE_HEIGHT
bg = ui.BG bg = ui.BG

@ -2,6 +2,10 @@ from trezor import res, ui
from trezor.ui.button import Button, ButtonCancel, ButtonConfirm from trezor.ui.button import Button, ButtonCancel, ButtonConfirm
from trezor.ui.loader import Loader, LoaderDefault from trezor.ui.loader import Loader, LoaderDefault
if False:
from trezor.ui.button import ButtonContent, ButtonStyleType
from trezor.ui.loader import LoaderStyleType
CONFIRMED = object() CONFIRMED = object()
CANCELLED = object() CANCELLED = object()
@ -14,13 +18,13 @@ class Confirm(ui.Layout):
def __init__( def __init__(
self, self,
content, content: ui.Control,
confirm=DEFAULT_CONFIRM, confirm: ButtonContent = DEFAULT_CONFIRM,
confirm_style=DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE,
cancel=DEFAULT_CANCEL, cancel: ButtonContent = DEFAULT_CANCEL,
cancel_style=DEFAULT_CANCEL_STYLE, cancel_style: ButtonStyleType = DEFAULT_CANCEL_STYLE,
major_confirm=False, major_confirm: bool = False,
): ) -> None:
self.content = content self.content = content
if confirm is not None: if confirm is not None:
@ -31,7 +35,7 @@ class Confirm(ui.Layout):
else: else:
area = ui.grid(9, n_x=2) area = ui.grid(9, n_x=2)
self.confirm = Button(area, confirm, confirm_style) self.confirm = Button(area, confirm, confirm_style)
self.confirm.on_click = self.on_confirm self.confirm.on_click = self.on_confirm # type: ignore
else: else:
self.confirm = None self.confirm = None
@ -43,21 +47,21 @@ class Confirm(ui.Layout):
else: else:
area = ui.grid(8, n_x=2) area = ui.grid(8, n_x=2)
self.cancel = Button(area, cancel, cancel_style) self.cancel = Button(area, cancel, cancel_style)
self.cancel.on_click = self.on_cancel self.cancel.on_click = self.on_cancel # type: ignore
else: else:
self.cancel = None self.cancel = None
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y) self.content.dispatch(event, x, y)
if self.confirm is not None: if self.confirm is not None:
self.confirm.dispatch(event, x, y) self.confirm.dispatch(event, x, y)
if self.cancel is not None: if self.cancel is not None:
self.cancel.dispatch(event, x, y) self.cancel.dispatch(event, x, y)
def on_confirm(self): def on_confirm(self) -> None:
raise ui.Result(CONFIRMED) raise ui.Result(CONFIRMED)
def on_cancel(self): def on_cancel(self) -> None:
raise ui.Result(CANCELLED) raise ui.Result(CANCELLED)
@ -68,44 +72,44 @@ class HoldToConfirm(ui.Layout):
def __init__( def __init__(
self, self,
content, content: ui.Control,
confirm=DEFAULT_CONFIRM, confirm: str = DEFAULT_CONFIRM,
confirm_style=DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE,
loader_style=DEFAULT_LOADER_STYLE, loader_style: LoaderStyleType = DEFAULT_LOADER_STYLE,
): ):
self.content = content self.content = content
self.loader = Loader(loader_style) self.loader = Loader(loader_style)
self.loader.on_start = self._on_loader_start self.loader.on_start = self._on_loader_start # type: ignore
self.button = Button(ui.grid(4, n_x=1), confirm, confirm_style) self.button = Button(ui.grid(4, n_x=1), confirm, confirm_style)
self.button.on_press_start = self._on_press_start self.button.on_press_start = self._on_press_start # type: ignore
self.button.on_press_end = self._on_press_end self.button.on_press_end = self._on_press_end # type: ignore
self.button.on_click = self._on_click self.button.on_click = self._on_click # type: ignore
def _on_press_start(self): def _on_press_start(self) -> None:
self.loader.start() self.loader.start()
def _on_press_end(self): def _on_press_end(self) -> None:
self.loader.stop() self.loader.stop()
def _on_loader_start(self): def _on_loader_start(self) -> None:
# Loader has either started growing, or returned to the 0-position. # Loader has either started growing, or returned to the 0-position.
# In the first case we need to clear the content leftovers, in the latter # In the first case we need to clear the content leftovers, in the latter
# we need to render the content again. # we need to render the content again.
ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT - 60, ui.BG) ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT - 60, ui.BG)
self.content.dispatch(ui.REPAINT, 0, 0) self.content.dispatch(ui.REPAINT, 0, 0)
def _on_click(self): def _on_click(self) -> None:
if self.loader.elapsed_ms() >= self.loader.target_ms: if self.loader.elapsed_ms() >= self.loader.target_ms:
self.on_confirm() self.on_confirm()
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
if self.loader.start_ms is not None: if self.loader.start_ms is not None:
self.loader.dispatch(event, x, y) self.loader.dispatch(event, x, y)
else: else:
self.content.dispatch(event, x, y) self.content.dispatch(event, x, y)
self.button.dispatch(event, x, y) self.button.dispatch(event, x, y)
def on_confirm(self): def on_confirm(self) -> None:
raise ui.Result(CONFIRMED) raise ui.Result(CONFIRMED)

@ -2,9 +2,9 @@ from trezor import ui
class Container(ui.Control): class Container(ui.Control):
def __init__(self, *children): def __init__(self, *children: ui.Control):
self.children = children self.children = children
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
for child in self.children: for child in self.children:
child.dispatch(event, x, y) child.dispatch(event, x, y)

@ -3,6 +3,10 @@ from trezor.ui.button import Button, ButtonConfirm
from trezor.ui.confirm import CONFIRMED from trezor.ui.confirm import CONFIRMED
from trezor.ui.text import TEXT_LINE_HEIGHT, TEXT_MARGIN_LEFT, render_text from trezor.ui.text import TEXT_LINE_HEIGHT, TEXT_MARGIN_LEFT, render_text
if False:
from typing import Type
from trezor.ui.button import ButtonContent
class DefaultInfoConfirm: class DefaultInfoConfirm:
@ -17,26 +21,35 @@ class DefaultInfoConfirm:
border_color = ui.BLACKISH border_color = ui.BLACKISH
if False:
InfoConfirmStyleType = Type[DefaultInfoConfirm]
class InfoConfirm(ui.Layout): class InfoConfirm(ui.Layout):
DEFAULT_CONFIRM = res.load(ui.ICON_CONFIRM) DEFAULT_CONFIRM = res.load(ui.ICON_CONFIRM)
DEFAULT_STYLE = DefaultInfoConfirm DEFAULT_STYLE = DefaultInfoConfirm
def __init__(self, text, confirm=DEFAULT_CONFIRM, style=DEFAULT_STYLE): def __init__(
self,
text: str,
confirm: ButtonContent = DEFAULT_CONFIRM,
style: InfoConfirmStyleType = DEFAULT_STYLE,
) -> None:
self.text = text.split() self.text = text.split()
self.style = style self.style = style
panel_area = ui.grid(0, n_x=1, n_y=1) panel_area = ui.grid(0, n_x=1, n_y=1)
self.panel_area = panel_area self.panel_area = panel_area
confirm_area = ui.grid(4, n_x=1) confirm_area = ui.grid(4, n_x=1)
self.confirm = Button(confirm_area, confirm, style.button) self.confirm = Button(confirm_area, confirm, style.button)
self.confirm.on_click = self.on_confirm self.confirm.on_click = self.on_confirm # type: ignore
self.repaint = True self.repaint = True
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
if event == ui.RENDER: if event == ui.RENDER:
self.on_render() self.on_render()
self.confirm.dispatch(event, x, y) self.confirm.dispatch(event, x, y)
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
x, y, w, h = self.panel_area x, y, w, h = self.panel_area
fg_color = self.style.fg_color fg_color = self.style.fg_color
@ -46,7 +59,7 @@ class InfoConfirm(ui.Layout):
ui.display.bar_radius(x, y, w, h, bg_color, ui.BG, ui.RADIUS) ui.display.bar_radius(x, y, w, h, bg_color, ui.BG, ui.RADIUS)
# render the info text # render the info text
render_text( render_text( # type: ignore
self.text, self.text,
new_lines=False, new_lines=False,
max_lines=6, max_lines=6,
@ -59,5 +72,5 @@ class InfoConfirm(ui.Layout):
self.repaint = False self.repaint = False
def on_confirm(self): def on_confirm(self) -> None:
raise ui.Result(CONFIRMED) raise ui.Result(CONFIRMED)

@ -4,22 +4,25 @@ from micropython import const
from trezor import res, ui from trezor import res, ui
from trezor.ui import display from trezor.ui import display
if False:
from typing import Optional, Type
class LoaderDefault: class LoaderDefault:
class normal: class normal:
bg_color = ui.BG bg_color = ui.BG
fg_color = ui.GREEN fg_color = ui.GREEN
icon = None icon = None # type: Optional[str]
icon_fg_color = None icon_fg_color = None # type: Optional[int]
class active: class active(normal):
bg_color = ui.BG bg_color = ui.BG
fg_color = ui.GREEN fg_color = ui.GREEN
icon = ui.ICON_CHECK icon = ui.ICON_CHECK
icon_fg_color = ui.WHITE icon_fg_color = ui.WHITE
class LoaderDanger: class LoaderDanger(LoaderDefault):
class normal(LoaderDefault.normal): class normal(LoaderDefault.normal):
fg_color = ui.RED fg_color = ui.RED
@ -27,31 +30,35 @@ class LoaderDanger:
fg_color = ui.RED fg_color = ui.RED
if False:
LoaderStyleType = Type[LoaderDefault]
_TARGET_MS = const(1000) _TARGET_MS = const(1000)
class Loader(ui.Control): class Loader(ui.Control):
def __init__(self, style=LoaderDefault): def __init__(self, style: LoaderStyleType = LoaderDefault) -> None:
self.normal_style = style.normal self.normal_style = style.normal
self.active_style = style.active self.active_style = style.active
self.target_ms = _TARGET_MS self.target_ms = _TARGET_MS
self.start_ms = None self.start_ms = None
self.stop_ms = None self.stop_ms = None
def start(self): def start(self) -> None:
self.start_ms = utime.ticks_ms() self.start_ms = utime.ticks_ms()
self.stop_ms = None self.stop_ms = None
self.on_start() self.on_start()
def stop(self): def stop(self) -> None:
self.stop_ms = utime.ticks_ms() self.stop_ms = utime.ticks_ms()
def elapsed_ms(self): def elapsed_ms(self) -> int:
if self.start_ms is None: if self.start_ms is None:
return 0 return 0
return utime.ticks_ms() - self.start_ms return utime.ticks_ms() - self.start_ms
def on_render(self): def on_render(self) -> None:
target = self.target_ms target = self.target_ms
start = self.start_ms start = self.start_ms
stop = self.stop_ms stop = self.stop_ms
@ -60,10 +67,10 @@ class Loader(ui.Control):
r = min(now - start, target) r = min(now - start, target)
else: else:
r = max(stop - start + (stop - now) * 2, 0) r = max(stop - start + (stop - now) * 2, 0)
if r == target: if r != target:
s = self.active_style
else:
s = self.normal_style s = self.normal_style
else:
s = self.active_style
Y = const(-24) Y = const(-24)
@ -80,7 +87,7 @@ class Loader(ui.Control):
if r == target: if r == target:
self.on_finish() self.on_finish()
def on_start(self): def on_start(self) -> None:
pass pass
def on_finish(self): def on_finish(self):

@ -3,6 +3,10 @@ from trezor.crypto import bip39
from trezor.ui import display from trezor.ui import display
from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm
if False:
from typing import Optional
from trezor.ui.button import ButtonContent, ButtonStyleStateType
def compute_mask(text: str) -> int: def compute_mask(text: str) -> int:
mask = 0 mask = 0
@ -15,44 +19,47 @@ def compute_mask(text: str) -> int:
class KeyButton(Button): class KeyButton(Button):
def __init__(self, area, content, keyboard): def __init__(
self, area: ui.Area, content: ButtonContent, keyboard: "Bip39Keyboard"
):
self.keyboard = keyboard self.keyboard = keyboard
super().__init__(area, content) super().__init__(area, content)
def on_click(self): def on_click(self) -> None:
self.keyboard.on_key_click(self) self.keyboard.on_key_click(self)
class InputButton(Button): class InputButton(Button):
def __init__(self, area, content, word): def __init__(self, area: ui.Area, text: str, word: str) -> None:
super().__init__(area, content) super().__init__(area, text)
self.word = word self.word = word
self.pending = False # should we draw the pending marker? self.pending = False
self.icon = None # rendered icon
self.disable() self.disable()
def edit(self, content, word, pending): def edit(self, text: str, word: str, pending: bool) -> None:
self.word = word self.word = word
self.content = content self.text = text
self.pending = pending self.pending = pending
self.repaint = True self.repaint = True
if word: if word:
if content == word: # confirm button if text == word: # confirm button
self.enable() self.enable()
self.normal_style = ButtonMonoConfirm.normal self.normal_style = ButtonMonoConfirm.normal
self.active_style = ButtonMonoConfirm.active self.active_style = ButtonMonoConfirm.active
self.icon = ui.ICON_CONFIRM self.icon = res.load(ui.ICON_CONFIRM)
else: # auto-complete button else: # auto-complete button
self.enable() self.enable()
self.normal_style = ButtonMono.normal self.normal_style = ButtonMono.normal
self.active_style = ButtonMono.active self.active_style = ButtonMono.active
self.icon = ui.ICON_CLICK self.icon = res.load(ui.ICON_CLICK)
else: # disabled button else: # disabled button
self.disabled_style = ButtonMono.disabled self.disabled_style = ButtonMono.disabled
self.disable() self.disable()
self.icon = None self.icon = b""
def render_content(self, s, ax, ay, aw, ah): def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
text_style = s.text_style text_style = s.text_style
fg_color = s.fg_color fg_color = s.fg_color
bg_color = s.bg_color bg_color = s.bg_color
@ -61,29 +68,29 @@ class InputButton(Button):
ty = ay + ah // 2 + 8 # y-offset of the content ty = ay + ah // 2 + 8 # y-offset of the content
# entered content # entered content
display.text(tx, ty, self.content, text_style, fg_color, bg_color) display.text(tx, ty, self.text, text_style, fg_color, bg_color)
# word suggestion # word suggestion
suggested_word = self.word[len(self.content) :] suggested_word = self.word[len(self.text) :]
width = display.text_width(self.content, text_style) width = display.text_width(self.text, text_style)
display.text(tx + width, ty, suggested_word, text_style, ui.GREY, bg_color) display.text(tx + width, ty, suggested_word, text_style, ui.GREY, bg_color)
if self.pending: if self.pending:
pw = display.text_width(self.content[-1:], text_style) pw = display.text_width(self.text[-1:], text_style)
px = tx + width - pw px = tx + width - pw
display.bar(px, ty + 2, pw + 1, 3, fg_color) display.bar(px, ty + 2, pw + 1, 3, fg_color)
if self.icon: if self.icon:
ix = ax + aw - 16 * 2 ix = ax + aw - 16 * 2
iy = ty - 16 iy = ty - 16
display.icon(ix, iy, res.load(self.icon), fg_color, bg_color) display.icon(ix, iy, self.icon, fg_color, bg_color)
class Prompt(ui.Control): class Prompt(ui.Control):
def __init__(self, prompt): def __init__(self, prompt: str) -> None:
self.prompt = prompt self.prompt = prompt
self.repaint = True self.repaint = True
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
display.bar(0, 8, ui.WIDTH, 60, ui.BG) display.bar(0, 8, ui.WIDTH, 60, ui.BG)
display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG) display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG)
@ -91,15 +98,15 @@ class Prompt(ui.Control):
class Bip39Keyboard(ui.Layout): class Bip39Keyboard(ui.Layout):
def __init__(self, prompt): def __init__(self, prompt: str) -> None:
self.prompt = Prompt(prompt) self.prompt = Prompt(prompt)
icon_back = res.load(ui.ICON_BACK) icon_back = res.load(ui.ICON_BACK)
self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear) self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear)
self.back.on_click = self.on_back_click self.back.on_click = self.on_back_click # type: ignore
self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), "", "") self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), "", "")
self.input.on_click = self.on_input_click self.input.on_click = self.on_input_click # type: ignore
self.keys = [ self.keys = [
KeyButton(ui.grid(i + 3, n_y=4), k, self) KeyButton(ui.grid(i + 3, n_y=4), k, self)
@ -107,82 +114,82 @@ class Bip39Keyboard(ui.Layout):
("abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz") ("abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz")
) )
] ]
self.pending_button = None self.pending_button = None # type: Optional[Button]
self.pending_index = 0 self.pending_index = 0
def dispatch(self, event: int, x: int, y: int): def dispatch(self, event: int, x: int, y: int) -> None:
for btn in self.keys: for btn in self.keys:
btn.dispatch(event, x, y) btn.dispatch(event, x, y)
if self.input.content: if self.input.text:
self.input.dispatch(event, x, y) self.input.dispatch(event, x, y)
self.back.dispatch(event, x, y) self.back.dispatch(event, x, y)
else: else:
self.prompt.dispatch(event, x, y) self.prompt.dispatch(event, x, y)
def on_back_click(self): def on_back_click(self) -> None:
# Backspace was clicked, let's delete the last character of input. # Backspace was clicked, let's delete the last character of input.
self.edit(self.input.content[:-1]) self.edit(self.input.text[:-1])
def on_input_click(self): def on_input_click(self) -> None:
# Input button was clicked. If the content matches the suggested word, # Input button was clicked. If the content matches the suggested word,
# let's confirm it, otherwise just auto-complete. # let's confirm it, otherwise just auto-complete.
content = self.input.content text = self.input.text
word = self.input.word word = self.input.word
if word and word == content: if word and word == text:
self.edit("") self.edit("")
self.on_confirm(word) self.on_confirm(word)
else: else:
self.edit(word) self.edit(word)
def on_key_click(self, btn: Button): def on_key_click(self, btn: Button) -> None:
# Key button was clicked. If this button is pending, let's cycle the # Key button was clicked. If this button is pending, let's cycle the
# pending character in input. If not, let's just append the first # pending character in input. If not, let's just append the first
# character. # character.
if self.pending_button is btn: if self.pending_button is btn:
index = (self.pending_index + 1) % len(btn.content) index = (self.pending_index + 1) % len(btn.text)
content = self.input.content[:-1] + btn.content[index] text = self.input.text[:-1] + btn.text[index]
else: else:
index = 0 index = 0
content = self.input.content + btn.content[0] text = self.input.text + btn.text[0]
self.edit(content, btn, index) self.edit(text, btn, index)
def on_timeout(self): def on_timeout(self) -> None:
# Timeout occurred. If we can auto-complete current input, let's just # Timeout occurred. If we can auto-complete current input, let's just
# reset the pending marker. If not, input is invalid, let's backspace # reset the pending marker. If not, input is invalid, let's backspace
# the last character. # the last character.
if self.input.word: if self.input.word:
self.edit(self.input.content) self.edit(self.input.text)
else: else:
self.edit(self.input.content[:-1]) self.edit(self.input.text[:-1])
def on_confirm(self, word): def on_confirm(self, word: str) -> None:
# Word was confirmed by the user. # Word was confirmed by the user.
raise ui.Result(word) raise ui.Result(word)
def edit(self, content: str, button: KeyButton = None, index: int = 0): def edit(self, text: str, button: Button = None, index: int = 0) -> None:
self.pending_button = button self.pending_button = button
self.pending_index = index self.pending_index = index
# find the completions # find the completions
pending = button is not None pending = button is not None
word = bip39.find_word(content) or "" word = bip39.find_word(text) or ""
mask = bip39.complete_word(content) mask = bip39.complete_word(text)
# modify the input state # modify the input state
self.input.edit(content, word, pending) self.input.edit(text, word, pending)
# enable or disable key buttons # enable or disable key buttons
for btn in self.keys: for btn in self.keys:
if btn is button or compute_mask(btn.content) & mask: if btn is button or compute_mask(btn.text) & mask:
btn.enable() btn.enable()
else: else:
btn.disable() btn.disable()
# invalidate the prompt if we display it next frame # invalidate the prompt if we display it next frame
if not self.input.content: if not self.input.text:
self.prompt.repaint = True self.prompt.repaint = True
async def handle_input(self): async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1) timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch) spawn_touch = loop.spawn(touch)

@ -3,44 +3,61 @@ from trezor.crypto import slip39
from trezor.ui import display from trezor.ui import display
from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm
if False:
from typing import Optional
from trezor.ui.button import ButtonContent, ButtonStyleStateType
class KeyButton(Button): class KeyButton(Button):
def __init__(self, area, content, keyboard, index): def __init__(
self,
area: ui.Area,
content: ButtonContent,
keyboard: "Slip39Keyboard",
index: int,
):
self.keyboard = keyboard self.keyboard = keyboard
self.index = index self.index = index
super().__init__(area, content) super().__init__(area, content)
def on_click(self): def on_click(self) -> None:
self.keyboard.on_key_click(self) self.keyboard.on_key_click(self)
class InputButton(Button): class InputButton(Button):
def __init__(self, area, keyboard): def __init__(self, area: ui.Area, keyboard: "Slip39Keyboard") -> None:
super().__init__(area, "") super().__init__(area, "")
self.word = "" self.word = ""
self.pending_button = None self.pending_button = None # type: Optional[Button]
self.pending_index = None self.pending_index = None # type: Optional[int]
self.icon = None # rendered icon
self.keyboard = keyboard self.keyboard = keyboard
self.disable() self.disable()
def edit(self, content, word, pending_button, pending_index): def edit(
self,
text: str,
word: str,
pending_button: Optional[Button],
pending_index: Optional[int],
) -> None:
self.word = word self.word = word
self.content = content self.text = text
self.pending_button = pending_button self.pending_button = pending_button
self.pending_index = pending_index self.pending_index = pending_index
self.repaint = True self.repaint = True
if word: if word: # confirm button
self.enable() self.enable()
self.normal_style = ButtonMonoConfirm.normal self.normal_style = ButtonMonoConfirm.normal
self.active_style = ButtonMonoConfirm.active self.active_style = ButtonMonoConfirm.active
self.icon = ui.ICON_CONFIRM self.icon = res.load(ui.ICON_CONFIRM)
else: # disabled button else: # disabled button
self.disabled_style = ButtonMono.normal self.disabled_style = ButtonMono.disabled
self.disable() self.disable()
self.icon = None self.icon = b""
def render_content(self, s, ax, ay, aw, ah): def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
text_style = s.text_style text_style = s.text_style
fg_color = s.fg_color fg_color = s.fg_color
bg_color = s.bg_color bg_color = s.bg_color
@ -49,11 +66,11 @@ class InputButton(Button):
ty = ay + ah // 2 + 8 # y-offset of the content ty = ay + ah // 2 + 8 # y-offset of the content
if not self.keyboard.is_input_final(): if not self.keyboard.is_input_final():
to_display = len(self.content) * "*" pending_button = self.pending_button
if self.pending_button: pending_index = self.pending_index
to_display = ( to_display = len(self.text) * "*"
to_display[:-1] + self.pending_button.content[self.pending_index] if pending_button and pending_index is not None:
) to_display = to_display[:-1] + pending_button.text[pending_index]
else: else:
to_display = self.word to_display = self.word
@ -61,22 +78,22 @@ class InputButton(Button):
if self.pending_button and not self.keyboard.is_input_final(): if self.pending_button and not self.keyboard.is_input_final():
width = display.text_width(to_display, text_style) width = display.text_width(to_display, text_style)
pw = display.text_width(self.content[-1:], text_style) pw = display.text_width(self.text[-1:], text_style)
px = tx + width - pw px = tx + width - pw
display.bar(px, ty + 2, pw + 1, 3, fg_color) display.bar(px, ty + 2, pw + 1, 3, fg_color)
if self.icon: if self.icon:
ix = ax + aw - 16 * 2 ix = ax + aw - 16 * 2
iy = ty - 16 iy = ty - 16
display.icon(ix, iy, res.load(self.icon), fg_color, bg_color) display.icon(ix, iy, self.icon, fg_color, bg_color)
class Prompt(ui.Control): class Prompt(ui.Control):
def __init__(self, prompt): def __init__(self, prompt: str) -> None:
self.prompt = prompt self.prompt = prompt
self.repaint = True self.repaint = True
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
display.bar(0, 8, ui.WIDTH, 60, ui.BG) display.bar(0, 8, ui.WIDTH, 60, ui.BG)
display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG) display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG)
@ -84,15 +101,15 @@ class Prompt(ui.Control):
class Slip39Keyboard(ui.Layout): class Slip39Keyboard(ui.Layout):
def __init__(self, prompt): def __init__(self, prompt: str) -> None:
self.prompt = Prompt(prompt) self.prompt = Prompt(prompt)
icon_back = res.load(ui.ICON_BACK) icon_back = res.load(ui.ICON_BACK)
self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear) self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear)
self.back.on_click = self.on_back_click self.back.on_click = self.on_back_click # type: ignore
self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), self) self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), self)
self.input.on_click = self.on_input_click self.input.on_click = self.on_input_click # type: ignore
self.keys = [ self.keys = [
KeyButton(ui.grid(i + 3, n_y=4), k, self, i + 1) KeyButton(ui.grid(i + 3, n_y=4), k, self, i + 1)
@ -100,26 +117,26 @@ class Slip39Keyboard(ui.Layout):
("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz") ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
) )
] ]
self.pending_button = None self.pending_button = None # type: Optional[Button]
self.pending_index = 0 self.pending_index = 0
self.button_sequence = "" self.button_sequence = ""
self.mask = slip39.KEYBOARD_FULL_MASK self.mask = slip39.KEYBOARD_FULL_MASK
def dispatch(self, event: int, x: int, y: int): def dispatch(self, event: int, x: int, y: int) -> None:
for btn in self.keys: for btn in self.keys:
btn.dispatch(event, x, y) btn.dispatch(event, x, y)
if self.input.content: if self.input.text:
self.input.dispatch(event, x, y) self.input.dispatch(event, x, y)
self.back.dispatch(event, x, y) self.back.dispatch(event, x, y)
else: else:
self.prompt.dispatch(event, x, y) self.prompt.dispatch(event, x, y)
def on_back_click(self): def on_back_click(self) -> None:
# Backspace was clicked, let's delete the last character of input. # Backspace was clicked, let's delete the last character of input.
self.button_sequence = self.button_sequence[:-1] self.button_sequence = self.button_sequence[:-1]
self.edit() self.edit()
def on_input_click(self): def on_input_click(self) -> None:
# Input button was clicked. If the content matches the suggested word, # Input button was clicked. If the content matches the suggested word,
# let's confirm it, otherwise just auto-complete. # let's confirm it, otherwise just auto-complete.
result = self.input.word result = self.input.word
@ -128,26 +145,26 @@ class Slip39Keyboard(ui.Layout):
self.edit() self.edit()
self.on_confirm(result) self.on_confirm(result)
def on_key_click(self, btn: KeyButton): def on_key_click(self, btn: KeyButton) -> None:
# Key button was clicked. If this button is pending, let's cycle the # Key button was clicked. If this button is pending, let's cycle the
# pending character in input. If not, let's just append the first # pending character in input. If not, let's just append the first
# character. # character.
if self.pending_button is btn: if self.pending_button is btn:
index = (self.pending_index + 1) % len(btn.content) index = (self.pending_index + 1) % len(btn.text)
else: else:
index = 0 index = 0
self.button_sequence += str(btn.index) self.button_sequence += str(btn.index)
self.edit(btn, index) self.edit(btn, index)
def on_timeout(self): def on_timeout(self) -> None:
# Timeout occurred. Let's redraw to draw asterisks. # Timeout occurred. Let's redraw to draw asterisks.
self.edit() self.edit()
def on_confirm(self, word): def on_confirm(self, word: str) -> None:
# Word was confirmed by the user. # Word was confirmed by the user.
raise ui.Result(word) raise ui.Result(word)
def edit(self, button: KeyButton = None, index: int = 0): def edit(self, button: Button = None, index: int = 0) -> None:
self.pending_button = button self.pending_button = button
self.pending_index = index self.pending_index = index
@ -172,7 +189,7 @@ class Slip39Keyboard(ui.Layout):
btn.disable() btn.disable()
# invalidate the prompt if we display it next frame # invalidate the prompt if we display it next frame
if not self.input.content: if not self.input.text:
self.prompt.repaint = True self.prompt.repaint = True
def is_input_final(self) -> bool: def is_input_final(self) -> bool:
@ -182,7 +199,7 @@ class Slip39Keyboard(ui.Layout):
def check_mask(self, index: int) -> bool: def check_mask(self, index: int) -> bool:
return bool((1 << (index - 1)) & self.mask) return bool((1 << (index - 1)) & self.mask)
async def handle_input(self): async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1) timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch) spawn_touch = loop.spawn(touch)

@ -6,6 +6,10 @@ from trezor.ui import display
from trezor.ui.button import Button, ButtonClear, ButtonConfirm from trezor.ui.button import Button, ButtonClear, ButtonConfirm
from trezor.ui.swipe import SWIPE_HORIZONTAL, SWIPE_LEFT, Swipe from trezor.ui.swipe import SWIPE_HORIZONTAL, SWIPE_LEFT, Swipe
if False:
from typing import List, Iterable, Optional
from trezor.ui.button import ButtonContent, ButtonStyleStateType
SPACE = res.load(ui.ICON_SPACE) SPACE = res.load(ui.ICON_SPACE)
KEYBOARD_KEYS = ( KEYBOARD_KEYS = (
@ -16,13 +20,13 @@ KEYBOARD_KEYS = (
) )
def digit_area(i): def digit_area(i: int) -> ui.Area:
if i == 9: # 0-position if i == 9: # 0-position
i = 10 # display it in the middle i = 10 # display it in the middle
return ui.grid(i + 3) # skip the first line return ui.grid(i + 3) # skip the first line
def render_scrollbar(page): def render_scrollbar(page: int) -> None:
BBOX = const(240) BBOX = const(240)
SIZE = const(8) SIZE = const(8)
pages = len(KEYBOARD_KEYS) pages = len(KEYBOARD_KEYS)
@ -43,42 +47,50 @@ def render_scrollbar(page):
class KeyButton(Button): class KeyButton(Button):
def __init__(self, area, content, keyboard): def __init__(
self, area: ui.Area, content: ButtonContent, keyboard: "PassphraseKeyboard"
) -> None:
self.keyboard = keyboard self.keyboard = keyboard
super().__init__(area, content) super().__init__(area, content)
def on_click(self): def on_click(self) -> None:
self.keyboard.on_key_click(self) self.keyboard.on_key_click(self)
def get_text_content(self): def get_text_content(self) -> str:
if self.content is SPACE: if self.text:
return self.text
elif self.icon is SPACE:
return " " return " "
else: else:
return self.content raise TypeError
def key_buttons(keys, keyboard): def key_buttons(
keys: Iterable[ButtonContent], keyboard: "PassphraseKeyboard"
) -> List[KeyButton]:
return [KeyButton(digit_area(i), k, keyboard) for i, k in enumerate(keys)] return [KeyButton(digit_area(i), k, keyboard) for i, k in enumerate(keys)]
class Input(Button): class Input(Button):
def __init__(self, area, content): def __init__(self, area: ui.Area, text: str) -> None:
super().__init__(area, content) super().__init__(area, text)
self.pending = False self.pending = False
self.disable() self.disable()
def edit(self, content, pending): def edit(self, text: str, pending: bool) -> None:
self.content = content self.text = text
self.pending = pending self.pending = pending
self.repaint = True self.repaint = True
def render_content(self, s, ax, ay, aw, ah): def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
text_style = s.text_style text_style = s.text_style
fg_color = s.fg_color fg_color = s.fg_color
bg_color = s.bg_color bg_color = s.bg_color
p = self.pending # should we draw the pending marker? p = self.pending # should we draw the pending marker?
t = self.content # input content t = self.text # input content
tx = ax + 24 # x-offset of the content tx = ax + 24 # x-offset of the content
ty = ay + ah // 2 + 8 # y-offset of the content ty = ay + ah // 2 + 8 # y-offset of the content
@ -98,16 +110,16 @@ class Input(Button):
cx = tx + width + 1 cx = tx + width + 1
display.bar(cx, ty - 18, 2, 22, fg_color) display.bar(cx, ty - 18, 2, 22, fg_color)
def on_click(self): def on_click(self) -> None:
pass pass
class Prompt(ui.Control): class Prompt(ui.Control):
def __init__(self, text): def __init__(self, text: str) -> None:
self.text = text self.text = text
self.repaint = True self.repaint = True
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
display.bar(0, 0, ui.WIDTH, 48, ui.BG) display.bar(0, 0, ui.WIDTH, 48, ui.BG)
display.text_center(ui.WIDTH // 2, 32, self.text, ui.BOLD, ui.GREY, ui.BG) display.text_center(ui.WIDTH // 2, 32, self.text, ui.BOLD, ui.GREY, ui.BG)
@ -118,7 +130,7 @@ CANCELLED = object()
class PassphraseKeyboard(ui.Layout): class PassphraseKeyboard(ui.Layout):
def __init__(self, prompt, max_length, page=1): def __init__(self, prompt: str, max_length: int, page: int = 1) -> None:
self.prompt = Prompt(prompt) self.prompt = Prompt(prompt)
self.max_length = max_length self.max_length = max_length
self.page = page self.page = page
@ -126,18 +138,18 @@ class PassphraseKeyboard(ui.Layout):
self.input = Input(ui.grid(0, n_x=1, n_y=6), "") self.input = Input(ui.grid(0, n_x=1, n_y=6), "")
self.back = Button(ui.grid(12), res.load(ui.ICON_BACK), ButtonClear) self.back = Button(ui.grid(12), res.load(ui.ICON_BACK), ButtonClear)
self.back.on_click = self.on_back_click self.back.on_click = self.on_back_click # type: ignore
self.back.disable() self.back.disable()
self.done = Button(ui.grid(14), res.load(ui.ICON_CONFIRM), ButtonConfirm) self.done = Button(ui.grid(14), res.load(ui.ICON_CONFIRM), ButtonConfirm)
self.done.on_click = self.on_confirm self.done.on_click = self.on_confirm # type: ignore
self.keys = key_buttons(KEYBOARD_KEYS[self.page], self) self.keys = key_buttons(KEYBOARD_KEYS[self.page], self)
self.pending_button = None self.pending_button = None # type: Optional[KeyButton]
self.pending_index = 0 self.pending_index = 0
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
if self.input.content: if self.input.text:
self.input.dispatch(event, x, y) self.input.dispatch(event, x, y)
else: else:
self.prompt.dispatch(event, x, y) self.prompt.dispatch(event, x, y)
@ -149,37 +161,37 @@ class PassphraseKeyboard(ui.Layout):
if event == ui.RENDER: if event == ui.RENDER:
render_scrollbar(self.page) render_scrollbar(self.page)
def on_back_click(self): def on_back_click(self) -> None:
# Backspace was clicked. If we have any content in the input, let's delete # Backspace was clicked. If we have any content in the input, let's delete
# the last character. Otherwise cancel. # the last character. Otherwise cancel.
content = self.input.content text = self.input.text
if content: if text:
self.edit(content[:-1]) self.edit(text[:-1])
else: else:
self.on_cancel() self.on_cancel()
def on_key_click(self, button: KeyButton): def on_key_click(self, button: KeyButton) -> None:
# Key button was clicked. If this button is pending, let's cycle the # Key button was clicked. If this button is pending, let's cycle the
# pending character in input. If not, let's just append the first # pending character in input. If not, let's just append the first
# character. # character.
button_text = button.get_text_content() button_text = button.get_text_content()
if self.pending_button is button: if self.pending_button is button:
index = (self.pending_index + 1) % len(button_text) index = (self.pending_index + 1) % len(button_text)
prefix = self.input.content[:-1] prefix = self.input.text[:-1]
else: else:
index = 0 index = 0
prefix = self.input.content prefix = self.input.text
if len(button_text) > 1: if len(button_text) > 1:
self.edit(prefix + button_text[index], button, index) self.edit(prefix + button_text[index], button, index)
else: else:
self.edit(prefix + button_text[index]) self.edit(prefix + button_text[index])
def on_timeout(self): def on_timeout(self) -> None:
# Timeout occurred, let's just reset the pending marker. # Timeout occurred, let's just reset the pending marker.
self.edit(self.input.content) self.edit(self.input.text)
def edit(self, content: str, button: Button = None, index: int = 0): def edit(self, text: str, button: KeyButton = None, index: int = 0) -> None:
if len(content) > self.max_length: if len(text) > self.max_length:
return return
self.pending_button = button self.pending_button = button
@ -187,15 +199,15 @@ class PassphraseKeyboard(ui.Layout):
# modify the input state # modify the input state
pending = button is not None pending = button is not None
self.input.edit(content, pending) self.input.edit(text, pending)
if content: if text:
self.back.enable() self.back.enable()
else: else:
self.back.disable() self.back.disable()
self.prompt.repaint = True self.prompt.repaint = True
async def handle_input(self): async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1) timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch) spawn_touch = loop.spawn(touch)
@ -214,7 +226,7 @@ class PassphraseKeyboard(ui.Layout):
else: else:
self.on_timeout() self.on_timeout()
async def handle_paging(self): async def handle_paging(self) -> None:
swipe = await Swipe(SWIPE_HORIZONTAL) swipe = await Swipe(SWIPE_HORIZONTAL)
if swipe == SWIPE_LEFT: if swipe == SWIPE_LEFT:
self.page = (self.page + 1) % len(KEYBOARD_KEYS) self.page = (self.page + 1) % len(KEYBOARD_KEYS)
@ -226,33 +238,33 @@ class PassphraseKeyboard(ui.Layout):
self.input.repaint = True self.input.repaint = True
self.prompt.repaint = True self.prompt.repaint = True
def on_cancel(self): def on_cancel(self) -> None:
raise ui.Result(CANCELLED) raise ui.Result(CANCELLED)
def on_confirm(self): def on_confirm(self) -> None:
raise ui.Result(self.input.content) raise ui.Result(self.input.text)
def create_tasks(self): def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering(), self.handle_paging() return self.handle_input(), self.handle_rendering(), self.handle_paging()
class PassphraseSource(ui.Layout): class PassphraseSource(ui.Layout):
def __init__(self, content): def __init__(self, content: ui.Control) -> None:
self.content = content self.content = content
self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device") self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device")
self.device.on_click = self.on_device self.device.on_click = self.on_device # type: ignore
self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), "Host") self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), "Host")
self.host.on_click = self.on_host self.host.on_click = self.on_host # type: ignore
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y) self.content.dispatch(event, x, y)
self.device.dispatch(event, x, y) self.device.dispatch(event, x, y)
self.host.dispatch(event, x, y) self.host.dispatch(event, x, y)
def on_device(self): def on_device(self) -> None:
raise ui.Result(PassphraseSourceType.DEVICE) raise ui.Result(PassphraseSourceType.DEVICE)
def on_host(self): def on_host(self) -> None:
raise ui.Result(PassphraseSourceType.HOST) raise ui.Result(PassphraseSourceType.HOST)

@ -11,14 +11,17 @@ from trezor.ui.button import (
ButtonMono, ButtonMono,
) )
if False:
from typing import Iterable
def digit_area(i):
def digit_area(i: int) -> ui.Area:
if i == 9: # 0-position if i == 9: # 0-position
i = 10 # display it in the middle i = 10 # display it in the middle
return ui.grid(i + 3) # skip the first line return ui.grid(i + 3) # skip the first line
def generate_digits(): def generate_digits() -> Iterable[int]:
digits = list(range(0, 10)) # 0-9 digits = list(range(0, 10)) # 0-9
random.shuffle(digits) random.shuffle(digits)
# We lay out the buttons top-left to bottom-right, but the order # We lay out the buttons top-left to bottom-right, but the order
@ -27,13 +30,13 @@ def generate_digits():
class PinInput(ui.Control): class PinInput(ui.Control):
def __init__(self, prompt, subprompt, pin): def __init__(self, prompt: str, subprompt: str, pin: str) -> None:
self.prompt = prompt self.prompt = prompt
self.subprompt = subprompt self.subprompt = subprompt
self.pin = pin self.pin = pin
self.repaint = True self.repaint = True
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
if self.pin: if self.pin:
self.render_pin() self.render_pin()
@ -41,7 +44,7 @@ class PinInput(ui.Control):
self.render_prompt() self.render_prompt()
self.repaint = False self.repaint = False
def render_pin(self): def render_pin(self) -> None:
display.bar(0, 0, ui.WIDTH, 50, ui.BG) display.bar(0, 0, ui.WIDTH, 50, ui.BG)
count = len(self.pin) count = len(self.pin)
BOX_WIDTH = const(240) BOX_WIDTH = const(240)
@ -54,7 +57,7 @@ class PinInput(ui.Control):
render_x + i * PADDING, RENDER_Y, DOT_SIZE, DOT_SIZE, ui.GREY, ui.BG, 4 render_x + i * PADDING, RENDER_Y, DOT_SIZE, DOT_SIZE, ui.GREY, ui.BG, 4
) )
def render_prompt(self): def render_prompt(self) -> None:
display.bar(0, 0, ui.WIDTH, 50, ui.BG) display.bar(0, 0, ui.WIDTH, 50, ui.BG)
if self.subprompt: if self.subprompt:
display.text_center(ui.WIDTH // 2, 20, self.prompt, ui.BOLD, ui.GREY, ui.BG) display.text_center(ui.WIDTH // 2, 20, self.prompt, ui.BOLD, ui.GREY, ui.BG)
@ -66,35 +69,37 @@ class PinInput(ui.Control):
class PinButton(Button): class PinButton(Button):
def __init__(self, index, digit, matrix): def __init__(self, index: int, digit: int, dialog: "PinDialog"):
self.matrix = matrix self.dialog = dialog
super().__init__(digit_area(index), str(digit), ButtonMono) super().__init__(digit_area(index), str(digit), ButtonMono)
def on_click(self): def on_click(self) -> None:
self.matrix.assign(self.matrix.input.pin + self.content) self.dialog.assign(self.dialog.input.pin + self.text)
CANCELLED = object() CANCELLED = object()
class PinDialog(ui.Layout): class PinDialog(ui.Layout):
def __init__(self, prompt, subprompt, allow_cancel=True, maxlength=9): def __init__(
self, prompt: str, subprompt: str, allow_cancel: bool = True, maxlength: int = 9
) -> None:
self.maxlength = maxlength self.maxlength = maxlength
self.input = PinInput(prompt, subprompt, "") self.input = PinInput(prompt, subprompt, "")
icon_confirm = res.load(ui.ICON_CONFIRM) icon_confirm = res.load(ui.ICON_CONFIRM)
self.confirm_button = Button(ui.grid(14), icon_confirm, ButtonConfirm) self.confirm_button = Button(ui.grid(14), icon_confirm, ButtonConfirm)
self.confirm_button.on_click = self.on_confirm self.confirm_button.on_click = self.on_confirm # type: ignore
self.confirm_button.disable() self.confirm_button.disable()
icon_back = res.load(ui.ICON_BACK) icon_back = res.load(ui.ICON_BACK)
self.reset_button = Button(ui.grid(12), icon_back, ButtonClear) self.reset_button = Button(ui.grid(12), icon_back, ButtonClear)
self.reset_button.on_click = self.on_reset self.reset_button.on_click = self.on_reset # type: ignore
if allow_cancel: if allow_cancel:
icon_lock = res.load(ui.ICON_LOCK) icon_lock = res.load(ui.ICON_LOCK)
self.cancel_button = Button(ui.grid(12), icon_lock, ButtonCancel) self.cancel_button = Button(ui.grid(12), icon_lock, ButtonCancel)
self.cancel_button.on_click = self.on_cancel self.cancel_button.on_click = self.on_cancel # type: ignore
else: else:
self.cancel_button = Button(ui.grid(12), "") self.cancel_button = Button(ui.grid(12), "")
self.cancel_button.disable() self.cancel_button.disable()
@ -103,7 +108,7 @@ class PinDialog(ui.Layout):
PinButton(i, d, self) for i, d in enumerate(generate_digits()) PinButton(i, d, self) for i, d in enumerate(generate_digits())
] ]
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
self.input.dispatch(event, x, y) self.input.dispatch(event, x, y)
if self.input.pin: if self.input.pin:
self.reset_button.dispatch(event, x, y) self.reset_button.dispatch(event, x, y)
@ -113,7 +118,7 @@ class PinDialog(ui.Layout):
for btn in self.pin_buttons: for btn in self.pin_buttons:
btn.dispatch(event, x, y) btn.dispatch(event, x, y)
def assign(self, pin): def assign(self, pin: str) -> None:
if len(pin) > self.maxlength: if len(pin) > self.maxlength:
return return
for btn in self.pin_buttons: for btn in self.pin_buttons:
@ -132,12 +137,12 @@ class PinDialog(ui.Layout):
self.input.pin = pin self.input.pin = pin
self.input.repaint = True self.input.repaint = True
def on_reset(self): def on_reset(self) -> None:
self.assign("") self.assign("")
def on_cancel(self): def on_cancel(self) -> None:
raise ui.Result(CANCELLED) raise ui.Result(CANCELLED)
def on_confirm(self): def on_confirm(self) -> None:
if self.input.pin: if self.input.pin:
raise ui.Result(self.input.pin) raise ui.Result(self.input.pin)

@ -1,17 +1,20 @@
from trezor import loop, ui from trezor import loop, ui
if False:
from typing import Iterable
class Popup(ui.Layout): class Popup(ui.Layout):
def __init__(self, content, time_ms=0): def __init__(self, content: ui.Control, time_ms: int = 0) -> None:
self.content = content self.content = content
self.time_ms = time_ms self.time_ms = time_ms
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y) self.content.dispatch(event, x, y)
def create_tasks(self): def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering(), self.handle_timeout() return self.handle_input(), self.handle_rendering(), self.handle_timeout()
def handle_timeout(self): def handle_timeout(self) -> loop.Task: # type: ignore
yield loop.sleep(self.time_ms * 1000) yield loop.sleep(self.time_ms * 1000)
raise ui.Result(None) raise ui.Result(None)

@ -2,11 +2,11 @@ from trezor import ui
class Qr(ui.Control): class Qr(ui.Control):
def __init__(self, data, x, y, scale): def __init__(self, data: bytes, x: int, y: int, scale: int):
self.data = data self.data = data
self.x = x self.x = x
self.y = y self.y = y
self.scale = scale self.scale = scale
def on_render(self): def on_render(self) -> None:
ui.display.qrcode(self.x, self.y, self.data, self.scale) ui.display.qrcode(self.x, self.y, self.data, self.scale)

@ -8,8 +8,11 @@ from trezor.ui.swipe import SWIPE_DOWN, SWIPE_UP, SWIPE_VERTICAL, Swipe
if __debug__: if __debug__:
from apps.debug import swipe_signal from apps.debug import swipe_signal
if False:
from typing import Iterable, Sequence
def render_scrollbar(pages: int, page: int):
def render_scrollbar(pages: int, page: int) -> None:
BBOX = const(220) BBOX = const(220)
SIZE = const(8) SIZE = const(8)
@ -28,7 +31,7 @@ def render_scrollbar(pages: int, page: int):
ui.display.bar_radius(X, Y + i * padding, SIZE, SIZE, fg, ui.BG, 4) ui.display.bar_radius(X, Y + i * padding, SIZE, SIZE, fg, ui.BG, 4)
def render_swipe_icon(): def render_swipe_icon() -> None:
DRAW_DELAY = const(200000) DRAW_DELAY = const(200000)
icon = res.load(ui.ICON_SWIPE) icon = res.load(ui.ICON_SWIPE)
@ -37,18 +40,20 @@ def render_swipe_icon():
ui.display.icon(70, 205, icon, c, ui.BG) ui.display.icon(70, 205, icon, c, ui.BG)
def render_swipe_text(): def render_swipe_text() -> None:
ui.display.text_center(130, 220, "Swipe", ui.BOLD, ui.GREY, ui.BG) ui.display.text_center(130, 220, "Swipe", ui.BOLD, ui.GREY, ui.BG)
class Paginated(ui.Layout): class Paginated(ui.Layout):
def __init__(self, pages, page=0, one_by_one=False): def __init__(
self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False
):
self.pages = pages self.pages = pages
self.page = page self.page = page
self.one_by_one = one_by_one self.one_by_one = one_by_one
self.repaint = True self.repaint = True
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
pages = self.pages pages = self.pages
page = self.page page = self.page
pages[page].dispatch(event, x, y) pages[page].dispatch(event, x, y)
@ -63,7 +68,7 @@ class Paginated(ui.Layout):
render_scrollbar(length, page) render_scrollbar(length, page)
self.repaint = False self.repaint = False
async def handle_paging(self): async def handle_paging(self) -> None:
if self.page == 0: if self.page == 0:
directions = SWIPE_UP directions = SWIPE_UP
elif self.page == len(self.pages) - 1: elif self.page == len(self.pages) - 1:
@ -86,21 +91,33 @@ class Paginated(ui.Layout):
self.on_change() self.on_change()
def create_tasks(self): def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering(), self.handle_paging() return self.handle_input(), self.handle_rendering(), self.handle_paging()
def on_change(self): def on_change(self) -> None:
if self.one_by_one: if self.one_by_one:
raise ui.Result(self.page) raise ui.Result(self.page)
class PageWithButtons(ui.Control): class PageWithButtons(ui.Control):
def __init__(self, content, paginated, index, count): def __init__(
self,
content: ui.Control,
paginated: "PaginatedWithButtons",
index: int,
count: int,
) -> None:
self.content = content self.content = content
self.paginated = paginated self.paginated = paginated
self.index = index self.index = index
self.count = count self.count = count
# somewhere in the middle, we can go up or down
left = res.load(ui.ICON_BACK)
left_style = ButtonDefault
right = res.load(ui.ICON_CLICK)
right_style = ButtonDefault
if self.index == 0: if self.index == 0:
# first page, we can cancel or go down # first page, we can cancel or go down
left = res.load(ui.ICON_CANCEL) left = res.load(ui.ICON_CANCEL)
@ -113,31 +130,25 @@ class PageWithButtons(ui.Control):
left_style = ButtonDefault left_style = ButtonDefault
right = res.load(ui.ICON_CONFIRM) right = res.load(ui.ICON_CONFIRM)
right_style = ButtonConfirm right_style = ButtonConfirm
else:
# somewhere in the middle, we can go up or down
left = res.load(ui.ICON_BACK)
left_style = ButtonDefault
right = res.load(ui.ICON_CLICK)
right_style = ButtonDefault
self.left = Button(ui.grid(8, n_x=2), left, left_style) self.left = Button(ui.grid(8, n_x=2), left, left_style)
self.left.on_click = self.on_left self.left.on_click = self.on_left # type: ignore
self.right = Button(ui.grid(9, n_x=2), right, right_style) self.right = Button(ui.grid(9, n_x=2), right, right_style)
self.right.on_click = self.on_right self.right.on_click = self.on_right # type: ignore
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y) self.content.dispatch(event, x, y)
self.left.dispatch(event, x, y) self.left.dispatch(event, x, y)
self.right.dispatch(event, x, y) self.right.dispatch(event, x, y)
def on_left(self): def on_left(self) -> None:
if self.index == 0: if self.index == 0:
self.paginated.on_cancel() self.paginated.on_cancel()
else: else:
self.paginated.on_up() self.paginated.on_up()
def on_right(self): def on_right(self) -> None:
if self.index == self.count - 1: if self.index == self.count - 1:
self.paginated.on_confirm() self.paginated.on_confirm()
else: else:
@ -145,36 +156,38 @@ class PageWithButtons(ui.Control):
class PaginatedWithButtons(ui.Layout): class PaginatedWithButtons(ui.Layout):
def __init__(self, pages, page=0, one_by_one=False): def __init__(
self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False
) -> None:
self.pages = [ self.pages = [
PageWithButtons(p, self, i, len(pages)) for i, p in enumerate(pages) PageWithButtons(p, self, i, len(pages)) for i, p in enumerate(pages)
] ]
self.page = page self.page = page
self.one_by_one = one_by_one self.one_by_one = one_by_one
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
pages = self.pages pages = self.pages
page = self.page page = self.page
pages[page].dispatch(event, x, y) pages[page].dispatch(event, x, y)
if event is ui.RENDER: if event is ui.RENDER:
render_scrollbar(len(pages), page) render_scrollbar(len(pages), page)
def on_up(self): def on_up(self) -> None:
self.page = max(self.page - 1, 0) self.page = max(self.page - 1, 0)
self.pages[self.page].dispatch(ui.REPAINT, 0, 0) self.pages[self.page].dispatch(ui.REPAINT, 0, 0)
self.on_change() self.on_change()
def on_down(self): def on_down(self) -> None:
self.page = min(self.page + 1, len(self.pages) - 1) self.page = min(self.page + 1, len(self.pages) - 1)
self.pages[self.page].dispatch(ui.REPAINT, 0, 0) self.pages[self.page].dispatch(ui.REPAINT, 0, 0)
self.on_change() self.on_change()
def on_confirm(self): def on_confirm(self) -> None:
raise ui.Result(CONFIRMED) raise ui.Result(CONFIRMED)
def on_cancel(self): def on_cancel(self) -> None:
raise ui.Result(CANCELLED) raise ui.Result(CANCELLED)
def on_change(self): def on_change(self) -> None:
if self.one_by_one: if self.one_by_one:
raise ui.Result(self.page) raise ui.Result(self.page)

@ -4,31 +4,31 @@ from trezor.ui.text import LABEL_CENTER, Label
class NumInput(ui.Control): class NumInput(ui.Control):
def __init__(self, count=5, max_count=16, min_count=1): def __init__(self, count: int = 5, max_count: int = 16, min_count: int = 1) -> None:
self.count = count self.count = count
self.max_count = max_count self.max_count = max_count
self.min_count = min_count self.min_count = min_count
self.minus = Button(ui.grid(3), "-") self.minus = Button(ui.grid(3), "-")
self.minus.on_click = self.on_minus self.minus.on_click = self.on_minus # type: ignore
self.plus = Button(ui.grid(5), "+") self.plus = Button(ui.grid(5), "+")
self.plus.on_click = self.on_plus self.plus.on_click = self.on_plus # type: ignore
self.text = Label(ui.grid(4), "", LABEL_CENTER, ui.BOLD) self.text = Label(ui.grid(4), "", LABEL_CENTER, ui.BOLD)
self.edit(count) self.edit(count)
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
self.minus.dispatch(event, x, y) self.minus.dispatch(event, x, y)
self.plus.dispatch(event, x, y) self.plus.dispatch(event, x, y)
self.text.dispatch(event, x, y) self.text.dispatch(event, x, y)
def on_minus(self): def on_minus(self) -> None:
self.edit(self.count - 1) self.edit(self.count - 1)
def on_plus(self): def on_plus(self) -> None:
self.edit(self.count + 1) self.edit(self.count + 1)
def edit(self, count): def edit(self, count: int) -> None:
count = max(count, self.min_count) count = max(count, self.min_count)
count = min(count, self.max_count) count = min(count, self.max_count)
if self.count != count: if self.count != count:
@ -45,5 +45,5 @@ class NumInput(ui.Control):
else: else:
self.plus.enable() self.plus.enable()
def on_change(self, count): def on_change(self, count: int) -> None:
pass pass

@ -2,6 +2,9 @@ from micropython import const
from trezor import io, loop, ui from trezor import io, loop, ui
if False:
from typing import Generator
SWIPE_UP = const(0x01) SWIPE_UP = const(0x01)
SWIPE_DOWN = const(0x02) SWIPE_DOWN = const(0x02)
SWIPE_LEFT = const(0x04) SWIPE_LEFT = const(0x04)
@ -15,24 +18,26 @@ _SWIPE_TRESHOLD = const(30)
class Swipe(ui.Control): class Swipe(ui.Control):
def __init__(self, directions=SWIPE_ALL, area=None): def __init__(self, directions: int = SWIPE_ALL, area: ui.Area = None) -> None:
if area is None: if area is None:
area = (0, 0, ui.WIDTH, ui.HEIGHT) area = (0, 0, ui.WIDTH, ui.HEIGHT)
self.area = area self.area = area
self.directions = directions self.directions = directions
self.start_x = None self.started = False
self.start_y = None self.start_x = 0
self.light_origin = None self.start_y = 0
self.light_origin = ui.BACKLIGHT_NORMAL
self.light_target = ui.BACKLIGHT_NONE self.light_target = ui.BACKLIGHT_NONE
def on_touch_start(self, x, y): def on_touch_start(self, x: int, y: int) -> None:
if ui.in_area(self.area, x, y): if ui.in_area(self.area, x, y):
self.start_x = x self.start_x = x
self.start_y = y self.start_y = y
self.light_origin = ui.BACKLIGHT_NORMAL self.light_origin = ui.BACKLIGHT_NORMAL
self.started = True
def on_touch_move(self, x, y): def on_touch_move(self, x: int, y: int) -> None:
if self.start_x is None: if not self.started:
return # not started in our area return # not started in our area
dirs = self.directions dirs = self.directions
@ -61,8 +66,8 @@ class Swipe(ui.Control):
) )
) )
def on_touch_end(self, x, y): def on_touch_end(self, x: int, y: int) -> None:
if self.start_x is None: if not self.started:
return # not started in our area return # not started in our area
dirs = self.directions dirs = self.directions
@ -93,13 +98,15 @@ class Swipe(ui.Control):
# no swipe detected, reset the state # no swipe detected, reset the state
ui.display.backlight(self.light_origin) ui.display.backlight(self.light_origin)
self.start_x = None self.started = False
self.start_y = None
def on_swipe(self, swipe): def on_swipe(self, swipe: int) -> None:
raise ui.Result(swipe) raise ui.Result(swipe)
def __iter__(self): def __await__(self) -> Generator:
return self.__iter__() # type: ignore
def __iter__(self) -> loop.Task: # type: ignore
try: try:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
while True: while True:

@ -2,6 +2,9 @@ from micropython import const
from trezor import ui from trezor import ui
if False:
from typing import List, Union
TEXT_HEADER_HEIGHT = const(48) TEXT_HEADER_HEIGHT = const(48)
TEXT_LINE_HEIGHT = const(26) TEXT_LINE_HEIGHT = const(26)
TEXT_LINE_HEIGHT_HALF = const(13) TEXT_LINE_HEIGHT_HALF = const(13)
@ -12,9 +15,12 @@ TEXT_MAX_LINES = const(5)
BR = const(-256) BR = const(-256)
BR_HALF = const(-257) BR_HALF = const(-257)
if False:
TextContent = Union[str, int]
def render_text( def render_text(
words: list, words: List[TextContent],
new_lines: bool, new_lines: bool,
max_lines: int, max_lines: int,
font: int = ui.NORMAL, font: int = ui.NORMAL,
@ -128,32 +134,32 @@ class Text(ui.Control):
self.icon_color = icon_color self.icon_color = icon_color
self.max_lines = max_lines self.max_lines = max_lines
self.new_lines = new_lines self.new_lines = new_lines
self.content = [] self.content = [] # type: List[Union[str, int]]
self.repaint = True self.repaint = True
def normal(self, *content): def normal(self, *content: TextContent) -> None:
self.content.append(ui.NORMAL) self.content.append(ui.NORMAL)
self.content.extend(content) self.content.extend(content)
def bold(self, *content): def bold(self, *content: TextContent) -> None:
self.content.append(ui.BOLD) self.content.append(ui.BOLD)
self.content.extend(content) self.content.extend(content)
def mono(self, *content): def mono(self, *content: TextContent) -> None:
self.content.append(ui.MONO) self.content.append(ui.MONO)
self.content.extend(content) self.content.extend(content)
def mono_bold(self, *content): def mono_bold(self, *content: TextContent) -> None:
self.content.append(ui.MONO_BOLD) self.content.append(ui.MONO_BOLD)
self.content.extend(content) self.content.extend(content)
def br(self): def br(self) -> None:
self.content.append(BR) self.content.append(BR)
def br_half(self): def br_half(self) -> None:
self.content.append(BR_HALF) self.content.append(BR_HALF)
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
ui.header( ui.header(
self.header_text, self.header_text,
@ -172,21 +178,27 @@ LABEL_RIGHT = const(2)
class Label(ui.Control): class Label(ui.Control):
def __init__(self, area, content, align=LABEL_LEFT, style=ui.NORMAL): def __init__(
self,
area: ui.Area,
content: str,
align: int = LABEL_LEFT,
style: int = ui.NORMAL,
) -> None:
self.area = area self.area = area
self.content = content self.content = content
self.align = align self.align = align
self.style = style self.style = style
self.repaint = True self.repaint = True
def on_render(self): def on_render(self) -> None:
if self.repaint: if self.repaint:
align = self.align align = self.align
ax, ay, aw, ah = self.area ax, ay, aw, ah = self.area
tx = ax + aw // 2 tx = ax + aw // 2
ty = ay + ah // 2 + 8 ty = ay + ah // 2 + 8
if align is LABEL_LEFT: if align is LABEL_LEFT:
ui.display.text_left(tx, ty, self.content, self.style, ui.FG, ui.BG, aw) ui.display.text(tx, ty, self.content, self.style, ui.FG, ui.BG, aw)
elif align is LABEL_CENTER: elif align is LABEL_CENTER:
ui.display.text_center( ui.display.text_center(
tx, ty, self.content, self.style, ui.FG, ui.BG, aw tx, ty, self.content, self.style, ui.FG, ui.BG, aw

@ -5,20 +5,20 @@ from trezor.ui.button import Button
class WordSelector(ui.Layout): class WordSelector(ui.Layout):
def __init__(self, content): def __init__(self, content: ui.Control) -> None:
self.content = content self.content = content
self.w12 = Button(ui.grid(6, n_y=4), "12") self.w12 = Button(ui.grid(6, n_y=4), "12")
self.w12.on_click = self.on_w12 self.w12.on_click = self.on_w12 # type: ignore
self.w18 = Button(ui.grid(7, n_y=4), "18") self.w18 = Button(ui.grid(7, n_y=4), "18")
self.w18.on_click = self.on_w18 self.w18.on_click = self.on_w18 # type: ignore
self.w20 = Button(ui.grid(8, n_y=4), "20") self.w20 = Button(ui.grid(8, n_y=4), "20")
self.w20.on_click = self.on_w20 self.w20.on_click = self.on_w20 # type: ignore
self.w24 = Button(ui.grid(9, n_y=4), "24") self.w24 = Button(ui.grid(9, n_y=4), "24")
self.w24.on_click = self.on_w24 self.w24.on_click = self.on_w24 # type: ignore
self.w33 = Button(ui.grid(10, n_y=4), "33") self.w33 = Button(ui.grid(10, n_y=4), "33")
self.w33.on_click = self.on_w33 self.w33.on_click = self.on_w33 # type: ignore
def dispatch(self, event, x, y): def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y) self.content.dispatch(event, x, y)
self.w12.dispatch(event, x, y) self.w12.dispatch(event, x, y)
self.w18.dispatch(event, x, y) self.w18.dispatch(event, x, y)
@ -26,17 +26,17 @@ class WordSelector(ui.Layout):
self.w24.dispatch(event, x, y) self.w24.dispatch(event, x, y)
self.w33.dispatch(event, x, y) self.w33.dispatch(event, x, y)
def on_w12(self): def on_w12(self) -> None:
raise ui.Result(12) raise ui.Result(12)
def on_w18(self): def on_w18(self) -> None:
raise ui.Result(18) raise ui.Result(18)
def on_w20(self): def on_w20(self) -> None:
raise ui.Result(20) raise ui.Result(20)
def on_w24(self): def on_w24(self) -> None:
raise ui.Result(24) raise ui.Result(24)
def on_w33(self): def on_w33(self) -> None:
raise ui.Result(33) raise ui.Result(33)

@ -27,12 +27,15 @@ if __debug__:
SAVE_SCREEN = 0 SAVE_SCREEN = 0
LOG_MEMORY = 0 LOG_MEMORY = 0
if False:
from typing import Iterable, Iterator, Protocol, List, TypeVar
def unimport_begin():
def unimport_begin() -> Iterable[str]:
return set(sys.modules) return set(sys.modules)
def unimport_end(mods): def unimport_end(mods: Iterable[str]) -> None:
for mod in sys.modules: for mod in sys.modules:
if mod not in mods: if mod not in mods:
# remove reference from sys.modules # remove reference from sys.modules
@ -53,7 +56,7 @@ def unimport_end(mods):
gc.collect() gc.collect()
def ensure(cond, msg=None): def ensure(cond: bool, msg: str = None) -> None:
if not cond: if not cond:
if msg is None: if msg is None:
raise AssertionError raise AssertionError
@ -61,48 +64,71 @@ def ensure(cond, msg=None):
raise AssertionError(msg) raise AssertionError(msg)
def chunks(items, size): if False:
Chunked = TypeVar("Chunked")
def chunks(items: List[Chunked], size: int) -> Iterator[List[Chunked]]:
for i in range(0, len(items), size): for i in range(0, len(items), size):
yield items[i : i + size] yield items[i : i + size]
def format_amount(amount, decimals): def format_amount(amount: int, decimals: int) -> str:
d = pow(10, decimals) d = pow(10, decimals)
amount = ("%d.%0*d" % (amount // d, decimals, amount % d)).rstrip("0") s = ("%d.%0*d" % (amount // d, decimals, amount % d)).rstrip("0").rstrip(".")
if amount.endswith("."): return s
amount = amount[:-1]
return amount
def format_ordinal(number): def format_ordinal(number: int) -> str:
return str(number) + {1: "st", 2: "nd", 3: "rd"}.get( return str(number) + {1: "st", 2: "nd", 3: "rd"}.get(
4 if 10 <= number % 100 < 20 else number % 10, "th" 4 if 10 <= number % 100 < 20 else number % 10, "th"
) )
if False:
class HashContext(Protocol):
def update(self, buf: bytes) -> None:
...
def digest(self) -> bytes:
...
class Writer(Protocol):
def append(self, b: int) -> None:
...
def extend(self, buf: bytes) -> None:
...
def write(self, buf: bytes) -> None:
...
class HashWriter: class HashWriter:
def __init__(self, ctx): def __init__(self, ctx: HashContext) -> None:
self.ctx = ctx self.ctx = ctx
self.buf = bytearray(1) # used in append() self.buf = bytearray(1) # used in append()
def extend(self, buf: bytearray): def append(self, b: int) -> None:
self.ctx.update(buf) self.buf[0] = b
self.ctx.update(self.buf)
def write(self, buf: bytearray): # alias for extend() def extend(self, buf: bytes) -> None:
self.ctx.update(buf) self.ctx.update(buf)
async def awrite(self, buf: bytearray): # AsyncWriter interface def write(self, buf: bytes) -> None: # alias for extend()
return self.ctx.update(buf) self.ctx.update(buf)
def append(self, b: int): async def awrite(self, buf: bytes) -> int: # AsyncWriter interface
self.buf[0] = b self.ctx.update(buf)
self.ctx.update(self.buf) return len(buf)
def get_digest(self) -> bytes: def get_digest(self) -> bytes:
return self.ctx.digest() return self.ctx.digest()
def obj_eq(l, r): def obj_eq(l: object, r: object) -> bool:
""" """
Compares object contents, supports __slots__. Compares object contents, supports __slots__.
""" """
@ -118,7 +144,7 @@ def obj_eq(l, r):
return True return True
def obj_repr(o): def obj_repr(o: object) -> str:
""" """
Returns a string representation of object, supports __slots__. Returns a string representation of object, supports __slots__.
""" """

@ -7,11 +7,28 @@ from trezor.wire.errors import Error
# import all errors into namespace, so that `wire.Error` is available elsewhere # import all errors into namespace, so that `wire.Error` is available elsewhere
from trezor.wire.errors import * # isort:skip # noqa: F401,F403 from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if False:
from typing import (
Any,
Awaitable,
Dict,
Callable,
Iterable,
List,
Optional,
Tuple,
Type,
)
from trezorio import WireInterface
from protobuf import LoadedMessageType, MessageType
Handler = Callable[..., loop.Task]
workflow_handlers = {} workflow_handlers = {} # type: Dict[int, Tuple[Handler, Iterable]]
def add(mtype, pkgname, modname, namespace=None): def add(mtype: int, pkgname: str, modname: str, namespace: List = None) -> None:
"""Shortcut for registering a dynamically-imported Protobuf workflow.""" """Shortcut for registering a dynamically-imported Protobuf workflow."""
if namespace is not None: if namespace is not None:
register( register(
@ -27,7 +44,7 @@ def add(mtype, pkgname, modname, namespace=None):
register(mtype, protobuf_workflow, import_workflow, pkgname, modname) register(mtype, protobuf_workflow, import_workflow, pkgname, modname)
def register(mtype, handler, *args): def register(mtype: int, handler: Handler, *args: Any) -> None:
"""Register `handler` to get scheduled after `mtype` message is received.""" """Register `handler` to get scheduled after `mtype` message is received."""
if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType): if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType):
mtype = mtype.MESSAGE_WIRE_TYPE mtype = mtype.MESSAGE_WIRE_TYPE
@ -36,54 +53,75 @@ def register(mtype, handler, *args):
workflow_handlers[mtype] = (handler, args) workflow_handlers[mtype] = (handler, args)
def setup(iface): def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on passed USB interface.""" """Initialize the wire stack on passed USB interface."""
loop.schedule(session_handler(iface, codec_v1.SESSION_ID)) loop.schedule(session_handler(iface, codec_v1.SESSION_ID))
class Context: class Context:
def __init__(self, iface, sid): def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface self.iface = iface
self.sid = sid self.sid = sid
async def call(self, msg, *types): async def call(
""" self, msg: MessageType, exptype: Type[LoadedMessageType]
Reply with `msg` and wait for one of `types`. See `self.write()` and ) -> LoadedMessageType:
`self.read()`.
"""
await self.write(msg) await self.write(msg)
del msg del msg
return await self.read(types) return await self.read(exptype)
async def read(self, types): async def call_any(self, msg: MessageType, *allowed_types: int) -> MessageType:
""" await self.write(msg)
Wait for incoming message on this wire context and return it. Raises del msg
`UnexpectedMessageError` if the message type does not match one of return await self.read_any(allowed_types)
`types`; and caller should always make sure to re-raise it.
""" async def read(
reader = self.getreader() self, exptype: Optional[Type[LoadedMessageType]]
) -> LoadedMessageType:
reader = self.make_reader()
if __debug__: if __debug__:
log.debug( log.debug(
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, types __name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype
) )
await reader.aopen() # wait for the message header await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via # if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it # `UnexpectedMessageError` and let the session handler deal with it
if reader.type not in types: if exptype is None or reader.type != exptype.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(reader) raise UnexpectedMessageError(reader)
# look up the protobuf class and parse the message # parse the message and return it
pbtype = messages.get_type(reader.type) return await protobuf.load_message(reader, exptype)
return await protobuf.load_message(reader, pbtype)
async def write(self, msg): async def read_any(self, allowed_types: Iterable[int]) -> MessageType:
""" reader = self.make_reader()
Write a protobuf message to this wire context.
""" if __debug__:
writer = self.getwriter() log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
allowed_types,
)
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if reader.type not in allowed_types:
raise UnexpectedMessageError(reader)
# find the protobuf type
exptype = messages.get_type(reader.type)
# parse the message and return it
return await protobuf.load_message(reader, exptype)
async def write(self, msg: protobuf.MessageType) -> None:
writer = self.make_writer()
if __debug__: if __debug__:
log.debug( log.debug(
@ -99,35 +137,35 @@ class Context:
await protobuf.dump_message(writer, msg, fields) await protobuf.dump_message(writer, msg, fields)
await writer.aclose() await writer.aclose()
def wait(self, *tasks): def wait(self, *tasks: Awaitable) -> Any:
""" """
Wait until one of the passed tasks finishes, and return the result, Wait until one of the passed tasks finishes, and return the result,
while servicing the wire context. If a message comes until one of the while servicing the wire context. If a message comes until one of the
tasks ends, `UnexpectedMessageError` is raised. tasks ends, `UnexpectedMessageError` is raised.
""" """
return loop.spawn(self.read(()), *tasks) return loop.spawn(self.read(None), *tasks)
def getreader(self): def make_reader(self) -> codec_v1.Reader:
return codec_v1.Reader(self.iface) return codec_v1.Reader(self.iface)
def getwriter(self): def make_writer(self) -> codec_v1.Writer:
return codec_v1.Writer(self.iface) return codec_v1.Writer(self.iface)
class UnexpectedMessageError(Exception): class UnexpectedMessageError(Exception):
def __init__(self, reader): def __init__(self, reader: codec_v1.Reader) -> None:
super().__init__() super().__init__()
self.reader = reader self.reader = reader
async def session_handler(iface, sid): async def session_handler(iface: WireInterface, sid: int) -> None:
reader = None reader = None
ctx = Context(iface, sid) ctx = Context(iface, sid)
while True: while True:
try: try:
# wait for new message, if needed, and find handler # wait for new message, if needed, and find handler
if not reader: if not reader:
reader = ctx.getreader() reader = ctx.make_reader()
await reader.aopen() await reader.aopen()
try: try:
handler, args = workflow_handlers[reader.type] handler, args = workflow_handlers[reader.type]
@ -160,7 +198,9 @@ async def session_handler(iface, sid):
reader = None reader = None
async def protobuf_workflow(ctx, reader, handler, *args): async def protobuf_workflow(
ctx: Context, reader: codec_v1.Reader, handler: Handler, *args: Any
) -> None:
from trezor.messages.Failure import Failure from trezor.messages.Failure import Failure
req = await protobuf.load_message(reader, messages.get_type(reader.type)) req = await protobuf.load_message(reader, messages.get_type(reader.type))
@ -185,7 +225,13 @@ async def protobuf_workflow(ctx, reader, handler, *args):
await ctx.write(res) await ctx.write(res)
async def keychain_workflow(ctx, req, namespace, handler, *args): async def keychain_workflow(
ctx: Context,
req: protobuf.MessageType,
namespace: List,
handler: Handler,
*args: Any
) -> Any:
from apps.common import seed from apps.common import seed
keychain = await seed.get_keychain(ctx, namespace) keychain = await seed.get_keychain(ctx, namespace)
@ -196,22 +242,28 @@ async def keychain_workflow(ctx, req, namespace, handler, *args):
keychain.__del__() keychain.__del__()
def import_workflow(ctx, req, pkgname, modname, *args): def import_workflow(
ctx: Context, req: protobuf.MessageType, pkgname: str, modname: str, *args: Any
) -> Any:
modpath = "%s.%s" % (pkgname, modname) modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0) module = __import__(modpath, None, None, (modname,), 0) # type: ignore
handler = getattr(module, modname) handler = getattr(module, modname)
return handler(ctx, req, *args) return handler(ctx, req, *args)
async def unexpected_msg(ctx, reader): async def unexpected_msg(ctx: Context, reader: codec_v1.Reader) -> None:
from trezor.messages.Failure import Failure from trezor.messages.Failure import Failure
# receive the message and throw it away # receive the message and throw it away
while reader.size > 0: await read_full_msg(reader)
buf = bytearray(reader.size)
await reader.areadinto(buf)
# respond with an unknown message error # respond with an unknown message error
await ctx.write( await ctx.write(
Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
) )
async def read_full_msg(reader: codec_v1.Reader) -> None:
while reader.size > 0:
buf = bytearray(reader.size)
await reader.areadinto(buf)

@ -3,6 +3,9 @@ from micropython import const
from trezor import io, loop, utils from trezor import io, loop, utils
if False:
from trezorio import WireInterface
_REP_LEN = const(64) _REP_LEN = const(64)
_REP_MARKER = const(63) # ord('?') _REP_MARKER = const(63) # ord('?')
@ -12,6 +15,7 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report
_REP_CONT_DATA = const(1) # offset of data in the continuation report _REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0) SESSION_ID = const(0)
INVALID_TYPE = const(-1)
class Reader: class Reader:
@ -20,17 +24,14 @@ class Reader:
async-file-like interface. async-file-like interface.
""" """
def __init__(self, iface): def __init__(self, iface: WireInterface) -> None:
self.iface = iface self.iface = iface
self.type = None self.type = INVALID_TYPE
self.size = None self.size = 0
self.data = None
self.ofs = 0 self.ofs = 0
self.data = bytes()
def __repr__(self): async def aopen(self) -> None:
return "<ReaderV1: type=%d size=%dB>" % (self.type, self.size)
async def aopen(self):
""" """
Begin the message transmission by waiting for initial V2 message report Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and on this session. `self.type` and `self.size` are initialized and
@ -53,7 +54,7 @@ class Reader:
self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize] self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize]
self.ofs = 0 self.ofs = 0
async def areadinto(self, buf): async def areadinto(self, buf: bytearray) -> int:
""" """
Read exactly `len(buf)` bytes into `buf`, waiting for additional Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered reports, if needed. Raises `EOFError` if end-of-message is encountered
@ -91,17 +92,14 @@ class Writer:
async-file-like interface. async-file-like interface.
""" """
def __init__(self, iface): def __init__(self, iface: WireInterface):
self.iface = iface self.iface = iface
self.type = None self.type = INVALID_TYPE
self.size = None self.size = 0
self.data = bytearray(_REP_LEN)
self.ofs = 0 self.ofs = 0
self.data = bytearray(_REP_LEN)
def __repr__(self): def setheader(self, mtype: int, msize: int) -> None:
return "<WriterV1: type=%d size=%dB>" % (self.type, self.size)
def setheader(self, mtype, msize):
""" """
Reset the writer state and load the message header with passed type and Reset the writer state and load the message header with passed type and
total message size. total message size.
@ -113,7 +111,7 @@ class Writer:
) )
self.ofs = _REP_INIT_DATA self.ofs = _REP_INIT_DATA
async def awrite(self, buf): async def awrite(self, buf: bytes) -> int:
""" """
Encode and write every byte from `buf`. Does not need to be called in Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf` case message has zero length. Raises `EOFError` if the length of `buf`
@ -142,7 +140,7 @@ class Writer:
return nwritten return nwritten
async def aclose(self): async def aclose(self) -> None:
"""Flush and close the message transmission.""" """Flush and close the message transmission."""
if self.ofs != _REP_CONT_DATA: if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned, # we didn't write anything or last write() wasn't report-aligned,

@ -2,72 +2,72 @@ from trezor.messages import FailureType
class Error(Exception): class Error(Exception):
def __init__(self, code, message): def __init__(self, code: int, message: str) -> None:
super().__init__() super().__init__()
self.code = code self.code = code
self.message = message self.message = message
class UnexpectedMessage(Error): class UnexpectedMessage(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.UnexpectedMessage, message) super().__init__(FailureType.UnexpectedMessage, message)
class ButtonExpected(Error): class ButtonExpected(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.ButtonExpected, message) super().__init__(FailureType.ButtonExpected, message)
class DataError(Error): class DataError(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.DataError, message) super().__init__(FailureType.DataError, message)
class ActionCancelled(Error): class ActionCancelled(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.ActionCancelled, message) super().__init__(FailureType.ActionCancelled, message)
class PinExpected(Error): class PinExpected(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.PinExpected, message) super().__init__(FailureType.PinExpected, message)
class PinCancelled(Error): class PinCancelled(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.PinCancelled, message) super().__init__(FailureType.PinCancelled, message)
class PinInvalid(Error): class PinInvalid(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.PinInvalid, message) super().__init__(FailureType.PinInvalid, message)
class InvalidSignature(Error): class InvalidSignature(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.InvalidSignature, message) super().__init__(FailureType.InvalidSignature, message)
class ProcessError(Error): class ProcessError(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.ProcessError, message) super().__init__(FailureType.ProcessError, message)
class NotEnoughFunds(Error): class NotEnoughFunds(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.NotEnoughFunds, message) super().__init__(FailureType.NotEnoughFunds, message)
class NotInitialized(Error): class NotInitialized(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.NotInitialized, message) super().__init__(FailureType.NotInitialized, message)
class PinMismatch(Error): class PinMismatch(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.PinMismatch, message) super().__init__(FailureType.PinMismatch, message)
class FirmwareError(Error): class FirmwareError(Error):
def __init__(self, message): def __init__(self, message: str) -> None:
super().__init__(FailureType.FirmwareError, message) super().__init__(FailureType.FirmwareError, message)

@ -1,17 +1,21 @@
from trezor import loop from trezor import loop
workflows = [] if False:
layouts = [] from trezor import ui
from typing import List, Callable, Optional
workflows = [] # type: List[loop.Task]
layouts = [] # type: List[ui.Layout]
layout_signal = loop.signal() layout_signal = loop.signal()
default = None default = None # type: Optional[loop.Task]
default_layout = None default_layout = None # type: Optional[Callable[[], loop.Task]]
def onstart(w): def onstart(w: loop.Task) -> None:
workflows.append(w) workflows.append(w)
def onclose(w): def onclose(w: loop.Task) -> None:
workflows.remove(w) workflows.remove(w)
if not layouts and default_layout: if not layouts and default_layout:
startdefault(default_layout) startdefault(default_layout)
@ -24,7 +28,7 @@ def onclose(w):
micropython.mem_info() micropython.mem_info()
def closedefault(): def closedefault() -> None:
global default global default
if default: if default:
@ -32,7 +36,7 @@ def closedefault():
default = None default = None
def startdefault(layout): def startdefault(layout: Callable[[], loop.Task]) -> None:
global default global default
global default_layout global default_layout
@ -42,18 +46,19 @@ def startdefault(layout):
loop.schedule(default) loop.schedule(default)
def restartdefault(): def restartdefault() -> None:
global default_layout global default_layout
d = default_layout
closedefault() closedefault()
startdefault(d) if default_layout:
startdefault(default_layout)
def onlayoutstart(l): def onlayoutstart(l: ui.Layout) -> None:
closedefault() closedefault()
layouts.append(l) layouts.append(l)
def onlayoutclose(l): def onlayoutclose(l: ui.Layout) -> None:
if l in layouts: if l in layouts:
layouts.remove(l) layouts.remove(l)

@ -33,15 +33,16 @@ addopts = --strict
xfail_strict = true xfail_strict = true
[mypy] [mypy]
mypy_path = mocks,mocks/generated mypy_path = src,mocks,mocks/generated
warn_unused_configs = True check_untyped_defs = True
disallow_subclassing_any = True disallow_subclassing_any = True
disallow_untyped_calls = True disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True disallow_untyped_defs = True
disallow_incomplete_defs = True disallow_incomplete_defs = True
check_untyped_defs = True namespace_packages = True
# no_implicit_optional = True # no_implicit_optional = True
warn_redundant_casts = True warn_redundant_casts = True
warn_return_any = True warn_return_any = True
warn_unused_configs = True
warn_unused_ignores = True warn_unused_ignores = True
disallow_untyped_decorators = True

Loading…
Cancel
Save