mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
trezorlib: reentrant session handling
This fixes the breakage introduced by transport reshuffles. It's still not great and I'd love to see context manager based sessions. But it's good enough for now.
This commit is contained in:
parent
daf97afb37
commit
85b85c67b3
@ -84,15 +84,23 @@ class BaseClient(object):
|
||||
LOG.info("creating client instance for device: {}".format(transport.get_path()))
|
||||
self.transport = transport
|
||||
self.ui = ui
|
||||
|
||||
self.session_counter = 0
|
||||
super(BaseClient, self).__init__() # *args, **kwargs)
|
||||
|
||||
def open(self):
|
||||
if self.session_counter == 0:
|
||||
self.transport.begin_session()
|
||||
self.session_counter += 1
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
if self.session_counter == 1:
|
||||
self.transport.end_session()
|
||||
self.session_counter -= 1
|
||||
|
||||
def cancel(self):
|
||||
self._raw_write(proto.Cancel())
|
||||
|
||||
@tools.session
|
||||
def call_raw(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
self._raw_write(msg)
|
||||
@ -174,6 +182,7 @@ class ProtocolMixin(object):
|
||||
def set_tx_api(self, tx_api):
|
||||
warnings.warn("set_tx_api is deprecated, use new arguments to sign_tx")
|
||||
|
||||
@tools.session
|
||||
def init_device(self):
|
||||
resp = self.call(proto.Initialize(state=self.state))
|
||||
if not isinstance(resp, proto.Features):
|
||||
|
@ -25,6 +25,8 @@ from .tools import expect
|
||||
class DebugLink:
|
||||
def __init__(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def open(self):
|
||||
self.transport.begin_session()
|
||||
|
||||
def close(self):
|
||||
@ -186,10 +188,13 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
self.set_passphrase("")
|
||||
super().__init__(transport, ui=self.ui)
|
||||
|
||||
def open(self):
|
||||
super().open()
|
||||
self.debug.open()
|
||||
|
||||
def close(self):
|
||||
self.debug.close()
|
||||
super().close()
|
||||
if self.debug:
|
||||
self.debug.close()
|
||||
|
||||
def set_filter(self, message_type, callback):
|
||||
self.filters[message_type] = callback
|
||||
|
@ -41,10 +41,9 @@ class TrezorTest:
|
||||
# self.client.set_buttonwait(3)
|
||||
|
||||
device.wipe(self.client)
|
||||
self.client.transport.begin_session()
|
||||
self.client.open()
|
||||
|
||||
def teardown_method(self, method):
|
||||
self.client.transport.end_session()
|
||||
self.client.close()
|
||||
|
||||
def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False):
|
||||
|
@ -55,15 +55,9 @@ def client():
|
||||
wirelink = get_device()
|
||||
client = TrezorClientDebugLink(wirelink)
|
||||
wipe_device(client)
|
||||
client.transport.begin_session()
|
||||
|
||||
client.open()
|
||||
yield client
|
||||
|
||||
client.transport.end_session()
|
||||
|
||||
# XXX debuglink session must also be closed
|
||||
# client.close accomplishes that for now; going forward, there should
|
||||
# also be proper session handling for debuglink
|
||||
client.close()
|
||||
|
||||
|
||||
|
@ -224,11 +224,11 @@ def session(f):
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(client, *args, **kwargs):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
client.transport.begin_session()
|
||||
client.open()
|
||||
try:
|
||||
return f(client, *args, **kwargs)
|
||||
finally:
|
||||
client.transport.begin_session()
|
||||
client.close()
|
||||
|
||||
return wrapped_f
|
||||
|
||||
|
@ -85,15 +85,16 @@ class Protocol:
|
||||
self.handle = handle
|
||||
self.session_counter = 0
|
||||
|
||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||
def begin_session(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.handle.open()
|
||||
self.session_counter += 1
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
if self.session_counter == 1:
|
||||
self.handle.close()
|
||||
self.session_counter -= 1
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
raise NotImplementedError
|
||||
|
Loading…
Reference in New Issue
Block a user