mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-04 03:40:58 +00:00
core: make USB codec resilient to OOM conditions
This commit is contained in:
parent
0a758b8181
commit
3514a31bc9
@ -39,9 +39,15 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
|
read_and_throw_away = False
|
||||||
|
|
||||||
if msize > len(buffer):
|
if msize > len(buffer):
|
||||||
# allocate a new buffer to fit the message
|
# allocate a new buffer to fit the message
|
||||||
|
try:
|
||||||
mdata = bytearray(msize) # type: utils.BufferType
|
mdata = bytearray(msize) # type: utils.BufferType
|
||||||
|
except MemoryError:
|
||||||
|
mdata = bytearray(_REP_LEN)
|
||||||
|
read_and_throw_away = True
|
||||||
else:
|
else:
|
||||||
# reuse a part of the supplied buffer
|
# reuse a part of the supplied buffer
|
||||||
mdata = memoryview(buffer)[:msize]
|
mdata = memoryview(buffer)[:msize]
|
||||||
@ -56,8 +62,14 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
# buffer the continuation data
|
# buffer the continuation data
|
||||||
|
if read_and_throw_away:
|
||||||
|
nread += len(report) - 1
|
||||||
|
else:
|
||||||
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA)
|
||||||
|
|
||||||
|
if read_and_throw_away:
|
||||||
|
raise CodecError("Message too large")
|
||||||
|
|
||||||
return Message(mtype, utils.BufferIO(mdata))
|
return Message(mtype, utils.BufferIO(mdata))
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
declare -a results
|
declare -a results
|
||||||
declare -i passed=0 failed=0 exit_code=0
|
declare -i passed=0 failed=0 exit_code=0
|
||||||
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
|
declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
|
||||||
MICROPYTHON="${MICROPYTHON:-../build/unix/micropython}"
|
MICROPYTHON="${MICROPYTHON:-../build/unix/micropython -X heapsize=1M}"
|
||||||
|
|
||||||
print_summary() {
|
print_summary() {
|
||||||
echo
|
echo
|
||||||
|
@ -26,6 +26,8 @@ class MockHID:
|
|||||||
|
|
||||||
MESSAGE_TYPE = 0x4242
|
MESSAGE_TYPE = 0x4242
|
||||||
|
|
||||||
|
HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL")
|
||||||
|
|
||||||
|
|
||||||
def make_header(mtype, length):
|
def make_header(mtype, length):
|
||||||
return b"?##" + ustruct.pack(">HL", mtype, length)
|
return b"?##" + ustruct.pack(">HL", mtype, length)
|
||||||
@ -60,11 +62,11 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
message = bytes(range(256))
|
message = bytes(range(256))
|
||||||
|
|
||||||
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
||||||
first_len = codec_v1._REP_LEN - len(header)
|
|
||||||
# first packet is header + (remaining)data
|
# first packet is header + (remaining)data
|
||||||
# other packets are "?" + 63 bytes of data
|
# other packets are "?" + 63 bytes of data
|
||||||
packets = [header + message[:first_len]] + [
|
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
|
||||||
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
|
b"?" + chunk
|
||||||
|
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
buffer = bytearray(256)
|
buffer = bytearray(256)
|
||||||
@ -120,7 +122,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
gen.send(None)
|
gen.send(None)
|
||||||
|
|
||||||
header = make_header(mtype=MESSAGE_TYPE, length=0)
|
header = make_header(mtype=MESSAGE_TYPE, length=0)
|
||||||
expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header))
|
expected_message = header + b"\x00" * HEADER_PAYLOAD_LENGTH
|
||||||
self.assertTrue(self.interface.data == [expected_message])
|
self.assertTrue(self.interface.data == [expected_message])
|
||||||
|
|
||||||
def test_write_multiple_packets(self):
|
def test_write_multiple_packets(self):
|
||||||
@ -128,11 +130,11 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
|
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
|
||||||
|
|
||||||
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
||||||
first_len = codec_v1._REP_LEN - len(header)
|
|
||||||
# first packet is header + (remaining)data
|
# first packet is header + (remaining)data
|
||||||
# other packets are "?" + 63 bytes of data
|
# other packets are "?" + 63 bytes of data
|
||||||
packets = [header + message[:first_len]] + [
|
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
|
||||||
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
|
b"?" + chunk
|
||||||
|
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
for _ in packets:
|
for _ in packets:
|
||||||
@ -150,7 +152,7 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
# last packet must be identical up to message length. remaining bytes in
|
# last packet must be identical up to message length. remaining bytes in
|
||||||
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||||
# previous packet
|
# previous packet
|
||||||
last_packet = packets[-1] + packets[-2][len(packets[-1]):]
|
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||||
self.assertEqual(last_packet, self.interface.data[-1])
|
self.assertEqual(last_packet, self.interface.data[-1])
|
||||||
|
|
||||||
def test_roundtrip(self):
|
def test_roundtrip(self):
|
||||||
@ -173,8 +175,28 @@ class TestWireCodecV1(unittest.TestCase):
|
|||||||
gen.send(self.interface.data[-1])
|
gen.send(self.interface.data[-1])
|
||||||
|
|
||||||
result = e.value.value
|
result = e.value.value
|
||||||
assert result.type == MESSAGE_TYPE
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
assert result.data.buffer == message
|
self.assertEqual(result.data.buffer, message)
|
||||||
|
|
||||||
|
def test_read_huge_packet(self):
|
||||||
|
# length such that it fits into 1 000 000 USB packets:
|
||||||
|
header = make_header(
|
||||||
|
mtype=MESSAGE_TYPE, length=999_999 * 63 + HEADER_PAYLOAD_LENGTH
|
||||||
|
)
|
||||||
|
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
|
||||||
|
|
||||||
|
buffer = bytearray(65536)
|
||||||
|
gen = codec_v1.read_message(self.interface, buffer)
|
||||||
|
|
||||||
|
query = gen.send(None)
|
||||||
|
for _ in range(999_999):
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
query = gen.send(packet)
|
||||||
|
|
||||||
|
with self.assertRaises(codec_v1.CodecError) as e:
|
||||||
|
gen.send(packet)
|
||||||
|
|
||||||
|
self.assertEqual(e.value.args[0], "Message too large")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user