diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index b29cf0221..bcd32c298 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1055,7 +1055,7 @@ class TrezorClientDebugLink(TrezorClient): return msg def set_input_flow( - self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] + self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] | None ) -> None: """Configure a sequence of input events for the current with-block. @@ -1083,6 +1083,10 @@ class TrezorClientDebugLink(TrezorClient): """ if not self.in_with_statement: raise RuntimeError("Must be called inside 'with' statement") + + if input_flow is None: + self.ui.input_flow = None + return if callable(input_flow): input_flow = input_flow() diff --git a/tests/input_flows.py b/tests/input_flows.py index 9bd1f97be..2f48c47a0 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -45,35 +45,27 @@ class InputFlowBase: self.BAK = BackupFlow(self.client) self.ETH = EthereumFlow(self.client) - def model(self) -> str | models.TrezorModel: + @property + def model(self) -> models.TrezorModel: return self.client.model - def get(self) -> Callable[[], BRGeneratorType]: + def get(self) -> Callable[[], BRGeneratorType] | None: self.client.watch_layout(True) # There could be one common input flow for all models if hasattr(self, "input_flow_common"): return getattr(self, "input_flow_common") - elif self.model() is models.T2T1: - return self.input_flow_tt - elif self.model() is models.T2B1: - return self.input_flow_tr - elif self.model() is models.T3T1: - return self.input_flow_t3t1 - else: + + models_dict = { + models.T2T1: "tt", + models.T2B1: "tr", + models.T3T1: "t3t1", + } + model_str = models_dict.get(self.model) + if model_str is None: raise ValueError("Unknown model") - def input_flow_tt(self) -> BRGeneratorType: - """Special for TT""" - raise NotImplementedError - - def input_flow_tr(self) -> BRGeneratorType: - """Special for TR""" - raise NotImplementedError - - def input_flow_t3t1(self) -> BRGeneratorType: - """Special for T3T1""" - raise NotImplementedError + return getattr(self, f"input_flow_{model_str}", None) def text_content(self) -> str: return self.debug.read_layout().text_content() @@ -98,7 +90,7 @@ class InputFlowSetupDevicePINWIpeCode(InputFlowBase): yield # do you want to set/change the wipe code? self.debug.press_yes() - if self.model() is models.T2B1: + if self.model is models.T2B1: layout = self.debug.read_layout() if "PinKeyboard" not in layout.all_components(): yield from swipe_if_necessary(self.debug) # wipe code info @@ -129,7 +121,7 @@ class InputFlowNewCodeMismatch(InputFlowBase): yield # do you want to set/change the pin/wipe code? self.debug.press_yes() - if self.model() is models.T2B1: + if self.model is models.T2B1: layout = self.debug.read_layout() if "PinKeyboard" not in layout.all_components(): yield from swipe_if_necessary(self.debug) # code info @@ -845,43 +837,6 @@ def sign_tx_go_to_info_t3t1( return content -def sign_tx_go_to_info_t3t1( - client: Client, multi_account: bool = False -) -> Generator[None, None, str]: - yield # confirm output - client.debug.read_layout() - client.debug.swipe_up() - yield # confirm output - client.debug.read_layout() - client.debug.swipe_up() - - if multi_account: - yield - client.debug.read_layout() - client.debug.swipe_up() - - yield # confirm transaction - client.debug.read_layout() - client.debug.click(buttons.CORNER_BUTTON) - client.debug.synchronize_at("VerticalMenu") - client.debug.click(buttons.VERTICAL_MENU[0]) - - layout = client.debug.read_layout() - content = layout.text_content() - - client.debug.click(buttons.CORNER_BUTTON) - client.debug.synchronize_at("VerticalMenu") - client.debug.click(buttons.VERTICAL_MENU[1]) - - layout = client.debug.read_layout() - content += " " + layout.text_content() - - client.debug.click(buttons.CORNER_BUTTON) - client.debug.click(buttons.CORNER_BUTTON) - - return content - - def sign_tx_go_to_info_tr( client: Client, ) -> Generator[None, None, str]: @@ -1146,9 +1101,9 @@ class InputFlowEIP712ShowMore(InputFlowBase): def _confirm_show_more(self) -> None: """Model-specific, either clicks a screen or presses a button.""" - if self.model() in (models.T2T1, models.T3T1): + if self.model in (models.T2T1, models.T3T1): self.debug.click(self.SHOW_MORE) - elif self.model() is models.T2B1: + elif self.model is models.T2B1: self.debug.press_right() def input_flow_common(self) -> BRGeneratorType: @@ -2006,7 +1961,7 @@ class InputFlowSlip39AdvancedRecoveryAbort(InputFlowBase): def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() - if self.model() in (models.T2T1, models.T3T1): + if self.model in (models.T2T1, models.T3T1): yield from self.REC.input_number_of_words(20) yield from self.REC.abort_recovery(True) @@ -2019,7 +1974,7 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase): def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() - if self.model() in (models.T2T1, models.T3T1): + if self.model in (models.T2T1, models.T3T1): yield from self.REC.input_number_of_words(self.word_count) yield from self.REC.abort_recovery(False) else: @@ -2128,7 +2083,7 @@ class InputFlowSlip39BasicRecoveryAbort(InputFlowBase): def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() - if self.model() in (models.T2T1, models.T3T1): + if self.model in (models.T2T1, models.T3T1): yield from self.REC.input_number_of_words(20) yield from self.REC.abort_recovery(True) @@ -2161,7 +2116,7 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase): def input_flow_common(self) -> BRGeneratorType: yield from self.REC.confirm_recovery() - if self.model() in (models.T2T1, models.T3T1): + if self.model in (models.T2T1, models.T3T1): yield from self.REC.input_number_of_words(self.word_count) yield from self.REC.abort_recovery(False) else: @@ -2295,17 +2250,6 @@ class InputFlowConfirmAllWarnings(InputFlowBase): def __init__(self, client: Client): super().__init__(client) - def input_flow_tt(self) -> BRGeneratorType: - br = yield - while True: - # wait for homescreen to go away - self.debug.read_layout() - self.client.ui._default_input_flow(br) - br = yield - - def input_flow_tr(self) -> BRGeneratorType: - return self.input_flow_tt() - def input_flow_t3t1(self) -> BRGeneratorType: br = yield while True: