mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-08 14:31:06 +00:00
refactor(core): split polling can_read and reading from USB
[no changelog]
This commit is contained in:
parent
13df961317
commit
e942f8e40d
@ -141,6 +141,20 @@ STATIC mp_obj_t mod_trezorio_HID_write(mp_obj_t self, mp_obj_t msg) {
|
|||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
|
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_write_obj,
|
||||||
mod_trezorio_HID_write);
|
mod_trezorio_HID_write);
|
||||||
|
|
||||||
|
/// def read(self, buf: bytes) -> int:
|
||||||
|
/// """
|
||||||
|
/// Reads message using USB HID (device) or UDP (emulator).
|
||||||
|
/// """
|
||||||
|
STATIC mp_obj_t mod_trezorio_HID_read(mp_obj_t self, mp_obj_t buffer) {
|
||||||
|
mp_obj_HID_t *o = MP_OBJ_TO_PTR(self);
|
||||||
|
mp_buffer_info_t buf = {0};
|
||||||
|
mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE);
|
||||||
|
ssize_t r = usb_hid_read(o->info.iface_num, buf.buf, buf.len);
|
||||||
|
return MP_OBJ_NEW_SMALL_INT(r);
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_HID_read_obj,
|
||||||
|
mod_trezorio_HID_read);
|
||||||
|
|
||||||
/// def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
|
/// def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
|
||||||
/// """
|
/// """
|
||||||
/// Sends message using USB HID (device) or UDP (emulator).
|
/// Sends message using USB HID (device) or UDP (emulator).
|
||||||
@ -162,6 +176,7 @@ STATIC const mp_rom_map_elem_t mod_trezorio_HID_locals_dict_table[] = {
|
|||||||
{MP_ROM_QSTR(MP_QSTR_iface_num),
|
{MP_ROM_QSTR(MP_QSTR_iface_num),
|
||||||
MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)},
|
MP_ROM_PTR(&mod_trezorio_HID_iface_num_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)},
|
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_HID_write_obj)},
|
||||||
|
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_HID_read_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_write_blocking),
|
{MP_ROM_QSTR(MP_QSTR_write_blocking),
|
||||||
MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)},
|
MP_ROM_PTR(&mod_trezorio_HID_write_blocking_obj)},
|
||||||
};
|
};
|
||||||
|
@ -166,23 +166,12 @@ STATIC mp_obj_t mod_trezorio_poll(mp_obj_t ifaces, mp_obj_t list_ref,
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
else if (mode == POLL_READ) {
|
else if (mode == POLL_READ) {
|
||||||
if (sectrue == usb_hid_can_read(iface)) {
|
if ((sectrue == usb_hid_can_read(iface)) ||
|
||||||
uint8_t buf[64] = {0};
|
(sectrue == usb_webusb_can_read(iface))) {
|
||||||
int len = usb_hid_read(iface, buf, sizeof(buf));
|
|
||||||
if (len > 0) {
|
|
||||||
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
|
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
|
||||||
ret->items[1] = mp_obj_new_bytes(buf, len);
|
ret->items[1] = MP_OBJ_NEW_SMALL_INT(64);
|
||||||
return mp_const_true;
|
return mp_const_true;
|
||||||
}
|
}
|
||||||
} else if (sectrue == usb_webusb_can_read(iface)) {
|
|
||||||
uint8_t buf[64] = {0};
|
|
||||||
int len = usb_webusb_read(iface, buf, sizeof(buf));
|
|
||||||
if (len > 0) {
|
|
||||||
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
|
|
||||||
ret->items[1] = mp_obj_new_bytes(buf, len);
|
|
||||||
return mp_const_true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (mode == POLL_WRITE) {
|
} else if (mode == POLL_WRITE) {
|
||||||
if (sectrue == usb_hid_can_write(iface)) {
|
if (sectrue == usb_hid_can_write(iface)) {
|
||||||
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
|
ret->items[0] = MP_OBJ_NEW_SMALL_INT(i);
|
||||||
|
@ -127,10 +127,25 @@ STATIC mp_obj_t mod_trezorio_WebUSB_write(mp_obj_t self, mp_obj_t msg) {
|
|||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
|
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_write_obj,
|
||||||
mod_trezorio_WebUSB_write);
|
mod_trezorio_WebUSB_write);
|
||||||
|
|
||||||
|
/// def read(self, buf: bytes) -> int:
|
||||||
|
/// """
|
||||||
|
/// Reads message using WebUSB (device) or UDP (emulator).
|
||||||
|
/// """
|
||||||
|
STATIC mp_obj_t mod_trezorio_WebUSB_read(mp_obj_t self, mp_obj_t buffer) {
|
||||||
|
mp_obj_HID_t *o = MP_OBJ_TO_PTR(self);
|
||||||
|
mp_buffer_info_t buf = {0};
|
||||||
|
mp_get_buffer_raise(buffer, &buf, MP_BUFFER_WRITE);
|
||||||
|
ssize_t r = usb_webusb_read(o->info.iface_num, buf.buf, buf.len);
|
||||||
|
return MP_OBJ_NEW_SMALL_INT(r);
|
||||||
|
}
|
||||||
|
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorio_WebUSB_read_obj,
|
||||||
|
mod_trezorio_WebUSB_read);
|
||||||
|
|
||||||
STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
|
STATIC const mp_rom_map_elem_t mod_trezorio_WebUSB_locals_dict_table[] = {
|
||||||
{MP_ROM_QSTR(MP_QSTR_iface_num),
|
{MP_ROM_QSTR(MP_QSTR_iface_num),
|
||||||
MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)},
|
MP_ROM_PTR(&mod_trezorio_WebUSB_iface_num_obj)},
|
||||||
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)},
|
{MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mod_trezorio_WebUSB_write_obj)},
|
||||||
|
{MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mod_trezorio_WebUSB_read_obj)},
|
||||||
};
|
};
|
||||||
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict,
|
STATIC MP_DEFINE_CONST_DICT(mod_trezorio_WebUSB_locals_dict,
|
||||||
mod_trezorio_WebUSB_locals_dict_table);
|
mod_trezorio_WebUSB_locals_dict_table);
|
||||||
|
@ -32,6 +32,11 @@ class HID:
|
|||||||
Sends message using USB HID (device) or UDP (emulator).
|
Sends message using USB HID (device) or UDP (emulator).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def read(self, buf: bytes) -> int:
|
||||||
|
"""
|
||||||
|
Reads message using USB HID (device) or UDP (emulator).
|
||||||
|
"""
|
||||||
|
|
||||||
def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
|
def write_blocking(self, msg: bytes, timeout_ms: int) -> int:
|
||||||
"""
|
"""
|
||||||
Sends message using USB HID (device) or UDP (emulator).
|
Sends message using USB HID (device) or UDP (emulator).
|
||||||
@ -148,6 +153,11 @@ class WebUSB:
|
|||||||
"""
|
"""
|
||||||
Sends message using USB WebUSB (device) or UDP (emulator).
|
Sends message using USB WebUSB (device) or UDP (emulator).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def read(self, buf: bytes) -> int:
|
||||||
|
"""
|
||||||
|
Reads message using WebUSB (device) or UDP (emulator).
|
||||||
|
"""
|
||||||
from . import fatfs, haptic, sdcard
|
from . import fatfs, haptic, sdcard
|
||||||
POLL_READ: int # wait until interface is readable and return read data
|
POLL_READ: int # wait until interface is readable and return read data
|
||||||
POLL_WRITE: int # wait until interface is writable
|
POLL_WRITE: int # wait until interface is writable
|
||||||
|
@ -376,7 +376,11 @@ async def _read_cmd(iface: HID) -> Cmd | None:
|
|||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
# wait for incoming command indefinitely
|
# wait for incoming command indefinitely
|
||||||
buf = await read
|
msg_len = await read
|
||||||
|
buf = bytearray(msg_len)
|
||||||
|
read_len = iface.read(buf)
|
||||||
|
if read_len != msg_len:
|
||||||
|
raise ValueError("Invalid length")
|
||||||
while True:
|
while True:
|
||||||
ifrm = overlay_struct(bytearray(buf), desc_init)
|
ifrm = overlay_struct(bytearray(buf), desc_init)
|
||||||
bcnt = ifrm.bcnt
|
bcnt = ifrm.bcnt
|
||||||
@ -415,7 +419,11 @@ async def _read_cmd(iface: HID) -> Cmd | None:
|
|||||||
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
|
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
|
||||||
while datalen < bcnt:
|
while datalen < bcnt:
|
||||||
try:
|
try:
|
||||||
buf = await read
|
msg_len = await read
|
||||||
|
buf = bytearray(msg_len)
|
||||||
|
read_len = iface.read(buf)
|
||||||
|
if read_len != msg_len:
|
||||||
|
raise ValueError("Invalid length")
|
||||||
except loop.Timeout:
|
except loop.Timeout:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
warning(__name__, "_ERR_MSG_TIMEOUT")
|
warning(__name__, "_ERR_MSG_TIMEOUT")
|
||||||
|
@ -25,7 +25,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
# wait for initial report
|
# wait for initial report
|
||||||
report = await read
|
msg_len = await read
|
||||||
|
report = bytearray(msg_len)
|
||||||
|
read_len = iface.read(report)
|
||||||
|
if read_len != msg_len:
|
||||||
|
raise CodecError("Invalid length")
|
||||||
if report[0] != _REP_MARKER:
|
if report[0] != _REP_MARKER:
|
||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||||
@ -50,7 +54,11 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
|
|
||||||
while nread < msize:
|
while nread < msize:
|
||||||
# wait for continuation report
|
# wait for continuation report
|
||||||
report = await read
|
msg_len = await read
|
||||||
|
report = bytearray(msg_len)
|
||||||
|
read_len = iface.read(report)
|
||||||
|
if read_len != msg_len:
|
||||||
|
raise CodecError("Invalid length")
|
||||||
if report[0] != _REP_MARKER:
|
if report[0] != _REP_MARKER:
|
||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ class MockHID:
|
|||||||
def __init__(self, num):
|
def __init__(self, num):
|
||||||
self.num = num
|
self.num = num
|
||||||
self.data = []
|
self.data = []
|
||||||
|
self.packet = None
|
||||||
|
|
||||||
def iface_num(self):
|
def iface_num(self):
|
||||||
return self.num
|
return self.num
|
||||||
@ -20,6 +21,17 @@ class MockHID:
|
|||||||
self.data.append(bytearray(msg))
|
self.data.append(bytearray(msg))
|
||||||
return len(msg)
|
return len(msg)
|
||||||
|
|
||||||
|
def mock_read(self, packet, gen):
|
||||||
|
self.packet = packet
|
||||||
|
return gen.send(len(packet))
|
||||||
|
|
||||||
|
def read(self, buffer):
|
||||||
|
if self.packet is None:
|
||||||
|
raise Exception("No packet to read")
|
||||||
|
buffer[:] = self.packet
|
||||||
|
self.packet = None
|
||||||
|
return len(buffer)
|
||||||
|
|
||||||
def wait_object(self, mode):
|
def wait_object(self, mode):
|
||||||
return wait(mode | self.num)
|
return wait(mode | self.num)
|
||||||
|
|
||||||
@ -48,7 +60,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
|
||||||
with self.assertRaises(StopIteration) as e:
|
with self.assertRaises(StopIteration) as e:
|
||||||
gen.send(message_packet)
|
self.interface.mock_read(message_packet, gen)
|
||||||
|
|
||||||
# e.value is StopIteration. e.value.value is the return value of the call
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
@ -74,11 +86,11 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
for packet in packets[:-1]:
|
for packet in packets[:-1]:
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
query = gen.send(packet)
|
query = self.interface.mock_read(packet, gen)
|
||||||
|
|
||||||
# last packet will stop
|
# last packet will stop
|
||||||
with self.assertRaises(StopIteration) as e:
|
with self.assertRaises(StopIteration) as e:
|
||||||
gen.send(packets[-1])
|
self.interface.mock_read(packets[-1], gen)
|
||||||
|
|
||||||
# e.value is StopIteration. e.value.value is the return value of the call
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
@ -103,7 +115,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
with self.assertRaises(StopIteration) as e:
|
with self.assertRaises(StopIteration) as e:
|
||||||
gen.send(packet)
|
self.interface.mock_read(packet, gen)
|
||||||
|
|
||||||
# e.value is StopIteration. e.value.value is the return value of the call
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
@ -169,10 +181,10 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
for packet in self.interface.data[:-1]:
|
for packet in self.interface.data[:-1]:
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
query = gen.send(packet)
|
query = self.interface.mock_read(packet, gen)
|
||||||
|
|
||||||
with self.assertRaises(StopIteration) as e:
|
with self.assertRaises(StopIteration) as e:
|
||||||
gen.send(self.interface.data[-1])
|
self.interface.mock_read(self.interface.data[-1], gen)
|
||||||
|
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
@ -194,10 +206,10 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
for _ in range(PACKET_COUNT - 1):
|
for _ in range(PACKET_COUNT - 1):
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
query = gen.send(packet)
|
query = self.interface.mock_read(packet, gen)
|
||||||
|
|
||||||
with self.assertRaises(codec_v1.CodecError) as e:
|
with self.assertRaises(codec_v1.CodecError) as e:
|
||||||
gen.send(packet)
|
self.interface.mock_read(packet,gen)
|
||||||
|
|
||||||
self.assertEqual(e.value.args[0], "Message too large")
|
self.assertEqual(e.value.args[0], "Message too large")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user