From b0a2297b14d07a01f624af5a5037da44e75df82e Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 14 Sep 2020 13:33:42 +0200 Subject: [PATCH] feat(core): convert protobuf decoding errors to DataErrors --- core/src/trezor/wire/__init__.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index ef13bdee94..9783045bbd 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -40,7 +40,7 @@ from trezor import log, loop, messages, ui, utils, workflow from trezor.messages import FailureType from trezor.messages.Failure import Failure from trezor.wire import codec_v1 -from trezor.wire.errors import ActionCancelled, Error +from trezor.wire.errors import ActionCancelled, DataError, Error # Import all errors into namespace, so that `wire.Error` is available from # other packages. @@ -117,6 +117,20 @@ if False: ... +def _wrap_protobuf_load( + reader: protobuf.Reader, + expected_type: Type[protobuf.LoadedMessageType], + field_cache: protobuf.FieldCache = None, +) -> protobuf.LoadedMessageType: + try: + return protobuf.load_message(reader, expected_type, field_cache) + except Exception as e: + if e.args: + raise DataError("Failed to decode message: {}".format(e.args[0])) + else: + raise DataError("Failed to decode message") + + class DummyContext: async def call(self, *argv: Any) -> None: pass @@ -201,8 +215,7 @@ class Context: workflow.idle_timer.touch() # look up the protobuf class and parse the message - pbtype = messages.get_type(msg.type) - return protobuf.load_message(msg.data, pbtype, field_cache) # type: ignore + return _wrap_protobuf_load(msg.data, expected_type, field_cache) async def read_any( self, expected_wire_types: Iterable[int] @@ -235,7 +248,7 @@ class Context: workflow.idle_timer.touch() # parse the message and return it - return protobuf.load_message(msg.data, exptype) + return _wrap_protobuf_load(msg.data, exptype) async def write( self, msg: protobuf.MessageType, field_cache: protobuf.FieldCache = None @@ -353,7 +366,7 @@ async def handle_session( # Try to decode the message according to schema from # `req_type`. Raises if the message is malformed. - req_msg = protobuf.load_message(msg.data, req_type) + req_msg = _wrap_protobuf_load(msg.data, req_type) # At this point, message reports are all processed and # correctly parsed into `req_msg`.