diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index 3423f3f52..8a7a6573b 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -64,6 +64,9 @@ class ChannelContext(Context): def set_channel_state(self, state: ChannelState) -> None: self.channel_cache.state = bytearray(state.value.to_bytes(1, "big")) + def set_buffer(self, buffer: utils.BufferType) -> None: + self.buffer = buffer + # CALLED BY THP_MAIN_LOOP async def receive_packet(self, packet: utils.BufferType): @@ -239,11 +242,13 @@ class ChannelContext(Context): return THP.sync_get_send_bit(self.channel_cache) != sync_bit -def load_cached_channels() -> dict[int, ChannelContext]: # TODO +def load_cached_channels(buffer: utils.BufferType) -> dict[int, ChannelContext]: # TODO channels: dict[int, ChannelContext] = {} cached_channels = cache_thp.get_all_allocated_channels() for c in cached_channels: channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c) + for c in channels.values(): + c.set_buffer(buffer) return channels diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 9c6b9a04d..c1f5b1621 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -58,7 +58,8 @@ def set_buffer(buffer): async def thp_main_loop(iface: WireInterface, is_debug_session=False): global _CHANNEL_CONTEXTS - _CHANNEL_CONTEXTS = load_cached_channels() + global _BUFFER + _CHANNEL_CONTEXTS = load_cached_channels(_BUFFER) read = loop.wait(iface.iface_num() | io.POLL_READ)