core/tests: update unit tests

pull/1128/head
matejcik 4 years ago committed by Tomas Susanka
parent 5e7fd3aea6
commit 7befdd07e4

@ -1,7 +1,7 @@
from common import * from common import *
import protobuf import protobuf
from trezor.utils import BufferIO from trezor.utils import BufferReader, BufferWriter
class Message(protobuf.MessageType): class Message(protobuf.MessageType):
@ -18,12 +18,12 @@ class Message(protobuf.MessageType):
def load_uvarint(data: bytes) -> int: def load_uvarint(data: bytes) -> int:
reader = BufferIO(data) reader = BufferReader(data)
return protobuf.load_uvarint(reader) return protobuf.load_uvarint(reader)
def dump_uvarint(value: int) -> bytearray: def dump_uvarint(value: int) -> bytearray:
writer = BufferIO(bytearray(16)) writer = BufferWriter(bytearray(16))
protobuf.dump_uvarint(writer, value) protobuf.dump_uvarint(writer, value)
return memoryview(writer.buffer)[:writer.offset] return memoryview(writer.buffer)[:writer.offset]
@ -65,22 +65,23 @@ class TestProtobuf(unittest.TestCase):
# ok message: # ok message:
msg = Message(-42, 5) msg = Message(-42, 5)
length = protobuf.count_message(msg) length = protobuf.count_message(msg)
buffer_io = BufferIO(bytearray(length)) buffer_writer = BufferWriter(bytearray(length))
protobuf.dump_message(buffer_io, msg) protobuf.dump_message(buffer_writer, msg)
buffer_io.seek(0)
nmsg = protobuf.load_message(buffer_io, Message) buffer_reader = BufferReader(buffer_writer.buffer)
nmsg = protobuf.load_message(buffer_reader, Message)
self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.sint_field, nmsg.sint_field)
self.assertEqual(msg.enum_field, nmsg.enum_field) self.assertEqual(msg.enum_field, nmsg.enum_field)
# bad enum value: # bad enum value:
buffer_io.seek(0) buffer_writer.seek(0)
msg = Message(-42, 42) msg = Message(-42, 42)
# XXX this assumes the message will have equal size # XXX this assumes the message will have equal size
protobuf.dump_message(buffer_io, msg) protobuf.dump_message(buffer_writer, msg)
buffer_io.seek(0) buffer_reader.seek(0)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
protobuf.load_message(buffer_io, Message) protobuf.load_message(buffer_reader, Message)
if __name__ == "__main__": if __name__ == "__main__":

@ -4,7 +4,7 @@ import ustruct
from trezor import io from trezor import io
from trezor.loop import wait from trezor.loop import wait
from trezor.utils import chunks, BufferIO from trezor.utils import chunks
from trezor.wire import codec_v1 from trezor.wire import codec_v1
@ -179,17 +179,20 @@ class TestWireCodecV1(unittest.TestCase):
self.assertEqual(result.data.buffer, message) self.assertEqual(result.data.buffer, message)
def test_read_huge_packet(self): def test_read_huge_packet(self):
# length such that it fits into 1 000 000 USB packets: PACKET_COUNT = 100_000
header = make_header( # message that takes up 100 000 USB packets
mtype=MESSAGE_TYPE, length=999_999 * 63 + HEADER_PAYLOAD_LENGTH message_size = (PACKET_COUNT - 1) * 63 + HEADER_PAYLOAD_LENGTH
) # ensure that a message this big won't fit into memory
self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(mtype=MESSAGE_TYPE, length=message_size)
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
buffer = bytearray(65536) buffer = bytearray(65536)
gen = codec_v1.read_message(self.interface, buffer) gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None) query = gen.send(None)
for _ in range(999_999): for _ in range(PACKET_COUNT - 1):
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet) query = gen.send(packet)

Loading…
Cancel
Save