1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-20 00:59:02 +00:00

fix(tests): record actual_responses for session init

This commit is contained in:
Martin Milata 2025-04-03 00:24:58 +02:00 committed by M1nd3r
parent 18a2527f4c
commit 4f6f073b54
4 changed files with 26 additions and 7 deletions

View File

@ -189,6 +189,12 @@ class TrezorClient:
else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
def _write(self, msg: t.Any) -> None:
self.protocol.write(msg)
def _read(self) -> t.Any:
return self.protocol.read()
def get_default_client(
path: t.Optional[str] = None,

View File

@ -968,13 +968,10 @@ class SessionDebugWrapper(Session):
return self.client
def _write(self, msg: t.Any) -> None:
self._session._write(self.debug_client._filter_message(msg))
self._session._write(msg)
def _read(self) -> t.Any:
resp = self.debug_client._filter_message(self._session._read())
if self.debug_client.actual_responses is not None:
self.debug_client.actual_responses.append(resp)
return resp
return self._session._read()
def resume(self) -> None:
self._session.resume()
@ -1372,6 +1369,15 @@ class TrezorClientDebugLink(TrezorClient):
next(input_flow) # start the generator
def _write(self, msg: t.Any) -> None:
super()._write(self._filter_message(msg))
def _read(self) -> t.Any:
resp = self._filter_message(super()._read())
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def load_device(
session: "Session",

View File

@ -186,7 +186,7 @@ class SessionV1(Session):
self._activate_self()
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
self.client.protocol.write(msg)
self.client._write(msg)
def _activate_self(self) -> None:
if self.client._last_active_session is not self:
@ -197,7 +197,7 @@ class SessionV1(Session):
assert self.client._last_active_session is self
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
return self.client.protocol.read()
return self.client._read()
def init_session(self, derive_cardano: bool | None = None) -> None:
if self.id == b"":

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import typing as t
from ... import messages
from ...mapping import ProtobufMapping
@ -24,3 +25,9 @@ class Channel:
def update_features(self) -> None:
raise NotImplementedError
def read(self, timeout: float | None = None) -> t.Any:
raise NotImplementedError
def write(self, msg: t.Any) -> None:
raise NotImplementedError