debuglink: add support for arbitrary message filters

(this replaces `debug_processor` from sign_tx)
pull/25/head
matejcik 6 years ago
parent 5087f30a69
commit 3239d53bc0

@ -13,13 +13,12 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from copy import deepcopy
from mnemonic import Mnemonic
from . import messages as proto, tools
from . import messages as proto, tools, protobuf
from .client import TrezorClient
from .protobuf import format_message
from .tools import expect
@ -172,9 +171,10 @@ class TrezorClientDebugLink(TrezorClient):
self.ui = DebugUI(self.debug)
self.in_with_statement = 0
self.button_wait = 0
self.screenshot_id = 0
self.filters = {}
# Always press Yes and provide correct pin
self.setup_debuglink(True, True)
@ -191,8 +191,16 @@ class TrezorClientDebugLink(TrezorClient):
if self.debug:
self.debug.close()
def set_buttonwait(self, secs):
self.button_wait = secs
def set_filter(self, message_type, callback):
self.filters[message_type] = callback
def _filter_message(self, msg):
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def set_input_flow(self, input_flow):
if callable(input_flow):
@ -265,9 +273,13 @@ class TrezorClientDebugLink(TrezorClient):
# self.screenshot_id += 1
resp = super()._raw_read()
resp = self._filter_message(resp)
self._check_request(resp)
return resp
def _raw_write(self, msg):
return super()._raw_write(self._filter_message(msg))
def _raise_unexpected_response(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
@ -294,7 +306,7 @@ class TrezorClientDebugLink(TrezorClient):
output.append("")
if msg is not None:
output.append("Actually received:")
output.append(format_message(msg))
output.append(protobuf.format_message(msg))
else:
output.append("This message was never received.")
raise AssertionError("\n".join(output))
@ -306,7 +318,8 @@ class TrezorClientDebugLink(TrezorClient):
if self.current_response >= len(self.expected_responses):
raise AssertionError(
"No more messages were expected, but we got:\n" + format_message(msg)
"No more messages were expected, but we got:\n"
+ protobuf.format_message(msg)
)
expected = self.expected_responses[self.current_response]

Loading…
Cancel
Save