From d17f879d9766c8744d3d3abfb5f8dc69c6c79833 Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 21 Oct 2019 15:52:02 +0200 Subject: [PATCH] mypy: use GenericContext protocol to work-around DummyContext --- core/src/apps/common/confirm.py | 4 +-- core/src/apps/common/layout.py | 4 +-- .../management/recovery_device/homescreen.py | 24 ++++++++-------- .../apps/management/recovery_device/layout.py | 28 ++++++++++--------- core/src/trezor/wire/__init__.py | 27 ++++++++++++++++-- 5 files changed, 56 insertions(+), 31 deletions(-) diff --git a/core/src/apps/common/confirm.py b/core/src/apps/common/confirm.py index b2ccca7b7..7bba2fc5c 100644 --- a/core/src/apps/common/confirm.py +++ b/core/src/apps/common/confirm.py @@ -17,7 +17,7 @@ if False: async def confirm( - ctx: wire.Context, + ctx: wire.GenericContext, content: ui.Component, code: EnumTypeButtonRequestType = ButtonRequestType.Other, confirm: Optional[ButtonContent] = Confirm.DEFAULT_CONFIRM, @@ -54,7 +54,7 @@ async def confirm( async def info_confirm( - ctx: wire.Context, + ctx: wire.GenericContext, content: ui.Component, info_func: Callable, code: EnumTypeButtonRequestType = ButtonRequestType.Other, diff --git a/core/src/apps/common/layout.py b/core/src/apps/common/layout.py index a10eab827..1e16c5d97 100644 --- a/core/src/apps/common/layout.py +++ b/core/src/apps/common/layout.py @@ -78,7 +78,7 @@ def address_n_to_str(address_n: list) -> str: async def show_warning( - ctx: wire.Context, + ctx: wire.GenericContext, content: Iterable[str], subheader: Iterable[str] = [], button: str = "Try again", @@ -96,7 +96,7 @@ async def show_warning( async def show_success( - ctx: wire.Context, + ctx: wire.GenericContext, content: Iterable[str] = [], subheader: Iterable[str] = [], button: str = "Continue", diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py index 05a22d5c1..c48a9a4bd 100644 --- a/core/src/apps/management/recovery_device/homescreen.py +++ b/core/src/apps/management/recovery_device/homescreen.py @@ -18,13 +18,13 @@ from apps.management import backup_types from apps.management.recovery_device import layout if False: - from typing import Optional, Tuple, cast + from typing import Optional, Tuple from trezor.messages.ResetDevice import EnumTypeBackupType async def recovery_homescreen() -> None: # recovery process does not communicate on the wire - ctx = cast(wire.Context, wire.DummyContext()) # TODO + ctx = wire.DummyContext() try: await recovery_process(ctx) finally: @@ -34,7 +34,7 @@ async def recovery_homescreen() -> None: wire.clear() -async def recovery_process(ctx: wire.Context) -> Success: +async def recovery_process(ctx: wire.GenericContext) -> Success: try: result = await _continue_recovery_process(ctx) except recover.RecoveryAborted: @@ -47,7 +47,7 @@ async def recovery_process(ctx: wire.Context) -> Success: return result -async def _continue_recovery_process(ctx: wire.Context) -> Success: +async def _continue_recovery_process(ctx: wire.GenericContext) -> Success: # gather the current recovery state from storage dry_run = storage_recovery.is_dry_run() word_count, backup_type = recover.load_slip39_state() @@ -98,7 +98,7 @@ async def _continue_recovery_process(ctx: wire.Context) -> Success: async def _finish_recovery_dry_run( - ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType + ctx: wire.GenericContext, secret: bytes, backup_type: EnumTypeBackupType ) -> Success: if backup_type is None: raise RuntimeError @@ -131,7 +131,7 @@ async def _finish_recovery_dry_run( async def _finish_recovery( - ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType + ctx: wire.GenericContext, secret: bytes, backup_type: EnumTypeBackupType ) -> Success: if backup_type is None: raise RuntimeError @@ -154,7 +154,7 @@ async def _finish_recovery( return Success(message="Device recovered") -async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int: +async def _request_word_count(ctx: wire.GenericContext, dry_run: bool) -> int: homepage = layout.RecoveryHomescreen("Select number of words") await layout.homescreen_dialog(ctx, homepage, "Select") @@ -163,7 +163,7 @@ async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int: async def _process_words( - ctx: wire.Context, words: str + ctx: wire.GenericContext, words: str ) -> Tuple[Optional[bytes], EnumTypeBackupType]: word_count = len(words.split(" ")) is_slip39 = backup_types.is_slip39_word_count(word_count) @@ -184,7 +184,9 @@ async def _process_words( return secret, backup_type -async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> None: +async def _request_share_first_screen( + ctx: wire.GenericContext, word_count: int +) -> None: if backup_types.is_slip39_word_count(word_count): remaining = storage_recovery.fetch_slip39_remaining_shares() if remaining: @@ -201,7 +203,7 @@ async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> Non await layout.homescreen_dialog(ctx, content, "Enter seed") -async def _request_share_next_screen(ctx: wire.Context) -> None: +async def _request_share_next_screen(ctx: wire.GenericContext) -> None: remaining = storage_recovery.fetch_slip39_remaining_shares() group_count = storage_recovery.get_slip39_group_count() if not remaining: @@ -222,7 +224,7 @@ async def _request_share_next_screen(ctx: wire.Context) -> None: await layout.homescreen_dialog(ctx, content, "Enter share") -async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None: +async def _show_remaining_groups_and_shares(ctx: wire.GenericContext) -> None: """ Show info dialog for Slip39 Advanced - what shares are to be entered. """ diff --git a/core/src/apps/management/recovery_device/layout.py b/core/src/apps/management/recovery_device/layout.py index c56e2ee5e..905a74adb 100644 --- a/core/src/apps/management/recovery_device/layout.py +++ b/core/src/apps/management/recovery_device/layout.py @@ -25,7 +25,7 @@ if False: from trezor.messages.ResetDevice import EnumTypeBackupType -async def confirm_abort(ctx: wire.Context, dry_run: bool = False) -> bool: +async def confirm_abort(ctx: wire.GenericContext, dry_run: bool = False) -> bool: if dry_run: text = Text("Abort seed check", ui.ICON_WIPE) text.normal("Do you really want to", "abort the seed check?") @@ -36,7 +36,7 @@ async def confirm_abort(ctx: wire.Context, dry_run: bool = False) -> bool: return await confirm(ctx, text, code=ButtonRequestType.ProtectCall) -async def request_word_count(ctx: wire.Context, dry_run: bool) -> int: +async def request_word_count(ctx: wire.GenericContext, dry_run: bool) -> int: await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck) if dry_run: @@ -55,7 +55,7 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int: async def request_mnemonic( - ctx: wire.Context, word_count: int, backup_type: Optional[EnumTypeBackupType] + ctx: wire.GenericContext, word_count: int, backup_type: Optional[EnumTypeBackupType] ) -> Optional[str]: await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck) @@ -81,7 +81,7 @@ async def request_mnemonic( async def check_word_validity( - ctx: wire.Context, + ctx: wire.GenericContext, current_index: int, current_word: str, backup_type: Optional[EnumTypeBackupType], @@ -155,7 +155,7 @@ async def check_word_validity( async def show_remaining_shares( - ctx: wire.Context, + ctx: wire.GenericContext, groups: Iterable[Tuple[int, Tuple[str, ...]]], # remaining + list 3 words shares_remaining: List[int], group_threshold: int, @@ -187,7 +187,7 @@ async def show_remaining_shares( async def show_group_share_success( - ctx: wire.Context, share_index: int, group_index: int + ctx: wire.GenericContext, share_index: int, group_index: int ) -> None: text = Text("Success", ui.ICON_CONFIRM) text.bold("You have entered") @@ -198,7 +198,9 @@ async def show_group_share_success( await confirm(ctx, text, confirm="Continue", cancel=None) -async def show_dry_run_result(ctx: wire.Context, result: bool, is_slip39: bool) -> None: +async def show_dry_run_result( + ctx: wire.GenericContext, result: bool, is_slip39: bool +) -> None: if result: if is_slip39: text = ( @@ -233,7 +235,7 @@ async def show_dry_run_result(ctx: wire.Context, result: bool, is_slip39: bool) await show_warning(ctx, text, button="Continue") -async def show_dry_run_different_type(ctx: wire.Context) -> None: +async def show_dry_run_different_type(ctx: wire.GenericContext) -> None: text = Text("Dry run failure", ui.ICON_CANCEL) text.normal("Seed in the device was") text.normal("created using another") @@ -243,26 +245,26 @@ async def show_dry_run_different_type(ctx: wire.Context) -> None: ) -async def show_invalid_mnemonic(ctx: wire.Context, word_count: int) -> None: +async def show_invalid_mnemonic(ctx: wire.GenericContext, word_count: int) -> None: if backup_types.is_slip39_word_count(word_count): await show_warning(ctx, ("You have entered", "an invalid recovery", "share.")) else: await show_warning(ctx, ("You have entered", "an invalid recovery", "seed.")) -async def show_share_already_added(ctx: wire.Context) -> None: +async def show_share_already_added(ctx: wire.GenericContext) -> None: await show_warning( ctx, ("Share already entered,", "please enter", "a different share.") ) -async def show_identifier_mismatch(ctx: wire.Context) -> None: +async def show_identifier_mismatch(ctx: wire.GenericContext) -> None: await show_warning( ctx, ("You have entered", "a share from another", "Shamir Backup.") ) -async def show_group_threshold_reached(ctx: wire.Context) -> None: +async def show_group_threshold_reached(ctx: wire.GenericContext) -> None: await show_warning( ctx, ( @@ -310,7 +312,7 @@ class RecoveryHomescreen(ui.Component): async def homescreen_dialog( - ctx: wire.Context, + ctx: wire.GenericContext, homepage: RecoveryHomescreen, button_label: str, info_func: Callable = None, diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 7c99e4e28..c12cc809d 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -99,14 +99,35 @@ def clear() -> None: workflow_namespaces.clear() +if False: + from typing import Protocol + + class GenericContext(Protocol): + async def call( + self, + msg: protobuf.MessageType, + expected_type: Type[protobuf.LoadedMessageType], + ) -> Any: + ... + + async def read(self, expected_type: Type[protobuf.LoadedMessageType]) -> Any: + ... + + async def write(self, msg: protobuf.MessageType) -> None: + ... + + async def wait(self, *tasks: Awaitable) -> Any: + ... + + class DummyContext: - async def call(*argv: Any) -> None: + async def call(self, *argv: Any) -> None: pass - async def read(*argv: Any) -> None: + async def read(self, *argv: Any) -> None: pass - async def write(*argv: Any) -> None: + async def write(self, *argv: Any) -> None: pass async def wait(self, *tasks: Awaitable) -> Any: