core/typing: add annotations

pull/320/head
Jan Pochyla 5 years ago
parent 86e16bbf31
commit 758a1a2528

@ -72,7 +72,7 @@ pylint: ## run pylint on application sources and tests
pylint -E $(shell find src tests -name *.py)
mypy:
mypy \
mypy --config-file ../setup.cfg \
src/main.py
## code generation:

@ -7,7 +7,7 @@ CURVE = "ed25519"
SEED_NAMESPACE = [HARDENED | 44, HARDENED | 1815]
def boot():
def boot() -> None:
wire.add(MessageType.CardanoGetAddress, __name__, "get_address")
wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.CardanoSignTx, __name__, "sign_tx")

@ -4,8 +4,8 @@ from trezor import log, wire
from trezor.crypto import base58, hashlib
from trezor.crypto.curve import ed25519
from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxAck import CardanoTxAck
from trezor.messages.CardanoTxRequest import CardanoTxRequest
from trezor.messages.MessageType import CardanoTxAck
from apps.cardano import CURVE, seed
from apps.cardano.address import (

@ -1,4 +1,9 @@
def length(address_type):
if False:
from typing import Tuple
from apps.common.coininfo import CoinType
def length(address_type: int) -> int:
if address_type <= 0xFF:
return 1
if address_type <= 0xFFFF:
@ -9,21 +14,21 @@ def length(address_type):
return 4
def tobytes(address_type: int):
def tobytes(address_type: int) -> bytes:
return address_type.to_bytes(length(address_type), "big")
def check(address_type, raw_address):
def check(address_type: int, raw_address: bytes) -> bool:
return raw_address.startswith(tobytes(address_type))
def strip(address_type, raw_address):
def strip(address_type: int, raw_address: bytes) -> bytes:
if not check(address_type, raw_address):
raise ValueError("Invalid address")
return raw_address[length(address_type) :]
def split(coin, raw_address):
def split(coin: CoinType, raw_address: bytes) -> Tuple[bytes, bytes]:
for f in (
"address_type",
"address_type_p2sh",

@ -2,12 +2,15 @@ from trezor.crypto import hashlib, hmac, random
from apps.common import storage
_cached_seed = None
_cached_passphrase = None
_cached_passphrase_fprint = b"\x00\x00\x00\x00"
if False:
from typing import Optional
_cached_seed = None # type: Optional[bytes]
_cached_passphrase = None # type: Optional[str]
_cached_passphrase_fprint = b"\x00\x00\x00\x00" # type: bytes
def get_state(prev_state: bytes = None, passphrase: str = None) -> bytes:
def get_state(prev_state: bytes = None, passphrase: str = None) -> Optional[bytes]:
if prev_state is None:
salt = random.bytes(32) # generate a random salt if no state provided
else:
@ -29,34 +32,34 @@ def _compute_state(salt: bytes, passphrase: str) -> bytes:
return salt + state
def get_seed():
def get_seed() -> Optional[bytes]:
return _cached_seed
def get_passphrase():
def get_passphrase() -> Optional[str]:
return _cached_passphrase
def get_passphrase_fprint():
def get_passphrase_fprint() -> bytes:
return _cached_passphrase_fprint
def has_passphrase():
def has_passphrase() -> bool:
return _cached_passphrase is not None
def set_seed(seed):
def set_seed(seed: Optional[bytes]) -> None:
global _cached_seed
_cached_seed = seed
def set_passphrase(passphrase):
def set_passphrase(passphrase: Optional[str]) -> None:
global _cached_passphrase, _cached_passphrase_fprint
_cached_passphrase = passphrase
_cached_passphrase_fprint = _compute_state(b"FPRINT", passphrase or "")[:4]
def clear(keep_passphrase: bool = False):
def clear(keep_passphrase: bool = False) -> None:
set_seed(None)
if not keep_passphrase:
set_passphrase(None)

@ -6,7 +6,11 @@ import ustruct as struct
from micropython import const
from trezor import log
from trezor.utils import ensure
if False:
from typing import Any, Dict, Iterable, List, Tuple
Value = Any
_CBOR_TYPE_MASK = const(0xE0)
_CBOR_INFO_BITS = const(0x1F)
@ -32,7 +36,7 @@ _CBOR_BREAK = const(0x1F)
_CBOR_RAW_TAG = const(0x18)
def _header(typ, l: int):
def _header(typ: int, l: int) -> bytes:
if l < 24:
return struct.pack(">B", typ + l)
elif l < 2 ** 8:
@ -47,7 +51,7 @@ def _header(typ, l: int):
raise NotImplementedError("Length %d not suppported" % l)
def _cbor_encode(value):
def _cbor_encode(value: Value) -> Iterable[bytes]:
if isinstance(value, int):
if value >= 0:
yield _header(_CBOR_UNSIGNED_INT, value)
@ -95,7 +99,7 @@ def _cbor_encode(value):
raise NotImplementedError
def _read_length(cbor, aux):
def _read_length(cbor: bytes, aux: int) -> Tuple[int, bytes]:
if aux < _CBOR_UINT8_FOLLOWS:
return (aux, cbor)
elif aux == _CBOR_UINT8_FOLLOWS:
@ -124,7 +128,7 @@ def _read_length(cbor, aux):
raise NotImplementedError("Length %d not suppported" % aux)
def _cbor_decode(cbor):
def _cbor_decode(cbor: bytes) -> Tuple[Value, bytes]:
fb = cbor[0]
data = b""
fb_type = fb & _CBOR_TYPE_MASK
@ -158,7 +162,7 @@ def _cbor_decode(cbor):
res.append(item)
return (res, data)
elif fb_type == _CBOR_MAP:
res = {}
res = {} # type: Dict[Value, Value]
if fb_aux == _CBOR_VAR_FOLLOWS:
data = cbor[1:]
while True:
@ -201,36 +205,41 @@ def _cbor_decode(cbor):
class Tagged:
def __init__(self, tag, value):
def __init__(self, tag: int, value: Value) -> None:
self.tag = tag
self.value = value
def __eq__(self, other):
return self.tag == other.tag and self.value == other.value
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Tagged)
and self.tag == other.tag
and self.value == other.value
)
class Raw:
def __init__(self, value):
def __init__(self, value: Value):
self.value = value
class IndefiniteLengthArray:
def __init__(self, array):
ensure(isinstance(array, list))
def __init__(self, array: List[Value]) -> None:
self.array = array
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, IndefiniteLengthArray):
return self.array == other.array
else:
elif isinstance(other, list):
return self.array == other
else:
return False
def encode(value):
def encode(value: Value) -> bytes:
return b"".join(_cbor_encode(value))
def decode(cbor: bytes):
def decode(cbor: bytes) -> Value:
res, check = _cbor_decode(cbor)
if not (check == b""):
raise ValueError

@ -1,23 +1,30 @@
from trezor import wire
from trezor.messages import ButtonRequestType, MessageType
from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.ui.confirm import CONFIRMED, Confirm, HoldToConfirm
if __debug__:
from apps.debug import confirm_signal
if False:
from typing import Any
from trezor import ui
from trezor.ui.confirm import ButtonContent, ButtonStyleType
from trezor.ui.loader import LoaderStyleType
async def confirm(
ctx,
content,
code=ButtonRequestType.Other,
confirm=Confirm.DEFAULT_CONFIRM,
confirm_style=Confirm.DEFAULT_CONFIRM_STYLE,
cancel=Confirm.DEFAULT_CANCEL,
cancel_style=Confirm.DEFAULT_CANCEL_STYLE,
major_confirm=None,
):
await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck)
ctx: wire.Context,
content: ui.Control,
code: int = ButtonRequestType.Other,
confirm: ButtonContent = Confirm.DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE,
cancel: ButtonContent = Confirm.DEFAULT_CANCEL,
cancel_style: ButtonStyleType = Confirm.DEFAULT_CANCEL_STYLE,
major_confirm: bool = False,
) -> bool:
await ctx.call(ButtonRequest(code=code), ButtonAck)
if content.__class__.__name__ == "Paginated":
content.pages[-1] = Confirm(
@ -41,14 +48,14 @@ async def confirm(
async def hold_to_confirm(
ctx,
content,
code=ButtonRequestType.Other,
confirm=HoldToConfirm.DEFAULT_CONFIRM,
confirm_style=HoldToConfirm.DEFAULT_CONFIRM_STYLE,
loader_style=HoldToConfirm.DEFAULT_LOADER_STYLE,
):
await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck)
ctx: wire.Context,
content: ui.Control,
code: int = ButtonRequestType.Other,
confirm: ButtonContent = Confirm.DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE,
loader_style: LoaderStyleType = HoldToConfirm.DEFAULT_LOADER_STYLE,
) -> bool:
await ctx.call(ButtonRequest(code=code), ButtonAck)
if content.__class__.__name__ == "Paginated":
content.pages[-1] = HoldToConfirm(
@ -64,13 +71,13 @@ async def hold_to_confirm(
return await ctx.wait(dialog) is CONFIRMED
async def require_confirm(*args, **kwargs):
async def require_confirm(*args: Any, **kwargs: Any) -> None:
confirmed = await confirm(*args, **kwargs)
if not confirmed:
raise wire.ActionCancelled("Cancelled")
async def require_hold_to_confirm(*args, **kwargs):
async def require_hold_to_confirm(*args: Any, **kwargs: Any) -> None:
confirmed = await hold_to_confirm(*args, **kwargs)
if not confirmed:
raise wire.ActionCancelled("Cancelled")

@ -12,10 +12,14 @@ from trezor.utils import chunks
from apps.common import HARDENED
from apps.common.confirm import confirm, require_confirm
if False:
from typing import Iterable
from trezor import wire
async def show_address(
ctx, address: str, desc: str = "Confirm address", network: str = None
):
ctx: wire.Context, address: str, desc: str = "Confirm address", network: str = None
) -> bool:
text = Text(desc, ui.ICON_RECEIVE, ui.GREEN)
if network is not None:
text.normal("%s network" % network)
@ -30,7 +34,9 @@ async def show_address(
)
async def show_qr(ctx, address: str, desc: str = "Confirm address"):
async def show_qr(
ctx: wire.Context, address: str, desc: str = "Confirm address"
) -> bool:
QR_X = const(120)
QR_Y = const(115)
QR_COEF = const(4)
@ -47,19 +53,19 @@ async def show_qr(ctx, address: str, desc: str = "Confirm address"):
)
async def show_pubkey(ctx, pubkey: bytes):
async def show_pubkey(ctx: wire.Context, pubkey: bytes) -> None:
lines = chunks(hexlify(pubkey).decode(), 18)
text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN)
text.mono(*lines)
return await require_confirm(ctx, text, ButtonRequestType.PublicKey)
await require_confirm(ctx, text, ButtonRequestType.PublicKey)
def split_address(address: str):
def split_address(address: str) -> Iterable[str]:
return chunks(address, 17)
def address_n_to_str(address_n: list) -> str:
def path_item(i: int):
def path_item(i: int) -> str:
if i & HARDENED:
return str(i ^ HARDENED) + "'"
else:

@ -8,13 +8,16 @@ from trezor.utils import consteq
from apps.common import storage
from apps.common.mnemonic import bip39, slip39
if False:
from typing import Any, Tuple
TYPE_BIP39 = const(0)
TYPE_SLIP39 = const(1)
TYPES_WORD_COUNT = {12: bip39, 18: bip39, 24: bip39, 20: slip39, 33: slip39}
def get() -> (bytes, int):
def get() -> Tuple[bytes, int]:
mnemonic_secret = storage.device.get_mnemonic_secret()
mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39
if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39):
@ -22,15 +25,16 @@ def get() -> (bytes, int):
return mnemonic_secret, mnemonic_type
def get_seed(passphrase: str = "", progress_bar=True):
def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
mnemonic_secret, mnemonic_type = get()
if mnemonic_type == TYPE_BIP39:
return bip39.get_seed(mnemonic_secret, passphrase, progress_bar)
elif mnemonic_type == TYPE_SLIP39:
return slip39.get_seed(mnemonic_secret, passphrase, progress_bar)
raise ValueError("Unknown mnemonic type")
def dry_run(secret: bytes):
def dry_run(secret: bytes) -> None:
digest_input = sha256(secret).digest()
stored, _ = get()
digest_stored = sha256(stored).digest()
@ -42,11 +46,11 @@ def dry_run(secret: bytes):
)
def module_from_words_count(count: int):
def module_from_words_count(count: int) -> Any:
return TYPES_WORD_COUNT[count]
def _start_progress():
def _start_progress() -> None:
workflow.closedefault()
ui.backlight_fade(ui.BACKLIGHT_DIM)
ui.display.clear()
@ -55,11 +59,11 @@ def _start_progress():
ui.backlight_fade(ui.BACKLIGHT_NORMAL)
def _render_progress(progress: int, total: int):
def _render_progress(progress: int, total: int) -> None:
p = 1000 * progress // total
ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
ui.display.refresh()
def _stop_progress():
def _stop_progress() -> None:
pass

@ -3,7 +3,7 @@ from trezor.crypto import bip39
from apps.common import mnemonic, storage
def get_type():
def get_type() -> int:
return mnemonic.TYPE_BIP39
@ -23,20 +23,19 @@ def process_all(mnemonics: list) -> bytes:
return mnemonics[0].encode()
def store(secret: bytes, needs_backup: bool, no_backup: bool):
def store(secret: bytes, needs_backup: bool, no_backup: bool) -> None:
storage.device.store_mnemonic_secret(
secret, mnemonic.TYPE_BIP39, needs_backup, no_backup
)
def get_seed(secret: bytes, passphrase: str, progress_bar=True):
def get_seed(secret: bytes, passphrase: str, progress_bar: bool = True) -> bytes:
if progress_bar:
mnemonic._start_progress()
seed = bip39.seed(secret.decode(), passphrase, mnemonic._render_progress)
mnemonic._stop_progress()
else:
seed = bip39.seed(secret.decode(), passphrase)
return seed
@ -44,5 +43,5 @@ def get_mnemonic_threshold(mnemonic: str) -> int:
return 1
def check(secret: bytes):
def check(secret: bytes) -> bool:
return bip39.check(secret)

@ -2,6 +2,9 @@ from trezor.crypto import slip39
from apps.common import mnemonic, storage
if False:
from typing import Optional
def generate_from_secret(master_secret: bytes, count: int, threshold: int) -> list:
"""
@ -12,11 +15,11 @@ def generate_from_secret(master_secret: bytes, count: int, threshold: int) -> li
)
def get_type():
def get_type() -> int:
return mnemonic.TYPE_SLIP39
def process_single(mnemonic: str) -> bytes:
def process_single(mnemonic: str) -> Optional[bytes]:
"""
Receives single mnemonic and processes it. Returns what is then stored in storage or
None if more shares are needed.
@ -72,14 +75,16 @@ def process_all(mnemonics: list) -> bytes:
return secret
def store(secret: bytes, needs_backup: bool, no_backup: bool):
def store(secret: bytes, needs_backup: bool, no_backup: bool) -> None:
storage.device.store_mnemonic_secret(
secret, mnemonic.TYPE_SLIP39, needs_backup, no_backup
)
storage.slip39.delete_progress()
def get_seed(encrypted_master_secret: bytes, passphrase: str, progress_bar=True):
def get_seed(
encrypted_master_secret: bytes, passphrase: str, progress_bar: bool = True
) -> bytes:
if progress_bar:
mnemonic._start_progress()
identifier = storage.slip39.get_identifier()

@ -7,20 +7,32 @@ from trezor.ui.text import Text
from apps.common import HARDENED
from apps.common.confirm import require_confirm
if False:
from typing import Any, Callable, List
from trezor import wire
from apps.common import seed
async def validate_path(ctx, validate_func, keychain, path, curve, **kwargs):
async def validate_path(
ctx: wire.Context,
validate_func: Callable[..., bool],
keychain: seed.Keychain,
path: List[int],
curve: str,
**kwargs: Any,
) -> None:
keychain.validate_path(path, curve)
if not validate_func(path, **kwargs):
await show_path_warning(ctx, path)
async def show_path_warning(ctx, path: list):
async def show_path_warning(ctx: wire.Context, path: List[int]) -> None:
text = Text("Confirm path", ui.ICON_WRONG, ui.RED)
text.normal("Path")
text.mono(*break_address_n_to_lines(path))
text.normal("is unknown.")
text.normal("Are you sure?")
return await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath)
def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool:
@ -53,7 +65,7 @@ def is_hardened(i: int) -> bool:
def break_address_n_to_lines(address_n: list) -> list:
def path_item(i: int):
def path_item(i: int) -> str:
if i & HARDENED:
return str(i ^ HARDENED) + "'"
else:

@ -1,9 +1,12 @@
from micropython import const
from trezor import ui, wire
from trezor.messages import ButtonRequestType, MessageType, PassphraseSourceType
from trezor.messages import ButtonRequestType, PassphraseSourceType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.PassphraseAck import PassphraseAck
from trezor.messages.PassphraseRequest import PassphraseRequest
from trezor.messages.PassphraseStateAck import PassphraseStateAck
from trezor.messages.PassphraseStateRequest import PassphraseStateRequest
from trezor.ui.passphrase import CANCELLED, PassphraseKeyboard, PassphraseSource
from trezor.ui.popup import Popup
@ -17,14 +20,14 @@ if __debug__:
_MAX_PASSPHRASE_LEN = const(50)
async def protect_by_passphrase(ctx) -> str:
async def protect_by_passphrase(ctx: wire.Context) -> str:
if storage.device.has_passphrase():
return await request_passphrase(ctx)
else:
return ""
async def request_passphrase(ctx) -> str:
async def request_passphrase(ctx: wire.Context) -> str:
source = storage.device.get_passphrase_source()
if source == PassphraseSourceType.ASK:
source = await request_passphrase_source(ctx)
@ -36,9 +39,9 @@ async def request_passphrase(ctx) -> str:
return passphrase
async def request_passphrase_source(ctx) -> int:
async def request_passphrase_source(ctx: wire.Context) -> int:
req = ButtonRequest(code=ButtonRequestType.PassphraseType)
await ctx.call(req, MessageType.ButtonAck)
await ctx.call(req, ButtonAck)
text = Text("Enter passphrase", ui.ICON_CONFIG)
text.normal("Where to enter your", "passphrase?")
@ -47,14 +50,14 @@ async def request_passphrase_source(ctx) -> int:
return await ctx.wait(source)
async def request_passphrase_ack(ctx, on_device: bool) -> str:
async def request_passphrase_ack(ctx: wire.Context, on_device: bool) -> str:
if not on_device:
text = Text("Passphrase entry", ui.ICON_CONFIG)
text.normal("Please, type passphrase", "on connected host.")
await Popup(text)
req = PassphraseRequest(on_device=on_device)
ack = await ctx.call(req, MessageType.PassphraseAck)
ack = await ctx.call(req, PassphraseAck)
if on_device:
if ack.passphrase is not None:
@ -74,6 +77,6 @@ async def request_passphrase_ack(ctx, on_device: bool) -> str:
state = cache.get_state(prev_state=ack.state, passphrase=passphrase)
req = PassphraseStateRequest(state=state)
ack = await ctx.call(req, MessageType.PassphraseStateAck, MessageType.Cancel)
ack = await ctx.call(req, PassphraseStateAck)
return passphrase

