From 758a1a252831af0748f5514fbfd9604829c968c3 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Wed, 3 Jul 2019 15:07:04 +0200 Subject: [PATCH] core/typing: add annotations --- core/Makefile | 2 +- core/src/apps/__init__.py | 0 core/src/apps/cardano/__init__.py | 2 +- core/src/apps/cardano/sign_tx.py | 2 +- core/src/apps/common/address_type.py | 15 ++- core/src/apps/common/cache.py | 25 ++-- core/src/apps/common/cbor.py | 41 +++--- core/src/apps/common/confirm.py | 49 ++++--- core/src/apps/common/layout.py | 20 ++- core/src/apps/common/mnemonic/__init__.py | 18 ++- core/src/apps/common/mnemonic/bip39.py | 9 +- core/src/apps/common/mnemonic/slip39.py | 13 +- core/src/apps/common/paths.py | 20 ++- core/src/apps/common/request_passphrase.py | 19 +-- core/src/apps/common/seed.py | 9 +- core/src/apps/common/signverify.py | 8 +- core/src/apps/common/writers.py | 21 +-- core/src/apps/debug/__init__.py | 23 +++- core/src/apps/eos/__init__.py | 2 +- core/src/apps/eos/actions/__init__.py | 16 ++- core/src/apps/eos/actions/layout.py | 100 ++++++++------ core/src/apps/eos/get_public_key.py | 11 +- core/src/apps/eos/helpers.py | 4 +- core/src/apps/eos/layout.py | 10 +- core/src/apps/eos/sign_tx.py | 13 +- core/src/apps/eos/writers.py | 89 ++++++------- core/src/apps/ethereum/__init__.py | 2 +- core/src/apps/ethereum/sign_tx.py | 2 +- core/src/apps/homescreen/__init__.py | 22 +++- core/src/apps/homescreen/homescreen.py | 8 +- core/src/apps/lisk/__init__.py | 2 +- core/src/apps/management/__init__.py | 2 +- core/src/apps/management/change_pin.py | 7 +- core/src/apps/management/recovery_device.py | 17 ++- core/src/apps/management/reset_device.py | 15 ++- core/src/apps/monero/__init__.py | 2 +- core/src/apps/monero/key_image_sync.py | 22 ++-- core/src/apps/monero/layout/common.py | 7 +- core/src/apps/monero/live_refresh.py | 2 +- core/src/apps/monero/sign_tx.py | 2 +- core/src/apps/nem/__init__.py | 2 +- core/src/apps/ripple/__init__.py | 2 +- core/src/apps/stellar/__init__.py | 2 +- core/src/apps/stellar/sign_tx.py | 2 +- core/src/apps/tezos/__init__.py | 2 +- core/src/apps/wallet/__init__.py | 2 +- core/src/apps/wallet/sign_tx/__init__.py | 2 +- core/src/boot.py | 4 +- core/src/protobuf.py | 86 ++++++------ core/src/trezor/crypto/__init__.py | 17 --- core/src/trezor/crypto/hmac.py | 22 +++- core/src/trezor/crypto/slip39.py | 126 +++++++++++------- core/src/trezor/log.py | 23 ++-- core/src/trezor/loop.py | 109 ++++++++++------ core/src/trezor/res/__init__.py | 6 +- core/src/trezor/ui/__init__.py | 53 +++++--- core/src/trezor/ui/button.py | 78 +++++++---- core/src/trezor/ui/checklist.py | 19 ++- core/src/trezor/ui/confirm.py | 56 ++++---- core/src/trezor/ui/container.py | 4 +- core/src/trezor/ui/info.py | 25 +++- core/src/trezor/ui/loader.py | 33 +++-- core/src/trezor/ui/mnemonic_bip39.py | 101 +++++++------- core/src/trezor/ui/mnemonic_slip39.py | 91 +++++++------ core/src/trezor/ui/passphrase.py | 108 ++++++++------- core/src/trezor/ui/pin.py | 43 +++--- core/src/trezor/ui/popup.py | 11 +- core/src/trezor/ui/qr.py | 4 +- core/src/trezor/ui/scroll.py | 67 ++++++---- core/src/trezor/ui/shamir.py | 16 +-- core/src/trezor/ui/swipe.py | 33 +++-- core/src/trezor/ui/text.py | 36 +++-- core/src/trezor/ui/word_select.py | 24 ++-- core/src/trezor/utils.py | 68 +++++++--- core/src/trezor/wire/__init__.py | 138 ++++++++++++++------ core/src/trezor/wire/codec_v1.py | 36 +++-- core/src/trezor/wire/errors.py | 28 ++-- core/src/trezor/workflow.py | 31 +++-- setup.cfg | 9 +- 79 files changed, 1292 insertions(+), 880 deletions(-) create mode 100644 core/src/apps/__init__.py diff --git a/core/Makefile b/core/Makefile index 913d46086..6e3cfbfd2 100644 --- a/core/Makefile +++ b/core/Makefile @@ -72,7 +72,7 @@ pylint: ## run pylint on application sources and tests pylint -E $(shell find src tests -name *.py) mypy: - mypy \ + mypy --config-file ../setup.cfg \ src/main.py ## code generation: diff --git a/core/src/apps/__init__.py b/core/src/apps/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/src/apps/cardano/__init__.py b/core/src/apps/cardano/__init__.py index 648340077..bb7d3e8b0 100644 --- a/core/src/apps/cardano/__init__.py +++ b/core/src/apps/cardano/__init__.py @@ -7,7 +7,7 @@ CURVE = "ed25519" SEED_NAMESPACE = [HARDENED | 44, HARDENED | 1815] -def boot(): +def boot() -> None: wire.add(MessageType.CardanoGetAddress, __name__, "get_address") wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key") wire.add(MessageType.CardanoSignTx, __name__, "sign_tx") diff --git a/core/src/apps/cardano/sign_tx.py b/core/src/apps/cardano/sign_tx.py index 9e204bd18..1978e9680 100644 --- a/core/src/apps/cardano/sign_tx.py +++ b/core/src/apps/cardano/sign_tx.py @@ -4,8 +4,8 @@ from trezor import log, wire from trezor.crypto import base58, hashlib from trezor.crypto.curve import ed25519 from trezor.messages.CardanoSignedTx import CardanoSignedTx +from trezor.messages.CardanoTxAck import CardanoTxAck from trezor.messages.CardanoTxRequest import CardanoTxRequest -from trezor.messages.MessageType import CardanoTxAck from apps.cardano import CURVE, seed from apps.cardano.address import ( diff --git a/core/src/apps/common/address_type.py b/core/src/apps/common/address_type.py index 41ba68b6d..16a43efd1 100644 --- a/core/src/apps/common/address_type.py +++ b/core/src/apps/common/address_type.py @@ -1,4 +1,9 @@ -def length(address_type): +if False: + from typing import Tuple + from apps.common.coininfo import CoinType + + +def length(address_type: int) -> int: if address_type <= 0xFF: return 1 if address_type <= 0xFFFF: @@ -9,21 +14,21 @@ def length(address_type): return 4 -def tobytes(address_type: int): +def tobytes(address_type: int) -> bytes: return address_type.to_bytes(length(address_type), "big") -def check(address_type, raw_address): +def check(address_type: int, raw_address: bytes) -> bool: return raw_address.startswith(tobytes(address_type)) -def strip(address_type, raw_address): +def strip(address_type: int, raw_address: bytes) -> bytes: if not check(address_type, raw_address): raise ValueError("Invalid address") return raw_address[length(address_type) :] -def split(coin, raw_address): +def split(coin: CoinType, raw_address: bytes) -> Tuple[bytes, bytes]: for f in ( "address_type", "address_type_p2sh", diff --git a/core/src/apps/common/cache.py b/core/src/apps/common/cache.py index f16a603db..435cc4771 100644 --- a/core/src/apps/common/cache.py +++ b/core/src/apps/common/cache.py @@ -2,12 +2,15 @@ from trezor.crypto import hashlib, hmac, random from apps.common import storage -_cached_seed = None -_cached_passphrase = None -_cached_passphrase_fprint = b"\x00\x00\x00\x00" +if False: + from typing import Optional +_cached_seed = None # type: Optional[bytes] +_cached_passphrase = None # type: Optional[str] +_cached_passphrase_fprint = b"\x00\x00\x00\x00" # type: bytes -def get_state(prev_state: bytes = None, passphrase: str = None) -> bytes: + +def get_state(prev_state: bytes = None, passphrase: str = None) -> Optional[bytes]: if prev_state is None: salt = random.bytes(32) # generate a random salt if no state provided else: @@ -29,34 +32,34 @@ def _compute_state(salt: bytes, passphrase: str) -> bytes: return salt + state -def get_seed(): +def get_seed() -> Optional[bytes]: return _cached_seed -def get_passphrase(): +def get_passphrase() -> Optional[str]: return _cached_passphrase -def get_passphrase_fprint(): +def get_passphrase_fprint() -> bytes: return _cached_passphrase_fprint -def has_passphrase(): +def has_passphrase() -> bool: return _cached_passphrase is not None -def set_seed(seed): +def set_seed(seed: Optional[bytes]) -> None: global _cached_seed _cached_seed = seed -def set_passphrase(passphrase): +def set_passphrase(passphrase: Optional[str]) -> None: global _cached_passphrase, _cached_passphrase_fprint _cached_passphrase = passphrase _cached_passphrase_fprint = _compute_state(b"FPRINT", passphrase or "")[:4] -def clear(keep_passphrase: bool = False): +def clear(keep_passphrase: bool = False) -> None: set_seed(None) if not keep_passphrase: set_passphrase(None) diff --git a/core/src/apps/common/cbor.py b/core/src/apps/common/cbor.py index 982ef13df..95789091d 100644 --- a/core/src/apps/common/cbor.py +++ b/core/src/apps/common/cbor.py @@ -6,7 +6,11 @@ import ustruct as struct from micropython import const from trezor import log -from trezor.utils import ensure + +if False: + from typing import Any, Dict, Iterable, List, Tuple + + Value = Any _CBOR_TYPE_MASK = const(0xE0) _CBOR_INFO_BITS = const(0x1F) @@ -32,7 +36,7 @@ _CBOR_BREAK = const(0x1F) _CBOR_RAW_TAG = const(0x18) -def _header(typ, l: int): +def _header(typ: int, l: int) -> bytes: if l < 24: return struct.pack(">B", typ + l) elif l < 2 ** 8: @@ -47,7 +51,7 @@ def _header(typ, l: int): raise NotImplementedError("Length %d not suppported" % l) -def _cbor_encode(value): +def _cbor_encode(value: Value) -> Iterable[bytes]: if isinstance(value, int): if value >= 0: yield _header(_CBOR_UNSIGNED_INT, value) @@ -95,7 +99,7 @@ def _cbor_encode(value): raise NotImplementedError -def _read_length(cbor, aux): +def _read_length(cbor: bytes, aux: int) -> Tuple[int, bytes]: if aux < _CBOR_UINT8_FOLLOWS: return (aux, cbor) elif aux == _CBOR_UINT8_FOLLOWS: @@ -124,7 +128,7 @@ def _read_length(cbor, aux): raise NotImplementedError("Length %d not suppported" % aux) -def _cbor_decode(cbor): +def _cbor_decode(cbor: bytes) -> Tuple[Value, bytes]: fb = cbor[0] data = b"" fb_type = fb & _CBOR_TYPE_MASK @@ -158,7 +162,7 @@ def _cbor_decode(cbor): res.append(item) return (res, data) elif fb_type == _CBOR_MAP: - res = {} + res = {} # type: Dict[Value, Value] if fb_aux == _CBOR_VAR_FOLLOWS: data = cbor[1:] while True: @@ -201,36 +205,41 @@ def _cbor_decode(cbor): class Tagged: - def __init__(self, tag, value): + def __init__(self, tag: int, value: Value) -> None: self.tag = tag self.value = value - def __eq__(self, other): - return self.tag == other.tag and self.value == other.value + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, Tagged) + and self.tag == other.tag + and self.value == other.value + ) class Raw: - def __init__(self, value): + def __init__(self, value: Value): self.value = value class IndefiniteLengthArray: - def __init__(self, array): - ensure(isinstance(array, list)) + def __init__(self, array: List[Value]) -> None: self.array = array - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, IndefiniteLengthArray): return self.array == other.array - else: + elif isinstance(other, list): return self.array == other + else: + return False -def encode(value): +def encode(value: Value) -> bytes: return b"".join(_cbor_encode(value)) -def decode(cbor: bytes): +def decode(cbor: bytes) -> Value: res, check = _cbor_decode(cbor) if not (check == b""): raise ValueError diff --git a/core/src/apps/common/confirm.py b/core/src/apps/common/confirm.py index 280f3a0e1..f64cbcf2c 100644 --- a/core/src/apps/common/confirm.py +++ b/core/src/apps/common/confirm.py @@ -1,23 +1,30 @@ from trezor import wire -from trezor.messages import ButtonRequestType, MessageType +from trezor.messages import ButtonRequestType +from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonRequest import ButtonRequest from trezor.ui.confirm import CONFIRMED, Confirm, HoldToConfirm if __debug__: from apps.debug import confirm_signal +if False: + from typing import Any + from trezor import ui + from trezor.ui.confirm import ButtonContent, ButtonStyleType + from trezor.ui.loader import LoaderStyleType + async def confirm( - ctx, - content, - code=ButtonRequestType.Other, - confirm=Confirm.DEFAULT_CONFIRM, - confirm_style=Confirm.DEFAULT_CONFIRM_STYLE, - cancel=Confirm.DEFAULT_CANCEL, - cancel_style=Confirm.DEFAULT_CANCEL_STYLE, - major_confirm=None, -): - await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck) + ctx: wire.Context, + content: ui.Control, + code: int = ButtonRequestType.Other, + confirm: ButtonContent = Confirm.DEFAULT_CONFIRM, + confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE, + cancel: ButtonContent = Confirm.DEFAULT_CANCEL, + cancel_style: ButtonStyleType = Confirm.DEFAULT_CANCEL_STYLE, + major_confirm: bool = False, +) -> bool: + await ctx.call(ButtonRequest(code=code), ButtonAck) if content.__class__.__name__ == "Paginated": content.pages[-1] = Confirm( @@ -41,14 +48,14 @@ async def confirm( async def hold_to_confirm( - ctx, - content, - code=ButtonRequestType.Other, - confirm=HoldToConfirm.DEFAULT_CONFIRM, - confirm_style=HoldToConfirm.DEFAULT_CONFIRM_STYLE, - loader_style=HoldToConfirm.DEFAULT_LOADER_STYLE, -): - await ctx.call(ButtonRequest(code=code), MessageType.ButtonAck) + ctx: wire.Context, + content: ui.Control, + code: int = ButtonRequestType.Other, + confirm: ButtonContent = Confirm.DEFAULT_CONFIRM, + confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE, + loader_style: LoaderStyleType = HoldToConfirm.DEFAULT_LOADER_STYLE, +) -> bool: + await ctx.call(ButtonRequest(code=code), ButtonAck) if content.__class__.__name__ == "Paginated": content.pages[-1] = HoldToConfirm( @@ -64,13 +71,13 @@ async def hold_to_confirm( return await ctx.wait(dialog) is CONFIRMED -async def require_confirm(*args, **kwargs): +async def require_confirm(*args: Any, **kwargs: Any) -> None: confirmed = await confirm(*args, **kwargs) if not confirmed: raise wire.ActionCancelled("Cancelled") -async def require_hold_to_confirm(*args, **kwargs): +async def require_hold_to_confirm(*args: Any, **kwargs: Any) -> None: confirmed = await hold_to_confirm(*args, **kwargs) if not confirmed: raise wire.ActionCancelled("Cancelled") diff --git a/core/src/apps/common/layout.py b/core/src/apps/common/layout.py index 6b4571492..061e3c3e2 100644 --- a/core/src/apps/common/layout.py +++ b/core/src/apps/common/layout.py @@ -12,10 +12,14 @@ from trezor.utils import chunks from apps.common import HARDENED from apps.common.confirm import confirm, require_confirm +if False: + from typing import Iterable + from trezor import wire + async def show_address( - ctx, address: str, desc: str = "Confirm address", network: str = None -): + ctx: wire.Context, address: str, desc: str = "Confirm address", network: str = None +) -> bool: text = Text(desc, ui.ICON_RECEIVE, ui.GREEN) if network is not None: text.normal("%s network" % network) @@ -30,7 +34,9 @@ async def show_address( ) -async def show_qr(ctx, address: str, desc: str = "Confirm address"): +async def show_qr( + ctx: wire.Context, address: str, desc: str = "Confirm address" +) -> bool: QR_X = const(120) QR_Y = const(115) QR_COEF = const(4) @@ -47,19 +53,19 @@ async def show_qr(ctx, address: str, desc: str = "Confirm address"): ) -async def show_pubkey(ctx, pubkey: bytes): +async def show_pubkey(ctx: wire.Context, pubkey: bytes) -> None: lines = chunks(hexlify(pubkey).decode(), 18) text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN) text.mono(*lines) - return await require_confirm(ctx, text, ButtonRequestType.PublicKey) + await require_confirm(ctx, text, ButtonRequestType.PublicKey) -def split_address(address: str): +def split_address(address: str) -> Iterable[str]: return chunks(address, 17) def address_n_to_str(address_n: list) -> str: - def path_item(i: int): + def path_item(i: int) -> str: if i & HARDENED: return str(i ^ HARDENED) + "'" else: diff --git a/core/src/apps/common/mnemonic/__init__.py b/core/src/apps/common/mnemonic/__init__.py index 62a1cb734..3730b0dfa 100644 --- a/core/src/apps/common/mnemonic/__init__.py +++ b/core/src/apps/common/mnemonic/__init__.py @@ -8,13 +8,16 @@ from trezor.utils import consteq from apps.common import storage from apps.common.mnemonic import bip39, slip39 +if False: + from typing import Any, Tuple + TYPE_BIP39 = const(0) TYPE_SLIP39 = const(1) TYPES_WORD_COUNT = {12: bip39, 18: bip39, 24: bip39, 20: slip39, 33: slip39} -def get() -> (bytes, int): +def get() -> Tuple[bytes, int]: mnemonic_secret = storage.device.get_mnemonic_secret() mnemonic_type = storage.device.get_mnemonic_type() or TYPE_BIP39 if mnemonic_type not in (TYPE_BIP39, TYPE_SLIP39): @@ -22,15 +25,16 @@ def get() -> (bytes, int): return mnemonic_secret, mnemonic_type -def get_seed(passphrase: str = "", progress_bar=True): +def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes: mnemonic_secret, mnemonic_type = get() if mnemonic_type == TYPE_BIP39: return bip39.get_seed(mnemonic_secret, passphrase, progress_bar) elif mnemonic_type == TYPE_SLIP39: return slip39.get_seed(mnemonic_secret, passphrase, progress_bar) + raise ValueError("Unknown mnemonic type") -def dry_run(secret: bytes): +def dry_run(secret: bytes) -> None: digest_input = sha256(secret).digest() stored, _ = get() digest_stored = sha256(stored).digest() @@ -42,11 +46,11 @@ def dry_run(secret: bytes): ) -def module_from_words_count(count: int): +def module_from_words_count(count: int) -> Any: return TYPES_WORD_COUNT[count] -def _start_progress(): +def _start_progress() -> None: workflow.closedefault() ui.backlight_fade(ui.BACKLIGHT_DIM) ui.display.clear() @@ -55,11 +59,11 @@ def _start_progress(): ui.backlight_fade(ui.BACKLIGHT_NORMAL) -def _render_progress(progress: int, total: int): +def _render_progress(progress: int, total: int) -> None: p = 1000 * progress // total ui.display.loader(p, False, 18, ui.WHITE, ui.BG) ui.display.refresh() -def _stop_progress(): +def _stop_progress() -> None: pass diff --git a/core/src/apps/common/mnemonic/bip39.py b/core/src/apps/common/mnemonic/bip39.py index 74e4ebe81..9b757e256 100644 --- a/core/src/apps/common/mnemonic/bip39.py +++ b/core/src/apps/common/mnemonic/bip39.py @@ -3,7 +3,7 @@ from trezor.crypto import bip39 from apps.common import mnemonic, storage -def get_type(): +def get_type() -> int: return mnemonic.TYPE_BIP39 @@ -23,20 +23,19 @@ def process_all(mnemonics: list) -> bytes: return mnemonics[0].encode() -def store(secret: bytes, needs_backup: bool, no_backup: bool): +def store(secret: bytes, needs_backup: bool, no_backup: bool) -> None: storage.device.store_mnemonic_secret( secret, mnemonic.TYPE_BIP39, needs_backup, no_backup ) -def get_seed(secret: bytes, passphrase: str, progress_bar=True): +def get_seed(secret: bytes, passphrase: str, progress_bar: bool = True) -> bytes: if progress_bar: mnemonic._start_progress() seed = bip39.seed(secret.decode(), passphrase, mnemonic._render_progress) mnemonic._stop_progress() else: seed = bip39.seed(secret.decode(), passphrase) - return seed @@ -44,5 +43,5 @@ def get_mnemonic_threshold(mnemonic: str) -> int: return 1 -def check(secret: bytes): +def check(secret: bytes) -> bool: return bip39.check(secret) diff --git a/core/src/apps/common/mnemonic/slip39.py b/core/src/apps/common/mnemonic/slip39.py index 3e54476c5..cde3ebcd3 100644 --- a/core/src/apps/common/mnemonic/slip39.py +++ b/core/src/apps/common/mnemonic/slip39.py @@ -2,6 +2,9 @@ from trezor.crypto import slip39 from apps.common import mnemonic, storage +if False: + from typing import Optional + def generate_from_secret(master_secret: bytes, count: int, threshold: int) -> list: """ @@ -12,11 +15,11 @@ def generate_from_secret(master_secret: bytes, count: int, threshold: int) -> li ) -def get_type(): +def get_type() -> int: return mnemonic.TYPE_SLIP39 -def process_single(mnemonic: str) -> bytes: +def process_single(mnemonic: str) -> Optional[bytes]: """ Receives single mnemonic and processes it. Returns what is then stored in storage or None if more shares are needed. @@ -72,14 +75,16 @@ def process_all(mnemonics: list) -> bytes: return secret -def store(secret: bytes, needs_backup: bool, no_backup: bool): +def store(secret: bytes, needs_backup: bool, no_backup: bool) -> None: storage.device.store_mnemonic_secret( secret, mnemonic.TYPE_SLIP39, needs_backup, no_backup ) storage.slip39.delete_progress() -def get_seed(encrypted_master_secret: bytes, passphrase: str, progress_bar=True): +def get_seed( + encrypted_master_secret: bytes, passphrase: str, progress_bar: bool = True +) -> bytes: if progress_bar: mnemonic._start_progress() identifier = storage.slip39.get_identifier() diff --git a/core/src/apps/common/paths.py b/core/src/apps/common/paths.py index aebfc900e..2cbb8017c 100644 --- a/core/src/apps/common/paths.py +++ b/core/src/apps/common/paths.py @@ -7,20 +7,32 @@ from trezor.ui.text import Text from apps.common import HARDENED from apps.common.confirm import require_confirm +if False: + from typing import Any, Callable, List + from trezor import wire + from apps.common import seed -async def validate_path(ctx, validate_func, keychain, path, curve, **kwargs): + +async def validate_path( + ctx: wire.Context, + validate_func: Callable[..., bool], + keychain: seed.Keychain, + path: List[int], + curve: str, + **kwargs: Any, +) -> None: keychain.validate_path(path, curve) if not validate_func(path, **kwargs): await show_path_warning(ctx, path) -async def show_path_warning(ctx, path: list): +async def show_path_warning(ctx: wire.Context, path: List[int]) -> None: text = Text("Confirm path", ui.ICON_WRONG, ui.RED) text.normal("Path") text.mono(*break_address_n_to_lines(path)) text.normal("is unknown.") text.normal("Are you sure?") - return await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath) + await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath) def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool: @@ -53,7 +65,7 @@ def is_hardened(i: int) -> bool: def break_address_n_to_lines(address_n: list) -> list: - def path_item(i: int): + def path_item(i: int) -> str: if i & HARDENED: return str(i ^ HARDENED) + "'" else: diff --git a/core/src/apps/common/request_passphrase.py b/core/src/apps/common/request_passphrase.py index e293ff2ff..92306ebf8 100644 --- a/core/src/apps/common/request_passphrase.py +++ b/core/src/apps/common/request_passphrase.py @@ -1,9 +1,12 @@ from micropython import const from trezor import ui, wire -from trezor.messages import ButtonRequestType, MessageType, PassphraseSourceType +from trezor.messages import ButtonRequestType, PassphraseSourceType +from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonRequest import ButtonRequest +from trezor.messages.PassphraseAck import PassphraseAck from trezor.messages.PassphraseRequest import PassphraseRequest +from trezor.messages.PassphraseStateAck import PassphraseStateAck from trezor.messages.PassphraseStateRequest import PassphraseStateRequest from trezor.ui.passphrase import CANCELLED, PassphraseKeyboard, PassphraseSource from trezor.ui.popup import Popup @@ -17,14 +20,14 @@ if __debug__: _MAX_PASSPHRASE_LEN = const(50) -async def protect_by_passphrase(ctx) -> str: +async def protect_by_passphrase(ctx: wire.Context) -> str: if storage.device.has_passphrase(): return await request_passphrase(ctx) else: return "" -async def request_passphrase(ctx) -> str: +async def request_passphrase(ctx: wire.Context) -> str: source = storage.device.get_passphrase_source() if source == PassphraseSourceType.ASK: source = await request_passphrase_source(ctx) @@ -36,9 +39,9 @@ async def request_passphrase(ctx) -> str: return passphrase -async def request_passphrase_source(ctx) -> int: +async def request_passphrase_source(ctx: wire.Context) -> int: req = ButtonRequest(code=ButtonRequestType.PassphraseType) - await ctx.call(req, MessageType.ButtonAck) + await ctx.call(req, ButtonAck) text = Text("Enter passphrase", ui.ICON_CONFIG) text.normal("Where to enter your", "passphrase?") @@ -47,14 +50,14 @@ async def request_passphrase_source(ctx) -> int: return await ctx.wait(source) -async def request_passphrase_ack(ctx, on_device: bool) -> str: +async def request_passphrase_ack(ctx: wire.Context, on_device: bool) -> str: if not on_device: text = Text("Passphrase entry", ui.ICON_CONFIG) text.normal("Please, type passphrase", "on connected host.") await Popup(text) req = PassphraseRequest(on_device=on_device) - ack = await ctx.call(req, MessageType.PassphraseAck) + ack = await ctx.call(req, PassphraseAck) if on_device: if ack.passphrase is not None: @@ -74,6 +77,6 @@ async def request_passphrase_ack(ctx, on_device: bool) -> str: state = cache.get_state(prev_state=ack.state, passphrase=passphrase) req = PassphraseStateRequest(state=state) - ack = await ctx.call(req, MessageType.PassphraseStateAck, MessageType.Cancel) + ack = await ctx.call(req, PassphraseStateAck) return passphrase diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 15ec2c592..6129e765c 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -4,7 +4,8 @@ from trezor.crypto import bip32 from apps.common import HARDENED, cache, mnemonic, storage from apps.common.request_passphrase import protect_by_passphrase -allow = list +if False: + from typing import List, Optional class Keychain: @@ -16,16 +17,16 @@ class Keychain: def __init__(self, seed: bytes, namespaces: list): self.seed = seed self.namespaces = namespaces - self.roots = [None] * len(namespaces) + self.roots = [None] * len(namespaces) # type: List[Optional[bip32.HDNode]] - def __del__(self): + def __del__(self) -> None: for root in self.roots: if root is not None: root.__del__() del self.roots del self.seed - def validate_path(self, checked_path: list, checked_curve: str): + def validate_path(self, checked_path: list, checked_curve: str) -> None: for curve, *path in self.namespaces: if path == checked_path[: len(path)] and curve == checked_curve: if "ed25519" in curve and not _path_hardened(checked_path): diff --git a/core/src/apps/common/signverify.py b/core/src/apps/common/signverify.py index 4b1373123..0bd94df2f 100644 --- a/core/src/apps/common/signverify.py +++ b/core/src/apps/common/signverify.py @@ -5,8 +5,12 @@ from trezor.utils import HashWriter from apps.wallet.sign_tx.writers import write_varint +if False: + from typing import List + from apps.common.coininfo import CoinType -def message_digest(coin, message): + +def message_digest(coin: CoinType, message: bytes) -> bytes: if coin.decred: h = HashWriter(blake256()) else: @@ -21,7 +25,7 @@ def message_digest(coin, message): return ret -def split_message(message): +def split_message(message: bytes) -> List[str]: try: m = bytes(message).decode() words = m.split(" ") diff --git a/core/src/apps/common/writers.py b/core/src/apps/common/writers.py index feeb32590..ec23c2518 100644 --- a/core/src/apps/common/writers.py +++ b/core/src/apps/common/writers.py @@ -1,5 +1,8 @@ from trezor.utils import ensure +if False: + from trezor.utils import Writer + def empty_bytearray(preallocate: int) -> bytearray: """ @@ -11,27 +14,27 @@ def empty_bytearray(preallocate: int) -> bytearray: return b -def write_uint8(w: bytearray, n: int) -> int: +def write_uint8(w: Writer, n: int) -> int: ensure(0 <= n <= 0xFF) w.append(n) return 1 -def write_uint16_le(w: bytearray, n: int) -> int: +def write_uint16_le(w: Writer, n: int) -> int: ensure(0 <= n <= 0xFFFF) w.append(n & 0xFF) w.append((n >> 8) & 0xFF) return 2 -def write_uint16_be(w: bytearray, n: int): +def write_uint16_be(w: Writer, n: int) -> int: ensure(0 <= n <= 0xFFFF) w.append((n >> 8) & 0xFF) w.append(n & 0xFF) return 2 -def write_uint32_le(w: bytearray, n: int) -> int: +def write_uint32_le(w: Writer, n: int) -> int: ensure(0 <= n <= 0xFFFFFFFF) w.append(n & 0xFF) w.append((n >> 8) & 0xFF) @@ -40,7 +43,7 @@ def write_uint32_le(w: bytearray, n: int) -> int: return 4 -def write_uint32_be(w: bytearray, n: int) -> int: +def write_uint32_be(w: Writer, n: int) -> int: ensure(0 <= n <= 0xFFFFFFFF) w.append((n >> 24) & 0xFF) w.append((n >> 16) & 0xFF) @@ -49,7 +52,7 @@ def write_uint32_be(w: bytearray, n: int) -> int: return 4 -def write_uint64_le(w: bytearray, n: int) -> int: +def write_uint64_le(w: Writer, n: int) -> int: ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF) w.append(n & 0xFF) w.append((n >> 8) & 0xFF) @@ -62,7 +65,7 @@ def write_uint64_le(w: bytearray, n: int) -> int: return 8 -def write_uint64_be(w: bytearray, n: int) -> int: +def write_uint64_be(w: Writer, n: int) -> int: ensure(0 <= n <= 0xFFFFFFFFFFFFFFFF) w.append((n >> 56) & 0xFF) w.append((n >> 48) & 0xFF) @@ -75,11 +78,11 @@ def write_uint64_be(w: bytearray, n: int) -> int: return 8 -def write_bytes(w: bytearray, b: bytes) -> int: +def write_bytes(w: Writer, b: bytes) -> int: w.extend(b) return len(b) -def write_bytes_reversed(w: bytearray, b: bytes) -> int: +def write_bytes_reversed(w: Writer, b: bytes) -> int: w.extend(bytes(reversed(b))) return len(b) diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index c6d7edabd..a945f7990 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -8,15 +8,24 @@ if __debug__: from trezor.messages import MessageType from trezor.wire import register, protobuf_workflow - reset_internal_entropy = None - reset_current_words = None - reset_word_index = None + if False: + from typing import List, Optional + from trezor import wire + from trezor.messages.DebugLinkDecision import DebugLinkDecision + from trezor.messages.DebugLinkGetState import DebugLinkGetState + from trezor.messages.DebugLinkState import DebugLinkState + + reset_internal_entropy = None # type: Optional[bytes] + reset_current_words = None # type: Optional[List[str]] + reset_word_index = None # type: Optional[int] confirm_signal = loop.signal() swipe_signal = loop.signal() input_signal = loop.signal() - async def dispatch_DebugLinkDecision(ctx, msg): + async def dispatch_DebugLinkDecision( + ctx: wire.Context, msg: DebugLinkDecision + ) -> None: from trezor.ui import confirm, swipe if msg.yes_no is not None: @@ -26,7 +35,9 @@ if __debug__: if msg.input is not None: input_signal.send(msg.input) - async def dispatch_DebugLinkGetState(ctx, msg): + async def dispatch_DebugLinkGetState( + ctx: wire.Context, msg: DebugLinkGetState + ) -> DebugLinkState: from trezor.messages.DebugLinkState import DebugLinkState from apps.common import storage, mnemonic @@ -39,7 +50,7 @@ if __debug__: m.reset_word = " ".join(reset_current_words) return m - def boot(): + def boot() -> None: # wipe storage when debug build is used on real hardware if not utils.EMULATOR: config.wipe() diff --git a/core/src/apps/eos/__init__.py b/core/src/apps/eos/__init__.py index 5ba816604..4c53f4ec5 100755 --- a/core/src/apps/eos/__init__.py +++ b/core/src/apps/eos/__init__.py @@ -6,7 +6,7 @@ from apps.common import HARDENED CURVE = "secp256k1" -def boot(): +def boot() -> None: ns = [[CURVE, HARDENED | 44, HARDENED | 194]] wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key", ns) diff --git a/core/src/apps/eos/actions/__init__.py b/core/src/apps/eos/actions/__init__.py index 4c6456499..8adf5ccfb 100644 --- a/core/src/apps/eos/actions/__init__.py +++ b/core/src/apps/eos/actions/__init__.py @@ -1,13 +1,19 @@ from trezor.crypto.hashlib import sha256 +from trezor.messages.EosTxActionAck import EosTxActionAck from trezor.messages.EosTxActionRequest import EosTxActionRequest -from trezor.messages.MessageType import EosTxActionAck from trezor.utils import HashWriter from apps.eos import helpers, writers from apps.eos.actions import layout +if False: + from trezor import wire + from trezor.utils import Writer -async def process_action(ctx, sha, action): + +async def process_action( + ctx: wire.Context, sha: HashWriter, action: EosTxActionAck +) -> None: name = helpers.eos_name_to_string(action.common.name) account = helpers.eos_name_to_string(action.common.account) @@ -65,7 +71,9 @@ async def process_action(ctx, sha, action): writers.write_bytes(sha, w) -async def process_unknown_action(ctx, w, action): +async def process_unknown_action( + ctx: wire.Context, w: Writer, action: EosTxActionAck +) -> None: checksum = HashWriter(sha256()) writers.write_variant32(checksum, action.unknown.data_size) checksum.extend(action.unknown.data_chunk) @@ -91,7 +99,7 @@ async def process_unknown_action(ctx, w, action): await layout.confirm_action_unknown(ctx, action.common, checksum.get_digest()) -def check_action(action, name, account): +def check_action(action: EosTxActionAck, name: str, account: str) -> bool: if account == "eosio": if ( (name == "buyram" and action.buy_ram is not None) diff --git a/core/src/apps/eos/actions/layout.py b/core/src/apps/eos/actions/layout.py index a888af0b3..e9a4296fc 100644 --- a/core/src/apps/eos/actions/layout.py +++ b/core/src/apps/eos/actions/layout.py @@ -2,22 +2,7 @@ from micropython import const from ubinascii import hexlify from trezor import ui -from trezor.messages import ( - ButtonRequestType, - EosActionBuyRam, - EosActionBuyRamBytes, - EosActionDelegate, - EosActionDeleteAuth, - EosActionLinkAuth, - EosActionNewAccount, - EosActionRefund, - EosActionSellRam, - EosActionTransfer, - EosActionUndelegate, - EosActionUnlinkAuth, - EosActionUpdateAuth, - EosActionVoteProducer, -) +from trezor.messages import ButtonRequestType from trezor.ui.scroll import Paginated from trezor.ui.text import Text from trezor.utils import chunks @@ -26,6 +11,25 @@ from apps.eos import helpers from apps.eos.get_public_key import _public_key_to_wif from apps.eos.layout import require_confirm +if False: + from typing import List + from trezor import wire + from trezor.messages.EosAuthorization import EosAuthorization + from trezor.messages.EosActionBuyRam import EosActionBuyRam + from trezor.messages.EosActionBuyRamBytes import EosActionBuyRamBytes + from trezor.messages.EosActionCommon import EosActionCommon + from trezor.messages.EosActionDelegate import EosActionDelegate + from trezor.messages.EosActionDeleteAuth import EosActionDeleteAuth + from trezor.messages.EosActionLinkAuth import EosActionLinkAuth + from trezor.messages.EosActionNewAccount import EosActionNewAccount + from trezor.messages.EosActionRefund import EosActionRefund + from trezor.messages.EosActionSellRam import EosActionSellRam + from trezor.messages.EosActionTransfer import EosActionTransfer + from trezor.messages.EosActionUndelegate import EosActionUndelegate + from trezor.messages.EosActionUnlinkAuth import EosActionUnlinkAuth + from trezor.messages.EosActionUpdateAuth import EosActionUpdateAuth + from trezor.messages.EosActionVoteProducer import EosActionVoteProducer + _LINE_LENGTH = const(17) _LINE_PLACEHOLDER = "{:<" + str(_LINE_LENGTH) + "}" _FIRST_PAGE = const(0) @@ -35,7 +39,9 @@ _FOUR_FIELDS_PER_PAGE = const(4) _FIVE_FIELDS_PER_PAGE = const(5) -async def _require_confirm_paginated(ctx, header, fields, per_page): +async def _require_confirm_paginated( + ctx: wire.Context, header: str, fields: List[str], per_page: int +) -> None: pages = [] for page in chunks(fields, per_page): if header == "Arbitrary data": @@ -47,7 +53,7 @@ async def _require_confirm_paginated(ctx, header, fields, per_page): await require_confirm(ctx, Paginated(pages), ButtonRequestType.ConfirmOutput) -async def confirm_action_buyram(ctx, msg: EosActionBuyRam): +async def confirm_action_buyram(ctx: wire.Context, msg: EosActionBuyRam) -> None: text = "Buy RAM" fields = [] fields.append("Payer:") @@ -59,7 +65,9 @@ async def confirm_action_buyram(ctx, msg: EosActionBuyRam): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_buyrambytes(ctx, msg: EosActionBuyRamBytes): +async def confirm_action_buyrambytes( + ctx: wire.Context, msg: EosActionBuyRamBytes +) -> None: text = "Buy RAM" fields = [] fields.append("Payer:") @@ -71,7 +79,7 @@ async def confirm_action_buyrambytes(ctx, msg: EosActionBuyRamBytes): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_delegate(ctx, msg: EosActionDelegate): +async def confirm_action_delegate(ctx: wire.Context, msg: EosActionDelegate) -> None: text = "Delegate" fields = [] fields.append("Sender:") @@ -93,7 +101,7 @@ async def confirm_action_delegate(ctx, msg: EosActionDelegate): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_sellram(ctx, msg: EosActionSellRam): +async def confirm_action_sellram(ctx: wire.Context, msg: EosActionSellRam) -> None: text = "Sell RAM" fields = [] fields.append("Receiver:") @@ -103,7 +111,9 @@ async def confirm_action_sellram(ctx, msg: EosActionSellRam): await _require_confirm_paginated(ctx, text, fields, _TWO_FIELDS_PER_PAGE) -async def confirm_action_undelegate(ctx, msg: EosActionUndelegate): +async def confirm_action_undelegate( + ctx: wire.Context, msg: EosActionUndelegate +) -> None: text = "Undelegate" fields = [] fields.append("Sender:") @@ -117,14 +127,16 @@ async def confirm_action_undelegate(ctx, msg: EosActionUndelegate): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_refund(ctx, msg: EosActionRefund): +async def confirm_action_refund(ctx: wire.Context, msg: EosActionRefund) -> None: text = Text("Refund", ui.ICON_CONFIRM, icon_color=ui.GREEN) text.normal("Owner:") text.normal(helpers.eos_name_to_string(msg.owner)) await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) -async def confirm_action_voteproducer(ctx, msg: EosActionVoteProducer): +async def confirm_action_voteproducer( + ctx: wire.Context, msg: EosActionVoteProducer +) -> None: if msg.proxy and not msg.producers: # PROXY text = Text("Vote for proxy", ui.ICON_CONFIRM, icon_color=ui.GREEN) @@ -151,7 +163,9 @@ async def confirm_action_voteproducer(ctx, msg: EosActionVoteProducer): await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) -async def confirm_action_transfer(ctx, msg: EosActionTransfer, account: str): +async def confirm_action_transfer( + ctx: wire.Context, msg: EosActionTransfer, account: str +) -> None: text = "Transfer" fields = [] fields.append("From:") @@ -170,7 +184,9 @@ async def confirm_action_transfer(ctx, msg: EosActionTransfer, account: str): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_updateauth(ctx, msg: EosActionUpdateAuth): +async def confirm_action_updateauth( + ctx: wire.Context, msg: EosActionUpdateAuth +) -> None: text = "Update Auth" fields = [] fields.append("Account:") @@ -183,7 +199,9 @@ async def confirm_action_updateauth(ctx, msg: EosActionUpdateAuth): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_deleteauth(ctx, msg: EosActionDeleteAuth): +async def confirm_action_deleteauth( + ctx: wire.Context, msg: EosActionDeleteAuth +) -> None: text = Text("Delete auth", ui.ICON_CONFIRM, icon_color=ui.GREEN) text.normal("Account:") text.normal(helpers.eos_name_to_string(msg.account)) @@ -192,7 +210,7 @@ async def confirm_action_deleteauth(ctx, msg: EosActionDeleteAuth): await require_confirm(ctx, text, ButtonRequestType.ConfirmOutput) -async def confirm_action_linkauth(ctx, msg: EosActionLinkAuth): +async def confirm_action_linkauth(ctx: wire.Context, msg: EosActionLinkAuth) -> None: text = "Link Auth" fields = [] fields.append("Account:") @@ -206,7 +224,9 @@ async def confirm_action_linkauth(ctx, msg: EosActionLinkAuth): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_unlinkauth(ctx, msg: EosActionUnlinkAuth): +async def confirm_action_unlinkauth( + ctx: wire.Context, msg: EosActionUnlinkAuth +) -> None: text = "Unlink Auth" fields = [] fields.append("Account:") @@ -218,7 +238,9 @@ async def confirm_action_unlinkauth(ctx, msg: EosActionUnlinkAuth): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_newaccount(ctx, msg: EosActionNewAccount): +async def confirm_action_newaccount( + ctx: wire.Context, msg: EosActionNewAccount +) -> None: text = "New Account" fields = [] fields.append("Creator:") @@ -230,7 +252,9 @@ async def confirm_action_newaccount(ctx, msg: EosActionNewAccount): await _require_confirm_paginated(ctx, text, fields, _FOUR_FIELDS_PER_PAGE) -async def confirm_action_unknown(ctx, action, checksum): +async def confirm_action_unknown( + ctx: wire.Context, action: EosActionCommon, checksum: bytes +) -> None: text = "Arbitrary data" fields = [] fields.append("Contract:") @@ -242,7 +266,7 @@ async def confirm_action_unknown(ctx, action, checksum): await _require_confirm_paginated(ctx, text, fields, _FIVE_FIELDS_PER_PAGE) -def authorization_fields(auth): +def authorization_fields(auth: EosAuthorization) -> List[str]: fields = [] fields.append("Threshold:") @@ -288,11 +312,9 @@ def authorization_fields(auth): return fields -def split_data(data): - temp_list = [] - len_left = len(data) - while len_left > 0: - temp_list.append("{} ".format(data[:_LINE_LENGTH])) +def split_data(data: str) -> List[str]: + lines = [] + while data: + lines.append("{} ".format(data[:_LINE_LENGTH])) data = data[_LINE_LENGTH:] - len_left = len(data) - return temp_list + return lines diff --git a/core/src/apps/eos/get_public_key.py b/core/src/apps/eos/get_public_key.py index 39a68b04c..74f5de0d3 100755 --- a/core/src/apps/eos/get_public_key.py +++ b/core/src/apps/eos/get_public_key.py @@ -8,6 +8,11 @@ from apps.eos import CURVE from apps.eos.helpers import base58_encode, validate_full_path from apps.eos.layout import require_get_public_key +if False: + from typing import Tuple + from trezor.crypto import bip32 + from apps.common import seed + def _public_key_to_wif(pub_key: bytes) -> str: if pub_key[0] == 0x04 and len(pub_key) == 65: @@ -20,14 +25,16 @@ def _public_key_to_wif(pub_key: bytes) -> str: return base58_encode("EOS", "", compressed_pub_key) -def _get_public_key(node): +def _get_public_key(node: bip32.HDNode) -> Tuple[str, bytes]: seckey = node.private_key() public_key = secp256k1.publickey(seckey, True) wif = _public_key_to_wif(public_key) return wif, public_key -async def get_public_key(ctx, msg: EosGetPublicKey, keychain): +async def get_public_key( + ctx: wire.Context, msg: EosGetPublicKey, keychain: seed.Keychain +) -> EosPublicKey: await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) node = keychain.derive(msg.address_n) diff --git a/core/src/apps/eos/helpers.py b/core/src/apps/eos/helpers.py index 0fc6bac9b..a440b8cf2 100644 --- a/core/src/apps/eos/helpers.py +++ b/core/src/apps/eos/helpers.py @@ -1,5 +1,5 @@ from trezor.crypto import base58 -from trezor.messages import EosAsset +from trezor.messages.EosAsset import EosAsset from apps.common import HARDENED @@ -12,7 +12,7 @@ def base58_encode(prefix: str, sig_prefix: str, data: bytes) -> str: return prefix + b58 -def eos_name_to_string(value) -> str: +def eos_name_to_string(value: int) -> str: charmap = ".12345abcdefghijklmnopqrstuvwxyz" tmp = value string = "" diff --git a/core/src/apps/eos/layout.py b/core/src/apps/eos/layout.py index 70677b849..41f0b17fa 100644 --- a/core/src/apps/eos/layout.py +++ b/core/src/apps/eos/layout.py @@ -1,19 +1,19 @@ -from trezor import ui +from trezor import ui, wire from trezor.messages import ButtonRequestType from trezor.ui.text import Text from apps.common.confirm import require_confirm -async def require_get_public_key(ctx, public_key): +async def require_get_public_key(ctx: wire.Context, public_key: str) -> None: text = Text("Confirm public key", ui.ICON_RECEIVE, ui.GREEN) text.normal(public_key) - return await require_confirm(ctx, text, ButtonRequestType.PublicKey) + await require_confirm(ctx, text, ButtonRequestType.PublicKey) -async def require_sign_tx(ctx, num_actions): +async def require_sign_tx(ctx: wire.Context, num_actions: int) -> None: text = Text("Sign transaction", ui.ICON_SEND, ui.GREEN) text.normal("You are about") text.normal("to sign {}".format(num_actions)) text.normal("action(s).") - return await require_confirm(ctx, text, ButtonRequestType.SignTx) + await require_confirm(ctx, text, ButtonRequestType.SignTx) diff --git a/core/src/apps/eos/sign_tx.py b/core/src/apps/eos/sign_tx.py index f959da686..db82da859 100644 --- a/core/src/apps/eos/sign_tx.py +++ b/core/src/apps/eos/sign_tx.py @@ -3,8 +3,8 @@ from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import sha256 from trezor.messages.EosSignedTx import EosSignedTx from trezor.messages.EosSignTx import EosSignTx +from trezor.messages.EosTxActionAck import EosTxActionAck from trezor.messages.EosTxActionRequest import EosTxActionRequest -from trezor.messages.MessageType import EosTxActionAck from trezor.utils import HashWriter from apps.common import paths @@ -13,8 +13,13 @@ from apps.eos.actions import process_action from apps.eos.helpers import base58_encode, validate_full_path from apps.eos.layout import require_sign_tx +if False: + from apps.common import seed -async def sign_tx(ctx, msg: EosSignTx, keychain): + +async def sign_tx( + ctx: wire.Context, msg: EosSignTx, keychain: seed.Keychain +) -> EosSignedTx: if msg.chain_id is None: raise wire.DataError("No chain id") if msg.header is None: @@ -39,7 +44,7 @@ async def sign_tx(ctx, msg: EosSignTx, keychain): return EosSignedTx(signature=base58_encode("SIG_", "K1", signature)) -async def _init(ctx, sha, msg): +async def _init(ctx: wire.Context, sha: HashWriter, msg: EosSignTx) -> None: writers.write_bytes(sha, msg.chain_id) writers.write_header(sha, msg.header) writers.write_variant32(sha, 0) @@ -48,7 +53,7 @@ async def _init(ctx, sha, msg): await require_sign_tx(ctx, msg.num_actions) -async def _actions(ctx, sha, num_actions: int): +async def _actions(ctx: wire.Context, sha: HashWriter, num_actions: int) -> None: for i in range(num_actions): action = await ctx.call(EosTxActionRequest(), EosTxActionAck) await process_action(ctx, sha, action) diff --git a/core/src/apps/eos/writers.py b/core/src/apps/eos/writers.py index 5b392bb3b..81ea92912 100644 --- a/core/src/apps/eos/writers.py +++ b/core/src/apps/eos/writers.py @@ -1,23 +1,3 @@ -from trezor.messages import ( - EosActionBuyRam, - EosActionBuyRamBytes, - EosActionCommon, - EosActionDelegate, - EosActionDeleteAuth, - EosActionLinkAuth, - EosActionNewAccount, - EosActionRefund, - EosActionSellRam, - EosActionTransfer, - EosActionUndelegate, - EosActionUpdateAuth, - EosActionVoteProducer, - EosAsset, - EosAuthorization, - EosTxHeader, -) -from trezor.utils import HashWriter - from apps.common.writers import ( write_bytes, write_uint8, @@ -26,8 +6,27 @@ from apps.common.writers import ( write_uint64_le, ) - -def write_auth(w: bytearray, auth: EosAuthorization) -> int: +if False: + from trezor.messages.EosActionBuyRam import EosActionBuyRam + from trezor.messages.EosActionBuyRamBytes import EosActionBuyRamBytes + from trezor.messages.EosActionCommon import EosActionCommon + from trezor.messages.EosActionDelegate import EosActionDelegate + from trezor.messages.EosActionDeleteAuth import EosActionDeleteAuth + from trezor.messages.EosActionLinkAuth import EosActionLinkAuth + from trezor.messages.EosActionNewAccount import EosActionNewAccount + from trezor.messages.EosActionRefund import EosActionRefund + from trezor.messages.EosActionSellRam import EosActionSellRam + from trezor.messages.EosActionTransfer import EosActionTransfer + from trezor.messages.EosActionUndelegate import EosActionUndelegate + from trezor.messages.EosActionUpdateAuth import EosActionUpdateAuth + from trezor.messages.EosActionVoteProducer import EosActionVoteProducer + from trezor.messages.EosAsset import EosAsset + from trezor.messages.EosAuthorization import EosAuthorization + from trezor.messages.EosTxHeader import EosTxHeader + from trezor.utils import Writer + + +def write_auth(w: Writer, auth: EosAuthorization) -> None: write_uint32_le(w, auth.threshold) write_variant32(w, len(auth.keys)) for key in auth.keys: @@ -47,7 +46,7 @@ def write_auth(w: bytearray, auth: EosAuthorization) -> int: write_uint16_le(w, wait.weight) -def write_header(hasher: HashWriter, header: EosTxHeader): +def write_header(hasher: Writer, header: EosTxHeader) -> None: write_uint32_le(hasher, header.expiration) write_uint16_le(hasher, header.ref_block_num) write_uint32_le(hasher, header.ref_block_prefix) @@ -56,7 +55,7 @@ def write_header(hasher: HashWriter, header: EosTxHeader): write_variant32(hasher, header.delay_sec) -def write_action_transfer(w: bytearray, msg: EosActionTransfer): +def write_action_transfer(w: Writer, msg: EosActionTransfer) -> None: write_uint64_le(w, msg.sender) write_uint64_le(w, msg.receiver) write_asset(w, msg.quantity) @@ -64,24 +63,24 @@ def write_action_transfer(w: bytearray, msg: EosActionTransfer): write_bytes(w, msg.memo) -def write_action_buyram(w: bytearray, msg: EosActionBuyRam): +def write_action_buyram(w: Writer, msg: EosActionBuyRam) -> None: write_uint64_le(w, msg.payer) write_uint64_le(w, msg.receiver) write_asset(w, msg.quantity) -def write_action_buyrambytes(w: bytearray, msg: EosActionBuyRamBytes): +def write_action_buyrambytes(w: Writer, msg: EosActionBuyRamBytes) -> None: write_uint64_le(w, msg.payer) write_uint64_le(w, msg.receiver) write_uint32_le(w, msg.bytes) -def write_action_sellram(w: bytearray, msg: EosActionSellRam): +def write_action_sellram(w: Writer, msg: EosActionSellRam) -> None: write_uint64_le(w, msg.account) write_uint64_le(w, msg.bytes) -def write_action_delegate(w: bytearray, msg: EosActionDelegate): +def write_action_delegate(w: Writer, msg: EosActionDelegate) -> None: write_uint64_le(w, msg.sender) write_uint64_le(w, msg.receiver) write_asset(w, msg.net_quantity) @@ -89,18 +88,18 @@ def write_action_delegate(w: bytearray, msg: EosActionDelegate): write_uint8(w, 1 if msg.transfer else 0) -def write_action_undelegate(w: bytearray, msg: EosActionUndelegate): +def write_action_undelegate(w: Writer, msg: EosActionUndelegate) -> None: write_uint64_le(w, msg.sender) write_uint64_le(w, msg.receiver) write_asset(w, msg.net_quantity) write_asset(w, msg.cpu_quantity) -def write_action_refund(w: bytearray, msg: EosActionRefund): +def write_action_refund(w: Writer, msg: EosActionRefund) -> None: write_uint64_le(w, msg.owner) -def write_action_voteproducer(w: bytearray, msg: EosActionVoteProducer): +def write_action_voteproducer(w: Writer, msg: EosActionVoteProducer) -> None: write_uint64_le(w, msg.voter) write_uint64_le(w, msg.proxy) write_variant32(w, len(msg.producers)) @@ -108,61 +107,59 @@ def write_action_voteproducer(w: bytearray, msg: EosActionVoteProducer): write_uint64_le(w, producer) -def write_action_updateauth(w: bytearray, msg: EosActionUpdateAuth): +def write_action_updateauth(w: Writer, msg: EosActionUpdateAuth) -> None: write_uint64_le(w, msg.account) write_uint64_le(w, msg.permission) write_uint64_le(w, msg.parent) write_auth(w, msg.auth) -def write_action_deleteauth(w: bytearray, msg: EosActionDeleteAuth): +def write_action_deleteauth(w: Writer, msg: EosActionDeleteAuth) -> None: write_uint64_le(w, msg.account) write_uint64_le(w, msg.permission) -def write_action_linkauth(w: bytearray, msg: EosActionLinkAuth): +def write_action_linkauth(w: Writer, msg: EosActionLinkAuth) -> None: write_uint64_le(w, msg.account) write_uint64_le(w, msg.code) write_uint64_le(w, msg.type) write_uint64_le(w, msg.requirement) -def write_action_unlinkauth(w: bytearray, msg: EosActionLinkAuth): +def write_action_unlinkauth(w: Writer, msg: EosActionLinkAuth) -> None: write_uint64_le(w, msg.account) write_uint64_le(w, msg.code) write_uint64_le(w, msg.type) -def write_action_newaccount(w: bytearray, msg: EosActionNewAccount): +def write_action_newaccount(w: Writer, msg: EosActionNewAccount) -> None: write_uint64_le(w, msg.creator) write_uint64_le(w, msg.name) write_auth(w, msg.owner) write_auth(w, msg.active) -def write_action_common(hasher: HashWriter, msg: EosActionCommon): - write_uint64_le(hasher, msg.account) - write_uint64_le(hasher, msg.name) - write_variant32(hasher, len(msg.authorization)) +def write_action_common(w: Writer, msg: EosActionCommon) -> None: + write_uint64_le(w, msg.account) + write_uint64_le(w, msg.name) + write_variant32(w, len(msg.authorization)) for authorization in msg.authorization: - write_uint64_le(hasher, authorization.actor) - write_uint64_le(hasher, authorization.permission) + write_uint64_le(w, authorization.actor) + write_uint64_le(w, authorization.permission) -def write_asset(w: bytearray, asset: EosAsset) -> int: +def write_asset(w: Writer, asset: EosAsset) -> None: write_uint64_le(w, asset.amount) write_uint64_le(w, asset.symbol) -def write_variant32(w: bytearray, value: int) -> int: +def write_variant32(w: Writer, value: int) -> None: variant = bytearray() while True: b = value & 0x7F value >>= 7 b |= (value > 0) << 7 variant.append(b) - if value == 0: break - write_bytes(w, bytes(variant)) diff --git a/core/src/apps/ethereum/__init__.py b/core/src/apps/ethereum/__init__.py index 553f35ecd..4117dd4f0 100644 --- a/core/src/apps/ethereum/__init__.py +++ b/core/src/apps/ethereum/__init__.py @@ -7,7 +7,7 @@ from apps.ethereum.networks import all_slip44_ids_hardened CURVE = "secp256k1" -def boot(): +def boot() -> None: ns = [] for i in all_slip44_ids_hardened(): ns.append([CURVE, HARDENED | 44, i]) diff --git a/core/src/apps/ethereum/sign_tx.py b/core/src/apps/ethereum/sign_tx.py index 16985d40e..0f0ad5923 100644 --- a/core/src/apps/ethereum/sign_tx.py +++ b/core/src/apps/ethereum/sign_tx.py @@ -3,8 +3,8 @@ from trezor.crypto import rlp from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import sha3_256 from trezor.messages.EthereumSignTx import EthereumSignTx +from trezor.messages.EthereumTxAck import EthereumTxAck from trezor.messages.EthereumTxRequest import EthereumTxRequest -from trezor.messages.MessageType import EthereumTxAck from trezor.utils import HashWriter from apps.common import paths diff --git a/core/src/apps/homescreen/__init__.py b/core/src/apps/homescreen/__init__.py index 7fd95958b..536edd8b5 100644 --- a/core/src/apps/homescreen/__init__.py +++ b/core/src/apps/homescreen/__init__.py @@ -6,8 +6,16 @@ from trezor.wire import protobuf_workflow, register from apps.common import cache, storage +if False: + from typing import NoReturn + from trezor.messages.Initialize import Initialize + from trezor.messages.GetFeatures import GetFeatures + from trezor.messages.Cancel import Cancel + from trezor.messages.ClearSession import ClearSession + from trezor.messages.Ping import Ping -def get_features(): + +def get_features() -> Features: f = Features() f.vendor = "trezor.io" f.language = "english" @@ -30,7 +38,7 @@ def get_features(): return f -async def handle_Initialize(ctx, msg): +async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features: if msg.state is None or msg.state != cache.get_state(prev_state=bytes(msg.state)): cache.clear() if msg.skip_passphrase: @@ -38,20 +46,20 @@ async def handle_Initialize(ctx, msg): return get_features() -async def handle_GetFeatures(ctx, msg): +async def handle_GetFeatures(ctx: wire.Context, msg: GetFeatures) -> Features: return get_features() -async def handle_Cancel(ctx, msg): +async def handle_Cancel(ctx: wire.Context, msg: Cancel) -> NoReturn: raise wire.ActionCancelled("Cancelled") -async def handle_ClearSession(ctx, msg): +async def handle_ClearSession(ctx: wire.Context, msg: ClearSession) -> Success: cache.clear(keep_passphrase=True) return Success(message="Session cleared") -async def handle_Ping(ctx, msg): +async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success: if msg.button_protection: from apps.common.confirm import require_confirm from trezor.messages.ButtonRequestType import ProtectCall @@ -65,7 +73,7 @@ async def handle_Ping(ctx, msg): return Success(message=msg.message) -def boot(): +def boot() -> None: register(MessageType.Initialize, protobuf_workflow, handle_Initialize) register(MessageType.GetFeatures, protobuf_workflow, handle_GetFeatures) register(MessageType.Cancel, protobuf_workflow, handle_Cancel) diff --git a/core/src/apps/homescreen/homescreen.py b/core/src/apps/homescreen/homescreen.py index 177e52990..3ba4cee35 100644 --- a/core/src/apps/homescreen/homescreen.py +++ b/core/src/apps/homescreen/homescreen.py @@ -3,7 +3,7 @@ from trezor import config, io, loop, res, ui from apps.common import storage -async def homescreen(): +async def homescreen() -> None: # render homescreen in dimmed mode and fade back in ui.backlight_fade(ui.BACKLIGHT_DIM) display_homescreen() @@ -15,7 +15,7 @@ async def homescreen(): await touch -def display_homescreen(): +def display_homescreen() -> None: image = None if storage.slip39.is_in_progress(): label = "Waiting for other shares" @@ -44,13 +44,13 @@ def display_homescreen(): ui.display.text_center(ui.WIDTH // 2, 220, label, ui.BOLD, ui.FG, ui.BG) -def _warn(message: str): +def _warn(message: str) -> None: ui.display.bar(0, 0, ui.WIDTH, 30, ui.YELLOW) ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.BLACK, ui.YELLOW) ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG) -def _err(message: str): +def _err(message: str) -> None: ui.display.bar(0, 0, ui.WIDTH, 30, ui.RED) ui.display.text_center(ui.WIDTH // 2, 22, message, ui.BOLD, ui.WHITE, ui.RED) ui.display.bar(0, 30, ui.WIDTH, ui.HEIGHT - 30, ui.BG) diff --git a/core/src/apps/lisk/__init__.py b/core/src/apps/lisk/__init__.py index 4032a1951..c525a5526 100644 --- a/core/src/apps/lisk/__init__.py +++ b/core/src/apps/lisk/__init__.py @@ -6,7 +6,7 @@ from apps.common import HARDENED CURVE = "ed25519" -def boot(): +def boot() -> None: ns = [[CURVE, HARDENED | 44, HARDENED | 134]] wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns) diff --git a/core/src/apps/management/__init__.py b/core/src/apps/management/__init__.py index 865e5b95b..e1bbbc386 100644 --- a/core/src/apps/management/__init__.py +++ b/core/src/apps/management/__init__.py @@ -2,7 +2,7 @@ from trezor import wire from trezor.messages import MessageType -def boot(): +def boot() -> None: # only enable LoadDevice in debug builds if __debug__: wire.add(MessageType.LoadDevice, __name__, "load_device") diff --git a/core/src/apps/management/change_pin.py b/core/src/apps/management/change_pin.py index b0c48c30c..a315b22d5 100644 --- a/core/src/apps/management/change_pin.py +++ b/core/src/apps/management/change_pin.py @@ -1,5 +1,6 @@ from trezor import config, ui, wire -from trezor.messages import ButtonRequestType, MessageType +from trezor.messages import ButtonRequestType +from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.Success import Success from trezor.pin import pin_to_int @@ -74,9 +75,7 @@ async def request_pin_confirm(ctx, *args, **kwargs): async def request_pin_ack(ctx, *args, **kwargs): try: - await ctx.call( - ButtonRequest(code=ButtonRequestType.Other), MessageType.ButtonAck - ) + await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck) return await ctx.wait(request_pin(*args, **kwargs)) except PinCancelled: raise wire.ActionCancelled("Cancelled") diff --git a/core/src/apps/management/recovery_device.py b/core/src/apps/management/recovery_device.py index 0d5cb3fd7..6cf186723 100644 --- a/core/src/apps/management/recovery_device.py +++ b/core/src/apps/management/recovery_device.py @@ -1,8 +1,8 @@ from trezor import config, ui, wire from trezor.crypto import slip39 from trezor.messages import ButtonRequestType +from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonRequest import ButtonRequest -from trezor.messages.MessageType import ButtonAck from trezor.messages.Success import Success from trezor.pin import pin_to_int from trezor.ui.info import InfoConfirm @@ -20,8 +20,11 @@ from apps.management.change_pin import request_pin_ack, request_pin_confirm if __debug__: from apps.debug import confirm_signal, input_signal +if False: + from trezor.messages.RecoveryDevice import RecoveryDevice -async def recovery_device(ctx, msg): + +async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: """ Recover BIP39/SLIP39 seed into empty device. @@ -116,7 +119,7 @@ async def recovery_device(ctx, msg): return Success(message="Device recovered") -async def request_wordcount(ctx, title: str) -> int: +async def request_wordcount(ctx: wire.Context, title: str) -> int: await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck) text = Text(title, ui.ICON_RECOVERY) @@ -131,7 +134,7 @@ async def request_wordcount(ctx, title: str) -> int: return count -async def request_mnemonic(ctx, count: int, slip39: bool) -> str: +async def request_mnemonic(ctx: wire.Context, count: int, slip39: bool) -> str: await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck) words = [] @@ -149,7 +152,7 @@ async def request_mnemonic(ctx, count: int, slip39: bool) -> str: return " ".join(words) -async def show_keyboard_info(ctx): +async def show_keyboard_info(ctx: wire.Context) -> None: await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck) info = InfoConfirm( @@ -174,7 +177,9 @@ async def show_success(ctx): ) -async def show_remaining_slip39_mnemonics(ctx, title, remaining: int): +async def show_remaining_slip39_mnemonics( + ctx: wire.Context, title: str, remaining: int +) -> None: text = Text(title, ui.ICON_RECOVERY) text.bold("Good job!") text.normal("Enter %s more recovery " % remaining) diff --git a/core/src/apps/management/reset_device.py b/core/src/apps/management/reset_device.py index 8e2b310ef..d3dd0ee41 100644 --- a/core/src/apps/management/reset_device.py +++ b/core/src/apps/management/reset_device.py @@ -1,6 +1,6 @@ from trezor import config, wire from trezor.crypto import bip39, hashlib, random, slip39 -from trezor.messages import MessageType +from trezor.messages.EntropyAck import EntropyAck from trezor.messages.EntropyRequest import EntropyRequest from trezor.messages.Success import Success from trezor.pin import pin_to_int @@ -12,8 +12,11 @@ from apps.management.common import layout if __debug__: from apps import debug +if False: + from trezor.messages.ResetDevice import ResetDevice -async def reset_device(ctx, msg): + +async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: # validate parameters and device state _validate_reset_device(msg) @@ -34,7 +37,7 @@ async def reset_device(ctx, msg): await layout.show_internal_entropy(ctx, int_entropy) # request external entropy and compute the master secret - entropy_ack = await ctx.call(EntropyRequest(), MessageType.EntropyAck) + entropy_ack = await ctx.call(EntropyRequest(), EntropyAck) ext_entropy = entropy_ack.entropy secret = _compute_secret_from_entropy(int_entropy, ext_entropy, msg.strength) @@ -84,7 +87,7 @@ async def reset_device(ctx, msg): return Success(message="Initialized") -async def backup_slip39_wallet(ctx, secret: bytes): +async def backup_slip39_wallet(ctx: wire.Context, secret: bytes) -> None: # get number of shares await layout.slip39_show_checklist_set_shares(ctx) shares_count = await layout.slip39_prompt_number_of_shares(ctx) @@ -101,12 +104,12 @@ async def backup_slip39_wallet(ctx, secret: bytes): await layout.slip39_show_and_confirm_shares(ctx, mnemonics) -async def backup_bip39_wallet(ctx, secret: bytes): +async def backup_bip39_wallet(ctx: wire.Context, secret: bytes) -> None: mnemonic = bip39.from_data(secret) await layout.bip39_show_and_confirm_mnemonic(ctx, mnemonic) -def _validate_reset_device(msg): +def _validate_reset_device(msg: ResetDevice) -> None: if msg.strength not in (128, 256): if msg.slip39: raise wire.ProcessError("Invalid strength (has to be 128 or 256 bits)") diff --git a/core/src/apps/monero/__init__.py b/core/src/apps/monero/__init__.py index 68704690a..9f774b0f3 100644 --- a/core/src/apps/monero/__init__.py +++ b/core/src/apps/monero/__init__.py @@ -7,7 +7,7 @@ CURVE = "ed25519" _LIVE_REFRESH_TOKEN = None # live-refresh permission token -def boot(): +def boot() -> None: ns = [[CURVE, HARDENED | 44, HARDENED | 128]] wire.add(MessageType.MoneroGetAddress, __name__, "get_address", ns) wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only", ns) diff --git a/core/src/apps/monero/key_image_sync.py b/core/src/apps/monero/key_image_sync.py index f4c8c70fd..59031f398 100644 --- a/core/src/apps/monero/key_image_sync.py +++ b/core/src/apps/monero/key_image_sync.py @@ -1,11 +1,14 @@ import gc from trezor import log, wire -from trezor.messages import MessageType from trezor.messages.MoneroExportedKeyImage import MoneroExportedKeyImage from trezor.messages.MoneroKeyImageExportInitAck import MoneroKeyImageExportInitAck from trezor.messages.MoneroKeyImageSyncFinalAck import MoneroKeyImageSyncFinalAck +from trezor.messages.MoneroKeyImageSyncFinalRequest import ( + MoneroKeyImageSyncFinalRequest, +) from trezor.messages.MoneroKeyImageSyncStepAck import MoneroKeyImageSyncStepAck +from trezor.messages.MoneroKeyImageSyncStepRequest import MoneroKeyImageSyncStepRequest from apps.common import paths from apps.monero import CURVE, misc @@ -18,19 +21,12 @@ async def key_image_sync(ctx, msg, keychain): state = KeyImageSync() res = await _init_step(state, ctx, msg, keychain) - while True: - msg = await ctx.call( - res, - MessageType.MoneroKeyImageSyncStepRequest, - MessageType.MoneroKeyImageSyncFinalRequest, - ) - del res - if msg.MESSAGE_WIRE_TYPE == MessageType.MoneroKeyImageSyncStepRequest: - res = await _sync_step(state, ctx, msg) - else: - res = await _final_step(state, ctx) - break + while state.current_output + 1 < state.num_outputs: + msg = await ctx.call(res, MoneroKeyImageSyncStepRequest) + res = await _sync_step(state, ctx, msg) gc.collect() + msg = await ctx.call(res, MoneroKeyImageSyncFinalRequest) + res = await _final_step(state, ctx) return res diff --git a/core/src/apps/monero/layout/common.py b/core/src/apps/monero/layout/common.py index 05bc7a3c0..b3555030e 100644 --- a/core/src/apps/monero/layout/common.py +++ b/core/src/apps/monero/layout/common.py @@ -1,5 +1,6 @@ from trezor import loop, ui, utils -from trezor.messages import ButtonRequestType, MessageType +from trezor.messages import ButtonRequestType +from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonRequest import ButtonRequest from trezor.ui.text import Text @@ -27,9 +28,7 @@ async def naive_pagination( paginated = PaginatedWithButtons(pages, one_by_one=True) while True: - await ctx.call( - ButtonRequest(code=ButtonRequestType.SignTx), MessageType.ButtonAck - ) + await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck) if __debug__: result = await loop.spawn(paginated, confirm_signal) else: diff --git a/core/src/apps/monero/live_refresh.py b/core/src/apps/monero/live_refresh.py index 50e04a3c9..b9063544b 100644 --- a/core/src/apps/monero/live_refresh.py +++ b/core/src/apps/monero/live_refresh.py @@ -21,7 +21,7 @@ async def live_refresh(ctx, msg: MoneroLiveRefreshStartRequest, keychain): res = await _init_step(state, ctx, msg, keychain) while True: - msg = await ctx.call( + msg = await ctx.call_any( res, MessageType.MoneroLiveRefreshStepRequest, MessageType.MoneroLiveRefreshFinalRequest, diff --git a/core/src/apps/monero/sign_tx.py b/core/src/apps/monero/sign_tx.py index 22f29ce50..892a7373f 100644 --- a/core/src/apps/monero/sign_tx.py +++ b/core/src/apps/monero/sign_tx.py @@ -26,7 +26,7 @@ async def sign_tx(ctx, received_msg, keychain): del (result_msg, received_msg) utils.unimport_end(mods) - received_msg = await ctx.read(accept_msgs) + received_msg = await ctx.read_any(accept_msgs) utils.unimport_end(mods) return result_msg diff --git a/core/src/apps/nem/__init__.py b/core/src/apps/nem/__init__.py index fb3d13a11..2253845bc 100644 --- a/core/src/apps/nem/__init__.py +++ b/core/src/apps/nem/__init__.py @@ -6,7 +6,7 @@ from apps.common import HARDENED CURVE = "ed25519-keccak" -def boot(): +def boot() -> None: ns = [[CURVE, HARDENED | 44, HARDENED | 43], [CURVE, HARDENED | 44, HARDENED | 1]] wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns) wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns) diff --git a/core/src/apps/ripple/__init__.py b/core/src/apps/ripple/__init__.py index ac116bf7e..f92768db2 100644 --- a/core/src/apps/ripple/__init__.py +++ b/core/src/apps/ripple/__init__.py @@ -6,7 +6,7 @@ from apps.common import HARDENED CURVE = "secp256k1" -def boot(): +def boot() -> None: ns = [[CURVE, HARDENED | 44, HARDENED | 144]] wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns) wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns) diff --git a/core/src/apps/stellar/__init__.py b/core/src/apps/stellar/__init__.py index f868d8054..fe40b24a8 100644 --- a/core/src/apps/stellar/__init__.py +++ b/core/src/apps/stellar/__init__.py @@ -6,7 +6,7 @@ from apps.common import HARDENED CURVE = "ed25519" -def boot(): +def boot() -> None: ns = [[CURVE, HARDENED | 44, HARDENED | 148]] wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns) wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns) diff --git a/core/src/apps/stellar/sign_tx.py b/core/src/apps/stellar/sign_tx.py index ad32a57f5..5bfbf1a64 100644 --- a/core/src/apps/stellar/sign_tx.py +++ b/core/src/apps/stellar/sign_tx.py @@ -77,7 +77,7 @@ def _timebounds(w: bytearray, start: int, end: int): async def _operations(ctx, w: bytearray, num_operations: int): writers.write_uint32(w, num_operations) for i in range(num_operations): - op = await ctx.call(StellarTxOpRequest(), *consts.op_wire_types) + op = await ctx.call_any(StellarTxOpRequest(), *consts.op_wire_types) await process_operation(ctx, w, op) diff --git a/core/src/apps/tezos/__init__.py b/core/src/apps/tezos/__init__.py index 0022075be..d0068be04 100644 --- a/core/src/apps/tezos/__init__.py +++ b/core/src/apps/tezos/__init__.py @@ -6,7 +6,7 @@ from apps.common import HARDENED CURVE = "ed25519" -def boot(): +def boot() -> None: ns = [[CURVE, HARDENED | 44, HARDENED | 1729]] wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns) wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns) diff --git a/core/src/apps/wallet/__init__.py b/core/src/apps/wallet/__init__.py index fc70179cc..03d5321c3 100644 --- a/core/src/apps/wallet/__init__.py +++ b/core/src/apps/wallet/__init__.py @@ -2,7 +2,7 @@ from trezor import wire from trezor.messages import MessageType -def boot(): +def boot() -> None: ns = [ ["curve25519"], ["ed25519"], diff --git a/core/src/apps/wallet/sign_tx/__init__.py b/core/src/apps/wallet/sign_tx/__init__.py index b38178cf8..feb7cfbcf 100644 --- a/core/src/apps/wallet/sign_tx/__init__.py +++ b/core/src/apps/wallet/sign_tx/__init__.py @@ -1,6 +1,6 @@ from trezor import utils, wire -from trezor.messages.MessageType import TxAck from trezor.messages.RequestType import TXFINISHED +from trezor.messages.TxAck import TxAck from trezor.messages.TxRequest import TxRequest from apps.common import paths diff --git a/core/src/boot.py b/core/src/boot.py index 7181c0ca1..f6c414482 100644 --- a/core/src/boot.py +++ b/core/src/boot.py @@ -5,7 +5,7 @@ from apps.common import storage from apps.common.request_pin import request_pin -async def bootscreen(): +async def bootscreen() -> None: ui.display.orientation(storage.device.get_rotation()) while True: try: @@ -27,7 +27,7 @@ async def bootscreen(): log.exception(__name__, e) -async def lockscreen(): +async def lockscreen() -> None: label = storage.device.get_label() image = storage.device.get_homescreen() if not label: diff --git a/core/src/protobuf.py b/core/src/protobuf.py index 64d6f047e..17fa08a8e 100644 --- a/core/src/protobuf.py +++ b/core/src/protobuf.py @@ -1,32 +1,31 @@ -''' +""" Extremely minimal streaming codec for a subset of protobuf. Supports uint32, bytes, string, embedded message and repeated fields. +""" -For de-serializing (loading) protobuf types, object with `AsyncReader` -interface is required: +from micropython import const ->>> class AsyncReader: ->>> async def areadinto(self, buffer): ->>> """ ->>> Reads `len(buffer)` bytes into `buffer`, or raises `EOFError`. ->>> """ +if False: + from typing import Any, Dict, List, Type, TypeVar + from typing_extensions import Protocol -For serializing (dumping) protobuf types, object with `AsyncWriter` interface is -required: + class AsyncReader(Protocol): + async def areadinto(self, buf: bytearray) -> int: + """ + Reads `len(buf)` bytes into `buf`, or raises `EOFError`. + """ ->>> class AsyncWriter: ->>> async def awrite(self, buffer): ->>> """ ->>> Writes all bytes from `buffer`, or raises `EOFError`. ->>> """ -''' + class AsyncWriter(Protocol): + async def awrite(self, buf: bytes) -> int: + """ + Writes all bytes from `buf`, or raises `EOFError`. + """ -from micropython import const _UVARINT_BUFFER = bytearray(1) -async def load_uvarint(reader): +async def load_uvarint(reader: AsyncReader) -> int: buffer = _UVARINT_BUFFER result = 0 shift = 0 @@ -39,11 +38,11 @@ async def load_uvarint(reader): return result -async def dump_uvarint(writer, n): +async def dump_uvarint(writer: AsyncWriter, n: int) -> None: if n < 0: raise ValueError("Cannot dump signed value, convert it to unsigned first.") buffer = _UVARINT_BUFFER - shifted = True + shifted = 1 while shifted: shifted = n >> 7 buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) @@ -51,7 +50,7 @@ async def dump_uvarint(writer, n): n = shifted -def count_uvarint(n): +def count_uvarint(n: int) -> int: if n < 0: raise ValueError("Cannot dump signed value, convert it to unsigned first.") if n <= 0x7F: @@ -95,14 +94,14 @@ def count_uvarint(n): # So we have to branch on whether the number is negative. -def sint_to_uint(sint): +def sint_to_uint(sint: int) -> int: res = sint << 1 if sint < 0: res = ~res return res -def uint_to_sint(uint): +def uint_to_sint(uint: int) -> int: sign = uint & 1 res = uint >> 1 if sign: @@ -133,27 +132,31 @@ class UnicodeType: class MessageType: WIRE_TYPE = 2 + # Type id for the wire codec. + # Technically, not every protobuf message has this. + MESSAGE_WIRE_TYPE = -1 + @classmethod - def get_fields(cls): + def get_fields(cls) -> Dict: return {} - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: for kw in kwargs: setattr(self, kw, kwargs[kw]) - def __eq__(self, rhs): + def __eq__(self, rhs: Any) -> bool: return self.__class__ is rhs.__class__ and self.__dict__ == rhs.__dict__ - def __repr__(self): + def __repr__(self) -> str: return "<%s>" % self.__class__.__name__ class LimitedReader: - def __init__(self, reader, limit): + def __init__(self, reader: AsyncReader, limit: int) -> None: self.reader = reader self.limit = limit - async def areadinto(self, buf): + async def areadinto(self, buf: bytearray) -> int: if self.limit < len(buf): raise EOFError else: @@ -162,20 +165,15 @@ class LimitedReader: return nread -class CountingWriter: - def __init__(self): - self.size = 0 - - async def awrite(self, buf): - nwritten = len(buf) - self.size += nwritten - return nwritten - - FLAG_REPEATED = const(1) +if False: + LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType) + -async def load_message(reader, msg_type): +async def load_message( + reader: AsyncReader, msg_type: Type[LoadedMessageType] +) -> LoadedMessageType: fields = msg_type.get_fields() msg = msg_type() @@ -239,7 +237,9 @@ async def load_message(reader, msg_type): return msg -async def dump_message(writer, msg, fields=None): +async def dump_message( + writer: AsyncWriter, msg: MessageType, fields: Dict = None +) -> None: repvalue = [0] if fields is None: @@ -297,7 +297,7 @@ async def dump_message(writer, msg, fields=None): raise TypeError -def count_message(msg, fields=None): +def count_message(msg: MessageType, fields: Dict = None) -> int: nbytes = 0 repvalue = [0] @@ -361,7 +361,7 @@ def count_message(msg, fields=None): return nbytes -def _count_bytes_list(svalue): +def _count_bytes_list(svalue: List[bytes]) -> int: res = 0 for x in svalue: res += len(x) diff --git a/core/src/trezor/crypto/__init__.py b/core/src/trezor/crypto/__init__.py index 7bc581b6a..5d538d3e1 100644 --- a/core/src/trezor/crypto/__init__.py +++ b/core/src/trezor/crypto/__init__.py @@ -1,5 +1,3 @@ -from gc import collect - from trezorcrypto import ( # noqa: F401 aes, bip32, @@ -12,18 +10,3 @@ from trezorcrypto import ( # noqa: F401 random, rfc6979, ) - - -class SecureContext: - def __init__(self): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - for k in self.__dict__: - o = getattr(self, k) - if hasattr(o, "__del__"): - o.__del__() - collect() diff --git a/core/src/trezor/crypto/hmac.py b/core/src/trezor/crypto/hmac.py index 6f16fffbe..69e2322ad 100644 --- a/core/src/trezor/crypto/hmac.py +++ b/core/src/trezor/crypto/hmac.py @@ -1,5 +1,23 @@ +if False: + from typing import Protocol, Type + + class HashContext(Protocol): + + digest_size = -1 # type: int + block_size = -1 # type: int + + def __init__(self, data: bytes = None) -> None: + ... + + def update(self, data: bytes) -> None: + ... + + def digest(self) -> bytes: + ... + + class Hmac: - def __init__(self, key, msg, digestmod): + def __init__(self, key: bytes, msg: bytes, digestmod: Type[HashContext]): self.digestmod = digestmod self.inner = digestmod() self.digest_size = self.inner.digest_size @@ -28,7 +46,7 @@ class Hmac: return outer.digest() -def new(key, msg, digestmod) -> Hmac: +def new(key: bytes, msg: bytes, digestmod: Type[HashContext]) -> Hmac: """ Creates a HMAC context object. """ diff --git a/core/src/trezor/crypto/slip39.py b/core/src/trezor/crypto/slip39.py index 1188200c3..02854dbc1 100644 --- a/core/src/trezor/crypto/slip39.py +++ b/core/src/trezor/crypto/slip39.py @@ -23,6 +23,11 @@ from micropython import const from trezor.crypto import hashlib, hmac, pbkdf2, random from trezorcrypto import shamir, slip39 +if False: + from typing import Dict, Iterable, List, Optional, Set, Tuple + + Indices = Tuple[int, ...] + KEYBOARD_FULL_MASK = const(0x1FF) """All buttons are allowed. 9-bit bitmap all set to 1.""" @@ -35,7 +40,7 @@ def compute_mask(prefix: str) -> int: def button_sequence_to_word(prefix: str) -> str: if not prefix: - return KEYBOARD_FULL_MASK + return "" return slip39.button_sequence_to_word(int(prefix)) @@ -43,11 +48,11 @@ _RADIX_BITS = const(10) """The length of the radix in bits.""" -def bits_to_bytes(n): +def bits_to_bytes(n: int) -> int: return (n + 7) // 8 -def bits_to_words(n): +def bits_to_words(n: int) -> int: return (n + _RADIX_BITS - 1) // _RADIX_BITS @@ -103,7 +108,7 @@ class MnemonicError(Exception): pass -def _rs1024_polymod(values): +def _rs1024_polymod(values: Indices) -> int: GEN = ( 0xE0E040, 0x1C1C080, @@ -125,7 +130,7 @@ def _rs1024_polymod(values): return chk -def rs1024_create_checksum(data): +def rs1024_create_checksum(data: Indices) -> Indices: values = tuple(_CUSTOMIZATION_STRING) + data + _CHECKSUM_LENGTH_WORDS * (0,) polymod = _rs1024_polymod(values) ^ 1 return tuple( @@ -133,11 +138,11 @@ def rs1024_create_checksum(data): ) -def rs1024_verify_checksum(data): +def rs1024_verify_checksum(data: Indices) -> bool: return _rs1024_polymod(tuple(_CUSTOMIZATION_STRING) + data) == 1 -def rs1024_error_index(data): +def rs1024_error_index(data: Indices) -> Optional[int]: GEN = ( 0x91F9F87, 0x122F1F07, @@ -164,11 +169,11 @@ def rs1024_error_index(data): return None -def xor(a, b): +def xor(a: bytes, b: bytes) -> bytes: return bytes(x ^ y for x, y in zip(a, b)) -def _int_from_indices(indices): +def _int_from_indices(indices: Indices) -> int: """Converts a list of base 1024 indices in big endian order to an integer value.""" value = 0 for index in indices: @@ -176,21 +181,21 @@ def _int_from_indices(indices): return value -def _int_to_indices(value, length, bits): +def _int_to_indices(value: int, length: int, bits: int) -> Iterable[int]: """Converts an integer value to indices in big endian order.""" mask = (1 << bits) - 1 return ((value >> (i * bits)) & mask for i in reversed(range(length))) -def mnemonic_from_indices(indices): +def mnemonic_from_indices(indices: Indices) -> str: return " ".join(slip39.get_word(i) for i in indices) -def mnemonic_to_indices(mnemonic): +def mnemonic_to_indices(mnemonic: str) -> Iterable[int]: return (slip39.word_index(word.lower()) for word in mnemonic.split()) -def _round_function(i, passphrase, e, salt, r): +def _round_function(i: int, passphrase: bytes, e: int, salt: bytes, r: bytes) -> bytes: """The round function used internally by the Feistel cipher.""" return pbkdf2( pbkdf2.HMAC_SHA256, @@ -200,13 +205,15 @@ def _round_function(i, passphrase, e, salt, r): ).key()[: len(r)] -def _get_salt(identifier): +def _get_salt(identifier: int) -> bytes: return _CUSTOMIZATION_STRING + identifier.to_bytes( bits_to_bytes(_ID_LENGTH_BITS), "big" ) -def _encrypt(master_secret, passphrase, iteration_exponent, identifier): +def _encrypt( + master_secret: bytes, passphrase: bytes, iteration_exponent: int, identifier: int +) -> bytes: l = master_secret[: len(master_secret) // 2] r = master_secret[len(master_secret) // 2 :] salt = _get_salt(identifier) @@ -218,7 +225,12 @@ def _encrypt(master_secret, passphrase, iteration_exponent, identifier): return r + l -def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase): +def decrypt( + identifier: int, + iteration_exponent: int, + encrypted_master_secret: bytes, + passphrase: bytes, +) -> bytes: l = encrypted_master_secret[: len(encrypted_master_secret) // 2] r = encrypted_master_secret[len(encrypted_master_secret) // 2 :] salt = _get_salt(identifier) @@ -230,13 +242,15 @@ def decrypt(identifier, iteration_exponent, encrypted_master_secret, passphrase) return r + l -def _create_digest(random_data, shared_secret): +def _create_digest(random_data: bytes, shared_secret: bytes) -> bytes: return hmac.new(random_data, shared_secret, hashlib.sha256).digest()[ :_DIGEST_LENGTH_BYTES ] -def _split_secret(threshold, share_count, shared_secret): +def _split_secret( + threshold: int, share_count: int, shared_secret: bytes +) -> List[Tuple[int, bytes]]: if threshold < 1: raise ValueError( "The requested threshold ({}) must be a positive integer.".format(threshold) @@ -278,7 +292,7 @@ def _split_secret(threshold, share_count, shared_secret): return shares -def _recover_secret(threshold, shares): +def _recover_secret(threshold: int, shares: List[Tuple[int, bytes]]) -> bytes: # If the threshold is 1, then the digest of the shared secret is not used. if threshold == 1: return shares[0][1] @@ -295,8 +309,12 @@ def _recover_secret(threshold, shares): def _group_prefix( - identifier, iteration_exponent, group_index, group_threshold, group_count -): + identifier: int, + iteration_exponent: int, + group_index: int, + group_threshold: int, + group_count: int, +) -> Indices: id_exp_int = (identifier << _ITERATION_EXP_LENGTH_BITS) + iteration_exponent return tuple(_int_to_indices(id_exp_int, _ID_EXP_LENGTH_WORDS, _RADIX_BITS)) + ( (group_index << 6) + ((group_threshold - 1) << 2) + ((group_count - 1) >> 2), @@ -304,15 +322,15 @@ def _group_prefix( def encode_mnemonic( - identifier, - iteration_exponent, - group_index, - group_threshold, - group_count, - member_index, - member_threshold, - value, -): + identifier: int, + iteration_exponent: int, + group_index: int, + group_threshold: int, + group_count: int, + member_index: int, + member_threshold: int, + value: bytes, +) -> str: """ Converts share data to a share mnemonic. :param int identifier: The random identifier. @@ -348,7 +366,7 @@ def encode_mnemonic( return mnemonic_from_indices(share_data + checksum) -def decode_mnemonic(mnemonic): +def decode_mnemonic(mnemonic: str) -> Tuple[int, int, int, int, int, int, int, bytes]: """Converts a share mnemonic to share data.""" mnemonic_data = tuple(mnemonic_to_indices(mnemonic)) @@ -401,12 +419,20 @@ def decode_mnemonic(mnemonic): ) -def _decode_mnemonics(mnemonics): +if False: + MnemonicGroups = Dict[int, Tuple[int, Set[Tuple[int, bytes]]]] + + +def _decode_mnemonics( + mnemonics: List[str] +) -> Tuple[int, int, int, int, MnemonicGroups]: identifiers = set() iteration_exponents = set() group_thresholds = set() group_counts = set() - groups = {} # { group_index : [member_threshold, set_of_member_shares] } + + # { group_index : [member_threshold, set_of_member_shares] } + groups = {} # type: MnemonicGroups for mnemonic in mnemonics: identifier, iteration_exponent, group_index, group_threshold, group_count, member_index, member_threshold, share_value = decode_mnemonic( mnemonic @@ -415,7 +441,7 @@ def _decode_mnemonics(mnemonics): iteration_exponents.add(iteration_exponent) group_thresholds.add(group_threshold) group_counts.add(group_count) - group = groups.setdefault(group_index, [member_threshold, set()]) + group = groups.setdefault(group_index, (member_threshold, set())) if group[0] != member_threshold: raise MnemonicError( "Invalid set of mnemonics. All mnemonics in a group must have the same member threshold." @@ -462,13 +488,13 @@ def generate_random_identifier() -> int: def generate_single_group_mnemonics_from_data( - master_secret, - identifier, - threshold, - count, - passphrase=b"", - iteration_exponent=DEFAULT_ITERATION_EXPONENT, -) -> list: + master_secret: bytes, + identifier: int, + threshold: int, + count: int, + passphrase: bytes = b"", + iteration_exponent: int = DEFAULT_ITERATION_EXPONENT, +) -> List[str]: return generate_mnemonics_from_data( master_secret, identifier, @@ -480,13 +506,13 @@ def generate_single_group_mnemonics_from_data( def generate_mnemonics_from_data( - master_secret, - identifier, - group_threshold, - groups, - passphrase=b"", - iteration_exponent=DEFAULT_ITERATION_EXPONENT, -) -> list: + master_secret: bytes, + identifier: int, + group_threshold: int, + groups: List[Tuple[int, int]], + passphrase: bytes = b"", + iteration_exponent: int = DEFAULT_ITERATION_EXPONENT, +) -> List[List[str]]: """ Splits a master secret into mnemonic shares using Shamir's secret sharing scheme. :param master_secret: The master secret to split. @@ -544,7 +570,7 @@ def generate_mnemonics_from_data( group_shares = _split_secret(group_threshold, len(groups), encrypted_master_secret) - mnemonics = [] + mnemonics = [] # type: List[List[str]] for (member_threshold, member_count), (group_index, group_secret) in zip( groups, group_shares ): @@ -568,7 +594,7 @@ def generate_mnemonics_from_data( return mnemonics -def combine_mnemonics(mnemonics): +def combine_mnemonics(mnemonics: List[str]) -> Tuple[int, int, bytes]: """ Combines mnemonic shares to obtain the master secret which was previously split using Shamir's secret sharing scheme. diff --git a/core/src/trezor/log.py b/core/src/trezor/log.py index a1ba47685..9df43c8cc 100644 --- a/core/src/trezor/log.py +++ b/core/src/trezor/log.py @@ -2,6 +2,9 @@ import sys import utime from micropython import const +if False: + from typing import Any + NOTSET = const(0) DEBUG = const(10) INFO = const(20) @@ -21,7 +24,7 @@ level = DEBUG color = True -def _log(name, mlevel, msg, *args): +def _log(name: str, mlevel: int, msg: str, *args: Any) -> None: if __debug__ and mlevel >= level: if color: fmt = ( @@ -35,26 +38,26 @@ def _log(name, mlevel, msg, *args): print(fmt % ((utime.ticks_us(), name, _leveldict[mlevel][0]) + args)) -def debug(name, msg, *args): +def debug(name: str, msg: str, *args: Any) -> None: _log(name, DEBUG, msg, *args) -def info(name, msg, *args): +def info(name: str, msg: str, *args: Any) -> None: _log(name, INFO, msg, *args) -def warning(name, msg, *args): +def warning(name: str, msg: str, *args: Any) -> None: _log(name, WARNING, msg, *args) -def error(name, msg, *args): +def error(name: str, msg: str, *args: Any) -> None: _log(name, ERROR, msg, *args) -def exception(name, exc): - _log(name, ERROR, "exception:") - sys.print_exception(exc) +def critical(name: str, msg: str, *args: Any) -> None: + _log(name, CRITICAL, msg, *args) -def critical(name, msg, *args): - _log(name, CRITICAL, msg, *args) +def exception(name: str, exc: BaseException) -> None: + _log(name, ERROR, "exception:") + sys.print_exception(exc) diff --git a/core/src/trezor/loop.py b/core/src/trezor/loop.py index 5a4424640..6218ed18c 100644 --- a/core/src/trezor/loop.py +++ b/core/src/trezor/loop.py @@ -13,12 +13,33 @@ from micropython import const from trezor import io, log -after_step_hook = None # function, called after each task step - -_QUEUE_SIZE = const(64) # maximum number of scheduled tasks -_queue = utimeq.utimeq(_QUEUE_SIZE) -_paused = {} -_finalizers = {} +if False: + from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Dict, + Generator, + List, + Optional, + Set, + ) + + Task = Coroutine + Finalizer = Callable[[Task, Any], None] + +# function to call after every task step +after_step_hook = None # type: Optional[Callable[[], None]] + +# tasks scheduled for execution in the future +_queue = utimeq.utimeq(64) + +# tasks paused on I/O +_paused = {} # type: Dict[int, Set[Task]] + +# functions to execute after a task is finished +_finalizers = {} # type: Dict[int, Finalizer] if __debug__: # for performance stats @@ -29,7 +50,9 @@ if __debug__: log_delay_rb = array.array("i", [0] * log_delay_rb_len) -def schedule(task, value=None, deadline=None, finalizer=None): +def schedule( + task: Task, value: Any = None, deadline: int = None, finalizer: Finalizer = None +) -> None: """ Schedule task to be executed with `value` on given `deadline` (in microseconds). Does not start the event loop itself, see `run`. @@ -41,20 +64,20 @@ def schedule(task, value=None, deadline=None, finalizer=None): _queue.push(deadline, task, value) -def pause(task, iface): +def pause(task: Task, iface: int) -> None: tasks = _paused.get(iface, None) if tasks is None: tasks = _paused[iface] = set() tasks.add(task) -def finalize(task, value): +def finalize(task: Task, value: Any) -> None: fn = _finalizers.pop(id(task), None) if fn is not None: fn(task, value) -def close(task): +def close(task: Task) -> None: for iface in _paused: _paused[iface].discard(task) _queue.discard(task) @@ -62,7 +85,7 @@ def close(task): finalize(task, GeneratorExit()) -def run(): +def run() -> None: """ Loop forever, stepping through scheduled tasks and awaiting I/O events inbetween. Use `schedule` first to add a coroutine to the task queue. @@ -98,13 +121,17 @@ def run(): # timeout occurred, run the first scheduled task if _queue: _queue.pop(task_entry) - _step(task_entry[1], task_entry[2]) + _step(task_entry[1], task_entry[2]) # type: ignore + # error: Argument 1 to "_step" has incompatible type "int"; expected "Coroutine[Any, Any, Any]" + # rationale: We use untyped lists here, because that is what the C API supports. -def _step(task, value): +def _step(task: Task, value: Any) -> None: try: if isinstance(value, BaseException): - result = task.throw(value) + result = task.throw(value) # type: ignore + # error: Argument 1 to "throw" of "Coroutine" has incompatible type "Exception"; expected "Type[BaseException]" + # rationale: In micropython, generator.throw() accepts the exception object directly. else: result = task.send(value) except StopIteration as e: # as e: @@ -133,10 +160,16 @@ class Syscall: scheduler, they do so through instances of a class derived from `Syscall`. """ - def __iter__(self): + def __iter__(self) -> Task: # type: ignore # support `yield from` or `await` on syscalls return (yield self) + def __await__(self) -> Generator: + return self.__iter__() # type: ignore + + def handle(self, task: Task) -> None: + pass + class sleep(Syscall): """ @@ -150,10 +183,10 @@ class sleep(Syscall): >>> print('missed by %d us', utime.ticks_diff(utime.ticks_us(), planned)) """ - def __init__(self, delay_us): + def __init__(self, delay_us: int) -> None: self.delay_us = delay_us - def handle(self, task): + def handle(self, task: Task) -> None: deadline = utime.ticks_add(utime.ticks_us(), self.delay_us) schedule(task, deadline, deadline) @@ -170,14 +203,14 @@ class wait(Syscall): >>> event, x, y = await loop.wait(io.TOUCH) # await touch event """ - def __init__(self, msg_iface): + def __init__(self, msg_iface: int) -> None: self.msg_iface = msg_iface - def handle(self, task): + def handle(self, task: Task) -> None: pause(task, self.msg_iface) -_NO_VALUE = () +_NO_VALUE = object() class signal(Syscall): @@ -196,28 +229,28 @@ class signal(Syscall): >>> # prints in the next iteration of the event loop """ - def __init__(self): + def __init__(self) -> None: self.reset() - def reset(self): + def reset(self) -> None: self.value = _NO_VALUE - self.task = None + self.task = None # type: Optional[Task] - def handle(self, task): + def handle(self, task: Task) -> None: self.task = task self._deliver() - def send(self, value): + def send(self, value: Any) -> None: self.value = value self._deliver() - def _deliver(self): + def _deliver(self) -> None: if self.task is not None and self.value is not _NO_VALUE: schedule(self.task, self.value) self.task = None self.value = _NO_VALUE - def __iter__(self): + def __iter__(self) -> Task: # type: ignore try: return (yield self) except: # noqa: E722 @@ -253,14 +286,13 @@ class spawn(Syscall): `spawn.__iter__` for explanation. Always use `await`. """ - def __init__(self, *children, exit_others=True): + def __init__(self, *children: Awaitable, exit_others: bool = True) -> None: self.children = children self.exit_others = exit_others - self.scheduled = [] # list of scheduled tasks - self.finished = [] # list of children that finished - self.callback = None + self.finished = [] # type: List[Awaitable] # children that finished + self.scheduled = [] # type: List[Task] # scheduled wrapper tasks - def handle(self, task): + def handle(self, task: Task) -> None: finalizer = self._finish scheduled = self.scheduled finished = self.finished @@ -273,16 +305,17 @@ class spawn(Syscall): if isinstance(child, _type_gen): child_task = child else: - child_task = iter(child) - schedule(child_task, None, None, finalizer) - scheduled.append(child_task) + child_task = iter(child) # type: ignore + schedule(child_task, None, None, finalizer) # type: ignore + scheduled.append(child_task) # type: ignore + # TODO: document the types here - def exit(self, except_for=None): + def exit(self, except_for: Task = None) -> None: for task in self.scheduled: if task != except_for: close(task) - def _finish(self, task, result): + def _finish(self, task: Task, result: Any) -> None: if not self.finished: for index, child_task in enumerate(self.scheduled): if child_task is task: @@ -293,7 +326,7 @@ class spawn(Syscall): self.exit(task) schedule(self.callback, result) - def __iter__(self): + def __iter__(self) -> Task: # type: ignore try: return (yield self) except: # noqa: E722 diff --git a/core/src/trezor/res/__init__.py b/core/src/trezor/res/__init__.py index cc63ef5c8..ff3fd6782 100644 --- a/core/src/trezor/res/__init__.py +++ b/core/src/trezor/res/__init__.py @@ -1,17 +1,17 @@ try: from .resources import resdata except ImportError: - resdata = None + resdata = {} -def load(name): +def load(name: str) -> bytes: """ Loads resource of a given name as bytes. """ return resdata[name] -def gettext(message): +def gettext(message: str) -> str: """ Returns localized string. This function is aliased to _. """ diff --git a/core/src/trezor/ui/__init__.py b/core/src/trezor/ui/__init__.py index 2260e4861..9c6c5cc0c 100644 --- a/core/src/trezor/ui/__init__.py +++ b/core/src/trezor/ui/__init__.py @@ -5,12 +5,18 @@ from trezorui import Display from trezor import io, loop, res, utils, workflow +if False: + from typing import Any, Generator, Iterable, Tuple, TypeVar + + Pos = Tuple[int, int] + Area = Tuple[int, int, int, int] + display = Display() # in debug mode, display an indicator in top right corner if __debug__: - def debug_display_refresh(): + def debug_display_refresh() -> None: display.bar(Display.WIDTH - 8, 0, 8, 8, 0xF800) display.refresh() if utils.SAVE_SCREEN: @@ -59,25 +65,25 @@ from trezor.ui import style # isort:skip from trezor.ui.style import * # isort:skip # noqa: F401,F403 -def pulse(delay: int): +def pulse(delay: int) -> float: # normalize sin from interval -1:1 to 0:1 return 0.5 + 0.5 * math.sin(utime.ticks_us() / delay) -async def click() -> tuple: +async def click() -> Pos: touch = loop.wait(io.TOUCH) while True: - ev, *pos = yield touch + ev, *pos = await touch if ev == io.TOUCH_START: break while True: - ev, *pos = yield touch + ev, *pos = await touch if ev == io.TOUCH_END: break - return pos + return pos # type: ignore -def backlight_fade(val: int, delay: int = 14000, step: int = 15): +def backlight_fade(val: int, delay: int = 14000, step: int = 15) -> None: if __debug__: if utils.DISABLE_FADE: display.backlight(val) @@ -96,7 +102,7 @@ def header( fg: int = style.FG, bg: int = style.BG, ifg: int = style.GREEN, -): +) -> None: if icon is not None: display.icon(14, 15, res.load(icon), ifg, bg) display.text(44, 35, title, BOLD, fg, bg) @@ -113,7 +119,7 @@ def grid( cells_x: int = 1, cells_y: int = 1, spacing: int = 0, -): +) -> Area: w = (end_x - start_x) // n_x h = (end_y - start_y) // n_y x = (i % n_x) * w @@ -121,7 +127,7 @@ def grid( return (x + start_x, y + start_y, (w - spacing) * cells_x, (h - spacing) * cells_y) -def in_area(area: tuple, x: int, y: int) -> bool: +def in_area(area: Area, x: int, y: int) -> bool: ax, ay, aw, ah = area return ax <= x <= ax + aw and ay <= y <= ay + ah @@ -132,7 +138,7 @@ REPAINT = const(-256) class Control: - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: if event is RENDER: self.on_render() elif event is io.TOUCH_START: @@ -144,16 +150,16 @@ class Control: elif event is REPAINT: self.repaint = True - def on_render(self): + def on_render(self) -> None: pass - def on_touch_start(self, x, y): + def on_touch_start(self, x: int, y: int) -> None: pass - def on_touch_move(self, x, y): + def on_touch_move(self, x: int, y: int) -> None: pass - def on_touch_end(self, x, y): + def on_touch_end(self, x: int, y: int) -> None: pass @@ -164,8 +170,12 @@ class LayoutCancelled(Exception): pass +if False: + ResultValue = TypeVar("ResultValue") + + class Result(Exception): - def __init__(self, value): + def __init__(self, value: ResultValue) -> None: self.value = value @@ -173,7 +183,7 @@ class Layout(Control): """ """ - async def __iter__(self): + async def __iter__(self) -> ResultValue: value = None try: if workflow.layout_signal.task is not None: @@ -188,17 +198,20 @@ class Layout(Control): workflow.onlayoutclose(self) return value - def create_tasks(self): + def __await__(self) -> Generator[Any, Any, ResultValue]: + return self.__iter__() # type: ignore + + def create_tasks(self) -> Iterable[loop.Task]: return self.handle_input(), self.handle_rendering() - def handle_input(self): + def handle_input(self) -> loop.Task: # type: ignore touch = loop.wait(io.TOUCH) while True: event, x, y = yield touch self.dispatch(event, x, y) self.dispatch(RENDER, 0, 0) - def handle_rendering(self): + def handle_rendering(self) -> loop.Task: # type: ignore backlight_fade(style.BACKLIGHT_DIM) display.clear() self.dispatch(RENDER, 0, 0) diff --git a/core/src/trezor/ui/button.py b/core/src/trezor/ui/button.py index cf82889b8..c433185b3 100644 --- a/core/src/trezor/ui/button.py +++ b/core/src/trezor/ui/button.py @@ -3,6 +3,9 @@ from micropython import const from trezor import ui from trezor.ui import display, in_area +if False: + from typing import Type, Union + class ButtonDefault: class normal: @@ -12,14 +15,14 @@ class ButtonDefault: border_color = ui.BG radius = ui.RADIUS - class active: + class active(normal): bg_color = ui.FG fg_color = ui.BLACKISH text_style = ui.BOLD border_color = ui.FG radius = ui.RADIUS - class disabled: + class disabled(normal): bg_color = ui.BG fg_color = ui.GREY text_style = ui.NORMAL @@ -38,7 +41,7 @@ class ButtonMono(ButtonDefault): text_style = ui.MONO -class ButtonMonoDark: +class ButtonMonoDark(ButtonDefault): class normal: bg_color = ui.DARK_BLACK fg_color = ui.DARK_WHITE @@ -46,14 +49,14 @@ class ButtonMonoDark: border_color = ui.BG radius = ui.RADIUS - class active: + class active(normal): bg_color = ui.FG fg_color = ui.DARK_BLACK text_style = ui.MONO border_color = ui.FG radius = ui.RADIUS - class disabled: + class disabled(normal): bg_color = ui.DARK_BLACK fg_color = ui.GREY text_style = ui.MONO @@ -98,6 +101,12 @@ class ButtonMonoConfirm(ButtonDefault): text_style = ui.MONO +if False: + ButtonContent = Union[str, bytes] + ButtonStyleType = Type[ButtonDefault] + ButtonStyleStateType = Type[ButtonDefault.normal] + + # button states _INITIAL = const(0) _PRESSED = const(1) @@ -110,39 +119,53 @@ _BORDER = const(4) # border size in pixels class Button(ui.Control): - def __init__(self, area, content, style=ButtonDefault): + def __init__( + self, + area: ui.Area, + content: ButtonContent, + style: ButtonStyleType = ButtonDefault, + ) -> None: + if isinstance(content, str): + self.text = content + self.icon = b"" + elif isinstance(content, bytes): + self.icon = content + self.text = "" + else: + raise TypeError self.area = area - self.content = content self.normal_style = style.normal self.active_style = style.active self.disabled_style = style.disabled self.state = _INITIAL self.repaint = True - def enable(self): + def enable(self) -> None: if self.state is not _INITIAL: self.state = _INITIAL self.repaint = True - def disable(self): + def disable(self) -> None: if self.state is not _DISABLED: self.state = _DISABLED self.repaint = True - def on_render(self): + def on_render(self) -> None: if self.repaint: - if self.state is _DISABLED: + if self.state is _INITIAL or self.state is _RELEASED: + s = self.normal_style + elif self.state is _DISABLED: s = self.disabled_style elif self.state is _PRESSED: s = self.active_style - else: - s = self.normal_style ax, ay, aw, ah = self.area self.render_background(s, ax, ay, aw, ah) self.render_content(s, ax, ay, aw, ah) self.repaint = False - def render_background(self, s, ax, ay, aw, ah): + def render_background( + self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int + ) -> None: radius = s.radius bg_color = s.bg_color border_color = s.border_color @@ -162,16 +185,21 @@ class Button(ui.Control): radius, ) - def render_content(self, s, ax, ay, aw, ah): + def render_content( + self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int + ) -> None: tx = ax + aw // 2 ty = ay + ah // 2 + 8 - t = self.content - if isinstance(t, str): + t = self.text + if t: display.text_center(tx, ty, t, s.text_style, s.fg_color, s.bg_color) - elif isinstance(t, bytes): - display.icon(tx - _ICON // 2, ty - _ICON, t, s.fg_color, s.bg_color) + return + i = self.icon + if i: + display.icon(tx - _ICON // 2, ty - _ICON, i, s.fg_color, s.bg_color) + return - def on_touch_start(self, x, y): + def on_touch_start(self, x: int, y: int) -> None: if self.state is _DISABLED: return if in_area(self.area, x, y): @@ -179,7 +207,7 @@ class Button(ui.Control): self.repaint = True self.on_press_start() - def on_touch_move(self, x, y): + def on_touch_move(self, x: int, y: int) -> None: if self.state is _DISABLED: return if in_area(self.area, x, y): @@ -193,7 +221,7 @@ class Button(ui.Control): self.repaint = True self.on_press_end() - def on_touch_end(self, x, y): + def on_touch_end(self, x: int, y: int) -> None: state = self.state if state is not _INITIAL and state is not _DISABLED: self.state = _INITIAL @@ -203,11 +231,11 @@ class Button(ui.Control): self.on_press_end() self.on_click() - def on_press_start(self): + def on_press_start(self) -> None: pass - def on_press_end(self): + def on_press_end(self) -> None: pass - def on_click(self): + def on_click(self) -> None: pass diff --git a/core/src/trezor/ui/checklist.py b/core/src/trezor/ui/checklist.py index 789223897..ba505b48b 100644 --- a/core/src/trezor/ui/checklist.py +++ b/core/src/trezor/ui/checklist.py @@ -3,32 +3,37 @@ from micropython import const from trezor import res, ui from trezor.ui.text import TEXT_HEADER_HEIGHT, TEXT_LINE_HEIGHT +if False: + from typing import Iterable, List, Union + + ChecklistItem = Union[str, Iterable[str]] + _CHECKLIST_MAX_LINES = const(5) _CHECKLIST_OFFSET_X = const(24) _CHECKLIST_OFFSET_X_ICON = const(0) class Checklist(ui.Control): - def __init__(self, title, icon): + def __init__(self, title: str, icon: str) -> None: self.title = title self.icon = icon - self.items = [] + self.items = [] # type: List[ChecklistItem] self.active = 0 self.repaint = True - def add(self, choice): - self.items.append(choice) + def add(self, item: ChecklistItem) -> None: + self.items.append(item) - def select(self, active): + def select(self, active: int) -> None: self.active = active - def on_render(self): + def on_render(self) -> None: if self.repaint: ui.header(self.title, self.icon) self.render_items() self.repaint = False - def render_items(self): + def render_items(self) -> None: offset_x = _CHECKLIST_OFFSET_X offset_y = TEXT_HEADER_HEIGHT + TEXT_LINE_HEIGHT bg = ui.BG diff --git a/core/src/trezor/ui/confirm.py b/core/src/trezor/ui/confirm.py index b28d1438c..f5d611f58 100644 --- a/core/src/trezor/ui/confirm.py +++ b/core/src/trezor/ui/confirm.py @@ -2,6 +2,10 @@ from trezor import res, ui from trezor.ui.button import Button, ButtonCancel, ButtonConfirm from trezor.ui.loader import Loader, LoaderDefault +if False: + from trezor.ui.button import ButtonContent, ButtonStyleType + from trezor.ui.loader import LoaderStyleType + CONFIRMED = object() CANCELLED = object() @@ -14,13 +18,13 @@ class Confirm(ui.Layout): def __init__( self, - content, - confirm=DEFAULT_CONFIRM, - confirm_style=DEFAULT_CONFIRM_STYLE, - cancel=DEFAULT_CANCEL, - cancel_style=DEFAULT_CANCEL_STYLE, - major_confirm=False, - ): + content: ui.Control, + confirm: ButtonContent = DEFAULT_CONFIRM, + confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE, + cancel: ButtonContent = DEFAULT_CANCEL, + cancel_style: ButtonStyleType = DEFAULT_CANCEL_STYLE, + major_confirm: bool = False, + ) -> None: self.content = content if confirm is not None: @@ -31,7 +35,7 @@ class Confirm(ui.Layout): else: area = ui.grid(9, n_x=2) self.confirm = Button(area, confirm, confirm_style) - self.confirm.on_click = self.on_confirm + self.confirm.on_click = self.on_confirm # type: ignore else: self.confirm = None @@ -43,21 +47,21 @@ class Confirm(ui.Layout): else: area = ui.grid(8, n_x=2) self.cancel = Button(area, cancel, cancel_style) - self.cancel.on_click = self.on_cancel + self.cancel.on_click = self.on_cancel # type: ignore else: self.cancel = None - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: self.content.dispatch(event, x, y) if self.confirm is not None: self.confirm.dispatch(event, x, y) if self.cancel is not None: self.cancel.dispatch(event, x, y) - def on_confirm(self): + def on_confirm(self) -> None: raise ui.Result(CONFIRMED) - def on_cancel(self): + def on_cancel(self) -> None: raise ui.Result(CANCELLED) @@ -68,44 +72,44 @@ class HoldToConfirm(ui.Layout): def __init__( self, - content, - confirm=DEFAULT_CONFIRM, - confirm_style=DEFAULT_CONFIRM_STYLE, - loader_style=DEFAULT_LOADER_STYLE, + content: ui.Control, + confirm: str = DEFAULT_CONFIRM, + confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE, + loader_style: LoaderStyleType = DEFAULT_LOADER_STYLE, ): self.content = content self.loader = Loader(loader_style) - self.loader.on_start = self._on_loader_start + self.loader.on_start = self._on_loader_start # type: ignore self.button = Button(ui.grid(4, n_x=1), confirm, confirm_style) - self.button.on_press_start = self._on_press_start - self.button.on_press_end = self._on_press_end - self.button.on_click = self._on_click + self.button.on_press_start = self._on_press_start # type: ignore + self.button.on_press_end = self._on_press_end # type: ignore + self.button.on_click = self._on_click # type: ignore - def _on_press_start(self): + def _on_press_start(self) -> None: self.loader.start() - def _on_press_end(self): + def _on_press_end(self) -> None: self.loader.stop() - def _on_loader_start(self): + def _on_loader_start(self) -> None: # Loader has either started growing, or returned to the 0-position. # In the first case we need to clear the content leftovers, in the latter # we need to render the content again. ui.display.bar(0, 0, ui.WIDTH, ui.HEIGHT - 60, ui.BG) self.content.dispatch(ui.REPAINT, 0, 0) - def _on_click(self): + def _on_click(self) -> None: if self.loader.elapsed_ms() >= self.loader.target_ms: self.on_confirm() - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: if self.loader.start_ms is not None: self.loader.dispatch(event, x, y) else: self.content.dispatch(event, x, y) self.button.dispatch(event, x, y) - def on_confirm(self): + def on_confirm(self) -> None: raise ui.Result(CONFIRMED) diff --git a/core/src/trezor/ui/container.py b/core/src/trezor/ui/container.py index 98da6bdfd..476a7f5c9 100644 --- a/core/src/trezor/ui/container.py +++ b/core/src/trezor/ui/container.py @@ -2,9 +2,9 @@ from trezor import ui class Container(ui.Control): - def __init__(self, *children): + def __init__(self, *children: ui.Control): self.children = children - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: for child in self.children: child.dispatch(event, x, y) diff --git a/core/src/trezor/ui/info.py b/core/src/trezor/ui/info.py index 07fb77f31..b47021d2b 100644 --- a/core/src/trezor/ui/info.py +++ b/core/src/trezor/ui/info.py @@ -3,6 +3,10 @@ from trezor.ui.button import Button, ButtonConfirm from trezor.ui.confirm import CONFIRMED from trezor.ui.text import TEXT_LINE_HEIGHT, TEXT_MARGIN_LEFT, render_text +if False: + from typing import Type + from trezor.ui.button import ButtonContent + class DefaultInfoConfirm: @@ -17,26 +21,35 @@ class DefaultInfoConfirm: border_color = ui.BLACKISH +if False: + InfoConfirmStyleType = Type[DefaultInfoConfirm] + + class InfoConfirm(ui.Layout): DEFAULT_CONFIRM = res.load(ui.ICON_CONFIRM) DEFAULT_STYLE = DefaultInfoConfirm - def __init__(self, text, confirm=DEFAULT_CONFIRM, style=DEFAULT_STYLE): + def __init__( + self, + text: str, + confirm: ButtonContent = DEFAULT_CONFIRM, + style: InfoConfirmStyleType = DEFAULT_STYLE, + ) -> None: self.text = text.split() self.style = style panel_area = ui.grid(0, n_x=1, n_y=1) self.panel_area = panel_area confirm_area = ui.grid(4, n_x=1) self.confirm = Button(confirm_area, confirm, style.button) - self.confirm.on_click = self.on_confirm + self.confirm.on_click = self.on_confirm # type: ignore self.repaint = True - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: if event == ui.RENDER: self.on_render() self.confirm.dispatch(event, x, y) - def on_render(self): + def on_render(self) -> None: if self.repaint: x, y, w, h = self.panel_area fg_color = self.style.fg_color @@ -46,7 +59,7 @@ class InfoConfirm(ui.Layout): ui.display.bar_radius(x, y, w, h, bg_color, ui.BG, ui.RADIUS) # render the info text - render_text( + render_text( # type: ignore self.text, new_lines=False, max_lines=6, @@ -59,5 +72,5 @@ class InfoConfirm(ui.Layout): self.repaint = False - def on_confirm(self): + def on_confirm(self) -> None: raise ui.Result(CONFIRMED) diff --git a/core/src/trezor/ui/loader.py b/core/src/trezor/ui/loader.py index fd992fbb5..03786ea8f 100644 --- a/core/src/trezor/ui/loader.py +++ b/core/src/trezor/ui/loader.py @@ -4,22 +4,25 @@ from micropython import const from trezor import res, ui from trezor.ui import display +if False: + from typing import Optional, Type + class LoaderDefault: class normal: bg_color = ui.BG fg_color = ui.GREEN - icon = None - icon_fg_color = None + icon = None # type: Optional[str] + icon_fg_color = None # type: Optional[int] - class active: + class active(normal): bg_color = ui.BG fg_color = ui.GREEN icon = ui.ICON_CHECK icon_fg_color = ui.WHITE -class LoaderDanger: +class LoaderDanger(LoaderDefault): class normal(LoaderDefault.normal): fg_color = ui.RED @@ -27,31 +30,35 @@ class LoaderDanger: fg_color = ui.RED +if False: + LoaderStyleType = Type[LoaderDefault] + + _TARGET_MS = const(1000) class Loader(ui.Control): - def __init__(self, style=LoaderDefault): + def __init__(self, style: LoaderStyleType = LoaderDefault) -> None: self.normal_style = style.normal self.active_style = style.active self.target_ms = _TARGET_MS self.start_ms = None self.stop_ms = None - def start(self): + def start(self) -> None: self.start_ms = utime.ticks_ms() self.stop_ms = None self.on_start() - def stop(self): + def stop(self) -> None: self.stop_ms = utime.ticks_ms() - def elapsed_ms(self): + def elapsed_ms(self) -> int: if self.start_ms is None: return 0 return utime.ticks_ms() - self.start_ms - def on_render(self): + def on_render(self) -> None: target = self.target_ms start = self.start_ms stop = self.stop_ms @@ -60,10 +67,10 @@ class Loader(ui.Control): r = min(now - start, target) else: r = max(stop - start + (stop - now) * 2, 0) - if r == target: - s = self.active_style - else: + if r != target: s = self.normal_style + else: + s = self.active_style Y = const(-24) @@ -80,7 +87,7 @@ class Loader(ui.Control): if r == target: self.on_finish() - def on_start(self): + def on_start(self) -> None: pass def on_finish(self): diff --git a/core/src/trezor/ui/mnemonic_bip39.py b/core/src/trezor/ui/mnemonic_bip39.py index 1c0186a54..2f67c45cf 100644 --- a/core/src/trezor/ui/mnemonic_bip39.py +++ b/core/src/trezor/ui/mnemonic_bip39.py @@ -3,6 +3,10 @@ from trezor.crypto import bip39 from trezor.ui import display from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm +if False: + from typing import Optional + from trezor.ui.button import ButtonContent, ButtonStyleStateType + def compute_mask(text: str) -> int: mask = 0 @@ -15,44 +19,47 @@ def compute_mask(text: str) -> int: class KeyButton(Button): - def __init__(self, area, content, keyboard): + def __init__( + self, area: ui.Area, content: ButtonContent, keyboard: "Bip39Keyboard" + ): self.keyboard = keyboard super().__init__(area, content) - def on_click(self): + def on_click(self) -> None: self.keyboard.on_key_click(self) class InputButton(Button): - def __init__(self, area, content, word): - super().__init__(area, content) + def __init__(self, area: ui.Area, text: str, word: str) -> None: + super().__init__(area, text) self.word = word - self.pending = False # should we draw the pending marker? - self.icon = None # rendered icon + self.pending = False self.disable() - def edit(self, content, word, pending): + def edit(self, text: str, word: str, pending: bool) -> None: self.word = word - self.content = content + self.text = text self.pending = pending self.repaint = True if word: - if content == word: # confirm button + if text == word: # confirm button self.enable() self.normal_style = ButtonMonoConfirm.normal self.active_style = ButtonMonoConfirm.active - self.icon = ui.ICON_CONFIRM + self.icon = res.load(ui.ICON_CONFIRM) else: # auto-complete button self.enable() self.normal_style = ButtonMono.normal self.active_style = ButtonMono.active - self.icon = ui.ICON_CLICK + self.icon = res.load(ui.ICON_CLICK) else: # disabled button self.disabled_style = ButtonMono.disabled self.disable() - self.icon = None + self.icon = b"" - def render_content(self, s, ax, ay, aw, ah): + def render_content( + self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int + ) -> None: text_style = s.text_style fg_color = s.fg_color bg_color = s.bg_color @@ -61,29 +68,29 @@ class InputButton(Button): ty = ay + ah // 2 + 8 # y-offset of the content # entered content - display.text(tx, ty, self.content, text_style, fg_color, bg_color) + display.text(tx, ty, self.text, text_style, fg_color, bg_color) # word suggestion - suggested_word = self.word[len(self.content) :] - width = display.text_width(self.content, text_style) + suggested_word = self.word[len(self.text) :] + width = display.text_width(self.text, text_style) display.text(tx + width, ty, suggested_word, text_style, ui.GREY, bg_color) if self.pending: - pw = display.text_width(self.content[-1:], text_style) + pw = display.text_width(self.text[-1:], text_style) px = tx + width - pw display.bar(px, ty + 2, pw + 1, 3, fg_color) if self.icon: ix = ax + aw - 16 * 2 iy = ty - 16 - display.icon(ix, iy, res.load(self.icon), fg_color, bg_color) + display.icon(ix, iy, self.icon, fg_color, bg_color) class Prompt(ui.Control): - def __init__(self, prompt): + def __init__(self, prompt: str) -> None: self.prompt = prompt self.repaint = True - def on_render(self): + def on_render(self) -> None: if self.repaint: display.bar(0, 8, ui.WIDTH, 60, ui.BG) display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG) @@ -91,15 +98,15 @@ class Prompt(ui.Control): class Bip39Keyboard(ui.Layout): - def __init__(self, prompt): + def __init__(self, prompt: str) -> None: self.prompt = Prompt(prompt) icon_back = res.load(ui.ICON_BACK) self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear) - self.back.on_click = self.on_back_click + self.back.on_click = self.on_back_click # type: ignore self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), "", "") - self.input.on_click = self.on_input_click + self.input.on_click = self.on_input_click # type: ignore self.keys = [ KeyButton(ui.grid(i + 3, n_y=4), k, self) @@ -107,82 +114,82 @@ class Bip39Keyboard(ui.Layout): ("abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz") ) ] - self.pending_button = None + self.pending_button = None # type: Optional[Button] self.pending_index = 0 - def dispatch(self, event: int, x: int, y: int): + def dispatch(self, event: int, x: int, y: int) -> None: for btn in self.keys: btn.dispatch(event, x, y) - if self.input.content: + if self.input.text: self.input.dispatch(event, x, y) self.back.dispatch(event, x, y) else: self.prompt.dispatch(event, x, y) - def on_back_click(self): + def on_back_click(self) -> None: # Backspace was clicked, let's delete the last character of input. - self.edit(self.input.content[:-1]) + self.edit(self.input.text[:-1]) - def on_input_click(self): + def on_input_click(self) -> None: # Input button was clicked. If the content matches the suggested word, # let's confirm it, otherwise just auto-complete. - content = self.input.content + text = self.input.text word = self.input.word - if word and word == content: + if word and word == text: self.edit("") self.on_confirm(word) else: self.edit(word) - def on_key_click(self, btn: Button): + def on_key_click(self, btn: Button) -> None: # Key button was clicked. If this button is pending, let's cycle the # pending character in input. If not, let's just append the first # character. if self.pending_button is btn: - index = (self.pending_index + 1) % len(btn.content) - content = self.input.content[:-1] + btn.content[index] + index = (self.pending_index + 1) % len(btn.text) + text = self.input.text[:-1] + btn.text[index] else: index = 0 - content = self.input.content + btn.content[0] - self.edit(content, btn, index) + text = self.input.text + btn.text[0] + self.edit(text, btn, index) - def on_timeout(self): + def on_timeout(self) -> None: # Timeout occurred. If we can auto-complete current input, let's just # reset the pending marker. If not, input is invalid, let's backspace # the last character. if self.input.word: - self.edit(self.input.content) + self.edit(self.input.text) else: - self.edit(self.input.content[:-1]) + self.edit(self.input.text[:-1]) - def on_confirm(self, word): + def on_confirm(self, word: str) -> None: # Word was confirmed by the user. raise ui.Result(word) - def edit(self, content: str, button: KeyButton = None, index: int = 0): + def edit(self, text: str, button: Button = None, index: int = 0) -> None: self.pending_button = button self.pending_index = index # find the completions pending = button is not None - word = bip39.find_word(content) or "" - mask = bip39.complete_word(content) + word = bip39.find_word(text) or "" + mask = bip39.complete_word(text) # modify the input state - self.input.edit(content, word, pending) + self.input.edit(text, word, pending) # enable or disable key buttons for btn in self.keys: - if btn is button or compute_mask(btn.content) & mask: + if btn is button or compute_mask(btn.text) & mask: btn.enable() else: btn.disable() # invalidate the prompt if we display it next frame - if not self.input.content: + if not self.input.text: self.prompt.repaint = True - async def handle_input(self): + async def handle_input(self) -> None: touch = loop.wait(io.TOUCH) timeout = loop.sleep(1000 * 1000 * 1) spawn_touch = loop.spawn(touch) diff --git a/core/src/trezor/ui/mnemonic_slip39.py b/core/src/trezor/ui/mnemonic_slip39.py index 3c980f9fb..a3375e45b 100644 --- a/core/src/trezor/ui/mnemonic_slip39.py +++ b/core/src/trezor/ui/mnemonic_slip39.py @@ -3,44 +3,61 @@ from trezor.crypto import slip39 from trezor.ui import display from trezor.ui.button import Button, ButtonClear, ButtonMono, ButtonMonoConfirm +if False: + from typing import Optional + from trezor.ui.button import ButtonContent, ButtonStyleStateType + class KeyButton(Button): - def __init__(self, area, content, keyboard, index): + def __init__( + self, + area: ui.Area, + content: ButtonContent, + keyboard: "Slip39Keyboard", + index: int, + ): self.keyboard = keyboard self.index = index super().__init__(area, content) - def on_click(self): + def on_click(self) -> None: self.keyboard.on_key_click(self) class InputButton(Button): - def __init__(self, area, keyboard): + def __init__(self, area: ui.Area, keyboard: "Slip39Keyboard") -> None: super().__init__(area, "") self.word = "" - self.pending_button = None - self.pending_index = None - self.icon = None # rendered icon + self.pending_button = None # type: Optional[Button] + self.pending_index = None # type: Optional[int] self.keyboard = keyboard self.disable() - def edit(self, content, word, pending_button, pending_index): + def edit( + self, + text: str, + word: str, + pending_button: Optional[Button], + pending_index: Optional[int], + ) -> None: self.word = word - self.content = content + self.text = text self.pending_button = pending_button self.pending_index = pending_index self.repaint = True - if word: + if word: # confirm button self.enable() self.normal_style = ButtonMonoConfirm.normal self.active_style = ButtonMonoConfirm.active - self.icon = ui.ICON_CONFIRM + self.icon = res.load(ui.ICON_CONFIRM) else: # disabled button - self.disabled_style = ButtonMono.normal + self.disabled_style = ButtonMono.disabled self.disable() - self.icon = None + self.icon = b"" - def render_content(self, s, ax, ay, aw, ah): + def render_content( + self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int + ) -> None: text_style = s.text_style fg_color = s.fg_color bg_color = s.bg_color @@ -49,11 +66,11 @@ class InputButton(Button): ty = ay + ah // 2 + 8 # y-offset of the content if not self.keyboard.is_input_final(): - to_display = len(self.content) * "*" - if self.pending_button: - to_display = ( - to_display[:-1] + self.pending_button.content[self.pending_index] - ) + pending_button = self.pending_button + pending_index = self.pending_index + to_display = len(self.text) * "*" + if pending_button and pending_index is not None: + to_display = to_display[:-1] + pending_button.text[pending_index] else: to_display = self.word @@ -61,22 +78,22 @@ class InputButton(Button): if self.pending_button and not self.keyboard.is_input_final(): width = display.text_width(to_display, text_style) - pw = display.text_width(self.content[-1:], text_style) + pw = display.text_width(self.text[-1:], text_style) px = tx + width - pw display.bar(px, ty + 2, pw + 1, 3, fg_color) if self.icon: ix = ax + aw - 16 * 2 iy = ty - 16 - display.icon(ix, iy, res.load(self.icon), fg_color, bg_color) + display.icon(ix, iy, self.icon, fg_color, bg_color) class Prompt(ui.Control): - def __init__(self, prompt): + def __init__(self, prompt: str) -> None: self.prompt = prompt self.repaint = True - def on_render(self): + def on_render(self) -> None: if self.repaint: display.bar(0, 8, ui.WIDTH, 60, ui.BG) display.text(20, 40, self.prompt, ui.BOLD, ui.GREY, ui.BG) @@ -84,15 +101,15 @@ class Prompt(ui.Control): class Slip39Keyboard(ui.Layout): - def __init__(self, prompt): + def __init__(self, prompt: str) -> None: self.prompt = Prompt(prompt) icon_back = res.load(ui.ICON_BACK) self.back = Button(ui.grid(0, n_x=3, n_y=4), icon_back, ButtonClear) - self.back.on_click = self.on_back_click + self.back.on_click = self.on_back_click # type: ignore self.input = InputButton(ui.grid(1, n_x=3, n_y=4, cells_x=2), self) - self.input.on_click = self.on_input_click + self.input.on_click = self.on_input_click # type: ignore self.keys = [ KeyButton(ui.grid(i + 3, n_y=4), k, self, i + 1) @@ -100,26 +117,26 @@ class Slip39Keyboard(ui.Layout): ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz") ) ] - self.pending_button = None + self.pending_button = None # type: Optional[Button] self.pending_index = 0 self.button_sequence = "" self.mask = slip39.KEYBOARD_FULL_MASK - def dispatch(self, event: int, x: int, y: int): + def dispatch(self, event: int, x: int, y: int) -> None: for btn in self.keys: btn.dispatch(event, x, y) - if self.input.content: + if self.input.text: self.input.dispatch(event, x, y) self.back.dispatch(event, x, y) else: self.prompt.dispatch(event, x, y) - def on_back_click(self): + def on_back_click(self) -> None: # Backspace was clicked, let's delete the last character of input. self.button_sequence = self.button_sequence[:-1] self.edit() - def on_input_click(self): + def on_input_click(self) -> None: # Input button was clicked. If the content matches the suggested word, # let's confirm it, otherwise just auto-complete. result = self.input.word @@ -128,26 +145,26 @@ class Slip39Keyboard(ui.Layout): self.edit() self.on_confirm(result) - def on_key_click(self, btn: KeyButton): + def on_key_click(self, btn: KeyButton) -> None: # Key button was clicked. If this button is pending, let's cycle the # pending character in input. If not, let's just append the first # character. if self.pending_button is btn: - index = (self.pending_index + 1) % len(btn.content) + index = (self.pending_index + 1) % len(btn.text) else: index = 0 self.button_sequence += str(btn.index) self.edit(btn, index) - def on_timeout(self): + def on_timeout(self) -> None: # Timeout occurred. Let's redraw to draw asterisks. self.edit() - def on_confirm(self, word): + def on_confirm(self, word: str) -> None: # Word was confirmed by the user. raise ui.Result(word) - def edit(self, button: KeyButton = None, index: int = 0): + def edit(self, button: Button = None, index: int = 0) -> None: self.pending_button = button self.pending_index = index @@ -172,7 +189,7 @@ class Slip39Keyboard(ui.Layout): btn.disable() # invalidate the prompt if we display it next frame - if not self.input.content: + if not self.input.text: self.prompt.repaint = True def is_input_final(self) -> bool: @@ -182,7 +199,7 @@ class Slip39Keyboard(ui.Layout): def check_mask(self, index: int) -> bool: return bool((1 << (index - 1)) & self.mask) - async def handle_input(self): + async def handle_input(self) -> None: touch = loop.wait(io.TOUCH) timeout = loop.sleep(1000 * 1000 * 1) spawn_touch = loop.spawn(touch) diff --git a/core/src/trezor/ui/passphrase.py b/core/src/trezor/ui/passphrase.py index b4b308763..0544f973f 100644 --- a/core/src/trezor/ui/passphrase.py +++ b/core/src/trezor/ui/passphrase.py @@ -6,6 +6,10 @@ from trezor.ui import display from trezor.ui.button import Button, ButtonClear, ButtonConfirm from trezor.ui.swipe import SWIPE_HORIZONTAL, SWIPE_LEFT, Swipe +if False: + from typing import List, Iterable, Optional + from trezor.ui.button import ButtonContent, ButtonStyleStateType + SPACE = res.load(ui.ICON_SPACE) KEYBOARD_KEYS = ( @@ -16,13 +20,13 @@ KEYBOARD_KEYS = ( ) -def digit_area(i): +def digit_area(i: int) -> ui.Area: if i == 9: # 0-position i = 10 # display it in the middle return ui.grid(i + 3) # skip the first line -def render_scrollbar(page): +def render_scrollbar(page: int) -> None: BBOX = const(240) SIZE = const(8) pages = len(KEYBOARD_KEYS) @@ -43,42 +47,50 @@ def render_scrollbar(page): class KeyButton(Button): - def __init__(self, area, content, keyboard): + def __init__( + self, area: ui.Area, content: ButtonContent, keyboard: "PassphraseKeyboard" + ) -> None: self.keyboard = keyboard super().__init__(area, content) - def on_click(self): + def on_click(self) -> None: self.keyboard.on_key_click(self) - def get_text_content(self): - if self.content is SPACE: + def get_text_content(self) -> str: + if self.text: + return self.text + elif self.icon is SPACE: return " " else: - return self.content + raise TypeError -def key_buttons(keys, keyboard): +def key_buttons( + keys: Iterable[ButtonContent], keyboard: "PassphraseKeyboard" +) -> List[KeyButton]: return [KeyButton(digit_area(i), k, keyboard) for i, k in enumerate(keys)] class Input(Button): - def __init__(self, area, content): - super().__init__(area, content) + def __init__(self, area: ui.Area, text: str) -> None: + super().__init__(area, text) self.pending = False self.disable() - def edit(self, content, pending): - self.content = content + def edit(self, text: str, pending: bool) -> None: + self.text = text self.pending = pending self.repaint = True - def render_content(self, s, ax, ay, aw, ah): + def render_content( + self, s: ButtonStyleStateType, ax: int, ay: int, aw: int, ah: int + ) -> None: text_style = s.text_style fg_color = s.fg_color bg_color = s.bg_color p = self.pending # should we draw the pending marker? - t = self.content # input content + t = self.text # input content tx = ax + 24 # x-offset of the content ty = ay + ah // 2 + 8 # y-offset of the content @@ -98,16 +110,16 @@ class Input(Button): cx = tx + width + 1 display.bar(cx, ty - 18, 2, 22, fg_color) - def on_click(self): + def on_click(self) -> None: pass class Prompt(ui.Control): - def __init__(self, text): + def __init__(self, text: str) -> None: self.text = text self.repaint = True - def on_render(self): + def on_render(self) -> None: if self.repaint: display.bar(0, 0, ui.WIDTH, 48, ui.BG) display.text_center(ui.WIDTH // 2, 32, self.text, ui.BOLD, ui.GREY, ui.BG) @@ -118,7 +130,7 @@ CANCELLED = object() class PassphraseKeyboard(ui.Layout): - def __init__(self, prompt, max_length, page=1): + def __init__(self, prompt: str, max_length: int, page: int = 1) -> None: self.prompt = Prompt(prompt) self.max_length = max_length self.page = page @@ -126,18 +138,18 @@ class PassphraseKeyboard(ui.Layout): self.input = Input(ui.grid(0, n_x=1, n_y=6), "") self.back = Button(ui.grid(12), res.load(ui.ICON_BACK), ButtonClear) - self.back.on_click = self.on_back_click + self.back.on_click = self.on_back_click # type: ignore self.back.disable() self.done = Button(ui.grid(14), res.load(ui.ICON_CONFIRM), ButtonConfirm) - self.done.on_click = self.on_confirm + self.done.on_click = self.on_confirm # type: ignore self.keys = key_buttons(KEYBOARD_KEYS[self.page], self) - self.pending_button = None + self.pending_button = None # type: Optional[KeyButton] self.pending_index = 0 - def dispatch(self, event, x, y): - if self.input.content: + def dispatch(self, event: int, x: int, y: int) -> None: + if self.input.text: self.input.dispatch(event, x, y) else: self.prompt.dispatch(event, x, y) @@ -149,37 +161,37 @@ class PassphraseKeyboard(ui.Layout): if event == ui.RENDER: render_scrollbar(self.page) - def on_back_click(self): + def on_back_click(self) -> None: # Backspace was clicked. If we have any content in the input, let's delete # the last character. Otherwise cancel. - content = self.input.content - if content: - self.edit(content[:-1]) + text = self.input.text + if text: + self.edit(text[:-1]) else: self.on_cancel() - def on_key_click(self, button: KeyButton): + def on_key_click(self, button: KeyButton) -> None: # Key button was clicked. If this button is pending, let's cycle the # pending character in input. If not, let's just append the first # character. button_text = button.get_text_content() if self.pending_button is button: index = (self.pending_index + 1) % len(button_text) - prefix = self.input.content[:-1] + prefix = self.input.text[:-1] else: index = 0 - prefix = self.input.content + prefix = self.input.text if len(button_text) > 1: self.edit(prefix + button_text[index], button, index) else: self.edit(prefix + button_text[index]) - def on_timeout(self): + def on_timeout(self) -> None: # Timeout occurred, let's just reset the pending marker. - self.edit(self.input.content) + self.edit(self.input.text) - def edit(self, content: str, button: Button = None, index: int = 0): - if len(content) > self.max_length: + def edit(self, text: str, button: KeyButton = None, index: int = 0) -> None: + if len(text) > self.max_length: return self.pending_button = button @@ -187,15 +199,15 @@ class PassphraseKeyboard(ui.Layout): # modify the input state pending = button is not None - self.input.edit(content, pending) + self.input.edit(text, pending) - if content: + if text: self.back.enable() else: self.back.disable() self.prompt.repaint = True - async def handle_input(self): + async def handle_input(self) -> None: touch = loop.wait(io.TOUCH) timeout = loop.sleep(1000 * 1000 * 1) spawn_touch = loop.spawn(touch) @@ -214,7 +226,7 @@ class PassphraseKeyboard(ui.Layout): else: self.on_timeout() - async def handle_paging(self): + async def handle_paging(self) -> None: swipe = await Swipe(SWIPE_HORIZONTAL) if swipe == SWIPE_LEFT: self.page = (self.page + 1) % len(KEYBOARD_KEYS) @@ -226,33 +238,33 @@ class PassphraseKeyboard(ui.Layout): self.input.repaint = True self.prompt.repaint = True - def on_cancel(self): + def on_cancel(self) -> None: raise ui.Result(CANCELLED) - def on_confirm(self): - raise ui.Result(self.input.content) + def on_confirm(self) -> None: + raise ui.Result(self.input.text) - def create_tasks(self): + def create_tasks(self) -> Iterable[loop.Task]: return self.handle_input(), self.handle_rendering(), self.handle_paging() class PassphraseSource(ui.Layout): - def __init__(self, content): + def __init__(self, content: ui.Control) -> None: self.content = content self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device") - self.device.on_click = self.on_device + self.device.on_click = self.on_device # type: ignore self.host = Button(ui.grid(12, n_y=4, n_x=4, cells_x=4), "Host") - self.host.on_click = self.on_host + self.host.on_click = self.on_host # type: ignore - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: self.content.dispatch(event, x, y) self.device.dispatch(event, x, y) self.host.dispatch(event, x, y) - def on_device(self): + def on_device(self) -> None: raise ui.Result(PassphraseSourceType.DEVICE) - def on_host(self): + def on_host(self) -> None: raise ui.Result(PassphraseSourceType.HOST) diff --git a/core/src/trezor/ui/pin.py b/core/src/trezor/ui/pin.py index ac4b472d3..c66c2e2d8 100644 --- a/core/src/trezor/ui/pin.py +++ b/core/src/trezor/ui/pin.py @@ -11,14 +11,17 @@ from trezor.ui.button import ( ButtonMono, ) +if False: + from typing import Iterable -def digit_area(i): + +def digit_area(i: int) -> ui.Area: if i == 9: # 0-position i = 10 # display it in the middle return ui.grid(i + 3) # skip the first line -def generate_digits(): +def generate_digits() -> Iterable[int]: digits = list(range(0, 10)) # 0-9 random.shuffle(digits) # We lay out the buttons top-left to bottom-right, but the order @@ -27,13 +30,13 @@ def generate_digits(): class PinInput(ui.Control): - def __init__(self, prompt, subprompt, pin): + def __init__(self, prompt: str, subprompt: str, pin: str) -> None: self.prompt = prompt self.subprompt = subprompt self.pin = pin self.repaint = True - def on_render(self): + def on_render(self) -> None: if self.repaint: if self.pin: self.render_pin() @@ -41,7 +44,7 @@ class PinInput(ui.Control): self.render_prompt() self.repaint = False - def render_pin(self): + def render_pin(self) -> None: display.bar(0, 0, ui.WIDTH, 50, ui.BG) count = len(self.pin) BOX_WIDTH = const(240) @@ -54,7 +57,7 @@ class PinInput(ui.Control): render_x + i * PADDING, RENDER_Y, DOT_SIZE, DOT_SIZE, ui.GREY, ui.BG, 4 ) - def render_prompt(self): + def render_prompt(self) -> None: display.bar(0, 0, ui.WIDTH, 50, ui.BG) if self.subprompt: display.text_center(ui.WIDTH // 2, 20, self.prompt, ui.BOLD, ui.GREY, ui.BG) @@ -66,35 +69,37 @@ class PinInput(ui.Control): class PinButton(Button): - def __init__(self, index, digit, matrix): - self.matrix = matrix + def __init__(self, index: int, digit: int, dialog: "PinDialog"): + self.dialog = dialog super().__init__(digit_area(index), str(digit), ButtonMono) - def on_click(self): - self.matrix.assign(self.matrix.input.pin + self.content) + def on_click(self) -> None: + self.dialog.assign(self.dialog.input.pin + self.text) CANCELLED = object() class PinDialog(ui.Layout): - def __init__(self, prompt, subprompt, allow_cancel=True, maxlength=9): + def __init__( + self, prompt: str, subprompt: str, allow_cancel: bool = True, maxlength: int = 9 + ) -> None: self.maxlength = maxlength self.input = PinInput(prompt, subprompt, "") icon_confirm = res.load(ui.ICON_CONFIRM) self.confirm_button = Button(ui.grid(14), icon_confirm, ButtonConfirm) - self.confirm_button.on_click = self.on_confirm + self.confirm_button.on_click = self.on_confirm # type: ignore self.confirm_button.disable() icon_back = res.load(ui.ICON_BACK) self.reset_button = Button(ui.grid(12), icon_back, ButtonClear) - self.reset_button.on_click = self.on_reset + self.reset_button.on_click = self.on_reset # type: ignore if allow_cancel: icon_lock = res.load(ui.ICON_LOCK) self.cancel_button = Button(ui.grid(12), icon_lock, ButtonCancel) - self.cancel_button.on_click = self.on_cancel + self.cancel_button.on_click = self.on_cancel # type: ignore else: self.cancel_button = Button(ui.grid(12), "") self.cancel_button.disable() @@ -103,7 +108,7 @@ class PinDialog(ui.Layout): PinButton(i, d, self) for i, d in enumerate(generate_digits()) ] - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: self.input.dispatch(event, x, y) if self.input.pin: self.reset_button.dispatch(event, x, y) @@ -113,7 +118,7 @@ class PinDialog(ui.Layout): for btn in self.pin_buttons: btn.dispatch(event, x, y) - def assign(self, pin): + def assign(self, pin: str) -> None: if len(pin) > self.maxlength: return for btn in self.pin_buttons: @@ -132,12 +137,12 @@ class PinDialog(ui.Layout): self.input.pin = pin self.input.repaint = True - def on_reset(self): + def on_reset(self) -> None: self.assign("") - def on_cancel(self): + def on_cancel(self) -> None: raise ui.Result(CANCELLED) - def on_confirm(self): + def on_confirm(self) -> None: if self.input.pin: raise ui.Result(self.input.pin) diff --git a/core/src/trezor/ui/popup.py b/core/src/trezor/ui/popup.py index fb6671784..f1ed196a6 100644 --- a/core/src/trezor/ui/popup.py +++ b/core/src/trezor/ui/popup.py @@ -1,17 +1,20 @@ from trezor import loop, ui +if False: + from typing import Iterable + class Popup(ui.Layout): - def __init__(self, content, time_ms=0): + def __init__(self, content: ui.Control, time_ms: int = 0) -> None: self.content = content self.time_ms = time_ms - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: self.content.dispatch(event, x, y) - def create_tasks(self): + def create_tasks(self) -> Iterable[loop.Task]: return self.handle_input(), self.handle_rendering(), self.handle_timeout() - def handle_timeout(self): + def handle_timeout(self) -> loop.Task: # type: ignore yield loop.sleep(self.time_ms * 1000) raise ui.Result(None) diff --git a/core/src/trezor/ui/qr.py b/core/src/trezor/ui/qr.py index 62527915a..bde3de18e 100644 --- a/core/src/trezor/ui/qr.py +++ b/core/src/trezor/ui/qr.py @@ -2,11 +2,11 @@ from trezor import ui class Qr(ui.Control): - def __init__(self, data, x, y, scale): + def __init__(self, data: bytes, x: int, y: int, scale: int): self.data = data self.x = x self.y = y self.scale = scale - def on_render(self): + def on_render(self) -> None: ui.display.qrcode(self.x, self.y, self.data, self.scale) diff --git a/core/src/trezor/ui/scroll.py b/core/src/trezor/ui/scroll.py index dcfdb6835..e370211f5 100644 --- a/core/src/trezor/ui/scroll.py +++ b/core/src/trezor/ui/scroll.py @@ -8,8 +8,11 @@ from trezor.ui.swipe import SWIPE_DOWN, SWIPE_UP, SWIPE_VERTICAL, Swipe if __debug__: from apps.debug import swipe_signal +if False: + from typing import Iterable, Sequence -def render_scrollbar(pages: int, page: int): + +def render_scrollbar(pages: int, page: int) -> None: BBOX = const(220) SIZE = const(8) @@ -28,7 +31,7 @@ def render_scrollbar(pages: int, page: int): ui.display.bar_radius(X, Y + i * padding, SIZE, SIZE, fg, ui.BG, 4) -def render_swipe_icon(): +def render_swipe_icon() -> None: DRAW_DELAY = const(200000) icon = res.load(ui.ICON_SWIPE) @@ -37,18 +40,20 @@ def render_swipe_icon(): ui.display.icon(70, 205, icon, c, ui.BG) -def render_swipe_text(): +def render_swipe_text() -> None: ui.display.text_center(130, 220, "Swipe", ui.BOLD, ui.GREY, ui.BG) class Paginated(ui.Layout): - def __init__(self, pages, page=0, one_by_one=False): + def __init__( + self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False + ): self.pages = pages self.page = page self.one_by_one = one_by_one self.repaint = True - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: pages = self.pages page = self.page pages[page].dispatch(event, x, y) @@ -63,7 +68,7 @@ class Paginated(ui.Layout): render_scrollbar(length, page) self.repaint = False - async def handle_paging(self): + async def handle_paging(self) -> None: if self.page == 0: directions = SWIPE_UP elif self.page == len(self.pages) - 1: @@ -86,21 +91,33 @@ class Paginated(ui.Layout): self.on_change() - def create_tasks(self): + def create_tasks(self) -> Iterable[loop.Task]: return self.handle_input(), self.handle_rendering(), self.handle_paging() - def on_change(self): + def on_change(self) -> None: if self.one_by_one: raise ui.Result(self.page) class PageWithButtons(ui.Control): - def __init__(self, content, paginated, index, count): + def __init__( + self, + content: ui.Control, + paginated: "PaginatedWithButtons", + index: int, + count: int, + ) -> None: self.content = content self.paginated = paginated self.index = index self.count = count + # somewhere in the middle, we can go up or down + left = res.load(ui.ICON_BACK) + left_style = ButtonDefault + right = res.load(ui.ICON_CLICK) + right_style = ButtonDefault + if self.index == 0: # first page, we can cancel or go down left = res.load(ui.ICON_CANCEL) @@ -113,31 +130,25 @@ class PageWithButtons(ui.Control): left_style = ButtonDefault right = res.load(ui.ICON_CONFIRM) right_style = ButtonConfirm - else: - # somewhere in the middle, we can go up or down - left = res.load(ui.ICON_BACK) - left_style = ButtonDefault - right = res.load(ui.ICON_CLICK) - right_style = ButtonDefault self.left = Button(ui.grid(8, n_x=2), left, left_style) - self.left.on_click = self.on_left + self.left.on_click = self.on_left # type: ignore self.right = Button(ui.grid(9, n_x=2), right, right_style) - self.right.on_click = self.on_right + self.right.on_click = self.on_right # type: ignore - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: self.content.dispatch(event, x, y) self.left.dispatch(event, x, y) self.right.dispatch(event, x, y) - def on_left(self): + def on_left(self) -> None: if self.index == 0: self.paginated.on_cancel() else: self.paginated.on_up() - def on_right(self): + def on_right(self) -> None: if self.index == self.count - 1: self.paginated.on_confirm() else: @@ -145,36 +156,38 @@ class PageWithButtons(ui.Control): class PaginatedWithButtons(ui.Layout): - def __init__(self, pages, page=0, one_by_one=False): + def __init__( + self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False + ) -> None: self.pages = [ PageWithButtons(p, self, i, len(pages)) for i, p in enumerate(pages) ] self.page = page self.one_by_one = one_by_one - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: pages = self.pages page = self.page pages[page].dispatch(event, x, y) if event is ui.RENDER: render_scrollbar(len(pages), page) - def on_up(self): + def on_up(self) -> None: self.page = max(self.page - 1, 0) self.pages[self.page].dispatch(ui.REPAINT, 0, 0) self.on_change() - def on_down(self): + def on_down(self) -> None: self.page = min(self.page + 1, len(self.pages) - 1) self.pages[self.page].dispatch(ui.REPAINT, 0, 0) self.on_change() - def on_confirm(self): + def on_confirm(self) -> None: raise ui.Result(CONFIRMED) - def on_cancel(self): + def on_cancel(self) -> None: raise ui.Result(CANCELLED) - def on_change(self): + def on_change(self) -> None: if self.one_by_one: raise ui.Result(self.page) diff --git a/core/src/trezor/ui/shamir.py b/core/src/trezor/ui/shamir.py index 5a7bc91ac..5f77639df 100644 --- a/core/src/trezor/ui/shamir.py +++ b/core/src/trezor/ui/shamir.py @@ -4,31 +4,31 @@ from trezor.ui.text import LABEL_CENTER, Label class NumInput(ui.Control): - def __init__(self, count=5, max_count=16, min_count=1): + def __init__(self, count: int = 5, max_count: int = 16, min_count: int = 1) -> None: self.count = count self.max_count = max_count self.min_count = min_count self.minus = Button(ui.grid(3), "-") - self.minus.on_click = self.on_minus + self.minus.on_click = self.on_minus # type: ignore self.plus = Button(ui.grid(5), "+") - self.plus.on_click = self.on_plus + self.plus.on_click = self.on_plus # type: ignore self.text = Label(ui.grid(4), "", LABEL_CENTER, ui.BOLD) self.edit(count) - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: self.minus.dispatch(event, x, y) self.plus.dispatch(event, x, y) self.text.dispatch(event, x, y) - def on_minus(self): + def on_minus(self) -> None: self.edit(self.count - 1) - def on_plus(self): + def on_plus(self) -> None: self.edit(self.count + 1) - def edit(self, count): + def edit(self, count: int) -> None: count = max(count, self.min_count) count = min(count, self.max_count) if self.count != count: @@ -45,5 +45,5 @@ class NumInput(ui.Control): else: self.plus.enable() - def on_change(self, count): + def on_change(self, count: int) -> None: pass diff --git a/core/src/trezor/ui/swipe.py b/core/src/trezor/ui/swipe.py index 34fe9b82a..b5ecc2514 100644 --- a/core/src/trezor/ui/swipe.py +++ b/core/src/trezor/ui/swipe.py @@ -2,6 +2,9 @@ from micropython import const from trezor import io, loop, ui +if False: + from typing import Generator + SWIPE_UP = const(0x01) SWIPE_DOWN = const(0x02) SWIPE_LEFT = const(0x04) @@ -15,24 +18,26 @@ _SWIPE_TRESHOLD = const(30) class Swipe(ui.Control): - def __init__(self, directions=SWIPE_ALL, area=None): + def __init__(self, directions: int = SWIPE_ALL, area: ui.Area = None) -> None: if area is None: area = (0, 0, ui.WIDTH, ui.HEIGHT) self.area = area self.directions = directions - self.start_x = None - self.start_y = None - self.light_origin = None + self.started = False + self.start_x = 0 + self.start_y = 0 + self.light_origin = ui.BACKLIGHT_NORMAL self.light_target = ui.BACKLIGHT_NONE - def on_touch_start(self, x, y): + def on_touch_start(self, x: int, y: int) -> None: if ui.in_area(self.area, x, y): self.start_x = x self.start_y = y self.light_origin = ui.BACKLIGHT_NORMAL + self.started = True - def on_touch_move(self, x, y): - if self.start_x is None: + def on_touch_move(self, x: int, y: int) -> None: + if not self.started: return # not started in our area dirs = self.directions @@ -61,8 +66,8 @@ class Swipe(ui.Control): ) ) - def on_touch_end(self, x, y): - if self.start_x is None: + def on_touch_end(self, x: int, y: int) -> None: + if not self.started: return # not started in our area dirs = self.directions @@ -93,13 +98,15 @@ class Swipe(ui.Control): # no swipe detected, reset the state ui.display.backlight(self.light_origin) - self.start_x = None - self.start_y = None + self.started = False - def on_swipe(self, swipe): + def on_swipe(self, swipe: int) -> None: raise ui.Result(swipe) - def __iter__(self): + def __await__(self) -> Generator: + return self.__iter__() # type: ignore + + def __iter__(self) -> loop.Task: # type: ignore try: touch = loop.wait(io.TOUCH) while True: diff --git a/core/src/trezor/ui/text.py b/core/src/trezor/ui/text.py index a5cad1f95..86eeacac5 100644 --- a/core/src/trezor/ui/text.py +++ b/core/src/trezor/ui/text.py @@ -2,6 +2,9 @@ from micropython import const from trezor import ui +if False: + from typing import List, Union + TEXT_HEADER_HEIGHT = const(48) TEXT_LINE_HEIGHT = const(26) TEXT_LINE_HEIGHT_HALF = const(13) @@ -12,9 +15,12 @@ TEXT_MAX_LINES = const(5) BR = const(-256) BR_HALF = const(-257) +if False: + TextContent = Union[str, int] + def render_text( - words: list, + words: List[TextContent], new_lines: bool, max_lines: int, font: int = ui.NORMAL, @@ -128,32 +134,32 @@ class Text(ui.Control): self.icon_color = icon_color self.max_lines = max_lines self.new_lines = new_lines - self.content = [] + self.content = [] # type: List[Union[str, int]] self.repaint = True - def normal(self, *content): + def normal(self, *content: TextContent) -> None: self.content.append(ui.NORMAL) self.content.extend(content) - def bold(self, *content): + def bold(self, *content: TextContent) -> None: self.content.append(ui.BOLD) self.content.extend(content) - def mono(self, *content): + def mono(self, *content: TextContent) -> None: self.content.append(ui.MONO) self.content.extend(content) - def mono_bold(self, *content): + def mono_bold(self, *content: TextContent) -> None: self.content.append(ui.MONO_BOLD) self.content.extend(content) - def br(self): + def br(self) -> None: self.content.append(BR) - def br_half(self): + def br_half(self) -> None: self.content.append(BR_HALF) - def on_render(self): + def on_render(self) -> None: if self.repaint: ui.header( self.header_text, @@ -172,21 +178,27 @@ LABEL_RIGHT = const(2) class Label(ui.Control): - def __init__(self, area, content, align=LABEL_LEFT, style=ui.NORMAL): + def __init__( + self, + area: ui.Area, + content: str, + align: int = LABEL_LEFT, + style: int = ui.NORMAL, + ) -> None: self.area = area self.content = content self.align = align self.style = style self.repaint = True - def on_render(self): + def on_render(self) -> None: if self.repaint: align = self.align ax, ay, aw, ah = self.area tx = ax + aw // 2 ty = ay + ah // 2 + 8 if align is LABEL_LEFT: - ui.display.text_left(tx, ty, self.content, self.style, ui.FG, ui.BG, aw) + ui.display.text(tx, ty, self.content, self.style, ui.FG, ui.BG, aw) elif align is LABEL_CENTER: ui.display.text_center( tx, ty, self.content, self.style, ui.FG, ui.BG, aw diff --git a/core/src/trezor/ui/word_select.py b/core/src/trezor/ui/word_select.py index 358c2f705..38fb3ede1 100644 --- a/core/src/trezor/ui/word_select.py +++ b/core/src/trezor/ui/word_select.py @@ -5,20 +5,20 @@ from trezor.ui.button import Button class WordSelector(ui.Layout): - def __init__(self, content): + def __init__(self, content: ui.Control) -> None: self.content = content self.w12 = Button(ui.grid(6, n_y=4), "12") - self.w12.on_click = self.on_w12 + self.w12.on_click = self.on_w12 # type: ignore self.w18 = Button(ui.grid(7, n_y=4), "18") - self.w18.on_click = self.on_w18 + self.w18.on_click = self.on_w18 # type: ignore self.w20 = Button(ui.grid(8, n_y=4), "20") - self.w20.on_click = self.on_w20 + self.w20.on_click = self.on_w20 # type: ignore self.w24 = Button(ui.grid(9, n_y=4), "24") - self.w24.on_click = self.on_w24 + self.w24.on_click = self.on_w24 # type: ignore self.w33 = Button(ui.grid(10, n_y=4), "33") - self.w33.on_click = self.on_w33 + self.w33.on_click = self.on_w33 # type: ignore - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: self.content.dispatch(event, x, y) self.w12.dispatch(event, x, y) self.w18.dispatch(event, x, y) @@ -26,17 +26,17 @@ class WordSelector(ui.Layout): self.w24.dispatch(event, x, y) self.w33.dispatch(event, x, y) - def on_w12(self): + def on_w12(self) -> None: raise ui.Result(12) - def on_w18(self): + def on_w18(self) -> None: raise ui.Result(18) - def on_w20(self): + def on_w20(self) -> None: raise ui.Result(20) - def on_w24(self): + def on_w24(self) -> None: raise ui.Result(24) - def on_w33(self): + def on_w33(self) -> None: raise ui.Result(33) diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 3250eb909..070c5b994 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -27,12 +27,15 @@ if __debug__: SAVE_SCREEN = 0 LOG_MEMORY = 0 +if False: + from typing import Iterable, Iterator, Protocol, List, TypeVar -def unimport_begin(): + +def unimport_begin() -> Iterable[str]: return set(sys.modules) -def unimport_end(mods): +def unimport_end(mods: Iterable[str]) -> None: for mod in sys.modules: if mod not in mods: # remove reference from sys.modules @@ -53,7 +56,7 @@ def unimport_end(mods): gc.collect() -def ensure(cond, msg=None): +def ensure(cond: bool, msg: str = None) -> None: if not cond: if msg is None: raise AssertionError @@ -61,48 +64,71 @@ def ensure(cond, msg=None): raise AssertionError(msg) -def chunks(items, size): +if False: + Chunked = TypeVar("Chunked") + + +def chunks(items: List[Chunked], size: int) -> Iterator[List[Chunked]]: for i in range(0, len(items), size): yield items[i : i + size] -def format_amount(amount, decimals): +def format_amount(amount: int, decimals: int) -> str: d = pow(10, decimals) - amount = ("%d.%0*d" % (amount // d, decimals, amount % d)).rstrip("0") - if amount.endswith("."): - amount = amount[:-1] - return amount + s = ("%d.%0*d" % (amount // d, decimals, amount % d)).rstrip("0").rstrip(".") + return s -def format_ordinal(number): +def format_ordinal(number: int) -> str: return str(number) + {1: "st", 2: "nd", 3: "rd"}.get( 4 if 10 <= number % 100 < 20 else number % 10, "th" ) +if False: + + class HashContext(Protocol): + def update(self, buf: bytes) -> None: + ... + + def digest(self) -> bytes: + ... + + class Writer(Protocol): + def append(self, b: int) -> None: + ... + + def extend(self, buf: bytes) -> None: + ... + + def write(self, buf: bytes) -> None: + ... + + class HashWriter: - def __init__(self, ctx): + def __init__(self, ctx: HashContext) -> None: self.ctx = ctx self.buf = bytearray(1) # used in append() - def extend(self, buf: bytearray): - self.ctx.update(buf) + def append(self, b: int) -> None: + self.buf[0] = b + self.ctx.update(self.buf) - def write(self, buf: bytearray): # alias for extend() + def extend(self, buf: bytes) -> None: self.ctx.update(buf) - async def awrite(self, buf: bytearray): # AsyncWriter interface - return self.ctx.update(buf) + def write(self, buf: bytes) -> None: # alias for extend() + self.ctx.update(buf) - def append(self, b: int): - self.buf[0] = b - self.ctx.update(self.buf) + async def awrite(self, buf: bytes) -> int: # AsyncWriter interface + self.ctx.update(buf) + return len(buf) def get_digest(self) -> bytes: return self.ctx.digest() -def obj_eq(l, r): +def obj_eq(l: object, r: object) -> bool: """ Compares object contents, supports __slots__. """ @@ -118,7 +144,7 @@ def obj_eq(l, r): return True -def obj_repr(o): +def obj_repr(o: object) -> str: """ Returns a string representation of object, supports __slots__. """ diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index e2fe9b6b9..ca1cdec24 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -7,11 +7,28 @@ from trezor.wire.errors import Error # import all errors into namespace, so that `wire.Error` is available elsewhere from trezor.wire.errors import * # isort:skip # noqa: F401,F403 +if False: + from typing import ( + Any, + Awaitable, + Dict, + Callable, + Iterable, + List, + Optional, + Tuple, + Type, + ) + from trezorio import WireInterface + from protobuf import LoadedMessageType, MessageType + + Handler = Callable[..., loop.Task] + -workflow_handlers = {} +workflow_handlers = {} # type: Dict[int, Tuple[Handler, Iterable]] -def add(mtype, pkgname, modname, namespace=None): +def add(mtype: int, pkgname: str, modname: str, namespace: List = None) -> None: """Shortcut for registering a dynamically-imported Protobuf workflow.""" if namespace is not None: register( @@ -27,7 +44,7 @@ def add(mtype, pkgname, modname, namespace=None): register(mtype, protobuf_workflow, import_workflow, pkgname, modname) -def register(mtype, handler, *args): +def register(mtype: int, handler: Handler, *args: Any) -> None: """Register `handler` to get scheduled after `mtype` message is received.""" if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType): mtype = mtype.MESSAGE_WIRE_TYPE @@ -36,54 +53,75 @@ def register(mtype, handler, *args): workflow_handlers[mtype] = (handler, args) -def setup(iface): +def setup(iface: WireInterface) -> None: """Initialize the wire stack on passed USB interface.""" loop.schedule(session_handler(iface, codec_v1.SESSION_ID)) class Context: - def __init__(self, iface, sid): + def __init__(self, iface: WireInterface, sid: int) -> None: self.iface = iface self.sid = sid - async def call(self, msg, *types): - """ - Reply with `msg` and wait for one of `types`. See `self.write()` and - `self.read()`. - """ + async def call( + self, msg: MessageType, exptype: Type[LoadedMessageType] + ) -> LoadedMessageType: await self.write(msg) del msg - return await self.read(types) + return await self.read(exptype) - async def read(self, types): - """ - Wait for incoming message on this wire context and return it. Raises - `UnexpectedMessageError` if the message type does not match one of - `types`; and caller should always make sure to re-raise it. - """ - reader = self.getreader() + async def call_any(self, msg: MessageType, *allowed_types: int) -> MessageType: + await self.write(msg) + del msg + return await self.read_any(allowed_types) + + async def read( + self, exptype: Optional[Type[LoadedMessageType]] + ) -> LoadedMessageType: + reader = self.make_reader() if __debug__: log.debug( - __name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, types + __name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype ) await reader.aopen() # wait for the message header # if we got a message with unexpected type, raise the reader via # `UnexpectedMessageError` and let the session handler deal with it - if reader.type not in types: + if exptype is None or reader.type != exptype.MESSAGE_WIRE_TYPE: raise UnexpectedMessageError(reader) - # look up the protobuf class and parse the message - pbtype = messages.get_type(reader.type) - return await protobuf.load_message(reader, pbtype) + # parse the message and return it + return await protobuf.load_message(reader, exptype) - async def write(self, msg): - """ - Write a protobuf message to this wire context. - """ - writer = self.getwriter() + async def read_any(self, allowed_types: Iterable[int]) -> MessageType: + reader = self.make_reader() + + if __debug__: + log.debug( + __name__, + "%s:%x read: %s", + self.iface.iface_num(), + self.sid, + allowed_types, + ) + + await reader.aopen() # wait for the message header + + # if we got a message with unexpected type, raise the reader via + # `UnexpectedMessageError` and let the session handler deal with it + if reader.type not in allowed_types: + raise UnexpectedMessageError(reader) + + # find the protobuf type + exptype = messages.get_type(reader.type) + + # parse the message and return it + return await protobuf.load_message(reader, exptype) + + async def write(self, msg: protobuf.MessageType) -> None: + writer = self.make_writer() if __debug__: log.debug( @@ -99,35 +137,35 @@ class Context: await protobuf.dump_message(writer, msg, fields) await writer.aclose() - def wait(self, *tasks): + def wait(self, *tasks: Awaitable) -> Any: """ Wait until one of the passed tasks finishes, and return the result, while servicing the wire context. If a message comes until one of the tasks ends, `UnexpectedMessageError` is raised. """ - return loop.spawn(self.read(()), *tasks) + return loop.spawn(self.read(None), *tasks) - def getreader(self): + def make_reader(self) -> codec_v1.Reader: return codec_v1.Reader(self.iface) - def getwriter(self): + def make_writer(self) -> codec_v1.Writer: return codec_v1.Writer(self.iface) class UnexpectedMessageError(Exception): - def __init__(self, reader): + def __init__(self, reader: codec_v1.Reader) -> None: super().__init__() self.reader = reader -async def session_handler(iface, sid): +async def session_handler(iface: WireInterface, sid: int) -> None: reader = None ctx = Context(iface, sid) while True: try: # wait for new message, if needed, and find handler if not reader: - reader = ctx.getreader() + reader = ctx.make_reader() await reader.aopen() try: handler, args = workflow_handlers[reader.type] @@ -160,7 +198,9 @@ async def session_handler(iface, sid): reader = None -async def protobuf_workflow(ctx, reader, handler, *args): +async def protobuf_workflow( + ctx: Context, reader: codec_v1.Reader, handler: Handler, *args: Any +) -> None: from trezor.messages.Failure import Failure req = await protobuf.load_message(reader, messages.get_type(reader.type)) @@ -185,7 +225,13 @@ async def protobuf_workflow(ctx, reader, handler, *args): await ctx.write(res) -async def keychain_workflow(ctx, req, namespace, handler, *args): +async def keychain_workflow( + ctx: Context, + req: protobuf.MessageType, + namespace: List, + handler: Handler, + *args: Any +) -> Any: from apps.common import seed keychain = await seed.get_keychain(ctx, namespace) @@ -196,22 +242,28 @@ async def keychain_workflow(ctx, req, namespace, handler, *args): keychain.__del__() -def import_workflow(ctx, req, pkgname, modname, *args): +def import_workflow( + ctx: Context, req: protobuf.MessageType, pkgname: str, modname: str, *args: Any +) -> Any: modpath = "%s.%s" % (pkgname, modname) - module = __import__(modpath, None, None, (modname,), 0) + module = __import__(modpath, None, None, (modname,), 0) # type: ignore handler = getattr(module, modname) return handler(ctx, req, *args) -async def unexpected_msg(ctx, reader): +async def unexpected_msg(ctx: Context, reader: codec_v1.Reader) -> None: from trezor.messages.Failure import Failure # receive the message and throw it away - while reader.size > 0: - buf = bytearray(reader.size) - await reader.areadinto(buf) + await read_full_msg(reader) # respond with an unknown message error await ctx.write( Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") ) + + +async def read_full_msg(reader: codec_v1.Reader) -> None: + while reader.size > 0: + buf = bytearray(reader.size) + await reader.areadinto(buf) diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 2b9c6f8b5..4cda63230 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -3,6 +3,9 @@ from micropython import const from trezor import io, loop, utils +if False: + from trezorio import WireInterface + _REP_LEN = const(64) _REP_MARKER = const(63) # ord('?') @@ -12,6 +15,7 @@ _REP_INIT_DATA = const(9) # offset of data in the initial report _REP_CONT_DATA = const(1) # offset of data in the continuation report SESSION_ID = const(0) +INVALID_TYPE = const(-1) class Reader: @@ -20,17 +24,14 @@ class Reader: async-file-like interface. """ - def __init__(self, iface): + def __init__(self, iface: WireInterface) -> None: self.iface = iface - self.type = None - self.size = None - self.data = None + self.type = INVALID_TYPE + self.size = 0 self.ofs = 0 + self.data = bytes() - def __repr__(self): - return "" % (self.type, self.size) - - async def aopen(self): + async def aopen(self) -> None: """ Begin the message transmission by waiting for initial V2 message report on this session. `self.type` and `self.size` are initialized and @@ -53,7 +54,7 @@ class Reader: self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize] self.ofs = 0 - async def areadinto(self, buf): + async def areadinto(self, buf: bytearray) -> int: """ Read exactly `len(buf)` bytes into `buf`, waiting for additional reports, if needed. Raises `EOFError` if end-of-message is encountered @@ -91,17 +92,14 @@ class Writer: async-file-like interface. """ - def __init__(self, iface): + def __init__(self, iface: WireInterface): self.iface = iface - self.type = None - self.size = None - self.data = bytearray(_REP_LEN) + self.type = INVALID_TYPE + self.size = 0 self.ofs = 0 + self.data = bytearray(_REP_LEN) - def __repr__(self): - return "" % (self.type, self.size) - - def setheader(self, mtype, msize): + def setheader(self, mtype: int, msize: int) -> None: """ Reset the writer state and load the message header with passed type and total message size. @@ -113,7 +111,7 @@ class Writer: ) self.ofs = _REP_INIT_DATA - async def awrite(self, buf): + async def awrite(self, buf: bytes) -> int: """ Encode and write every byte from `buf`. Does not need to be called in case message has zero length. Raises `EOFError` if the length of `buf` @@ -142,7 +140,7 @@ class Writer: return nwritten - async def aclose(self): + async def aclose(self) -> None: """Flush and close the message transmission.""" if self.ofs != _REP_CONT_DATA: # we didn't write anything or last write() wasn't report-aligned, diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py index 4a4e152f5..2e38e3ca6 100644 --- a/core/src/trezor/wire/errors.py +++ b/core/src/trezor/wire/errors.py @@ -2,72 +2,72 @@ from trezor.messages import FailureType class Error(Exception): - def __init__(self, code, message): + def __init__(self, code: int, message: str) -> None: super().__init__() self.code = code self.message = message class UnexpectedMessage(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.UnexpectedMessage, message) class ButtonExpected(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.ButtonExpected, message) class DataError(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.DataError, message) class ActionCancelled(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.ActionCancelled, message) class PinExpected(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.PinExpected, message) class PinCancelled(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.PinCancelled, message) class PinInvalid(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.PinInvalid, message) class InvalidSignature(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.InvalidSignature, message) class ProcessError(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.ProcessError, message) class NotEnoughFunds(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.NotEnoughFunds, message) class NotInitialized(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.NotInitialized, message) class PinMismatch(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.PinMismatch, message) class FirmwareError(Error): - def __init__(self, message): + def __init__(self, message: str) -> None: super().__init__(FailureType.FirmwareError, message) diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index 89f1fdab7..da5a149c8 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -1,17 +1,21 @@ from trezor import loop -workflows = [] -layouts = [] +if False: + from trezor import ui + from typing import List, Callable, Optional + +workflows = [] # type: List[loop.Task] +layouts = [] # type: List[ui.Layout] layout_signal = loop.signal() -default = None -default_layout = None +default = None # type: Optional[loop.Task] +default_layout = None # type: Optional[Callable[[], loop.Task]] -def onstart(w): +def onstart(w: loop.Task) -> None: workflows.append(w) -def onclose(w): +def onclose(w: loop.Task) -> None: workflows.remove(w) if not layouts and default_layout: startdefault(default_layout) @@ -24,7 +28,7 @@ def onclose(w): micropython.mem_info() -def closedefault(): +def closedefault() -> None: global default if default: @@ -32,7 +36,7 @@ def closedefault(): default = None -def startdefault(layout): +def startdefault(layout: Callable[[], loop.Task]) -> None: global default global default_layout @@ -42,18 +46,19 @@ def startdefault(layout): loop.schedule(default) -def restartdefault(): +def restartdefault() -> None: global default_layout - d = default_layout + closedefault() - startdefault(d) + if default_layout: + startdefault(default_layout) -def onlayoutstart(l): +def onlayoutstart(l: ui.Layout) -> None: closedefault() layouts.append(l) -def onlayoutclose(l): +def onlayoutclose(l: ui.Layout) -> None: if l in layouts: layouts.remove(l) diff --git a/setup.cfg b/setup.cfg index f0963e0ce..5361caa3c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,15 +33,16 @@ addopts = --strict xfail_strict = true [mypy] -mypy_path = mocks,mocks/generated -warn_unused_configs = True +mypy_path = src,mocks,mocks/generated +check_untyped_defs = True disallow_subclassing_any = True disallow_untyped_calls = True +disallow_untyped_decorators = True disallow_untyped_defs = True disallow_incomplete_defs = True -check_untyped_defs = True +namespace_packages = True # no_implicit_optional = True warn_redundant_casts = True warn_return_any = True +warn_unused_configs = True warn_unused_ignores = True -disallow_untyped_decorators = True