diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index 3423f3f52..27bddd863 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -64,6 +64,9 @@ class ChannelContext(Context): def set_channel_state(self, state: ChannelState) -> None: self.channel_cache.state = bytearray(state.value.to_bytes(1, "big")) + def set_buffer(self, buffer: utils.BufferType) -> None: + self.buffer = buffer + # CALLED BY THP_MAIN_LOOP async def receive_packet(self, packet: utils.BufferType): @@ -102,7 +105,7 @@ class ChannelContext(Context): # TODO use small buffer # TODO for now, we create a new big buffer every time. It should be changed - self.buffer = _get_buffer_for_payload(payload_length, self.buffer) + self.buffer = _get_buffer_for_payload(payload_length, packet) await self._buffer_packet_data(self.buffer, packet, 0) @@ -239,11 +242,13 @@ class ChannelContext(Context): return THP.sync_get_send_bit(self.channel_cache) != sync_bit -def load_cached_channels() -> dict[int, ChannelContext]: # TODO +def load_cached_channels(buffer: utils.BufferType) -> dict[int, ChannelContext]: # TODO channels: dict[int, ChannelContext] = {} cached_channels = cache_thp.get_all_allocated_channels() for c in cached_channels: channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c) + for c in channels.values(): + c.set_buffer(buffer) return channels diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index fb5193eab..c1f5b1621 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -52,13 +52,14 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag def set_buffer(buffer): + global _BUFFER _BUFFER = buffer - print(_BUFFER) # TODO remove async def thp_main_loop(iface: WireInterface, is_debug_session=False): global _CHANNEL_CONTEXTS - _CHANNEL_CONTEXTS = load_cached_channels() + global _BUFFER + _CHANNEL_CONTEXTS = load_cached_channels(_BUFFER) read = loop.wait(iface.iface_num() | io.POLL_READ)