@ -4,7 +4,8 @@ from trezor.crypto import bip32
from apps.common import HARDENED, cache, mnemonic, storage
from apps.common.request_passphrase import protect_by_passphrase
allow = list
if False:
from typing import List, Optional
class Keychain:
@ -16,16 +17,16 @@ class Keychain:
def __init__(self, seed: bytes, namespaces: list):
self.seed = seed
self.namespaces = namespaces
self.roots = [None] * len(namespaces)
self.roots = [None] * len(namespaces) # type: List[Optional[bip32.HDNode]]
def __del__(self):
def __del__(self) -> None:
for root in self.roots:
if root is not None:
root.__del__()
del self.roots
del self.seed
def validate_path(self, checked_path: list, checked_curve: str):
def validate_path(self, checked_path: list, checked_curve: str) -> None:
for curve, *path in self.namespaces:
if path == checked_path[: len(path)] and curve == checked_curve:
if "ed25519" in curve and not _path_hardened(checked_path):

@ -5,8 +5,12 @@ from trezor.utils import HashWriter
from apps.wallet.sign_tx.writers import write_varint
if False:
from typing import List
from apps.common.coininfo import CoinType
def message_digest(coin, message):
def message_digest(coin: CoinType, message: bytes) -> bytes:
if coin.decred:
h = HashWriter(blake256())
else:
@ -21,7 +25,7 @@ def message_digest(coin, message):
return ret
def split_message(message):
def split_message(message: bytes) -> List[str]:
try:
m = bytes(message).decode()
words = m.split(" ")

@ -1,5 +1,8 @@
from trezor.utils import ensure
if False:
from trezor.utils import Writer
def empty_bytearray(preallocate: int) -> bytearray:
"""
@ -11,27 +14,27 @@ def empty_bytearray(preallocate: int) -> bytearray:
return b
def write_uint8(w: bytearray, n: int) -> int:
def write_uint8(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFF)
w.append(n)
return 1
def write_uint16_le(w: bytearray, n: int) -> int:
def write_uint16_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
return 2
def write_uint16_be(w: bytearray, n: int):
def write_uint16_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFF)
w.append((n >> 8) & 0xFF)
w.append(n & 0xFF)
return 2
def write_uint32_le(w: bytearray, n: int) -> int:
def write_uint32_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFF)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
@ -40,7 +43,7 @@ def write_uint32_le(w: bytearray, n: int) -> int:
return 4
def write_uint32_be(w: bytearray, n: int) -> int:
def write_uint32_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFF)
w.append((n >> 24) & 0xFF)
w.append((n >> 16) & 0xFF)
@ -49,7 +52,7 @@ def write_uint32_be(w: bytearray, n: int) -> int:
return 4
def write_uint64_le(w: bytearray, n: int) -> int:
def write_uint64_le(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF)
w.append(n & 0xFF)
w.append((n >> 8) & 0xFF)
@ -62,7 +65,7 @@ def write_uint64_le(w: bytearray, n: int) -> int:
return 8
def write_uint64_be(w: bytearray, n: int) -> int:
def write_uint64_be(w: Writer, n: int) -> int:
ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF)
w.append((n >> 56) & 0xFF)
w.append((n >> 48) & 0xFF)
@ -75,11 +78,11 @@ def write_uint64_be(w: bytearray, n: int) -> int:
return 8
def write_bytes(w: bytearray, b: bytes) -> int:
def write_bytes(w: Writer, b: bytes) -> int:
w.extend(b)
return len(b)
def write_bytes_reversed(w: bytearray, b: bytes) -> int:
def write_bytes_reversed(w: Writer, b: bytes) -> int:
w.extend(bytes(reversed(b)))
return len(b)

@ -8,15 +8,24 @@ if __debug__:
from trezor.messages import MessageType
from trezor.wire import register, protobuf_workflow
reset_internal_entropy = None
reset_current_words = None
reset_word_index = None
if False:
from typing import List, Optional
from trezor import wire
from trezor.messages.DebugLinkDecision import DebugLinkDecision
from trezor.messages.DebugLinkGetState import DebugLinkGetState
from trezor.messages.DebugLinkState import DebugLinkState
reset_internal_entropy = None # type: Optional[bytes]
reset_current_words = None # type: Optional[List[str]]
reset_word_index = None # type: Optional[int]
confirm_signal = loop.signal()
swipe_signal = loop.signal()
input_signal = loop.signal()
async def dispatch_DebugLinkDecision(ctx, msg):
async def dispatch_DebugLinkDecision(
ctx: wire.Context, msg: DebugLinkDecision
) -> None:
from trezor.ui import confirm, swipe
if msg.yes_no is not None:
@ -26,7 +35,9 @@ if __debug__:
if msg.input is not None:
input_signal.send(msg.input)
async def dispatch_DebugLinkGetState(ctx, msg):
async def dispatch_DebugLinkGetState(
ctx: wire.Context, msg: DebugLinkGetState
) -> DebugLinkState:
from trezor.messages.DebugLinkState import DebugLinkState
from apps.common import storage, mnemonic
@ -39,7 +50,7 @@ if __debug__:
m.reset_word = " ".join(reset_current_words)
return m
def boot():
def boot() -> None:
# wipe storage when debug build is used on real hardware
if not utils.EMULATOR:
config.wipe()

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "secp256k1"
def boot():
def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 194]]
wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key", ns)

@ -1,13 +1,19 @@
from trezor.crypto.hashlib import sha256
from trezor.messages.EosTxActionAck import EosTxActionAck
from trezor.messages.EosTxActionRequest import EosTxActionRequest
from trezor.messages.MessageType import EosTxActionAck
from trezor.utils import HashWriter
from apps.eos import helpers, writers
from apps.eos.actions import layout
if False:
from trezor import wire
from trezor.utils import Writer
async def process_action(ctx, sha, action):
async def process_action(
ctx: wire.Context, sha: HashWriter, action: EosTxActionAck
) -> None:
name = helpers.eos_name_to_string(action.common.name)
account = helpers.eos_name_to_string(action.common.account)
@ -65,7 +71,9 @@ async def process_action(ctx, sha, action):
writers.write_bytes(sha, w)
async def process_unknown_action(ctx, w, action):
async def process_unknown_action(
ctx: wire.Context, w: Writer, action: EosTxActionAck
) -> None:
checksum = HashWriter(sha256())
writers.write_variant32(checksum, action.unknown.data_size)
checksum.extend(action.unknown.data_chunk)
@ -91,7 +99,7 @@ async def process_unknown_action(ctx, w, action):
await layout.confirm_action_unknown(ctx, action.common, checksum.get_digest())
def check_action(action, name, account):
def check_action(action: EosTxActionAck, name: str, account: str) -> bool:
if account == "eosio":
if (
(name == "buyram" and action.buy_ram is not None)

@ -2,22 +2,7 @@ from micropython import const
from ubinascii import hexlify
from trezor import ui
from trezor.messages import (
ButtonRequestType,
EosActionBuyRam,
EosActionBuyRamBytes,
EosActionDelegate,
EosActionDeleteAuth,
EosActionLinkAuth,
EosActionNewAccount,
EosActionRefund,
EosActionSellRam,
EosActionTransfer,
EosActionUndelegate,
EosActionUnlinkAuth,
EosActionUpdateAuth,
EosActionVoteProducer,
)
from trezor.messages import ButtonRequestType
from trezor.ui.scroll import Paginated
from trezor.ui.text import Text
from trezor.utils import chunks
@ -26,6 +11,25 @@ from apps.eos import helpers
from apps.eos.get_public_key import _public_key_to_wif
from apps.eos.layout import require_confirm
if False:
from typing import List
from trezor import wire
from trezor.messages.EosAuthorization import EosAuthorization
from trezor.messages.EosActionBuyRam import EosActionBuyRam
from trezor.messages.EosActionBuyRamBytes import EosActionBuyRamBytes
from trezor.messages.EosActionCommon import EosActionCommon
from trezor.messages.EosActionDelegate import EosActionDelegate
from trezor.messages.EosActionDeleteAuth import EosActionDeleteAuth
from trezor.messages.EosActionLinkAuth import EosActionLinkAuth
from trezor.messages.EosActionNewAccount import EosActionNewAccount
from trezor.messages.EosActionRefund import EosActionRefund
from trezor.messages.EosActionSellRam import EosActionSellRam
from trezor.messages.EosActionTransfer import EosActionTransfer
from trezor.messages.EosActionUndelegate import EosActionUndelegate
from trezor.messages.EosActionUnlinkAuth import EosActionUnlinkAuth
from trezor.messages.EosActionUpdateAuth import EosActionUpdateAuth
from trezor.messages.EosActionVoteProducer import EosActionVoteProducer
_LINE_LENGTH = const(17)
_LINE_PLACEHOLDER = "{:<" + str(_LINE_LENGTH) + "}"
_FIRST_PAGE = const(0)
@ -35,7 +39,9 @@ _FOUR_FIELDS_PER_PAGE = const(4)
_FIVE_FIELDS_PER_PAGE = const(5)
async def _require_confirm_paginated(ctx, header, fields, per_page):
async def _require_confirm_paginated(
ctx: wire.Context, header: str, fields: List[str], per_page: int
) -> None:
pages = []
for page in chunks(fields, per_page):
if header == "Arbitrary data":
@ -47,7 +53,7 @@ async def _require_confirm_paginated(ctx, header, fields, per_page):
await require_confirm(ctx, Paginated(pages), ButtonRequestType.ConfirmOutput)
async def confirm_action_buyram(ctx, msg: EosActionBuyRam):
async def confirm_action_buyram(ctx: wire.Context, msg: EosActionBuyRam) -> None:
text = "Buy RAM"
fields = []
fields.append("Payer:")
@ -59,7 +65,9 @@ async def confirm_action_buyram(ctx, msg: EosActionBuyRam):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_buyrambytes(ctx, msg: EosActionBuyRamBytes):
async def confirm_action_buyrambytes(
ctx: wire.Context, msg: EosActionBuyRamBytes
) -> None:
text = "Buy RAM"
fields = []
fields.append("Payer:")
@ -71,7 +79,7 @@ async def confirm_action_buyrambytes(ctx, msg: EosActionBuyRamBytes):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_delegate(ctx, msg: EosActionDelegate):
async def confirm_action_delegate(ctx: wire.Context, msg: EosActionDelegate) -> None:
text = "Delegate"
fields = []
fields.append("Sender:")
@ -93,7 +101,7 @@ async def confirm_action_delegate(ctx, msg: EosActionDelegate):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_sellram(ctx, msg: EosActionSellRam):
async def confirm_action_sellram(ctx: wire.Context, msg: EosActionSellRam) -> None:
text = "Sell RAM"
fields = []
fields.append("Receiver:")
@ -103,7 +111,9 @@ async def confirm_action_sellram(ctx, msg: EosActionSellRam):
await _require_confirm_paginated(ctx, text, fields, _TWO_FIELDS_PER_PAGE)
async def confirm_action_undelegate(ctx, msg: EosActionUndelegate):
async def confirm_action_undelegate(
ctx: wire.Context, msg: EosActionUndelegate
) -> None:
text = "Undelegate"
fields = []
fields.append("Sender:")
@ -117,14 +127,16 @@ async def confirm_action_undelegate(ctx, msg: EosActionUndelegate):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_refund(ctx, msg: EosActionRefund):
async def confirm_action_refund(ctx: wire.Context, msg: EosActionRefund) -> None:
text = Text("Refund", ui.ICON_CONFIRM, icon_color=ui.GREEN)
text.normal("Owner:")
text.normal(helpers.eos_name_to_string(msg.owner))
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_action_voteproducer(ctx, msg: EosActionVoteProducer):
async def confirm_action_voteproducer(
ctx: wire.Context, msg: EosActionVoteProducer
) -> None:
if msg.proxy and not msg.producers:
# PROXY
text = Text("Vote for proxy", ui.ICON_CONFIRM, icon_color=ui.GREEN)
@ -151,7 +163,9 @@ async def confirm_action_voteproducer(ctx, msg: EosActionVoteProducer):
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_action_transfer(ctx, msg: EosActionTransfer, account: str):
async def confirm_action_transfer(
ctx: wire.Context, msg: EosActionTransfer, account: str
) -> None:
text = "Transfer"
fields = []
fields.append("From:")
@ -170,7 +184,9 @@ async def confirm_action_transfer(ctx, msg: EosActionTransfer, account: str):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_updateauth(ctx, msg: EosActionUpdateAuth):
async def confirm_action_updateauth(
ctx: wire.Context, msg: EosActionUpdateAuth
) -> None:
text = "Update Auth"
fields = []
fields.append("Account:")
@ -183,7 +199,9 @@ async def confirm_action_updateauth(ctx, msg: EosActionUpdateAuth):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_deleteauth(ctx, msg: EosActionDeleteAuth):
async def confirm_action_deleteauth(
ctx: wire.Context, msg: EosActionDeleteAuth
) -> None:
text = Text("Delete auth", ui.ICON_CONFIRM, icon_color=ui.GREEN)
text.normal("Account:")
text.normal(helpers.eos_name_to_string(msg.account))
@ -192,7 +210,7 @@ async def confirm_action_deleteauth(ctx, msg: EosActionDeleteAuth):
await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput)
async def confirm_action_linkauth(ctx, msg: EosActionLinkAuth):
async def confirm_action_linkauth(ctx: wire.Context, msg: EosActionLinkAuth) -> None:
text = "Link Auth"
fields = []
fields.append("Account:")
@ -206,7 +224,9 @@ async def confirm_action_linkauth(ctx, msg: EosActionLinkAuth):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_unlinkauth(ctx, msg: EosActionUnlinkAuth):
async def confirm_action_unlinkauth(
ctx: wire.Context, msg: EosActionUnlinkAuth
) -> None:
text = "Unlink Auth"
fields = []
fields.append("Account:")
@ -218,7 +238,9 @@ async def confirm_action_unlinkauth(ctx, msg: EosActionUnlinkAuth):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_newaccount(ctx, msg: EosActionNewAccount):
async def confirm_action_newaccount(
ctx: wire.Context, msg: EosActionNewAccount
) -> None:
text = "New Account"
fields = []
fields.append("Creator:")
@ -230,7 +252,9 @@ async def confirm_action_newaccount(ctx, msg: EosActionNewAccount):
await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE)
async def confirm_action_unknown(ctx, action, checksum):
async def confirm_action_unknown(
ctx: wire.Context, action: EosActionCommon, checksum: bytes
) -> None:
text = "Arbitrary data"
fields = []
fields.append("Contract:")
@ -242,7 +266,7 @@ async def confirm_action_unknown(ctx, action, checksum):
await _require_confirm_paginated(ctx, text, fields, _FIVE_FIELDS_PER_PAGE)
def authorization_fields(auth):
def authorization_fields(auth: EosAuthorization) -> List[str]:
fields = []
fields.append("Threshold:")
@ -288,11 +312,9 @@ def authorization_fields(auth):
return fields
def split_data(data):
temp_list = []
len_left = len(data)
while len_left > 0:
temp_list.append("{} ".format(data[:_LINE_LENGTH]))
def split_data(data: str) -> List[str]:
lines = []
while data:
lines.append("{} ".format(data[:_LINE_LENGTH]))
data = data[_LINE_LENGTH:]
len_left = len(data)
return temp_list
return lines

@ -8,6 +8,11 @@ from apps.eos import CURVE
from apps.eos.helpers import base58_encode, validate_full_path
from apps.eos.layout import require_get_public_key
if False:
from typing import Tuple
from trezor.crypto import bip32
from apps.common import seed
def _public_key_to_wif(pub_key: bytes) -> str:
if pub_key[0] == 0x04 and len(pub_key) == 65:
@ -20,14 +25,16 @@ def _public_key_to_wif(pub_key: bytes) -> str:
return base58_encode("EOS", "", compressed_pub_key)
def _get_public_key(node):
def _get_public_key(node: bip32.HDNode) -> Tuple[str, bytes]:
seckey = node.private_key()
public_key = secp256k1.publickey(seckey, True)
wif = _public_key_to_wif(public_key)
return wif, public_key
async def get_public_key(ctx, msg: EosGetPublicKey, keychain):
async def get_public_key(
ctx: wire.Context, msg: EosGetPublicKey, keychain: seed.Keychain
) -> EosPublicKey:
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
node = keychain.derive(msg.address_n)

@ -1,5 +1,5 @@
from trezor.crypto import base58
from trezor.messages import EosAsset
from trezor.messages.EosAsset import EosAsset
from apps.common import HARDENED
@ -12,7 +12,7 @@ def base58_encode(prefix: str, sig_prefix: str, data: bytes) -> str:
return prefix + b58
def eos_name_to_string(value) -> str:
def eos_name_to_string(value: int) -> str:
charmap = ".12345abcdefghijklmnopqrstuvwxyz"
tmp = value
string = ""

@ -1,19 +1,19 @@
from trezor import ui
from trezor import ui, wire
from trezor.messages import ButtonRequestType
from trezor.ui.text import Text
from apps.common.confirm import require_confirm
async def require_get_public_key(ctx, public_key):
async def require_get_public_key(ctx: wire.Context, public_key: str) -> None:
text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN)
text.normal(public_key)
return await require_confirm(ctx, text, ButtonRequestType.PublicKey)
await require_confirm(ctx, text, ButtonRequestType.PublicKey)
async def require_sign_tx(ctx, num_actions):
async def require_sign_tx(ctx: wire.Context, num_actions: int) -> None:
text = Text("Sign transaction", ui.ICON_SEND, ui.GREEN)
text.normal("You are about")
text.normal("to sign {}".format(num_actions))
text.normal("action(s).")
return await require_confirm(ctx, text, ButtonRequestType.SignTx)
await require_confirm(ctx, text, ButtonRequestType.SignTx)

@ -3,8 +3,8 @@ from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha256
from trezor.messages.EosSignedTx import EosSignedTx
from trezor.messages.EosSignTx import EosSignTx
from trezor.messages.EosTxActionAck import EosTxActionAck
from trezor.messages.EosTxActionRequest import EosTxActionRequest
from trezor.messages.MessageType import EosTxActionAck
from trezor.utils import HashWriter
from apps.common import paths
@ -13,8 +13,13 @@ from apps.eos.actions import process_action
from apps.eos.helpers import base58_encode, validate_full_path
from apps.eos.layout import require_sign_tx
if False:
from apps.common import seed
async def sign_tx(ctx, msg: EosSignTx, keychain):
async def sign_tx(
ctx: wire.Context, msg: EosSignTx, keychain: seed.Keychain
) -> EosSignedTx:
if msg.chain_id is None:
raise wire.DataError("No chain id")
if msg.header is None:
@ -39,7 +44,7 @@ async def sign_tx(ctx, msg: EosSignTx, keychain):
return EosSignedTx(signature=base58_encode("SIG_", "K1", signature))
async def _init(ctx, sha, msg):
async def _init(ctx: wire.Context, sha: HashWriter, msg: EosSignTx) -> None:
writers.write_bytes(sha, msg.chain_id)
writers.write_header(sha, msg.header)
writers.write_variant32(sha, 0)
@ -48,7 +53,7 @@ async def _init(ctx, sha, msg):
await require_sign_tx(ctx, msg.num_actions)
async def _actions(ctx, sha, num_actions: int):
async def _actions(ctx: wire.Context, sha: HashWriter, num_actions: int) -> None:
for i in range(num_actions):
action = await ctx.call(EosTxActionRequest(), EosTxActionAck)
await process_action(ctx, sha, action)

@ -1,23 +1,3 @@
from trezor.messages import (
EosActionBuyRam,
EosActionBuyRamBytes,
EosActionCommon,
EosActionDelegate,
EosActionDeleteAuth,
EosActionLinkAuth,
EosActionNewAccount,
EosActionRefund,
EosActionSellRam,
EosActionTransfer,
EosActionUndelegate,
EosActionUpdateAuth,
EosActionVoteProducer,
EosAsset,
EosAuthorization,
EosTxHeader,
)
from trezor.utils import HashWriter
from apps.common.writers import (
write_bytes,
write_uint8,
@ -26,8 +6,27 @@ from apps.common.writers import (
write_uint64_le,
)
def write_auth(w: bytearray, auth: EosAuthorization) -> int:
if False:
from trezor.messages.EosActionBuyRam import EosActionBuyRam
from trezor.messages.EosActionBuyRamBytes import EosActionBuyRamBytes
from trezor.messages.EosActionCommon import EosActionCommon
from trezor.messages.EosActionDelegate import EosActionDelegate
from trezor.messages.EosActionDeleteAuth import EosActionDeleteAuth
from trezor.messages.EosActionLinkAuth import EosActionLinkAuth
from trezor.messages.EosActionNewAccount import EosActionNewAccount
from trezor.messages.EosActionRefund import EosActionRefund
from trezor.messages.EosActionSellRam import EosActionSellRam
from trezor.messages.EosActionTransfer import EosActionTransfer
from trezor.messages.EosActionUndelegate import EosActionUndelegate
from trezor.messages.EosActionUpdateAuth import EosActionUpdateAuth
from trezor.messages.EosActionVoteProducer import EosActionVoteProducer
from trezor.messages.EosAsset import EosAsset
from trezor.messages.EosAuthorization import EosAuthorization
from trezor.messages.EosTxHeader import EosTxHeader
from trezor.utils import Writer
def write_auth(w: Writer, auth: EosAuthorization) -> None:
write_uint32_le(w, auth.threshold)
write_variant32(w, len(auth.keys))
for key in auth.keys:
@ -47,7 +46,7 @@ def write_auth(w: bytearray, auth: EosAuthorization) -> int:
write_uint16_le(w, wait.weight)
def write_header(hasher: HashWriter, header: EosTxHeader):
def write_header(hasher: Writer, header: EosTxHeader) -> None:
write_uint32_le(hasher, header.expiration)
write_uint16_le(hasher, header.ref_block_num)
write_uint32_le(hasher, header.ref_block_prefix)
@ -56,7 +55,7 @@ def write_header(hasher: HashWriter, header: EosTxHeader):
write_variant32(hasher, header.delay_sec)
def write_action_transfer(w: bytearray, msg: EosActionTransfer):
def write_action_transfer(w: Writer, msg: EosActionTransfer) -> None:
write_uint64_le(w, msg.sender)
write_uint64_le(w, msg.receiver)
write_asset(w, msg.quantity)
@ -64,24 +63,24 @@ def write_action_transfer(w: bytearray, msg: EosActionTransfer):
write_bytes(w, msg.memo)
def write_action_buyram(w: bytearray, msg: EosActionBuyRam):
def write_action_buyram(w: Writer, msg: EosActionBuyRam) -> None:
write_uint64_le(w, msg.payer)
write_uint64_le(w, msg.receiver)
write_asset(w, msg.quantity)
def write_action_buyrambytes(w: bytearray, msg: EosActionBuyRamBytes):
def write_action_buyrambytes(w: Writer, msg: EosActionBuyRamBytes) -> None:
write_uint64_le(w, msg.payer)
write_uint64_le(w, msg.receiver)
write_uint32_le(w, msg.bytes)
def write_action_sellram(w: bytearray, msg: EosActionSellRam):
def write_action_sellram(w: Writer, msg: EosActionSellRam) -> None:
write_uint64_le(w, msg.account)
write_uint64_le(w, msg.bytes)
def write_action_delegate(w: bytearray, msg: EosActionDelegate):
def write_action_delegate(w: Writer, msg: EosActionDelegate) -> None:
write_uint64_le(w, msg.sender)
write_uint64_le(w, msg.receiver)
write_asset(w, msg.net_quantity)
@ -89,18 +88,18 @@ def write_action_delegate(w: bytearray, msg: EosActionDelegate):
write_uint8(w, 1 if msg.transfer else 0)
def write_action_undelegate(w: bytearray, msg: EosActionUndelegate):
def write_action_undelegate(w: Writer, msg: EosActionUndelegate) -> None:
write_uint64_le(w, msg.sender)
write_uint64_le(w, msg.receiver)
write_asset(w, msg.net_quantity)
write_asset(w, msg.cpu_quantity)
def write_action_refund(w: bytearray, msg: EosActionRefund):
def write_action_refund(w: Writer, msg: EosActionRefund) -> None:
write_uint64_le(w, msg.owner)
def write_action_voteproducer(w: bytearray, msg: EosActionVoteProducer):
def write_action_voteproducer(w: Writer, msg: EosActionVoteProducer) -> None:
write_uint64_le(w, msg.voter)
write_uint64_le(w, msg.proxy)
write_variant32(w, len(msg.producers))
@ -108,61 +107,59 @@ def write_action_voteproducer(w: bytearray, msg: EosActionVoteProducer):
write_uint64_le(w, producer)
def write_action_updateauth(w: bytearray, msg: EosActionUpdateAuth):
def write_action_updateauth(w: Writer, msg: EosActionUpdateAuth) -> None:
write_uint64_le(w, msg.account)
write_uint64_le(w, msg.permission)
write_uint64_le(w, msg.parent)
write_auth(w, msg.auth)
def write_action_deleteauth(w: bytearray, msg: EosActionDeleteAuth):
def write_action_deleteauth(w: Writer, msg: EosActionDeleteAuth) -> None:
write_uint64_le(w, msg.account)
write_uint64_le(w, msg.permission)
def write_action_linkauth(w: bytearray, msg: EosActionLinkAuth):
def write_action_linkauth(w: Writer, msg: EosActionLinkAuth) -> None:
write_uint64_le(w, msg.account)
write_uint64_le(w, msg.code)
write_uint64_le(w, msg.type)
write_uint64_le(w, msg.requirement)
def write_action_unlinkauth(w: bytearray, msg: EosActionLinkAuth):
def write_action_unlinkauth(w: Writer, msg: EosActionLinkAuth) -> None:
write_uint64_le(w, msg.account)
write_uint64_le(w, msg.code)
write_uint64_le(w, msg.type)
def write_action_newaccount(w: bytearray, msg: EosActionNewAccount):
def write_action_newaccount(w: Writer, msg: EosActionNewAccount) -> None:
write_uint64_le(w, msg.creator)
write_uint64_le(w, msg.name)
write_auth(w, msg.owner)
write_auth(w, msg.active)
def write_action_common(hasher: HashWriter, msg: EosActionCommon):
write_uint64_le(hasher, msg.account)
write_uint64_le(hasher, msg.name)
write_variant32(hasher, len(msg.authorization))
def write_action_common(w: Writer, msg: EosActionCommon) -> None:
write_uint64_le(w, msg.account)
write_uint64_le(w, msg.name)
write_variant32(w, len(msg.authorization))
for authorization in msg.authorization:
write_uint64_le(hasher, authorization.actor)
write_uint64_le(hasher, authorization.permission)
write_uint64_le(w, authorization.actor)
write_uint64_le(w, authorization.permission)
def write_asset(w: bytearray, asset: EosAsset) -> int:
def write_asset(w: Writer, asset: EosAsset) -> None:
write_uint64_le(w, asset.amount)
write_uint64_le(w, asset.symbol)
def write_variant32(w: bytearray, value: int) -> int:
def write_variant32(w: Writer, value: int) -> None:
variant = bytearray()
while True:
b = value & 0x7F
value >>= 7
b |= (value > 0) << 7
variant.append(b)
if value == 0:
break
write_bytes(w, bytes(variant))

@ -7,7 +7,7 @@ from apps.ethereum.networks import all_slip44_ids_hardened
CURVE = "secp256k1"
def boot():
def boot() -> None:
ns = []
for i in all_slip44_ids_hardened():
ns.append([CURVE, HARDENED | 44, i])

@ -3,8 +3,8 @@ from trezor.crypto import rlp
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages.EthereumSignTx import EthereumSignTx
from trezor.messages.EthereumTxAck import EthereumTxAck
from trezor.messages.EthereumTxRequest import EthereumTxRequest
from trezor.messages.MessageType import EthereumTxAck
from trezor.utils import HashWriter
from apps.common import paths

@ -6,8 +6,16 @@ from trezor.wire import protobuf_workflow, register
from apps.common import cache, storage
if False:
from typing import NoReturn
from trezor.messages.Initialize import Initialize
from trezor.messages.GetFeatures import GetFeatures
from trezor.messages.Cancel import Cancel
from trezor.messages.ClearSession import ClearSession
from trezor.messages.Ping import Ping
def get_features():
def get_features() -> Features:
f = Features()
f.vendor = "trezor.io"
f.language = "english"
@ -30,7 +38,7 @@ def get_features():
return f
async def handle_Initialize(ctx, msg):
async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features:
if msg.state is None or msg.state != cache.get_state(prev_state=bytes(msg.state)):
cache.clear()
if msg.skip_passphrase:
@ -38,20 +46,20 @@ async def handle_Initialize(ctx, msg):
return get_features()
async def handle_GetFeatures(ctx, msg):
async def handle_GetFeatures(ctx: wire.Context, msg: GetFeatures) -> Features:
return get_features()
async def handle_Cancel(ctx, msg):
async def handle_Cancel(ctx: wire.Context, msg: Cancel) -> NoReturn:
raise wire.ActionCancelled("Cancelled")
async def handle_ClearSession(ctx, msg):
async def handle_ClearSession(ctx: wire.Context, msg: ClearSession) -> Success:
cache.clear(keep_passphrase=True)
return Success(message="Session cleared")
async def handle_Ping(ctx, msg):
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
if msg.button_protection:
from apps.common.confirm import require_confirm
from trezor.messages.ButtonRequestType import ProtectCall
@ -65,7 +73,7 @@ async def handle_Ping(ctx, msg):
return Success(message=msg.message)
def boot():
def boot() -> None:
register(MessageType.Initialize, protobuf_workflow, handle_Initialize)
register(MessageType.GetFeatures, protobuf_workflow, handle_GetFeatures)
register(MessageType.Cancel, protobuf_workflow, handle_Cancel)

@ -3,7 +3,7 @@ from trezor import config, io, loop, res, ui
from apps.common import storage
async def homescreen():
async def homescreen() -> None:
# render homescreen in dimmed mode and fade back in
ui.backlight_fade(ui.BACKLIGHT_DIM)
display_homescreen()
@ -15,7 +15,7 @@ async def homescreen():
await touch
def display_homescreen():
def display_homescreen() -> None:
image = None
if storage.slip39.is_in_progress():
label = "Waiting for other shares"
@ -44,13 +44,13 @@ def display_homescreen():
ui.display.text_center(ui.WIDTH // 2, 220, label, ui.BOLD, ui.FG, ui.BG)
def _warn(message: str):
def _warn(message: str) -> None:
ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW)
ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.BLACK, ui.YELLOW)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)
def _err(message: str):
def _err(message: str) -> None:
ui.display.bar(0, 0, ui.WIDTH, 30, ui.RED)
ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.WHITE, ui.RED)
ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519"
def boot():
def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 134]]
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns)

@ -2,7 +2,7 @@ from trezor import wire
from trezor.messages import MessageType
def boot():
def boot() -> None:
# only enable LoadDevice in debug builds
if __debug__:
wire.add(MessageType.LoadDevice, __name__, "load_device")

@ -1,5 +1,6 @@
from trezor import config, ui, wire
from trezor.messages import ButtonRequestType, MessageType
from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.Success import Success
from trezor.pin import pin_to_int
@ -74,9 +75,7 @@ async def request_pin_confirm(ctx, *args, **kwargs):
async def request_pin_ack(ctx, *args, **kwargs):
try:
await ctx.call(
ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck
)
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck)
return await ctx.wait(request_pin(*args, **kwargs))
except PinCancelled:
raise wire.ActionCancelled("Cancelled")

@ -1,8 +1,8 @@
from trezor import config, ui, wire
from trezor.crypto import slip39
from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.MessageType import ButtonAck
from trezor.messages.Success import Success
from trezor.pin import pin_to_int
from trezor.ui.info import InfoConfirm
@ -20,8 +20,11 @@ from apps.management.change_pin import request_pin_ack, request_pin_confirm
if __debug__:
from apps.debug import confirm_signal, input_signal
if False:
from trezor.messages.RecoveryDevice import RecoveryDevice
async def recovery_device(ctx, msg):
async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
"""
Recover BIP39/SLIP39 seed into empty device.
@ -116,7 +119,7 @@ async def recovery_device(ctx, msg):
return Success(message="Device recovered")
async def request_wordcount(ctx, title: str) -> int:
async def request_wordcount(ctx: wire.Context, title: str) -> int:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck)
text = Text(title, ui.ICON_RECOVERY)
@ -131,7 +134,7 @@ async def request_wordcount(ctx, title: str) -> int:
return count
async def request_mnemonic(ctx, count: int, slip39: bool) -> str:
async def request_mnemonic(ctx: wire.Context, count: int, slip39: bool) -> str:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck)
words = []
@ -149,7 +152,7 @@ async def request_mnemonic(ctx, count: int, slip39: bool) -> str:
return " ".join(words)
async def show_keyboard_info(ctx):
async def show_keyboard_info(ctx: wire.Context) -> None:
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck)
info = InfoConfirm(
@ -174,7 +177,9 @@ async def show_success(ctx):
)
async def show_remaining_slip39_mnemonics(ctx, title, remaining: int):
async def show_remaining_slip39_mnemonics(
ctx: wire.Context, title: str, remaining: int
) -> None:
text = Text(title, ui.ICON_RECOVERY)
text.bold("Good job!")
text.normal("Enter %s more recovery " % remaining)

