1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-28 00:58:09 +00:00

core: implement synchronous v1 codec

This commit is contained in:
matejcik 2020-06-26 12:30:12 +02:00 committed by Tomas Susanka
parent 34bd57006f
commit 85d74ece76
5 changed files with 186 additions and 227 deletions

View File

@ -58,15 +58,16 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj,
mod_trezorutils_consteq); mod_trezorutils_consteq);
/// def memcpy( /// def memcpy(
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int /// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
/// ) -> int: /// ) -> int:
/// """ /// """
/// Copies at most `n` bytes from `src` at offset `src_ofs` to /// Copies at most `n` bytes from `src` at offset `src_ofs` to
/// `dst` at offset `dst_ofs`. Returns the number of actually /// `dst` at offset `dst_ofs`. Returns the number of actually
/// copied bytes. /// copied bytes. If `n` is not specified, tries to copy
/// as much as possible.
/// """ /// """
STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) { STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
mp_arg_check_num(n_args, 0, 5, 5, false); mp_arg_check_num(n_args, 0, 4, 5, false);
mp_buffer_info_t dst; mp_buffer_info_t dst;
mp_get_buffer_raise(args[0], &dst, MP_BUFFER_WRITE); mp_get_buffer_raise(args[0], &dst, MP_BUFFER_WRITE);
@ -76,7 +77,12 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
mp_get_buffer_raise(args[2], &src, MP_BUFFER_READ); mp_get_buffer_raise(args[2], &src, MP_BUFFER_READ);
uint32_t src_ofs = trezor_obj_get_uint(args[3]); uint32_t src_ofs = trezor_obj_get_uint(args[3]);
uint32_t n = trezor_obj_get_uint(args[4]); uint32_t n = 0;
if (n_args > 4) {
n = trezor_obj_get_uint(args[4]);
} else {
n = src.len;
}
size_t dst_rem = (dst_ofs < dst.len) ? dst.len - dst_ofs : 0; size_t dst_rem = (dst_ofs < dst.len) ? dst.len - dst_ofs : 0;
size_t src_rem = (src_ofs < src.len) ? src.len - src_ofs : 0; size_t src_rem = (src_ofs < src.len) ? src.len - src_ofs : 0;
@ -86,7 +92,7 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
return mp_obj_new_int(ncpy); return mp_obj_new_int(ncpy);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 5, 5, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 4, 5,
mod_trezorutils_memcpy); mod_trezorutils_memcpy);
/// def halt(msg: str = None) -> None: /// def halt(msg: str = None) -> None:

View File

@ -13,12 +13,13 @@ def consteq(sec: bytes, pub: bytes) -> bool:
# extmod/modtrezorutils/modtrezorutils.c # extmod/modtrezorutils/modtrezorutils.c
def memcpy( def memcpy(
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
) -> int: ) -> int:
""" """
Copies at most `n` bytes from `src` at offset `src_ofs` to Copies at most `n` bytes from `src` at offset `src_ofs` to
`dst` at offset `dst_ofs`. Returns the number of actually `dst` at offset `dst_ofs`. Returns the number of actually
copied bytes. copied bytes. If `n` is not specified, tries to copy
as much as possible.
""" """

View File

