import sys __all__ = [ 'run_tests', 'run_test', 'assert_eq', 'assert_not_eq', 'assert_is_instance', 'mock_call', ] # Running def run_tests(mod_name='__main__'): ntotal = 0 nok = 0 nfailed = 0 for name, test in get_tests(mod_name): result = run_test(test) report_test(name, test, result) ntotal += 1 if result: nok += 1 else: nfailed += 1 break report_total(ntotal, nok, nfailed) if nfailed > 0: sys.exit(1) def get_tests(mod_name): module = __import__(mod_name) for name in dir(module): if name.startswith('test_'): yield name, getattr(module, name) def run_test(test): try: test() except Exception as e: report_exception(e) return False else: return True # Reporting def report_test(name, test, result): if result: print('OK', name) else: print('ERR', name) def report_exception(exc): sys.print_exception(exc) def report_total(total, ok, failed): print('Total:', total, 'OK:', ok, 'Failed:', failed) # Assertions def assert_eq(a, b, msg=None): assert a == b, msg or format_eq(a, b) def assert_not_eq(a, b, msg=None): assert a != b, msg or format_not_eq(a, b) def assert_is_instance(obj, cls, msg=None): assert isinstance(obj, cls), msg or format_is_instance(obj, cls) def assert_eq_obj(a, b, msg=None): assert_is_instance(a, b.__class__, msg) assert_eq(a.__dict__, b.__dict__, msg) def format_eq(a, b): return '\n%r\nvs (expected)\n%r' % (a, b) def format_not_eq(a, b): return '%r not expected to be equal %r' % (a, b) def format_is_instance(obj, cls): return '%r expected to be instance of %r' % (obj, cls) def assert_async(task, syscalls): for prev_result, expected in syscalls: if isinstance(expected, Exception): with assert_raises(expected.__class__): task.send(prev_result) else: syscall = task.send(prev_result) assert_eq_obj(syscall, expected) class assert_raises: def __init__(self, exc_type): self.exc_type = exc_type def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): assert exc_type is not None, '%r not raised' % self.exc_type return issubclass(exc_type, self.exc_type) class mock_call: def __init__(self, original, expected): self.original = original self.expected = expected self.record = [] def __call__(self, *args): self.record.append(args) assert_eq(args, self.expected.pop(0)) def assert_called_n_times(self, n, msg=None): assert_eq(len(self.record), n, msg)