@ -1,6 +1,6 @@
from trezor import config, wire
from trezor.crypto import bip39, hashlib, random, slip39
from trezor.messages import MessageType
from trezor.messages.EntropyAck import EntropyAck
from trezor.messages.EntropyRequest import EntropyRequest
from trezor.messages.Success import Success
from trezor.pin import pin_to_int
@ -12,8 +12,11 @@ from apps.management.common import layout
if __debug__:
from apps import debug
if False:
from trezor.messages.ResetDevice import ResetDevice
async def reset_device(ctx, msg):
async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
# validate parameters and device state
_validate_reset_device(msg)
@ -34,7 +37,7 @@ async def reset_device(ctx, msg):
await layout.show_internal_entropy(ctx, int_entropy)
# request external entropy and compute the master secret
entropy_ack = await ctx.call(EntropyRequest(), MessageType.EntropyAck)
entropy_ack = await ctx.call(EntropyRequest(), EntropyAck)
ext_entropy = entropy_ack.entropy
secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength)
@ -84,7 +87,7 @@ async def reset_device(ctx, msg):
return Success(message="Initialized")
async def backup_slip39_wallet(ctx, secret: bytes):
async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None:
# get number of shares
await layout.slip39_show_checklist_set_shares(ctx)
shares_count = await layout.slip39_prompt_number_of_shares(ctx)
@ -101,12 +104,12 @@ async def backup_slip39_wallet(ctx, secret: bytes):
await layout.slip39_show_and_confirm_shares(ctx, mnemonics)
async def backup_bip39_wallet(ctx, secret: bytes):
async def backup_bip39_wallet(ctx: wire.Context, secret: bytes) -> None:
mnemonic = bip39.from_data(secret)
await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic)
def _validate_reset_device(msg):
def _validate_reset_device(msg: ResetDevice) -> None:
if msg.strength not in (128, 256):
if msg.slip39:
raise wire.ProcessError("Invalid strength (has to be 128 or 256 bits)")

@ -7,7 +7,7 @@ CURVE = "ed25519"
_LIVE_REFRESH_TOKEN = None # live-refresh permission token
def boot():
def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 128]]
wire.add(MessageType.MoneroGetAddress, __name__, "get_address", ns)
wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only", ns)

@ -1,11 +1,14 @@
import gc
from trezor import log, wire
from trezor.messages import MessageType
from trezor.messages.MoneroExportedKeyImage import MoneroExportedKeyImage
from trezor.messages.MoneroKeyImageExportInitAck import MoneroKeyImageExportInitAck
from trezor.messages.MoneroKeyImageSyncFinalAck import MoneroKeyImageSyncFinalAck
from trezor.messages.MoneroKeyImageSyncFinalRequest import (
MoneroKeyImageSyncFinalRequest,
)
from trezor.messages.MoneroKeyImageSyncStepAck import MoneroKeyImageSyncStepAck
from trezor.messages.MoneroKeyImageSyncStepRequest import MoneroKeyImageSyncStepRequest
from apps.common import paths
from apps.monero import CURVE, misc
@ -18,19 +21,12 @@ async def key_image_sync(ctx, msg, keychain):
state = KeyImageSync()
res = await _init_step(state, ctx, msg, keychain)
while True:
msg = await ctx.call(
res,
MessageType.MoneroKeyImageSyncStepRequest,
MessageType.MoneroKeyImageSyncFinalRequest,
)
del res
if msg.MESSAGE_WIRE_TYPE == MessageType.MoneroKeyImageSyncStepRequest:
res = await _sync_step(state, ctx, msg)
else:
res = await _final_step(state, ctx)
break
while state.current_output + 1 < state.num_outputs:
msg = await ctx.call(res, MoneroKeyImageSyncStepRequest)
res = await _sync_step(state, ctx, msg)
gc.collect()
msg = await ctx.call(res, MoneroKeyImageSyncFinalRequest)
res = await _final_step(state, ctx)
return res

@ -1,5 +1,6 @@
from trezor import loop, ui, utils
from trezor.messages import ButtonRequestType, MessageType
from trezor.messages import ButtonRequestType
from trezor.messages.ButtonAck import ButtonAck
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.ui.text import Text
@ -27,9 +28,7 @@ async def naive_pagination(
paginated = PaginatedWithButtons(pages, one_by_one=True)
while True:
await ctx.call(
ButtonRequest(code=ButtonRequestType.SignTx), MessageType.ButtonAck
)
await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck)
if __debug__:
result = await loop.spawn(paginated, confirm_signal)
else:

