mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-02 10:51:06 +00:00
core/tests: extract common await_result() method
This commit is contained in:
parent
ee07b32f52
commit
d4171aaedc
@ -1,9 +1,23 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append('../src')
|
sys.path.append("../src")
|
||||||
|
|
||||||
from ubinascii import hexlify, unhexlify # noqa: F401
|
from ubinascii import hexlify, unhexlify # noqa: F401
|
||||||
|
|
||||||
import unittest # noqa: F401
|
import unittest # noqa: F401
|
||||||
|
|
||||||
from trezor import utils # noqa: F401
|
from trezor import utils # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def await_result(task: Awaitable) -> Any:
|
||||||
|
value = None
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = task.send(value)
|
||||||
|
except StopIteration as e:
|
||||||
|
return e.value
|
||||||
|
|
||||||
|
if result:
|
||||||
|
value = await_result(result)
|
||||||
|
else:
|
||||||
|
value = None
|
||||||
|
@ -44,28 +44,14 @@ class ByteArrayWriter:
|
|||||||
return len(buf)
|
return len(buf)
|
||||||
|
|
||||||
|
|
||||||
def run_until_complete(task: Awaitable) -> Any:
|
|
||||||
value = None
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
result = task.send(value)
|
|
||||||
except StopIteration as e:
|
|
||||||
return e.value
|
|
||||||
|
|
||||||
if result:
|
|
||||||
value = run_until_complete(result)
|
|
||||||
else:
|
|
||||||
value = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_uvarint(data: bytes) -> int:
|
def load_uvarint(data: bytes) -> int:
|
||||||
reader = ByteReader(data)
|
reader = ByteReader(data)
|
||||||
return run_until_complete(protobuf.load_uvarint(reader))
|
return await_result(protobuf.load_uvarint(reader))
|
||||||
|
|
||||||
|
|
||||||
def dump_uvarint(value: int) -> bytearray:
|
def dump_uvarint(value: int) -> bytearray:
|
||||||
writer = ByteArrayWriter()
|
writer = ByteArrayWriter()
|
||||||
run_until_complete(protobuf.dump_uvarint(writer, value))
|
await_result(protobuf.dump_uvarint(writer, value))
|
||||||
return writer.buf
|
return writer.buf
|
||||||
|
|
||||||
|
|
||||||
@ -106,9 +92,9 @@ class TestProtobuf(unittest.TestCase):
|
|||||||
# ok message:
|
# ok message:
|
||||||
msg = Message(-42, 5)
|
msg = Message(-42, 5)
|
||||||
writer = ByteArrayWriter()
|
writer = ByteArrayWriter()
|
||||||
run_until_complete(protobuf.dump_message(writer, msg))
|
await_result(protobuf.dump_message(writer, msg))
|
||||||
reader = ByteReader(bytes(writer.buf))
|
reader = ByteReader(bytes(writer.buf))
|
||||||
nmsg = run_until_complete(protobuf.load_message(reader, Message))
|
nmsg = await_result(protobuf.load_message(reader, Message))
|
||||||
|
|
||||||
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
||||||
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
||||||
@ -116,10 +102,10 @@ class TestProtobuf(unittest.TestCase):
|
|||||||
# bad enum value:
|
# bad enum value:
|
||||||
msg = Message(-42, 42)
|
msg = Message(-42, 42)
|
||||||
writer = ByteArrayWriter()
|
writer = ByteArrayWriter()
|
||||||
run_until_complete(protobuf.dump_message(writer, msg))
|
await_result(protobuf.dump_message(writer, msg))
|
||||||
reader = ByteReader(bytes(writer.buf))
|
reader = ByteReader(bytes(writer.buf))
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
run_until_complete(protobuf.load_message(reader, Message))
|
await_result(protobuf.load_message(reader, Message))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user