mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-14 01:10:58 +00:00
core: implement synchronous v1 codec
This commit is contained in:
parent
34bd57006f
commit
85d74ece76
@ -58,15 +58,16 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj,
|
||||
mod_trezorutils_consteq);
|
||||
|
||||
/// def memcpy(
|
||||
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int
|
||||
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
||||
/// ) -> int:
|
||||
/// """
|
||||
/// Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||
/// `dst` at offset `dst_ofs`. Returns the number of actually
|
||||
/// copied bytes.
|
||||
/// copied bytes. If `n` is not specified, tries to copy
|
||||
/// as much as possible.
|
||||
/// """
|
||||
STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
|
||||
mp_arg_check_num(n_args, 0, 5, 5, false);
|
||||
mp_arg_check_num(n_args, 0, 4, 5, false);
|
||||
|
||||
mp_buffer_info_t dst;
|
||||
mp_get_buffer_raise(args[0], &dst, MP_BUFFER_WRITE);
|
||||
@ -76,7 +77,12 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
|
||||
mp_get_buffer_raise(args[2], &src, MP_BUFFER_READ);
|
||||
uint32_t src_ofs = trezor_obj_get_uint(args[3]);
|
||||
|
||||
uint32_t n = trezor_obj_get_uint(args[4]);
|
||||
uint32_t n = 0;
|
||||
if (n_args > 4) {
|
||||
n = trezor_obj_get_uint(args[4]);
|
||||
} else {
|
||||
n = src.len;
|
||||
}
|
||||
|
||||
size_t dst_rem = (dst_ofs < dst.len) ? dst.len - dst_ofs : 0;
|
||||
size_t src_rem = (src_ofs < src.len) ? src.len - src_ofs : 0;
|
||||
@ -86,7 +92,7 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
|
||||
|
||||
return mp_obj_new_int(ncpy);
|
||||
}
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 5, 5,
|
||||
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 4, 5,
|
||||
mod_trezorutils_memcpy);
|
||||
|
||||
/// def halt(msg: str = None) -> None:
|
||||
|
@ -13,12 +13,13 @@ def consteq(sec: bytes, pub: bytes) -> bool:
|
||||
|
||||
# extmod/modtrezorutils/modtrezorutils.c
|
||||
def memcpy(
|
||||
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int
|
||||
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
|
||||
) -> int:
|
||||
"""
|
||||
Copies at most `n` bytes from `src` at offset `src_ofs` to
|
||||
`dst` at offset `dst_ofs`. Returns the number of actually
|
||||
copied bytes.
|
||||
copied bytes. If `n` is not specified, tries to copy
|
||||
as much as possible.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -9,14 +9,14 @@ if False:
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
|
||||
from typing_extensions import Protocol
|
||||
|
||||
class AsyncReader(Protocol):
|
||||
async def areadinto(self, buf: bytearray) -> int:
|
||||
class Reader(Protocol):
|
||||
def readinto(self, buf: bytearray) -> int:
|
||||
"""
|
||||
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
|
||||
"""
|
||||
|
||||
class AsyncWriter(Protocol):
|
||||
async def awrite(self, buf: bytes) -> int:
|
||||
class Writer(Protocol):
|
||||
def write(self, buf: bytes) -> int:
|
||||
"""
|
||||
Writes all bytes from `buf`, or raises `EOFError`.
|
||||
"""
|
||||
@ -25,20 +25,20 @@ if False:
|
||||
_UVARINT_BUFFER = bytearray(1)
|
||||
|
||||
|
||||
async def load_uvarint(reader: AsyncReader) -> int:
|
||||
def load_uvarint(reader: Reader) -> int:
|
||||
buffer = _UVARINT_BUFFER
|
||||
result = 0
|
||||
shift = 0
|
||||
byte = 0x80
|
||||
while byte & 0x80:
|
||||
await reader.areadinto(buffer)
|
||||
reader.readinto(buffer)
|
||||
byte = buffer[0]
|
||||
result += (byte & 0x7F) << shift
|
||||
shift += 7
|
||||
return result
|
||||
|
||||
|
||||
async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
|
||||
def dump_uvarint(writer: Writer, n: int) -> None:
|
||||
if n < 0:
|
||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||
buffer = _UVARINT_BUFFER
|
||||
@ -46,7 +46,7 @@ async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
|
||||
while shifted:
|
||||
shifted = n >> 7
|
||||
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
|
||||
await writer.awrite(buffer)
|
||||
writer.write(buffer)
|
||||
n = shifted
|
||||
|
||||
|
||||
@ -165,15 +165,15 @@ class MessageType:
|
||||
|
||||
|
||||
class LimitedReader:
|
||||
def __init__(self, reader: AsyncReader, limit: int) -> None:
|
||||
def __init__(self, reader: Reader, limit: int) -> None:
|
||||
self.reader = reader
|
||||
self.limit = limit
|
||||
|
||||
async def areadinto(self, buf: bytearray) -> int:
|
||||
def readinto(self, buf: bytearray) -> int:
|
||||
if self.limit < len(buf):
|
||||
raise EOFError
|
||||
else:
|
||||
nread = await self.reader.areadinto(buf)
|
||||
nread = self.reader.readinto(buf)
|
||||
self.limit -= nread
|
||||
return nread
|
||||
|
||||
@ -184,8 +184,8 @@ if False:
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
|
||||
|
||||
|
||||
async def load_message(
|
||||
reader: AsyncReader, msg_type: Type[LoadedMessageType]
|
||||
def load_message(
|
||||
reader: Reader, msg_type: Type[LoadedMessageType]
|
||||
) -> LoadedMessageType:
|
||||
fields = msg_type.get_fields()
|
||||
msg = msg_type()
|
||||
@ -197,7 +197,7 @@ async def load_message(
|
||||
|
||||
while True:
|
||||
try:
|
||||
fkey = await load_uvarint(reader)
|
||||
fkey = load_uvarint(reader)
|
||||
except EOFError:
|
||||
break # no more fields to load
|
||||
|
||||
@ -208,10 +208,10 @@ async def load_message(
|
||||
|
||||
if field is None: # unknown field, skip it
|
||||
if wtype == 0:
|
||||
await load_uvarint(reader)
|
||||
load_uvarint(reader)
|
||||
elif wtype == 2:
|
||||
ivalue = await load_uvarint(reader)
|
||||
await reader.areadinto(bytearray(ivalue))
|
||||
ivalue = load_uvarint(reader)
|
||||
reader.readinto(bytearray(ivalue))
|
||||
else:
|
||||
raise ValueError
|
||||
continue
|
||||
@ -220,7 +220,7 @@ async def load_message(
|
||||
if wtype != ftype.WIRE_TYPE:
|
||||
raise TypeError # parsed wire type differs from the schema
|
||||
|
||||
ivalue = await load_uvarint(reader)
|
||||
ivalue = load_uvarint(reader)
|
||||
|
||||
if ftype is UVarintType:
|
||||
fvalue = ivalue
|
||||
@ -232,13 +232,13 @@ async def load_message(
|
||||
fvalue = ftype.validate(ivalue)
|
||||
elif ftype is BytesType:
|
||||
fvalue = bytearray(ivalue)
|
||||
await reader.areadinto(fvalue)
|
||||
reader.readinto(fvalue)
|
||||
elif ftype is UnicodeType:
|
||||
fvalue = bytearray(ivalue)
|
||||
await reader.areadinto(fvalue)
|
||||
reader.readinto(fvalue)
|
||||
fvalue = bytes(fvalue).decode()
|
||||
elif issubclass(ftype, MessageType):
|
||||
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
|
||||
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
|
||||
else:
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
@ -257,9 +257,7 @@ async def load_message(
|
||||
return msg
|
||||
|
||||
|
||||
async def dump_message(
|
||||
writer: AsyncWriter, msg: MessageType, fields: Dict = None
|
||||
) -> None:
|
||||
def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
|
||||
repvalue = [0]
|
||||
|
||||
if fields is None:
|
||||
@ -281,39 +279,39 @@ async def dump_message(
|
||||
ffields = None # type: Optional[Dict]
|
||||
|
||||
for svalue in fvalue:
|
||||
await dump_uvarint(writer, fkey)
|
||||
dump_uvarint(writer, fkey)
|
||||
|
||||
if ftype is UVarintType:
|
||||
await dump_uvarint(writer, svalue)
|
||||
dump_uvarint(writer, svalue)
|
||||
|
||||
elif ftype is SVarintType:
|
||||
await dump_uvarint(writer, sint_to_uint(svalue))
|
||||
dump_uvarint(writer, sint_to_uint(svalue))
|
||||
|
||||
elif ftype is BoolType:
|
||||
await dump_uvarint(writer, int(svalue))
|
||||
dump_uvarint(writer, int(svalue))
|
||||
|
||||
elif isinstance(ftype, EnumType):
|
||||
await dump_uvarint(writer, svalue)
|
||||
dump_uvarint(writer, svalue)
|
||||
|
||||
elif ftype is BytesType:
|
||||
if isinstance(svalue, list):
|
||||
await dump_uvarint(writer, _count_bytes_list(svalue))
|
||||
dump_uvarint(writer, _count_bytes_list(svalue))
|
||||
for sub_svalue in svalue:
|
||||
await writer.awrite(sub_svalue)
|
||||
writer.write(sub_svalue)
|
||||
else:
|
||||
await dump_uvarint(writer, len(svalue))
|
||||
await writer.awrite(svalue)
|
||||
dump_uvarint(writer, len(svalue))
|
||||
writer.write(svalue)
|
||||
|
||||
elif ftype is UnicodeType:
|
||||
svalue = svalue.encode()
|
||||
await dump_uvarint(writer, len(svalue))
|
||||
await writer.awrite(svalue)
|
||||
dump_uvarint(writer, len(svalue))
|
||||
writer.write(svalue)
|
||||
|
||||
elif issubclass(ftype, MessageType):
|
||||
if ffields is None:
|
||||
ffields = ftype.get_fields()
|
||||
await dump_uvarint(writer, count_message(svalue, ffields))
|
||||
await dump_message(writer, svalue, ffields)
|
||||
dump_uvarint(writer, count_message(svalue, ffields))
|
||||
dump_message(writer, svalue, ffields)
|
||||
|
||||
else:
|
||||
raise TypeError
|
||||
|
@ -138,6 +138,7 @@ class Context:
|
||||
def __init__(self, iface: WireInterface, sid: int) -> None:
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
|
||||
|
||||
async def call(
|
||||
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
|
||||
@ -153,11 +154,13 @@ class Context:
|
||||
del msg
|
||||
return await self.read_any(expected_wire_types)
|
||||
|
||||
async def read_from_wire(self) -> codec_v1.Message:
|
||||
self.buffer_io.seek(0)
|
||||
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
|
||||
|
||||
async def read(
|
||||
self, expected_type: Type[protobuf.LoadedMessageType]
|
||||
) -> protobuf.LoadedMessageType:
|
||||
reader = self.make_reader()
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -167,14 +170,13 @@ class Context:
|
||||
expected_type,
|
||||
)
|
||||
|
||||
# Wait for the message header, contained in the first report. After
|
||||
# we receive it, we have a message type to match on.
|
||||
await reader.aopen()
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the reader via
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if reader.type != expected_type.MESSAGE_WIRE_TYPE:
|
||||
raise UnexpectedMessageError(reader)
|
||||
if msg.type != expected_type.MESSAGE_WIRE_TYPE:
|
||||
raise UnexpectedMessageError(msg)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
@ -187,14 +189,13 @@ class Context:
|
||||
|
||||
workflow.idle_timer.touch()
|
||||
|
||||
# parse the message and return it
|
||||
return await protobuf.load_message(reader, expected_type)
|
||||
# look up the protobuf class and parse the message
|
||||
pbtype = messages.get_type(msg.type)
|
||||
return protobuf.load_message(msg.data, pbtype)
|
||||
|
||||
async def read_any(
|
||||
self, expected_wire_types: Iterable[int]
|
||||
) -> protobuf.MessageType:
|
||||
reader = self.make_reader()
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -204,17 +205,16 @@ class Context:
|
||||
expected_wire_types,
|
||||
)
|
||||
|
||||
# Wait for the message header, contained in the first report. After
|
||||
# we receive it, we have a message type to match on.
|
||||
await reader.aopen()
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the reader via
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if reader.type not in expected_wire_types:
|
||||
raise UnexpectedMessageError(reader)
|
||||
if msg.type not in expected_wire_types:
|
||||
raise UnexpectedMessageError(msg)
|
||||
|
||||
# find the protobuf type
|
||||
exptype = messages.get_type(reader.type)
|
||||
exptype = messages.get_type(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
@ -224,24 +224,20 @@ class Context:
|
||||
workflow.idle_timer.touch()
|
||||
|
||||
# parse the message and return it
|
||||
return await protobuf.load_message(reader, exptype)
|
||||
return protobuf.load_message(msg.data, exptype)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
writer = self.make_writer()
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
|
||||
)
|
||||
|
||||
# get the message size
|
||||
fields = msg.get_fields()
|
||||
size = protobuf.count_message(msg, fields)
|
||||
|
||||
# write the message
|
||||
writer.setheader(msg.MESSAGE_WIRE_TYPE, size)
|
||||
await protobuf.dump_message(writer, msg, fields)
|
||||
await writer.aclose()
|
||||
self.buffer_io.seek(0)
|
||||
protobuf.dump_message(self.buffer_io, msg)
|
||||
await codec_v1.write_message(
|
||||
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
|
||||
)
|
||||
|
||||
def wait(self, *tasks: Awaitable) -> Any:
|
||||
"""
|
||||
@ -251,43 +247,35 @@ class Context:
|
||||
"""
|
||||
return loop.race(self.read_any(()), *tasks)
|
||||
|
||||
def make_reader(self) -> codec_v1.Reader:
|
||||
return codec_v1.Reader(self.iface)
|
||||
|
||||
def make_writer(self) -> codec_v1.Writer:
|
||||
return codec_v1.Writer(self.iface)
|
||||
|
||||
|
||||
class UnexpectedMessageError(Exception):
|
||||
def __init__(self, reader: codec_v1.Reader) -> None:
|
||||
self.reader = reader
|
||||
def __init__(self, msg: codec_v1.Message) -> None:
|
||||
self.msg = msg
|
||||
|
||||
|
||||
async def handle_session(
|
||||
iface: WireInterface, session_id: int, use_workflow: bool = True
|
||||
) -> None:
|
||||
ctx = Context(iface, session_id)
|
||||
next_reader = None # type: Optional[codec_v1.Reader]
|
||||
next_msg = None # type: Optional[codec_v1.Message]
|
||||
res_msg = None # type: Optional[protobuf.MessageType]
|
||||
req_reader = None
|
||||
req_type = None
|
||||
req_msg = None
|
||||
while True:
|
||||
try:
|
||||
if next_reader is None:
|
||||
if next_msg is None:
|
||||
# We are not currently reading a message, so let's wait for one.
|
||||
# If the decoding fails, exception is raised and we try again
|
||||
# (with the same `Reader` instance, it's OK). Even in case of
|
||||
# de-synchronized wire communication, report with a message
|
||||
# header is eventually received, after a couple of tries.
|
||||
req_reader = ctx.make_reader()
|
||||
await req_reader.aopen()
|
||||
msg = await ctx.read_from_wire()
|
||||
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = messages.get_type(req_reader.type).__name__
|
||||
msg_type = messages.get_type(msg.type).__name__
|
||||
except KeyError:
|
||||
msg_type = "%d - unknown message type" % req_reader.type
|
||||
msg_type = "%d - unknown message type" % msg.type
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x receive: <%s>",
|
||||
@ -298,8 +286,8 @@ async def handle_session(
|
||||
else:
|
||||
# We have a reader left over from earlier. We should process
|
||||
# this message instead of waiting for new one.
|
||||
req_reader = next_reader
|
||||
next_reader = None
|
||||
msg = next_msg
|
||||
next_msg = None
|
||||
|
||||
# Now we are in a middle of reading a message and we need to decide
|
||||
# what to do with it, based on its type from the message header.
|
||||
@ -312,13 +300,11 @@ async def handle_session(
|
||||
|
||||
# We need to find a handler for this message type. Should not
|
||||
# raise.
|
||||
handler = find_handler(iface, req_reader.type)
|
||||
handler = find_handler(iface, msg.type)
|
||||
|
||||
if handler is None:
|
||||
# If no handler is found, we can skip decoding and directly
|
||||
# respond with failure, but first, we should read the rest of
|
||||
# the message reports. Should not raise.
|
||||
await read_and_throw_away(req_reader)
|
||||
# respond with failure. Should not raise.
|
||||
res_msg = unexpected_message()
|
||||
|
||||
else:
|
||||
@ -332,11 +318,11 @@ async def handle_session(
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = messages.get_type(req_reader.type)
|
||||
req_type = messages.get_type(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = await protobuf.load_message(req_reader, req_type)
|
||||
req_msg = protobuf.load_message(msg.data, req_type)
|
||||
|
||||
# At this point, message reports are all processed and
|
||||
# correctly parsed into `req_msg`.
|
||||
@ -364,7 +350,7 @@ async def handle_session(
|
||||
# TODO:
|
||||
# We might handle only the few common cases here, like
|
||||
# Initialize and Cancel.
|
||||
next_reader = exc.reader
|
||||
next_msg = exc.msg
|
||||
res_msg = None
|
||||
|
||||
except Exception as exc:
|
||||
@ -401,7 +387,6 @@ async def handle_session(
|
||||
|
||||
# Cleanup, so garbage collection triggered after un-importing can
|
||||
# pick up the trash.
|
||||
req_reader = None
|
||||
req_type = None
|
||||
req_msg = None
|
||||
res_msg = None
|
||||
|
@ -18,143 +18,112 @@ SESSION_ID = const(0)
|
||||
INVALID_TYPE = const(-1)
|
||||
|
||||
|
||||
class Reader:
|
||||
"""
|
||||
Decoder for a wire codec over the HID (or UDP) layer. Provides readable
|
||||
async-file-like interface.
|
||||
"""
|
||||
class CodecError(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, iface: WireInterface) -> None:
|
||||
self.iface = iface
|
||||
self.type = INVALID_TYPE
|
||||
self.size = 0
|
||||
self.ofs = 0
|
||||
self.data = bytes()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "<Reader type: %s>" % self.type
|
||||
class BytesIO:
|
||||
def __init__(self, buffer: bytearray) -> None:
|
||||
self.buffer = buffer
|
||||
self.offset = 0
|
||||
|
||||
async def aopen(self) -> None:
|
||||
"""
|
||||
Start reading a message by waiting for initial message report. Because
|
||||
the first report contains the message header, `self.type` and
|
||||
`self.size` are initialized and available after `aopen()` returns.
|
||||
"""
|
||||
read = loop.wait(self.iface.iface_num() | io.POLL_READ)
|
||||
while True:
|
||||
# wait for initial report
|
||||
report = await read
|
||||
marker = report[0]
|
||||
if marker == _REP_MARKER:
|
||||
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
|
||||
raise ValueError
|
||||
break
|
||||
def seek(self, offset: int) -> None:
|
||||
offset = min(offset, len(self.buffer))
|
||||
offset = max(offset, 0)
|
||||
self.offset = offset
|
||||
|
||||
# load received message header
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize]
|
||||
self.ofs = 0
|
||||
|
||||
async def areadinto(self, buf: bytearray) -> int:
|
||||
"""
|
||||
Read exactly `len(buf)` bytes into `buf`, waiting for additional
|
||||
reports, if needed. Raises `EOFError` if end-of-message is encountered
|
||||
before the full read can be completed.
|
||||
"""
|
||||
if self.size < len(buf):
|
||||
def readinto(self, dst: bytearray) -> int:
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(dst) > len(buffer) - offset:
|
||||
raise EOFError
|
||||
|
||||
read = loop.wait(self.iface.iface_num() | io.POLL_READ)
|
||||
nread = 0
|
||||
while nread < len(buf):
|
||||
if self.ofs == len(self.data):
|
||||
# we are at the end of received data
|
||||
# wait for continuation report
|
||||
while True:
|
||||
report = await read
|
||||
marker = report[0]
|
||||
if marker == _REP_MARKER:
|
||||
break
|
||||
self.data = report[_REP_CONT_DATA : _REP_CONT_DATA + self.size]
|
||||
self.ofs = 0
|
||||
|
||||
# copy as much as possible to target buffer
|
||||
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
|
||||
nread += nbytes
|
||||
self.ofs += nbytes
|
||||
self.size -= nbytes
|
||||
|
||||
nread = utils.memcpy(dst, 0, buffer, offset)
|
||||
self.offset += nread
|
||||
return nread
|
||||
|
||||
|
||||
class Writer:
|
||||
"""
|
||||
Encoder for a wire codec over the HID (or UDP) layer. Provides writable
|
||||
async-file-like interface.
|
||||
"""
|
||||
|
||||
def __init__(self, iface: WireInterface):
|
||||
self.iface = iface
|
||||
self.type = INVALID_TYPE
|
||||
self.size = 0
|
||||
self.ofs = 0
|
||||
self.data = bytearray(_REP_LEN)
|
||||
|
||||
def setheader(self, mtype: int, msize: int) -> None:
|
||||
"""
|
||||
Reset the writer state and load the message header with passed type and
|
||||
total message size.
|
||||
"""
|
||||
self.type = mtype
|
||||
self.size = msize
|
||||
ustruct.pack_into(
|
||||
_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
|
||||
)
|
||||
self.ofs = _REP_INIT_DATA
|
||||
|
||||
async def awrite(self, buf: bytes) -> int:
|
||||
"""
|
||||
Encode and write every byte from `buf`. Does not need to be called in
|
||||
case message has zero length. Raises `EOFError` if the length of `buf`
|
||||
exceeds the remaining message length.
|
||||
"""
|
||||
if self.size < len(buf):
|
||||
def write(self, src: bytes) -> int:
|
||||
buffer = self.buffer
|
||||
offset = self.offset
|
||||
if len(src) > len(buffer) - offset:
|
||||
raise EOFError
|
||||
nwrite = utils.memcpy(buffer, offset, src, 0)
|
||||
self.offset += nwrite
|
||||
return nwrite
|
||||
|
||||
def get_written(self) -> bytes:
|
||||
return memoryview(self.buffer)[: self.offset]
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mtype: int, mdata: BytesIO) -> None:
|
||||
self.type = mtype
|
||||
self.data = mdata
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
# wait for initial report
|
||||
report = await read
|
||||
if report[0] != _REP_MARKER:
|
||||
raise CodecError("Invalid magic")
|
||||
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
|
||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||
raise CodecError("Invalid magic")
|
||||
|
||||
throw_away = False
|
||||
if msize > len(buffer):
|
||||
throw_away = True
|
||||
|
||||
# prepare the backing buffer
|
||||
mdata = memoryview(buffer)[:msize]
|
||||
|
||||
# buffer the initial data
|
||||
nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA)
|
||||
|
||||
while nread < msize:
|
||||
# wait for continuation report
|
||||
report = await read
|
||||
if report[0] != _REP_MARKER:
|
||||
raise CodecError("Invalid magic")
|
||||
|
||||
# buffer the continuation data
|
||||
if not throw_away:
|
||||
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||
|
||||
if throw_away:
|
||||
raise CodecError("Message too large")
|
||||
|
||||
return Message(mtype, BytesIO(mdata))
|
||||
|
||||
|
||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None:
|
||||
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
|
||||
# gather data from msg
|
||||
msize = len(mdata)
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(_REP_LEN)
|
||||
repofs = _REP_INIT_DATA
|
||||
ustruct.pack_into(
|
||||
_REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
|
||||
)
|
||||
|
||||
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
||||
nwritten = 0
|
||||
while nwritten < len(buf):
|
||||
# copy as much as possible to report buffer
|
||||
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
|
||||
nwritten += nbytes
|
||||
self.ofs += nbytes
|
||||
self.size -= nbytes
|
||||
while True:
|
||||
# copy as much as possible to the report buffer
|
||||
nwritten += utils.memcpy(report, repofs, mdata, nwritten)
|
||||
|
||||
if self.ofs == _REP_LEN:
|
||||
# we are at the end of the report, flush it
|
||||
# write the report
|
||||
while True:
|
||||
await write
|
||||
n = self.iface.write(self.data)
|
||||
if n == len(self.data):
|
||||
n = iface.write(report)
|
||||
if n == len(report):
|
||||
break
|
||||
self.ofs = _REP_CONT_DATA
|
||||
|
||||
return nwritten
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Flush and close the message transmission."""
|
||||
if self.ofs != _REP_CONT_DATA:
|
||||
# we didn't write anything or last write() wasn't report-aligned,
|
||||
# pad the final report and flush it
|
||||
while self.ofs < _REP_LEN:
|
||||
self.data[self.ofs] = 0x00
|
||||
self.ofs += 1
|
||||
|
||||
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
||||
while True:
|
||||
await write
|
||||
n = self.iface.write(self.data)
|
||||
if n == len(self.data):
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < msize:
|
||||
repofs = _REP_CONT_DATA
|
||||
else:
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user