From e942f8e40d6f9bc629c62bcdc2d2c27215640e13 Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Tue, 3 Dec 2024 19:41:31 +0100 Subject: [PATCH] refactor(core): split polling can_read and reading from USB [no changelog] --- .../upymod/modtrezorio/modtrezorio-hid.h | 15 ++++++++++ .../upymod/modtrezorio/modtrezorio-poll.h | 21 ++++---------- .../upymod/modtrezorio/modtrezorio-webusb.h | 15 ++++++++++ core/mocks/generated/trezorio/__init__.pyi | 10 +++++++ core/src/apps/webauthn/fido2.py | 12 ++++++-- core/src/trezor/wire/codec/codec_v1.py | 12 ++++++-- core/tests/test_trezor.wire.codec.codec_v1.py | 28 +++++++++++++------ 7 files changed, 85 insertions(+), 28 deletions(-) diff --git a/core/embed/upymod/modtrezorio/modtrezorio-hid.h b/core/embed/upymod/modtrezorio/modtrezorio-hid.h index bcf141d776..cedc2b437d 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-hid.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-hid.h @@ -141,6 +141,20 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj, mod_trezorio_HID_write); +/// def read(self, buf: bytes) -> int: +/// """ +/// Reads message using USB HID (device) or UDP (emulator). +/// """ +STATIC mp_obj_t mod_trezorio_HID_read(mp_obj_t self, mp_obj_t buffer) { + mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); + mp_buffer_info_t buf = {0}; + mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); + ssize_t r = usb_hid_read(o->info.iface_num, buf.buf, buf.len); + return MP_OBJ_NEW_SMALL_INT(r); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_read_obj, + mod_trezorio_HID_read); + /// def write_blocking(self, msg: bytes, timeout_ms: int) -> int: /// """ /// Sends message using USB HID (device) or UDP (emulator). @@ -162,6 +176,7 @@ STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = { {MP_ROM_QSTR(MP_QSTR_iface_num), MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)}, {MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)}, + {MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)}, {MP_ROM_QSTR(MP_QSTR_write_blocking), MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)}, }; diff --git a/core/embed/upymod/modtrezorio/modtrezorio-poll.h b/core/embed/upymod/modtrezorio/modtrezorio-poll.h index 2d97e491e7..3e0a738abd 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-poll.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-poll.h @@ -166,22 +166,11 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref, } #endif else if (mode == POLL_READ) { - if (sectrue == usb_hid_can_read(iface)) { - uint8_t buf[64] = {0}; - int len = usb_hid_read(iface, buf, sizeof(buf)); - if (len > 0) { - ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); - ret->items[1] = mp_obj_new_bytes(buf, len); - return mp_const_true; - } - } else if (sectrue == usb_webusb_can_read(iface)) { - uint8_t buf[64] = {0}; - int len = usb_webusb_read(iface, buf, sizeof(buf)); - if (len > 0) { - ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); - ret->items[1] = mp_obj_new_bytes(buf, len); - return mp_const_true; - } + if ((sectrue == usb_hid_can_read(iface)) || + (sectrue == usb_webusb_can_read(iface))) { + ret->items[0] = MP_OBJ_NEW_SMALL_INT(i); + ret->items[1] = MP_OBJ_NEW_SMALL_INT(64); + return mp_const_true; } } else if (mode == POLL_WRITE) { if (sectrue == usb_hid_can_write(iface)) { diff --git a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h index d893b10717..4773be0764 100644 --- a/core/embed/upymod/modtrezorio/modtrezorio-webusb.h +++ b/core/embed/upymod/modtrezorio/modtrezorio-webusb.h @@ -127,10 +127,25 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj, mod_trezorio_WebUSB_write); +/// def read(self, buf: bytes) -> int: +/// """ +/// Reads message using WebUSB (device) or UDP (emulator). +/// """ +STATIC mp_obj_t mod_trezorio_WebUSB_read(mp_obj_t self, mp_obj_t buffer) { + mp_obj_HID_t *o = MP_OBJ_TO_PTR(self); + mp_buffer_info_t buf = {0}; + mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE); + ssize_t r = usb_webusb_read(o->info.iface_num, buf.buf, buf.len); + return MP_OBJ_NEW_SMALL_INT(r); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_read_obj, + mod_trezorio_WebUSB_read); + STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = { {MP_ROM_QSTR(MP_QSTR_iface_num), MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)}, {MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)}, + {MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)}, }; STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict, mod_trezorio_WebUSB_locals_dict_table); diff --git a/core/mocks/generated/trezorio/__init__.pyi b/core/mocks/generated/trezorio/__init__.pyi index efb11e08e9..8f34d9d1f0 100644 --- a/core/mocks/generated/trezorio/__init__.pyi +++ b/core/mocks/generated/trezorio/__init__.pyi @@ -32,6 +32,11 @@ class HID: Sends message using USB HID (device) or UDP (emulator). """ + def read(self, buf: bytes) -> int: + """ + Reads message using USB HID (device) or UDP (emulator). + """ + def write_blocking(self, msg: bytes, timeout_ms: int) -> int: """ Sends message using USB HID (device) or UDP (emulator). @@ -148,6 +153,11 @@ class WebUSB: """ Sends message using USB WebUSB (device) or UDP (emulator). """ + + def read(self, buf: bytes) -> int: + """ + Reads message using WebUSB (device) or UDP (emulator). + """ from . import fatfs, haptic, sdcard POLL_READ: int # wait until interface is readable and return read data POLL_WRITE: int # wait until interface is writable diff --git a/core/src/apps/webauthn/fido2.py b/core/src/apps/webauthn/fido2.py index 5a1bedd4ae..2125ec8699 100644 --- a/core/src/apps/webauthn/fido2.py +++ b/core/src/apps/webauthn/fido2.py @@ -376,7 +376,11 @@ async def _read_cmd(iface: HID) -> Cmd | None: read = loop.wait(iface.iface_num() | io.POLL_READ) # wait for incoming command indefinitely - buf = await read + msg_len = await read + buf = bytearray(msg_len) + read_len = iface.read(buf) + if read_len != msg_len: + raise ValueError("Invalid length") while True: ifrm = overlay_struct(bytearray(buf), desc_init) bcnt = ifrm.bcnt @@ -415,7 +419,11 @@ async def _read_cmd(iface: HID) -> Cmd | None: read.timeout_ms = _CTAP_HID_TIMEOUT_MS while datalen < bcnt: try: - buf = await read + msg_len = await read + buf = bytearray(msg_len) + read_len = iface.read(buf) + if read_len != msg_len: + raise ValueError("Invalid length") except loop.Timeout: if __debug__: warning(__name__, "_ERR_MSG_TIMEOUT") diff --git a/core/src/trezor/wire/codec/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py index 02ff37f0ea..68a2797ff5 100644 --- a/core/src/trezor/wire/codec/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -25,7 +25,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag read = loop.wait(iface.iface_num() | io.POLL_READ) # wait for initial report - report = await read + msg_len = await read + report = bytearray(msg_len) + read_len = iface.read(report) + if read_len != msg_len: + raise CodecError("Invalid length") if report[0] != _REP_MARKER: raise CodecError("Invalid magic") _, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report) @@ -50,7 +54,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag while nread < msize: # wait for continuation report - report = await read + msg_len = await read + report = bytearray(msg_len) + read_len = iface.read(report) + if read_len != msg_len: + raise CodecError("Invalid length") if report[0] != _REP_MARKER: raise CodecError("Invalid magic") diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py index 78675859e2..5a73467e25 100644 --- a/core/tests/test_trezor.wire.codec.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -12,6 +12,7 @@ class MockHID: def __init__(self, num): self.num = num self.data = [] + self.packet = None def iface_num(self): return self.num @@ -20,6 +21,17 @@ class MockHID: self.data.append(bytearray(msg)) return len(msg) + def mock_read(self, packet, gen): + self.packet = packet + return gen.send(len(packet)) + + def read(self, buffer): + if self.packet is None: + raise Exception("No packet to read") + buffer[:] = self.packet + self.packet = None + return len(buffer) + def wait_object(self, mode): return wait(mode | self.num) @@ -48,7 +60,7 @@ class TestWireCodecV1(unittest.TestCase): self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) with self.assertRaises(StopIteration) as e: - gen.send(message_packet) + self.interface.mock_read(message_packet, gen) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value @@ -74,11 +86,11 @@ class TestWireCodecV1(unittest.TestCase): query = gen.send(None) for packet in packets[:-1]: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) - query = gen.send(packet) + query = self.interface.mock_read(packet, gen) # last packet will stop with self.assertRaises(StopIteration) as e: - gen.send(packets[-1]) + self.interface.mock_read(packets[-1], gen) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value @@ -103,7 +115,7 @@ class TestWireCodecV1(unittest.TestCase): query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) with self.assertRaises(StopIteration) as e: - gen.send(packet) + self.interface.mock_read(packet, gen) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value @@ -169,10 +181,10 @@ class TestWireCodecV1(unittest.TestCase): query = gen.send(None) for packet in self.interface.data[:-1]: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) - query = gen.send(packet) + query = self.interface.mock_read(packet, gen) with self.assertRaises(StopIteration) as e: - gen.send(self.interface.data[-1]) + self.interface.mock_read(self.interface.data[-1], gen) result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) @@ -194,10 +206,10 @@ class TestWireCodecV1(unittest.TestCase): query = gen.send(None) for _ in range(PACKET_COUNT - 1): self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) - query = gen.send(packet) + query = self.interface.mock_read(packet, gen) with self.assertRaises(codec_v1.CodecError) as e: - gen.send(packet) + self.interface.mock_read(packet,gen) self.assertEqual(e.value.args[0], "Message too large")