From 081e4590756af5eefa3952a8a5fe25a19616536d Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Tue, 9 Jul 2019 14:05:35 +0200 Subject: [PATCH] core: fix various types --- core/src/apps/common/storage/__init__.py | 8 ++--- core/src/apps/common/storage/common.py | 21 ++++++++------ core/src/apps/common/storage/device.py | 27 +++++++++-------- core/src/apps/common/storage/slip39.py | 29 ++++++++++--------- .../apps/common/storage/slip39_mnemonics.py | 13 +++++---- core/src/apps/ethereum/networks.py | 15 ++++++---- core/src/apps/ethereum/networks.py.mako | 15 ++++++---- core/src/trezor/ui/loader.py | 10 +++---- 8 files changed, 78 insertions(+), 60 deletions(-) diff --git a/core/src/apps/common/storage/__init__.py b/core/src/apps/common/storage/__init__.py index 15e4238b3..2d3afa426 100644 --- a/core/src/apps/common/storage/__init__.py +++ b/core/src/apps/common/storage/__init__.py @@ -4,7 +4,7 @@ from apps.common import cache from apps.common.storage import common, device, slip39 -def set_current_version(): +def set_current_version() -> None: device.set_version(common._STORAGE_VERSION_CURRENT) @@ -12,19 +12,19 @@ def is_initialized() -> bool: return device.is_version_stored() and not slip39.is_in_progress() -def wipe(): +def wipe() -> None: config.wipe() cache.clear() -def init_unlocked(): +def init_unlocked() -> None: # Check for storage version upgrade. version = device.get_version() if version == common._STORAGE_VERSION_01: _migrate_from_version_01() -def _migrate_from_version_01(): +def _migrate_from_version_01() -> None: # Make the U2F counter public and writable even when storage is locked. # U2F counter wasn't public, so we are intentionally not using storage.device module. counter = common._get(common._APP_DEVICE, device._U2F_COUNTER) diff --git a/core/src/apps/common/storage/common.py b/core/src/apps/common/storage/common.py index 16eda101a..96e9f57ea 100644 --- a/core/src/apps/common/storage/common.py +++ b/core/src/apps/common/storage/common.py @@ -1,5 +1,8 @@ from trezor import config +if False: + from typing import Optional + # Namespaces: # fmt: off # Intentionally not using const() to allow import in submodules. @@ -15,19 +18,19 @@ _STORAGE_VERSION_01 = b"\x01" _STORAGE_VERSION_CURRENT = b"\x02" -def _set(app: int, key: int, data: bytes, public: bool = False): +def _set(app: int, key: int, data: bytes, public: bool = False) -> None: config.set(app, key, data, public) -def _get(app: int, key: int, public: bool = False): +def _get(app: int, key: int, public: bool = False) -> Optional[bytes]: return config.get(app, key, public) -def _delete(app: int, key: int): +def _delete(app: int, key: int) -> None: config.delete(app, key) -def _set_true_or_delete(app: int, key: int, value: bool): +def _set_true_or_delete(app: int, key: int, value: bool) -> None: if value: _set_bool(app, key, value) else: @@ -45,29 +48,29 @@ def _get_bool(app: int, key: int, public: bool = False) -> bool: return _get(app, key, public) == _TRUE_BYTE -def _set_uint8(app: int, key: int, val: int): +def _set_uint8(app: int, key: int, val: int) -> None: _set(app, key, val.to_bytes(1, "big")) -def _get_uint8(app: int, key: int) -> int: +def _get_uint8(app: int, key: int) -> Optional[int]: val = _get(app, key) if not val: return None return int.from_bytes(val, "big") -def _set_uint16(app: int, key: int, val: int): +def _set_uint16(app: int, key: int, val: int) -> None: _set(app, key, val.to_bytes(2, "big")) -def _get_uint16(app: int, key: int) -> int: +def _get_uint16(app: int, key: int) -> Optional[int]: val = _get(app, key) if not val: return None return int.from_bytes(val, "big") -def _next_counter(app: int, key: int, public: bool = False): +def _next_counter(app: int, key: int, public: bool = False) -> Optional[int]: return config.next_counter(app, key, public) diff --git a/core/src/apps/common/storage/device.py b/core/src/apps/common/storage/device.py index 83bc6b77b..be93a09bf 100644 --- a/core/src/apps/common/storage/device.py +++ b/core/src/apps/common/storage/device.py @@ -5,6 +5,9 @@ from trezor.crypto import random from apps.common.storage import common +if False: + from typing import Optional + # Namespace: _NAMESPACE = common._APP_DEVICE @@ -35,12 +38,12 @@ def is_version_stored() -> bool: return bool(common._get(_NAMESPACE, _VERSION)) -def get_version() -> bool: +def get_version() -> Optional[bytes]: return common._get(_NAMESPACE, _VERSION) -def set_version(version: bytes) -> bool: - return common._set(_NAMESPACE, _VERSION, version) +def set_version(version: bytes) -> None: + common._set(_NAMESPACE, _VERSION, version) def _new_device_id() -> str: @@ -62,18 +65,18 @@ def get_rotation() -> int: return int.from_bytes(rotation, "big") -def get_label() -> str: +def get_label() -> Optional[str]: label = common._get(_NAMESPACE, _LABEL, True) # public if label is None: return None return label.decode() -def get_mnemonic_secret() -> bytes: +def get_mnemonic_secret() -> Optional[bytes]: return common._get(_NAMESPACE, _MNEMONIC_SECRET) -def get_mnemonic_type() -> int: +def get_mnemonic_type() -> Optional[int]: return common._get_uint8(_NAMESPACE, _MNEMONIC_TYPE) @@ -81,7 +84,7 @@ def has_passphrase() -> bool: return common._get_bool(_NAMESPACE, _USE_PASSPHRASE) -def get_homescreen() -> bytes: +def get_homescreen() -> Optional[bytes]: return common._get(_NAMESPACE, _HOMESCREEN, True) # public @@ -171,11 +174,11 @@ def get_flags() -> int: def set_flags(flags: int) -> None: b = common._get(_NAMESPACE, _FLAGS) if b is None: - b = 0 + i = 0 else: - b = int.from_bytes(b, "big") - flags = (flags | b) & 0xFFFFFFFF - if flags != b: + i = int.from_bytes(b, "big") + flags = (flags | i) & 0xFFFFFFFF + if flags != i: common._set(_NAMESPACE, _FLAGS, flags.to_bytes(4, "big")) @@ -193,7 +196,7 @@ def set_autolock_delay_ms(delay_ms: int) -> None: common._set(_NAMESPACE, _AUTOLOCK_DELAY_MS, delay_ms.to_bytes(4, "big")) -def next_u2f_counter() -> int: +def next_u2f_counter() -> Optional[int]: return common._next_counter(_NAMESPACE, _U2F_COUNTER, True) # writable when locked diff --git a/core/src/apps/common/storage/slip39.py b/core/src/apps/common/storage/slip39.py index 140e3b700..dd5f0527e 100644 --- a/core/src/apps/common/storage/slip39.py +++ b/core/src/apps/common/storage/slip39.py @@ -2,6 +2,9 @@ from micropython import const from apps.common.storage import common, slip39_mnemonics +if False: + from typing import Optional + # Namespace: _NAMESPACE = common._APP_SLIP39 @@ -16,55 +19,55 @@ _SLIP39_ITERATION_EXPONENT = const(0x05) # int # fmt: on -def set_in_progress(val: bool): +def set_in_progress(val: bool) -> None: common._set_bool(_NAMESPACE, _SLIP39_IN_PROGRESS, val) -def is_in_progress(): +def is_in_progress() -> bool: return common._get_bool(_NAMESPACE, _SLIP39_IN_PROGRESS) -def set_identifier(identifier: int): +def set_identifier(identifier: int) -> None: common._set_uint16(_NAMESPACE, _SLIP39_IDENTIFIER, identifier) -def get_identifier() -> int: +def get_identifier() -> Optional[int]: return common._get_uint16(_NAMESPACE, _SLIP39_IDENTIFIER) -def set_threshold(threshold: int): +def set_threshold(threshold: int) -> None: common._set_uint8(_NAMESPACE, _SLIP39_THRESHOLD, threshold) -def get_threshold() -> int: +def get_threshold() -> Optional[int]: return common._get_uint8(_NAMESPACE, _SLIP39_THRESHOLD) -def set_remaining(remaining: int): +def set_remaining(remaining: int) -> None: common._set_uint8(_NAMESPACE, _SLIP39_REMAINING, remaining) -def get_remaining() -> int: +def get_remaining() -> Optional[int]: return common._get_uint8(_NAMESPACE, _SLIP39_REMAINING) -def set_words_count(count: int): +def set_words_count(count: int) -> None: common._set_uint8(_NAMESPACE, _SLIP39_WORDS_COUNT, count) -def get_words_count() -> int: +def get_words_count() -> Optional[int]: return common._get_uint8(_NAMESPACE, _SLIP39_WORDS_COUNT) -def set_iteration_exponent(exponent: int): +def set_iteration_exponent(exponent: int) -> None: common._set_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT, exponent) -def get_iteration_exponent() -> int: +def get_iteration_exponent() -> Optional[int]: return common._get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) -def delete_progress(): +def delete_progress() -> None: common._delete(_NAMESPACE, _SLIP39_IN_PROGRESS) common._delete(_NAMESPACE, _SLIP39_REMAINING) common._delete(_NAMESPACE, _SLIP39_THRESHOLD) diff --git a/core/src/apps/common/storage/slip39_mnemonics.py b/core/src/apps/common/storage/slip39_mnemonics.py index b6e796632..7887dddbf 100644 --- a/core/src/apps/common/storage/slip39_mnemonics.py +++ b/core/src/apps/common/storage/slip39_mnemonics.py @@ -2,22 +2,25 @@ from trezor.crypto import slip39 from apps.common.storage import common +if False: + from typing import List, Optional + # Mnemonics stored during SLIP-39 recovery process. # Each mnemonic is stored under key = index. -def set(index: int, mnemonic: str): +def set(index: int, mnemonic: str) -> None: common._set(common._APP_SLIP39_MNEMONICS, index, mnemonic.encode()) -def get(index: int) -> str: +def get(index: int) -> Optional[str]: m = common._get(common._APP_SLIP39_MNEMONICS, index) if m: return m.decode() - return False + return None -def fetch() -> list: +def fetch() -> List[str]: mnemonics = [] for index in range(0, slip39.MAX_SHARE_COUNT): m = get(index) @@ -26,6 +29,6 @@ def fetch() -> list: return mnemonics -def delete(): +def delete() -> None: for index in range(0, slip39.MAX_SHARE_COUNT): common._delete(common._APP_SLIP39_MNEMONICS, index) diff --git a/core/src/apps/ethereum/networks.py b/core/src/apps/ethereum/networks.py index 41667f83f..6783bab01 100644 --- a/core/src/apps/ethereum/networks.py +++ b/core/src/apps/ethereum/networks.py @@ -3,30 +3,33 @@ from apps.common import HARDENED +if False: + from typing import Iterator, Optional -def shortcut_by_chain_id(chain_id, tx_type=None): - if tx_type in [1, 6] and chain_id in [1, 3]: + +def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str: + if tx_type in (1, 6) and chain_id in (1, 3): return "WAN" else: n = by_chain_id(chain_id) return n.shortcut if n is not None else "UNKN" -def by_chain_id(chain_id): +def by_chain_id(chain_id: int) -> Optional["NetworkInfo"]: for n in NETWORKS: if n.chain_id == chain_id: return n return None -def by_slip44(slip44): +def by_slip44(slip44: int) -> Optional["NetworkInfo"]: for n in NETWORKS: if n.slip44 == slip44: return n return None -def all_slip44_ids_hardened(): +def all_slip44_ids_hardened() -> Iterator[int]: for n in NETWORKS: yield n.slip44 | HARDENED @@ -34,7 +37,7 @@ def all_slip44_ids_hardened(): class NetworkInfo: def __init__( self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool - ): + ) -> None: self.chain_id = chain_id self.slip44 = slip44 self.shortcut = shortcut diff --git a/core/src/apps/ethereum/networks.py.mako b/core/src/apps/ethereum/networks.py.mako index fa0312eaf..15a43e7bb 100644 --- a/core/src/apps/ethereum/networks.py.mako +++ b/core/src/apps/ethereum/networks.py.mako @@ -3,30 +3,33 @@ from apps.common import HARDENED +if False: + from typing import Iterator, Optional -def shortcut_by_chain_id(chain_id, tx_type=None): - if tx_type in [1, 6] and chain_id in [1, 3]: + +def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str: + if tx_type in (1, 6) and chain_id in (1, 3): return "WAN" else: n = by_chain_id(chain_id) return n.shortcut if n is not None else "UNKN" -def by_chain_id(chain_id): +def by_chain_id(chain_id: int) -> Optional["NetworkInfo"]: for n in NETWORKS: if n.chain_id == chain_id: return n return None -def by_slip44(slip44): +def by_slip44(slip44: int) -> Optional["NetworkInfo"]: for n in NETWORKS: if n.slip44 == slip44: return n return None -def all_slip44_ids_hardened(): +def all_slip44_ids_hardened() -> Iterator[int]: for n in NETWORKS: yield n.slip44 | HARDENED @@ -34,7 +37,7 @@ def all_slip44_ids_hardened(): class NetworkInfo: def __init__( self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool - ): + ) -> None: self.chain_id = chain_id self.slip44 = slip44 self.shortcut = shortcut diff --git a/core/src/trezor/ui/loader.py b/core/src/trezor/ui/loader.py index 03786ea8f..bfc8c3442 100644 --- a/core/src/trezor/ui/loader.py +++ b/core/src/trezor/ui/loader.py @@ -90,20 +90,20 @@ class Loader(ui.Control): def on_start(self) -> None: pass - def on_finish(self): + def on_finish(self) -> None: pass class LoadingAnimation(ui.Layout): - def __init__(self, style=LoaderDefault): + def __init__(self, style: LoaderStyleType = LoaderDefault) -> None: self.loader = Loader(style) - self.loader.on_finish = self.on_finish + self.loader.on_finish = self.on_finish # type: ignore self.loader.start() - def dispatch(self, event, x, y): + def dispatch(self, event: int, x: int, y: int) -> None: if not self.loader.elapsed_ms(): self.loader.start() self.loader.dispatch(event, x, y) - def on_finish(self): + def on_finish(self) -> None: raise ui.Result(None)