1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-26 16:18:22 +00:00

core: implement synchronous v1 codec

This commit is contained in:
matejcik 2020-06-26 12:30:12 +02:00 committed by Tomas Susanka
parent 34bd57006f
commit 85d74ece76
5 changed files with 186 additions and 227 deletions

View File

@ -58,15 +58,16 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorutils_consteq_obj,
mod_trezorutils_consteq);
/// def memcpy(
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int
/// dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
/// ) -> int:
/// """
/// Copies at most `n` bytes from `src` at offset `src_ofs` to
/// `dst` at offset `dst_ofs`. Returns the number of actually
/// copied bytes.
/// `dst` at offset `dst_ofs`. Returns the number of actually
/// copied bytes. If `n` is not specified, tries to copy
/// as much as possible.
/// """
STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
mp_arg_check_num(n_args, 0, 5, 5, false);
mp_arg_check_num(n_args, 0, 4, 5, false);
mp_buffer_info_t dst;
mp_get_buffer_raise(args[0], &dst, MP_BUFFER_WRITE);
@ -76,7 +77,12 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
mp_get_buffer_raise(args[2], &src, MP_BUFFER_READ);
uint32_t src_ofs = trezor_obj_get_uint(args[3]);
uint32_t n = trezor_obj_get_uint(args[4]);
uint32_t n = 0;
if (n_args > 4) {
n = trezor_obj_get_uint(args[4]);
} else {
n = src.len;
}
size_t dst_rem = (dst_ofs < dst.len) ? dst.len - dst_ofs : 0;
size_t src_rem = (src_ofs < src.len) ? src.len - src_ofs : 0;
@ -86,7 +92,7 @@ STATIC mp_obj_t mod_trezorutils_memcpy(size_t n_args, const mp_obj_t *args) {
return mp_obj_new_int(ncpy);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 5, 5,
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorutils_memcpy_obj, 4, 5,
mod_trezorutils_memcpy);
/// def halt(msg: str = None) -> None:

View File

@ -13,12 +13,13 @@ def consteq(sec: bytes, pub: bytes) -> bool:
# extmod/modtrezorutils/modtrezorutils.c
def memcpy(
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int
dst: bytearray, dst_ofs: int, src: bytes, src_ofs: int, n: int = None
) -> int:
"""
Copies at most `n` bytes from `src` at offset `src_ofs` to
`dst` at offset `dst_ofs`. Returns the number of actually
copied bytes.
`dst` at offset `dst_ofs`. Returns the number of actually
copied bytes. If `n` is not specified, tries to copy
as much as possible.
"""

View File

