diff --git a/core/embed/rust/build.rs b/core/embed/rust/build.rs index d5a0104a9..4e26dfcff 100644 --- a/core/embed/rust/build.rs +++ b/core/embed/rust/build.rs @@ -176,6 +176,7 @@ fn generate_micropython_bindings() { .allowlist_type("mp_obj_list_t") .allowlist_function("mp_obj_new_list") .allowlist_function("mp_obj_list_append") + .allowlist_function("mp_obj_list_get") .allowlist_function("mp_obj_list_set_len") .allowlist_var("mp_type_list") // map diff --git a/core/embed/rust/src/micropython/list.rs b/core/embed/rust/src/micropython/list.rs index a34d3ceb5..a6da3d425 100644 --- a/core/embed/rust/src/micropython/list.rs +++ b/core/embed/rust/src/micropython/list.rs @@ -42,16 +42,56 @@ impl List { Ok(gc_list) } + // Internal helper to get the `Obj` variant of this. + // SAFETY: For convenience, the function works on an immutable reference, but + // the returned `Obj` is inherently mutable. + // Caller is responsible for ensuring that self is borrowed mutably if any + // mutation is to occur. + unsafe fn as_mut_obj(&self) -> Obj { + unsafe { + let ptr = self as *const Self as *mut _; + Obj::from_ptr(ptr) + } + } + pub fn append(&mut self, value: Obj) -> Result<(), Error> { unsafe { - let ptr = self as *mut Self; - let list = Obj::from_ptr(ptr.cast()); + // SAFETY: self is borrowed mutably. + let list = self.as_mut_obj(); // EXCEPTION: Will raise if allocation fails. catch_exception(|| { ffi::mp_obj_list_append(list, value); }) } } + + pub fn len(&self) -> usize { + self.as_slice().len() + } + + pub fn as_slice(&self) -> &[Obj] { + unsafe { + // SAFETY: mp_obj_list_get() does not mutate the list. + let list = self.as_mut_obj(); + let mut len: usize = 0; + let mut items_ptr: *mut Obj = ptr::null_mut(); + ffi::mp_obj_list_get(list, &mut len, &mut items_ptr); + assert!(!items_ptr.is_null()); + core::slice::from_raw_parts(items_ptr, len) + } + } + + pub fn as_mut_slice(&mut self) -> &mut [Obj] { + unsafe { + // SAFETY: self is borrowed mutably. + let list = self.as_mut_obj(); + let mut len: usize = 0; + let mut items_ptr: *mut Obj = ptr::null_mut(); + ffi::mp_obj_list_get(list, &mut len, &mut items_ptr); + assert!(!items_ptr.is_null()); + core::slice::from_raw_parts_mut(items_ptr, len) + } + } } impl From> for Obj { @@ -105,4 +145,54 @@ mod tests { .unwrap(); assert_eq!(vec, retrieved_vec); } + + #[test] + fn list_len() { + unsafe { mpy_init() }; + + let vec: Vec = (0..17).collect(); + let list = List::from_iter(vec.iter().copied()).unwrap(); + assert_eq!(list.len(), vec.len()); + } + + #[test] + fn list_as_slice() { + unsafe { mpy_init() }; + + let vec: Vec = (13..13 + 17).collect(); + let list = List::from_iter(vec.iter().copied()).unwrap(); + + let slice = list.as_slice(); + assert_eq!(slice.len(), vec.len()); + for i in 0..slice.len() { + assert_eq!(vec[i], slice[i].try_into().unwrap()); + } + } + + #[test] + fn list_as_mut_slice() { + unsafe { mpy_init() }; + + let vec: Vec = (0..5).collect(); + let mut list = List::from_iter(vec.iter().copied()).unwrap(); + + let slice = unsafe { Gc::::as_mut(&mut list) }.as_mut_slice(); + assert_eq!(slice.len(), vec.len()); + assert_eq!(vec[0], slice[0].try_into().unwrap()); + + for i in 0..slice.len() { + slice[i] = ((i + 10) as u16).into(); + } + + let mut buf = IterBuf::new(); + let iter = Iter::try_from_obj_with_buf(list.into(), &mut buf).unwrap(); + let retrieved_vec: Vec = iter + .map(TryInto::try_into) + .collect::, Error>>() + .unwrap(); + + for i in 0..retrieved_vec.len() { + assert_eq!(retrieved_vec[i], vec[i] + 10); + } + } }