From 20bcc6892679c592821b3535a6defe73b7ab20ff Mon Sep 17 00:00:00 2001 From: Tomas Susanka Date: Mon, 23 Dec 2019 13:35:06 +0000 Subject: [PATCH] core/tests: remove utest.py --- core/tests/README.md | 18 ++ core/tests/test_apps.monero.serializer.py | 1 - core/tests/test_trezor.wire.codec_v1.py | 254 +++++++++++----------- core/tests/unittest.py | 20 +- core/tests/utest.py | 138 ------------ 5 files changed, 158 insertions(+), 273 deletions(-) create mode 100644 core/tests/README.md delete mode 100644 core/tests/utest.py diff --git a/core/tests/README.md b/core/tests/README.md new file mode 100644 index 000000000..3bf324715 --- /dev/null +++ b/core/tests/README.md @@ -0,0 +1,18 @@ +# Unit tests + +Unit tests test some smaller individual parts of code (mainly functions and classes) and are run by micropython directly. + +## Usage + +Please use the unittest.TestCase class: + +```python +from common import * + +class TestSomething(unittest.TestCase): + + test_something(self): + self.assertTrue(True) +``` + +Usage of `assert` is discouraged because it is not evaluated in production code (when `PYOPT=1`). Use `self.assertXY` instead, see `unittest.py`. diff --git a/core/tests/test_apps.monero.serializer.py b/core/tests/test_apps.monero.serializer.py index 18197dc01..95915646d 100644 --- a/core/tests/test_apps.monero.serializer.py +++ b/core/tests/test_apps.monero.serializer.py @@ -1,4 +1,3 @@ -import utest from common import * from trezor import log, loop, utils diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec_v1.py index 11f0e602f..931c0d85e 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec_v1.py @@ -1,8 +1,4 @@ -import sys - -sys.path.append('../src') - -from utest import * +from common import * from ubinascii import unhexlify from trezor import io @@ -25,147 +21,149 @@ class MockHID: return len(msg) -def test_reader(): - rep_len = 64 - interface_num = 0xdeadbeef - message_type = 0x4321 - message_len = 250 - interface = MockHID(interface_num) - reader = codec_v1.Reader(interface) +class TestWireCodecV1(unittest.TestCase): - message = bytearray(range(message_len)) - report_header = bytearray(unhexlify('3f23234321000000fa')) + def test_reader(self): + rep_len = 64 + interface_num = 0xdeadbeef + message_type = 0x4321 + message_len = 250 + interface = MockHID(interface_num) + reader = codec_v1.Reader(interface) - # open, expected one read - first_report = report_header + message[:rep_len - len(report_header)] - assert_async(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ]) - assert_eq(reader.type, message_type) - assert_eq(reader.size, message_len) + message = bytearray(range(message_len)) + report_header = bytearray(unhexlify('3f23234321000000fa')) - # empty read - empty_buffer = bytearray() - assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()), ]) - assert_eq(len(empty_buffer), 0) - assert_eq(reader.size, message_len) + # open, expected one read + first_report = report_header + message[:rep_len - len(report_header)] + self.assertAsync(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ]) + self.assertEqual(reader.type, message_type) + self.assertEqual(reader.size, message_len) - # short read, expected no read - short_buffer = bytearray(32) - assert_async(reader.areadinto(short_buffer), [(None, StopIteration()), ]) - assert_eq(len(short_buffer), 32) - assert_eq(short_buffer, message[:len(short_buffer)]) - assert_eq(reader.size, message_len - len(short_buffer)) + # empty read + empty_buffer = bytearray() + self.assertAsync(reader.areadinto(empty_buffer), [(None, StopIteration()), ]) + self.assertEqual(len(empty_buffer), 0) + self.assertEqual(reader.size, message_len) - # aligned read, expected no read - aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) - assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()), ]) - assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) - assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) + # short read, expected no read + short_buffer = bytearray(32) + self.assertAsync(reader.areadinto(short_buffer), [(None, StopIteration()), ]) + self.assertEqual(len(short_buffer), 32) + self.assertEqual(short_buffer, message[:len(short_buffer)]) + self.assertEqual(reader.size, message_len - len(short_buffer)) - # one byte read, expected one read - next_report_header = bytearray(unhexlify('3f')) - next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] - onebyte_buffer = bytearray(1) - assert_async(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ]) - assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) - assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) + # aligned read, expected no read + aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) + self.assertAsync(reader.areadinto(aligned_buffer), [(None, StopIteration()), ]) + self.assertEqual(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) + self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) - # too long read, raises eof - assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ]) + # one byte read, expected one read + next_report_header = bytearray(unhexlify('3f')) + next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] + onebyte_buffer = bytearray(1) + self.assertAsync(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ]) + self.assertEqual(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) + self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) - # long read, expect multiple reads - start_size = reader.size - long_buffer = bytearray(start_size) - report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):] - report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)] - report_payload_rest = report_payload[len(report_payload_head):] - report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header))) - report_payloads = [report_payload_head] + report_payload_rest - next_reports = [next_report_header + r for r in report_payloads] - expected_syscalls = [] - for i, _ in enumerate(next_reports): - prev_report = next_reports[i - 1] if i > 0 else None - expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num))) - expected_syscalls.append((next_reports[-1], StopIteration())) - assert_async(reader.areadinto(long_buffer), expected_syscalls) - assert_eq(long_buffer, message[-start_size:]) - assert_eq(reader.size, 0) + # too long read, raises eof + self.assertAsync(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ]) - # one byte read, raises eof - assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()), ]) + # long read, expect multiple reads + start_size = reader.size + long_buffer = bytearray(start_size) + report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):] + report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)] + report_payload_rest = report_payload[len(report_payload_head):] + report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header))) + report_payloads = [report_payload_head] + report_payload_rest + next_reports = [next_report_header + r for r in report_payloads] + expected_syscalls = [] + for i, _ in enumerate(next_reports): + prev_report = next_reports[i - 1] if i > 0 else None + expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num))) + expected_syscalls.append((next_reports[-1], StopIteration())) + self.assertAsync(reader.areadinto(long_buffer), expected_syscalls) + self.assertEqual(long_buffer, message[-start_size:]) + self.assertEqual(reader.size, 0) + + # one byte read, raises eof + self.assertAsync(reader.areadinto(onebyte_buffer), [(None, EOFError()), ]) -def test_writer(): - rep_len = 64 - interface_num = 0xdeadbeef - message_type = 0x87654321 - message_len = 1024 - interface = MockHID(interface_num) - writer = codec_v1.Writer(interface) - writer.setheader(message_type, message_len) + def test_writer(self): + rep_len = 64 + interface_num = 0xdeadbeef + message_type = 0x87654321 + message_len = 1024 + interface = MockHID(interface_num) + writer = codec_v1.Writer(interface) + writer.setheader(message_type, message_len) - # init header corresponding to the data above - report_header = bytearray(unhexlify('3f2323432100000400')) + # init header corresponding to the data above + report_header = bytearray(unhexlify('3f2323432100000400')) - assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) + self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header))) - # empty write - start_size = writer.size - assert_async(writer.awrite(bytearray()), [(None, StopIteration()), ]) - assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header))) - assert_eq(writer.size, start_size) + # empty write + start_size = writer.size + self.assertAsync(writer.awrite(bytearray()), [(None, StopIteration()), ]) + self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header))) + self.assertEqual(writer.size, start_size) - # short write, expected no report - start_size = writer.size - short_payload = bytearray(range(4)) - assert_async(writer.awrite(short_payload), [(None, StopIteration()), ]) - assert_eq(writer.size, start_size - len(short_payload)) - assert_eq(writer.data, - report_header + - short_payload + - bytearray(rep_len - len(report_header) - len(short_payload))) + # short write, expected no report + start_size = writer.size + short_payload = bytearray(range(4)) + self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ]) + self.assertEqual(writer.size, start_size - len(short_payload)) + self.assertEqual(writer.data, + report_header + + short_payload + + bytearray(rep_len - len(report_header) - len(short_payload))) - # aligned write, expected one report - start_size = writer.size - aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) - assert_async(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ]) - assert_eq(interface.data, [report_header + - short_payload + - aligned_payload + - bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ]) - assert_eq(writer.size, start_size - len(aligned_payload)) - interface.data.clear() + # aligned write, expected one report + start_size = writer.size + aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) + self.assertAsync(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ]) + self.assertEqual(interface.data, [report_header + + short_payload + + aligned_payload + + bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ]) + self.assertEqual(writer.size, start_size - len(aligned_payload)) + interface.data.clear() - # short write, expected no report, but data starts with correct seq and cont marker - report_header = bytearray(unhexlify('3f')) - start_size = writer.size - assert_async(writer.awrite(short_payload), [(None, StopIteration()), ]) - assert_eq(writer.size, start_size - len(short_payload)) - assert_eq(writer.data[:len(report_header) + len(short_payload)], - report_header + short_payload) + # short write, expected no report, but data starts with correct seq and cont marker + report_header = bytearray(unhexlify('3f')) + start_size = writer.size + self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ]) + self.assertEqual(writer.size, start_size - len(short_payload)) + self.assertEqual(writer.data[:len(report_header) + len(short_payload)], + report_header + short_payload) - # long write, expected multiple reports - start_size = writer.size - long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload))) - long_payload_rest = bytearray(range(start_size - len(long_payload_head))) - long_payload = long_payload_head + long_payload_rest - expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header))) - expected_reports = [report_header + r for r in expected_payloads] - expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1]))) - # test write - expected_write_reports = expected_reports[:-1] - assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) - assert_eq(interface.data, expected_write_reports) - assert_eq(writer.size, start_size - len(long_payload)) - interface.data.clear() - # test write raises eof - assert_async(writer.awrite(bytearray(1)), [(None, EOFError())]) - assert_eq(interface.data, []) - # test close - expected_close_reports = expected_reports[-1:] - assert_async(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) - assert_eq(interface.data, expected_close_reports) - assert_eq(writer.size, 0) + # long write, expected multiple reports + start_size = writer.size + long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload))) + long_payload_rest = bytearray(range(start_size - len(long_payload_head))) + long_payload = long_payload_head + long_payload_rest + expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header))) + expected_reports = [report_header + r for r in expected_payloads] + expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1]))) + # test write + expected_write_reports = expected_reports[:-1] + self.assertAsync(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) + self.assertEqual(interface.data, expected_write_reports) + self.assertEqual(writer.size, start_size - len(long_payload)) + interface.data.clear() + # test write raises eof + self.assertAsync(writer.awrite(bytearray(1)), [(None, EOFError())]) + self.assertEqual(interface.data, []) + # test close + expected_close_reports = expected_reports[-1:] + self.assertAsync(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) + self.assertEqual(interface.data, expected_close_reports) + self.assertEqual(writer.size, 0) if __name__ == '__main__': - run_tests() + unittest.main() diff --git a/core/tests/unittest.py b/core/tests/unittest.py index 0c52a08da..4cb339fe3 100644 --- a/core/tests/unittest.py +++ b/core/tests/unittest.py @@ -1,5 +1,4 @@ from trezor.utils import ensure -from utest import assert_async class SkipTest(Exception): @@ -135,6 +134,19 @@ class TestCase: for i in range(len(x)): self.assertEqual(x[i], y[i], msg) + def assertAsync(self, task, syscalls): + for prev_result, expected in syscalls: + if isinstance(expected, Exception): + with self.assertRaises(expected.__class__): + task.send(prev_result) + else: + syscall = task.send(prev_result) + self.assertObjectEqual(syscall, expected) + + def assertObjectEqual(self, a, b, msg=''): + self.assertIsInstance(a, b.__class__, msg) + self.assertEqual(a.__dict__, b.__dict__, msg) + def skip(msg): def _decor(fun): @@ -188,16 +200,12 @@ def run_class(c, test_result): print('class', c.__qualname__) for name in dir(o): if name.startswith("test"): - is_async = name.startswith("test_async") print(' ', name, end=' ...') m = getattr(o, name) try: set_up() test_result.testsRun += 1 - if is_async: - assert_async(m(), [(None, StopIteration()), ]) - else: - m() + m() tear_down() print(" ok") except SkipTest as e: diff --git a/core/tests/utest.py b/core/tests/utest.py deleted file mode 100644 index 17d2fde1c..000000000 --- a/core/tests/utest.py +++ /dev/null @@ -1,138 +0,0 @@ -import sys - -__all__ = [ - 'run_tests', - 'run_test', - 'assert_eq', - 'assert_not_eq', - 'assert_is_instance', - 'mock_call', -] - - -# Running - - -def run_tests(mod_name='__main__'): - ntotal = 0 - nok = 0 - nfailed = 0 - - for name, test in get_tests(mod_name): - result = run_test(test) - report_test(name, test, result) - ntotal += 1 - if result: - nok += 1 - else: - nfailed += 1 - break - report_total(ntotal, nok, nfailed) - - if nfailed > 0: - sys.exit(1) - - -def get_tests(mod_name): - module = __import__(mod_name) - for name in dir(module): - if name.startswith('test_'): - yield name, getattr(module, name) - - -def run_test(test): - try: - test() - except Exception as e: - report_exception(e) - return False - else: - return True - - -# Reporting - - -def report_test(name, test, result): - if result: - print('OK', name) - else: - print('ERR', name) - - -def report_exception(exc): - sys.print_exception(exc) - - -def report_total(total, ok, failed): - print('Total:', total, 'OK:', ok, 'Failed:', failed) - - -# Assertions - - -def assert_eq(a, b, msg=None): - assert a == b, msg or format_eq(a, b) - - -def assert_not_eq(a, b, msg=None): - assert a != b, msg or format_not_eq(a, b) - - -def assert_is_instance(obj, cls, msg=None): - assert isinstance(obj, cls), msg or format_is_instance(obj, cls) - - -def assert_eq_obj(a, b, msg=None): - assert_is_instance(a, b.__class__, msg) - assert_eq(a.__dict__, b.__dict__, msg) - - -def format_eq(a, b): - return '\n%r\nvs (expected)\n%r' % (a, b) - - -def format_not_eq(a, b): - return '%r not expected to be equal %r' % (a, b) - - -def format_is_instance(obj, cls): - return '%r expected to be instance of %r' % (obj, cls) - - -def assert_async(task, syscalls): - for prev_result, expected in syscalls: - if isinstance(expected, Exception): - with assert_raises(expected.__class__): - task.send(prev_result) - else: - syscall = task.send(prev_result) - assert_eq_obj(syscall, expected) - - -class assert_raises: - - def __init__(self, exc_type): - self.exc_type = exc_type - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - assert exc_type is not None, '%r not raised' % self.exc_type - return issubclass(exc_type, self.exc_type) - - -class mock_call: - - def __init__(self, original, expected): - self.original = original - self.expected = expected - self.record = [] - - def __call__(self, *args): - self.record.append(args) - assert_eq(args, self.expected.pop(0)) - - def assert_called_n_times(self, n, msg=None): - assert_eq(len(self.record), n, msg)