core: make USB codec resilient to OOM conditions

pull/1128/head
matejcik 4 years ago committed by Tomas Susanka
parent 0a758b8181
commit 3514a31bc9

@ -39,9 +39,15 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
raise CodecError("Invalid magic")
read_and_throw_away = False
if msize > len(buffer):
# allocate a new buffer to fit the message
mdata = bytearray(msize) # type: utils.BufferType
try:
mdata = bytearray(msize) # type: utils.BufferType
except MemoryError:
mdata = bytearray(_REP_LEN)
read_and_throw_away = True
else:
# reuse a part of the supplied buffer
mdata = memoryview(buffer)[:msize]
@ -56,7 +62,13 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
raise CodecError("Invalid magic")
# buffer the continuation data
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
if read_and_throw_away:
nread += len(report) - 1
else:
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
if read_and_throw_away:
raise CodecError("Message too large")
return Message(mtype, utils.BufferIO(mdata))

@ -3,7 +3,7 @@
declare -a results
declare -i passed=0 failed=0 exit_code=0
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
MICROPYTHON="${MICROPYTHON:-../build/unix/micropython}"
MICROPYTHON="${MICROPYTHON:-../build/unix/micropython -X heapsize=1M}"
print_summary() {
echo

@ -26,6 +26,8 @@ class MockHID:
MESSAGE_TYPE = 0x4242
HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL")
def make_header(mtype, length):
return b"?##" + ustruct.pack(">HL", mtype, length)
@ -60,11 +62,11 @@ class TestWireCodecV1(unittest.TestCase):
message = bytes(range(256))
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
first_len = codec_v1._REP_LEN - len(header)
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:first_len]] + [
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
b"?" + chunk
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
]
buffer = bytearray(256)
@ -120,7 +122,7 @@ class TestWireCodecV1(unittest.TestCase):
gen.send(None)
header = make_header(mtype=MESSAGE_TYPE, length=0)
expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header))
expected_message = header + b"\x00" * HEADER_PAYLOAD_LENGTH
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
@ -128,11 +130,11 @@ class TestWireCodecV1(unittest.TestCase):
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
first_len = codec_v1._REP_LEN - len(header)
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:first_len]] + [
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
b"?" + chunk
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
]
for _ in packets:
@ -150,7 +152,7 @@ class TestWireCodecV1(unittest.TestCase):
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]):]
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
@ -173,8 +175,28 @@ class TestWireCodecV1(unittest.TestCase):
gen.send(self.interface.data[-1])
result = e.value.value
assert result.type == MESSAGE_TYPE
assert result.data.buffer == message
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
def test_read_huge_packet(self):
# length such that it fits into 1 000 000 USB packets:
header = make_header(
mtype=MESSAGE_TYPE, length=999_999 * 63 + HEADER_PAYLOAD_LENGTH
)
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
buffer = bytearray(65536)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
for _ in range(999_999):
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(codec_v1.CodecError) as e:
gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":

Loading…
Cancel
Save