1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 15:38:11 +00:00

mypy: use GenericContext protocol to work-around DummyContext

This commit is contained in:
matejcik 2019-10-21 15:52:02 +02:00 committed by Tomas Susanka
parent ed190c772c
commit d17f879d97
5 changed files with 56 additions and 31 deletions

View File

@ -17,7 +17,7 @@ if False:
async def confirm( async def confirm(
ctx: wire.Context, ctx: wire.GenericContext,
content: ui.Component, content: ui.Component,
code: EnumTypeButtonRequestType = ButtonRequestType.Other, code: EnumTypeButtonRequestType = ButtonRequestType.Other,
confirm: Optional[ButtonContent] = Confirm.DEFAULT_CONFIRM, confirm: Optional[ButtonContent] = Confirm.DEFAULT_CONFIRM,
@ -54,7 +54,7 @@ async def confirm(
async def info_confirm( async def info_confirm(
ctx: wire.Context, ctx: wire.GenericContext,
content: ui.Component, content: ui.Component,
info_func: Callable, info_func: Callable,
code: EnumTypeButtonRequestType = ButtonRequestType.Other, code: EnumTypeButtonRequestType = ButtonRequestType.Other,

View File

@ -78,7 +78,7 @@ def address_n_to_str(address_n: list) -> str:
async def show_warning( async def show_warning(
ctx: wire.Context, ctx: wire.GenericContext,
content: Iterable[str], content: Iterable[str],
subheader: Iterable[str] = [], subheader: Iterable[str] = [],
button: str = "Try again", button: str = "Try again",
@ -96,7 +96,7 @@ async def show_warning(
async def show_success( async def show_success(
ctx: wire.Context, ctx: wire.GenericContext,
content: Iterable[str] = [], content: Iterable[str] = [],
subheader: Iterable[str] = [], subheader: Iterable[str] = [],
button: str = "Continue", button: str = "Continue",

View File

@ -18,13 +18,13 @@ from apps.management import backup_types
from apps.management.recovery_device import layout from apps.management.recovery_device import layout
if False: if False:
from typing import Optional, Tuple, cast from typing import Optional, Tuple
from trezor.messages.ResetDevice import EnumTypeBackupType from trezor.messages.ResetDevice import EnumTypeBackupType
async def recovery_homescreen() -> None: async def recovery_homescreen() -> None:
# recovery process does not communicate on the wire # recovery process does not communicate on the wire
ctx = cast(wire.Context, wire.DummyContext()) # TODO ctx = wire.DummyContext()
try: try:
await recovery_process(ctx) await recovery_process(ctx)
finally: finally:
@ -34,7 +34,7 @@ async def recovery_homescreen() -> None:
wire.clear() wire.clear()
async def recovery_process(ctx: wire.Context) -> Success: async def recovery_process(ctx: wire.GenericContext) -> Success:
try: try:
result = await _continue_recovery_process(ctx) result = await _continue_recovery_process(ctx)
except recover.RecoveryAborted: except recover.RecoveryAborted:
@ -47,7 +47,7 @@ async def recovery_process(ctx: wire.Context) -> Success:
return result return result
async def _continue_recovery_process(ctx: wire.Context) -> Success: async def _continue_recovery_process(ctx: wire.GenericContext) -> Success:
# gather the current recovery state from storage # gather the current recovery state from storage
dry_run = storage_recovery.is_dry_run() dry_run = storage_recovery.is_dry_run()
word_count, backup_type = recover.load_slip39_state() word_count, backup_type = recover.load_slip39_state()
@ -98,7 +98,7 @@ async def _continue_recovery_process(ctx: wire.Context) -> Success:
async def _finish_recovery_dry_run( async def _finish_recovery_dry_run(
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType ctx: wire.GenericContext, secret: bytes, backup_type: EnumTypeBackupType
) -> Success: ) -> Success:
if backup_type is None: if backup_type is None:
raise RuntimeError raise RuntimeError
@ -131,7 +131,7 @@ async def _finish_recovery_dry_run(
async def _finish_recovery( async def _finish_recovery(
ctx: wire.Context, secret: bytes, backup_type: EnumTypeBackupType ctx: wire.GenericContext, secret: bytes, backup_type: EnumTypeBackupType
) -> Success: ) -> Success:
if backup_type is None: if backup_type is None:
raise RuntimeError raise RuntimeError
@ -154,7 +154,7 @@ async def _finish_recovery(
return Success(message="Device recovered") return Success(message="Device recovered")
async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int: async def _request_word_count(ctx: wire.GenericContext, dry_run: bool) -> int:
homepage = layout.RecoveryHomescreen("Select number of words") homepage = layout.RecoveryHomescreen("Select number of words")
await layout.homescreen_dialog(ctx, homepage, "Select") await layout.homescreen_dialog(ctx, homepage, "Select")
@ -163,7 +163,7 @@ async def _request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def _process_words( async def _process_words(
ctx: wire.Context, words: str ctx: wire.GenericContext, words: str
) -> Tuple[Optional[bytes], EnumTypeBackupType]: ) -> Tuple[Optional[bytes], EnumTypeBackupType]:
word_count = len(words.split(" ")) word_count = len(words.split(" "))
is_slip39 = backup_types.is_slip39_word_count(word_count) is_slip39 = backup_types.is_slip39_word_count(word_count)
@ -184,7 +184,9 @@ async def _process_words(
return secret, backup_type return secret, backup_type
async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> None: async def _request_share_first_screen(
ctx: wire.GenericContext, word_count: int
) -> None:
if backup_types.is_slip39_word_count(word_count): if backup_types.is_slip39_word_count(word_count):
remaining = storage_recovery.fetch_slip39_remaining_shares() remaining = storage_recovery.fetch_slip39_remaining_shares()
if remaining: if remaining:
@ -201,7 +203,7 @@ async def _request_share_first_screen(ctx: wire.Context, word_count: int) -> Non
await layout.homescreen_dialog(ctx, content, "Enter seed") await layout.homescreen_dialog(ctx, content, "Enter seed")
async def _request_share_next_screen(ctx: wire.Context) -> None: async def _request_share_next_screen(ctx: wire.GenericContext) -> None:
remaining = storage_recovery.fetch_slip39_remaining_shares() remaining = storage_recovery.fetch_slip39_remaining_shares()
group_count = storage_recovery.get_slip39_group_count() group_count = storage_recovery.get_slip39_group_count()
if not remaining: if not remaining:
@ -222,7 +224,7 @@ async def _request_share_next_screen(ctx: wire.Context) -> None:
await layout.homescreen_dialog(ctx, content, "Enter share") await layout.homescreen_dialog(ctx, content, "Enter share")
async def _show_remaining_groups_and_shares(ctx: wire.Context) -> None: async def _show_remaining_groups_and_shares(ctx: wire.GenericContext) -> None:
""" """
Show info dialog for Slip39 Advanced - what shares are to be entered. Show info dialog for Slip39 Advanced - what shares are to be entered.
""" """

View File

@ -25,7 +25,7 @@ if False:
from trezor.messages.ResetDevice import EnumTypeBackupType from trezor.messages.ResetDevice import EnumTypeBackupType
async def confirm_abort(ctx: wire.Context, dry_run: bool = False) -> bool: async def confirm_abort(ctx: wire.GenericContext, dry_run: bool = False) -> bool:
if dry_run: if dry_run:
text = Text("Abort seed check", ui.ICON_WIPE) text = Text("Abort seed check", ui.ICON_WIPE)
text.normal("Do you really want to", "abort the seed check?") text.normal("Do you really want to", "abort the seed check?")
@ -36,7 +36,7 @@ async def confirm_abort(ctx: wire.Context, dry_run: bool = False) -> bool:
return await confirm(ctx, text, code=ButtonRequestType.ProtectCall) return await confirm(ctx, text, code=ButtonRequestType.ProtectCall)
async def request_word_count(ctx: wire.Context, dry_run: bool) -> int: async def request_word_count(ctx: wire.GenericContext, dry_run: bool) -> int:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicWordCount), ButtonAck)
if dry_run: if dry_run:
@ -55,7 +55,7 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int:
async def request_mnemonic( async def request_mnemonic(
ctx: wire.Context, word_count: int, backup_type: Optional[EnumTypeBackupType] ctx: wire.GenericContext, word_count: int, backup_type: Optional[EnumTypeBackupType]
) -> Optional[str]: ) -> Optional[str]:
await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.MnemonicInput), ButtonAck)
@ -81,7 +81,7 @@ async def request_mnemonic(
async def check_word_validity( async def check_word_validity(
ctx: wire.Context, ctx: wire.GenericContext,
current_index: int, current_index: int,
current_word: str, current_word: str,
backup_type: Optional[EnumTypeBackupType], backup_type: Optional[EnumTypeBackupType],
@ -155,7 +155,7 @@ async def check_word_validity(
async def show_remaining_shares( async def show_remaining_shares(
ctx: wire.Context, ctx: wire.GenericContext,
groups: Iterable[Tuple[int, Tuple[str, ...]]], # remaining + list 3 words groups: Iterable[Tuple[int, Tuple[str, ...]]], # remaining + list 3 words
shares_remaining: List[int], shares_remaining: List[int],
group_threshold: int, group_threshold: int,
@ -187,7 +187,7 @@ async def show_remaining_shares(
async def show_group_share_success( async def show_group_share_success(
ctx: wire.Context, share_index: int, group_index: int ctx: wire.GenericContext, share_index: int, group_index: int
) -> None: ) -> None:
text = Text("Success", ui.ICON_CONFIRM) text = Text("Success", ui.ICON_CONFIRM)
text.bold("You have entered") text.bold("You have entered")
@ -198,7 +198,9 @@ async def show_group_share_success(
await confirm(ctx, text, confirm="Continue", cancel=None) await confirm(ctx, text, confirm="Continue", cancel=None)
async def show_dry_run_result(ctx: wire.Context, result: bool, is_slip39: bool) -> None: async def show_dry_run_result(
ctx: wire.GenericContext, result: bool, is_slip39: bool
) -> None:
if result: if result:
if is_slip39: if is_slip39:
text = ( text = (
@ -233,7 +235,7 @@ async def show_dry_run_result(ctx: wire.Context, result: bool, is_slip39: bool)
await show_warning(ctx, text, button="Continue") await show_warning(ctx, text, button="Continue")
async def show_dry_run_different_type(ctx: wire.Context) -> None: async def show_dry_run_different_type(ctx: wire.GenericContext) -> None:
text = Text("Dry run failure", ui.ICON_CANCEL) text = Text("Dry run failure", ui.ICON_CANCEL)
text.normal("Seed in the device was") text.normal("Seed in the device was")
text.normal("created using another") text.normal("created using another")
@ -243,26 +245,26 @@ async def show_dry_run_different_type(ctx: wire.Context) -> None:
) )
async def show_invalid_mnemonic(ctx: wire.Context, word_count: int) -> None: async def show_invalid_mnemonic(ctx: wire.GenericContext, word_count: int) -> None:
if backup_types.is_slip39_word_count(word_count): if backup_types.is_slip39_word_count(word_count):
await show_warning(ctx, ("You have entered", "an invalid recovery", "share.")) await show_warning(ctx, ("You have entered", "an invalid recovery", "share."))
else: else:
await show_warning(ctx, ("You have entered", "an invalid recovery", "seed.")) await show_warning(ctx, ("You have entered", "an invalid recovery", "seed."))
async def show_share_already_added(ctx: wire.Context) -> None: async def show_share_already_added(ctx: wire.GenericContext) -> None:
await show_warning( await show_warning(
ctx, ("Share already entered,", "please enter", "a different share.") ctx, ("Share already entered,", "please enter", "a different share.")
) )
async def show_identifier_mismatch(ctx: wire.Context) -> None: async def show_identifier_mismatch(ctx: wire.GenericContext) -> None:
await show_warning( await show_warning(
ctx, ("You have entered", "a share from another", "Shamir Backup.") ctx, ("You have entered", "a share from another", "Shamir Backup.")
) )
async def show_group_threshold_reached(ctx: wire.Context) -> None: async def show_group_threshold_reached(ctx: wire.GenericContext) -> None:
await show_warning( await show_warning(
ctx, ctx,
( (
@ -310,7 +312,7 @@ class RecoveryHomescreen(ui.Component):
async def homescreen_dialog( async def homescreen_dialog(
ctx: wire.Context, ctx: wire.GenericContext,
homepage: RecoveryHomescreen, homepage: RecoveryHomescreen,
button_label: str, button_label: str,
info_func: Callable = None, info_func: Callable = None,

View File

@ -99,14 +99,35 @@ def clear() -> None:
workflow_namespaces.clear() workflow_namespaces.clear()
if False:
from typing import Protocol
class GenericContext(Protocol):
async def call(
self,
msg: protobuf.MessageType,
expected_type: Type[protobuf.LoadedMessageType],
) -> Any:
...
async def read(self, expected_type: Type[protobuf.LoadedMessageType]) -> Any:
...
async def write(self, msg: protobuf.MessageType) -> None:
...
async def wait(self, *tasks: Awaitable) -> Any:
...
class DummyContext: class DummyContext:
async def call(*argv: Any) -> None: async def call(self, *argv: Any) -> None:
pass pass
async def read(*argv: Any) -> None: async def read(self, *argv: Any) -> None:
pass pass
async def write(*argv: Any) -> None: async def write(self, *argv: Any) -> None:
pass pass
async def wait(self, *tasks: Awaitable) -> Any: async def wait(self, *tasks: Awaitable) -> Any: