1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-20 17:19:01 +00:00

chore(python): fix filters and expected vs actual responses

This commit is contained in:
M1nd3r 2025-04-16 16:52:58 +02:00
parent 7addc2de39
commit 22580b4890
3 changed files with 10 additions and 14 deletions

View File

@ -63,7 +63,10 @@ class TrezorClient:
_last_active_session: SessionV1 | None = None
_session_id_counter: int = 0
_default_pairing_method: int = messages.ThpPairingMethod.CodeEntry
_default_pairing_method: messages.ThpPairingMethod = (
messages.ThpPairingMethod.CodeEntry
)
def __init__(
self,
transport: Transport,
@ -101,7 +104,9 @@ class TrezorClient:
else:
raise Exception("Unknown protocol version")
def do_pairing(self, pairing_method: int | None = None) -> None:
def do_pairing(
self, pairing_method: messages.ThpPairingMethod | None = None
) -> None:
from .transport.session import SessionV2
assert self.protocol_version == ProtocolVersion.V2

View File

@ -1003,6 +1003,7 @@ class SessionDebugWrapper(Session):
msg = self._session._read()
if isinstance(self.client, TrezorClientDebugLink):
msg = self.client._filter_message(msg)
self.client.notify_read(msg)
return msg
def resume(self) -> None:
@ -1098,7 +1099,6 @@ class TrezorClientDebugLink(TrezorClient):
if self.protocol_version is ProtocolVersion.V2:
assert isinstance(self.protocol, ProtocolV2Channel)
self.do_pairing(pairing_method=messages.ThpPairingMethod.SkipPairing)
# self.protocol = self.protocol.get_channel()
self.debug.model = self.model
self.debug.version = self.version
@ -1418,15 +1418,6 @@ class TrezorClientDebugLink(TrezorClient):
next(input_flow) # start the generator
def _write(self, msg: t.Any) -> None:
super()._write(self._filter_message(msg))
def _read(self) -> t.Any:
resp = self._filter_message(super()._read())
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def notify_read(self, msg: protobuf.MessageType) -> None:
try:
if self.actual_responses is not None:

View File

@ -189,7 +189,7 @@ class SessionV1(Session):
self._activate_self()
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
self.client._write(msg)
self.client.protocol.write(msg)
def _activate_self(self) -> None:
if self.client._last_active_session is not self:
@ -200,7 +200,7 @@ class SessionV1(Session):
assert self.client._last_active_session is self
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
return self.client._read()
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None) -> None:
if self.id == b"":