diff --git a/core/tests/test_protobuf.py b/core/tests/test_protobuf.py index 524b28069..3089469cb 100644 --- a/core/tests/test_protobuf.py +++ b/core/tests/test_protobuf.py @@ -1,7 +1,7 @@ from common import * import protobuf -from trezor.utils import BufferIO +from trezor.utils import BufferReader, BufferWriter class Message(protobuf.MessageType): @@ -18,12 +18,12 @@ class Message(protobuf.MessageType): def load_uvarint(data: bytes) -> int: - reader = BufferIO(data) + reader = BufferReader(data) return protobuf.load_uvarint(reader) def dump_uvarint(value: int) -> bytearray: - writer = BufferIO(bytearray(16)) + writer = BufferWriter(bytearray(16)) protobuf.dump_uvarint(writer, value) return memoryview(writer.buffer)[:writer.offset] @@ -65,22 +65,23 @@ class TestProtobuf(unittest.TestCase): # ok message: msg = Message(-42, 5) length = protobuf.count_message(msg) - buffer_io = BufferIO(bytearray(length)) - protobuf.dump_message(buffer_io, msg) - buffer_io.seek(0) - nmsg = protobuf.load_message(buffer_io, Message) + buffer_writer = BufferWriter(bytearray(length)) + protobuf.dump_message(buffer_writer, msg) + + buffer_reader = BufferReader(buffer_writer.buffer) + nmsg = protobuf.load_message(buffer_reader, Message) self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.enum_field, nmsg.enum_field) # bad enum value: - buffer_io.seek(0) + buffer_writer.seek(0) msg = Message(-42, 42) # XXX this assumes the message will have equal size - protobuf.dump_message(buffer_io, msg) - buffer_io.seek(0) + protobuf.dump_message(buffer_writer, msg) + buffer_reader.seek(0) with self.assertRaises(TypeError): - protobuf.load_message(buffer_io, Message) + protobuf.load_message(buffer_reader, Message) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec_v1.py index 5ff5c15b0..3e8cab290 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec_v1.py @@ -4,7 +4,7 @@ import ustruct from trezor import io from trezor.loop import wait -from trezor.utils import chunks, BufferIO +from trezor.utils import chunks from trezor.wire import codec_v1 @@ -179,17 +179,20 @@ class TestWireCodecV1(unittest.TestCase): self.assertEqual(result.data.buffer, message) def test_read_huge_packet(self): - # length such that it fits into 1 000 000 USB packets: - header = make_header( - mtype=MESSAGE_TYPE, length=999_999 * 63 + HEADER_PAYLOAD_LENGTH - ) + PACKET_COUNT = 100_000 + # message that takes up 100 000 USB packets + message_size = (PACKET_COUNT - 1) * 63 + HEADER_PAYLOAD_LENGTH + # ensure that a message this big won't fit into memory + self.assertRaises(MemoryError, bytearray, message_size) + + header = make_header(mtype=MESSAGE_TYPE, length=message_size) packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH buffer = bytearray(65536) gen = codec_v1.read_message(self.interface, buffer) query = gen.send(None) - for _ in range(999_999): + for _ in range(PACKET_COUNT - 1): self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) query = gen.send(packet)