From ac6936d65d19d3cf2b1c2a6bfe86cd897583e4dc Mon Sep 17 00:00:00 2001 From: matejcik Date: Wed, 19 Oct 2022 11:43:56 +0200 Subject: [PATCH] fix(core/debug): separate buffer for debuglink to prevent BufferLock deadlocks [no changelog] --- core/src/trezor/wire/__init__.py | 11 ++- core/src/trezor/wire/codec_v1.py | 125 ++++++++++++------------------- 2 files changed, 57 insertions(+), 79 deletions(-) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 8d18adb93..5d53e0bef 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -136,6 +136,10 @@ PROTOBUF_BUFFER_SIZE = 8192 WIRE_BUFFER = bytearray(PROTOBUF_BUFFER_SIZE) +if __debug__: + PROTOBUF_BUFFER_SIZE_DEBUG = 1024 + WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) + class Context: def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None: @@ -381,7 +385,12 @@ async def _handle_single_message( async def handle_session( iface: WireInterface, session_id: int, is_debug_session: bool = False ) -> None: - ctx = Context(iface, session_id, WIRE_BUFFER) + if __debug__ and is_debug_session: + ctx_buffer = WIRE_BUFFER_DEBUG + else: + ctx_buffer = WIRE_BUFFER + + ctx = Context(iface, session_id, ctx_buffer) next_msg: codec_v1.Message | None = None if __debug__ and is_debug_session: diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 459ebb7ee..045b868bb 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING from trezor import io, loop, utils if TYPE_CHECKING: - from typing import Any from trezorio import WireInterface _REP_LEN = const(64) @@ -19,34 +18,6 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report SESSION_ID = const(0) INVALID_TYPE = const(-1) -# The wire buffer is shared between the main wire interface and debuglink -# (see __init__.py). There's no obvious guarantee that both interfaces won't -# use it at the same time, thus we check this at runtime in debug builds. -if __debug__: - - class BufferLock: # type: ignore [Class declaration "BufferLock" is obscured by a declaration of the same name] - def __init__(self) -> None: - self.in_use = False - - def __enter__(self) -> None: - assert not self.in_use, "global buffer already used by another context" - self.in_use = True - - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: - self.in_use = False - -else: - - class BufferLock: - def __enter__(self) -> None: - pass - - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: - pass - - -buffer_lock = BufferLock() - class CodecError(Exception): pass @@ -71,32 +42,31 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag read_and_throw_away = False - with buffer_lock: - if msize > len(buffer): - # allocate a new buffer to fit the message - try: - mdata: utils.BufferType = bytearray(msize) - except MemoryError: - mdata = bytearray(_REP_LEN) - read_and_throw_away = True + if msize > len(buffer): + # allocate a new buffer to fit the message + try: + mdata: utils.BufferType = bytearray(msize) + except MemoryError: + mdata = bytearray(_REP_LEN) + read_and_throw_away = True + else: + # reuse a part of the supplied buffer + mdata = memoryview(buffer)[:msize] + + # buffer the initial data + nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA) + + while nread < msize: + # wait for continuation report + report = await read + if report[0] != _REP_MARKER: + raise CodecError("Invalid magic") + + # buffer the continuation data + if read_and_throw_away: + nread += len(report) - 1 else: - # reuse a part of the supplied buffer - mdata = memoryview(buffer)[:msize] - - # buffer the initial data - nread = utils.memcpy(mdata, 0, report, _REP_INIT_DATA) - - while nread < msize: - # wait for continuation report - report = await read - if report[0] != _REP_MARKER: - raise CodecError("Invalid magic") - - # buffer the continuation data - if read_and_throw_away: - nread += len(report) - 1 - else: - nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA) + nread += utils.memcpy(mdata, nread, report, _REP_CONT_DATA) if read_and_throw_away: raise CodecError("Message too large") @@ -107,31 +77,30 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: write = loop.wait(iface.iface_num() | io.POLL_WRITE) - with buffer_lock: - # gather data from msg - msize = len(mdata) + # gather data from msg + msize = len(mdata) - # prepare the report buffer with header data - report = bytearray(_REP_LEN) - repofs = _REP_INIT_DATA - ustruct.pack_into( - _REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize - ) + # prepare the report buffer with header data + report = bytearray(_REP_LEN) + repofs = _REP_INIT_DATA + ustruct.pack_into( + _REP_INIT, report, 0, _REP_MARKER, _REP_MAGIC, _REP_MAGIC, mtype, msize + ) - nwritten = 0 + nwritten = 0 + while True: + # copy as much as possible to the report buffer + nwritten += utils.memcpy(report, repofs, mdata, nwritten) + + # write the report while True: - # copy as much as possible to the report buffer - nwritten += utils.memcpy(report, repofs, mdata, nwritten) - - # write the report - while True: - await write - n = iface.write(report) - if n == len(report): - break - - # if we have more data to write, use continuation reports for it - if nwritten < msize: - repofs = _REP_CONT_DATA - else: + await write + n = iface.write(report) + if n == len(report): break + + # if we have more data to write, use continuation reports for it + if nwritten < msize: + repofs = _REP_CONT_DATA + else: + break