diff --git a/core/embed/rust/librust_qstr.h b/core/embed/rust/librust_qstr.h index 6f0e76e58a..478866ddc7 100644 --- a/core/embed/rust/librust_qstr.h +++ b/core/embed/rust/librust_qstr.h @@ -103,6 +103,7 @@ static void _librust_qstrs(void) { MP_QSTR_bounds; MP_QSTR_button; MP_QSTR_button_event; + MP_QSTR_button_request; MP_QSTR_buttons__abort; MP_QSTR_buttons__access; MP_QSTR_buttons__again; diff --git a/core/embed/rust/src/ui/button_request.rs b/core/embed/rust/src/ui/button_request.rs new file mode 100644 index 0000000000..a9529debfd --- /dev/null +++ b/core/embed/rust/src/ui/button_request.rs @@ -0,0 +1,60 @@ +use crate::strutil::TString; +use num_traits::FromPrimitive; + +// ButtonRequestType from messages-common.proto +// Eventually this should be generated +#[derive(Clone, Copy, FromPrimitive)] +#[repr(u16)] +pub enum ButtonRequestCode { + Other = 1, + FeeOverThreshold = 2, + ConfirmOutput = 3, + ResetDevice = 4, + ConfirmWord = 5, + WipeDevice = 6, + ProtectCall = 7, + SignTx = 8, + FirmwareCheck = 9, + Address = 10, + PublicKey = 11, + MnemonicWordCount = 12, + MnemonicInput = 13, + UnknownDerivationPath = 15, + RecoveryHomepage = 16, + Success = 17, + Warning = 18, + PassphraseEntry = 19, + PinEntry = 20, +} + +impl ButtonRequestCode { + pub fn num(&self) -> u16 { + *self as u16 + } + + pub fn with_type(self, br_type: &'static str) -> ButtonRequest { + ButtonRequest::new(self, br_type.into()) + } + + pub fn from(i: u16) -> Self { + unwrap!(Self::from_u16(i)) + } +} + +const MAX_TYPE_LEN: usize = 32; + +#[derive(Clone)] +pub struct ButtonRequest { + pub code: ButtonRequestCode, + pub br_type: TString<'static>, +} + +impl ButtonRequest { + pub fn new(code: ButtonRequestCode, br_type: TString<'static>) -> Self { + ButtonRequest { code, br_type } + } + + pub fn from_tstring(code: u16, br_type: TString<'static>) -> Self { + ButtonRequest::new(ButtonRequestCode::from(code), br_type) + } +} diff --git a/core/embed/rust/src/ui/component/base.rs b/core/embed/rust/src/ui/component/base.rs index 5e575ad5d6..397fcd6c8b 100644 --- a/core/embed/rust/src/ui/component/base.rs +++ b/core/embed/rust/src/ui/component/base.rs @@ -6,6 +6,7 @@ use crate::{ strutil::TString, time::Duration, ui::{ + button_request::{ButtonRequest, ButtonRequestCode}, component::{maybe::PaintOverlapping, MsgMap}, display::{self, Color}, geometry::{Offset, Rect}, @@ -487,6 +488,7 @@ pub struct EventCtx { paint_requested: bool, anim_frame_scheduled: bool, page_count: Option, + button_request: Option, root_repaint_requested: bool, } @@ -513,6 +515,7 @@ impl EventCtx { * `Child::marked_for_paint` being true. */ anim_frame_scheduled: false, page_count: None, + button_request: None, root_repaint_requested: false, } } @@ -570,6 +573,16 @@ impl EventCtx { self.page_count } + pub fn send_button_request(&mut self, code: ButtonRequestCode, br_type: TString<'static>) { + #[cfg(feature = "ui_debug")] + assert!(self.button_request.is_none()); + self.button_request = Some(ButtonRequest::new(code, br_type)); + } + + pub fn button_request(&mut self) -> Option { + self.button_request.take() + } + pub fn pop_timer(&mut self) -> Option<(TimerToken, Duration)> { self.timers.pop() } @@ -579,6 +592,9 @@ impl EventCtx { self.paint_requested = false; self.anim_frame_scheduled = false; self.page_count = None; + #[cfg(feature = "ui_debug")] + assert!(self.button_request.is_none()); + self.button_request = None; self.root_repaint_requested = false; } diff --git a/core/embed/rust/src/ui/component/button_request.rs b/core/embed/rust/src/ui/component/button_request.rs new file mode 100644 index 0000000000..193bd11dba --- /dev/null +++ b/core/embed/rust/src/ui/component/button_request.rs @@ -0,0 +1,65 @@ +use crate::ui::{ + button_request::ButtonRequest, + component::{Component, Event, EventCtx}, + geometry::Rect, +}; + +/// Component that sends a ButtonRequest after receiving Event::Attach. The +/// request is only sent once. +#[derive(Clone)] +pub struct OneButtonRequest { + button_request: Option, + pub inner: T, +} + +impl OneButtonRequest { + pub fn new(button_request: ButtonRequest, inner: T) -> Self { + Self { + button_request: Some(button_request), + inner, + } + } +} + +impl Component for OneButtonRequest { + type Msg = T::Msg; + + fn place(&mut self, bounds: Rect) -> Rect { + self.inner.place(bounds) + } + + fn event(&mut self, ctx: &mut EventCtx, event: Event) -> Option { + if matches!(event, Event::Attach) { + if let Some(button_request) = self.button_request.take() { + ctx.send_button_request(button_request.code, button_request.br_type) + } + } + self.inner.event(ctx, event) + } + + fn paint(&mut self) { + self.inner.paint() + } + + fn render<'s>(&'s self, target: &mut impl crate::ui::shape::Renderer<'s>) { + self.inner.render(target) + } +} + +#[cfg(feature = "ui_debug")] +impl crate::trace::Trace for OneButtonRequest { + fn trace(&self, t: &mut dyn crate::trace::Tracer) { + self.inner.trace(t) + } +} + +pub trait ButtonRequestExt { + fn one_button_request(self, br: ButtonRequest) -> OneButtonRequest + where + Self: Sized, + { + OneButtonRequest::new(br, self) + } +} + +impl ButtonRequestExt for T {} diff --git a/core/embed/rust/src/ui/component/mod.rs b/core/embed/rust/src/ui/component/mod.rs index bc1e36dbb3..3d891580f2 100644 --- a/core/embed/rust/src/ui/component/mod.rs +++ b/core/embed/rust/src/ui/component/mod.rs @@ -3,6 +3,7 @@ pub mod bar; pub mod base; pub mod border; +pub mod button_request; pub mod connect; pub mod empty; pub mod image; @@ -24,6 +25,7 @@ pub mod timeout; pub use bar::Bar; pub use base::{Child, Component, ComponentExt, Event, EventCtx, Never, Root, TimerToken}; pub use border::Border; +pub use button_request::{ButtonRequestExt, OneButtonRequest}; pub use empty::Empty; #[cfg(all(feature = "jpeg", feature = "micropython"))] pub use jpeg::Jpeg; diff --git a/core/embed/rust/src/ui/layout/obj.rs b/core/embed/rust/src/ui/layout/obj.rs index b7c15eb6f5..01130a3ccc 100644 --- a/core/embed/rust/src/ui/layout/obj.rs +++ b/core/embed/rust/src/ui/layout/obj.rs @@ -17,6 +17,7 @@ use crate::{ }, time::Duration, ui::{ + button_request::ButtonRequest, component::{Component, Event, EventCtx, Never, Root, TimerToken}, constant, display::sync, @@ -248,6 +249,17 @@ impl LayoutObj { self.inner.borrow().page_count.into() } + fn obj_button_request(&self) -> Result { + let inner = &mut *self.inner.borrow_mut(); + + match inner.event_ctx.button_request() { + None => Ok(Obj::const_none()), + Some(ButtonRequest { code, br_type }) => { + (code.num().into(), br_type.try_into()?).try_into() + } + } + } + #[cfg(feature = "ui_debug")] fn obj_bounds(&self) { use crate::ui::display; @@ -280,6 +292,7 @@ impl LayoutObj { Qstr::MP_QSTR_trace => obj_fn_2!(ui_layout_trace).as_obj(), Qstr::MP_QSTR_bounds => obj_fn_1!(ui_layout_bounds).as_obj(), Qstr::MP_QSTR_page_count => obj_fn_1!(ui_layout_page_count).as_obj(), + Qstr::MP_QSTR_button_request => obj_fn_1!(ui_layout_button_request).as_obj(), }), }; &TYPE @@ -470,6 +483,14 @@ extern "C" fn ui_layout_page_count(this: Obj) -> Obj { unsafe { util::try_or_raise(block) } } +extern "C" fn ui_layout_button_request(this: Obj) -> Obj { + let block = || { + let this: Gc = this.try_into()?; + this.obj_button_request() + }; + unsafe { util::try_or_raise(block) } +} + #[cfg(feature = "ui_debug")] #[no_mangle] pub extern "C" fn ui_debug_layout_type() -> &'static Type { diff --git a/core/embed/rust/src/ui/mod.rs b/core/embed/rust/src/ui/mod.rs index 75c14bbf74..9e1518c8a0 100644 --- a/core/embed/rust/src/ui/mod.rs +++ b/core/embed/rust/src/ui/mod.rs @@ -2,6 +2,7 @@ pub mod macros; pub mod animation; +pub mod button_request; pub mod component; pub mod constant; pub mod display; diff --git a/core/embed/rust/src/ui/model_tt/layout.rs b/core/embed/rust/src/ui/model_tt/layout.rs index a981b04cac..48e87a4ef2 100644 --- a/core/embed/rust/src/ui/model_tt/layout.rs +++ b/core/embed/rust/src/ui/model_tt/layout.rs @@ -1652,6 +1652,9 @@ pub static mp_module_trezorui2: Module = obj_module! { /// def page_count(self) -> int: /// """Return the number of pages in the layout object.""" /// + /// def button_request(self) -> tuple[int, str] | None: + /// """Return (code, type) of button request made during the last event or timer pass.""" + /// /// class UiResult: /// """Result of a UI operation.""" /// pass diff --git a/core/mocks/generated/trezorui2.pyi b/core/mocks/generated/trezorui2.pyi index e26546dc92..94e8f6b6f0 100644 --- a/core/mocks/generated/trezorui2.pyi +++ b/core/mocks/generated/trezorui2.pyi @@ -1079,6 +1079,8 @@ class LayoutObj(Generic[T]): """Paint bounds of individual components on screen.""" def page_count(self) -> int: """Return the number of pages in the layout object.""" + def button_request(self) -> tuple[int, str] | None: + """Return (code, type) of button request made during the last event or timer pass.""" # rust/src/ui/model_tt/layout.rs diff --git a/core/src/trezor/ui/layouts/common.py b/core/src/trezor/ui/layouts/common.py index f16dcc16e8..86ce7a47fb 100644 --- a/core/src/trezor/ui/layouts/common.py +++ b/core/src/trezor/ui/layouts/common.py @@ -39,4 +39,4 @@ async def interact( # We know for certain how many pages the layout will have pages = layout.page_count() # type: ignore [Cannot access attribute "page_count" for class "LayoutType"] await button_request(br_type, br_code, pages) - return await context.wait(layout) + return await layout diff --git a/core/src/trezor/ui/layouts/tr/__init__.py b/core/src/trezor/ui/layouts/tr/__init__.py index fa75223173..8101160d52 100644 --- a/core/src/trezor/ui/layouts/tr/__init__.py +++ b/core/src/trezor/ui/layouts/tr/__init__.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING import trezorui2 -from trezor import TR, io, loop, ui, utils +from trezor import TR, io, log, loop, ui, utils from trezor.enums import ButtonRequestType -from trezor.wire import ActionCancelled -from trezor.wire.context import wait as ctx_wait +from trezor.messages import ButtonAck, ButtonRequest +from trezor.wire import ActionCancelled, context from ..common import button_request, interact @@ -38,9 +38,11 @@ if __debug__: class RustLayout(LayoutParentType[T]): # pylint: disable=super-init-not-called def __init__(self, layout: trezorui2.LayoutObj[T]): + self.br_chan = loop.chan() self.layout = layout self.timer = loop.Timer() self.layout.attach_timer_fn(self.set_timer) + self._send_button_request() def set_timer(self, token: int, deadline: int) -> None: self.timer.schedule(deadline, token) @@ -62,13 +64,23 @@ class RustLayout(LayoutParentType[T]): from trezor.enums import DebugPhysicalButton def create_tasks(self) -> tuple[loop.AwaitableTask, ...]: - return ( - self.handle_input_and_rendering(), - self.handle_timers(), - self.handle_swipe_signal(), - self.handle_button_signal(), - self.handle_result_signal(), - ) + if context.CURRENT_CONTEXT: + return ( + self.handle_input_and_rendering(), + self.handle_timers(), + self.handle_swipe_signal(), + self.handle_button_signal(), + self.handle_result_signal(), + self.handle_usb(context.get_context()), + ) + else: + return ( + self.handle_input_and_rendering(), + self.handle_timers(), + self.handle_swipe_signal(), + self.handle_button_signal(), + self.handle_result_signal(), + ) async def handle_result_signal(self) -> None: """Enables sending arbitrary input - ui.Result. @@ -98,30 +110,41 @@ class RustLayout(LayoutParentType[T]): async def _press_left(self, hold_ms: int | None) -> Any: """Triggers left button press.""" self.layout.button_event(io.BUTTON_PRESSED, io.BUTTON_LEFT) + self._send_button_request() self._paint() if hold_ms is not None: await loop.sleep(hold_ms) - return self.layout.button_event(io.BUTTON_RELEASED, io.BUTTON_LEFT) + r = self.layout.button_event(io.BUTTON_RELEASED, io.BUTTON_LEFT) + self._send_button_request() + return r async def _press_right(self, hold_ms: int | None) -> Any: """Triggers right button press.""" self.layout.button_event(io.BUTTON_PRESSED, io.BUTTON_RIGHT) + self._send_button_request() self._paint() if hold_ms is not None: await loop.sleep(hold_ms) - return self.layout.button_event(io.BUTTON_RELEASED, io.BUTTON_RIGHT) + r = self.layout.button_event(io.BUTTON_RELEASED, io.BUTTON_RIGHT) + self._send_button_request() + return r async def _press_middle(self, hold_ms: int | None) -> Any: """Triggers middle button press.""" self.layout.button_event(io.BUTTON_PRESSED, io.BUTTON_LEFT) + self._send_button_request() self._paint() self.layout.button_event(io.BUTTON_PRESSED, io.BUTTON_RIGHT) + self._send_button_request() self._paint() if hold_ms is not None: await loop.sleep(hold_ms) self.layout.button_event(io.BUTTON_RELEASED, io.BUTTON_LEFT) + self._send_button_request() self._paint() - return self.layout.button_event(io.BUTTON_RELEASED, io.BUTTON_RIGHT) + r = self.layout.button_event(io.BUTTON_RELEASED, io.BUTTON_RIGHT) + self._send_button_request() + return r async def _press_button( self, @@ -197,7 +220,11 @@ class RustLayout(LayoutParentType[T]): else: def create_tasks(self) -> tuple[loop.AwaitableTask, ...]: - return self.handle_timers(), self.handle_input_and_rendering() + return ( + self.handle_timers(), + self.handle_input_and_rendering(), + self.handle_usb(context.get_context()), + ) def _first_paint(self) -> None: self._paint() @@ -233,6 +260,7 @@ class RustLayout(LayoutParentType[T]): msg = None if event in (io.BUTTON_PRESSED, io.BUTTON_RELEASED): msg = self.layout.button_event(event, button_num) + self._send_button_request() if msg is not None: raise ui.Result(msg) self._paint() @@ -242,6 +270,7 @@ class RustLayout(LayoutParentType[T]): # Using `yield` instead of `await` to avoid allocations. token = yield self.timer msg = self.layout.timer(token) + self._send_button_request() if msg is not None: raise ui.Result(msg) self._paint() @@ -250,6 +279,20 @@ class RustLayout(LayoutParentType[T]): """How many paginated pages current screen has.""" return self.layout.page_count() + async def handle_usb(self, ctx: context.Context): + while True: + br_code, br_type, page_count = await loop.race( + ctx.read(()), self.br_chan.take() + ) + log.debug(__name__, "ButtonRequest.type=%s", br_type) + await ctx.call(ButtonRequest(code=br_code, pages=page_count), ButtonAck) + + def _send_button_request(self): + res = self.layout.button_request() + if res is not None: + br_code, br_type = res + self.br_chan.publish((br_code, br_type, self.layout.page_count())) + def draw_simple(layout: trezorui2.LayoutObj[Any]) -> None: # Simple drawing not supported for layouts that set timers. @@ -513,7 +556,7 @@ async def show_address( pages=layout.page_count(), ) layout.request_complete_repaint() - result = await ctx_wait(layout) + result = await layout # User confirmed with middle button. if result is CONFIRMED: @@ -532,27 +575,23 @@ async def show_address( ) return result - result = await ctx_wait( - RustLayout( - trezorui2.show_address_details( - qr_title="", # unused on this model - address=address if address_qr is None else address_qr, - case_sensitive=case_sensitive, - details_title="", # unused on this model - account=account, - path=path, - xpubs=[(xpub_title(i), xpub) for i, xpub in enumerate(xpubs)], - ) - ), + result = await RustLayout( + trezorui2.show_address_details( + qr_title="", # unused on this model + address=address if address_qr is None else address_qr, + case_sensitive=case_sensitive, + details_title="", # unused on this model + account=account, + path=path, + xpubs=[(xpub_title(i), xpub) for i, xpub in enumerate(xpubs)], + ) ) # Can only go back from the address details. assert result is CANCELLED # User pressed left cancel button, show mismatch dialogue. else: - result = await ctx_wait( - RustLayout(trezorui2.show_mismatch(title=mismatch_title)) - ) + result = await RustLayout(trezorui2.show_mismatch(title=mismatch_title)) assert result in (CONFIRMED, CANCELLED) # Right button aborts action, left goes back to showing address. if result is CONFIRMED: @@ -1029,24 +1068,22 @@ async def confirm_value( should_show_more_layout.page_count(), ) - result = await ctx_wait(should_show_more_layout) + result = await should_show_more_layout if result is CONFIRMED: return elif result is INFO: info_title, info_value = info_items_list[0] - await ctx_wait( - RustLayout( - trezorui2.confirm_blob( - title=info_title, - data=info_value, - description=description, - extra=None, - verb="", - verb_cancel="<", - hold=False, - chunkify=chunkify_info, - ) + await RustLayout( + trezorui2.confirm_blob( + title=info_title, + data=info_value, + description=description, + extra=None, + verb="", + verb_cancel="<", + hold=False, + chunkify=chunkify_info, ) ) else: @@ -1287,7 +1324,7 @@ async def confirm_modify_output( address_layout.page_count(), ) address_layout.request_complete_repaint() - await raise_if_not_confirmed(ctx_wait(address_layout)) + await raise_if_not_confirmed(address_layout) if send_button_request: send_button_request = False @@ -1297,7 +1334,7 @@ async def confirm_modify_output( modify_layout.page_count(), ) modify_layout.request_complete_repaint() - result = await ctx_wait(modify_layout) + result = await modify_layout if result is CONFIRMED: break diff --git a/core/src/trezor/ui/layouts/tt/__init__.py b/core/src/trezor/ui/layouts/tt/__init__.py index 3c8f506130..aa954e13c9 100644 --- a/core/src/trezor/ui/layouts/tt/__init__.py +++ b/core/src/trezor/ui/layouts/tt/__init__.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING import trezorui2 -from trezor import TR, io, loop, ui, utils +from trezor import TR, io, log, loop, ui, utils from trezor.enums import ButtonRequestType -from trezor.wire import ActionCancelled -from trezor.wire.context import wait as ctx_wait +from trezor.messages import ButtonAck, ButtonRequest +from trezor.wire import ActionCancelled, context from ..common import button_request, interact @@ -40,9 +40,11 @@ class RustLayout(LayoutParentType[T]): # pylint: disable=super-init-not-called def __init__(self, layout: trezorui2.LayoutObj[T]): + self.br_chan = loop.chan() self.layout = layout self.timer = loop.Timer() self.layout.attach_timer_fn(self.set_timer) + self._send_button_request() def set_timer(self, token: int, deadline: int) -> None: self.timer.schedule(deadline, token) @@ -63,13 +65,23 @@ class RustLayout(LayoutParentType[T]): if __debug__: def create_tasks(self) -> tuple[loop.AwaitableTask, ...]: - return ( - self.handle_timers(), - self.handle_input_and_rendering(), - self.handle_swipe(), - self.handle_click_signal(), - self.handle_result_signal(), - ) + if context.CURRENT_CONTEXT: + return ( + self.handle_timers(), + self.handle_input_and_rendering(), + self.handle_swipe(), + self.handle_click_signal(), + self.handle_result_signal(), + self.handle_usb(context.get_context()), + ) + else: + return ( + self.handle_timers(), + self.handle_input_and_rendering(), + self.handle_swipe(), + self.handle_click_signal(), + self.handle_result_signal(), + ) async def handle_result_signal(self) -> None: """Enables sending arbitrary input - ui.Result. @@ -116,6 +128,7 @@ class RustLayout(LayoutParentType[T]): (io.TOUCH_END, orig_x + 2 * off_x, orig_y + 2 * off_y), ): msg = self.layout.touch_event(event, x, y) + self._send_button_request() self._paint() if msg is not None: raise ui.Result(msg) @@ -135,10 +148,12 @@ class RustLayout(LayoutParentType[T]): from apps.debug import notify_layout_change self.layout.touch_event(io.TOUCH_START, x, y) + self._send_button_request() self._paint() if hold_ms is not None: await loop.sleep(hold_ms) msg = self.layout.touch_event(io.TOUCH_END, x, y) + self._send_button_request() if msg is not None: debug_storage.new_layout_event_id = event_id @@ -165,7 +180,17 @@ class RustLayout(LayoutParentType[T]): else: def create_tasks(self) -> tuple[loop.AwaitableTask, ...]: - return self.handle_timers(), self.handle_input_and_rendering() + if context.CURRENT_CONTEXT: + return ( + self.handle_timers(), + self.handle_input_and_rendering(), + self.handle_usb(context.get_context()), + ) + else: + return ( + self.handle_timers(), + self.handle_input_and_rendering(), + ) def _first_paint(self) -> None: ui.backlight_fade(ui.style.BACKLIGHT_NONE) @@ -205,6 +230,7 @@ class RustLayout(LayoutParentType[T]): msg = None if event in (io.TOUCH_START, io.TOUCH_MOVE, io.TOUCH_END): msg = self.layout.touch_event(event, x, y) + self._send_button_request() if msg is not None: raise ui.Result(msg) self._paint() @@ -214,6 +240,7 @@ class RustLayout(LayoutParentType[T]): # Using `yield` instead of `await` to avoid allocations. token = yield self.timer msg = self.layout.timer(token) + self._send_button_request() if msg is not None: raise ui.Result(msg) self._paint() @@ -221,6 +248,20 @@ class RustLayout(LayoutParentType[T]): def page_count(self) -> int: return self.layout.page_count() + async def handle_usb(self, ctx: context.Context): + while True: + br_code, br_type, page_count = await loop.race( + ctx.read(()), self.br_chan.take() + ) + log.debug(__name__, "ButtonRequest.type=%s", br_type) + await ctx.call(ButtonRequest(code=br_code, pages=page_count), ButtonAck) + + def _send_button_request(self): + res = self.layout.button_request() + if res is not None: + br_code, br_type = res + self.br_chan.publish((br_code, br_type, self.layout.page_count())) + def draw_simple(layout: trezorui2.LayoutObj[Any]) -> None: # Simple drawing not supported for layouts that set timers. @@ -460,7 +501,7 @@ async def show_address( pages=layout.page_count(), ) layout.request_complete_repaint() - result = await ctx_wait(layout) + result = await layout # User pressed right button. if result is CONFIRMED: @@ -478,25 +519,21 @@ async def show_address( ) return result - result = await ctx_wait( - RustLayout( - trezorui2.show_address_details( - qr_title=title, - address=address if address_qr is None else address_qr, - case_sensitive=case_sensitive, - details_title=details_title, - account=account, - path=path, - xpubs=[(xpub_title(i), xpub) for i, xpub in enumerate(xpubs)], - ) + result = await RustLayout( + trezorui2.show_address_details( + qr_title=title, + address=address if address_qr is None else address_qr, + case_sensitive=case_sensitive, + details_title=details_title, + account=account, + path=path, + xpubs=[(xpub_title(i), xpub) for i, xpub in enumerate(xpubs)], ) ) assert result is CANCELLED else: - result = await ctx_wait( - RustLayout(trezorui2.show_mismatch(title=mismatch_title)) - ) + result = await RustLayout(trezorui2.show_mismatch(title=mismatch_title)) assert result in (CONFIRMED, CANCELLED) # Right button aborts action, left goes back to showing address. if result is CONFIRMED: @@ -1198,7 +1235,7 @@ async def confirm_modify_output( address_layout.page_count(), ) address_layout.request_complete_repaint() - await raise_if_not_confirmed(ctx_wait(address_layout)) + await raise_if_not_confirmed(address_layout) if send_button_request: send_button_request = False @@ -1208,7 +1245,7 @@ async def confirm_modify_output( modify_layout.page_count(), ) modify_layout.request_complete_repaint() - result = await ctx_wait(modify_layout) + result = await modify_layout if result is CONFIRMED: break @@ -1223,11 +1260,11 @@ async def with_info( await button_request(br_type, br_code, pages=main_layout.page_count()) while True: - result = await ctx_wait(main_layout) + result = await main_layout if result is INFO: info_layout.request_complete_repaint() - result = await ctx_wait(info_layout) + result = await info_layout assert result is CANCELLED main_layout.request_complete_repaint() continue @@ -1355,8 +1392,8 @@ async def confirm_signverify( address_layout, info_layout, br_type, br_code=BR_TYPE_OTHER ) if result is not CONFIRMED: - result = await ctx_wait( - RustLayout(trezorui2.show_mismatch(title=TR.addr_mismatch__mismatch)) + result = await RustLayout( + trezorui2.show_mismatch(title=TR.addr_mismatch__mismatch) ) assert result in (CONFIRMED, CANCELLED) # Right button aborts action, left goes back to showing address. diff --git a/core/src/trezor/ui/layouts/tt/recovery.py b/core/src/trezor/ui/layouts/tt/recovery.py index a3c99e4fc8..82310480a9 100644 --- a/core/src/trezor/ui/layouts/tt/recovery.py +++ b/core/src/trezor/ui/layouts/tt/recovery.py @@ -3,7 +3,6 @@ from typing import Callable, Iterable import trezorui2 from trezor import TR from trezor.enums import ButtonRequestType -from trezor.wire.context import wait as ctx_wait from ..common import interact from . import RustLayout, raise_if_not_confirmed @@ -17,7 +16,7 @@ async def _is_confirmed_info( info_func: Callable, ) -> bool: while True: - result = await ctx_wait(dialog) + result = await dialog if result is trezorui2.INFO: await info_func() @@ -50,7 +49,7 @@ async def request_word( ) ) - word: str = await ctx_wait(keyboard) + word: str = await keyboard return word @@ -143,7 +142,7 @@ async def continue_recovery( if info_func is not None: return await _is_confirmed_info(homepage, info_func) else: - result = await ctx_wait(homepage) + result = await homepage return result is CONFIRMED diff --git a/core/src/trezor/ui/layouts/tt/reset.py b/core/src/trezor/ui/layouts/tt/reset.py index c5a08d10f6..1975770637 100644 --- a/core/src/trezor/ui/layouts/tt/reset.py +++ b/core/src/trezor/ui/layouts/tt/reset.py @@ -4,7 +4,6 @@ import trezorui2 from trezor import TR from trezor.enums import ButtonRequestType from trezor.wire import ActionCancelled -from trezor.wire.context import wait as ctx_wait from ..common import interact from . import RustLayout, raise_if_not_confirmed @@ -95,15 +94,13 @@ async def select_word( while len(words) < 3: words.append(words[-1]) - result = await ctx_wait( - RustLayout( - trezorui2.select_word( - title=title, - description=TR.reset__select_word_x_of_y_template.format( - checked_index + 1, count - ), - words=(words[0], words[1], words[2]), - ) + result = await RustLayout( + trezorui2.select_word( + title=title, + description=TR.reset__select_word_x_of_y_template.format( + checked_index + 1, count + ), + words=(words[0], words[1], words[2]), ) ) if __debug__ and isinstance(result, str): @@ -183,13 +180,11 @@ async def _prompt_number( assert isinstance(value, int) return value - await ctx_wait( - RustLayout( - trezorui2.show_simple( - title=None, - description=info(value), - button=TR.buttons__ok_i_understand, - ) + await RustLayout( + trezorui2.show_simple( + title=None, + description=info(value), + button=TR.buttons__ok_i_understand, ) ) num_input.request_complete_repaint() diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 08eaab3474..b2f9afa5e4 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -157,6 +157,17 @@ class Context: memoryview(buffer)[:msg_size], ) + async def call( + self, + msg: protobuf.MessageType, + expected_type: type[LoadedMessageType], + ) -> LoadedMessageType: + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + CURRENT_CONTEXT: Context | None = None @@ -186,11 +197,7 @@ async def call( if CURRENT_CONTEXT is None: raise RuntimeError("No wire context") - assert expected_type.MESSAGE_WIRE_TYPE is not None - - await CURRENT_CONTEXT.write(msg) - del msg - return await CURRENT_CONTEXT.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + return await CURRENT_CONTEXT.call(msg, expected_type) async def call_any( diff --git a/core/tests/test_apps.ethereum.sign_typed_data.py b/core/tests/test_apps.ethereum.sign_typed_data.py index bf6f3772c9..e8a261886a 100644 --- a/core/tests/test_apps.ethereum.sign_typed_data.py +++ b/core/tests/test_apps.ethereum.sign_typed_data.py @@ -41,6 +41,17 @@ class MockContext: async def read(self, _resp_types, _resp_type): return EthereumTypedDataValueAck(value=self.next_response) + async def call( + self, + msg: protobuf.MessageType, + expected_type: type[LoadedMessageType], + ) -> LoadedMessageType: + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + # Helper functions from trezorctl to build expected type data structures # TODO: it could be better to group these functions into a class, to visibly differentiate it