From 7c65f0357a4be5681bd39af039b8271918d0e0d9 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Tue, 22 Jun 2021 19:47:03 +0200 Subject: [PATCH] feat(core/rust): Implement exception catching in Rust chore(core): Add test for Rust exc catching chore(core): Document exception catching in Rust [no changelog] --- core/embed/extmod/trezorobj.c | 12 ++++ core/embed/extmod/trezorobj.h | 2 + core/embed/rust/build.rs | 1 + core/embed/rust/src/micropython/runtime.rs | 79 +++++++++++++++++++++- 4 files changed, 93 insertions(+), 1 deletion(-) diff --git a/core/embed/extmod/trezorobj.c b/core/embed/extmod/trezorobj.c index 518d44db2..29575266b 100644 --- a/core/embed/extmod/trezorobj.c +++ b/core/embed/extmod/trezorobj.c @@ -21,6 +21,7 @@ #include "memzero.h" #include "py/objint.h" +#include "py/runtime.h" static bool mpz_as_ll_checked(const mpz_t *i, long long *value) { // Analogue of `mpz_as_int_checked` from mpz.c @@ -62,3 +63,14 @@ bool trezor_obj_get_ll_checked(mp_obj_t obj, long long *value) { return false; } } + +mp_obj_t trezor_obj_call_protected(void (*func)(void *), void *arg) { + nlr_buf_t nlr; + if (nlr_push(&nlr) == 0) { + (*func)(arg); + nlr_pop(); + return mp_const_none; + } else { + return MP_OBJ_FROM_PTR(nlr.ret_val); + } +} diff --git a/core/embed/extmod/trezorobj.h b/core/embed/extmod/trezorobj.h index 98a53703b..2fa43f5e4 100644 --- a/core/embed/extmod/trezorobj.h +++ b/core/embed/extmod/trezorobj.h @@ -77,4 +77,6 @@ static inline uint8_t trezor_obj_get_uint8(mp_obj_t obj) { bool trezor_obj_get_ll_checked(mp_obj_t obj, long long *value); +mp_obj_t trezor_obj_call_protected(void (*func)(void *), void *arg); + #endif diff --git a/core/embed/rust/build.rs b/core/embed/rust/build.rs index b9ec24932..a84544738 100644 --- a/core/embed/rust/build.rs +++ b/core/embed/rust/build.rs @@ -95,6 +95,7 @@ fn generate_micropython_bindings() { .allowlist_function("mp_map_lookup") // runtime .allowlist_function("mp_raise_ValueError") + .allowlist_function("trezor_obj_call_protected") // typ .allowlist_var("mp_type_type"); diff --git a/core/embed/rust/src/micropython/runtime.rs b/core/embed/rust/src/micropython/runtime.rs index c39291348..7c7b45121 100644 --- a/core/embed/rust/src/micropython/runtime.rs +++ b/core/embed/rust/src/micropython/runtime.rs @@ -1,6 +1,8 @@ +use core::mem::MaybeUninit; + use cstr_core::CStr; -use super::ffi; +use super::{ffi, obj::Obj}; pub fn raise_value_error(msg: &'static CStr) -> ! { unsafe { @@ -8,3 +10,78 @@ pub fn raise_value_error(msg: &'static CStr) -> ! { } panic!(); } + +/// Execute `func` while catching MicroPython exceptions. Returns `Ok` in the +/// successful case, and `Err` with the caught `Obj` in case of a raise. +pub fn except(mut func: F) -> Result +where + F: FnMut() -> T, +{ + // Because MicroPython exceptions use `setjmp` and `longjmp`-like mechanism that + // doesn't play too well with Rust, we setup the non-local return pads in C, and + // execute `func` through a callback. + + unsafe { + // First, we craft a wrapping closure that calls `func`. Because we are generic + // over the return type, we cannot pass the returned value over the FFI + // boundary, so we assign it explicitly in `wrapper`. + let mut result = MaybeUninit::zeroed(); + let mut wrapper = || { + result = MaybeUninit::new(func()); + }; + // `wrapper` is a closure, and to pass it over the FFI, we split it into a function + // pointer, and a user-data pointer. `ffi::trezor_obj_call_protected` then calls + // the `callback` with the `argument`. + let (callback, argument) = split_func_into_callback_and_argument(&mut wrapper); + let exception = ffi::trezor_obj_call_protected(Some(callback), argument); + if exception == Obj::const_none() { + Ok(result.assume_init()) + } else { + Err(exception) + } + } +} + +type ProtectedArgument = *mut cty::c_void; +type ProtectedCallback = unsafe extern "C" fn(ProtectedArgument); + +fn split_func_into_callback_and_argument(func: &mut F) -> (ProtectedCallback, ProtectedArgument) +where + F: FnMut(), +{ + // Here we mono-morphize a version of `trampoline` for each type `F`, so it + // calls the correct `FnMut` impl, and cast `func` into its data part to use + // as the argument. + (trampoline::, func as *mut _ as *mut _) +} + +unsafe extern "C" fn trampoline(arg: ProtectedArgument) +where + F: FnMut(), +{ + // Synthesize a callable `*mut F` from the closure environment pointer `arg`, + // and call it. + let func = arg as *mut F; + unsafe { + (*func)(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn except_returns_ok_on_no_exception() { + let result = except(|| 1); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 1); + } + + #[test] + fn except_catches_value_error() { + let msg = unsafe { CStr::from_bytes_with_nul_unchecked(b"msg\0") }; + let result = except(|| raise_value_error(&msg)); + assert!(result.is_err()); + } +}