mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-25 01:18:54 +00:00
core/tests: remove utest.py
This commit is contained in:
parent
cf80f9cc43
commit
20bcc68926
18
core/tests/README.md
Normal file
18
core/tests/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Unit tests
|
||||||
|
|
||||||
|
Unit tests test some smaller individual parts of code (mainly functions and classes) and are run by micropython directly.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Please use the unittest.TestCase class:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from common import *
|
||||||
|
|
||||||
|
class TestSomething(unittest.TestCase):
|
||||||
|
|
||||||
|
test_something(self):
|
||||||
|
self.assertTrue(True)
|
||||||
|
```
|
||||||
|
|
||||||
|
Usage of `assert` is discouraged because it is not evaluated in production code (when `PYOPT=1`). Use `self.assertXY` instead, see `unittest.py`.
|
@ -1,4 +1,3 @@
|
|||||||
import utest
|
|
||||||
from common import *
|
from common import *
|
||||||
from trezor import log, loop, utils
|
from trezor import log, loop, utils
|
||||||
|
|
||||||
|
@ -1,8 +1,4 @@
|
|||||||
import sys
|
from common import *
|
||||||
|
|
||||||
sys.path.append('../src')
|
|
||||||
|
|
||||||
from utest import *
|
|
||||||
from ubinascii import unhexlify
|
from ubinascii import unhexlify
|
||||||
|
|
||||||
from trezor import io
|
from trezor import io
|
||||||
@ -25,147 +21,149 @@ class MockHID:
|
|||||||
return len(msg)
|
return len(msg)
|
||||||
|
|
||||||
|
|
||||||
def test_reader():
|
class TestWireCodecV1(unittest.TestCase):
|
||||||
rep_len = 64
|
|
||||||
interface_num = 0xdeadbeef
|
|
||||||
message_type = 0x4321
|
|
||||||
message_len = 250
|
|
||||||
interface = MockHID(interface_num)
|
|
||||||
reader = codec_v1.Reader(interface)
|
|
||||||
|
|
||||||
message = bytearray(range(message_len))
|
def test_reader(self):
|
||||||
report_header = bytearray(unhexlify('3f23234321000000fa'))
|
rep_len = 64
|
||||||
|
interface_num = 0xdeadbeef
|
||||||
|
message_type = 0x4321
|
||||||
|
message_len = 250
|
||||||
|
interface = MockHID(interface_num)
|
||||||
|
reader = codec_v1.Reader(interface)
|
||||||
|
|
||||||
# open, expected one read
|
message = bytearray(range(message_len))
|
||||||
first_report = report_header + message[:rep_len - len(report_header)]
|
report_header = bytearray(unhexlify('3f23234321000000fa'))
|
||||||
assert_async(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ])
|
|
||||||
assert_eq(reader.type, message_type)
|
|
||||||
assert_eq(reader.size, message_len)
|
|
||||||
|
|
||||||
# empty read
|
# open, expected one read
|
||||||
empty_buffer = bytearray()
|
first_report = report_header + message[:rep_len - len(report_header)]
|
||||||
assert_async(reader.areadinto(empty_buffer), [(None, StopIteration()), ])
|
self.assertAsync(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ])
|
||||||
assert_eq(len(empty_buffer), 0)
|
self.assertEqual(reader.type, message_type)
|
||||||
assert_eq(reader.size, message_len)
|
self.assertEqual(reader.size, message_len)
|
||||||
|
|
||||||
# short read, expected no read
|
# empty read
|
||||||
short_buffer = bytearray(32)
|
empty_buffer = bytearray()
|
||||||
assert_async(reader.areadinto(short_buffer), [(None, StopIteration()), ])
|
self.assertAsync(reader.areadinto(empty_buffer), [(None, StopIteration()), ])
|
||||||
assert_eq(len(short_buffer), 32)
|
self.assertEqual(len(empty_buffer), 0)
|
||||||
assert_eq(short_buffer, message[:len(short_buffer)])
|
self.assertEqual(reader.size, message_len)
|
||||||
assert_eq(reader.size, message_len - len(short_buffer))
|
|
||||||
|
|
||||||
# aligned read, expected no read
|
# short read, expected no read
|
||||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
short_buffer = bytearray(32)
|
||||||
assert_async(reader.areadinto(aligned_buffer), [(None, StopIteration()), ])
|
self.assertAsync(reader.areadinto(short_buffer), [(None, StopIteration()), ])
|
||||||
assert_eq(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
self.assertEqual(len(short_buffer), 32)
|
||||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
self.assertEqual(short_buffer, message[:len(short_buffer)])
|
||||||
|
self.assertEqual(reader.size, message_len - len(short_buffer))
|
||||||
|
|
||||||
# one byte read, expected one read
|
# aligned read, expected no read
|
||||||
next_report_header = bytearray(unhexlify('3f'))
|
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
||||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
self.assertAsync(reader.areadinto(aligned_buffer), [(None, StopIteration()), ])
|
||||||
onebyte_buffer = bytearray(1)
|
self.assertEqual(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
||||||
assert_async(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ])
|
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
||||||
assert_eq(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
|
||||||
assert_eq(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
|
||||||
|
|
||||||
# too long read, raises eof
|
# one byte read, expected one read
|
||||||
assert_async(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ])
|
next_report_header = bytearray(unhexlify('3f'))
|
||||||
|
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
||||||
|
onebyte_buffer = bytearray(1)
|
||||||
|
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ])
|
||||||
|
self.assertEqual(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
||||||
|
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
||||||
|
|
||||||
# long read, expect multiple reads
|
# too long read, raises eof
|
||||||
start_size = reader.size
|
self.assertAsync(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ])
|
||||||
long_buffer = bytearray(start_size)
|
|
||||||
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
|
||||||
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
|
||||||
report_payload_rest = report_payload[len(report_payload_head):]
|
|
||||||
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
|
||||||
report_payloads = [report_payload_head] + report_payload_rest
|
|
||||||
next_reports = [next_report_header + r for r in report_payloads]
|
|
||||||
expected_syscalls = []
|
|
||||||
for i, _ in enumerate(next_reports):
|
|
||||||
prev_report = next_reports[i - 1] if i > 0 else None
|
|
||||||
expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num)))
|
|
||||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
|
||||||
assert_async(reader.areadinto(long_buffer), expected_syscalls)
|
|
||||||
assert_eq(long_buffer, message[-start_size:])
|
|
||||||
assert_eq(reader.size, 0)
|
|
||||||
|
|
||||||
# one byte read, raises eof
|
# long read, expect multiple reads
|
||||||
assert_async(reader.areadinto(onebyte_buffer), [(None, EOFError()), ])
|
start_size = reader.size
|
||||||
|
long_buffer = bytearray(start_size)
|
||||||
|
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
||||||
|
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
||||||
|
report_payload_rest = report_payload[len(report_payload_head):]
|
||||||
|
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
||||||
|
report_payloads = [report_payload_head] + report_payload_rest
|
||||||
|
next_reports = [next_report_header + r for r in report_payloads]
|
||||||
|
expected_syscalls = []
|
||||||
|
for i, _ in enumerate(next_reports):
|
||||||
|
prev_report = next_reports[i - 1] if i > 0 else None
|
||||||
|
expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num)))
|
||||||
|
expected_syscalls.append((next_reports[-1], StopIteration()))
|
||||||
|
self.assertAsync(reader.areadinto(long_buffer), expected_syscalls)
|
||||||
|
self.assertEqual(long_buffer, message[-start_size:])
|
||||||
|
self.assertEqual(reader.size, 0)
|
||||||
|
|
||||||
|
# one byte read, raises eof
|
||||||
|
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, EOFError()), ])
|
||||||
|
|
||||||
|
|
||||||
def test_writer():
|
def test_writer(self):
|
||||||
rep_len = 64
|
rep_len = 64
|
||||||
interface_num = 0xdeadbeef
|
interface_num = 0xdeadbeef
|
||||||
message_type = 0x87654321
|
message_type = 0x87654321
|
||||||
message_len = 1024
|
message_len = 1024
|
||||||
interface = MockHID(interface_num)
|
interface = MockHID(interface_num)
|
||||||
writer = codec_v1.Writer(interface)
|
writer = codec_v1.Writer(interface)
|
||||||
writer.setheader(message_type, message_len)
|
writer.setheader(message_type, message_len)
|
||||||
|
|
||||||
# init header corresponding to the data above
|
# init header corresponding to the data above
|
||||||
report_header = bytearray(unhexlify('3f2323432100000400'))
|
report_header = bytearray(unhexlify('3f2323432100000400'))
|
||||||
|
|
||||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||||
|
|
||||||
# empty write
|
# empty write
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
assert_async(writer.awrite(bytearray()), [(None, StopIteration()), ])
|
self.assertAsync(writer.awrite(bytearray()), [(None, StopIteration()), ])
|
||||||
assert_eq(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
||||||
assert_eq(writer.size, start_size)
|
self.assertEqual(writer.size, start_size)
|
||||||
|
|
||||||
# short write, expected no report
|
# short write, expected no report
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
short_payload = bytearray(range(4))
|
short_payload = bytearray(range(4))
|
||||||
assert_async(writer.awrite(short_payload), [(None, StopIteration()), ])
|
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ])
|
||||||
assert_eq(writer.size, start_size - len(short_payload))
|
self.assertEqual(writer.size, start_size - len(short_payload))
|
||||||
assert_eq(writer.data,
|
self.assertEqual(writer.data,
|
||||||
report_header +
|
report_header +
|
||||||
short_payload +
|
short_payload +
|
||||||
bytearray(rep_len - len(report_header) - len(short_payload)))
|
bytearray(rep_len - len(report_header) - len(short_payload)))
|
||||||
|
|
||||||
# aligned write, expected one report
|
# aligned write, expected one report
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||||
assert_async(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ])
|
self.assertAsync(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ])
|
||||||
assert_eq(interface.data, [report_header +
|
self.assertEqual(interface.data, [report_header +
|
||||||
short_payload +
|
short_payload +
|
||||||
aligned_payload +
|
aligned_payload +
|
||||||
bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ])
|
bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ])
|
||||||
assert_eq(writer.size, start_size - len(aligned_payload))
|
self.assertEqual(writer.size, start_size - len(aligned_payload))
|
||||||
interface.data.clear()
|
interface.data.clear()
|
||||||
|
|
||||||
# short write, expected no report, but data starts with correct seq and cont marker
|
# short write, expected no report, but data starts with correct seq and cont marker
|
||||||
report_header = bytearray(unhexlify('3f'))
|
report_header = bytearray(unhexlify('3f'))
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
assert_async(writer.awrite(short_payload), [(None, StopIteration()), ])
|
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ])
|
||||||
assert_eq(writer.size, start_size - len(short_payload))
|
self.assertEqual(writer.size, start_size - len(short_payload))
|
||||||
assert_eq(writer.data[:len(report_header) + len(short_payload)],
|
self.assertEqual(writer.data[:len(report_header) + len(short_payload)],
|
||||||
report_header + short_payload)
|
report_header + short_payload)
|
||||||
|
|
||||||
# long write, expected multiple reports
|
# long write, expected multiple reports
|
||||||
start_size = writer.size
|
start_size = writer.size
|
||||||
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
||||||
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
||||||
long_payload = long_payload_head + long_payload_rest
|
long_payload = long_payload_head + long_payload_rest
|
||||||
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
||||||
expected_reports = [report_header + r for r in expected_payloads]
|
expected_reports = [report_header + r for r in expected_payloads]
|
||||||
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
||||||
# test write
|
# test write
|
||||||
expected_write_reports = expected_reports[:-1]
|
expected_write_reports = expected_reports[:-1]
|
||||||
assert_async(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
|
self.assertAsync(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
|
||||||
assert_eq(interface.data, expected_write_reports)
|
self.assertEqual(interface.data, expected_write_reports)
|
||||||
assert_eq(writer.size, start_size - len(long_payload))
|
self.assertEqual(writer.size, start_size - len(long_payload))
|
||||||
interface.data.clear()
|
interface.data.clear()
|
||||||
# test write raises eof
|
# test write raises eof
|
||||||
assert_async(writer.awrite(bytearray(1)), [(None, EOFError())])
|
self.assertAsync(writer.awrite(bytearray(1)), [(None, EOFError())])
|
||||||
assert_eq(interface.data, [])
|
self.assertEqual(interface.data, [])
|
||||||
# test close
|
# test close
|
||||||
expected_close_reports = expected_reports[-1:]
|
expected_close_reports = expected_reports[-1:]
|
||||||
assert_async(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
|
self.assertAsync(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
|
||||||
assert_eq(interface.data, expected_close_reports)
|
self.assertEqual(interface.data, expected_close_reports)
|
||||||
assert_eq(writer.size, 0)
|
self.assertEqual(writer.size, 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
unittest.main()
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from trezor.utils import ensure
|
from trezor.utils import ensure
|
||||||
from utest import assert_async
|
|
||||||
|
|
||||||
|
|
||||||
class SkipTest(Exception):
|
class SkipTest(Exception):
|
||||||
@ -135,6 +134,19 @@ class TestCase:
|
|||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
self.assertEqual(x[i], y[i], msg)
|
self.assertEqual(x[i], y[i], msg)
|
||||||
|
|
||||||
|
def assertAsync(self, task, syscalls):
|
||||||
|
for prev_result, expected in syscalls:
|
||||||
|
if isinstance(expected, Exception):
|
||||||
|
with self.assertRaises(expected.__class__):
|
||||||
|
task.send(prev_result)
|
||||||
|
else:
|
||||||
|
syscall = task.send(prev_result)
|
||||||
|
self.assertObjectEqual(syscall, expected)
|
||||||
|
|
||||||
|
def assertObjectEqual(self, a, b, msg=''):
|
||||||
|
self.assertIsInstance(a, b.__class__, msg)
|
||||||
|
self.assertEqual(a.__dict__, b.__dict__, msg)
|
||||||
|
|
||||||
|
|
||||||
def skip(msg):
|
def skip(msg):
|
||||||
def _decor(fun):
|
def _decor(fun):
|
||||||
@ -188,16 +200,12 @@ def run_class(c, test_result):
|
|||||||
print('class', c.__qualname__)
|
print('class', c.__qualname__)
|
||||||
for name in dir(o):
|
for name in dir(o):
|
||||||
if name.startswith("test"):
|
if name.startswith("test"):
|
||||||
is_async = name.startswith("test_async")
|
|
||||||
print(' ', name, end=' ...')
|
print(' ', name, end=' ...')
|
||||||
m = getattr(o, name)
|
m = getattr(o, name)
|
||||||
try:
|
try:
|
||||||
set_up()
|
set_up()
|
||||||
test_result.testsRun += 1
|
test_result.testsRun += 1
|
||||||
if is_async:
|
m()
|
||||||
assert_async(m(), [(None, StopIteration()), ])
|
|
||||||
else:
|
|
||||||
m()
|
|
||||||
tear_down()
|
tear_down()
|
||||||
print(" ok")
|
print(" ok")
|
||||||
except SkipTest as e:
|
except SkipTest as e:
|
||||||
|
@ -1,138 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'run_tests',
|
|
||||||
'run_test',
|
|
||||||
'assert_eq',
|
|
||||||
'assert_not_eq',
|
|
||||||
'assert_is_instance',
|
|
||||||
'mock_call',
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Running
|
|
||||||
|
|
||||||
|
|
||||||
def run_tests(mod_name='__main__'):
|
|
||||||
ntotal = 0
|
|
||||||
nok = 0
|
|
||||||
nfailed = 0
|
|
||||||
|
|
||||||
for name, test in get_tests(mod_name):
|
|
||||||
result = run_test(test)
|
|
||||||
report_test(name, test, result)
|
|
||||||
ntotal += 1
|
|
||||||
if result:
|
|
||||||
nok += 1
|
|
||||||
else:
|
|
||||||
nfailed += 1
|
|
||||||
break
|
|
||||||
report_total(ntotal, nok, nfailed)
|
|
||||||
|
|
||||||
if nfailed > 0:
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tests(mod_name):
|
|
||||||
module = __import__(mod_name)
|
|
||||||
for name in dir(module):
|
|
||||||
if name.startswith('test_'):
|
|
||||||
yield name, getattr(module, name)
|
|
||||||
|
|
||||||
|
|
||||||
def run_test(test):
|
|
||||||
try:
|
|
||||||
test()
|
|
||||||
except Exception as e:
|
|
||||||
report_exception(e)
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# Reporting
|
|
||||||
|
|
||||||
|
|
||||||
def report_test(name, test, result):
|
|
||||||
if result:
|
|
||||||
print('OK', name)
|
|
||||||
else:
|
|
||||||
print('ERR', name)
|
|
||||||
|
|
||||||
|
|
||||||
def report_exception(exc):
|
|
||||||
sys.print_exception(exc)
|
|
||||||
|
|
||||||
|
|
||||||
def report_total(total, ok, failed):
|
|
||||||
print('Total:', total, 'OK:', ok, 'Failed:', failed)
|
|
||||||
|
|
||||||
|
|
||||||
# Assertions
|
|
||||||
|
|
||||||
|
|
||||||
def assert_eq(a, b, msg=None):
|
|
||||||
assert a == b, msg or format_eq(a, b)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_not_eq(a, b, msg=None):
|
|
||||||
assert a != b, msg or format_not_eq(a, b)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_is_instance(obj, cls, msg=None):
|
|
||||||
assert isinstance(obj, cls), msg or format_is_instance(obj, cls)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_eq_obj(a, b, msg=None):
|
|
||||||
assert_is_instance(a, b.__class__, msg)
|
|
||||||
assert_eq(a.__dict__, b.__dict__, msg)
|
|
||||||
|
|
||||||
|
|
||||||
def format_eq(a, b):
|
|
||||||
return '\n%r\nvs (expected)\n%r' % (a, b)
|
|
||||||
|
|
||||||
|
|
||||||
def format_not_eq(a, b):
|
|
||||||
return '%r not expected to be equal %r' % (a, b)
|
|
||||||
|
|
||||||
|
|
||||||
def format_is_instance(obj, cls):
|
|
||||||
return '%r expected to be instance of %r' % (obj, cls)
|
|
||||||
|
|
||||||
|
|
||||||
def assert_async(task, syscalls):
|
|
||||||
for prev_result, expected in syscalls:
|
|
||||||
if isinstance(expected, Exception):
|
|
||||||
with assert_raises(expected.__class__):
|
|
||||||
task.send(prev_result)
|
|
||||||
else:
|
|
||||||
syscall = task.send(prev_result)
|
|
||||||
assert_eq_obj(syscall, expected)
|
|
||||||
|
|
||||||
|
|
||||||
class assert_raises:
|
|
||||||
|
|
||||||
def __init__(self, exc_type):
|
|
||||||
self.exc_type = exc_type
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
assert exc_type is not None, '%r not raised' % self.exc_type
|
|
||||||
return issubclass(exc_type, self.exc_type)
|
|
||||||
|
|
||||||
|
|
||||||
class mock_call:
|
|
||||||
|
|
||||||
def __init__(self, original, expected):
|
|
||||||
self.original = original
|
|
||||||
self.expected = expected
|
|
||||||
self.record = []
|
|
||||||
|
|
||||||
def __call__(self, *args):
|
|
||||||
self.record.append(args)
|
|
||||||
assert_eq(args, self.expected.pop(0))
|
|
||||||
|
|
||||||
def assert_called_n_times(self, n, msg=None):
|
|
||||||
assert_eq(len(self.record), n, msg)
|
|
Loading…
Reference in New Issue
Block a user