1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 17:38:39 +00:00

trezorlib: reentrant session handling

This fixes the breakage introduced by transport reshuffles.
It's still not great and I'd love to see context manager based sessions.
But it's good enough for now.
This commit is contained in:
matejcik 2018-11-08 18:08:02 +01:00
parent daf97afb37
commit 85b85c67b3
6 changed files with 25 additions and 17 deletions

View File

@ -84,15 +84,23 @@ class BaseClient(object):
LOG.info("creating client instance for device: {}".format(transport.get_path())) LOG.info("creating client instance for device: {}".format(transport.get_path()))
self.transport = transport self.transport = transport
self.ui = ui self.ui = ui
self.session_counter = 0
super(BaseClient, self).__init__() # *args, **kwargs) super(BaseClient, self).__init__() # *args, **kwargs)
def open(self):
if self.session_counter == 0:
self.transport.begin_session()
self.session_counter += 1
def close(self): def close(self):
pass if self.session_counter == 1:
self.transport.end_session()
self.session_counter -= 1
def cancel(self): def cancel(self):
self._raw_write(proto.Cancel()) self._raw_write(proto.Cancel())
@tools.session
def call_raw(self, msg): def call_raw(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg) self._raw_write(msg)
@ -174,6 +182,7 @@ class ProtocolMixin(object):
def set_tx_api(self, tx_api): def set_tx_api(self, tx_api):
warnings.warn("set_tx_api is deprecated, use new arguments to sign_tx") warnings.warn("set_tx_api is deprecated, use new arguments to sign_tx")
@tools.session
def init_device(self): def init_device(self):
resp = self.call(proto.Initialize(state=self.state)) resp = self.call(proto.Initialize(state=self.state))
if not isinstance(resp, proto.Features): if not isinstance(resp, proto.Features):

View File

@ -25,6 +25,8 @@ from .tools import expect
class DebugLink: class DebugLink:
def __init__(self, transport): def __init__(self, transport):
self.transport = transport self.transport = transport
def open(self):
self.transport.begin_session() self.transport.begin_session()
def close(self): def close(self):
@ -186,10 +188,13 @@ class TrezorClientDebugLink(TrezorClient):
self.set_passphrase("") self.set_passphrase("")
super().__init__(transport, ui=self.ui) super().__init__(transport, ui=self.ui)
def open(self):
super().open()
self.debug.open()
def close(self): def close(self):
super().close()
if self.debug:
self.debug.close() self.debug.close()
super().close()
def set_filter(self, message_type, callback): def set_filter(self, message_type, callback):
self.filters[message_type] = callback self.filters[message_type] = callback

View File

@ -41,10 +41,9 @@ class TrezorTest:
# self.client.set_buttonwait(3) # self.client.set_buttonwait(3)
device.wipe(self.client) device.wipe(self.client)
self.client.transport.begin_session() self.client.open()
def teardown_method(self, method): def teardown_method(self, method):
self.client.transport.end_session()
self.client.close() self.client.close()
def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False): def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False):

View File

@ -55,15 +55,9 @@ def client():
wirelink = get_device() wirelink = get_device()
client = TrezorClientDebugLink(wirelink) client = TrezorClientDebugLink(wirelink)
wipe_device(client) wipe_device(client)
client.transport.begin_session()
client.open()
yield client yield client
client.transport.end_session()
# XXX debuglink session must also be closed
# client.close accomplishes that for now; going forward, there should
# also be proper session handling for debuglink
client.close() client.close()

View File

@ -224,11 +224,11 @@ def session(f):
@functools.wraps(f) @functools.wraps(f)
def wrapped_f(client, *args, **kwargs): def wrapped_f(client, *args, **kwargs):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
client.transport.begin_session() client.open()
try: try:
return f(client, *args, **kwargs) return f(client, *args, **kwargs)
finally: finally:
client.transport.begin_session() client.close()
return wrapped_f return wrapped_f

View File

@ -85,15 +85,16 @@ class Protocol:
self.handle = handle self.handle = handle
self.session_counter = 0 self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None: def begin_session(self) -> None:
if self.session_counter == 0: if self.session_counter == 0:
self.handle.open() self.handle.open()
self.session_counter += 1 self.session_counter += 1
def end_session(self) -> None: def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0) if self.session_counter == 1:
if self.session_counter == 0:
self.handle.close() self.handle.close()
self.session_counter -= 1
def read(self) -> protobuf.MessageType: def read(self) -> protobuf.MessageType:
raise NotImplementedError raise NotImplementedError