@ -9,14 +9,14 @@ if False:
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
from typing_extensions import Protocol
class AsyncReader(Protocol):
async def areadinto(self, buf: bytearray) -> int:
class Reader(Protocol):
def readinto(self, buf: bytearray) -> int:
"""
Reads `len(buf)` bytes into `buf`, or raises `EOFError`.
"""
class AsyncWriter(Protocol):
async def awrite(self, buf: bytes) -> int:
class Writer(Protocol):
def write(self, buf: bytes) -> int:
"""
Writes all bytes from `buf`, or raises `EOFError`.
"""
@ -25,20 +25,20 @@ if False:
_UVARINT_BUFFER = bytearray(1)
async def load_uvarint(reader: AsyncReader) -> int:
def load_uvarint(reader: Reader) -> int:
buffer = _UVARINT_BUFFER
result = 0
shift = 0
byte = 0x80
while byte & 0x80:
await reader.areadinto(buffer)
reader.readinto(buffer)
byte = buffer[0]
result += (byte & 0x7F) << shift
shift += 7
return result
async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
def dump_uvarint(writer: Writer, n: int) -> None:
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER
@ -46,7 +46,7 @@ async def dump_uvarint(writer: AsyncWriter, n: int) -> None:
while shifted:
shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
await writer.awrite(buffer)
writer.write(buffer)
n = shifted
@ -165,15 +165,15 @@ class MessageType:
class LimitedReader:
def __init__(self, reader: AsyncReader, limit: int) -> None:
def __init__(self, reader: Reader, limit: int) -> None:
self.reader = reader
self.limit = limit
async def areadinto(self, buf: bytearray) -> int:
def readinto(self, buf: bytearray) -> int:
if self.limit < len(buf):
raise EOFError
else:
nread = await self.reader.areadinto(buf)
nread = self.reader.readinto(buf)
self.limit -= nread
return nread
@ -184,8 +184,8 @@ if False:
LoadedMessageType = TypeVar("LoadedMessageType", bound=MessageType)
async def load_message(
reader: AsyncReader, msg_type: Type[LoadedMessageType]
def load_message(
reader: Reader, msg_type: Type[LoadedMessageType]
) -> LoadedMessageType:
fields = msg_type.get_fields()
msg = msg_type()
@ -197,7 +197,7 @@ async def load_message(
while True:
try:
fkey = await load_uvarint(reader)
fkey = load_uvarint(reader)
except EOFError:
break # no more fields to load
@ -208,10 +208,10 @@ async def load_message(
if field is None: # unknown field, skip it
if wtype == 0:
await load_uvarint(reader)
load_uvarint(reader)
elif wtype == 2:
ivalue = await load_uvarint(reader)
await reader.areadinto(bytearray(ivalue))
ivalue = load_uvarint(reader)
reader.readinto(bytearray(ivalue))
else:
raise ValueError
continue
@ -220,7 +220,7 @@ async def load_message(
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
ivalue = await load_uvarint(reader)
ivalue = load_uvarint(reader)
if ftype is UVarintType:
fvalue = ivalue
@ -232,13 +232,13 @@ async def load_message(
fvalue = ftype.validate(ivalue)
elif ftype is BytesType:
fvalue = bytearray(ivalue)
await reader.areadinto(fvalue)
reader.readinto(fvalue)
elif ftype is UnicodeType:
fvalue = bytearray(ivalue)
await reader.areadinto(fvalue)
reader.readinto(fvalue)
fvalue = bytes(fvalue).decode()
elif issubclass(ftype, MessageType):
fvalue = await load_message(LimitedReader(reader, ivalue), ftype)
fvalue = load_message(LimitedReader(reader, ivalue), ftype)
else:
raise TypeError # field type is unknown
@ -257,9 +257,7 @@ async def load_message(
return msg
async def dump_message(
writer: AsyncWriter, msg: MessageType, fields: Dict = None
) -> None:
def dump_message(writer: Writer, msg: MessageType, fields: Dict = None) -> None:
repvalue = [0]
if fields is None:
@ -281,39 +279,39 @@ async def dump_message(
ffields = None # type: Optional[Dict]
for svalue in fvalue:
await dump_uvarint(writer, fkey)
dump_uvarint(writer, fkey)
if ftype is UVarintType:
await dump_uvarint(writer, svalue)
dump_uvarint(writer, svalue)
elif ftype is SVarintType:
await dump_uvarint(writer, sint_to_uint(svalue))
dump_uvarint(writer, sint_to_uint(svalue))
elif ftype is BoolType:
await dump_uvarint(writer, int(svalue))
dump_uvarint(writer, int(svalue))
elif isinstance(ftype, EnumType):
await dump_uvarint(writer, svalue)
dump_uvarint(writer, svalue)
elif ftype is BytesType:
if isinstance(svalue, list):
await dump_uvarint(writer, _count_bytes_list(svalue))
dump_uvarint(writer, _count_bytes_list(svalue))
for sub_svalue in svalue:
await writer.awrite(sub_svalue)
writer.write(sub_svalue)
else:
await dump_uvarint(writer, len(svalue))
await writer.awrite(svalue)
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif ftype is UnicodeType:
svalue = svalue.encode()
await dump_uvarint(writer, len(svalue))
await writer.awrite(svalue)
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif issubclass(ftype, MessageType):
if ffields is None:
ffields = ftype.get_fields()
await dump_uvarint(writer, count_message(svalue, ffields))
await dump_message(writer, svalue, ffields)
dump_uvarint(writer, count_message(svalue, ffields))
dump_message(writer, svalue, ffields)
else:
raise TypeError

View File

@ -138,6 +138,7 @@ class Context:
def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface
self.sid = sid
self.buffer_io = codec_v1.BytesIO(bytearray(8192))
async def call(
self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
@ -153,11 +154,13 @@ class Context:
del msg
return await self.read_any(expected_wire_types)
async def read_from_wire(self) -> codec_v1.Message:
self.buffer_io.seek(0)
return await codec_v1.read_message(self.iface, self.buffer_io.buffer)
async def read(
self, expected_type: Type[protobuf.LoadedMessageType]
) -> protobuf.LoadedMessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__,
@ -167,14 +170,13 @@ class Context:
expected_type,
)
# Wait for the message header, contained in the first report. After
# we receive it, we have a message type to match on.
await reader.aopen()
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the reader via
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if reader.type != expected_type.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(reader)
if msg.type != expected_type.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(msg)
if __debug__:
log.debug(
@ -187,14 +189,13 @@ class Context:
workflow.idle_timer.touch()
# parse the message and return it
return await protobuf.load_message(reader, expected_type)
# look up the protobuf class and parse the message
pbtype = messages.get_type(msg.type)
return protobuf.load_message(msg.data, pbtype)
async def read_any(
self, expected_wire_types: Iterable[int]
) -> protobuf.MessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__,
@ -204,17 +205,16 @@ class Context:
expected_wire_types,
)
# Wait for the message header, contained in the first report. After
# we receive it, we have a message type to match on.
await reader.aopen()
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the reader via
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if reader.type not in expected_wire_types:
raise UnexpectedMessageError(reader)
if msg.type not in expected_wire_types:
raise UnexpectedMessageError(msg)
# find the protobuf type
exptype = messages.get_type(reader.type)
exptype = messages.get_type(msg.type)
if __debug__:
log.debug(
@ -224,24 +224,20 @@ class Context:
workflow.idle_timer.touch()
# parse the message and return it
return await protobuf.load_message(reader, exptype)
return protobuf.load_message(msg.data, exptype)
async def write(self, msg: protobuf.MessageType) -> None:
writer = self.make_writer()
if __debug__:
log.debug(
__name__, "%s:%x write: %s", self.iface.iface_num(), self.sid, msg
)
# get the message size
fields = msg.get_fields()
size = protobuf.count_message(msg, fields)
# write the message
writer.setheader(msg.MESSAGE_WIRE_TYPE, size)
await protobuf.dump_message(writer, msg, fields)
await writer.aclose()
self.buffer_io.seek(0)
protobuf.dump_message(self.buffer_io, msg)
await codec_v1.write_message(
self.iface, msg.MESSAGE_WIRE_TYPE, self.buffer_io.get_written()
)
def wait(self, *tasks: Awaitable) -> Any:
"""
@ -251,43 +247,35 @@ class Context:
"""
return loop.race(self.read_any(()), *tasks)
def make_reader(self) -> codec_v1.Reader:
return codec_v1.Reader(self.iface)
def make_writer(self) -> codec_v1.Writer:
return codec_v1.Writer(self.iface)
class UnexpectedMessageError(Exception):
def __init__(self, reader: codec_v1.Reader) -> None:
self.reader = reader
def __init__(self, msg: codec_v1.Message) -> None:
self.msg = msg
async def handle_session(
iface: WireInterface, session_id: int, use_workflow: bool = True
) -> None:
ctx = Context(iface, session_id)
next_reader = None # type: Optional[codec_v1.Reader]
next_msg = None # type: Optional[codec_v1.Message]
res_msg = None # type: Optional[protobuf.MessageType]
req_reader = None
req_type = None
req_msg = None
while True:
try:
if next_reader is None:
if next_msg is None:
# We are not currently reading a message, so let's wait for one.
# If the decoding fails, exception is raised and we try again
# (with the same `Reader` instance, it's OK). Even in case of
# de-synchronized wire communication, report with a message
# header is eventually received, after a couple of tries.
req_reader = ctx.make_reader()
await req_reader.aopen()
msg = await ctx.read_from_wire()
if __debug__:
try:
msg_type = messages.get_type(req_reader.type).__name__
msg_type = messages.get_type(msg.type).__name__
except KeyError:
msg_type = "%d - unknown message type" % req_reader.type
msg_type = "%d - unknown message type" % msg.type
log.debug(
__name__,
"%s:%x receive: <%s>",
@ -298,8 +286,8 @@ async def handle_session(
else:
# We have a reader left over from earlier. We should process
# this message instead of waiting for new one.
req_reader = next_reader
next_reader = None
msg = next_msg
next_msg = None
# Now we are in a middle of reading a message and we need to decide
# what to do with it, based on its type from the message header.
@ -312,13 +300,11 @@ async def handle_session(
# We need to find a handler for this message type. Should not
# raise.
handler = find_handler(iface, req_reader.type)
handler = find_handler(iface, msg.type)
if handler is None:
# If no handler is found, we can skip decoding and directly
# respond with failure, but first, we should read the rest of
# the message reports. Should not raise.
await read_and_throw_away(req_reader)
# respond with failure. Should not raise.
res_msg = unexpected_message()
else:
@ -332,11 +318,11 @@ async def handle_session(
try:
# Find a protobuf.MessageType subclass that describes this
# message. Raises if the type is not found.
req_type = messages.get_type(req_reader.type)
req_type = messages.get_type(msg.type)
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = await protobuf.load_message(req_reader, req_type)
req_msg = protobuf.load_message(msg.data, req_type)
# At this point, message reports are all processed and
# correctly parsed into `req_msg`.
@ -364,7 +350,7 @@ async def handle_session(
# TODO:
# We might handle only the few common cases here, like
# Initialize and Cancel.
next_reader = exc.reader
next_msg = exc.msg
res_msg = None
except Exception as exc:
@ -401,7 +387,6 @@ async def handle_session(
# Cleanup, so garbage collection triggered after un-importing can
# pick up the trash.
req_reader = None
req_type = None
req_msg = None
res_msg = None

View File

@ -18,143 +18,112 @@ SESSION_ID = const(0)
INVALID_TYPE = const(-1)
class Reader:
"""
Decoder for a wire codec over the HID (or UDP) layer. Provides readable
async-file-like interface.
"""
class CodecError(Exception):
pass
def __init__(self, iface: WireInterface) -> None:
self.iface = iface
self.type = INVALID_TYPE
self.size = 0
self.ofs = 0
self.data = bytes()
def __repr__(self) -> str:
return "<Reader type: %s>" % self.type
class BytesIO:
def __init__(self, buffer: bytearray) -> None:
self.buffer = buffer
self.offset = 0
async def aopen(self) -> None:
"""
Start reading a message by waiting for initial message report. Because
the first report contains the message header, `self.type` and
`self.size` are initialized and available after `aopen()` returns.
"""
read = loop.wait(self.iface.iface_num() | io.POLL_READ)
while True:
# wait for initial report
report = await read
marker = report[0]
if marker == _REP_MARKER:
_, m1, m2, mtype, msize = ustruct.unpack(_REP_INIT, report)
if m1 != _REP_MAGIC or m2 != _REP_MAGIC:
raise ValueError
break
def seek(self, offset: int) -> None:
offset = min(offset, len(self.buffer))
offset = max(offset, 0)
self.offset = offset
# load received message header
self.type = mtype
self.size = msize
self.data = report[_REP_INIT_DATA : _REP_INIT_DATA + msize]
self.ofs = 0
async def areadinto(self, buf: bytearray) -> int:
"""
Read exactly `len(buf)` bytes into `buf`, waiting for additional
reports, if needed. Raises `EOFError` if end-of-message is encountered
before the full read can be completed.
"""
if self.size < len(buf):
def readinto(self, dst: bytearray) -> int:
buffer = self.buffer
offset = self.offset
if len(dst) > len(buffer) - offset:
raise EOFError
read = loop.wait(self.iface.iface_num() | io.POLL_READ)
nread = 0
while nread < len(buf):
if self.ofs == len(self.data):
# we are at the end of received data
# wait for continuation report
while True:
report = await read
marker = report[0]
if marker == _REP_MARKER:
break
self.data = report[_REP_CONT_DATA : _REP_CONT_DATA + self.size]
self.ofs = 0
# copy as much as possible to target buffer
nbytes = utils.memcpy(buf, nread, self.data, self.ofs, len(buf))
nread += nbytes
self.ofs += nbytes
self.size -= nbytes
nread = utils.memcpy(dst, 0, buffer, offset)
self.offset += nread
return nread
class Writer:
"""
Encoder for a wire codec over the HID (or UDP) layer. Provides writable
async-file-like interface.
"""
def __init__(self, iface: WireInterface):
self.iface = iface
self.type = INVALID_TYPE
self.size = 0
self.ofs = 0
self.data = bytearray(_REP_LEN)
def setheader(self, mtype: int, msize: int) -> None:
"""
Reset the writer state and load the message header with passed type and
total message size.
"""
self.type = mtype
self.size = msize
ustruct.pack_into(
_REP_INIT, self.data, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
)
self.ofs = _REP_INIT_DATA
async def awrite(self, buf: bytes) -> int:
"""
Encode and write every byte from `buf`. Does not need to be called in
case message has zero length. Raises `EOFError` if the length of `buf`
exceeds the remaining message length.
"""
if self.size < len(buf):
def write(self, src: bytes) -> int:
buffer = self.buffer
offset = self.offset
if len(src) > len(buffer) - offset:
raise EOFError
nwrite = utils.memcpy(buffer, offset, src, 0)
self.offset += nwrite
return nwrite
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE)
nwritten = 0
while nwritten < len(buf):
# copy as much as possible to report buffer
nbytes = utils.memcpy(self.data, self.ofs, buf, nwritten, len(buf))
nwritten += nbytes
self.ofs += nbytes
self.size -= nbytes
def get_written(self) -> bytes:
return memoryview(self.buffer)[: self.offset]
if self.ofs == _REP_LEN:
# we are at the end of the report, flush it
while True:
await write
n = self.iface.write(self.data)
if n == len(self.data):
break
self.ofs = _REP_CONT_DATA
return nwritten
class Message:
def __init__(self, mtype: int, mdata: BytesIO) -> None:
self.type = mtype
self.data = mdata
async def aclose(self) -> None:
"""Flush and close the message transmission."""
if self.ofs != _REP_CONT_DATA:
# we didn't write anything or last write() wasn't report-aligned,
# pad the final report and flush it
while self.ofs < _REP_LEN:
self.data[self.ofs] = 0x00
self.ofs += 1
write = loop.wait(self.iface.iface_num() | io.POLL_WRITE)
while True:
await write
n = self.iface.write(self.data)
if n == len(self.data):
break
async def read_message(iface: WireInterface, buffer: bytearray) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)
# wait for initial report
report = await read
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
_, magic1, magic2, mtype, msize = ustruct.unpack(_REP_INIT, report)
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
raise CodecError("Invalid magic")
throw_away = False
if msize > len(buffer):
throw_away = True
# prepare the backing buffer
mdata = memoryview(buffer)[:msize]
# buffer the initial data
nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA)
while nread < msize:
# wait for continuation report
report = await read
if report[0] != _REP_MARKER:
raise CodecError("Invalid magic")
# buffer the continuation data
if not throw_away:
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
if throw_away:
raise CodecError("Message too large")
return Message(mtype, BytesIO(mdata))
async def write_message(iface: WireInterface, mtype: int, mdata: bytearray) -> None:
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
# gather data from msg
msize = len(mdata)
# prepare the report buffer with header data
report = bytearray(_REP_LEN)
repofs = _REP_INIT_DATA
ustruct.pack_into(
_REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize
)
nwritten = 0
while True:
# copy as much as possible to the report buffer
nwritten += utils.memcpy(report, repofs, mdata, nwritten)
# write the report
while True:
await write
n = iface.write(report)
if n == len(report):
break
# if we have more data to write, use continuation reports for it
if nwritten < msize:
repofs = _REP_CONT_DATA
else:
break