@ -21,7 +21,7 @@ async def live_refresh(ctx, msg: MoneroLiveRefreshStartRequest, keychain):
res = await _init_step(state, ctx, msg, keychain)
while True:
msg = await ctx.call(
msg = await ctx.call_any(
res,
MessageType.MoneroLiveRefreshStepRequest,
MessageType.MoneroLiveRefreshFinalRequest,

@ -26,7 +26,7 @@ async def sign_tx(ctx, received_msg, keychain):
del (result_msg, received_msg)
utils.unimport_end(mods)
received_msg = await ctx.read(accept_msgs)
received_msg = await ctx.read_any(accept_msgs)
utils.unimport_end(mods)
return result_msg

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519-keccak"
def boot():
def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 43], [CURVE, HARDENED | 44, HARDENED | 1]]
wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns)
wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "secp256k1"
def boot():
def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 144]]
wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns)
wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519"
def boot():
def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 148]]
wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns)
wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns)

@ -77,7 +77,7 @@ def _timebounds(w: bytearray, start: int, end: int):
async def _operations(ctx, w: bytearray, num_operations: int):
writers.write_uint32(w, num_operations)
for i in range(num_operations):
op = await ctx.call(StellarTxOpRequest(), *consts.op_wire_types)
op = await ctx.call_any(StellarTxOpRequest(), *consts.op_wire_types)
await process_operation(ctx, w, op)

@ -6,7 +6,7 @@ from apps.common import HARDENED
CURVE = "ed25519"
def boot():
def boot() -> None:
ns = [[CURVE, HARDENED | 44, HARDENED | 1729]]
wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns)
wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns)

