mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-28 00:58:09 +00:00
core: implement synchronous v1 codec
This commit is contained in:
parent
34bd57006f
commit
85d74ece76
@ -58,15 +58,16 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj,
|
|||||||
mod_trezorutils_consteq);
|
mod_trezorutils_consteq);
|
||||||
|
|
||||||
/// def memcpy(
|
/// def memcpy(
|
||||||
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int
|
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
||||||
/// ) -> int:
|
/// ) -> int:
|
||||||
/// """
|
/// """
|
||||||
/// Copies at most `n` bytes from `src` at offset `src_ofs` to
|
/// Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||||
/// `dst` at offset `dst_ofs`. Returns the number of actually
|
/// `dst` at offset `dst_ofs`. Returns the number of actually
|
||||||
/// copied bytes.
|
/// copied bytes. If `n` is not specified, tries to copy
|
||||||
|
/// as much as possible.
|
||||||
/// """
|
/// """
|
||||||
STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
|
STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
|
||||||
mp_arg_check_num(n_args, 0, 5, 5, false);
|
mp_arg_check_num(n_args, 0, 4, 5, false);
|
||||||
|
|
||||||
mp_buffer_info_t dst;
|
mp_buffer_info_t dst;
|
||||||
mp_get_buffer_raise(args[0], &dst, MP_BUFFER_WRITE);
|
mp_get_buffer_raise(args[0], &dst, MP_BUFFER_WRITE);
|
||||||
@ -76,7 +77,12 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
|
|||||||
mp_get_buffer_raise(args[2], &src, MP_BUFFER_READ);
|
mp_get_buffer_raise(args[2], &src, MP_BUFFER_READ);
|
||||||
uint32_t src_ofs = trezor_obj_get_uint(args[3]);
|
uint32_t src_ofs = trezor_obj_get_uint(args[3]);
|
||||||
|
|
||||||
uint32_t n = trezor_obj_get_uint(args[4]);
|
uint32_t n = 0;
|
||||||
|
if (n_args > 4) {
|
||||||
|
n = trezor_obj_get_uint(args[4]);
|
||||||
|
} else {
|
||||||
|
n = src.len;
|
||||||
|
}
|
||||||
|
|
||||||
size_t dst_rem = (dst_ofs < dst.len) ? dst.len - dst_ofs : 0;
|
size_t dst_rem = (dst_ofs < dst.len) ? dst.len - dst_ofs : 0;
|
||||||
size_t src_rem = (src_ofs < src.len) ? src.len - src_ofs : 0;
|
size_t src_rem = (src_ofs < src.len) ? src.len - src_ofs : 0;
|
||||||
@ -86,7 +92,7 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
|
|||||||
|
|
||||||
return mp_obj_new_int(ncpy);
|
return mp_obj_new_int(ncpy);
|
||||||
}
|
}
|
||||||
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 5, 5,
|
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 4, 5,
|
||||||
mod_trezorutils_memcpy);
|
mod_trezorutils_memcpy);
|
||||||
|
|
||||||
/// def halt(msg: str = None) -> None:
|
/// def halt(msg: str = None) -> None:
|
||||||
|
@ -13,12 +13,13 @@ def consteq(sec: bytes, pub: bytes) -> bool:
|
|||||||
|
|
||||||
# extmod/modtrezorutils/modtrezorutils.c
|
# extmod/modtrezorutils/modtrezorutils.c
|
||||||
def memcpy(
|
def memcpy(
|
||||||
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int
|
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Copies at most `n` bytes from `src` at offset `src_ofs` to
|
Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||||
`dst` at offset `dst_ofs`. Returns the number of actually
|
`dst` at offset `dst_ofs`. Returns the number of actually
|
||||||
copied bytes.
|
copied bytes. If `n` is not specified, tries to copy
|
||||||
|
as much as possible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,14 +9,14 @@ if False:
|
|||||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
class AsyncReader(Protocol):
|
class Reader(Protocol):
|
||||||
async def areadinto(self, buf: bytearray) -> int:
|
def readinto(self, buf: bytearray) -> int:
|
||||||
"""
|
"""
|
||||||
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
|
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class AsyncWriter(Protocol):
|
class Writer(Protocol):
|
||||||
async def awrite(self, buf: bytes) -> int:
|
def write(self, buf: bytes) -> int:
|
||||||
"""
|
"""
|
||||||
Writes all bytes from `buf`, or raises `EOFError`.
|
Writes all bytes from `buf`, or raises `EOFError`.
|
||||||
"""
|
"""
|
||||||
@ -25,20 +25,20 @@ if False:
|
|||||||
_UVARINT_BUFFER = bytearray(1)
|
_UVARINT_BUFFER = bytearray(1)
|
||||||
|
|
||||||
|
|
||||||
async def load_uvarint(reader: AsyncReader) -> int:
|
def load_uvarint(reader: Reader) -> int:
|
||||||
buffer = _UVARINT_BUFFER
|
buffer = _UVARINT_BUFFER
|
||||||
result = 0
|
result = 0
|
||||||
shift = 0
|
shift = 0
|
||||||
byte = 0x80
|
byte = 0x80
|
||||||
while byte & 0x80:
|
while byte & 0x80:
|
||||||
await reader.areadinto(buffer)
|
reader.readinto(buffer)
|
||||||
byte = buffer[0]
|
byte = buffer[0]
|
||||||
result += (byte & 0x7F) << shift
|
result += (byte & 0x7F) << shift
|
||||||
shift += 7
|
shift += 7
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
|
def dump_uvarint(writer: Writer, n: int) -> None:
|
||||||
if n < 0:
|
if n < 0:
|
||||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||||
buffer = _UVARINT_BUFFER
|
buffer = _UVARINT_BUFFER
|
||||||
@ -46,7 +46,7 @@ async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
|
|||||||
while shifted:
|
while shifted:
|
||||||
shifted = n >> 7
|
shifted = n >> 7
|
||||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||||
await writer.awrite(buffer)
|
writer.write(buffer)
|
||||||
n = shifted
|
n = shifted
|
||||||
|
|
||||||
|
|
||||||
@ -165,15 +165,15 @@ class MessageType:
|
|||||||
|
|
||||||
|
|
||||||
class LimitedReader:
|
class LimitedReader:
|
||||||
def __init__(self, reader: AsyncReader, limit: int) -> None:
|
def __init__(self, reader: Reader, limit: int) -> None:
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
|
|
||||||
async def areadinto(self, buf: bytearray) -> int:
|
def readinto(self, buf: bytearray) -> int:
|
||||||
if self.limit < len(buf):
|
if self.limit < len(buf):
|
||||||
raise EOFError
|
raise EOFError
|
||||||
else:
|
else:
|
||||||
nread = await self.reader.areadinto(buf)
|
nread = self.reader.readinto(buf)
|
||||||
self.limit -= nread
|
self.limit -= nread
|
||||||
return nread
|
return nread
|
||||||
|
|
||||||
@ -184,8 +184,8 @@ if False:
|
|||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
|
||||||
|
|
||||||
|
|
||||||
async def load_message(
|
def load_message(
|
||||||
reader: AsyncReader, msg_type: Type[LoadedMessageType]
|
reader: Reader, msg_type: Type[LoadedMessageType]
|
||||||
) -> LoadedMessageType:
|
) -> LoadedMessageType:
|
||||||
fields = msg_type.get_fields()
|
fields = msg_type.get_fields()
|
||||||
msg = msg_type()
|
msg = msg_type()
|
||||||
@ -197,7 +197,7 @@ async def load_message(
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
fkey = await load_uvarint(reader)
|
fkey = load_uvarint(reader)
|
||||||
except EOFError:
|
except EOFError:
|
||||||
break # no more fields to load
|
break # no more fields to load
|
||||||
|
|
||||||
@ -208,10 +208,10 @@ async def load_message(
|
|||||||
|
|
||||||
if field is None: # unknown field, skip it
|
if field is None: # unknown field, skip it
|
||||||
if wtype == 0:
|
if wtype == 0:
|
||||||
await load_uvarint(reader)
|
load_uvarint(reader)
|
||||||
elif wtype == 2:
|
elif wtype == 2:
|
||||||
ivalue = await load_uvarint(reader)
|
ivalue = load_uvarint(reader)
|
||||||
await reader.areadinto(bytearray(ivalue))
|
reader.readinto(bytearray(ivalue))
|
||||||
else:
|
else:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
continue
|
continue
|
||||||
@ -220,7 +220,7 @@ async def load_message(
|
|||||||
if wtype != ftype.WIRE_TYPE:
|
if wtype != ftype.WIRE_TYPE:
|
||||||
raise TypeError # parsed wire type differs from the schema
|
raise TypeError # parsed wire type differs from the schema
|
||||||
|
|
||||||
ivalue = await load_uvarint(reader)
|
ivalue = load_uvarint(reader)
|
||||||
|
|
||||||
if ftype is UVarintType:
|
if ftype is UVarintType:
|
||||||
fvalue = ivalue
|
fvalue = ivalue
|
||||||
@ -232,13 +232,13 @@ async def load_message(
|
|||||||
fvalue = ftype.validate(ivalue)
|
fvalue = ftype.validate(ivalue)
|
||||||
elif ftype is BytesType:
|
elif ftype is BytesType:
|
||||||
fvalue = bytearray(ivalue)
|
fvalue = bytearray(ivalue)
|
||||||
await reader.areadinto(fvalue)
|
reader.readinto(fvalue)
|
||||||
elif ftype is UnicodeType:
|
elif ftype is UnicodeType:
|
||||||
fvalue = bytearray(ivalue)
|
fvalue = bytearray(ivalue)
|
||||||
await reader.areadinto(fvalue)
|
reader.readinto(fvalue)
|
||||||
fvalue = bytes(fvalue).decode()
|
fvalue = bytes(fvalue).decode()
|
||||||
elif issubclass(ftype, MessageType):
|
elif issubclass(ftype, MessageType):
|
||||||
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
|
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
|
||||||
else:
|
else:
|
||||||
raise TypeError # field type is unknown
|
raise TypeError # field type is unknown
|
||||||
|
|
||||||
@ -257,9 +257,7 @@ async def load_message(
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
async def dump_message(
|
def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
||||||
writer: AsyncWriter, msg: MessageType, fields: Dict = None
|
|
||||||
) -> None:
|
|
||||||
repvalue = [0]
|
repvalue = [0]
|
||||||
|
|
||||||
if fields is None:
|
if fields is None:
|
||||||
@ -281,39 +279,39 @@ async def dump_message(
|
|||||||
ffields = None # type: Optional[Dict]
|
ffields = None # type: Optional[Dict]
|
||||||
|
|
||||||
for svalue in fvalue:
|
for svalue in fvalue:
|
||||||
await dump_uvarint(writer, fkey)
|
dump_uvarint(writer, fkey)
|
||||||
|
|
||||||
if ftype is UVarintType:
|
if ftype is UVarintType:
|
||||||
await dump_uvarint(writer, svalue)
|
dump_uvarint(writer, svalue)
|
||||||
|
|
||||||
elif ftype is SVarintType:
|
elif ftype is SVarintType:
|
||||||
await dump_uvarint(writer, sint_to_uint(svalue))
|
dump_uvarint(writer, sint_to_uint(svalue))
|
||||||
|
|
||||||
elif ftype is BoolType:
|
elif ftype is BoolType:
|
||||||
await dump_uvarint(writer, int(svalue))
|
dump_uvarint(writer, int(svalue))
|
||||||
|
|
||||||
elif isinstance(ftype, EnumType):
|
elif isinstance(ftype, EnumType):
|
||||||
await dump_uvarint(writer, svalue)
|
dump_uvarint(writer, svalue)
|
||||||
|
|
||||||
elif ftype is BytesType:
|
elif ftype is BytesType:
|
||||||
if isinstance(svalue, list):
|
if isinstance(svalue, list):
|
||||||
await dump_uvarint(writer, _count_bytes_list(svalue))
|
dump_uvarint(writer, _count_bytes_list(svalue))
|
||||||
for sub_svalue in svalue:
|
for sub_svalue in svalue:
|
||||||
await writer.awrite(sub_svalue)
|
writer.write(sub_svalue)
|
||||||
else:
|
else:
|
||||||
await dump_uvarint(writer, len(svalue))
|
dump_uvarint(writer, len(svalue))
|
||||||
await writer.awrite(svalue)
|
writer.write(svalue)
|
||||||
|
|
||||||
elif ftype is UnicodeType:
|
elif ftype is UnicodeType:
|
||||||
svalue = svalue.encode()
|
svalue = svalue.encode()
|
||||||
await dump_uvarint(writer, len(svalue))
|
dump_uvarint(writer, len(svalue))
|
||||||
await writer.awrite(svalue)
|
writer.write(svalue)
|
||||||
|
|
||||||
elif issubclass(ftype, MessageType):
|
elif issubclass(ftype, MessageType):
|
||||||
if ffields is None:
|
if ffields is None:
|
||||||
ffields = ftype.get_fields()
|
ffields = ftype.get_fields()
|
||||||
await dump_uvarint(writer, count_message(svalue, ffields))
|
dump_uvarint(writer, count_message(svalue, ffields))
|
||||||
await dump_message(writer, svalue, ffields)
|
dump_message(writer, svalue, ffields)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
@ -138,6 +138,7 @@ class Context:
|
|||||||
def __init__(self, iface: WireInterface, sid: int) -> None:
|
def __init__(self, iface: WireInterface, sid: int) -> None:
|
||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
|
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
|
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
|
||||||
@ -153,11 +154,13 @@ class Context:
|
|||||||
del msg
|
del msg
|
||||||
return await self.read_any(expected_wire_types)
|
return await self.read_any(expected_wire_types)
|
||||||
|
|
||||||
|
async def read_from_wire(self) -> codec_v1.Message:
|
||||||
|
self.buffer_io.seek(0)
|
||||||
|
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
||||||
|
|
||||||
async def read(
|
async def read(
|
||||||
self, expected_type: Type[protobuf.LoadedMessageType]
|
self, expected_type: Type[protobuf.LoadedMessageType]
|
||||||
) -> protobuf.LoadedMessageType:
|
) -> protobuf.LoadedMessageType:
|
||||||
reader = self.make_reader()
|
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
@ -167,14 +170,13 @@ class Context:
|
|||||||
expected_type,
|
expected_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for the message header, contained in the first report. After
|
# Load the full message into a buffer, parse out type and data payload
|
||||||
# we receive it, we have a message type to match on.
|
msg = await self.read_from_wire()
|
||||||
await reader.aopen()
|
|
||||||
|
|
||||||
# If we got a message with unexpected type, raise the reader via
|
# If we got a message with unexpected type, raise the message via
|
||||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||||
if reader.type != expected_type.MESSAGE_WIRE_TYPE:
|
if msg.type != expected_type.MESSAGE_WIRE_TYPE:
|
||||||
raise UnexpectedMessageError(reader)
|
raise UnexpectedMessageError(msg)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -187,14 +189,13 @@ class Context:
|
|||||||
|
|
||||||
workflow.idle_timer.touch()
|
workflow.idle_timer.touch()
|
||||||
|
|
||||||
# parse the message and return it
|
# look up the protobuf class and parse the message
|
||||||
return await protobuf.load_message(reader, expected_type)
|
pbtype = messages.get_type(msg.type)
|
||||||
|
return protobuf.load_message(msg.data, pbtype)
|
||||||
|
|
||||||
async def read_any(
|
async def read_any(
|
||||||
self, expected_wire_types: Iterable[int]
|
self, expected_wire_types: Iterable[int]
|
||||||
) -> protobuf.MessageType:
|
) -> protobuf.MessageType:
|
||||||
reader = self.make_reader()
|
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
@ -204,17 +205,16 @@ class Context:
|
|||||||
expected_wire_types,
|
expected_wire_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for the message header, contained in the first report. After
|
# Load the full message into a buffer, parse out type and data payload
|
||||||
# we receive it, we have a message type to match on.
|
msg = await self.read_from_wire()
|
||||||
await reader.aopen()
|
|
||||||
|
|
||||||
# If we got a message with unexpected type, raise the reader via
|
# If we got a message with unexpected type, raise the message via
|
||||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||||
if reader.type not in expected_wire_types:
|
if msg.type not in expected_wire_types:
|
||||||
raise UnexpectedMessageError(reader)
|
raise UnexpectedMessageError(msg)
|
||||||
|
|
||||||
# find the protobuf type
|
# find the protobuf type
|
||||||
exptype = messages.get_type(reader.type)
|
exptype = messages.get_type(msg.type)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -224,24 +224,20 @@ class Context:
|
|||||||
workflow.idle_timer.touch()
|
workflow.idle_timer.touch()
|
||||||
|
|
||||||
# parse the message and return it
|
# parse the message and return it
|
||||||
return await protobuf.load_message(reader, exptype)
|
return protobuf.load_message(msg.data, exptype)
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
writer = self.make_writer()
|
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the message size
|
|
||||||
fields = msg.get_fields()
|
|
||||||
size = protobuf.count_message(msg, fields)
|
|
||||||
|
|
||||||
# write the message
|
# write the message
|
||||||
writer.setheader(msg.MESSAGE_WIRE_TYPE, size)
|
self.buffer_io.seek(0)
|
||||||
await protobuf.dump_message(writer, msg, fields)
|
protobuf.dump_message(self.buffer_io, msg)
|
||||||
await writer.aclose()
|
await codec_v1.write_message(
|
||||||
|
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
|
||||||
|
)
|
||||||
|
|
||||||
def wait(self, *tasks: Awaitable) -> Any:
|
def wait(self, *tasks: Awaitable) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -251,43 +247,35 @@ class Context:
|
|||||||
"""
|
"""
|
||||||
return loop.race(self.read_any(()), *tasks)
|
return loop.race(self.read_any(()), *tasks)
|
||||||
|
|
||||||
def make_reader(self) -> codec_v1.Reader:
|
|
||||||
return codec_v1.Reader(self.iface)
|
|
||||||
|
|
||||||
def make_writer(self) -> codec_v1.Writer:
|
|
||||||
return codec_v1.Writer(self.iface)
|
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedMessageError(Exception):
|
class UnexpectedMessageError(Exception):
|
||||||
def __init__(self, reader: codec_v1.Reader) -> None:
|
def __init__(self, msg: codec_v1.Message) -> None:
|
||||||
self.reader = reader
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
async def handle_session(
|
async def handle_session(
|
||||||
iface: WireInterface, session_id: int, use_workflow: bool = True
|
iface: WireInterface, session_id: int, use_workflow: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
ctx = Context(iface, session_id)
|
ctx = Context(iface, session_id)
|
||||||
next_reader = None # type: Optional[codec_v1.Reader]
|
next_msg = None # type: Optional[codec_v1.Message]
|
||||||
res_msg = None # type: Optional[protobuf.MessageType]
|
res_msg = None # type: Optional[protobuf.MessageType]
|
||||||
req_reader = None
|
|
||||||
req_type = None
|
req_type = None
|
||||||
req_msg = None
|
req_msg = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if next_reader is None:
|
if next_msg is None:
|
||||||
# We are not currently reading a message, so let's wait for one.
|
# We are not currently reading a message, so let's wait for one.
|
||||||
# If the decoding fails, exception is raised and we try again
|
# If the decoding fails, exception is raised and we try again
|
||||||
# (with the same `Reader` instance, it's OK). Even in case of
|
# (with the same `Reader` instance, it's OK). Even in case of
|
||||||
# de-synchronized wire communication, report with a message
|
# de-synchronized wire communication, report with a message
|
||||||
# header is eventually received, after a couple of tries.
|
# header is eventually received, after a couple of tries.
|
||||||
req_reader = ctx.make_reader()
|
msg = await ctx.read_from_wire()
|
||||||
await req_reader.aopen()
|
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
try:
|
try:
|
||||||
msg_type = messages.get_type(req_reader.type).__name__
|
msg_type = messages.get_type(msg.type).__name__
|
||||||
except KeyError:
|
except KeyError:
|
||||||
msg_type = "%d - unknown message type" % req_reader.type
|
msg_type = "%d - unknown message type" % msg.type
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"%s:%x receive: <%s>",
|
"%s:%x receive: <%s>",
|
||||||
@ -298,8 +286,8 @@ async def handle_session(
|
|||||||
else:
|
else:
|
||||||
# We have a reader left over from earlier. We should process
|
# We have a reader left over from earlier. We should process
|
||||||
# this message instead of waiting for new one.
|
# this message instead of waiting for new one.
|
||||||
req_reader = next_reader
|
msg = next_msg
|
||||||
next_reader = None
|
next_msg = None
|
||||||
|
|
||||||
# Now we are in a middle of reading a message and we need to decide
|
# Now we are in a middle of reading a message and we need to decide
|
||||||
# what to do with it, based on its type from the message header.
|
# what to do with it, based on its type from the message header.
|
||||||
@ -312,13 +300,11 @@ async def handle_session(
|
|||||||
|
|
||||||
# We need to find a handler for this message type. Should not
|
# We need to find a handler for this message type. Should not
|
||||||
# raise.
|
# raise.
|
||||||
handler = find_handler(iface, req_reader.type)
|
handler = find_handler(iface, msg.type)
|
||||||
|
|
||||||
if handler is None:
|
if handler is None:
|
||||||
# If no handler is found, we can skip decoding and directly
|
# If no handler is found, we can skip decoding and directly
|
||||||
# respond with failure, but first, we should read the rest of
|
# respond with failure. Should not raise.
|
||||||
# the message reports. Should not raise.
|
|
||||||
await read_and_throw_away(req_reader)
|
|
||||||
res_msg = unexpected_message()
|
res_msg = unexpected_message()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -332,11 +318,11 @@ async def handle_session(
|
|||||||
try:
|
try:
|
||||||
# Find a protobuf.MessageType subclass that describes this
|
# Find a protobuf.MessageType subclass that describes this
|
||||||
# message. Raises if the type is not found.
|
# message. Raises if the type is not found.
|
||||||
req_type = messages.get_type(req_reader.type)
|
req_type = messages.get_type(msg.type)
|
||||||
|
|
||||||
# Try to decode the message according to schema from
|
# Try to decode the message according to schema from
|
||||||
# `req_type`. Raises if the message is malformed.
|
# `req_type`. Raises if the message is malformed.
|
||||||
req_msg = await protobuf.load_message(req_reader, req_type)
|
req_msg = protobuf.load_message(msg.data, req_type)
|
||||||
|
|
||||||
# At this point, message reports are all processed and
|
# At this point, message reports are all processed and
|
||||||
# correctly parsed into `req_msg`.
|
# correctly parsed into `req_msg`.
|
||||||
@ -364,7 +350,7 @@ async def handle_session(
|
|||||||
# TODO:
|
# TODO:
|
||||||
# We might handle only the few common cases here, like
|
# We might handle only the few common cases here, like
|
||||||
# Initialize and Cancel.
|
# Initialize and Cancel.
|
||||||
next_reader = exc.reader
|
next_msg = exc.msg
|
||||||
res_msg = None
|
res_msg = None
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@ -401,7 +387,6 @@ async def handle_session(
|
|||||||
|
|
||||||
# Cleanup, so garbage collection triggered after un-importing can
|
# Cleanup, so garbage collection triggered after un-importing can
|
||||||
# pick up the trash.
|
# pick up the trash.
|
||||||
req_reader = None
|
|
||||||
req_type = None
|
req_type = None
|
||||||
req_msg = None
|
req_msg = None
|
||||||
res_msg = None
|
res_msg = None
|
||||||
|
@ -18,143 +18,112 @@ SESSION_ID = const(0)
|
|||||||
INVALID_TYPE = const(-1)
|
INVALID_TYPE = const(-1)
|
||||||
|
|
||||||
|
|
||||||
class Reader:
|
class CodecError(Exception):
|
||||||
"""
|
pass
|
||||||
Decoder for a wire codec over the HID (or UDP) layer. Provides readable
|
|
||||||
async-file-like interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, iface: WireInterface) -> None:
|
|
||||||
self.iface = iface
|
|
||||||
self.type = INVALID_TYPE
|
|
||||||
self.size = 0
|
|
||||||
self.ofs = 0
|
|
||||||
self.data = bytes()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
class BytesIO:
|
||||||
return "<Reader type: %s>" % self.type
|
def __init__(self, buffer: bytearray) -> None:
|
||||||
|
self.buffer = buffer
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
async def aopen(self) -> None:
|
def seek(self, offset: int) -> None:
|
||||||
"""
|
offset = min(offset, len(self.buffer))
|
||||||
Start reading a message by waiting for initial message report. Because
|
offset = max(offset, 0)
|
||||||
the first report contains the message header, `self.type` and
|
self.offset = offset
|
||||||
`self.size` are initialized and available after `aopen()` returns.
|
|
||||||
"""
|
|
||||||
read = loop.wait(self.iface.iface_num() | io.POLL_READ)
|
|
||||||
while True:
|
|
||||||
# wait for initial report
|
|
||||||
report = await read
|
|
||||||
marker = report[0]
|
|
||||||
if marker == _REP_MARKER:
|
|
||||||
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
|
||||||
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
|
|
||||||
raise ValueError
|
|
||||||
break
|
|
||||||
|
|
||||||
# load received message header
|
def readinto(self, dst: bytearray) -> int:
|
||||||
self.type = mtype
|
buffer = self.buffer
|
||||||
self.size = msize
|
offset = self.offset
|
||||||
self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize]
|
if len(dst) > len(buffer) - offset:
|
||||||
self.ofs = 0
|
|
||||||
|
|
||||||
async def areadinto(self, buf: bytearray) -> int:
|
|
||||||
"""
|
|
||||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
|
||||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
|
||||||
before the full read can be completed.
|
|
||||||
"""
|
|
||||||
if self.size < len(buf):
|
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
nread = utils.memcpy(dst, 0, buffer, offset)
|
||||||
read = loop.wait(self.iface.iface_num() | io.POLL_READ)
|
self.offset += nread
|
||||||
nread = 0
|
|
||||||
while nread < len(buf):
|
|
||||||
if self.ofs == len(self.data):
|
|
||||||
# we are at the end of received data
|
|
||||||
# wait for continuation report
|
|
||||||
while True:
|
|
||||||
report = await read
|
|
||||||
marker = report[0]
|
|
||||||
if marker == _REP_MARKER:
|
|
||||||
break
|
|
||||||
self.data = report[_REP_CONT_DATA : _REP_CONT_DATA + self.size]
|
|
||||||
self.ofs = 0
|
|
||||||
|
|
||||||
# copy as much as possible to target buffer
|
|
||||||
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
|
|
||||||
nread += nbytes
|
|
||||||
self.ofs += nbytes
|
|
||||||
self.size -= nbytes
|
|
||||||
|
|
||||||
return nread
|
return nread
|
||||||
|
|
||||||
|
def write(self, src: bytes) -> int:
|
||||||
class Writer:
|
buffer = self.buffer
|
||||||
"""
|
offset = self.offset
|
||||||
Encoder for a wire codec over the HID (or UDP) layer. Provides writable
|
if len(src) > len(buffer) - offset:
|
||||||
async-file-like interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, iface: WireInterface):
|
|
||||||
self.iface = iface
|
|
||||||
self.type = INVALID_TYPE
|
|
||||||
self.size = 0
|
|
||||||
self.ofs = 0
|
|
||||||
self.data = bytearray(_REP_LEN)
|
|
||||||
|
|
||||||
def setheader(self, mtype: int, msize: int) -> None:
|
|
||||||
"""
|
|
||||||
Reset the writer state and load the message header with passed type and
|
|
||||||
total message size.
|
|
||||||
"""
|
|
||||||
self.type = mtype
|
|
||||||
self.size = msize
|
|
||||||
ustruct.pack_into(
|
|
||||||
_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
|
|
||||||
)
|
|
||||||
self.ofs = _REP_INIT_DATA
|
|
||||||
|
|
||||||
async def awrite(self, buf: bytes) -> int:
|
|
||||||
"""
|
|
||||||
Encode and write every byte from `buf`. Does not need to be called in
|
|
||||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
|
||||||
exceeds the remaining message length.
|
|
||||||
"""
|
|
||||||
if self.size < len(buf):
|
|
||||||
raise EOFError
|
raise EOFError
|
||||||
|
nwrite = utils.memcpy(buffer, offset, src, 0)
|
||||||
|
self.offset += nwrite
|
||||||
|
return nwrite
|
||||||
|
|
||||||
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
def get_written(self) -> bytes:
|
||||||
nwritten = 0
|
return memoryview(self.buffer)[: self.offset]
|
||||||
while nwritten < len(buf):
|
|
||||||
# copy as much as possible to report buffer
|
|
||||||
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
|
|
||||||
nwritten += nbytes
|
|
||||||
self.ofs += nbytes
|
|
||||||
self.size -= nbytes
|
|
||||||
|
|
||||||
if self.ofs == _REP_LEN:
|
|
||||||
# we are at the end of the report, flush it
|
|
||||||
while True:
|
|
||||||
await write
|
|
||||||
n = self.iface.write(self.data)
|
|
||||||
if n == len(self.data):
|
|
||||||
break
|
|
||||||
self.ofs = _REP_CONT_DATA
|
|
||||||
|
|
||||||
return nwritten
|
class Message:
|
||||||
|
def __init__(self, mtype: int, mdata: BytesIO) -> None:
|
||||||
|
self.type = mtype
|
||||||
|
self.data = mdata
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
|
||||||
"""Flush and close the message transmission."""
|
|
||||||
if self.ofs != _REP_CONT_DATA:
|
|
||||||
# we didn't write anything or last write() wasn't report-aligned,
|
|
||||||
# pad the final report and flush it
|
|
||||||
while self.ofs < _REP_LEN:
|
|
||||||
self.data[self.ofs] = 0x00
|
|
||||||
self.ofs += 1
|
|
||||||
|
|
||||||
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
|
||||||
while True:
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
await write
|
|
||||||
n = self.iface.write(self.data)
|
# wait for initial report
|
||||||
if n == len(self.data):
|
report = await read
|
||||||
break
|
if report[0] != _REP_MARKER:
|
||||||
|
raise CodecError("Invalid magic")
|
||||||
|
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||||
|
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||||
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
|
throw_away = False
|
||||||
|
if msize > len(buffer):
|
||||||
|
throw_away = True
|
||||||
|
|
||||||
|
# prepare the backing buffer
|
||||||
|
mdata = memoryview(buffer)[:msize]
|
||||||
|
|
||||||
|
# buffer the initial data
|
||||||
|
nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA)
|
||||||
|
|
||||||
|
while nread < msize:
|
||||||
|
# wait for continuation report
|
||||||
|
report = await read
|
||||||
|
if report[0] != _REP_MARKER:
|
||||||
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
|
# buffer the continuation data
|
||||||
|
if not throw_away:
|
||||||
|
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||||
|
|
||||||
|
if throw_away:
|
||||||
|
raise CodecError("Message too large")
|
||||||
|
|
||||||
|
return Message(mtype, BytesIO(mdata))
|
||||||
|
|
||||||
|
|
||||||
|
async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None:
|
||||||
|
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||||
|
|
||||||
|
# gather data from msg
|
||||||
|
msize = len(mdata)
|
||||||
|
|
||||||
|
# prepare the report buffer with header data
|
||||||
|
report = bytearray(_REP_LEN)
|
||||||
|
repofs = _REP_INIT_DATA
|
||||||
|
ustruct.pack_into(
|
||||||
|
_REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
|
||||||
|
)
|
||||||
|
|
||||||
|
nwritten = 0
|
||||||
|
while True:
|
||||||
|
# copy as much as possible to the report buffer
|
||||||
|
nwritten += utils.memcpy(report, repofs, mdata, nwritten)
|
||||||
|
|
||||||
|
# write the report
|
||||||
|
while True:
|
||||||
|
await write
|
||||||
|
n = iface.write(report)
|
||||||
|
if n == len(report):
|
||||||
|
break
|
||||||
|
|
||||||
|
# if we have more data to write, use continuation reports for it
|
||||||
|
if nwritten < msize:
|
||||||
|
repofs = _REP_CONT_DATA
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
Loading…
Reference in New Issue
Block a user