1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-01 19:10:58 +00:00

client: clean up exception and decorator code

This commit is contained in:
matejcik 2018-05-11 15:27:39 +02:00
parent a478dac5f7
commit cc7c8ccb59

View File

@ -83,21 +83,14 @@ def get_buttonrequest_value(code):
class CallException(Exception):
def __init__(self, code, message):
super(CallException, self).__init__()
self.args = [code, message]
class AssertionException(Exception):
def __init__(self, code, message):
self.args = [code, message]
pass
class PinException(CallException):
pass
class field(object):
class field:
# Decorator extracts single value from
# protobuf object. If the field is not
# present, raises an exception.
@ -105,13 +98,14 @@ class field(object):
self.field = field
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
return getattr(ret, self.field)
return wrapped_f
class expect(object):
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages
# or raises an exception
@ -119,6 +113,7 @@ class expect(object):
self.expected = expected
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
@ -130,6 +125,7 @@ class expect(object):
def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
client = args[0]
client.transport.session_begin()
@ -388,19 +384,19 @@ class DebugLinkMixin(object):
try:
expected = self.expected_responses.pop(0)
except IndexError:
raise AssertionException(proto.FailureType.UnexpectedMessage,
"Got %s, but no message has been expected" % repr(msg))
raise AssertionError(proto.FailureType.UnexpectedMessage,
"Got %s, but no message has been expected" % repr(msg))
if msg.__class__ != expected.__class__:
raise AssertionException(proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)))
raise AssertionError(proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)))
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
raise AssertionException(proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)))
raise AssertionError(proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)))
def callback_ButtonRequest(self, msg):
self.DEBUG("ButtonRequest code: " + get_buttonrequest_value(msg.code))