mypy: use GenericContext protocol to work-around DummyContext

pull/645/head
matejcik 5 years ago committed by Tomas Susanka
parent ed190c772c
commit d17f879d97

@ -17,7 +17,7 @@ if False:
async def confirm(
ctx: wire.Context,
ctx: wire.GenericContext,
content: ui.Component,
code: EnumTypeButtonRequestType = ButtonRequestType.Other,
confirm: Optional[ButtonContent] = Confirm.DEFAULT_CONFIRM,
@ -54,7 +54,7 @@ async def confirm(
async def info_confirm(
ctx: wire.Context,
ctx: wire.GenericContext,
content: ui.Component,
info_func: Callable,
code: EnumTypeButtonRequestType = ButtonRequestType.Other,

@ -78,7 +78,7 @@ def address_n_to_str(address_n: list) -> str:
async def show_warning(
ctx: wire.Context,
ctx: wire.GenericContext,
content: Iterable[str],
subheader: Iterable[str] = [],
button: str = "Try again",
@ -96,7 +96,7 @@ async def show_warning(
async def show_success(
ctx: wire.Context,
ctx: wire.GenericContext,
content: Iterable[str] = [],
subheader: Iterable[str] = [],
button: str = "Continue",

@ -18,13 +18,13 @@ from apps.management import backup_types
from apps.management.recovery_device import layout
if False:
from typing import Optional, Tuple, cast
from typing import Optional, Tuple
from trezor.messages.ResetDevice import EnumTypeBackupType
async def recovery_homescreen() -> None:
# recovery process does not communicate on the wire
ctx = cast(wire.Context, wire.DummyContext()) # TODO
ctx = wire.DummyContext()
try:
await recovery_process(ctx)
finally:
@ -34,7 +34,7 @@ async def recovery_homescreen() -> None:
wire.clear()
async def recovery_process(ctx: wire.Context) -> Success:
async def recovery_process(ctx: wire.GenericContext) -> Success:
try:
result = await _continue_recovery_process(ctx)
except recover.RecoveryAborted:
@ -47,7 +47,7 @@ async def recovery_process(ctx: wire.Context) -> Success:
return result
async def _continue_recovery_process(ctx: wire.Context) -> Success:
async def _continue_recovery_process(ctx: wire.GenericContext) -> Success:
# gather the current recovery state from storage
dry_run = storage_recovery.is_dry_run()
word_count, backup_type = recover.load_slip39_state()
@ -98,7 +98,7 @@ async def _continue_recovery_process(ctx: wire.Context) -> Success:
async def _finish_recovery_dry_run(
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
ctx: wire.GenericContext, secret: bytes, backup_type: EnumTypeBackupType
) -> Success:
if backup_type is None:
raise RuntimeError
@ -131,7 +131,7 @@ async def _finish_recovery_dry_run(
async def _finish_recovery(
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType
ctx: wire.GenericContext, secret: bytes, backup_type: EnumTypeBackupType
) -> Success:
if backup_type is None:
raise RuntimeError
@ -154,7 +154,7 @@ async def _finish_recovery(
return Success(message="Device recovered")
async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def _request_word_count(ctx: wire.GenericContext, dry_run: bool) -> int:
homepage = layout.RecoveryHomescreen("Select number of words")
await layout.homescreen_dialog(ctx, homepage, "Select")
@ -163,7 +163,7 @@ async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def _process_words(
ctx: wire.Context, words: str
ctx: wire.GenericContext, words: str
) -> Tuple[Optional[bytes], EnumTypeBackupType]:
word_count = len(words.split(" "))
is_slip39 = backup_types.is_slip39_word_count(word_count)
@ -184,7 +184,9 @@ async def _process_words(
return secret, backup_type
async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> None:
async def _request_share_first_screen(
ctx: wire.GenericContext, word_count: int
) -> None:
if backup_types.is_slip39_word_count(word_count):
remaining = storage_recovery.fetch_slip39_remaining_shares()
if remaining:
@ -201,7 +203,7 @@ async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> Non
await layout.homescreen_dialog(ctx, content, "Enter seed")
async def _request_share_next_screen(ctx: wire.Context) -> None:
async def _request_share_next_screen(ctx: wire.GenericContext) -> None:
remaining = storage_recovery.fetch_slip39_remaining_shares()
group_count = storage_recovery.get_slip39_group_count()
if not remaining:
@ -222,7 +224,7 @@ async def _request_share_next_screen(ctx: wire.Context) -> None:
await layout.homescreen_dialog(ctx, content, "Enter share")
async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None:
async def _show_remaining_groups_and_shares(ctx: wire.GenericContext) -> None:
"""
Show info dialog for Slip39 Advanced - what shares are to be entered.
"""

@ -25,7 +25,7 @@ if False:
from trezor.messages.ResetDevice import EnumTypeBackupType
async def confirm_abort(ctx: wire.Context, dry_run: bool = False) -> bool:
async def confirm_abort(ctx: wire.GenericContext, dry_run: bool = False) -> bool:
if dry_run:
text = Text("Abort seed check", ui.ICON_WIPE)
text.normal("Do you really want to", "abort the seed check?")
@ -36,7 +36,7 @@ async def confirm_abort(ctx: wire.Context, dry_run: bool = False) -> bool:
return await confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def request_word_count(ctx: wire.GenericContext, dry_run: bool) -> int:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck)
if dry_run:
@ -55,7 +55,7 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def request_mnemonic(
ctx: wire.Context, word_count: int, backup_type: Optional[EnumTypeBackupType]
ctx: wire.GenericContext, word_count: int, backup_type: Optional[EnumTypeBackupType]
) -> Optional[str]:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck)
@ -81,7 +81,7 @@ async def request_mnemonic(
async def check_word_validity(
ctx: wire.Context,
ctx: wire.GenericContext,
current_index: int,
current_word: str,
backup_type: Optional[EnumTypeBackupType],
@ -155,7 +155,7 @@ async def check_word_validity(
async def show_remaining_shares(
ctx: wire.Context,
ctx: wire.GenericContext,
groups: Iterable[Tuple[int, Tuple[str, ...]]], # remaining + list 3 words
shares_remaining: List[int],
group_threshold: int,
@ -187,7 +187,7 @@ async def show_remaining_shares(
async def show_group_share_success(
ctx: wire.Context, share_index: int, group_index: int
ctx: wire.GenericContext, share_index: int, group_index: int
) -> None:
text = Text("Success", ui.ICON_CONFIRM)
text.bold("You have entered")
@ -198,7 +198,9 @@ async def show_group_share_success(
await confirm(ctx, text, confirm="Continue", cancel=None)
async def show_dry_run_result(ctx: wire.Context, result: bool, is_slip39: bool) -> None:
async def show_dry_run_result(
ctx: wire.GenericContext, result: bool, is_slip39: bool
) -> None:
if result:
if is_slip39:
text = (
@ -233,7 +235,7 @@ async def show_dry_run_result(ctx: wire.Context, result: bool, is_slip39: bool)
await show_warning(ctx, text, button="Continue")
async def show_dry_run_different_type(ctx: wire.Context) -> None:
async def show_dry_run_different_type(ctx: wire.GenericContext) -> None:
text = Text("Dry run failure", ui.ICON_CANCEL)
text.normal("Seed in the device was")
text.normal("created using another")
@ -243,26 +245,26 @@ async def show_dry_run_different_type(ctx: wire.Context) -> None:
)
async def show_invalid_mnemonic(ctx: wire.Context, word_count: int) -> None:
async def show_invalid_mnemonic(ctx: wire.GenericContext, word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count):
await show_warning(ctx, ("You have entered", "an invalid recovery", "share."))
else:
await show_warning(ctx, ("You have entered", "an invalid recovery", "seed."))
async def show_share_already_added(ctx: wire.Context) -> None:
async def show_share_already_added(ctx: wire.GenericContext) -> None:
await show_warning(
ctx, ("Share already entered,", "please enter", "a different share.")
)
async def show_identifier_mismatch(ctx: wire.Context) -> None:
async def show_identifier_mismatch(ctx: wire.GenericContext) -> None:
await show_warning(
ctx, ("You have entered", "a share from another", "Shamir Backup.")
)
async def show_group_threshold_reached(ctx: wire.Context) -> None:
async def show_group_threshold_reached(ctx: wire.GenericContext) -> None:
await show_warning(
ctx,
(
@ -310,7 +312,7 @@ class RecoveryHomescreen(ui.Component):
async def homescreen_dialog(
ctx: wire.Context,
ctx: wire.GenericContext,
homepage: RecoveryHomescreen,
button_label: str,
info_func: Callable = None,

@ -99,14 +99,35 @@ def clear() -> None:
workflow_namespaces.clear()
if False:
from typing import Protocol
class GenericContext(Protocol):
async def call(
self,
msg: protobuf.MessageType,
expected_type: Type[protobuf.LoadedMessageType],
) -> Any:
...
async def read(self, expected_type: Type[protobuf.LoadedMessageType]) -> Any:
...
async def write(self, msg: protobuf.MessageType) -> None:
...
async def wait(self, *tasks: Awaitable) -> Any:
...
class DummyContext:
async def call(*argv: Any) -> None:
async def call(self, *argv: Any) -> None:
pass
async def read(*argv: Any) -> None:
async def read(self, *argv: Any) -> None:
pass
async def write(*argv: Any) -> None:
async def write(self, *argv: Any) -> None:
pass
async def wait(self, *tasks: Awaitable) -> Any:

Loading…
Cancel
Save