mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-26 17:38:39 +00:00
trezorlib: reentrant session handling
This fixes the breakage introduced by transport reshuffles. It's still not great and I'd love to see context manager based sessions. But it's good enough for now.
This commit is contained in:
parent
daf97afb37
commit
85b85c67b3
@ -84,15 +84,23 @@ class BaseClient(object):
|
|||||||
LOG.info("creating client instance for device: {}".format(transport.get_path()))
|
LOG.info("creating client instance for device: {}".format(transport.get_path()))
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.ui = ui
|
self.ui = ui
|
||||||
|
|
||||||
|
self.session_counter = 0
|
||||||
super(BaseClient, self).__init__() # *args, **kwargs)
|
super(BaseClient, self).__init__() # *args, **kwargs)
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
if self.session_counter == 0:
|
||||||
|
self.transport.begin_session()
|
||||||
|
self.session_counter += 1
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
if self.session_counter == 1:
|
||||||
|
self.transport.end_session()
|
||||||
|
self.session_counter -= 1
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
self._raw_write(proto.Cancel())
|
self._raw_write(proto.Cancel())
|
||||||
|
|
||||||
@tools.session
|
|
||||||
def call_raw(self, msg):
|
def call_raw(self, msg):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
self._raw_write(msg)
|
self._raw_write(msg)
|
||||||
@ -174,6 +182,7 @@ class ProtocolMixin(object):
|
|||||||
def set_tx_api(self, tx_api):
|
def set_tx_api(self, tx_api):
|
||||||
warnings.warn("set_tx_api is deprecated, use new arguments to sign_tx")
|
warnings.warn("set_tx_api is deprecated, use new arguments to sign_tx")
|
||||||
|
|
||||||
|
@tools.session
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
resp = self.call(proto.Initialize(state=self.state))
|
resp = self.call(proto.Initialize(state=self.state))
|
||||||
if not isinstance(resp, proto.Features):
|
if not isinstance(resp, proto.Features):
|
||||||
|
@ -25,6 +25,8 @@ from .tools import expect
|
|||||||
class DebugLink:
|
class DebugLink:
|
||||||
def __init__(self, transport):
|
def __init__(self, transport):
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
|
||||||
|
def open(self):
|
||||||
self.transport.begin_session()
|
self.transport.begin_session()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@ -186,10 +188,13 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.set_passphrase("")
|
self.set_passphrase("")
|
||||||
super().__init__(transport, ui=self.ui)
|
super().__init__(transport, ui=self.ui)
|
||||||
|
|
||||||
|
def open(self):
|
||||||
|
super().open()
|
||||||
|
self.debug.open()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
self.debug.close()
|
||||||
super().close()
|
super().close()
|
||||||
if self.debug:
|
|
||||||
self.debug.close()
|
|
||||||
|
|
||||||
def set_filter(self, message_type, callback):
|
def set_filter(self, message_type, callback):
|
||||||
self.filters[message_type] = callback
|
self.filters[message_type] = callback
|
||||||
|
@ -41,10 +41,9 @@ class TrezorTest:
|
|||||||
# self.client.set_buttonwait(3)
|
# self.client.set_buttonwait(3)
|
||||||
|
|
||||||
device.wipe(self.client)
|
device.wipe(self.client)
|
||||||
self.client.transport.begin_session()
|
self.client.open()
|
||||||
|
|
||||||
def teardown_method(self, method):
|
def teardown_method(self, method):
|
||||||
self.client.transport.end_session()
|
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False):
|
def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False):
|
||||||
|
@ -55,15 +55,9 @@ def client():
|
|||||||
wirelink = get_device()
|
wirelink = get_device()
|
||||||
client = TrezorClientDebugLink(wirelink)
|
client = TrezorClientDebugLink(wirelink)
|
||||||
wipe_device(client)
|
wipe_device(client)
|
||||||
client.transport.begin_session()
|
|
||||||
|
|
||||||
|
client.open()
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
client.transport.end_session()
|
|
||||||
|
|
||||||
# XXX debuglink session must also be closed
|
|
||||||
# client.close accomplishes that for now; going forward, there should
|
|
||||||
# also be proper session handling for debuglink
|
|
||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -224,11 +224,11 @@ def session(f):
|
|||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped_f(client, *args, **kwargs):
|
def wrapped_f(client, *args, **kwargs):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
client.transport.begin_session()
|
client.open()
|
||||||
try:
|
try:
|
||||||
return f(client, *args, **kwargs)
|
return f(client, *args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
client.transport.begin_session()
|
client.close()
|
||||||
|
|
||||||
return wrapped_f
|
return wrapped_f
|
||||||
|
|
||||||
|
@ -85,15 +85,16 @@ class Protocol:
|
|||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.session_counter = 0
|
self.session_counter = 0
|
||||||
|
|
||||||
|
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||||
def begin_session(self) -> None:
|
def begin_session(self) -> None:
|
||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
self.handle.open()
|
self.handle.open()
|
||||||
self.session_counter += 1
|
self.session_counter += 1
|
||||||
|
|
||||||
def end_session(self) -> None:
|
def end_session(self) -> None:
|
||||||
self.session_counter = max(self.session_counter - 1, 0)
|
if self.session_counter == 1:
|
||||||
if self.session_counter == 0:
|
|
||||||
self.handle.close()
|
self.handle.close()
|
||||||
|
self.session_counter -= 1
|
||||||
|
|
||||||
def read(self) -> protobuf.MessageType:
|
def read(self) -> protobuf.MessageType:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
Loading…
Reference in New Issue
Block a user