diff --git a/trezorlib/client.py b/trezorlib/client.py index 03053949a..c27fb8393 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -83,21 +83,14 @@ def get_buttonrequest_value(code): class CallException(Exception): - def __init__(self, code, message): - super(CallException, self).__init__() - self.args = [code, message] - - -class AssertionException(Exception): - def __init__(self, code, message): - self.args = [code, message] + pass class PinException(CallException): pass -class field(object): +class field: # Decorator extracts single value from # protobuf object. If the field is not # present, raises an exception. @@ -105,13 +98,14 @@ class field(object): self.field = field def __call__(self, f): + @functools.wraps(f) def wrapped_f(*args, **kwargs): ret = f(*args, **kwargs) return getattr(ret, self.field) return wrapped_f -class expect(object): +class expect: # Decorator checks if the method # returned one of expected protobuf messages # or raises an exception @@ -119,6 +113,7 @@ class expect(object): self.expected = expected def __call__(self, f): + @functools.wraps(f) def wrapped_f(*args, **kwargs): ret = f(*args, **kwargs) if not isinstance(ret, self.expected): @@ -130,6 +125,7 @@ class expect(object): def session(f): # Decorator wraps a BaseClient method # with session activation / deactivation + @functools.wraps(f) def wrapped_f(*args, **kwargs): client = args[0] client.transport.session_begin() @@ -388,19 +384,19 @@ class DebugLinkMixin(object): try: expected = self.expected_responses.pop(0) except IndexError: - raise AssertionException(proto.FailureType.UnexpectedMessage, - "Got %s, but no message has been expected" % repr(msg)) + raise AssertionError(proto.FailureType.UnexpectedMessage, + "Got %s, but no message has been expected" % repr(msg)) if msg.__class__ != expected.__class__: - raise AssertionException(proto.FailureType.UnexpectedMessage, - "Expected %s, got %s" % (repr(expected), repr(msg))) + raise AssertionError(proto.FailureType.UnexpectedMessage, + "Expected %s, got %s" % (repr(expected), repr(msg))) for field, value in expected.__dict__.items(): if value is None or value == []: continue if getattr(msg, field) != value: - raise AssertionException(proto.FailureType.UnexpectedMessage, - "Expected %s, got %s" % (repr(expected), repr(msg))) + raise AssertionError(proto.FailureType.UnexpectedMessage, + "Expected %s, got %s" % (repr(expected), repr(msg))) def callback_ButtonRequest(self, msg): self.DEBUG("ButtonRequest code: " + get_buttonrequest_value(msg.code))