mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-22 21:30:56 +00:00
core/tests: extract common await_result() method
This commit is contained in:
parent
ee07b32f52
commit
d4171aaedc
@ -1,9 +1,23 @@
|
||||
import sys
|
||||
|
||||
sys.path.append('../src')
|
||||
sys.path.append("../src")
|
||||
|
||||
from ubinascii import hexlify, unhexlify # noqa: F401
|
||||
|
||||
import unittest # noqa: F401
|
||||
|
||||
from trezor import utils # noqa: F401
|
||||
|
||||
|
||||
def await_result(task: Awaitable) -> Any:
|
||||
value = None
|
||||
while True:
|
||||
try:
|
||||
result = task.send(value)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
if result:
|
||||
value = await_result(result)
|
||||
else:
|
||||
value = None
|
||||
|
@ -44,28 +44,14 @@ class ByteArrayWriter:
|
||||
return len(buf)
|
||||
|
||||
|
||||
def run_until_complete(task: Awaitable) -> Any:
|
||||
value = None
|
||||
while True:
|
||||
try:
|
||||
result = task.send(value)
|
||||
except StopIteration as e:
|
||||
return e.value
|
||||
|
||||
if result:
|
||||
value = run_until_complete(result)
|
||||
else:
|
||||
value = None
|
||||
|
||||
|
||||
def load_uvarint(data: bytes) -> int:
|
||||
reader = ByteReader(data)
|
||||
return run_until_complete(protobuf.load_uvarint(reader))
|
||||
return await_result(protobuf.load_uvarint(reader))
|
||||
|
||||
|
||||
def dump_uvarint(value: int) -> bytearray:
|
||||
writer = ByteArrayWriter()
|
||||
run_until_complete(protobuf.dump_uvarint(writer, value))
|
||||
await_result(protobuf.dump_uvarint(writer, value))
|
||||
return writer.buf
|
||||
|
||||
|
||||
@ -106,9 +92,9 @@ class TestProtobuf(unittest.TestCase):
|
||||
# ok message:
|
||||
msg = Message(-42, 5)
|
||||
writer = ByteArrayWriter()
|
||||
run_until_complete(protobuf.dump_message(writer, msg))
|
||||
await_result(protobuf.dump_message(writer, msg))
|
||||
reader = ByteReader(bytes(writer.buf))
|
||||
nmsg = run_until_complete(protobuf.load_message(reader, Message))
|
||||
nmsg = await_result(protobuf.load_message(reader, Message))
|
||||
|
||||
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
||||
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
||||
@ -116,10 +102,10 @@ class TestProtobuf(unittest.TestCase):
|
||||
# bad enum value:
|
||||
msg = Message(-42, 42)
|
||||
writer = ByteArrayWriter()
|
||||
run_until_complete(protobuf.dump_message(writer, msg))
|
||||
await_result(protobuf.dump_message(writer, msg))
|
||||
reader = ByteReader(bytes(writer.buf))
|
||||
with self.assertRaises(TypeError):
|
||||
run_until_complete(protobuf.load_message(reader, Message))
|
||||
await_result(protobuf.load_message(reader, Message))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user