From 3514a31bc96cb9550ec39c13ded50c5e7530a87c Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 13 Jul 2020 13:20:23 +0200 Subject: [PATCH] core: make USB codec resilient to OOM conditions --- core/src/trezor/wire/codec_v1.py | 16 ++++++++-- core/tests/run_tests.sh | 2 +- core/tests/test_trezor.wire.codec_v1.py | 42 +++++++++++++++++++------ 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index dfb1d9a0dd..0ae747b7ba 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -39,9 +39,15 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC: raise CodecError("Invalid magic") + read_and_throw_away = False + if msize > len(buffer): # allocate a new buffer to fit the message - mdata = bytearray(msize) # type: utils.BufferType + try: + mdata = bytearray(msize) # type: utils.BufferType + except MemoryError: + mdata = bytearray(_REP_LEN) + read_and_throw_away = True else: # reuse a part of the supplied buffer mdata = memoryview(buffer)[:msize] @@ -56,7 +62,13 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag raise CodecError("Invalid magic") # buffer the continuation data - nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA) + if read_and_throw_away: + nread += len(report) - 1 + else: + nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA) + + if read_and_throw_away: + raise CodecError("Message too large") return Message(mtype, utils.BufferIO(mdata)) diff --git a/core/tests/run_tests.sh b/core/tests/run_tests.sh index 6dbc97659b..204eae110e 100755 --- a/core/tests/run_tests.sh +++ b/core/tests/run_tests.sh @@ -3,7 +3,7 @@ declare -a results declare -i passed=0 failed=0 exit_code=0 declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m' -MICROPYTHON="${MICROPYTHON:-../build/unix/micropython}" +MICROPYTHON="${MICROPYTHON:-../build/unix/micropython -X heapsize=1M}" print_summary() { echo diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec_v1.py index 0785c2f929..5ff5c15b03 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec_v1.py @@ -26,6 +26,8 @@ class MockHID: MESSAGE_TYPE = 0x4242 +HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL") + def make_header(mtype, length): return b"?##" + ustruct.pack(">HL", mtype, length) @@ -60,11 +62,11 @@ class TestWireCodecV1(unittest.TestCase): message = bytes(range(256)) header = make_header(mtype=MESSAGE_TYPE, length=len(message)) - first_len = codec_v1._REP_LEN - len(header) # first packet is header + (remaining)data # other packets are "?" + 63 bytes of data - packets = [header + message[:first_len]] + [ - b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1) + packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [ + b"?" + chunk + for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1) ] buffer = bytearray(256) @@ -120,7 +122,7 @@ class TestWireCodecV1(unittest.TestCase): gen.send(None) header = make_header(mtype=MESSAGE_TYPE, length=0) - expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header)) + expected_message = header + b"\x00" * HEADER_PAYLOAD_LENGTH self.assertTrue(self.interface.data == [expected_message]) def test_write_multiple_packets(self): @@ -128,11 +130,11 @@ class TestWireCodecV1(unittest.TestCase): gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message) header = make_header(mtype=MESSAGE_TYPE, length=len(message)) - first_len = codec_v1._REP_LEN - len(header) # first packet is header + (remaining)data # other packets are "?" + 63 bytes of data - packets = [header + message[:first_len]] + [ - b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1) + packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [ + b"?" + chunk + for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1) ] for _ in packets: @@ -150,7 +152,7 @@ class TestWireCodecV1(unittest.TestCase): # last packet must be identical up to message length. remaining bytes in # the 64-byte packets are garbage -- in particular, it's the bytes of the # previous packet - last_packet = packets[-1] + packets[-2][len(packets[-1]):] + last_packet = packets[-1] + packets[-2][len(packets[-1]) :] self.assertEqual(last_packet, self.interface.data[-1]) def test_roundtrip(self): @@ -173,8 +175,28 @@ class TestWireCodecV1(unittest.TestCase): gen.send(self.interface.data[-1]) result = e.value.value - assert result.type == MESSAGE_TYPE - assert result.data.buffer == message + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data.buffer, message) + + def test_read_huge_packet(self): + # length such that it fits into 1 000 000 USB packets: + header = make_header( + mtype=MESSAGE_TYPE, length=999_999 * 63 + HEADER_PAYLOAD_LENGTH + ) + packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH + + buffer = bytearray(65536) + gen = codec_v1.read_message(self.interface, buffer) + + query = gen.send(None) + for _ in range(999_999): + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + with self.assertRaises(codec_v1.CodecError) as e: + gen.send(packet) + + self.assertEqual(e.value.args[0], "Message too large") if __name__ == "__main__":