mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-03 19:31:02 +00:00
core: make USB codec resilient to OOM conditions
This commit is contained in:
parent
0a758b8181
commit
3514a31bc9
@ -39,9 +39,15 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||
raise CodecError("Invalid magic")
|
||||
|
||||
read_and_throw_away = False
|
||||
|
||||
if msize > len(buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
mdata = bytearray(msize) # type: utils.BufferType
|
||||
try:
|
||||
mdata = bytearray(msize) # type: utils.BufferType
|
||||
except MemoryError:
|
||||
mdata = bytearray(_REP_LEN)
|
||||
read_and_throw_away = True
|
||||
else:
|
||||
# reuse a part of the supplied buffer
|
||||
mdata = memoryview(buffer)[:msize]
|
||||
@ -56,7 +62,13 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
||||
raise CodecError("Invalid magic")
|
||||
|
||||
# buffer the continuation data
|
||||
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||
if read_and_throw_away:
|
||||
nread += len(report) - 1
|
||||
else:
|
||||
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||
|
||||
if read_and_throw_away:
|
||||
raise CodecError("Message too large")
|
||||
|
||||
return Message(mtype, utils.BufferIO(mdata))
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
declare -a results
|
||||
declare -i passed=0 failed=0 exit_code=0
|
||||
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
|
||||
MICROPYTHON="${MICROPYTHON:-../build/unix/micropython}"
|
||||
MICROPYTHON="${MICROPYTHON:-../build/unix/micropython -X heapsize=1M}"
|
||||
|
||||
print_summary() {
|
||||
echo
|
||||
|
@ -26,6 +26,8 @@ class MockHID:
|
||||
|
||||
MESSAGE_TYPE = 0x4242
|
||||
|
||||
HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL")
|
||||
|
||||
|
||||
def make_header(mtype, length):
|
||||
return b"?##" + ustruct.pack(">HL", mtype, length)
|
||||
@ -60,11 +62,11 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
message = bytes(range(256))
|
||||
|
||||
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
||||
first_len = codec_v1._REP_LEN - len(header)
|
||||
# first packet is header + (remaining)data
|
||||
# other packets are "?" + 63 bytes of data
|
||||
packets = [header + message[:first_len]] + [
|
||||
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
|
||||
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
|
||||
b"?" + chunk
|
||||
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
|
||||
]
|
||||
|
||||
buffer = bytearray(256)
|
||||
@ -120,7 +122,7 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
gen.send(None)
|
||||
|
||||
header = make_header(mtype=MESSAGE_TYPE, length=0)
|
||||
expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header))
|
||||
expected_message = header + b"\x00" * HEADER_PAYLOAD_LENGTH
|
||||
self.assertTrue(self.interface.data == [expected_message])
|
||||
|
||||
def test_write_multiple_packets(self):
|
||||
@ -128,11 +130,11 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
|
||||
|
||||
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
||||
first_len = codec_v1._REP_LEN - len(header)
|
||||
# first packet is header + (remaining)data
|
||||
# other packets are "?" + 63 bytes of data
|
||||
packets = [header + message[:first_len]] + [
|
||||
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
|
||||
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
|
||||
b"?" + chunk
|
||||
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
|
||||
]
|
||||
|
||||
for _ in packets:
|
||||
@ -150,7 +152,7 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
# last packet must be identical up to message length. remaining bytes in
|
||||
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||
# previous packet
|
||||
last_packet = packets[-1] + packets[-2][len(packets[-1]):]
|
||||
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||
self.assertEqual(last_packet, self.interface.data[-1])
|
||||
|
||||
def test_roundtrip(self):
|
||||
@ -173,8 +175,28 @@ class TestWireCodecV1(unittest.TestCase):
|
||||
gen.send(self.interface.data[-1])
|
||||
|
||||
result = e.value.value
|
||||
assert result.type == MESSAGE_TYPE
|
||||
assert result.data.buffer == message
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data.buffer, message)
|
||||
|
||||
def test_read_huge_packet(self):
|
||||
# length such that it fits into 1 000 000 USB packets:
|
||||
header = make_header(
|
||||
mtype=MESSAGE_TYPE, length=999_999 * 63 + HEADER_PAYLOAD_LENGTH
|
||||
)
|
||||
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
|
||||
|
||||
buffer = bytearray(65536)
|
||||
gen = codec_v1.read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
for _ in range(999_999):
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
query = gen.send(packet)
|
||||
|
||||
with self.assertRaises(codec_v1.CodecError) as e:
|
||||
gen.send(packet)
|
||||
|
||||
self.assertEqual(e.value.args[0], "Message too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user