1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-04 09:16:06 +00:00

debuglink: add support for arbitrary message filters

(this replaces `debug_processor` from sign_tx)
This commit is contained in:
matejcik 2018-11-02 16:20:10 +01:00
parent 5087f30a69
commit 3239d53bc0

View File

@ -13,13 +13,12 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from copy import deepcopy
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import messages as proto, tools from . import messages as proto, tools, protobuf
from .client import TrezorClient from .client import TrezorClient
from .protobuf import format_message
from .tools import expect from .tools import expect
@ -172,9 +171,10 @@ class TrezorClientDebugLink(TrezorClient):
self.ui = DebugUI(self.debug) self.ui = DebugUI(self.debug)
self.in_with_statement = 0 self.in_with_statement = 0
self.button_wait = 0
self.screenshot_id = 0 self.screenshot_id = 0
self.filters = {}
# Always press Yes and provide correct pin # Always press Yes and provide correct pin
self.setup_debuglink(True, True) self.setup_debuglink(True, True)
@ -191,8 +191,16 @@ class TrezorClientDebugLink(TrezorClient):
if self.debug: if self.debug:
self.debug.close() self.debug.close()
def set_buttonwait(self, secs): def set_filter(self, message_type, callback):
self.button_wait = secs self.filters[message_type] = callback
def _filter_message(self, msg):
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def set_input_flow(self, input_flow): def set_input_flow(self, input_flow):
if callable(input_flow): if callable(input_flow):
@ -265,9 +273,13 @@ class TrezorClientDebugLink(TrezorClient):
# self.screenshot_id += 1 # self.screenshot_id += 1
resp = super()._raw_read() resp = super()._raw_read()
resp = self._filter_message(resp)
self._check_request(resp) self._check_request(resp)
return resp return resp
def _raw_write(self, msg):
return super()._raw_write(self._filter_message(msg))
def _raise_unexpected_response(self, msg): def _raise_unexpected_response(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
@ -294,7 +306,7 @@ class TrezorClientDebugLink(TrezorClient):
output.append("") output.append("")
if msg is not None: if msg is not None:
output.append("Actually received:") output.append("Actually received:")
output.append(format_message(msg)) output.append(protobuf.format_message(msg))
else: else:
output.append("This message was never received.") output.append("This message was never received.")
raise AssertionError("\n".join(output)) raise AssertionError("\n".join(output))
@ -306,7 +318,8 @@ class TrezorClientDebugLink(TrezorClient):
if self.current_response >= len(self.expected_responses): if self.current_response >= len(self.expected_responses):
raise AssertionError( raise AssertionError(
"No more messages were expected, but we got:\n" + format_message(msg) "No more messages were expected, but we got:\n"
+ protobuf.format_message(msg)
) )
expected = self.expected_responses[self.current_response] expected = self.expected_responses[self.current_response]