1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

refactor(python,tests): add is_invalidated flag to client, do not set the emulator.client outside from emulator

[no changelog]
This commit is contained in:
M1nd3r 2024-12-04 16:19:22 +01:00
parent 0335a81c44
commit 869e8107ce
7 changed files with 15 additions and 12 deletions

View File

@ -283,7 +283,6 @@ def cli(
assert emulator.client is not None assert emulator.client is not None
trezorlib.device.wipe(emulator.client.get_management_session()) trezorlib.device.wipe(emulator.client.get_management_session())
emulator.client = emulator.client.get_new_client()
trezorlib.debuglink.load_device( trezorlib.debuglink.load_device(
emulator.client.get_management_session(), emulator.client.get_management_session(),

View File

@ -93,17 +93,10 @@ class Emulator:
""" """
if self._client is None: if self._client is None:
raise RuntimeError raise RuntimeError
if self._client.is_invalidated:
self._client = self._client.get_new_client()
return self._client return self._client
@client.setter
def client(self, new_client: TrezorClientDebugLink) -> None:
"""Setter for the client property to update _client."""
if not isinstance(new_client, TrezorClientDebugLink):
raise TypeError(
f"Expected a TrezorClientDebugLink, got {type(new_client).__name__}."
)
self._client = new_client
def make_args(self) -> List[str]: def make_args(self) -> List[str]:
return [] return []

View File

@ -72,6 +72,7 @@ class TrezorClient:
protobuf_mapping: ProtobufMapping | None = None, protobuf_mapping: ProtobufMapping | None = None,
protocol: ProtocolAndChannel | None = None, protocol: ProtocolAndChannel | None = None,
) -> None: ) -> None:
self._is_invalidated: bool = False
self.transport = transport self.transport = transport
if protobuf_mapping is None: if protobuf_mapping is None:
@ -181,6 +182,9 @@ class TrezorClient:
assert self._management_session is not None assert self._management_session is not None
return self._management_session return self._management_session
def invalidate(self) -> None:
self._is_invalidated = True
@property @property
def features(self) -> messages.Features: def features(self) -> messages.Features:
if self._features is None: if self._features is None:
@ -214,6 +218,10 @@ class TrezorClient:
) )
return ver return ver
@property
def is_invalidated(self) -> bool:
return self._is_invalidated
def refresh_features(self) -> None: def refresh_features(self) -> None:
self.protocol.update_features() self.protocol.update_features()
self._features = self.protocol.get_features() self._features = self.protocol.get_features()

View File

@ -139,6 +139,7 @@ def sd_protect(
def wipe(session: "Session") -> "MessageType": def wipe(session: "Session") -> "MessageType":
ret = session.call(messages.WipeDevice()) ret = session.call(messages.WipeDevice())
session.invalidate()
# if not session.features.bootloader_mode: # if not session.features.bootloader_mode:
# session.refresh_features() # session.refresh_features()
return ret return ret

View File

@ -77,6 +77,9 @@ class Session:
) )
return resp.message or "" return resp.message or ""
def invalidate(self) -> None:
self.client.invalidate()
@property @property
def features(self) -> messages.Features: def features(self) -> messages.Features:
return self.client.features return self.client.features

View File

@ -21,7 +21,6 @@ def test_safety_checks_level_after_reboot(
core_emulator: Emulator, set_level: SafetyCheckLevel, after_level: SafetyCheckLevel core_emulator: Emulator, set_level: SafetyCheckLevel, after_level: SafetyCheckLevel
): ):
device.wipe(core_emulator.client.get_management_session()) device.wipe(core_emulator.client.get_management_session())
core_emulator.client = core_emulator.client.get_new_client()
debuglink.load_device( debuglink.load_device(
core_emulator.client.get_management_session(), core_emulator.client.get_management_session(),
mnemonic=MNEMONIC12, mnemonic=MNEMONIC12,

View File

@ -53,7 +53,7 @@ def emulator(gen: str, tag: str) -> Iterator[Emulator]:
pin_protection=False, pin_protection=False,
skip_backup=True, skip_backup=True,
) )
emu.client = emu.client.get_new_client() emu.client.invalidate()
resp = emu.client.get_management_session().call( resp = emu.client.get_management_session().call(
ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST) ApplySettingsCompat(use_passphrase=True, passphrase_source=SOURCE_HOST)
) )