diff --git a/core/tests/common.py b/core/tests/common.py index 1e699e3820..fce79b33c1 100644 --- a/core/tests/common.py +++ b/core/tests/common.py @@ -1,9 +1,23 @@ import sys -sys.path.append('../src') +sys.path.append("../src") from ubinascii import hexlify, unhexlify # noqa: F401 import unittest # noqa: F401 from trezor import utils # noqa: F401 + + +def await_result(task: Awaitable) -> Any: + value = None + while True: + try: + result = task.send(value) + except StopIteration as e: + return e.value + + if result: + value = await_result(result) + else: + value = None diff --git a/core/tests/test_protobuf.py b/core/tests/test_protobuf.py index d362401e52..c6de123319 100644 --- a/core/tests/test_protobuf.py +++ b/core/tests/test_protobuf.py @@ -44,28 +44,14 @@ class ByteArrayWriter: return len(buf) -def run_until_complete(task: Awaitable) -> Any: - value = None - while True: - try: - result = task.send(value) - except StopIteration as e: - return e.value - - if result: - value = run_until_complete(result) - else: - value = None - - def load_uvarint(data: bytes) -> int: reader = ByteReader(data) - return run_until_complete(protobuf.load_uvarint(reader)) + return await_result(protobuf.load_uvarint(reader)) def dump_uvarint(value: int) -> bytearray: writer = ByteArrayWriter() - run_until_complete(protobuf.dump_uvarint(writer, value)) + await_result(protobuf.dump_uvarint(writer, value)) return writer.buf @@ -106,9 +92,9 @@ class TestProtobuf(unittest.TestCase): # ok message: msg = Message(-42, 5) writer = ByteArrayWriter() - run_until_complete(protobuf.dump_message(writer, msg)) + await_result(protobuf.dump_message(writer, msg)) reader = ByteReader(bytes(writer.buf)) - nmsg = run_until_complete(protobuf.load_message(reader, Message)) + nmsg = await_result(protobuf.load_message(reader, Message)) self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.enum_field, nmsg.enum_field) @@ -116,10 +102,10 @@ class TestProtobuf(unittest.TestCase): # bad enum value: msg = Message(-42, 42) writer = ByteArrayWriter() - run_until_complete(protobuf.dump_message(writer, msg)) + await_result(protobuf.dump_message(writer, msg)) reader = ByteReader(bytes(writer.buf)) with self.assertRaises(TypeError): - run_until_complete(protobuf.load_message(reader, Message)) + await_result(protobuf.load_message(reader, Message)) if __name__ == "__main__":