@ -9,14 +9,14 @@ if False:
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
from typing_extensions import Protocol from typing_extensions import Protocol
class AsyncReader(Protocol): class Reader(Protocol):
async def areadinto(self, buf: bytearray) -> int: def readinto(self, buf: bytearray) -> int:
""" """
Reads `len(buf)` bytes into `buf`, or raises `EOFError`. Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
""" """
class AsyncWriter(Protocol): class Writer(Protocol):
async def awrite(self, buf: bytes) -> int: def write(self, buf: bytes) -> int:
""" """
Writes all bytes from `buf`, or raises `EOFError`. Writes all bytes from `buf`, or raises `EOFError`.
""" """
@ -25,20 +25,20 @@ if False:
_UVARINT_BUFFER = bytearray(1) _UVARINT_BUFFER = bytearray(1)
async def load_uvarint(reader: AsyncReader) -> int: def load_uvarint(reader: Reader) -> int:
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
result = 0 result = 0
shift = 0 shift = 0
byte = 0x80 byte = 0x80
while byte & 0x80: while byte & 0x80:
await reader.areadinto(buffer) reader.readinto(buffer)
byte = buffer[0] byte = buffer[0]
result += (byte & 0x7F) << shift result += (byte & 0x7F) << shift
shift += 7 shift += 7
return result return result
async def dump_uvarint(writer: AsyncWriter, n: int) -> None: def dump_uvarint(writer: Writer, n: int) -> None:
if n < 0: if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.") raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
@ -46,7 +46,7 @@ async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
while shifted: while shifted:
shifted = n >> 7 shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
await writer.awrite(buffer) writer.write(buffer)
n = shifted n = shifted
@ -165,15 +165,15 @@ class MessageType:
class LimitedReader: class LimitedReader:
def __init__(self, reader: AsyncReader, limit: int) -> None: def __init__(self, reader: Reader, limit: int) -> None:
self.reader = reader self.reader = reader
self.limit = limit self.limit = limit
async def areadinto(self, buf: bytearray) -> int: def readinto(self, buf: bytearray) -> int:
if self.limit < len(buf): if self.limit < len(buf):
raise EOFError raise EOFError
else: else:
nread = await self.reader.areadinto(buf) nread = self.reader.readinto(buf)
self.limit -= nread self.limit -= nread
return nread return nread
@ -184,8 +184,8 @@ if False:
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType) LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
async def load_message( def load_message(
reader: AsyncReader, msg_type: Type[LoadedMessageType] reader: Reader, msg_type: Type[LoadedMessageType]
) -> LoadedMessageType: ) -> LoadedMessageType:
fields = msg_type.get_fields() fields = msg_type.get_fields()
msg = msg_type() msg = msg_type()
@ -197,7 +197,7 @@ async def load_message(
while True: while True:
try: try:
fkey = await load_uvarint(reader) fkey = load_uvarint(reader)
except EOFError: except EOFError:
break # no more fields to load break # no more fields to load
@ -208,10 +208,10 @@ async def load_message(
if field is None: # unknown field, skip it if field is None: # unknown field, skip it
if wtype == 0: if wtype == 0:
await load_uvarint(reader) load_uvarint(reader)
elif wtype == 2: elif wtype == 2:
ivalue = await load_uvarint(reader) ivalue = load_uvarint(reader)
await reader.areadinto(bytearray(ivalue)) reader.readinto(bytearray(ivalue))
else: else:
raise ValueError raise ValueError
continue continue
@ -220,7 +220,7 @@ async def load_message(
if wtype != ftype.WIRE_TYPE: if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema raise TypeError # parsed wire type differs from the schema
ivalue = await load_uvarint(reader) ivalue = load_uvarint(reader)
if ftype is UVarintType: if ftype is UVarintType:
fvalue = ivalue fvalue = ivalue
@ -232,13 +232,13 @@ async def load_message(
fvalue = ftype.validate(ivalue) fvalue = ftype.validate(ivalue)
elif ftype is BytesType: elif ftype is BytesType:
fvalue = bytearray(ivalue) fvalue = bytearray(ivalue)
await reader.areadinto(fvalue) reader.readinto(fvalue)
elif ftype is UnicodeType: elif ftype is UnicodeType:
fvalue = bytearray(ivalue) fvalue = bytearray(ivalue)
await reader.areadinto(fvalue) reader.readinto(fvalue)
fvalue = bytes(fvalue).decode() fvalue = bytes(fvalue).decode()
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype) fvalue = load_message(LimitedReader(reader, ivalue), ftype)
else: else:
raise TypeError # field type is unknown raise TypeError # field type is unknown
@ -257,9 +257,7 @@ async def load_message(
return msg return msg
async def dump_message( def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
writer: AsyncWriter, msg: MessageType, fields: Dict = None
) -> None:
repvalue = [0] repvalue = [0]
if fields is None: if fields is None:
@ -281,39 +279,39 @@ async def dump_message(
ffields = None # type: Optional[Dict] ffields = None # type: Optional[Dict]
for svalue in fvalue: for svalue in fvalue:
await dump_uvarint(writer, fkey) dump_uvarint(writer, fkey)
if ftype is UVarintType: if ftype is UVarintType:
await dump_uvarint(writer, svalue) dump_uvarint(writer, svalue)
elif ftype is SVarintType: elif ftype is SVarintType:
await dump_uvarint(writer, sint_to_uint(svalue)) dump_uvarint(writer, sint_to_uint(svalue))
elif ftype is BoolType: elif ftype is BoolType:
await dump_uvarint(writer, int(svalue)) dump_uvarint(writer, int(svalue))
elif isinstance(ftype, EnumType): elif isinstance(ftype, EnumType):
await dump_uvarint(writer, svalue) dump_uvarint(writer, svalue)
elif ftype is BytesType: elif ftype is BytesType:
if isinstance(svalue, list): if isinstance(svalue, list):
await dump_uvarint(writer, _count_bytes_list(svalue)) dump_uvarint(writer, _count_bytes_list(svalue))
for sub_svalue in svalue: for sub_svalue in svalue:
await writer.awrite(sub_svalue) writer.write(sub_svalue)
else: else:
await dump_uvarint(writer, len(svalue)) dump_uvarint(writer, len(svalue))
await writer.awrite(svalue) writer.write(svalue)
elif ftype is UnicodeType: elif ftype is UnicodeType:
svalue = svalue.encode() svalue = svalue.encode()
await dump_uvarint(writer, len(svalue)) dump_uvarint(writer, len(svalue))
await writer.awrite(svalue) writer.write(svalue)
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
if ffields is None: if ffields is None:
ffields = ftype.get_fields() ffields = ftype.get_fields()
await dump_uvarint(writer, count_message(svalue, ffields)) dump_uvarint(writer, count_message(svalue, ffields))
await dump_message(writer, svalue, ffields) dump_message(writer, svalue, ffields)
else: else:
raise TypeError raise TypeError

View File

@ -138,6 +138,7 @@ class Context:
def __init__(self, iface: WireInterface, sid: int) -> None: def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface self.iface = iface
self.sid = sid self.sid = sid
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
async def call( async def call(
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType] self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
@ -153,11 +154,13 @@ class Context:
del msg del msg
return await self.read_any(expected_wire_types) return await self.read_any(expected_wire_types)
async def read_from_wire(self) -> codec_v1.Message:
self.buffer_io.seek(0)
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
async def read( async def read(
self, expected_type: Type[protobuf.LoadedMessageType] self, expected_type: Type[protobuf.LoadedMessageType]
) -> protobuf.LoadedMessageType: ) -> protobuf.LoadedMessageType:
reader = self.make_reader()
if __debug__: if __debug__:
log.debug( log.debug(
__name__, __name__,
@ -167,14 +170,13 @@ class Context:
expected_type, expected_type,
) )
# Wait for the message header, contained in the first report. After # Load the full message into a buffer, parse out type and data payload
# we receive it, we have a message type to match on. msg = await self.read_from_wire()
await reader.aopen()
# If we got a message with unexpected type, raise the reader via # If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it. # `UnexpectedMessageError` and let the session handler deal with it.
if reader.type != expected_type.MESSAGE_WIRE_TYPE: if msg.type != expected_type.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(reader) raise UnexpectedMessageError(msg)
if __debug__: if __debug__:
log.debug( log.debug(
@ -187,14 +189,13 @@ class Context:
workflow.idle_timer.touch() workflow.idle_timer.touch()
# parse the message and return it # look up the protobuf class and parse the message
return await protobuf.load_message(reader, expected_type) pbtype = messages.get_type(msg.type)
return protobuf.load_message(msg.data, pbtype)
async def read_any( async def read_any(
self, expected_wire_types: Iterable[int] self, expected_wire_types: Iterable[int]
) -> protobuf.MessageType: ) -> protobuf.MessageType:
reader = self.make_reader()
if __debug__: if __debug__:
log.debug( log.debug(
__name__, __name__,
@ -204,17 +205,16 @@ class Context:
expected_wire_types, expected_wire_types,
) )
# Wait for the message header, contained in the first report. After # Load the full message into a buffer, parse out type and data payload
# we receive it, we have a message type to match on. msg = await self.read_from_wire()
await reader.aopen()
# If we got a message with unexpected type, raise the reader via # If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it. # `UnexpectedMessageError` and let the session handler deal with it.
if reader.type not in expected_wire_types: if msg.type not in expected_wire_types:
raise UnexpectedMessageError(reader) raise UnexpectedMessageError(msg)
# find the protobuf type # find the protobuf type
exptype = messages.get_type(reader.type) exptype = messages.get_type(msg.type)
if __debug__: if __debug__:
log.debug( log.debug(
@ -224,24 +224,20 @@ class Context:
workflow.idle_timer.touch() workflow.idle_timer.touch()
# parse the message and return it # parse the message and return it
return await protobuf.load_message(reader, exptype) return protobuf.load_message(msg.data, exptype)
async def write(self, msg: protobuf.MessageType) -> None: async def write(self, msg: protobuf.MessageType) -> None:
writer = self.make_writer()
if __debug__: if __debug__:
log.debug( log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg __name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
) )
# get the message size
fields = msg.get_fields()
size = protobuf.count_message(msg, fields)
# write the message # write the message
writer.setheader(msg.MESSAGE_WIRE_TYPE, size) self.buffer_io.seek(0)
await protobuf.dump_message(writer, msg, fields) protobuf.dump_message(self.buffer_io, msg)
await writer.aclose() await codec_v1.write_message(
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
)
def wait(self, *tasks: Awaitable) -> Any: def wait(self, *tasks: Awaitable) -> Any:
""" """
@ -251,43 +247,35 @@ class Context:
""" """
return loop.race(self.read_any(()), *tasks) return loop.race(self.read_any(()), *tasks)
def make_reader(self) -> codec_v1.Reader:
return codec_v1.Reader(self.iface)
def make_writer(self) -> codec_v1.Writer:
return codec_v1.Writer(self.iface)
class UnexpectedMessageError(Exception): class UnexpectedMessageError(Exception):
def __init__(self, reader: codec_v1.Reader) -> None: def __init__(self, msg: codec_v1.Message) -> None:
self.reader = reader self.msg = msg
async def handle_session( async def handle_session(
iface: WireInterface, session_id: int, use_workflow: bool = True iface: WireInterface, session_id: int, use_workflow: bool = True
) -> None: ) -> None:
ctx = Context(iface, session_id) ctx = Context(iface, session_id)
next_reader = None # type: Optional[codec_v1.Reader] next_msg = None # type: Optional[codec_v1.Message]
res_msg = None # type: Optional[protobuf.MessageType] res_msg = None # type: Optional[protobuf.MessageType]
req_reader = None
req_type = None req_type = None
req_msg = None req_msg = None
while True: while True:
try: try:
if next_reader is None: if next_msg is None:
# We are not currently reading a message, so let's wait for one. # We are not currently reading a message, so let's wait for one.
# If the decoding fails, exception is raised and we try again # If the decoding fails, exception is raised and we try again
# (with the same `Reader` instance, it's OK). Even in case of # (with the same `Reader` instance, it's OK). Even in case of
# de-synchronized wire communication, report with a message # de-synchronized wire communication, report with a message
# header is eventually received, after a couple of tries. # header is eventually received, after a couple of tries.
req_reader = ctx.make_reader() msg = await ctx.read_from_wire()
await req_reader.aopen()
if __debug__: if __debug__:
try: try:
msg_type = messages.get_type(req_reader.type).__name__ msg_type = messages.get_type(msg.type).__name__
except KeyError: except KeyError:
msg_type = "%d - unknown message type" % req_reader.type msg_type = "%d - unknown message type" % msg.type
log.debug( log.debug(
__name__, __name__,
"%s:%x receive: <%s>", "%s:%x receive: <%s>",
@ -298,8 +286,8 @@ async def handle_session(
else: else:
# We have a reader left over from earlier. We should process # We have a reader left over from earlier. We should process
# this message instead of waiting for new one. # this message instead of waiting for new one.
req_reader = next_reader msg = next_msg
next_reader = None next_msg = None
# Now we are in a middle of reading a message and we need to decide # Now we are in a middle of reading a message and we need to decide
# what to do with it, based on its type from the message header. # what to do with it, based on its type from the message header.
@ -312,13 +300,11 @@ async def handle_session(
# We need to find a handler for this message type. Should not # We need to find a handler for this message type. Should not
# raise. # raise.
handler = find_handler(iface, req_reader.type) handler = find_handler(iface, msg.type)
if handler is None: if handler is None:
# If no handler is found, we can skip decoding and directly # If no handler is found, we can skip decoding and directly
# respond with failure, but first, we should read the rest of # respond with failure. Should not raise.
# the message reports. Should not raise.
await read_and_throw_away(req_reader)
res_msg = unexpected_message() res_msg = unexpected_message()
else: else:
@ -332,11 +318,11 @@ async def handle_session(
try: try:
# Find a protobuf.MessageType subclass that describes this # Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found. # message. Raises if the type is not found.
req_type = messages.get_type(req_reader.type) req_type = messages.get_type(msg.type)
# Try to decode the message according to schema from # Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed. # `req_type`. Raises if the message is malformed.
req_msg = await protobuf.load_message(req_reader, req_type) req_msg = protobuf.load_message(msg.data, req_type)
# At this point, message reports are all processed and # At this point, message reports are all processed and
# correctly parsed into `req_msg`. # correctly parsed into `req_msg`.
@ -364,7 +350,7 @@ async def handle_session(
# TODO: # TODO:
# We might handle only the few common cases here, like # We might handle only the few common cases here, like
# Initialize and Cancel. # Initialize and Cancel.
next_reader = exc.reader next_msg = exc.msg
res_msg = None res_msg = None
except Exception as exc: except Exception as exc:
@ -401,7 +387,6 @@ async def handle_session(
# Cleanup, so garbage collection triggered after un-importing can # Cleanup, so garbage collection triggered after un-importing can
# pick up the trash. # pick up the trash.
req_reader = None
req_type = None req_type = None
req_msg = None req_msg = None
res_msg = None res_msg = None

View File

@ -18,143 +18,112 @@ SESSION_ID = const(0)
INVALID_TYPE = const(-1) INVALID_TYPE = const(-1)
class Reader: class CodecError(Exception):
""" pass
Decoder for a wire codec over the HID (or UDP) layer. Provides readable
async-file-like interface.
"""
def __init__(self, iface: WireInterface) -> None:
self.iface = iface
self.type = INVALID_TYPE
self.size = 0
self.ofs = 0
self.data = bytes()
def __repr__(self) -> str: class BytesIO:
return "<Reader type: %s>" % self.type def __init__(self, buffer: bytearray) -> None:
self.buffer = buffer
self.offset = 0
async def aopen(self) -> None: def seek(self, offset: int) -> None:
""" offset = min(offset, len(self.buffer))
Start reading a message by waiting for initial message report. Because offset = max(offset, 0)
the first report contains the message header, `self.type` and self.offset = offset
`self.size` are initialized and available after `aopen()` returns.
"""
read = loop.wait(self.iface.iface_num() | io.POLL_READ)
while True:
# wait for initial report
report = await read
marker = report[0]
if marker == _REP_MARKER:
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
raise ValueError
break
# load received message header def readinto(self, dst: bytearray) -> int:
self.type = mtype buffer = self.buffer
self.size = msize offset = self.offset
self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize] if len(dst) > len(buffer) - offset:
self.ofs = 0
async def areadinto(self, buf: bytearray) -> int:
"""
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
before the full read can be completed.
"""
if self.size < len(buf):
raise EOFError raise EOFError
nread = utils.memcpy(dst, 0, buffer, offset)
read = loop.wait(self.iface.iface_num() | io.POLL_READ) self.offset += nread
nread = 0
while nread < len(buf):
if self.ofs == len(self.data):
# we are at the end of received data
# wait for continuation report
while True:
report = await read
marker = report[0]
if marker == _REP_MARKER:
break
self.data = report[_REP_CONT_DATA : _REP_CONT_DATA + self.size]
self.ofs = 0
# copy as much as possible to target buffer
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
nread += nbytes
self.ofs += nbytes
self.size -= nbytes
return nread return nread
def write(self, src: bytes) -> int:
class Writer: buffer = self.buffer
""" offset = self.offset
Encoder for a wire codec over the HID (or UDP) layer. Provides writable if len(src) > len(buffer) - offset:
async-file-like interface.
"""
def __init__(self, iface: WireInterface):
self.iface = iface
self.type = INVALID_TYPE
self.size = 0
self.ofs = 0
self.data = bytearray(_REP_LEN)
def setheader(self, mtype: int, msize: int) -> None:
"""
Reset the writer state and load the message header with passed type and
total message size.
"""
self.type = mtype
self.size = msize
ustruct.pack_into(
_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
)
self.ofs = _REP_INIT_DATA
async def awrite(self, buf: bytes) -> int:
"""
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
exceeds the remaining message length.
"""
if self.size < len(buf):
raise EOFError raise EOFError
nwrite = utils.memcpy(buffer, offset, src, 0)
self.offset += nwrite
return nwrite
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE) def get_written(self) -> bytes:
nwritten = 0 return memoryview(self.buffer)[: self.offset]
while nwritten < len(buf):
# copy as much as possible to report buffer
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
nwritten += nbytes
self.ofs += nbytes
self.size -= nbytes
if self.ofs == _REP_LEN:
# we are at the end of the report, flush it
while True:
await write
n = self.iface.write(self.data)
if n == len(self.data):
break
self.ofs = _REP_CONT_DATA
return nwritten class Message:
def __init__(self, mtype: int, mdata: BytesIO) -> None:
self.type = mtype
self.data = mdata
async def aclose(self) -> None:
"""Flush and close the message transmission."""
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
# pad the final report and flush it
while self.ofs < _REP_LEN:
self.data[self.ofs] = 0x00
self.ofs += 1
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE) async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
while True: read = loop.wait(iface.iface_num() | io.POLL_READ)
await write
n = self.iface.write(self.data) # wait for initial report
if n == len(self.data): report = await read
break if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
raise CodecError("Invalid magic")
throw_away = False
if msize > len(buffer):
throw_away = True
# prepare the backing buffer
mdata = memoryview(buffer)[:msize]
# buffer the initial data
nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA)
while nread < msize:
# wait for continuation report
report = await read
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
# buffer the continuation data
if not throw_away:
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
if throw_away:
raise CodecError("Message too large")
return Message(mtype, BytesIO(mdata))
async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None:
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
# gather data from msg
msize = len(mdata)
# prepare the report buffer with header data
report = bytearray(_REP_LEN)
repofs = _REP_INIT_DATA
ustruct.pack_into(
_REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
)
nwritten = 0
while True:
# copy as much as possible to the report buffer
nwritten += utils.memcpy(report, repofs, mdata, nwritten)
# write the report
while True:
await write
n = iface.write(report)
if n == len(report):
break
# if we have more data to write, use continuation reports for it
if nwritten < msize:
repofs = _REP_CONT_DATA
else:
break