@ -2,7 +2,7 @@ from trezor import wire
from trezor.messages import MessageType
def boot():
def boot() -> None:
ns = [
["curve25519"],
["ed25519"],

@ -1,6 +1,6 @@
from trezor import utils, wire
from trezor.messages.MessageType import TxAck
from trezor.messages.RequestType import TXFINISHED
from trezor.messages.TxAck import TxAck
from trezor.messages.TxRequest import TxRequest
from apps.common import paths

@ -5,7 +5,7 @@ from apps.common import storage
from apps.common.request_pin import request_pin
async def bootscreen():
async def bootscreen() -> None:
ui.display.orientation(storage.device.get_rotation())
while True:
try:
@ -27,7 +27,7 @@ async def bootscreen():
log.exception(__name__, e)
async def lockscreen():
async def lockscreen() -> None:
label = storage.device.get_label()
image = storage.device.get_homescreen()
if not label:

@ -1,32 +1,31 @@
'''
"""
Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
"""
For de-serializing (loading) protobuf types, object with `AsyncReader`
interface is required:
from micropython import const
>>> class AsyncReader:
>>> async def areadinto(self, buffer):
>>> """
>>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`.
>>> """
if False:
from typing import Any, Dict, List, Type, TypeVar
from typing_extensions import Protocol
For serializing (dumping) protobuf types, object with `AsyncWriter` interface is
required:
class AsyncReader(Protocol):
async def areadinto(self, buf: bytearray) -> int:
"""
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
"""
>>> class AsyncWriter:
>>> async def awrite(self, buffer):
>>> """
>>> Writes all bytes from `buffer`, or raises `EOFError`.
>>> """
'''
class AsyncWriter(Protocol):
async def awrite(self, buf: bytes) -> int:
"""
Writes all bytes from `buf`, or raises `EOFError`.
"""
from micropython import const
_UVARINT_BUFFER = bytearray(1)
async def load_uvarint(reader):
async def load_uvarint(reader: AsyncReader) -> int:
buffer = _UVARINT_BUFFER
result = 0
shift = 0
@ -39,11 +38,11 @@ async def load_uvarint(reader):
return result
async def dump_uvarint(writer, n):
async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER
shifted = True
shifted = 1
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
@ -51,7 +50,7 @@ async def dump_uvarint(writer, n):
n = shifted
def count_uvarint(n):
def count_uvarint(n: int) -> int:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
if n <= 0x7F:
@ -95,14 +94,14 @@ def count_uvarint(n):
# So we have to branch on whether the number is negative.
def sint_to_uint(sint):
def sint_to_uint(sint: int) -> int:
res = sint << 1
if sint < 0:
res = ~res
return res
def uint_to_sint(uint):
def uint_to_sint(uint: int) -> int:
sign = uint & 1
res = uint >> 1
if sign:
@ -133,27 +132,31 @@ class UnicodeType:
class MessageType:
WIRE_TYPE = 2
# Type id for the wire codec.
# Technically, not every protobuf message has this.
MESSAGE_WIRE_TYPE = -1
@classmethod
def get_fields(cls):
def get_fields(cls) -> Dict:
return {}
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
for kw in kwargs:
setattr(self, kw, kwargs[kw])
def __eq__(self, rhs):
def __eq__(self, rhs: Any) -> bool:
return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__
def __repr__(self):
def __repr__(self) -> str:
return "<%s>" % self.__class__.__name__
class LimitedReader:
def __init__(self, reader, limit):
def __init__(self, reader: AsyncReader, limit: int) -> None:
self.reader = reader
self.limit = limit
async def areadinto(self, buf):
async def areadinto(self, buf: bytearray) -> int:
if self.limit < len(buf):
raise EOFError
else:
@ -162,20 +165,15 @@ class LimitedReader:
return nread
class CountingWriter:
def __init__(self):
self.size = 0
async def awrite(self, buf):
nwritten = len(buf)
self.size += nwritten
return nwritten
FLAG_REPEATED = const(1)
if False:
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
async def load_message(reader, msg_type):
async def load_message(
reader: AsyncReader, msg_type: Type[LoadedMessageType]
) -> LoadedMessageType:
fields = msg_type.get_fields()
msg = msg_type()
@ -239,7 +237,9 @@ async def load_message(reader, msg_type):
return msg
async def dump_message(writer, msg, fields=None):
async def dump_message(
writer: AsyncWriter, msg: MessageType, fields: Dict = None
) -> None:
repvalue = [0]
if fields is None:
@ -297,7 +297,7 @@ async def dump_message(writer, msg, fields=None):
raise TypeError
def count_message(msg, fields=None):
def count_message(msg: MessageType, fields: Dict = None) -> int:
nbytes = 0
repvalue = [0]
@ -361,7 +361,7 @@ def count_message(msg, fields=None):
return nbytes
def _count_bytes_list(svalue):
def _count_bytes_list(svalue: List[bytes]) -> int:
res = 0
for x in svalue:
res += len(x)

@ -1,5 +1,3 @@
from gc import collect
from trezorcrypto import ( # noqa: F401
aes,
bip32,
@ -12,18 +10,3 @@ from trezorcrypto import ( # noqa: F401
random,
rfc6979,
)
class SecureContext:
def __init__(self):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
for k in self.__dict__:
o = getattr(self, k)
if hasattr(o, "__del__"):
o.__del__()
collect()

@ -1,5 +1,23 @@
if False:
from typing import Protocol, Type
class HashContext(Protocol):
digest_size = -1 # type: int
block_size = -1 # type: int
def __init__(self, data: bytes = None) -> None:
...
def update(self, data: bytes) -> None:
...
def digest(self) -> bytes:
...
class Hmac:
def __init__(self, key, msg, digestmod):
def __init__(self, key: bytes, msg: bytes, digestmod: Type[HashContext]):
self.digestmod = digestmod
self.inner = digestmod()
self.digest_size = self.inner.digest_size
@ -28,7 +46,7 @@ class Hmac:
return outer.digest()
def new(key, msg, digestmod) -> Hmac:
def new(key: bytes, msg: bytes, digestmod: Type[HashContext]) -> Hmac:
"""
Creates a HMAC context object.
"""

@ -23,6 +23,11 @@ from micropython import const
from trezor.crypto import hashlib, hmac, pbkdf2, random
from trezorcrypto import shamir, slip39
if False:
from typing import Dict, Iterable, List, Optional, Set, Tuple
Indices = Tuple[int, ...]
KEYBOARD_FULL_MASK = const(0x1FF)
"""All buttons are allowed. 9-bit bitmap all set to 1."""
@ -35,7 +40,7 @@ def compute_mask(prefix: str) -> int:
def button_sequence_to_word(prefix: str) -> str:
if not prefix:
return KEYBOARD_FULL_MASK
return ""
return slip39.button_sequence_to_word(int(prefix))
@ -43,11 +48,11 @@ _RADIX_BITS = const(10)
"""The length of the radix in bits."""
def bits_to_bytes(n):
def bits_to_bytes(n: int) -> int:
return (n + 7) // 8
def bits_to_words(n):
def bits_to_words(n: int) -> int:
return (n + _RADIX_BITS - 1) // _RADIX_BITS
@ -103,7 +108,7 @@ class MnemonicError(Exception):
pass
def _rs1024_polymod(values):
def _rs1024_polymod(values: Indices) -> int:
GEN = (
0xE0E040,
0x1C1C080,
@ -125,7 +130,7 @@ def _rs1024_polymod(values):
return chk
def rs1024_create_checksum(data):
def rs1024_create_checksum(data: Indices) -> Indices:
values = tuple(_CUSTOMIZATION_STRING) + data + _CHECKSUM_LENGTH_WORDS * (0,)
polymod = _rs1024_polymod(values) ^ 1
return tuple(
@ -133,11 +138,11 @@ def rs1024_create_checksum(data):
)
def rs1024_verify_checksum(data):
def rs1024_verify_checksum(data: Indices) -> bool:
return _rs1024_polymod(tuple(_CUSTOMIZATION_STRING) + data) == 1
def rs1024_error_index(data):
def rs1024_error_index(data: Indices) -> Optional[int]:
GEN = (
0x91F9F87,
0x122F1F07,
@ -164,11 +169,11 @@ def rs1024_error_index(data):
return None
def xor(a, b):
def xor(a: bytes, b: bytes) -> bytes:
return bytes(x ^ y for x, y in zip(a, b))
def _int_from_indices(indices):
def _int_from_indices(indices: Indices) -> int:
"""Converts a list of base 1024 indices in big endian order to an integer value."""
value = 0
for index in indices:
@ -176,21 +181,21 @@ def _int_from_indices(indices):
return value
def _int_to_indices(value, length, bits):
def _int_to_indices(value: int, length: int, bits: int) -> Iterable[int]:
"""Converts an integer value to indices in big endian order."""
mask = (1 << bits) - 1
return ((value >> (i * bits)) & mask for i in reversed(range(length)))
def mnemonic_from_indices(indices):
def mnemonic_from_indices(indices: Indices) -> str:
return " ".join(slip39.get_word(i) for i in indices)
def mnemonic_to_indices(mnemonic):
def mnemonic_to_indices(mnemonic: str) -> Iterable[int]:
return (slip39.word_index(word.lower()) for word in mnemonic.split())
def _round_function(i, passphrase, e, salt, r):
def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes:
"""The round function used internally by the Feistel cipher."""
return pbkdf2(
pbkdf2.HMAC_SHA256,
@ -200,13 +205,15 @@ def _round_function(i, passphrase, e, salt, r):
).key()[: len(r)]
def _get_salt(identifier):
def _get_salt(identifier: int) -> bytes:
return _CUSTOMIZATION_STRING + identifier.to_bytes(
bits_to_bytes(_ID_LENGTH_BITS), "big"
)
def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
def _encrypt(
master_secret: bytes, passphrase: bytes, iteration_exponent: int, identifier: int
) -> bytes:
l = master_secret[: len(master_secret) // 2]
r = master_secret[len(master_secret) // 2 :]
salt = _get_salt(identifier)
@ -218,7 +225,12 @@ def _encrypt(master_secret, passphrase, iteration_exponent, identifier):
return r + l
def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase):
def decrypt(
identifier: int,
iteration_exponent: int,
encrypted_master_secret: bytes,
passphrase: bytes,
) -> bytes:
l = encrypted_master_secret[: len(encrypted_master_secret) // 2]
r = encrypted_master_secret[len(encrypted_master_secret) // 2 :]
salt = _get_salt(identifier)
@ -230,13 +242,15 @@ def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase)
return r + l
def _create_digest(random_data, shared_secret):
def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes:
return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[
:_DIGEST_LENGTH_BYTES
]
def _split_secret(threshold, share_count, shared_secret):
def _split_secret(
threshold: int, share_count: int, shared_secret: bytes
) -> List[Tuple[int, bytes]]:
if threshold < 1:
raise ValueError(
"The requested threshold ({}) must be a positive integer.".format(threshold)
@ -278,7 +292,7 @@ def _split_secret(threshold, share_count, shared_secret):
return shares
def _recover_secret(threshold, shares):
def _recover_secret(threshold: int, shares: List[Tuple[int, bytes]]) -> bytes:
# If the threshold is 1, then the digest of the shared secret is not used.
if threshold == 1:
return shares[0][1]
@ -295,8 +309,12 @@ def _recover_secret(threshold, shares):
def _group_prefix(
identifier, iteration_exponent, group_index, group_threshold, group_count
):
identifier: int,
iteration_exponent: int,
group_index: int,
group_threshold: int,
group_count: int,
) -> Indices:
id_exp_int = (identifier << _ITERATION_EXP_LENGTH_BITS) + iteration_exponent
return tuple(_int_to_indices(id_exp_int, _ID_EXP_LENGTH_WORDS, _RADIX_BITS)) + (
(group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2),
@ -304,15 +322,15 @@ def _group_prefix(
def encode_mnemonic(
identifier,
iteration_exponent,
group_index,
group_threshold,
group_count,
member_index,
member_threshold,
value,
):
identifier: int,
iteration_exponent: int,
group_index: int,
group_threshold: int,
group_count: int,
member_index: int,
member_threshold: int,
value: bytes,
) -> str:
"""
Converts share data to a share mnemonic.
:param int identifier: The random identifier.
@ -348,7 +366,7 @@ def encode_mnemonic(
return mnemonic_from_indices(share_data + checksum)
def decode_mnemonic(mnemonic):
def decode_mnemonic(mnemonic: str) -> Tuple[int, int, int, int, int, int, int, bytes]:
"""Converts a share mnemonic to share data."""
mnemonic_data = tuple(mnemonic_to_indices(mnemonic))
@ -401,12 +419,20 @@ def decode_mnemonic(mnemonic):
)
def _decode_mnemonics(mnemonics):
if False:
MnemonicGroups = Dict[int, Tuple[int, Set[Tuple[int, bytes]]]]
def _decode_mnemonics(
mnemonics: List[str]
) -> Tuple[int, int, int, int, MnemonicGroups]:
identifiers = set()
iteration_exponents = set()
group_thresholds = set()
group_counts = set()
groups = {} # { group_index : [member_threshold, set_of_member_shares] }
# { group_index : [member_threshold, set_of_member_shares] }
groups = {} # type: MnemonicGroups
for mnemonic in mnemonics:
identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic(
mnemonic
@ -415,7 +441,7 @@ def _decode_mnemonics(mnemonics):
iteration_exponents.add(iteration_exponent)
group_thresholds.add(group_threshold)
group_counts.add(group_count)
group = groups.setdefault(group_index, [member_threshold, set()])
group = groups.setdefault(group_index, (member_threshold, set()))
if group[0] != member_threshold:
raise MnemonicError(
"Invalid set of mnemonics. All mnemonics in a group must have the same member threshold."
@ -462,13 +488,13 @@ def generate_random_identifier() -> int:
def generate_single_group_mnemonics_from_data(
master_secret,
identifier,
threshold,
count,
passphrase=b"",
iteration_exponent=DEFAULT_ITERATION_EXPONENT,
) -> list:
master_secret: bytes,
identifier: int,
threshold: int,
count: int,
passphrase: bytes = b"",
iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> List[str]:
return generate_mnemonics_from_data(
master_secret,
identifier,
@ -480,13 +506,13 @@ def generate_single_group_mnemonics_from_data(
def generate_mnemonics_from_data(
master_secret,
identifier,
group_threshold,
groups,
passphrase=b"",
iteration_exponent=DEFAULT_ITERATION_EXPONENT,
) -> list:
master_secret: bytes,
identifier: int,
group_threshold: int,
groups: List[Tuple[int, int]],
passphrase: bytes = b"",
iteration_exponent: int = DEFAULT_ITERATION_EXPONENT,
) -> List[List[str]]:
"""
Splits a master secret into mnemonic shares using Shamir's secret sharing scheme.
:param master_secret: The master secret to split.
@ -544,7 +570,7 @@ def generate_mnemonics_from_data(
group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret)
mnemonics = []
mnemonics = [] # type: List[List[str]]
for (member_threshold, member_count), (group_index, group_secret) in zip(
groups, group_shares
):
@ -568,7 +594,7 @@ def generate_mnemonics_from_data(
return mnemonics
def combine_mnemonics(mnemonics):
def combine_mnemonics(mnemonics: List[str]) -> Tuple[int, int, bytes]:
"""
Combines mnemonic shares to obtain the master secret which was previously split using
Shamir's secret sharing scheme.

@ -2,6 +2,9 @@ import sys
import utime
from micropython import const
if False:
from typing import Any
NOTSET = const(0)
DEBUG = const(10)
INFO = const(20)
@ -21,7 +24,7 @@ level = DEBUG
color = True
def _log(name, mlevel, msg, *args):
def _log(name: str, mlevel: int, msg: str, *args: Any) -> None:
if __debug__ and mlevel >= level:
if color:
fmt = (
@ -35,26 +38,26 @@ def _log(name, mlevel, msg, *args):
print(fmt % ((utime.ticks_us(), name, _leveldict[mlevel][0]) + args))
def debug(name, msg, *args):
def debug(name: str, msg: str, *args: Any) -> None:
_log(name, DEBUG, msg, *args)
def info(name, msg, *args):
def info(name: str, msg: str, *args: Any) -> None:
_log(name, INFO, msg, *args)
def warning(name, msg, *args):
def warning(name: str, msg: str, *args: Any) -> None:
_log(name, WARNING, msg, *args)
def error(name, msg, *args):
def error(name: str, msg: str, *args: Any) -> None:
_log(name, ERROR, msg, *args)
def exception(name, exc):
_log(name, ERROR, "exception:")
sys.print_exception(exc)
def critical(name: str, msg: str, *args: Any) -> None:
_log(name, CRITICAL, msg, *args)
def critical(name, msg, *args):
_log(name, CRITICAL, msg, *args)
def exception(name: str, exc: BaseException) -> None:
_log(name, ERROR, "exception:")
sys.print_exception(exc)

@ -13,12 +13,33 @@ from micropython import const
from trezor import io, log
after_step_hook = None # function, called after each task step
_QUEUE_SIZE = const(64) # maximum number of scheduled tasks
_queue = utimeq.utimeq(_QUEUE_SIZE)
_paused = {}
_finalizers = {}
if False:
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Dict,
Generator,
List,
Optional,
Set,
)
Task = Coroutine
Finalizer = Callable[[Task, Any], None]
# function to call after every task step
after_step_hook = None # type: Optional[Callable[[], None]]
# tasks scheduled for execution in the future
_queue = utimeq.utimeq(64)
# tasks paused on I/O
_paused = {} # type: Dict[int, Set[Task]]
# functions to execute after a task is finished
_finalizers = {} # type: Dict[int, Finalizer]
if __debug__:
# for performance stats
@ -29,7 +50,9 @@ if __debug__:
log_delay_rb = array.array("i", [0] * log_delay_rb_len)
def schedule(task, value=None, deadline=None, finalizer=None):
def schedule(
task: Task, value: Any = None, deadline: int = None, finalizer: Finalizer = None
) -> None:
"""
Schedule task to be executed with `value` on given `deadline` (in
microseconds). Does not start the event loop itself, see `run`.
@ -41,20 +64,20 @@ def schedule(task, value=None, deadline=None, finalizer=None):
_queue.push(deadline, task, value)
def pause(task, iface):
def pause(task: Task, iface: int) -> None:
tasks = _paused.get(iface, None)
if tasks is None:
tasks = _paused[iface] = set()
tasks.add(task)
def finalize(task, value):
def finalize(task: Task, value: Any) -> None:
fn = _finalizers.pop(id(task), None)
if fn is not None:
fn(task, value)
def close(task):
def close(task: Task) -> None:
for iface in _paused:
_paused[iface].discard(task)
_queue.discard(task)
@ -62,7 +85,7 @@ def close(task):
finalize(task, GeneratorExit())
def run():
def run() -> None:
"""
Loop forever, stepping through scheduled tasks and awaiting I/O events
inbetween. Use `schedule` first to add a coroutine to the task queue.
@ -98,13 +121,17 @@ def run():
# timeout occurred, run the first scheduled task
if _queue:
_queue.pop(task_entry)
_step(task_entry[1], task_entry[2])
_step(task_entry[1], task_entry[2]) # type: ignore
# error: Argument 1 to "_step" has incompatible type "int"; expected "Coroutine[Any, Any, Any]"
# rationale: We use untyped lists here, because that is what the C API supports.
def _step(task, value):
def _step(task: Task, value: Any) -> None:
try:
if isinstance(value, BaseException):
result = task.throw(value)
result = task.throw(value) # type: ignore
# error: Argument 1 to "throw" of "Coroutine" has incompatible type "Exception"; expected "Type[BaseException]"
# rationale: In micropython, generator.throw() accepts the exception object directly.
else:
result = task.send(value)
except StopIteration as e: # as e:
@ -133,10 +160,16 @@ class Syscall:
scheduler, they do so through instances of a class derived from `Syscall`.
"""
def __iter__(self):
def __iter__(self) -> Task: # type: ignore
# support `yield from` or `await` on syscalls
return (yield self)
def __await__(self) -> Generator:
return self.__iter__() # type: ignore
def handle(self, task: Task) -> None:
pass
class sleep(Syscall):
"""
@ -150,10 +183,10 @@ class sleep(Syscall):
>>> print('missed by %d us', utime.ticks_diff(utime.ticks_us(), planned))
"""
def __init__(self, delay_us):
def __init__(self, delay_us: int) -> None:
self.delay_us = delay_us
def handle(self, task):
def handle(self, task: Task) -> None:
deadline = utime.ticks_add(utime.ticks_us(), self.delay_us)
schedule(task, deadline, deadline)
@ -170,14 +203,14 @@ class wait(Syscall):
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event
"""
def __init__(self, msg_iface):
def __init__(self, msg_iface: int) -> None:
self.msg_iface = msg_iface
def handle(self, task):
def handle(self, task: Task) -> None:
pause(task, self.msg_iface)
_NO_VALUE = ()
_NO_VALUE = object()
class signal(Syscall):
@ -196,28 +229,28 @@ class signal(Syscall):
>>> # prints in the next iteration of the event loop
"""
def __init__(self):
def __init__(self) -> None:
self.reset()
def reset(self):
def reset(self) -> None:
self.value = _NO_VALUE
self.task = None
self.task = None # type: Optional[Task]
def handle(self, task):
def handle(self, task: Task) -> None:
self.task = task
self._deliver()
def send(self, value):
def send(self, value: Any) -> None:
self.value = value
self._deliver()
def _deliver(self):
def _deliver(self) -> None:
if self.task is not None and self.value is not _NO_VALUE:
schedule(self.task, self.value)
self.task = None
self.value = _NO_VALUE
def __iter__(self):
def __iter__(self) -> Task: # type: ignore
try:
return (yield self)
except: # noqa: E722
@ -253,14 +286,13 @@ class spawn(Syscall):
`spawn.__iter__` for explanation. Always use `await`.
"""
def __init__(self, *children, exit_others=True):
def __init__(self, *children: Awaitable, exit_others: bool = True) -> None:
self.children = children
self.exit_others = exit_others
self.scheduled = [] # list of scheduled tasks
self.finished = [] # list of children that finished
self.callback = None
self.finished = [] # type: List[Awaitable] # children that finished
self.scheduled = [] # type: List[Task] # scheduled wrapper tasks
def handle(self, task):
def handle(self, task: Task) -> None:
finalizer = self._finish
scheduled = self.scheduled
finished = self.finished
@ -273,16 +305,17 @@ class spawn(Syscall):
if isinstance(child, _type_gen):
child_task = child
else:
child_task = iter(child)
schedule(child_task, None, None, finalizer)
scheduled.append(child_task)
child_task = iter(child) # type: ignore
schedule(child_task, None, None, finalizer) # type: ignore
scheduled.append(child_task) # type: ignore
# TODO: document the types here
def exit(self, except_for=None):
def exit(self, except_for: Task = None) -> None:
for task in self.scheduled:
if task != except_for:
close(task)
def _finish(self, task, result):
def _finish(self, task: Task, result: Any) -> None:
if not self.finished:
for index, child_task in enumerate(self.scheduled):
if child_task is task:
@ -293,7 +326,7 @@ class spawn(Syscall):
self.exit(task)
schedule(self.callback, result)
def __iter__(self):
def __iter__(self) -> Task: # type: ignore
try:
return (yield self)
except: # noqa: E722

@ -1,17 +1,17 @@
try:
from .resources import resdata
except ImportError:
resdata = None
resdata = {}
def load(name):
def load(name: str) -> bytes:
"""
Loads resource of a given name as bytes.
"""
return resdata[name]
def gettext(message):
def gettext(message: str) -> str:
"""
Returns localized string. This function is aliased to _.
"""

@ -5,12 +5,18 @@ from trezorui import Display
from trezor import io, loop, res, utils, workflow
if False:
from typing import Any, Generator, Iterable, Tuple, TypeVar
Pos = Tuple[int, int]
Area = Tuple[int, int, int, int]
display = Display()
# in debug mode, display an indicator in top right corner
if __debug__:
def debug_display_refresh():
def debug_display_refresh() -> None:
display.bar(Display.WIDTH - 8, 0, 8, 8, 0xF800)
display.refresh()
if utils.SAVE_SCREEN:
@ -59,25 +65,25 @@ from trezor.ui import style # isort:skip
from trezor.ui.style import * # isort:skip # noqa: F401,F403
def pulse(delay: int):
def pulse(delay: int) -> float:
# normalize sin from interval -1:1 to 0:1
return 0.5 + 0.5 * math.sin(utime.ticks_us() / delay)
async def click() -> tuple:
async def click() -> Pos:
touch = loop.wait(io.TOUCH)
while True:
ev, *pos = yield touch
ev, *pos = await touch
if ev == io.TOUCH_START:
break
while True:
ev, *pos = yield touch
ev, *pos = await touch
if ev == io.TOUCH_END:
break
return pos
return pos # type: ignore
def backlight_fade(val: int, delay: int = 14000, step: int = 15):
def backlight_fade(val: int, delay: int = 14000, step: int = 15) -> None:
if __debug__:
if utils.DISABLE_FADE:
display.backlight(val)
@ -96,7 +102,7 @@ def header(
fg: int = style.FG,
bg: int = style.BG,
ifg: int = style.GREEN,
):
) -> None:
if icon is not None:
display.icon(14, 15, res.load(icon), ifg, bg)
display.text(44, 35, title, BOLD, fg, bg)
@ -113,7 +119,7 @@ def grid(
cells_x: int = 1,
cells_y: int = 1,
spacing: int = 0,
):
) -> Area:
w = (end_x - start_x) // n_x
h = (end_y - start_y) // n_y
x = (i % n_x) * w
@ -121,7 +127,7 @@ def grid(
return (x + start_x, y + start_y, (w - spacing) * cells_x, (h - spacing) * cells_y)
def in_area(area: tuple, x: int, y: int) -> bool:
def in_area(area: Area, x: int, y: int) -> bool:
ax, ay, aw, ah = area
return ax <= x <= ax + aw and ay <= y <= ay + ah
@ -132,7 +138,7 @@ REPAINT = const(-256)
class Control:
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
if event is RENDER:
self.on_render()
elif event is io.TOUCH_START:
@ -144,16 +150,16 @@ class Control:
elif event is REPAINT:
self.repaint = True
def on_render(self):
def on_render(self) -> None:
pass
def on_touch_start(self, x, y):
def on_touch_start(self, x: int, y: int) -> None:
pass
def on_touch_move(self, x, y):
def on_touch_move(self, x: int, y: int) -> None:
pass
def on_touch_end(self, x, y):
def on_touch_end(self, x: int, y: int) -> None:
pass
@ -164,8 +170,12 @@ class LayoutCancelled(Exception):
pass
if False:
ResultValue = TypeVar("ResultValue")
class Result(Exception):
def __init__(self, value):
def __init__(self, value: ResultValue) -> None:
self.value = value
@ -173,7 +183,7 @@ class Layout(Control):
"""
"""
async def __iter__(self):
async def __iter__(self) -> ResultValue:
value = None
try:
if workflow.layout_signal.task is not None:
@ -188,17 +198,20 @@ class Layout(Control):
workflow.onlayoutclose(self)
return value
def create_tasks(self):
def __await__(self) -> Generator[Any, Any, ResultValue]:
return self.__iter__() # type: ignore
def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering()
def handle_input(self):
def handle_input(self) -> loop.Task: # type: ignore
touch = loop.wait(io.TOUCH)
while True:
event, x, y = yield touch
self.dispatch(event, x, y)
self.dispatch(RENDER, 0, 0)
def handle_rendering(self):
def handle_rendering(self) -> loop.Task: # type: ignore
backlight_fade(style.BACKLIGHT_DIM)
display.clear()
self.dispatch(RENDER, 0, 0)

@ -3,6 +3,9 @@ from micropython import const
from trezor import ui
from trezor.ui import display, in_area
if False:
from typing import Type, Union
class ButtonDefault:
class normal:
@ -12,14 +15,14 @@ class ButtonDefault:
border_color = ui.BG
radius = ui.RADIUS
class active:
class active(normal):
bg_color = ui.FG
fg_color = ui.BLACKISH
text_style = ui.BOLD
border_color = ui.FG
radius = ui.RADIUS
class disabled:
class disabled(normal):
bg_color = ui.BG
fg_color = ui.GREY
text_style = ui.NORMAL
@ -38,7 +41,7 @@ class ButtonMono(ButtonDefault):
text_style = ui.MONO
class ButtonMonoDark:
class ButtonMonoDark(ButtonDefault):
class normal:
bg_color = ui.DARK_BLACK
fg_color = ui.DARK_WHITE
@ -46,14 +49,14 @@ class ButtonMonoDark:
border_color = ui.BG
radius = ui.RADIUS
class active:
class active(normal):
bg_color = ui.FG
fg_color = ui.DARK_BLACK
text_style = ui.MONO
border_color = ui.FG
radius = ui.RADIUS
class disabled:
class disabled(normal):
bg_color = ui.DARK_BLACK
fg_color = ui.GREY
text_style = ui.MONO
@ -98,6 +101,12 @@ class ButtonMonoConfirm(ButtonDefault):
text_style = ui.MONO
if False:
ButtonContent = Union[str, bytes]
ButtonStyleType = Type[ButtonDefault]
ButtonStyleStateType = Type[ButtonDefault.normal]
# button states
_INITIAL = const(0)
_PRESSED = const(1)
@ -110,39 +119,53 @@ _BORDER = const(4) # border size in pixels
class Button(ui.Control):
def __init__(self, area, content, style=ButtonDefault):
def __init__(
self,
area: ui.Area,
content: ButtonContent,
style: ButtonStyleType = ButtonDefault,
) -> None:
if isinstance(content, str):
self.text = content
self.icon = b""
elif isinstance(content, bytes):
self.icon = content
self.text = ""
else:
raise TypeError
self.area = area
self.content = content
self.normal_style = style.normal
self.active_style = style.active
self.disabled_style = style.disabled
self.state = _INITIAL
self.repaint = True
def enable(self):
def enable(self) -> None:
if self.state is not _INITIAL:
self.state = _INITIAL
self.repaint = True
def disable(self):
def disable(self) -> None:
if self.state is not _DISABLED:
self.state = _DISABLED
self.repaint = True
def on_render(self):
def on_render(self) -> None:
if self.repaint:
if self.state is _DISABLED:
if self.state is _INITIAL or self.state is _RELEASED:
s = self.normal_style
elif self.state is _DISABLED:
s = self.disabled_style
elif self.state is _PRESSED:
s = self.active_style
else:
s = self.normal_style
ax, ay, aw, ah = self.area
self.render_background(s, ax, ay, aw, ah)
self.render_content(s, ax, ay, aw, ah)
self.repaint = False
def render_background(self, s, ax, ay, aw, ah):
def render_background(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
radius = s.radius
bg_color = s.bg_color
border_color = s.border_color
@ -162,16 +185,21 @@ class Button(ui.Control):
radius,
)
def render_content(self, s, ax, ay, aw, ah):
def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
tx = ax + aw // 2
ty = ay + ah // 2 + 8
t = self.content
if isinstance(t, str):
t = self.text
if t:
display.text_center(tx, ty, t, s.text_style, s.fg_color, s.bg_color)
elif isinstance(t, bytes):
display.icon(tx - _ICON // 2, ty - _ICON, t, s.fg_color, s.bg_color)
return
i = self.icon
if i:
display.icon(tx - _ICON // 2, ty - _ICON, i, s.fg_color, s.bg_color)
return
def on_touch_start(self, x, y):
def on_touch_start(self, x: int, y: int) -> None:
if self.state is _DISABLED:
return
if in_area(self.area, x, y):
@ -179,7 +207,7 @@ class Button(ui.Control):
self.repaint = True
self.on_press_start()
def on_touch_move(self, x, y):
def on_touch_move(self, x: int, y: int) -> None:
if self.state is _DISABLED:
return
if in_area(self.area, x, y):
@ -193,7 +221,7 @@ class Button(ui.Control):
self.repaint = True
self.on_press_end()
def on_touch_end(self, x, y):
def on_touch_end(self, x: int, y: int) -> None:
state = self.state
if state is not _INITIAL and state is not _DISABLED:
self.state = _INITIAL
@ -203,11 +231,11 @@ class Button(ui.Control):
self.on_press_end()
self.on_click()
def on_press_start(self):
def on_press_start(self) -> None:
pass
def on_press_end(self):
def on_press_end(self) -> None:
pass
def on_click(self):
def on_click(self) -> None:
pass

@ -3,32 +3,37 @@ from micropython import const
from trezor import res, ui
from trezor.ui.text import TEXT_HEADER_HEIGHT, TEXT_LINE_HEIGHT
if False:
from typing import Iterable, List, Union
ChecklistItem = Union[str, Iterable[str]]
_CHECKLIST_MAX_LINES = const(5)
_CHECKLIST_OFFSET_X = const(24)
_CHECKLIST_OFFSET_X_ICON = const(0)
class Checklist(ui.Control):
def __init__(self, title, icon):
def __init__(self, title: str, icon: str) -> None:
self.title = title
self.icon = icon
self.items = []
self.items = [] # type: List[ChecklistItem]
self.active = 0
self.repaint = True
def add(self, choice):
self.items.append(choice)
def add(self, item: ChecklistItem) -> None:
self.items.append(item)
def select(self, active):
def select(self, active: int) -> None:
self.active = active
def on_render(self):
def on_render(self) -> None:
if self.repaint:
ui.header(self.title, self.icon)
self.render_items()
self.repaint = False
def render_items(self):
def render_items(self) -> None:
offset_x = _CHECKLIST_OFFSET_X
offset_y = TEXT_HEADER_HEIGHT + TEXT_LINE_HEIGHT
bg = ui.BG

@ -2,6 +2,10 @@ from trezor import res, ui
from trezor.ui.button import Button, ButtonCancel, ButtonConfirm
from trezor.ui.loader import Loader, LoaderDefault
if False:
from trezor.ui.button import ButtonContent, ButtonStyleType
from trezor.ui.loader import LoaderStyleType
CONFIRMED = object()
CANCELLED = object()
@ -14,13 +18,13 @@ class Confirm(ui.Layout):
def __init__(
self,
content,
confirm=DEFAULT_CONFIRM,
confirm_style=DEFAULT_CONFIRM_STYLE,
cancel=DEFAULT_CANCEL,
cancel_style=DEFAULT_CANCEL_STYLE,
major_confirm=False,
):
content: ui.Control,
confirm: ButtonContent = DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE,
cancel: ButtonContent = DEFAULT_CANCEL,
cancel_style: ButtonStyleType = DEFAULT_CANCEL_STYLE,
major_confirm: bool = False,
) -> None:
self.content = content
if confirm is not None:
@ -31,7 +35,7 @@ class Confirm(ui.Layout):
else:
area = ui.grid(9, n_x=2)
self.confirm = Button(area, confirm, confirm_style)
self.confirm.on_click = self.on_confirm
self.confirm.on_click = self.on_confirm # type: ignore
else:
self.confirm = None
@ -43,21 +47,21 @@ class Confirm(ui.Layout):
else:
area = ui.grid(8, n_x=2)
self.cancel = Button(area, cancel, cancel_style)
self.cancel.on_click = self.on_cancel
self.cancel.on_click = self.on_cancel # type: ignore
else:
self.cancel = None
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y)
if self.confirm is not None:
self.confirm.dispatch(event, x, y)
if self.cancel is not None:
self.cancel.dispatch(event, x, y)
def on_confirm(self):
def on_confirm(self) -> None:
raise ui.Result(CONFIRMED)
def on_cancel(self):
def on_cancel(self) -> None:
raise ui.Result(CANCELLED)
@ -68,44 +72,44 @@ class HoldToConfirm(ui.Layout):
def __init__(
self,
content,
confirm=DEFAULT_CONFIRM,
confirm_style=DEFAULT_CONFIRM_STYLE,
loader_style=DEFAULT_LOADER_STYLE,
content: ui.Control,
confirm: str = DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE,
loader_style: LoaderStyleType = DEFAULT_LOADER_STYLE,
):
self.content = content
self.loader = Loader(loader_style)
self.loader.on_start = self._on_loader_start
self.loader.on_start = self._on_loader_start # type: ignore
self.button = Button(ui.grid(4, n_x=1), confirm, confirm_style)
self.button.on_press_start = self._on_press_start
self.button.on_press_end = self._on_press_end
self.button.on_click = self._on_click
self.button.on_press_start = self._on_press_start # type: ignore
self.button.on_press_end = self._on_press_end # type: ignore
self.button.on_click = self._on_click # type: ignore
def _on_press_start(self):
def _on_press_start(self) -> None:
self.loader.start()
def _on_press_end(self):
def _on_press_end(self) -> None:
self.loader.stop()
def _on_loader_start(self):
def _on_loader_start(self) -> None:
# Loader has either started growing, or returned to the 0-position.
# In the first case we need to clear the content leftovers, in the latter
# we need to render the content again.
ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT - 60, ui.BG)
self.content.dispatch(ui.REPAINT, 0, 0)
def _on_click(self):
def _on_click(self) -> None:
if self.loader.elapsed_ms() >= self.loader.target_ms:
self.on_confirm()
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
if self.loader.start_ms is not None:
self.loader.dispatch(event, x, y)
else:
self.content.dispatch(event, x, y)
self.button.dispatch(event, x, y)
def on_confirm(self):
def on_confirm(self) -> None:
raise ui.Result(CONFIRMED)

@ -2,9 +2,9 @@ from trezor import ui
class Container(ui.Control):
def __init__(self, *children):
def __init__(self, *children: ui.Control):
self.children = children
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
for child in self.children:
child.dispatch(event, x, y)

@ -3,6 +3,10 @@ from trezor.ui.button import Button, ButtonConfirm
from trezor.ui.confirm import CONFIRMED
from trezor.ui.text import TEXT_LINE_HEIGHT, TEXT_MARGIN_LEFT, render_text
if False:
from typing import Type
from trezor.ui.button import ButtonContent
class DefaultInfoConfirm:
@ -17,26 +21,35 @@ class DefaultInfoConfirm:
border_color = ui.BLACKISH
if False:
InfoConfirmStyleType = Type[DefaultInfoConfirm]
class InfoConfirm(ui.Layout):
DEFAULT_CONFIRM = res.load(ui.ICON_CONFIRM)
DEFAULT_STYLE = DefaultInfoConfirm
def __init__(self, text, confirm=DEFAULT_CONFIRM, style=DEFAULT_STYLE):
def __init__(
self,
text: str,
confirm: ButtonContent = DEFAULT_CONFIRM,
style: InfoConfirmStyleType = DEFAULT_STYLE,
) -> None:
self.text = text.split()
self.style = style
panel_area = ui.grid(0, n_x=1, n_y=1)
self.panel_area = panel_area
confirm_area = ui.grid(4, n_x=1)
self.confirm = Button(confirm_area, confirm, style.button)
self.confirm.on_click = self.on_confirm
self.confirm.on_click = self.on_confirm # type: ignore
self.repaint = True
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
if event == ui.RENDER:
self.on_render()
self.confirm.dispatch(event, x, y)
def on_render(self):
def on_render(self) -> None:
if self.repaint:
x, y, w, h = self.panel_area
fg_color = self.style.fg_color
@ -46,7 +59,7 @@ class InfoConfirm(ui.Layout):
ui.display.bar_radius(x, y, w, h, bg_color, ui.BG, ui.RADIUS)
# render the info text
render_text(
render_text( # type: ignore
self.text,
new_lines=False,
max_lines=6,
@ -59,5 +72,5 @@ class InfoConfirm(ui.Layout):
self.repaint = False
def on_confirm(self):
def on_confirm(self) -> None:
raise ui.Result(CONFIRMED)

@ -4,22 +4,25 @@ from micropython import const
from trezor import res, ui
from trezor.ui import display
if False:
from typing import Optional, Type
class LoaderDefault:
class normal:
bg_color = ui.BG
fg_color = ui.GREEN
icon = None
icon_fg_color = None
icon = None # type: Optional[str]
icon_fg_color = None # type: Optional[int]
class active:
class active(normal):
bg_color = ui.BG
fg_color = ui.GREEN
icon = ui.ICON_CHECK
icon_fg_color = ui.WHITE
class LoaderDanger:
class LoaderDanger(LoaderDefault):
class normal(LoaderDefault.normal):
fg_color = ui.RED
@ -27,31 +30,35 @@ class LoaderDanger:
fg_color = ui.RED
if False:
LoaderStyleType = Type[LoaderDefault]
_TARGET_MS = const(1000)
class Loader(ui.Control):
def __init__(self, style=LoaderDefault):
def __init__(self, style: LoaderStyleType = LoaderDefault) -> None:
self.normal_style = style.normal
self.active_style = style.active
self.target_ms = _TARGET_MS
self.start_ms = None
self.stop_ms = None
def start(self):
def start(self) -> None:
self.start_ms = utime.ticks_ms()
self.stop_ms = None
self.on_start()
def stop(self):
def stop(self) -> None:
self.stop_ms = utime.ticks_ms()
def elapsed_ms(self):
def elapsed_ms(self) -> int:
if self.start_ms is None:
return 0
return utime.ticks_ms() - self.start_ms
def on_render(self):
def on_render(self) -> None:
target = self.target_ms
start = self.start_ms
stop = self.stop_ms
@ -60,10 +67,10 @@ class Loader(ui.Control):
r = min(now - start, target)
else:
r = max(stop - start + (stop - now) * 2, 0)
if r == target:
s = self.active_style
else:
if r != target:
s = self.normal_style
else:
s = self.active_style
Y = const(-24)
@ -80,7 +87,7 @@ class Loader(ui.Control):
if r == target:
self.on_finish()
def on_start(self):
def on_start(self) -> None:
pass
def on_finish(self):

@ -3,6 +3,10 @@ from trezor.crypto import bip39
from trezor.ui import display
from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm
if False:
from typing import Optional
from trezor.ui.button import ButtonContent, ButtonStyleStateType
def compute_mask(text: str) -> int:
mask = 0
@ -15,44 +19,47 @@ def compute_mask(text: str) -> int:
class KeyButton(Button):
def __init__(self, area, content, keyboard):
def __init__(
self, area: ui.Area, content: ButtonContent, keyboard: "Bip39Keyboard"
):
self.keyboard = keyboard
super().__init__(area, content)
def on_click(self):
def on_click(self) -> None:
self.keyboard.on_key_click(self)
class InputButton(Button):
def __init__(self, area, content, word):
super().__init__(area, content)
def __init__(self, area: ui.Area, text: str, word: str) -> None:
super().__init__(area, text)
self.word = word
self.pending = False # should we draw the pending marker?
self.icon = None # rendered icon
self.pending = False
self.disable()
def edit(self, content, word, pending):
def edit(self, text: str, word: str, pending: bool) -> None:
self.word = word
self.content = content
self.text = text
self.pending = pending
self.repaint = True
if word:
if content == word: # confirm button
if text == word: # confirm button
self.enable()
self.normal_style = ButtonMonoConfirm.normal
self.active_style = ButtonMonoConfirm.active
self.icon = ui.ICON_CONFIRM
self.icon = res.load(ui.ICON_CONFIRM)
else: # auto-complete button
self.enable()
self.normal_style = ButtonMono.normal
self.active_style = ButtonMono.active
self.icon = ui.ICON_CLICK
self.icon = res.load(ui.ICON_CLICK)
else: # disabled button
self.disabled_style = ButtonMono.disabled
self.disable()
self.icon = None
self.icon = b""
def render_content(self, s, ax, ay, aw, ah):
def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
text_style = s.text_style
fg_color = s.fg_color
bg_color = s.bg_color
@ -61,29 +68,29 @@ class InputButton(Button):
ty = ay + ah // 2 + 8 # y-offset of the content
# entered content
display.text(tx, ty, self.content, text_style, fg_color, bg_color)
display.text(tx, ty, self.text, text_style, fg_color, bg_color)
# word suggestion
suggested_word = self.word[len(self.content) :]
width = display.text_width(self.content, text_style)
suggested_word = self.word[len(self.text) :]
width = display.text_width(self.text, text_style)
display.text(tx + width, ty, suggested_word, text_style, ui.GREY, bg_color)
if self.pending:
pw = display.text_width(self.content[-1:], text_style)
pw = display.text_width(self.text[-1:], text_style)
px = tx + width - pw
display.bar(px, ty + 2, pw + 1, 3, fg_color)
if self.icon:
ix = ax + aw - 16 * 2
iy = ty - 16
display.icon(ix, iy, res.load(self.icon), fg_color, bg_color)
display.icon(ix, iy, self.icon, fg_color, bg_color)
class Prompt(ui.Control):
def __init__(self, prompt):
def __init__(self, prompt: str) -> None:
self.prompt = prompt
self.repaint = True
def on_render(self):
def on_render(self) -> None:
if self.repaint:
display.bar(0, 8, ui.WIDTH, 60, ui.BG)
display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG)
@ -91,15 +98,15 @@ class Prompt(ui.Control):
class Bip39Keyboard(ui.Layout):
def __init__(self, prompt):
def __init__(self, prompt: str) -> None:
self.prompt = Prompt(prompt)
icon_back = res.load(ui.ICON_BACK)
self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear)
self.back.on_click = self.on_back_click
self.back.on_click = self.on_back_click # type: ignore
self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), "", "")
self.input.on_click = self.on_input_click
self.input.on_click = self.on_input_click # type: ignore
self.keys = [
KeyButton(ui.grid(i + 3, n_y=4), k, self)
@ -107,82 +114,82 @@ class Bip39Keyboard(ui.Layout):
("abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz")
)
]
self.pending_button = None
self.pending_button = None # type: Optional[Button]
self.pending_index = 0
def dispatch(self, event: int, x: int, y: int):
def dispatch(self, event: int, x: int, y: int) -> None:
for btn in self.keys:
btn.dispatch(event, x, y)
if self.input.content:
if self.input.text:
self.input.dispatch(event, x, y)
self.back.dispatch(event, x, y)
else:
self.prompt.dispatch(event, x, y)
def on_back_click(self):
def on_back_click(self) -> None:
# Backspace was clicked, let's delete the last character of input.
self.edit(self.input.content[:-1])
self.edit(self.input.text[:-1])
def on_input_click(self):
def on_input_click(self) -> None:
# Input button was clicked. If the content matches the suggested word,
# let's confirm it, otherwise just auto-complete.
content = self.input.content
text = self.input.text
word = self.input.word
if word and word == content:
if word and word == text:
self.edit("")
self.on_confirm(word)
else:
self.edit(word)
def on_key_click(self, btn: Button):
def on_key_click(self, btn: Button) -> None:
# Key button was clicked. If this button is pending, let's cycle the
# pending character in input. If not, let's just append the first
# character.
if self.pending_button is btn:
index = (self.pending_index + 1) % len(btn.content)
content = self.input.content[:-1] + btn.content[index]
index = (self.pending_index + 1) % len(btn.text)
text = self.input.text[:-1] + btn.text[index]
else:
index = 0
content = self.input.content + btn.content[0]
self.edit(content, btn, index)
text = self.input.text + btn.text[0]
self.edit(text, btn, index)
def on_timeout(self):
def on_timeout(self) -> None:
# Timeout occurred. If we can auto-complete current input, let's just
# reset the pending marker. If not, input is invalid, let's backspace
# the last character.
if self.input.word:
self.edit(self.input.content)
self.edit(self.input.text)
else:
self.edit(self.input.content[:-1])
self.edit(self.input.text[:-1])
def on_confirm(self, word):
def on_confirm(self, word: str) -> None:
# Word was confirmed by the user.
raise ui.Result(word)
def edit(self, content: str, button: KeyButton = None, index: int = 0):
def edit(self, text: str, button: Button = None, index: int = 0) -> None:
self.pending_button = button
self.pending_index = index
# find the completions
pending = button is not None
word = bip39.find_word(content) or ""
mask = bip39.complete_word(content)
word = bip39.find_word(text) or ""
mask = bip39.complete_word(text)
# modify the input state
self.input.edit(content, word, pending)
self.input.edit(text, word, pending)
# enable or disable key buttons
for btn in self.keys:
if btn is button or compute_mask(btn.content) & mask:
if btn is button or compute_mask(btn.text) & mask:
btn.enable()
else:
btn.disable()
# invalidate the prompt if we display it next frame
if not self.input.content:
if not self.input.text:
self.prompt.repaint = True
async def handle_input(self):
async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch)

@ -3,44 +3,61 @@ from trezor.crypto import slip39
from trezor.ui import display
from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm
if False:
from typing import Optional
from trezor.ui.button import ButtonContent, ButtonStyleStateType
class KeyButton(Button):
def __init__(self, area, content, keyboard, index):
def __init__(
self,
area: ui.Area,
content: ButtonContent,
keyboard: "Slip39Keyboard",
index: int,
):
self.keyboard = keyboard
self.index = index
super().__init__(area, content)
def on_click(self):
def on_click(self) -> None:
self.keyboard.on_key_click(self)
class InputButton(Button):
def __init__(self, area, keyboard):
def __init__(self, area: ui.Area, keyboard: "Slip39Keyboard") -> None:
super().__init__(area, "")
self.word = ""
self.pending_button = None
self.pending_index = None
self.icon = None # rendered icon
self.pending_button = None # type: Optional[Button]
self.pending_index = None # type: Optional[int]
self.keyboard = keyboard
self.disable()
def edit(self, content, word, pending_button, pending_index):
def edit(
self,
text: str,
word: str,
pending_button: Optional[Button],
pending_index: Optional[int],
) -> None:
self.word = word
self.content = content
self.text = text
self.pending_button = pending_button
self.pending_index = pending_index
self.repaint = True
if word:
if word: # confirm button
self.enable()
self.normal_style = ButtonMonoConfirm.normal
self.active_style = ButtonMonoConfirm.active
self.icon = ui.ICON_CONFIRM
self.icon = res.load(ui.ICON_CONFIRM)
else: # disabled button
self.disabled_style = ButtonMono.normal
self.disabled_style = ButtonMono.disabled
self.disable()
self.icon = None
self.icon = b""
def render_content(self, s, ax, ay, aw, ah):
def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
text_style = s.text_style
fg_color = s.fg_color
bg_color = s.bg_color
@ -49,11 +66,11 @@ class InputButton(Button):
ty = ay + ah // 2 + 8 # y-offset of the content
if not self.keyboard.is_input_final():
to_display = len(self.content) * "*"
if self.pending_button:
to_display = (
to_display[:-1] + self.pending_button.content[self.pending_index]
)
pending_button = self.pending_button
pending_index = self.pending_index
to_display = len(self.text) * "*"
if pending_button and pending_index is not None:
to_display = to_display[:-1] + pending_button.text[pending_index]
else:
to_display = self.word
@ -61,22 +78,22 @@ class InputButton(Button):
if self.pending_button and not self.keyboard.is_input_final():
width = display.text_width(to_display, text_style)
pw = display.text_width(self.content[-1:], text_style)
pw = display.text_width(self.text[-1:], text_style)
px = tx + width - pw
display.bar(px, ty + 2, pw + 1, 3, fg_color)
if self.icon:
ix = ax + aw - 16 * 2
iy = ty - 16
display.icon(ix, iy, res.load(self.icon), fg_color, bg_color)
display.icon(ix, iy, self.icon, fg_color, bg_color)
class Prompt(ui.Control):
def __init__(self, prompt):
def __init__(self, prompt: str) -> None:
self.prompt = prompt
self.repaint = True
def on_render(self):
def on_render(self) -> None:
if self.repaint:
display.bar(0, 8, ui.WIDTH, 60, ui.BG)
display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG)
@ -84,15 +101,15 @@ class Prompt(ui.Control):
class Slip39Keyboard(ui.Layout):
def __init__(self, prompt):
def __init__(self, prompt: str) -> None:
self.prompt = Prompt(prompt)
icon_back = res.load(ui.ICON_BACK)
self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear)
self.back.on_click = self.on_back_click
self.back.on_click = self.on_back_click # type: ignore
self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), self)
self.input.on_click = self.on_input_click
self.input.on_click = self.on_input_click # type: ignore
self.keys = [
KeyButton(ui.grid(i + 3, n_y=4), k, self, i + 1)
@ -100,26 +117,26 @@ class Slip39Keyboard(ui.Layout):
("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
)
]
self.pending_button = None
self.pending_button = None # type: Optional[Button]
self.pending_index = 0
self.button_sequence = ""
self.mask = slip39.KEYBOARD_FULL_MASK
def dispatch(self, event: int, x: int, y: int):
def dispatch(self, event: int, x: int, y: int) -> None:
for btn in self.keys:
btn.dispatch(event, x, y)
if self.input.content:
if self.input.text:
self.input.dispatch(event, x, y)
self.back.dispatch(event, x, y)
else:
self.prompt.dispatch(event, x, y)
def on_back_click(self):
def on_back_click(self) -> None:
# Backspace was clicked, let's delete the last character of input.
self.button_sequence = self.button_sequence[:-1]
self.edit()
def on_input_click(self):
def on_input_click(self) -> None:
# Input button was clicked. If the content matches the suggested word,
# let's confirm it, otherwise just auto-complete.
result = self.input.word
@ -128,26 +145,26 @@ class Slip39Keyboard(ui.Layout):
self.edit()
self.on_confirm(result)
def on_key_click(self, btn: KeyButton):
def on_key_click(self, btn: KeyButton) -> None:
# Key button was clicked. If this button is pending, let's cycle the
# pending character in input. If not, let's just append the first
# character.
if self.pending_button is btn:
index = (self.pending_index + 1) % len(btn.content)
index = (self.pending_index + 1) % len(btn.text)
else:
index = 0
self.button_sequence += str(btn.index)
self.edit(btn, index)
def on_timeout(self):
def on_timeout(self) -> None:
# Timeout occurred. Let's redraw to draw asterisks.
self.edit()
def on_confirm(self, word):
def on_confirm(self, word: str) -> None:
# Word was confirmed by the user.
raise ui.Result(word)
def edit(self, button: KeyButton = None, index: int = 0):
def edit(self, button: Button = None, index: int = 0) -> None:
self.pending_button = button
self.pending_index = index
@ -172,7 +189,7 @@ class Slip39Keyboard(ui.Layout):
btn.disable()
# invalidate the prompt if we display it next frame
if not self.input.content:
if not self.input.text:
self.prompt.repaint = True
def is_input_final(self) -> bool:
@ -182,7 +199,7 @@ class Slip39Keyboard(ui.Layout):
def check_mask(self, index: int) -> bool:
return bool((1 << (index - 1)) & self.mask)
async def handle_input(self):
async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch)

@ -6,6 +6,10 @@ from trezor.ui import display
from trezor.ui.button import Button, ButtonClear, ButtonConfirm
from trezor.ui.swipe import SWIPE_HORIZONTAL, SWIPE_LEFT, Swipe
if False:
from typing import List, Iterable, Optional
from trezor.ui.button import ButtonContent, ButtonStyleStateType
SPACE = res.load(ui.ICON_SPACE)
KEYBOARD_KEYS = (
@ -16,13 +20,13 @@ KEYBOARD_KEYS = (
)
def digit_area(i):
def digit_area(i: int) -> ui.Area:
if i == 9: # 0-position
i = 10 # display it in the middle
return ui.grid(i + 3) # skip the first line
def render_scrollbar(page):
def render_scrollbar(page: int) -> None:
BBOX = const(240)
SIZE = const(8)
pages = len(KEYBOARD_KEYS)
@ -43,42 +47,50 @@ def render_scrollbar(page):
class KeyButton(Button):
def __init__(self, area, content, keyboard):
def __init__(
self, area: ui.Area, content: ButtonContent, keyboard: "PassphraseKeyboard"
) -> None:
self.keyboard = keyboard
super().__init__(area, content)
def on_click(self):
def on_click(self) -> None:
self.keyboard.on_key_click(self)
def get_text_content(self):
if self.content is SPACE:
def get_text_content(self) -> str:
if self.text:
return self.text
elif self.icon is SPACE:
return " "
else:
return self.content
raise TypeError
def key_buttons(keys, keyboard):
def key_buttons(
keys: Iterable[ButtonContent], keyboard: "PassphraseKeyboard"
) -> List[KeyButton]:
return [KeyButton(digit_area(i), k, keyboard) for i, k in enumerate(keys)]
class Input(Button):
def __init__(self, area, content):
super().__init__(area, content)
def __init__(self, area: ui.Area, text: str) -> None:
super().__init__(area, text)
self.pending = False
self.disable()
def edit(self, content, pending):
self.content = content
def edit(self, text: str, pending: bool) -> None:
self.text = text
self.pending = pending
self.repaint = True
def render_content(self, s, ax, ay, aw, ah):
def render_content(
self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int
) -> None:
text_style = s.text_style
fg_color = s.fg_color
bg_color = s.bg_color
p = self.pending # should we draw the pending marker?
t = self.content # input content
t = self.text # input content
tx = ax + 24 # x-offset of the content
ty = ay + ah // 2 + 8 # y-offset of the content
@ -98,16 +110,16 @@ class Input(Button):
cx = tx + width + 1
display.bar(cx, ty - 18, 2, 22, fg_color)
def on_click(self):
def on_click(self) -> None:
pass
class Prompt(ui.Control):
def __init__(self, text):
def __init__(self, text: str) -> None:
self.text = text
self.repaint = True
def on_render(self):
def on_render(self) -> None:
if self.repaint:
display.bar(0, 0, ui.WIDTH, 48, ui.BG)
display.text_center(ui.WIDTH // 2, 32, self.text, ui.BOLD, ui.GREY, ui.BG)
@ -118,7 +130,7 @@ CANCELLED = object()
class PassphraseKeyboard(ui.Layout):
def __init__(self, prompt, max_length, page=1):
def __init__(self, prompt: str, max_length: int, page: int = 1) -> None:
self.prompt = Prompt(prompt)
self.max_length = max_length
self.page = page
@ -126,18 +138,18 @@ class PassphraseKeyboard(ui.Layout):
self.input = Input(ui.grid(0, n_x=1, n_y=6), "")
self.back = Button(ui.grid(12), res.load(ui.ICON_BACK), ButtonClear)
self.back.on_click = self.on_back_click
self.back.on_click = self.on_back_click # type: ignore
self.back.disable()
self.done = Button(ui.grid(14), res.load(ui.ICON_CONFIRM), ButtonConfirm)
self.done.on_click = self.on_confirm
self.done.on_click = self.on_confirm # type: ignore
self.keys = key_buttons(KEYBOARD_KEYS[self.page], self)
self.pending_button = None
self.pending_button = None # type: Optional[KeyButton]
self.pending_index = 0
def dispatch(self, event, x, y):
if self.input.content:
def dispatch(self, event: int, x: int, y: int) -> None:
if self.input.text:
self.input.dispatch(event, x, y)
else:
self.prompt.dispatch(event, x, y)
@ -149,37 +161,37 @@ class PassphraseKeyboard(ui.Layout):
if event == ui.RENDER:
render_scrollbar(self.page)
def on_back_click(self):
def on_back_click(self) -> None:
# Backspace was clicked. If we have any content in the input, let's delete
# the last character. Otherwise cancel.
content = self.input.content
if content:
self.edit(content[:-1])
text = self.input.text
if text:
self.edit(text[:-1])
else:
self.on_cancel()
def on_key_click(self, button: KeyButton):
def on_key_click(self, button: KeyButton) -> None:
# Key button was clicked. If this button is pending, let's cycle the
# pending character in input. If not, let's just append the first
# character.
button_text = button.get_text_content()
if self.pending_button is button:
index = (self.pending_index + 1) % len(button_text)
prefix = self.input.content[:-1]
prefix = self.input.text[:-1]
else:
index = 0
prefix = self.input.content
prefix = self.input.text
if len(button_text) > 1:
self.edit(prefix + button_text[index], button, index)
else:
self.edit(prefix + button_text[index])
def on_timeout(self):
def on_timeout(self) -> None:
# Timeout occurred, let's just reset the pending marker.
self.edit(self.input.content)
self.edit(self.input.text)
def edit(self, content: str, button: Button = None, index: int = 0):
if len(content) > self.max_length:
def edit(self, text: str, button: KeyButton = None, index: int = 0) -> None:
if len(text) > self.max_length:
return
self.pending_button = button
@ -187,15 +199,15 @@ class PassphraseKeyboard(ui.Layout):
# modify the input state
pending = button is not None
self.input.edit(content, pending)
self.input.edit(text, pending)
if content:
if text:
self.back.enable()
else:
self.back.disable()
self.prompt.repaint = True
async def handle_input(self):
async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch)
@ -214,7 +226,7 @@ class PassphraseKeyboard(ui.Layout):
else:
self.on_timeout()
async def handle_paging(self):
async def handle_paging(self) -> None:
swipe = await Swipe(SWIPE_HORIZONTAL)
if swipe == SWIPE_LEFT:
self.page = (self.page + 1) % len(KEYBOARD_KEYS)
@ -226,33 +238,33 @@ class PassphraseKeyboard(ui.Layout):
self.input.repaint = True
self.prompt.repaint = True
def on_cancel(self):
def on_cancel(self) -> None:
raise ui.Result(CANCELLED)
def on_confirm(self):
raise ui.Result(self.input.content)
def on_confirm(self) -> None:
raise ui.Result(self.input.text)
def create_tasks(self):
def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering(), self.handle_paging()
class PassphraseSource(ui.Layout):
def __init__(self, content):
def __init__(self, content: ui.Control) -> None:
self.content = content
self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device")
self.device.on_click = self.on_device
self.device.on_click = self.on_device # type: ignore
self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), "Host")
self.host.on_click = self.on_host
self.host.on_click = self.on_host # type: ignore
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y)
self.device.dispatch(event, x, y)
self.host.dispatch(event, x, y)
def on_device(self):
def on_device(self) -> None:
raise ui.Result(PassphraseSourceType.DEVICE)
def on_host(self):
def on_host(self) -> None:
raise ui.Result(PassphraseSourceType.HOST)

@ -11,14 +11,17 @@ from trezor.ui.button import (
ButtonMono,
)
if False:
from typing import Iterable
def digit_area(i):
def digit_area(i: int) -> ui.Area:
if i == 9: # 0-position
i = 10 # display it in the middle
return ui.grid(i + 3) # skip the first line
def generate_digits():
def generate_digits() -> Iterable[int]:
digits = list(range(0, 10)) # 0-9
random.shuffle(digits)
# We lay out the buttons top-left to bottom-right, but the order
@ -27,13 +30,13 @@ def generate_digits():
class PinInput(ui.Control):
def __init__(self, prompt, subprompt, pin):
def __init__(self, prompt: str, subprompt: str, pin: str) -> None:
self.prompt = prompt
self.subprompt = subprompt
self.pin = pin
self.repaint = True
def on_render(self):
def on_render(self) -> None:
if self.repaint:
if self.pin:
self.render_pin()
@ -41,7 +44,7 @@ class PinInput(ui.Control):
self.render_prompt()
self.repaint = False
def render_pin(self):
def render_pin(self) -> None:
display.bar(0, 0, ui.WIDTH, 50, ui.BG)
count = len(self.pin)
BOX_WIDTH = const(240)
@ -54,7 +57,7 @@ class PinInput(ui.Control):
render_x + i * PADDING, RENDER_Y, DOT_SIZE, DOT_SIZE, ui.GREY, ui.BG, 4
)
def render_prompt(self):
def render_prompt(self) -> None:
display.bar(0, 0, ui.WIDTH, 50, ui.BG)
if self.subprompt:
display.text_center(ui.WIDTH // 2, 20, self.prompt, ui.BOLD, ui.GREY, ui.BG)
@ -66,35 +69,37 @@ class PinInput(ui.Control):
class PinButton(Button):
def __init__(self, index, digit, matrix):
self.matrix = matrix
def __init__(self, index: int, digit: int, dialog: "PinDialog"):
self.dialog = dialog
super().__init__(digit_area(index), str(digit), ButtonMono)
def on_click(self):
self.matrix.assign(self.matrix.input.pin + self.content)
def on_click(self) -> None:
self.dialog.assign(self.dialog.input.pin + self.text)
CANCELLED = object()
class PinDialog(ui.Layout):
def __init__(self, prompt, subprompt, allow_cancel=True, maxlength=9):
def __init__(
self, prompt: str, subprompt: str, allow_cancel: bool = True, maxlength: int = 9
) -> None:
self.maxlength = maxlength
self.input = PinInput(prompt, subprompt, "")
icon_confirm = res.load(ui.ICON_CONFIRM)
self.confirm_button = Button(ui.grid(14), icon_confirm, ButtonConfirm)
self.confirm_button.on_click = self.on_confirm
self.confirm_button.on_click = self.on_confirm # type: ignore
self.confirm_button.disable()
icon_back = res.load(ui.ICON_BACK)
self.reset_button = Button(ui.grid(12), icon_back, ButtonClear)
self.reset_button.on_click = self.on_reset
self.reset_button.on_click = self.on_reset # type: ignore
if allow_cancel:
icon_lock = res.load(ui.ICON_LOCK)
self.cancel_button = Button(ui.grid(12), icon_lock, ButtonCancel)
self.cancel_button.on_click = self.on_cancel
self.cancel_button.on_click = self.on_cancel # type: ignore
else:
self.cancel_button = Button(ui.grid(12), "")
self.cancel_button.disable()
@ -103,7 +108,7 @@ class PinDialog(ui.Layout):
PinButton(i, d, self) for i, d in enumerate(generate_digits())
]
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
self.input.dispatch(event, x, y)
if self.input.pin:
self.reset_button.dispatch(event, x, y)
@ -113,7 +118,7 @@ class PinDialog(ui.Layout):
for btn in self.pin_buttons:
btn.dispatch(event, x, y)
def assign(self, pin):
def assign(self, pin: str) -> None:
if len(pin) > self.maxlength:
return
for btn in self.pin_buttons:
@ -132,12 +137,12 @@ class PinDialog(ui.Layout):
self.input.pin = pin
self.input.repaint = True
def on_reset(self):
def on_reset(self) -> None:
self.assign("")
def on_cancel(self):
def on_cancel(self) -> None:
raise ui.Result(CANCELLED)
def on_confirm(self):
def on_confirm(self) -> None:
if self.input.pin:
raise ui.Result(self.input.pin)

@ -1,17 +1,20 @@
from trezor import loop, ui
if False:
from typing import Iterable
class Popup(ui.Layout):
def __init__(self, content, time_ms=0):
def __init__(self, content: ui.Control, time_ms: int = 0) -> None:
self.content = content
self.time_ms = time_ms
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y)
def create_tasks(self):
def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering(), self.handle_timeout()
def handle_timeout(self):
def handle_timeout(self) -> loop.Task: # type: ignore
yield loop.sleep(self.time_ms * 1000)
raise ui.Result(None)

@ -2,11 +2,11 @@ from trezor import ui
class Qr(ui.Control):
def __init__(self, data, x, y, scale):
def __init__(self, data: bytes, x: int, y: int, scale: int):
self.data = data
self.x = x
self.y = y
self.scale = scale
def on_render(self):
def on_render(self) -> None:
ui.display.qrcode(self.x, self.y, self.data, self.scale)

@ -8,8 +8,11 @@ from trezor.ui.swipe import SWIPE_DOWN, SWIPE_UP, SWIPE_VERTICAL, Swipe
if __debug__:
from apps.debug import swipe_signal
if False:
from typing import Iterable, Sequence
def render_scrollbar(pages: int, page: int):
def render_scrollbar(pages: int, page: int) -> None:
BBOX = const(220)
SIZE = const(8)
@ -28,7 +31,7 @@ def render_scrollbar(pages: int, page: int):
ui.display.bar_radius(X, Y + i * padding, SIZE, SIZE, fg, ui.BG, 4)
def render_swipe_icon():
def render_swipe_icon() -> None:
DRAW_DELAY = const(200000)
icon = res.load(ui.ICON_SWIPE)
@ -37,18 +40,20 @@ def render_swipe_icon():
ui.display.icon(70, 205, icon, c, ui.BG)
def render_swipe_text():
def render_swipe_text() -> None:
ui.display.text_center(130, 220, "Swipe", ui.BOLD, ui.GREY, ui.BG)
class Paginated(ui.Layout):
def __init__(self, pages, page=0, one_by_one=False):
def __init__(
self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False
):
self.pages = pages
self.page = page
self.one_by_one = one_by_one
self.repaint = True
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
pages = self.pages
page = self.page
pages[page].dispatch(event, x, y)
@ -63,7 +68,7 @@ class Paginated(ui.Layout):
render_scrollbar(length, page)
self.repaint = False
async def handle_paging(self):
async def handle_paging(self) -> None:
if self.page == 0:
directions = SWIPE_UP
elif self.page == len(self.pages) - 1:
@ -86,21 +91,33 @@ class Paginated(ui.Layout):
self.on_change()
def create_tasks(self):
def create_tasks(self) -> Iterable[loop.Task]:
return self.handle_input(), self.handle_rendering(), self.handle_paging()
def on_change(self):
def on_change(self) -> None:
if self.one_by_one:
raise ui.Result(self.page)
class PageWithButtons(ui.Control):
def __init__(self, content, paginated, index, count):
def __init__(
self,
content: ui.Control,
paginated: "PaginatedWithButtons",
index: int,
count: int,
) -> None:
self.content = content
self.paginated = paginated
self.index = index
self.count = count
# somewhere in the middle, we can go up or down
left = res.load(ui.ICON_BACK)
left_style = ButtonDefault
right = res.load(ui.ICON_CLICK)
right_style = ButtonDefault
if self.index == 0:
# first page, we can cancel or go down
left = res.load(ui.ICON_CANCEL)
@ -113,31 +130,25 @@ class PageWithButtons(ui.Control):
left_style = ButtonDefault
right = res.load(ui.ICON_CONFIRM)
right_style = ButtonConfirm
else:
# somewhere in the middle, we can go up or down
left = res.load(ui.ICON_BACK)
left_style = ButtonDefault
right = res.load(ui.ICON_CLICK)
right_style = ButtonDefault
self.left = Button(ui.grid(8, n_x=2), left, left_style)
self.left.on_click = self.on_left
self.left.on_click = self.on_left # type: ignore
self.right = Button(ui.grid(9, n_x=2), right, right_style)
self.right.on_click = self.on_right
self.right.on_click = self.on_right # type: ignore
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y)
self.left.dispatch(event, x, y)
self.right.dispatch(event, x, y)
def on_left(self):
def on_left(self) -> None:
if self.index == 0:
self.paginated.on_cancel()
else:
self.paginated.on_up()
def on_right(self):
def on_right(self) -> None:
if self.index == self.count - 1:
self.paginated.on_confirm()
else:
@ -145,36 +156,38 @@ class PageWithButtons(ui.Control):
class PaginatedWithButtons(ui.Layout):
def __init__(self, pages, page=0, one_by_one=False):
def __init__(
self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False
) -> None:
self.pages = [
PageWithButtons(p, self, i, len(pages)) for i, p in enumerate(pages)
]
self.page = page
self.one_by_one = one_by_one
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
pages = self.pages
page = self.page
pages[page].dispatch(event, x, y)
if event is ui.RENDER:
render_scrollbar(len(pages), page)
def on_up(self):
def on_up(self) -> None:
self.page = max(self.page - 1, 0)
self.pages[self.page].dispatch(ui.REPAINT, 0, 0)
self.on_change()
def on_down(self):
def on_down(self) -> None:
self.page = min(self.page + 1, len(self.pages) - 1)
self.pages[self.page].dispatch(ui.REPAINT, 0, 0)
self.on_change()
def on_confirm(self):
def on_confirm(self) -> None:
raise ui.Result(CONFIRMED)
def on_cancel(self):
def on_cancel(self) -> None:
raise ui.Result(CANCELLED)
def on_change(self):
def on_change(self) -> None:
if self.one_by_one:
raise ui.Result(self.page)

@ -4,31 +4,31 @@ from trezor.ui.text import LABEL_CENTER, Label
class NumInput(ui.Control):
def __init__(self, count=5, max_count=16, min_count=1):
def __init__(self, count: int = 5, max_count: int = 16, min_count: int = 1) -> None:
self.count = count
self.max_count = max_count
self.min_count = min_count
self.minus = Button(ui.grid(3), "-")
self.minus.on_click = self.on_minus
self.minus.on_click = self.on_minus # type: ignore
self.plus = Button(ui.grid(5), "+")
self.plus.on_click = self.on_plus
self.plus.on_click = self.on_plus # type: ignore
self.text = Label(ui.grid(4), "", LABEL_CENTER, ui.BOLD)
self.edit(count)
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
self.minus.dispatch(event, x, y)
self.plus.dispatch(event, x, y)
self.text.dispatch(event, x, y)
def on_minus(self):
def on_minus(self) -> None:
self.edit(self.count - 1)
def on_plus(self):
def on_plus(self) -> None:
self.edit(self.count + 1)
def edit(self, count):
def edit(self, count: int) -> None:
count = max(count, self.min_count)
count = min(count, self.max_count)
if self.count != count:
@ -45,5 +45,5 @@ class NumInput(ui.Control):
else:
self.plus.enable()
def on_change(self, count):
def on_change(self, count: int) -> None:
pass

@ -2,6 +2,9 @@ from micropython import const
from trezor import io, loop, ui
if False:
from typing import Generator
SWIPE_UP = const(0x01)
SWIPE_DOWN = const(0x02)
SWIPE_LEFT = const(0x04)
@ -15,24 +18,26 @@ _SWIPE_TRESHOLD = const(30)
class Swipe(ui.Control):
def __init__(self, directions=SWIPE_ALL, area=None):
def __init__(self, directions: int = SWIPE_ALL, area: ui.Area = None) -> None:
if area is None:
area = (0, 0, ui.WIDTH, ui.HEIGHT)
self.area = area
self.directions = directions
self.start_x = None
self.start_y = None
self.light_origin = None
self.started = False
self.start_x = 0
self.start_y = 0
self.light_origin = ui.BACKLIGHT_NORMAL
self.light_target = ui.BACKLIGHT_NONE
def on_touch_start(self, x, y):
def on_touch_start(self, x: int, y: int) -> None:
if ui.in_area(self.area, x, y):
self.start_x = x
self.start_y = y
self.light_origin = ui.BACKLIGHT_NORMAL
self.started = True
def on_touch_move(self, x, y):
if self.start_x is None:
def on_touch_move(self, x: int, y: int) -> None:
if not self.started:
return # not started in our area
dirs = self.directions
@ -61,8 +66,8 @@ class Swipe(ui.Control):
)
)
def on_touch_end(self, x, y):
if self.start_x is None:
def on_touch_end(self, x: int, y: int) -> None:
if not self.started:
return # not started in our area
dirs = self.directions
@ -93,13 +98,15 @@ class Swipe(ui.Control):
# no swipe detected, reset the state
ui.display.backlight(self.light_origin)
self.start_x = None
self.start_y = None
self.started = False
def on_swipe(self, swipe):
def on_swipe(self, swipe: int) -> None:
raise ui.Result(swipe)
def __iter__(self):
def __await__(self) -> Generator:
return self.__iter__() # type: ignore
def __iter__(self) -> loop.Task: # type: ignore
try:
touch = loop.wait(io.TOUCH)
while True:

@ -2,6 +2,9 @@ from micropython import const
from trezor import ui
if False:
from typing import List, Union
TEXT_HEADER_HEIGHT = const(48)
TEXT_LINE_HEIGHT = const(26)
TEXT_LINE_HEIGHT_HALF = const(13)
@ -12,9 +15,12 @@ TEXT_MAX_LINES = const(5)
BR = const(-256)
BR_HALF = const(-257)
if False:
TextContent = Union[str, int]
def render_text(
words: list,
words: List[TextContent],
new_lines: bool,
max_lines: int,
font: int = ui.NORMAL,
@ -128,32 +134,32 @@ class Text(ui.Control):
self.icon_color = icon_color
self.max_lines = max_lines
self.new_lines = new_lines
self.content = []
self.content = [] # type: List[Union[str, int]]
self.repaint = True
def normal(self, *content):
def normal(self, *content: TextContent) -> None:
self.content.append(ui.NORMAL)
self.content.extend(content)
def bold(self, *content):
def bold(self, *content: TextContent) -> None:
self.content.append(ui.BOLD)
self.content.extend(content)
def mono(self, *content):
def mono(self, *content: TextContent) -> None:
self.content.append(ui.MONO)
self.content.extend(content)
def mono_bold(self, *content):
def mono_bold(self, *content: TextContent) -> None:
self.content.append(ui.MONO_BOLD)
self.content.extend(content)
def br(self):
def br(self) -> None:
self.content.append(BR)
def br_half(self):
def br_half(self) -> None:
self.content.append(BR_HALF)
def on_render(self):
def on_render(self) -> None:
if self.repaint:
ui.header(
self.header_text,
@ -172,21 +178,27 @@ LABEL_RIGHT = const(2)
class Label(ui.Control):
def __init__(self, area, content, align=LABEL_LEFT, style=ui.NORMAL):
def __init__(
self,
area: ui.Area,
content: str,
align: int = LABEL_LEFT,
style: int = ui.NORMAL,
) -> None:
self.area = area
self.content = content
self.align = align
self.style = style
self.repaint = True
def on_render(self):
def on_render(self) -> None:
if self.repaint:
align = self.align
ax, ay, aw, ah = self.area
tx = ax + aw // 2
ty = ay + ah // 2 + 8
if align is LABEL_LEFT:
ui.display.text_left(tx, ty, self.content, self.style, ui.FG, ui.BG, aw)
ui.display.text(tx, ty, self.content, self.style, ui.FG, ui.BG, aw)
elif align is LABEL_CENTER:
ui.display.text_center(
tx, ty, self.content, self.style, ui.FG, ui.BG, aw

@ -5,20 +5,20 @@ from trezor.ui.button import Button
class WordSelector(ui.Layout):
def __init__(self, content):
def __init__(self, content: ui.Control) -> None:
self.content = content
self.w12 = Button(ui.grid(6, n_y=4), "12")
self.w12.on_click = self.on_w12
self.w12.on_click = self.on_w12 # type: ignore
self.w18 = Button(ui.grid(7, n_y=4), "18")
self.w18.on_click = self.on_w18
self.w18.on_click = self.on_w18 # type: ignore
self.w20 = Button(ui.grid(8, n_y=4), "20")
self.w20.on_click = self.on_w20
self.w20.on_click = self.on_w20 # type: ignore
self.w24 = Button(ui.grid(9, n_y=4), "24")
self.w24.on_click = self.on_w24
self.w24.on_click = self.on_w24 # type: ignore
self.w33 = Button(ui.grid(10, n_y=4), "33")
self.w33.on_click = self.on_w33
self.w33.on_click = self.on_w33 # type: ignore
def dispatch(self, event, x, y):
def dispatch(self, event: int, x: int, y: int) -> None:
self.content.dispatch(event, x, y)
self.w12.dispatch(event, x, y)
self.w18.dispatch(event, x, y)
@ -26,17 +26,17 @@ class WordSelector(ui.Layout):
self.w24.dispatch(event, x, y)
self.w33.dispatch(event, x, y)
def on_w12(self):
def on_w12(self) -> None:
raise ui.Result(12)
def on_w18(self):
def on_w18(self) -> None:
raise ui.Result(18)
def on_w20(self):
def on_w20(self) -> None:
raise ui.Result(20)
def on_w24(self):
def on_w24(self) -> None:
raise ui.Result(24)
def on_w33(self):
def on_w33(self) -> None:
raise ui.Result(33)

@ -27,12 +27,15 @@ if __debug__:
SAVE_SCREEN = 0
LOG_MEMORY = 0
if False:
from typing import Iterable, Iterator, Protocol, List, TypeVar
def unimport_begin():
def unimport_begin() -> Iterable[str]:
return set(sys.modules)
def unimport_end(mods):
def unimport_end(mods: Iterable[str]) -> None:
for mod in sys.modules:
if mod not in mods:
# remove reference from sys.modules
@ -53,7 +56,7 @@ def unimport_end(mods):
gc.collect()
def ensure(cond, msg=None):
def ensure(cond: bool, msg: str = None) -> None:
if not cond:
if msg is None:
raise AssertionError
@ -61,48 +64,71 @@ def ensure(cond, msg=None):
raise AssertionError(msg)
def chunks(items, size):
if False:
Chunked = TypeVar("Chunked")
def chunks(items: List[Chunked], size: int) -> Iterator[List[Chunked]]:
for i in range(0, len(items), size):
yield items[i : i + size]
def format_amount(amount, decimals):
def format_amount(amount: int, decimals: int) -> str:
d = pow(10, decimals)
amount = ("%d.%0*d" % (amount // d, decimals, amount % d)).rstrip("0")
if amount.endswith("."):
amount = amount[:-1]
return amount
s = ("%d.%0*d" % (amount // d, decimals, amount % d)).rstrip("0").rstrip(".")
return s
def format_ordinal(number):
def format_ordinal(number: int) -> str:
return str(number) + {1: "st", 2: "nd", 3: "rd"}.get(
4 if 10 <= number % 100 < 20 else number % 10, "th"
)
if False:
class HashContext(Protocol):
def update(self, buf: bytes) -> None:
...
def digest(self) -> bytes:
...
class Writer(Protocol):
def append(self, b: int) -> None:
...
def extend(self, buf: bytes) -> None:
...
def write(self, buf: bytes) -> None:
...
class HashWriter:
def __init__(self, ctx):
def __init__(self, ctx: HashContext) -> None:
self.ctx = ctx
self.buf = bytearray(1) # used in append()
def extend(self, buf: bytearray):
self.ctx.update(buf)
def append(self, b: int) -> None:
self.buf[0] = b
self.ctx.update(self.buf)
def write(self, buf: bytearray): # alias for extend()
def extend(self, buf: bytes) -> None:
self.ctx.update(buf)
async def awrite(self, buf: bytearray): # AsyncWriter interface
return self.ctx.update(buf)
def write(self, buf: bytes) -> None: # alias for extend()
self.ctx.update(buf)
def append(self, b: int):
self.buf[0] = b
self.ctx.update(self.buf)
async def awrite(self, buf: bytes) -> int: # AsyncWriter interface
self.ctx.update(buf)
return len(buf)
def get_digest(self) -> bytes:
return self.ctx.digest()
def obj_eq(l, r):
def obj_eq(l: object, r: object) -> bool:
"""
Compares object contents, supports __slots__.
"""
@ -118,7 +144,7 @@ def obj_eq(l, r):
return True
def obj_repr(o):
def obj_repr(o: object) -> str:
"""
Returns a string representation of object, supports __slots__.
"""

@ -7,11 +7,28 @@ from trezor.wire.errors import Error
# import all errors into namespace, so that `wire.Error` is available elsewhere
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if False:
from typing import (
Any,
Awaitable,
Dict,
Callable,
Iterable,
List,
Optional,
Tuple,
Type,
)
from trezorio import WireInterface
from protobuf import LoadedMessageType, MessageType
Handler = Callable[..., loop.Task]
workflow_handlers = {}
workflow_handlers = {} # type: Dict[int, Tuple[Handler, Iterable]]
def add(mtype, pkgname, modname, namespace=None):
def add(mtype: int, pkgname: str, modname: str, namespace: List = None) -> None:
"""Shortcut for registering a dynamically-imported Protobuf workflow."""
if namespace is not None:
register(
@ -27,7 +44,7 @@ def add(mtype, pkgname, modname, namespace=None):
register(mtype, protobuf_workflow, import_workflow, pkgname, modname)
def register(mtype, handler, *args):
def register(mtype: int, handler: Handler, *args: Any) -> None:
"""Register `handler` to get scheduled after `mtype` message is received."""
if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType):
mtype = mtype.MESSAGE_WIRE_TYPE
@ -36,54 +53,75 @@ def register(mtype, handler, *args):
workflow_handlers[mtype] = (handler, args)
def setup(iface):
def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(session_handler(iface, codec_v1.SESSION_ID))
class Context:
def __init__(self, iface, sid):
def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface
self.sid = sid
async def call(self, msg, *types):
"""
Reply with `msg` and wait for one of `types`. See `self.write()` and
`self.read()`.
"""
async def call(
self, msg: MessageType, exptype: Type[LoadedMessageType]
) -> LoadedMessageType:
await self.write(msg)
del msg
return await self.read(types)
return await self.read(exptype)
async def read(self, types):
"""
Wait for incoming message on this wire context and return it. Raises
`UnexpectedMessageError` if the message type does not match one of
`types`; and caller should always make sure to re-raise it.
"""
reader = self.getreader()
async def call_any(self, msg: MessageType, *allowed_types: int) -> MessageType:
await self.write(msg)
del msg
return await self.read_any(allowed_types)
async def read(
self, exptype: Optional[Type[LoadedMessageType]]
) -> LoadedMessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, types
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype
)
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if reader.type not in types:
if exptype is None or reader.type != exptype.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(reader)
# look up the protobuf class and parse the message
pbtype = messages.get_type(reader.type)
return await protobuf.load_message(reader, pbtype)
# parse the message and return it
return await protobuf.load_message(reader, exptype)
async def write(self, msg):
"""
Write a protobuf message to this wire context.
"""
writer = self.getwriter()
async def read_any(self, allowed_types: Iterable[int]) -> MessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
allowed_types,
)
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if reader.type not in allowed_types:
raise UnexpectedMessageError(reader)
# find the protobuf type
exptype = messages.get_type(reader.type)
# parse the message and return it
return await protobuf.load_message(reader, exptype)
async def write(self, msg: protobuf.MessageType) -> None:
writer = self.make_writer()
if __debug__:
log.debug(
@ -99,35 +137,35 @@ class Context:
await protobuf.dump_message(writer, msg, fields)
await writer.aclose()
def wait(self, *tasks):
def wait(self, *tasks: Awaitable) -> Any:
"""
Wait until one of the passed tasks finishes, and return the result,
while servicing the wire context. If a message comes until one of the
tasks ends, `UnexpectedMessageError` is raised.
"""
return loop.spawn(self.read(()), *tasks)
return loop.spawn(self.read(None), *tasks)
def getreader(self):
def make_reader(self) -> codec_v1.Reader:
return codec_v1.Reader(self.iface)
def getwriter(self):
def make_writer(self) -> codec_v1.Writer:
return codec_v1.Writer(self.iface)
class UnexpectedMessageError(Exception):
def __init__(self, reader):
def __init__(self, reader: codec_v1.Reader) -> None:
super().__init__()
self.reader = reader
async def session_handler(iface, sid):
async def session_handler(iface: WireInterface, sid: int) -> None:
reader = None
ctx = Context(iface, sid)
while True:
try:
# wait for new message, if needed, and find handler
if not reader:
reader = ctx.getreader()
reader = ctx.make_reader()
await reader.aopen()
try:
handler, args = workflow_handlers[reader.type]
@ -160,7 +198,9 @@ async def session_handler(iface, sid):
reader = None
async def protobuf_workflow(ctx, reader, handler, *args):
async def protobuf_workflow(
ctx: Context, reader: codec_v1.Reader, handler: Handler, *args: Any
) -> None:
from trezor.messages.Failure import Failure
req = await protobuf.load_message(reader, messages.get_type(reader.type))
@ -185,7 +225,13 @@ async def protobuf_workflow(ctx, reader, handler, *args):
await ctx.write(res)
async def keychain_workflow(ctx, req, namespace, handler, *args):
async def keychain_workflow(
ctx: Context,
req: protobuf.MessageType,
namespace: List,
handler: Handler,
*args: Any
) -> Any:
from apps.common import seed
keychain = await seed.get_keychain(ctx, namespace)
@ -196,22 +242,28 @@ async def keychain_workflow(ctx, req, namespace, handler, *args):
keychain.__del__()
def import_workflow(ctx, req, pkgname, modname, *args):
def import_workflow(
ctx: Context, req: protobuf.MessageType, pkgname: str, modname: str, *args: Any
) -> Any:
modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0)
module = __import__(modpath, None, None, (modname,), 0) # type: ignore
handler = getattr(module, modname)
return handler(ctx, req, *args)
async def unexpected_msg(ctx, reader):
async def unexpected_msg(ctx: Context, reader: codec_v1.Reader) -> None:
from trezor.messages.Failure import Failure
# receive the message and throw it away
while reader.size > 0:
buf = bytearray(reader.size)
await reader.areadinto(buf)
await read_full_msg(reader)
# respond with an unknown message error
await ctx.write(
Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
)
async def read_full_msg(reader: codec_v1.Reader) -> None:
while reader.size > 0:
buf = bytearray(reader.size)
await reader.areadinto(buf)

@ -3,6 +3,9 @@ from micropython import const
from trezor import io, loop, utils
if False:
from trezorio import WireInterface
_REP_LEN = const(64)
_REP_MARKER = const(63) # ord('?')
@ -12,6 +15,7 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report
_REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0)
INVALID_TYPE = const(-1)
class Reader:
@ -20,17 +24,14 @@ class Reader:
async-file-like interface.
"""
def __init__(self, iface):
def __init__(self, iface: WireInterface) -> None:
self.iface = iface
self.type = None
self.size = None
self.data = None
self.type = INVALID_TYPE
self.size = 0
self.ofs = 0
self.data = bytes()
def __repr__(self):
return "<ReaderV1: type=%d size=%dB>" % (self.type, self.size)
async def aopen(self):
async def aopen(self) -> None:
"""
Begin the message transmission by waiting for initial V2 message report
on this session. `self.type` and `self.size` are initialized and
@ -53,7 +54,7 @@ class Reader:
self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize]
self.ofs = 0
async def areadinto(self, buf):
async def areadinto(self, buf: bytearray) -> int:
"""
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
@ -91,17 +92,14 @@ class Writer:
async-file-like interface.
"""
def __init__(self, iface):
def __init__(self, iface: WireInterface):
self.iface = iface
self.type = None
self.size = None
self.data = bytearray(_REP_LEN)
self.type = INVALID_TYPE
self.size = 0
self.ofs = 0
self.data = bytearray(_REP_LEN)
def __repr__(self):
return "<WriterV1: type=%d size=%dB>" % (self.type, self.size)
def setheader(self, mtype, msize):
def setheader(self, mtype: int, msize: int) -> None:
"""
Reset the writer state and load the message header with passed type and
total message size.
@ -113,7 +111,7 @@ class Writer:
)
self.ofs = _REP_INIT_DATA
async def awrite(self, buf):
async def awrite(self, buf: bytes) -> int:
"""
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
@ -142,7 +140,7 @@ class Writer:
return nwritten
async def aclose(self):
async def aclose(self) -> None:
"""Flush and close the message transmission."""
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,

@ -2,72 +2,72 @@ from trezor.messages import FailureType
class Error(Exception):
def __init__(self, code, message):
def __init__(self, code: int, message: str) -> None:
super().__init__()
self.code = code
self.message = message
class UnexpectedMessage(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.UnexpectedMessage, message)
class ButtonExpected(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.ButtonExpected, message)
class DataError(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.DataError, message)
class ActionCancelled(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.ActionCancelled, message)
class PinExpected(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.PinExpected, message)
class PinCancelled(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.PinCancelled, message)
class PinInvalid(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.PinInvalid, message)
class InvalidSignature(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.InvalidSignature, message)
class ProcessError(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.ProcessError, message)
class NotEnoughFunds(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.NotEnoughFunds, message)
class NotInitialized(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.NotInitialized, message)
class PinMismatch(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.PinMismatch, message)
class FirmwareError(Error):
def __init__(self, message):
def __init__(self, message: str) -> None:
super().__init__(FailureType.FirmwareError, message)

@ -1,17 +1,21 @@
from trezor import loop
workflows = []
layouts = []
if False:
from trezor import ui
from typing import List, Callable, Optional
workflows = [] # type: List[loop.Task]
layouts = [] # type: List[ui.Layout]
layout_signal = loop.signal()
default = None
default_layout = None
default = None # type: Optional[loop.Task]
default_layout = None # type: Optional[Callable[[], loop.Task]]
def onstart(w):
def onstart(w: loop.Task) -> None:
workflows.append(w)
def onclose(w):
def onclose(w: loop.Task) -> None:
workflows.remove(w)
if not layouts and default_layout:
startdefault(default_layout)
@ -24,7 +28,7 @@ def onclose(w):
micropython.mem_info()
def closedefault():
def closedefault() -> None:
global default
if default:
@ -32,7 +36,7 @@ def closedefault():
default = None
def startdefault(layout):
def startdefault(layout: Callable[[], loop.Task]) -> None:
global default
global default_layout
@ -42,18 +46,19 @@ def startdefault(layout):
loop.schedule(default)
def restartdefault():
def restartdefault() -> None:
global default_layout
d = default_layout
closedefault()
startdefault(d)
if default_layout:
startdefault(default_layout)
def onlayoutstart(l):
def onlayoutstart(l: ui.Layout) -> None:
closedefault()
layouts.append(l)
def onlayoutclose(l):
def onlayoutclose(l: ui.Layout) -> None:
if l in layouts:
layouts.remove(l)

@ -33,15 +33,16 @@ addopts = --strict
xfail_strict = true
[mypy]
mypy_path = mocks,mocks/generated
warn_unused_configs = True
mypy_path = src,mocks,mocks/generated
check_untyped_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
check_untyped_defs = True
namespace_packages = True
# no_implicit_optional = True
warn_redundant_casts = True
warn_return_any = True
warn_unused_configs = True
warn_unused_ignores = True
disallow_untyped_decorators = True

Loading…
Cancel
Save