From a8aa7fe632917fcfc45ac221476a434f9ae8f512 Mon Sep 17 00:00:00 2001
From: M1nd3r <petrsedlacek.km@seznam.cz>
Date: Tue, 4 Feb 2025 15:19:56 +0100
Subject: [PATCH] feat(python): implement session based trezorlib

Co-authored-by: mmilata <martin@martinmilata.cz>
---
 python/.changelog.d/4577.changed          |   1 +
 python/src/trezorlib/authentication.py    |   6 +-
 python/src/trezorlib/benchmark.py         |  10 +-
 python/src/trezorlib/binance.py           |  18 +-
 python/src/trezorlib/btc.py               |  51 +-
 python/src/trezorlib/cardano.py           |  36 +-
 python/src/trezorlib/client.py            | 612 ++++++---------------
 python/src/trezorlib/debuglink.py         | 633 +++++++++++++++++-----
 python/src/trezorlib/device.py            | 173 +++---
 python/src/trezorlib/eos.py               |  15 +-
 python/src/trezorlib/ethereum.py          |  48 +-
 python/src/trezorlib/fido.py              |  22 +-
 python/src/trezorlib/firmware/__init__.py |  18 +-
 python/src/trezorlib/mapping.py           |   1 +
 python/src/trezorlib/misc.py              |  26 +-
 python/src/trezorlib/monero.py            |  10 +-
 python/src/trezorlib/nem.py               |  10 +-
 python/src/trezorlib/ripple.py            |  10 +-
 python/src/trezorlib/solana.py            |  14 +-
 python/src/trezorlib/stellar.py           |  12 +-
 python/src/trezorlib/tezos.py             |  14 +-
 python/src/trezorlib/tools.py             |  19 +-
 python/src/trezorlib/transport/session.py | 152 ++++++
 23 files changed, 1048 insertions(+), 863 deletions(-)
 create mode 100644 python/.changelog.d/4577.changed
 create mode 100644 python/src/trezorlib/transport/session.py

diff --git a/python/.changelog.d/4577.changed b/python/.changelog.d/4577.changed
new file mode 100644
index 0000000000..971618ec04
--- /dev/null
+++ b/python/.changelog.d/4577.changed
@@ -0,0 +1 @@
+Changed trezorlib to session-based. Changes also affect trezorctl, python tools, and tests.
diff --git a/python/src/trezorlib/authentication.py b/python/src/trezorlib/authentication.py
index 08c32c3735..28b8e16056 100644
--- a/python/src/trezorlib/authentication.py
+++ b/python/src/trezorlib/authentication.py
@@ -10,7 +10,7 @@ from cryptography.hazmat.primitives import hashes, serialization
 from cryptography.hazmat.primitives.asymmetric import ec, utils
 
 from . import device
-from .client import TrezorClient
+from .transport.session import Session
 
 LOG = logging.getLogger(__name__)
 
@@ -349,7 +349,7 @@ def verify_authentication_response(
 
 
 def authenticate_device(
-    client: TrezorClient,
+    session: Session,
     challenge: bytes | None = None,
     *,
     whitelist: t.Collection[bytes] | None = None,
@@ -359,7 +359,7 @@ def authenticate_device(
     if challenge is None:
         challenge = secrets.token_bytes(16)
 
-    resp = device.authenticate(client, challenge)
+    resp = device.authenticate(session, challenge)
 
     return verify_authentication_response(
         challenge,
diff --git a/python/src/trezorlib/benchmark.py b/python/src/trezorlib/benchmark.py
index 6587e2a3ab..64218b7aad 100644
--- a/python/src/trezorlib/benchmark.py
+++ b/python/src/trezorlib/benchmark.py
@@ -19,16 +19,16 @@ from typing import TYPE_CHECKING
 from . import messages
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
+    from .transport.session import Session
 
 
 def list_names(
-    client: "TrezorClient",
+    session: "Session",
 ) -> messages.BenchmarkNames:
-    return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
+    return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
 
 
-def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult:
-    return client.call(
+def run(session: "Session", name: str) -> messages.BenchmarkResult:
+    return session.call(
         messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
     )
diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py
index 938092a2df..6b35db0446 100644
--- a/python/src/trezorlib/binance.py
+++ b/python/src/trezorlib/binance.py
@@ -18,20 +18,19 @@ from typing import TYPE_CHECKING
 
 from . import messages
 from .protobuf import dict_to_proto
-from .tools import session
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     address_n: "Address",
     show_display: bool = False,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         messages.BinanceGetAddress(
             address_n=address_n, show_display=show_display, chunkify=chunkify
         ),
@@ -40,17 +39,16 @@ def get_address(
 
 
 def get_public_key(
-    client: "TrezorClient", address_n: "Address", show_display: bool = False
+    session: "Session", address_n: "Address", show_display: bool = False
 ) -> bytes:
-    return client.call(
+    return session.call(
         messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
         expect=messages.BinancePublicKey,
     ).public_key
 
 
-@session
 def sign_tx(
-    client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False
+    session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False
 ) -> messages.BinanceSignedTx:
     msg = tx_json["msgs"][0]
     tx_msg = tx_json.copy()
@@ -59,7 +57,7 @@ def sign_tx(
     tx_msg["chunkify"] = chunkify
     envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
 
-    client.call(envelope, expect=messages.BinanceTxRequest)
+    session.call(envelope, expect=messages.BinanceTxRequest)
 
     if "refid" in msg:
         msg = dict_to_proto(messages.BinanceCancelMsg, msg)
@@ -70,4 +68,4 @@ def sign_tx(
     else:
         raise ValueError("can not determine msg type")
 
-    return client.call(msg, expect=messages.BinanceSignedTx)
+    return session.call(msg, expect=messages.BinanceSignedTx)
diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py
index 078f486d9e..bd2ded07c4 100644
--- a/python/src/trezorlib/btc.py
+++ b/python/src/trezorlib/btc.py
@@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
 from typing_extensions import Protocol, TypedDict
 
 from . import exceptions, messages
-from .tools import _return_success, prepare_message_bytes, session
+from .tools import _return_success, prepare_message_bytes
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
     class ScriptSig(TypedDict):
         asm: str
@@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
 
 
 def get_public_node(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     ecdsa_curve_name: Optional[str] = None,
     show_display: bool = False,
@@ -116,12 +116,12 @@ def get_public_node(
     unlock_path_mac: Optional[bytes] = None,
 ) -> messages.PublicKey:
     if unlock_path:
-        client.call(
+        session.call(
             messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
             expect=messages.UnlockedPathRequest,
         )
 
-    return client.call(
+    return session.call(
         messages.GetPublicKey(
             address_n=n,
             ecdsa_curve_name=ecdsa_curve_name,
@@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str:
 
 
 def get_authenticated_address(
-    client: "TrezorClient",
+    session: "Session",
     coin_name: str,
     n: "Address",
     show_display: bool = False,
@@ -151,12 +151,12 @@ def get_authenticated_address(
     chunkify: bool = False,
 ) -> messages.Address:
     if unlock_path:
-        client.call(
+        session.call(
             messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
             expect=messages.UnlockedPathRequest,
         )
 
-    return client.call(
+    return session.call(
         messages.GetAddress(
             address_n=n,
             coin_name=coin_name,
@@ -171,13 +171,13 @@ def get_authenticated_address(
 
 
 def get_ownership_id(
-    client: "TrezorClient",
+    session: "Session",
     coin_name: str,
     n: "Address",
     multisig: Optional[messages.MultisigRedeemScriptType] = None,
     script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
 ) -> bytes:
-    return client.call(
+    return session.call(
         messages.GetOwnershipId(
             address_n=n,
             coin_name=coin_name,
@@ -189,7 +189,7 @@ def get_ownership_id(
 
 
 def get_ownership_proof(
-    client: "TrezorClient",
+    session: "Session",
     coin_name: str,
     n: "Address",
     multisig: Optional[messages.MultisigRedeemScriptType] = None,
@@ -200,9 +200,9 @@ def get_ownership_proof(
     preauthorized: bool = False,
 ) -> Tuple[bytes, bytes]:
     if preauthorized:
-        client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
+        session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
 
-    res = client.call(
+    res = session.call(
         messages.GetOwnershipProof(
             address_n=n,
             coin_name=coin_name,
@@ -219,7 +219,7 @@ def get_ownership_proof(
 
 
 def sign_message(
-    client: "TrezorClient",
+    session: "Session",
     coin_name: str,
     n: "Address",
     message: AnyStr,
@@ -227,7 +227,7 @@ def sign_message(
     no_script_type: bool = False,
     chunkify: bool = False,
 ) -> messages.MessageSignature:
-    return client.call(
+    return session.call(
         messages.SignMessage(
             coin_name=coin_name,
             address_n=n,
@@ -241,7 +241,7 @@ def sign_message(
 
 
 def verify_message(
-    client: "TrezorClient",
+    session: "Session",
     coin_name: str,
     address: str,
     signature: bytes,
@@ -249,7 +249,7 @@ def verify_message(
     chunkify: bool = False,
 ) -> bool:
     try:
-        client.call(
+        session.call(
             messages.VerifyMessage(
                 address=address,
                 signature=signature,
@@ -264,9 +264,8 @@ def verify_message(
         return False
 
 
-@session
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     coin_name: str,
     inputs: Sequence[messages.TxInputType],
     outputs: Sequence[messages.TxOutputType],
@@ -314,14 +313,14 @@ def sign_tx(
                 setattr(signtx, name, value)
 
     if unlock_path:
-        client.call(
+        session.call(
             messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
             expect=messages.UnlockedPathRequest,
         )
     elif preauthorized:
-        client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
+        session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
 
-    res = client.call(signtx, expect=messages.TxRequest)
+    res = session.call(signtx, expect=messages.TxRequest)
 
     # Prepare structure for signatures
     signatures: List[Optional[bytes]] = [None] * len(inputs)
@@ -380,7 +379,7 @@ def sign_tx(
         if res.request_type == R.TXPAYMENTREQ:
             assert res.details.request_index is not None
             msg = payment_reqs[res.details.request_index]
-            res = client.call(msg, expect=messages.TxRequest)
+            res = session.call(msg, expect=messages.TxRequest)
         else:
             msg = messages.TransactionType()
             if res.request_type == R.TXMETA:
@@ -410,7 +409,7 @@ def sign_tx(
                     f"Unknown request type - {res.request_type}."
                 )
 
-            res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
+            res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
 
     for i, sig in zip(inputs, signatures):
         if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
@@ -420,7 +419,7 @@ def sign_tx(
 
 
 def authorize_coinjoin(
-    client: "TrezorClient",
+    session: "Session",
     coordinator: str,
     max_rounds: int,
     max_coordinator_fee_rate: int,
@@ -429,7 +428,7 @@ def authorize_coinjoin(
     coin_name: str,
     script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
 ) -> str | None:
-    resp = client.call(
+    resp = session.call(
         messages.AuthorizeCoinJoin(
             coordinator=coordinator,
             max_rounds=max_rounds,
diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py
index 4cbc635f1f..a945cc9b10 100644
--- a/python/src/trezorlib/cardano.py
+++ b/python/src/trezorlib/cardano.py
@@ -35,7 +35,7 @@ from . import messages as m
 from . import tools
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
+    from .transport.session import Session
 
 PROTOCOL_MAGICS = {
     "mainnet": 764824073,
@@ -818,7 +818,7 @@ def _get_collateral_inputs_items(
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     address_parameters: m.CardanoAddressParametersType,
     protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
     network_id: int = NETWORK_IDS["mainnet"],
@@ -826,7 +826,7 @@ def get_address(
     derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         m.CardanoGetAddress(
             address_parameters=address_parameters,
             protocol_magic=protocol_magic,
@@ -840,12 +840,12 @@ def get_address(
 
 
 def get_public_key(
-    client: "TrezorClient",
+    session: "Session",
     address_n: List[int],
     derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
     show_display: bool = False,
 ) -> m.CardanoPublicKey:
-    return client.call(
+    return session.call(
         m.CardanoGetPublicKey(
             address_n=address_n,
             derivation_type=derivation_type,
@@ -856,12 +856,12 @@ def get_public_key(
 
 
 def get_native_script_hash(
-    client: "TrezorClient",
+    session: "Session",
     native_script: m.CardanoNativeScript,
     display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
     derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
 ) -> m.CardanoNativeScriptHash:
-    return client.call(
+    return session.call(
         m.CardanoGetNativeScriptHash(
             script=native_script,
             display_format=display_format,
@@ -872,7 +872,7 @@ def get_native_script_hash(
 
 
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     signing_mode: m.CardanoTxSigningMode,
     inputs: List[InputWithPath],
     outputs: List[OutputWithData],
@@ -907,7 +907,7 @@ def sign_tx(
         signing_mode,
     )
 
-    response = client.call(
+    response = session.call(
         m.CardanoSignTxInit(
             signing_mode=signing_mode,
             inputs_count=len(inputs),
@@ -942,12 +942,12 @@ def sign_tx(
         _get_certificates_items(certificates),
         withdrawals,
     ):
-        response = client.call(tx_item, expect=m.CardanoTxItemAck)
+        response = session.call(tx_item, expect=m.CardanoTxItemAck)
 
     sign_tx_response: Dict[str, Any] = {}
 
     if auxiliary_data is not None:
-        auxiliary_data_supplement = client.call(
+        auxiliary_data_supplement = session.call(
             auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
         )
         if (
@@ -958,25 +958,25 @@ def sign_tx(
                 auxiliary_data_supplement.__dict__
             )
 
-        response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
+        response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
 
     for tx_item in chain(
         _get_mint_items(mint),
         _get_collateral_inputs_items(collateral_inputs),
         required_signers,
     ):
-        response = client.call(tx_item, expect=m.CardanoTxItemAck)
+        response = session.call(tx_item, expect=m.CardanoTxItemAck)
 
     if collateral_return is not None:
         for tx_item in _get_output_items(collateral_return):
-            response = client.call(tx_item, expect=m.CardanoTxItemAck)
+            response = session.call(tx_item, expect=m.CardanoTxItemAck)
 
     for reference_input in reference_inputs:
-        response = client.call(reference_input, expect=m.CardanoTxItemAck)
+        response = session.call(reference_input, expect=m.CardanoTxItemAck)
 
     sign_tx_response["witnesses"] = []
     for witness_request in witness_requests:
-        response = client.call(witness_request, expect=m.CardanoTxWitnessResponse)
+        response = session.call(witness_request, expect=m.CardanoTxWitnessResponse)
         sign_tx_response["witnesses"].append(
             {
                 "type": response.type,
@@ -986,9 +986,9 @@ def sign_tx(
             }
         )
 
-    response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
+    response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
     sign_tx_response["tx_hash"] = response.tx_hash
 
-    response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
+    response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
 
     return sign_tx_response
diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py
index 529992dfb0..2d5cb2398e 100644
--- a/python/src/trezorlib/client.py
+++ b/python/src/trezorlib/client.py
@@ -13,28 +13,22 @@
 #
 # You should have received a copy of the License along with this library.
 # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
-
 from __future__ import annotations
 
 import logging
 import os
-import warnings
-from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
-
-from mnemonic import Mnemonic
+import typing as t
+from enum import IntEnum
 
 from . import exceptions, mapping, messages, models
-from .log import DUMP_BYTES
-from .messages import Capability
-from .protobuf import MessageType
-from .tools import parse_path, session
+from .mapping import ProtobufMapping
+from .tools import parse_path
+from .transport import Transport, get_transport
+from .transport.thp.protocol_and_channel import Channel
+from .transport.thp.protocol_v1 import ProtocolV1Channel
 
-if TYPE_CHECKING:
-    from .transport import Transport
-    from .ui import TrezorClientUI
-
-UI = TypeVar("UI", bound="TrezorClientUI")
-MT = TypeVar("MT", bound=MessageType)
+if t.TYPE_CHECKING:
+    from .transport.session import Session
 
 LOG = logging.getLogger(__name__)
 
@@ -51,8 +45,175 @@ Or visit https://suite.trezor.io/
 """.strip()
 
 
+LOG = logging.getLogger(__name__)
+
+
+class ProtocolVersion(IntEnum):
+    UNKNOWN = 0x00
+    PROTOCOL_V1 = 0x01  # Codec
+    PROTOCOL_V2 = 0x02  # THP
+
+
+class TrezorClient:
+    button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
+    passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
+    pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
+
+    _seedless_session: Session | None = None
+    _features: messages.Features | None = None
+    _protocol_version: int
+    _setup_pin: str | None = None  # Should by used only by conftest
+
+    def __init__(
+        self,
+        transport: Transport,
+        protobuf_mapping: ProtobufMapping | None = None,
+        protocol: Channel | None = None,
+    ) -> None:
+        self._is_invalidated: bool = False
+        self.transport = transport
+
+        if protobuf_mapping is None:
+            self.mapping = mapping.DEFAULT_MAPPING
+        else:
+            self.mapping = protobuf_mapping
+        if protocol is None:
+            self.protocol = self._get_protocol()
+        else:
+            self.protocol = protocol
+        self.protocol.mapping = self.mapping
+
+        if isinstance(self.protocol, ProtocolV1Channel):
+            self._protocol_version = ProtocolVersion.PROTOCOL_V1
+        else:
+            self._protocol_version = ProtocolVersion.UNKNOWN
+
+    @classmethod
+    def resume(
+        cls,
+        transport: Transport,
+        protobuf_mapping: ProtobufMapping | None = None,
+    ) -> TrezorClient:
+        if protobuf_mapping is None:
+            protobuf_mapping = mapping.DEFAULT_MAPPING
+        protocol = ProtocolV1Channel(transport, protobuf_mapping)
+        return TrezorClient(transport, protobuf_mapping, protocol)
+
+    def get_session(
+        self,
+        passphrase: str | object | None = None,
+        derive_cardano: bool = False,
+        session_id: int = 0,
+    ) -> Session:
+        """
+        Returns initialized session (with derived seed).
+
+        Will fail if the device is not initialized
+        """
+        from .transport.session import SessionV1
+
+        if isinstance(self.protocol, ProtocolV1Channel):
+            session = SessionV1.new(
+                self,
+                derive_cardano=derive_cardano,
+                session_id=session_id,
+            )
+            if should_derive:
+                if isinstance(passphrase, str):
+                    temporary = self.passphrase_callback
+                    self.passphrase_callback = get_callback_passphrase_v1(
+                        passphrase=passphrase
+                    )
+                    derive_seed(session)
+                    self.passphrase_callback = temporary
+                elif passphrase is PASSPHRASE_ON_DEVICE:
+                    derive_seed(session)
+
+            return session
+        raise NotImplementedError
+
+    def resume_session(self, session: Session):
+        """
+        Note: this function potentially modifies the input session.
+        """
+        from .transport.session import SessionV1
+
+        if isinstance(session, SessionV1):
+            session.init_session()
+            return session
+        else:
+            raise NotImplementedError
+
+    def get_seedless_session(self, new_session: bool = False) -> Session:
+        from .transport.session import SessionV1
+
+        if not new_session and self._seedless_session is not None:
+            return self._seedless_session
+        if isinstance(self.protocol, ProtocolV1Channel):
+            self._seedless_session = SessionV1.new(
+                client=self,
+                passphrase="",
+                derive_cardano=False,
+            )
+        assert self._seedless_session is not None
+        return self._seedless_session
+
+    def invalidate(self) -> None:
+        self._is_invalidated = True
+
+    @property
+    def features(self) -> messages.Features:
+        if self._features is None:
+            self._features = self.protocol.get_features()
+        assert self._features is not None
+        return self._features
+
+    @property
+    def protocol_version(self) -> int:
+        return self._protocol_version
+
+    @property
+    def model(self) -> models.TrezorModel:
+        model = models.detect(self.features)
+        if self.features.vendor not in model.vendors:
+            raise exceptions.TrezorException(
+                f"Unrecognized vendor: {self.features.vendor}"
+            )
+        return model
+
+    @property
+    def version(self) -> tuple[int, int, int]:
+        f = self.features
+        ver = (
+            f.major_version,
+            f.minor_version,
+            f.patch_version,
+        )
+        return ver
+
+    @property
+    def is_invalidated(self) -> bool:
+        return self._is_invalidated
+
+    def refresh_features(self) -> None:
+        self.protocol.update_features()
+        self._features = self.protocol.get_features()
+
+    def _get_protocol(self) -> Channel:
+        self.transport.open()
+
+        protocol = ProtocolV1Channel(self.transport, mapping.DEFAULT_MAPPING)
+
+        protocol.write(messages.Initialize())
+
+        _ = protocol.read()
+        self.transport.close()
+        return protocol
+
+
 def get_default_client(
-    path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any
+    path: t.Optional[str] = None,
+    **kwargs: t.Any,
 ) -> "TrezorClient":
     """Get a client for a connected Trezor device.
 
@@ -62,427 +223,10 @@ def get_default_client(
     the value of TREZOR_PATH env variable, or finds first connected Trezor.
     If no UI is supplied, instantiates the default CLI UI.
     """
-    from .transport import get_transport
-    from .ui import ClickUI
 
     if path is None:
         path = os.getenv("TREZOR_PATH")
 
     transport = get_transport(path, prefix_search=True)
-    if ui is None:
-        ui = ClickUI()
 
-    return TrezorClient(transport, ui, **kwargs)
-
-
-class TrezorClient(Generic[UI]):
-    """Trezor client, a connection to a Trezor device.
-
-    This class allows you to manage connection state, send and receive protobuf
-    messages, handle user interactions, and perform some generic tasks
-    (send a cancel message, initialize or clear a session, ping the device).
-    """
-
-    model: models.TrezorModel
-    transport: "Transport"
-    session_id: Optional[bytes]
-    ui: UI
-    features: messages.Features
-
-    def __init__(
-        self,
-        transport: "Transport",
-        ui: UI,
-        session_id: Optional[bytes] = None,
-        derive_cardano: Optional[bool] = None,
-        model: Optional[models.TrezorModel] = None,
-        _init_device: bool = True,
-    ) -> None:
-        """Create a TrezorClient instance.
-
-        You have to provide a `transport`, i.e., a raw connection to the device. You can
-        use `trezorlib.transport.get_transport` to find one.
-
-        You have to provide a UI implementation for the three kinds of interaction:
-        - button request (notify the user that their interaction is needed)
-        - PIN request (on T1, ask the user to input numbers for a PIN matrix)
-        - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for
-          details.
-
-        You can supply a `session_id` you might have saved in the previous session. If
-        you do, the user might not need to enter their passphrase again.
-
-        You can provide Trezor model information. If not provided, it is detected from
-        the model name reported at initialization time.
-
-        By default, the instance will open a connection to the Trezor device, send an
-        `Initialize` message, set up the `features` field from the response, and connect
-        to a session. By specifying `_init_device=False`, this step is skipped. Notably,
-        this means that `client.features` is unset. Use `client.init_device()` or
-        `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break.
-        Only use this if you are _sure_ that you know what you are doing. This feature
-        might be removed at any time.
-        """
-        LOG.info(f"creating client instance for device: {transport.get_path()}")
-        # Here, self.model could be set to None. Unless _init_device is False, it will
-        # get correctly reconfigured as part of the init_device flow.
-        self.model = model  # type: ignore ["None" is incompatible with "TrezorModel"]
-        if self.model:
-            self.mapping = self.model.default_mapping
-        else:
-            self.mapping = mapping.DEFAULT_MAPPING
-        self.transport = transport
-        self.ui = ui
-        self.session_counter = 0
-        self.session_id = session_id
-        if _init_device:
-            self.init_device(session_id=session_id, derive_cardano=derive_cardano)
-
-    def open(self) -> None:
-        if self.session_counter == 0:
-            self.transport.begin_session()
-        self.session_counter += 1
-
-    def close(self) -> None:
-        self.session_counter = max(self.session_counter - 1, 0)
-        if self.session_counter == 0:
-            # TODO call EndSession here?
-            self.transport.end_session()
-
-    def cancel(self) -> None:
-        self._raw_write(messages.Cancel())
-
-    def call_raw(self, msg: MessageType) -> MessageType:
-        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
-        self._raw_write(msg)
-        return self._raw_read()
-
-    def _raw_write(self, msg: MessageType) -> None:
-        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
-        LOG.debug(
-            f"sending message: {msg.__class__.__name__}",
-            extra={"protobuf": msg},
-        )
-        msg_type, msg_bytes = self.mapping.encode(msg)
-        LOG.log(
-            DUMP_BYTES,
-            f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
-        )
-        self.transport.write(msg_type, msg_bytes)
-
-    def _raw_read(self) -> MessageType:
-        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
-        msg_type, msg_bytes = self.transport.read()
-        LOG.log(
-            DUMP_BYTES,
-            f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
-        )
-        msg = self.mapping.decode(msg_type, msg_bytes)
-        LOG.debug(
-            f"received message: {msg.__class__.__name__}",
-            extra={"protobuf": msg},
-        )
-        return msg
-
-    def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
-        try:
-            pin = self.ui.get_pin(msg.type)
-        except exceptions.Cancelled:
-            self.call_raw(messages.Cancel())
-            raise
-
-        if any(d not in "123456789" for d in pin) or not (
-            1 <= len(pin) <= MAX_PIN_LENGTH
-        ):
-            self.call_raw(messages.Cancel())
-            raise ValueError("Invalid PIN provided")
-
-        resp = self.call_raw(messages.PinMatrixAck(pin=pin))
-        if isinstance(resp, messages.Failure) and resp.code in (
-            messages.FailureType.PinInvalid,
-            messages.FailureType.PinCancelled,
-            messages.FailureType.PinExpected,
-        ):
-            raise exceptions.PinException(resp.code, resp.message)
-        else:
-            return resp
-
-    def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType:
-        available_on_device = Capability.PassphraseEntry in self.features.capabilities
-
-        def send_passphrase(
-            passphrase: Optional[str] = None, on_device: Optional[bool] = None
-        ) -> MessageType:
-            msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
-            resp = self.call_raw(msg)
-            if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
-                self.session_id = resp.state
-                resp = self.call_raw(messages.Deprecated_PassphraseStateAck())
-            return resp
-
-        # short-circuit old style entry
-        if msg._on_device is True:
-            return send_passphrase(None, None)
-
-        try:
-            passphrase = self.ui.get_passphrase(available_on_device=available_on_device)
-        except exceptions.Cancelled:
-            self.call_raw(messages.Cancel())
-            raise
-
-        if passphrase is PASSPHRASE_ON_DEVICE:
-            if not available_on_device:
-                self.call_raw(messages.Cancel())
-                raise RuntimeError("Device is not capable of entering passphrase")
-            else:
-                return send_passphrase(on_device=True)
-
-        # else process host-entered passphrase
-        if not isinstance(passphrase, str):
-            raise RuntimeError("Passphrase must be a str")
-        passphrase = Mnemonic.normalize_string(passphrase)
-        if len(passphrase) > MAX_PASSPHRASE_LENGTH:
-            self.call_raw(messages.Cancel())
-            raise ValueError("Passphrase too long")
-
-        return send_passphrase(passphrase, on_device=False)
-
-    def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
-        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
-        # do this raw - send ButtonAck first, notify UI later
-        self._raw_write(messages.ButtonAck())
-        self.ui.button_request(msg)
-        return self._raw_read()
-
-    @session
-    def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
-        self.check_firmware_version()
-        resp = self.call_raw(msg)
-        while True:
-            if isinstance(resp, messages.PinMatrixRequest):
-                resp = self._callback_pin(resp)
-            elif isinstance(resp, messages.PassphraseRequest):
-                resp = self._callback_passphrase(resp)
-            elif isinstance(resp, messages.ButtonRequest):
-                resp = self._callback_button(resp)
-            elif isinstance(resp, messages.Failure):
-                if resp.code == messages.FailureType.ActionCancelled:
-                    raise exceptions.Cancelled
-                raise exceptions.TrezorFailure(resp)
-            elif not isinstance(resp, expect):
-                raise exceptions.UnexpectedMessageError(expect, resp)
-            else:
-                return resp
-
-    def _refresh_features(self, features: messages.Features) -> None:
-        """Update internal fields based on passed-in Features message."""
-
-        if not self.model:
-            self.model = models.detect(features)
-
-        if features.vendor not in self.model.vendors:
-            raise exceptions.TrezorException(f"Unrecognized vendor: {features.vendor}")
-
-        self.features = features
-        self.version = (
-            self.features.major_version,
-            self.features.minor_version,
-            self.features.patch_version,
-        )
-        self.check_firmware_version(warn_only=True)
-        if self.features.session_id is not None:
-            self.session_id = self.features.session_id
-            self.features.session_id = None
-
-    @session
-    def refresh_features(self) -> messages.Features:
-        """Reload features from the device.
-
-        Should be called after changing settings or performing operations that affect
-        device state.
-        """
-        resp = self.call_raw(messages.GetFeatures())
-        if not isinstance(resp, messages.Features):
-            raise exceptions.TrezorException("Unexpected response to GetFeatures")
-        self._refresh_features(resp)
-        return resp
-
-    @session
-    def init_device(
-        self,
-        *,
-        session_id: Optional[bytes] = None,
-        new_session: bool = False,
-        derive_cardano: Optional[bool] = None,
-    ) -> Optional[bytes]:
-        """Initialize the device and return a session ID.
-
-        You can optionally specify a session ID. If the session still exists on the
-        device, the same session ID will be returned and the session is resumed.
-        Otherwise a different session ID is returned.
-
-        Specify `new_session=True` to open a fresh session. Since firmware version
-        1.9.0/2.3.0, the previous session will remain cached on the device, and can be
-        resumed by calling `init_device` again with the appropriate session ID.
-
-        If neither `new_session` nor `session_id` is specified, the current session ID
-        will be reused. If no session ID was cached, a new session ID will be allocated
-        and returned.
-
-        # Version notes:
-
-        Trezor One older than 1.9.0 does not have session management. Optional arguments
-        have no effect and the function returns None
-
-        Trezor T older than 2.3.0 does not have session cache. Requesting a new session
-        will overwrite the old one. In addition, this function will always return None.
-        A valid session_id can be obtained from the `session_id` attribute, but only
-        after a passphrase-protected call is performed. You can use the following code:
-
-        >>> client.init_device()
-        >>> client.ensure_unlocked()
-        >>> valid_session_id = client.session_id
-        """
-        if new_session:
-            self.session_id = None
-        elif session_id is not None:
-            self.session_id = session_id
-
-        resp = self.call_raw(
-            messages.Initialize(
-                session_id=self.session_id,
-                derive_cardano=derive_cardano,
-            )
-        )
-        if isinstance(resp, messages.Failure):
-            # can happen if `derive_cardano` does not match the current session
-            raise exceptions.TrezorFailure(resp)
-        if not isinstance(resp, messages.Features):
-            raise exceptions.TrezorException("Unexpected response to Initialize")
-
-        if self.session_id is not None and resp.session_id == self.session_id:
-            LOG.info("Successfully resumed session")
-        elif session_id is not None:
-            LOG.info("Failed to resume session")
-
-        # TT < 2.3.0 compatibility:
-        # _refresh_features will clear out the session_id field. We want this function
-        # to return its value, so that callers can rely on it being either a valid
-        # session_id, or None if we can't do that.
-        # Older TT FW does not report session_id in Features and self.session_id might
-        # be invalid because TT will not allocate a session_id until a passphrase
-        # exchange happens.
-        reported_session_id = resp.session_id
-        self._refresh_features(resp)
-        return reported_session_id
-
-    def is_outdated(self) -> bool:
-        if self.features.bootloader_mode:
-            return False
-        return self.version < self.model.minimum_version
-
-    def check_firmware_version(self, warn_only: bool = False) -> None:
-        if self.is_outdated():
-            if warn_only:
-                warnings.warn("Firmware is out of date", stacklevel=2)
-            else:
-                raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
-
-    def ping(self, msg: str, button_protection: bool = False) -> str:
-        # We would like ping to work on any valid TrezorClient instance, but
-        # due to the protection modes, we need to go through self.call, and that will
-        # raise an exception if the firmware is too old.
-        # So we short-circuit the simplest variant of ping with call_raw.
-        if not button_protection:
-            # XXX this should be: `with self:`
-            try:
-                self.open()
-                resp = self.call_raw(messages.Ping(message=msg))
-                if isinstance(resp, messages.ButtonRequest):
-                    # device is PIN-locked.
-                    # respond and hope for the best
-                    resp = self._callback_button(resp)
-                resp = messages.Success.ensure_isinstance(resp)
-                assert resp.message is not None
-                return resp.message
-            finally:
-                self.close()
-
-        resp = self.call(
-            messages.Ping(message=msg, button_protection=button_protection),
-            expect=messages.Success,
-        )
-        assert resp.message is not None
-        return resp.message
-
-    def get_device_id(self) -> Optional[str]:
-        return self.features.device_id
-
-    @session
-    def lock(self, *, _refresh_features: bool = True) -> None:
-        """Lock the device.
-
-        If the device does not have a PIN configured, this will do nothing.
-        Otherwise, a lock screen will be shown and the device will prompt for PIN
-        before further actions.
-
-        This call does _not_ invalidate passphrase cache. If passphrase is in use,
-        the device will not prompt for it after unlocking.
-
-        To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate
-        passphrase cache, use `clear_session()`.
-        """
-        # Private argument _refresh_features can be used internally to avoid
-        # refreshing in cases where we will refresh soon anyway. This is used
-        # in TrezorClient.clear_session()
-        self.call(messages.LockDevice())
-        if _refresh_features:
-            self.refresh_features()
-
-    @session
-    def ensure_unlocked(self) -> None:
-        """Ensure the device is unlocked and a passphrase is cached.
-
-        If the device is locked, this will prompt for PIN. If passphrase is enabled
-        and no passphrase is cached for the current session, the device will also
-        prompt for passphrase.
-
-        After calling this method, further actions on the device will not prompt for
-        PIN or passphrase until the device is locked or the session becomes invalid.
-        """
-        from .btc import get_address
-
-        get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
-        self.refresh_features()
-
-    def end_session(self) -> None:
-        """Close the current session and clear cached passphrase.
-
-        The session will become invalid until `init_device()` is called again.
-        If passphrase is enabled, further actions will prompt for it again.
-
-        This is a no-op in bootloader mode, as it does not support session management.
-        """
-        # since: 2.3.4, 1.9.4
-        try:
-            if not self.features.bootloader_mode:
-                self.call(messages.EndSession())
-        except exceptions.TrezorFailure:
-            # A failure most likely means that the FW version does not support
-            # the EndSession call. We ignore the failure and clear the local session_id.
-            # The client-side end result is identical.
-            pass
-        self.session_id = None
-
-    @session
-    def clear_session(self) -> None:
-        """Lock the device and present a fresh session.
-
-        The current session will be invalidated and a new one will be started. If the
-        device has PIN enabled, it will become locked.
-
-        Equivalent to calling `lock()`, `end_session()` and `init_device()`.
-        """
-        self.lock(_refresh_features=False)
-        self.end_session()
-        self.init_device(new_session=True)
+    return TrezorClient(transport, **kwargs)
diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py
index 75cb90bd46..1fa5658bd7 100644
--- a/python/src/trezorlib/debuglink.py
+++ b/python/src/trezorlib/debuglink.py
@@ -21,56 +21,57 @@ import logging
 import re
 import textwrap
 import time
+import typing as t
 from contextlib import contextmanager
 from copy import deepcopy
 from datetime import datetime
 from enum import Enum, IntEnum, auto
 from itertools import zip_longest
 from pathlib import Path
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    Dict,
-    Generator,
-    Iterable,
-    Iterator,
-    Sequence,
-    Tuple,
-    Union,
-)
 
 from mnemonic import Mnemonic
 
-from . import mapping, messages, models, protobuf
-from .client import TrezorClient
-from .exceptions import TrezorFailure, UnexpectedMessageError
+from . import btc, mapping, messages, models, protobuf
+from .client import (
+    MAX_PASSPHRASE_LENGTH,
+    MAX_PIN_LENGTH,
+    PASSPHRASE_ON_DEVICE,
+    ProtocolVersion,
+    TrezorClient,
+)
+from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
 from .log import DUMP_BYTES
-from .messages import DebugWaitType
+from .messages import Capability, DebugWaitType
+from .protobuf import MessageType
+from .tools import parse_path
+from .transport.session import Session
+from .transport.thp.protocol_v1 import ProtocolV1Channel
 
-if TYPE_CHECKING:
+if t.TYPE_CHECKING:
     from typing_extensions import Protocol
 
     from .messages import PinMatrixRequestType
     from .transport import Transport
 
-    ExpectedMessage = Union[
-        protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
+    ExpectedMessage = t.Union[
+        protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter"
     ]
 
-    AnyDict = Dict[str, Any]
-    Coords = Tuple[int, int]
+    AnyDict = t.Dict[str, t.Any]
+    Coords = t.Tuple[int, int]
 
     class InputFunc(Protocol):
+
         def __call__(
             self,
             hold_ms: int | None = None,
         ) -> "None": ...
 
-    InputFlowType = Generator[None, messages.ButtonRequest, None]
+    InputFlowType = t.Generator[None, messages.ButtonRequest, None]
 
 
 EXPECTED_RESPONSES_CONTEXT_LINES = 3
+PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
 
 LOG = logging.getLogger(__name__)
 
@@ -108,11 +109,11 @@ class UnstructuredJSONReader:
         except json.JSONDecodeError:
             self.dict = {}
 
-    def top_level_value(self, key: str) -> Any:
+    def top_level_value(self, key: str) -> t.Any:
         return self.dict.get(key)
 
-    def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
-        def recursively_find(data: Any) -> Iterator[Any]:
+    def find_objects_with_key_and_value(self, key: str, value: t.Any) -> list[AnyDict]:
+        def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
             if isinstance(data, dict):
                 if data.get(key) == value:
                     yield data
@@ -125,7 +126,7 @@ class UnstructuredJSONReader:
         return list(recursively_find(self.dict))
 
     def find_unique_object_with_key_and_value(
-        self, key: str, value: Any
+        self, key: str, value: t.Any
     ) -> AnyDict | None:
         objects = self.find_objects_with_key_and_value(key, value)
         if not objects:
@@ -133,8 +134,10 @@ class UnstructuredJSONReader:
         assert len(objects) == 1
         return objects[0]
 
-    def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
-        def recursively_find(data: Any) -> Iterator[Any]:
+    def find_values_by_key(
+        self, key: str, only_type: type | None = None
+    ) -> list[t.Any]:
+        def recursively_find(data: t.Any) -> t.Iterator[t.Any]:
             if isinstance(data, dict):
                 if key in data:
                     yield data[key]
@@ -152,8 +155,8 @@ class UnstructuredJSONReader:
         return values
 
     def find_unique_value_by_key(
-        self, key: str, default: Any, only_type: type | None = None
-    ) -> Any:
+        self, key: str, default: t.Any, only_type: type | None = None
+    ) -> t.Any:
         values = self.find_values_by_key(key, only_type=only_type)
         if not values:
             return default
@@ -164,7 +167,7 @@ class UnstructuredJSONReader:
 class LayoutContent(UnstructuredJSONReader):
     """Contains helper functions to extract specific parts of the layout."""
 
-    def __init__(self, json_tokens: Sequence[str]) -> None:
+    def __init__(self, json_tokens: t.Sequence[str]) -> None:
         json_str = "".join(json_tokens)
         super().__init__(json_str)
 
@@ -430,6 +433,7 @@ class DebugLink:
         self.allow_interactions = auto_interact
         self.mapping = mapping.DEFAULT_MAPPING
 
+        self.protocol = ProtocolV1Channel(self.transport, self.mapping)
         # To be set by TrezorClientDebugLink (is not known during creation time)
         self.model: models.TrezorModel | None = None
         self.version: tuple[int, int, int] = (0, 0, 0)
@@ -480,10 +484,16 @@ class DebugLink:
         return ButtonActions(self.layout_type)
 
     def open(self) -> None:
-        self.transport.begin_session()
+        self.transport.open()
+        # raise NotImplementedError
+        # TODO is this needed?
+        # self.transport.deprecated_begin_session()
 
     def close(self) -> None:
-        self.transport.end_session()
+        pass
+        # raise NotImplementedError
+        # TODO is this needed?
+        # self.transport.deprecated_end_session()
 
     def _write(self, msg: protobuf.MessageType) -> None:
         if self.waiting_for_layout_change:
@@ -500,15 +510,10 @@ class DebugLink:
             DUMP_BYTES,
             f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
         )
-        self.transport.write(msg_type, msg_bytes)
+        self.protocol.write(msg)
 
     def _read(self) -> protobuf.MessageType:
-        ret_type, ret_bytes = self.transport.read()
-        LOG.log(
-            DUMP_BYTES,
-            f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}",
-        )
-        msg = self.mapping.decode(ret_type, ret_bytes)
+        msg = self.protocol.read()
 
         # Collapse tokens to make log use less lines.
         msg_for_log = msg
@@ -522,7 +527,7 @@ class DebugLink:
         )
         return msg
 
-    def _call(self, msg: protobuf.MessageType) -> Any:
+    def _call(self, msg: protobuf.MessageType) -> t.Any:
         self._write(msg)
         return self._read()
 
@@ -556,7 +561,7 @@ class DebugLink:
 
     def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
         # Next layout change will be caused by external event
-        # (e.g. device being auto-locked or as a result of device_handler.run(xxx))
+        # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
         # and not by our debug actions/decisions.
         # Resetting the debug state so we wait for the next layout change
         # (and do not return the current state).
@@ -571,7 +576,7 @@ class DebugLink:
         return LayoutContent(obj.tokens)
 
     @contextmanager
-    def wait_for_layout_change(self) -> Iterator[None]:
+    def wait_for_layout_change(self) -> t.Iterator[None]:
         # make sure some current layout is up by issuing a dummy GetState
         self.state()
 
@@ -624,7 +629,7 @@ class DebugLink:
 
         return "".join([str(matrix.index(p) + 1) for p in pin])
 
-    def read_recovery_word(self) -> Tuple[str | None, int | None]:
+    def read_recovery_word(self) -> t.Tuple[str | None, int | None]:
         state = self.state()
         return (state.recovery_fake_word, state.recovery_word_pos)
 
@@ -680,7 +685,7 @@ class DebugLink:
         """Send text input to the device. See `_decision` for more details."""
         self._decision(messages.DebugLinkDecision(input=word))
 
-    def click(self, click: Tuple[int, int], hold_ms: int | None = None) -> None:
+    def click(self, click: t.Tuple[int, int], hold_ms: int | None = None) -> None:
         """Send a click to the device. See `_decision` for more details."""
         x, y = click
         self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms))
@@ -803,10 +808,10 @@ class DebugUI:
         self.clear()
 
     def clear(self) -> None:
-        self.pins: Iterator[str] | None = None
-        self.passphrase = ""
-        self.input_flow: Union[
-            Generator[None, messages.ButtonRequest, None], object, None
+        self.pins: t.Iterator[str] | None = None
+        self.passphrase = None
+        self.input_flow: t.Union[
+            t.Generator[None, messages.ButtonRequest, None], object, None
         ] = None
 
     def _default_input_flow(self, br: messages.ButtonRequest) -> None:
@@ -838,7 +843,7 @@ class DebugUI:
             raise AssertionError("input flow ended prematurely")
         else:
             try:
-                assert isinstance(self.input_flow, Generator)
+                assert isinstance(self.input_flow, t.Generator)
                 self.input_flow.send(br)
             except StopIteration:
                 self.input_flow = self.INPUT_FLOW_DONE
@@ -854,18 +859,21 @@ class DebugUI:
         except StopIteration:
             raise AssertionError("PIN sequence ended prematurely")
 
-    def get_passphrase(self, available_on_device: bool) -> str:
+    def get_passphrase(self, available_on_device: bool) -> str | None | object:
         self.debuglink.snapshot_legacy()
         return self.passphrase
 
 
 class MessageFilter:
-    def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
+
+    def __init__(
+        self, message_type: t.Type[protobuf.MessageType], **fields: t.Any
+    ) -> None:
         self.message_type = message_type
-        self.fields: Dict[str, Any] = {}
+        self.fields: t.Dict[str, t.Any] = {}
         self.update_fields(**fields)
 
-    def update_fields(self, **fields: Any) -> "MessageFilter":
+    def update_fields(self, **fields: t.Any) -> "MessageFilter":
         for name, value in fields.items():
             try:
                 self.fields[name] = self.from_message_or_type(value)
@@ -913,7 +921,7 @@ class MessageFilter:
         return True
 
     def to_string(self, maxwidth: int = 80) -> str:
-        fields: list[Tuple[str, str]] = []
+        fields: list[t.Tuple[str, str]] = []
         for field in self.message_type.FIELDS.values():
             if field.name not in self.fields:
                 continue
@@ -943,7 +951,7 @@ class MessageFilter:
 
 
 class MessageFilterGenerator:
-    def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]:
+    def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]:
         message_type = getattr(messages, key)
         return MessageFilter(message_type).update_fields
 
@@ -951,6 +959,230 @@ class MessageFilterGenerator:
 message_filters = MessageFilterGenerator()
 
 
+class SessionDebugWrapper(Session):
+    def __init__(self, session: Session) -> None:
+        if isinstance(session, SessionDebugWrapper):
+            raise Exception("Cannot wrap already wrapped session!")
+        self.__dict__["_session"] = session
+        self.reset_debug_features()
+
+    def __getattr__(self, name: str) -> t.Any:
+        return getattr(self._session, name)
+
+    def __setattr__(self, name: str, value: t.Any) -> None:
+        if hasattr(self._session, name):
+            setattr(self._session, name, value)
+        else:
+            self.__dict__[name] = value
+
+    @property
+    def protocol_version(self) -> int:
+        return self.client.protocol_version
+
+    def _write(self, msg: t.Any) -> None:
+        print("writing message:", msg.__class__.__name__)
+        self._session._write(self._filter_message(msg))
+
+    def _read(self) -> t.Any:
+        resp = self._filter_message(self._session._read())
+        print("reading message:", resp.__class__.__name__)
+        if self.actual_responses is not None:
+            self.actual_responses.append(resp)
+        return resp
+
+    def set_expected_responses(
+        self,
+        expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
+    ) -> None:
+        """Set a sequence of expected responses to session calls.
+
+        Within a given with-block, the list of received responses from device must
+        match the list of expected responses, otherwise an ``AssertionError`` is raised.
+
+        If an expected response is given a field value other than ``None``, that field value
+        must exactly match the received field value. If a given field is ``None``
+        (or unspecified) in the expected response, the received field value is not
+        checked.
+
+        Each expected response can also be a tuple ``(bool, message)``. In that case, the
+        expected response is only evaluated if the first field is ``True``.
+        This is useful for differentiating sequences between Trezor models:
+
+        >>> trezor_one = session.features.model == "1"
+        >>> session.set_expected_responses([
+        >>>     messages.ButtonRequest(code=ConfirmOutput),
+        >>>     (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
+        >>>     messages.Success(),
+        >>> ])
+        """
+        if not self.in_with_statement:
+            raise RuntimeError("Must be called inside 'with' statement")
+
+        # make sure all items are (bool, message) tuples
+        expected_with_validity = (
+            e if isinstance(e, tuple) else (True, e) for e in expected
+        )
+
+        # only apply those items that are (True, message)
+        self.expected_responses = [
+            MessageFilter.from_message_or_type(expected)
+            for valid, expected in expected_with_validity
+            if valid
+        ]
+        self.actual_responses = []
+
+    def lock(self) -> None:
+        """Lock the device.
+
+        If the device does not have a PIN configured, this will do nothing.
+        Otherwise, a lock screen will be shown and the device will prompt for PIN
+        before further actions.
+
+        This call does _not_ invalidate passphrase cache. If passphrase is in use,
+        the device will not prompt for it after unlocking.
+
+        To invalidate passphrase cache, use `session.end()`. To lock _and_ invalidate
+        passphrase cache, use `session.lock()` followed by `session.end()`.
+        """
+        self.call(messages.LockDevice())
+        self.refresh_features()
+
+    def ensure_unlocked(self) -> None:
+        btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
+        self.refresh_features()
+
+    def set_filter(
+        self,
+        message_type: t.Type[protobuf.MessageType],
+        callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
+    ) -> None:
+        """Configure a filter function for a specified message type.
+
+        The `callback` must be a function that accepts a protobuf message, and returns
+        a (possibly modified) protobuf message of the same type. Whenever a message
+        is sent or received that matches `message_type`, `callback` is invoked on the
+        message and its result is substituted for the original.
+
+        Useful for test scenarios with an active malicious actor on the wire.
+        """
+        if not self.in_with_statement:
+            raise RuntimeError("Must be called inside 'with' statement")
+
+        self.filters[message_type] = callback
+
+    def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
+        message_type = msg.__class__
+        callback = self.filters.get(message_type)
+        if callable(callback):
+            return callback(deepcopy(msg))
+        else:
+            return msg
+
+    def reset_debug_features(self) -> None:
+        """Prepare the debugging session for a new testcase.
+
+        Clears all debugging state that might have been modified by a testcase.
+        """
+        self.in_with_statement = False
+        self.expected_responses: list[MessageFilter] | None = None
+        self.actual_responses: list[protobuf.MessageType] | None = None
+        self.filters: t.Dict[
+            t.Type[protobuf.MessageType],
+            t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
+        ] = {}
+        self.button_callback = self.client.button_callback
+        self.pin_callback = self.client.pin_callback
+        self.passphrase_callback = self._session.passphrase_callback
+
+    def __enter__(self) -> "SessionDebugWrapper":
+        # For usage in with/expected_responses
+        if self.in_with_statement:
+            raise RuntimeError("Do not nest!")
+        self.in_with_statement = True
+        return self
+
+    def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
+        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
+
+        # copy expected/actual responses before clearing them
+        expected_responses = self.expected_responses
+        actual_responses = self.actual_responses
+
+        # grab a copy of the inputflow generator to raise an exception through it
+        if isinstance(self.client, TrezorClientDebugLink) and isinstance(
+            self.client.ui, DebugUI
+        ):
+            input_flow = self.client.ui.input_flow
+        else:
+            input_flow = None
+
+        self.reset_debug_features()
+
+        if exc_type is None:
+            # If no other exception was raised, evaluate missed responses
+            # (raises AssertionError on mismatch)
+            self._verify_responses(expected_responses, actual_responses)
+
+        elif isinstance(input_flow, t.Generator):
+            # Propagate the exception through the input flow, so that we see in
+            # traceback where it is stuck.
+            input_flow.throw(exc_type, value, traceback)
+
+    @classmethod
+    def _verify_responses(
+        cls,
+        expected: list[MessageFilter] | None,
+        actual: list[protobuf.MessageType] | None,
+    ) -> None:
+        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
+
+        if expected is None and actual is None:
+            return
+
+        assert expected is not None
+        assert actual is not None
+
+        for i, (exp, act) in enumerate(zip_longest(expected, actual)):
+            if exp is None:
+                output = cls._expectation_lines(expected, i)
+                output.append("No more messages were expected, but we got:")
+                for resp in actual[i:]:
+                    output.append(
+                        textwrap.indent(protobuf.format_message(resp), "    ")
+                    )
+                raise AssertionError("\n".join(output))
+
+            if act is None:
+                output = cls._expectation_lines(expected, i)
+                output.append("This and the following message was not received.")
+                raise AssertionError("\n".join(output))
+
+            if not exp.match(act):
+                output = cls._expectation_lines(expected, i)
+                output.append("Actually received:")
+                output.append(textwrap.indent(protobuf.format_message(act), "    "))
+                raise AssertionError("\n".join(output))
+
+    @staticmethod
+    def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
+        start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
+        stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
+        output: list[str] = []
+        output.append("Expected responses:")
+        if start_at > 0:
+            output.append(f"    (...{start_at} previous responses omitted)")
+        for i in range(start_at, stop_at):
+            exp = expected[i]
+            prefix = "    " if i != current else ">>> "
+            output.append(textwrap.indent(exp.to_string(), prefix))
+        if stop_at < len(expected):
+            omitted = len(expected) - stop_at
+            output.append(f"    (...{omitted} following responses omitted)")
+
+        output.append("")
+        return output
+
+
 class TrezorClientDebugLink(TrezorClient):
     # This class implements automatic responses
     # and other functionality for unit tests
@@ -976,11 +1208,13 @@ class TrezorClientDebugLink(TrezorClient):
                 raise
 
         # set transport explicitly so that sync_responses can work
-        self.transport = transport
+        super().__init__(transport)
 
-        self.reset_debug_features()
+        self.transport = transport
+        self.ui: DebugUI = DebugUI(self.debug)
+
+        self.reset_debug_features(new_seedless_session=True)
         self.sync_responses()
-        super().__init__(transport, ui=self.ui)
 
         # So that we can choose right screenshotting logic (T1 vs TT)
         # and know the supported debug capabilities
@@ -991,8 +1225,18 @@ class TrezorClientDebugLink(TrezorClient):
     def layout_type(self) -> LayoutType:
         return self.debug.layout_type
 
-    def reset_debug_features(self) -> None:
-        """Prepare the debugging client for a new testcase.
+    def get_new_client(self) -> TrezorClientDebugLink:
+        new_client = TrezorClientDebugLink(
+            self.transport, self.debug.allow_interactions
+        )
+        new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
+        new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory
+        new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
+        return new_client
+
+    def reset_debug_features(self, new_seedless_session: bool = False) -> None:
+        """
+        Prepare the debugging client for a new testcase.
 
         Clears all debugging state that might have been modified by a testcase.
         """
@@ -1000,55 +1244,159 @@ class TrezorClientDebugLink(TrezorClient):
         self.in_with_statement = False
         self.expected_responses: list[MessageFilter] | None = None
         self.actual_responses: list[protobuf.MessageType] | None = None
-        self.filters: dict[
-            type[protobuf.MessageType],
-            Callable[[protobuf.MessageType], protobuf.MessageType] | None,
+        self.filters: t.Dict[
+            t.Type[protobuf.MessageType],
+            t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
         ] = {}
+        if new_seedless_session:
+            self._seedless_session = self.get_seedless_session(new_session=True)
+
+    @property
+    def button_callback(self):
+
+        def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
+            __tracebackhide__ = True  # for pytest # pylint: disable=W0612
+            # do this raw - send ButtonAck first, notify UI later
+            session._write(messages.ButtonAck())
+            self.ui.button_request(msg)
+            return session._read()
+
+        return _callback_button
+
+    @property
+    def pin_callback(self):
+
+        def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
+            try:
+                pin = self.ui.get_pin(msg.type)
+            except Cancelled:
+                session.call_raw(messages.Cancel())
+                raise
+
+            if any(d not in "123456789" for d in pin) or not (
+                1 <= len(pin) <= MAX_PIN_LENGTH
+            ):
+                session.call_raw(messages.Cancel())
+                raise ValueError("Invalid PIN provided")
+            resp = session.call_raw(messages.PinMatrixAck(pin=pin))
+            if isinstance(resp, messages.Failure) and resp.code in (
+                messages.FailureType.PinInvalid,
+                messages.FailureType.PinCancelled,
+                messages.FailureType.PinExpected,
+            ):
+                raise PinException(resp.code, resp.message)
+            else:
+                return resp
+
+        return _callback_pin
+
+    @property
+    def passphrase_callback(self):
+        def _callback_passphrase(
+            session: Session, msg: messages.PassphraseRequest
+        ) -> t.Any:
+            available_on_device = (
+                Capability.PassphraseEntry in session.features.capabilities
+            )
+
+            def send_passphrase(
+                passphrase: str | None = None, on_device: bool | None = None
+            ) -> MessageType:
+                msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
+                resp = session.call_raw(msg)
+                if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
+                    if resp.state is not None:
+                        session.id = resp.state
+                    else:
+                        raise RuntimeError("Object resp.state is None")
+                    resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
+                return resp
+
+            # short-circuit old style entry
+            if msg._on_device is True:
+                return send_passphrase(None, None)
+
+            try:
+                if isinstance(session, SessionDebugWrapper):
+                    passphrase = self.ui.get_passphrase(
+                        available_on_device=available_on_device
+                    )
+                    if passphrase is None:
+                        passphrase = session.passphrase
+                else:
+                    raise NotImplementedError
+            except Cancelled:
+                session.call_raw(messages.Cancel())
+                raise
+
+            if passphrase is PASSPHRASE_ON_DEVICE:
+                if not available_on_device:
+                    session.call_raw(messages.Cancel())
+                    raise RuntimeError("Device is not capable of entering passphrase")
+                else:
+                    return send_passphrase(on_device=True)
+
+            # else process host-entered passphrase
+            if not isinstance(passphrase, str):
+                raise RuntimeError("Passphrase must be a str")
+            passphrase = Mnemonic.normalize_string(passphrase)
+            if len(passphrase) > MAX_PASSPHRASE_LENGTH:
+                session.call_raw(messages.Cancel())
+                raise ValueError("Passphrase too long")
+
+            return send_passphrase(passphrase, on_device=False)
+
+        return _callback_passphrase
 
     def ensure_open(self) -> None:
         """Only open session if there isn't already an open one."""
-        if self.session_counter == 0:
-            self.open()
+        # if self.session_counter == 0:
+        #     self.open()
+        # TODO check if is this needed
 
     def open(self) -> None:
-        super().open()
-        if self.session_counter == 1:
-            self.debug.open()
+        pass
+        # TODO is this needed?
+        # self.debug.open()
 
     def close(self) -> None:
-        if self.session_counter == 1:
-            self.debug.close()
-        super().close()
+        pass
+        # TODO is this needed?
+        # self.debug.close()
 
-    def set_filter(
+    def lock(self) -> None:
+        s = self.get_seedless_session()
+        s.lock()
+
+    def get_session(
         self,
-        message_type: type[protobuf.MessageType],
-        callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
-    ) -> None:
-        """Configure a filter function for a specified message type.
+        passphrase: str | object | None = "",
+        derive_cardano: bool = False,
+        session_id: int = 0,
+    ) -> SessionDebugWrapper:
+        if isinstance(passphrase, str):
+            passphrase = Mnemonic.normalize_string(passphrase)
+        return SessionDebugWrapper(
+            super().get_session(passphrase, derive_cardano, session_id)
+        )
 
-        The `callback` must be a function that accepts a protobuf message, and returns
-        a (possibly modified) protobuf message of the same type. Whenever a message
-        is sent or received that matches `message_type`, `callback` is invoked on the
-        message and its result is substituted for the original.
+    def get_seedless_session(
+        self, *args: t.Any, **kwargs: t.Any
+    ) -> SessionDebugWrapper:
+        session = super().get_seedless_session(*args, **kwargs)
+        if not isinstance(session, SessionDebugWrapper):
+            session = SessionDebugWrapper(session)
+        return session
 
-        Useful for test scenarios with an active malicious actor on the wire.
-        """
-        if not self.in_with_statement:
-            raise RuntimeError("Must be called inside 'with' statement")
-
-        self.filters[message_type] = callback
-
-    def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
-        message_type = msg.__class__
-        callback = self.filters.get(message_type)
-        if callable(callback):
-            return callback(deepcopy(msg))
+    def resume_session(self, session: Session) -> SessionDebugWrapper:
+        if isinstance(session, SessionDebugWrapper):
+            session._session = super().resume_session(session._session)
+            return session
         else:
-            return msg
+            return SessionDebugWrapper(super().resume_session(session))
 
     def set_input_flow(
-        self, input_flow: InputFlowType | Callable[[], InputFlowType]
+        self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
     ) -> None:
         """Configure a sequence of input events for the current with-block.
 
@@ -1104,7 +1452,7 @@ class TrezorClientDebugLink(TrezorClient):
         self.in_with_statement = True
         return self
 
-    def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None:
+    def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
         __tracebackhide__ = True  # for pytest # pylint: disable=W0612
 
         # copy expected/actual responses before clearing them
@@ -1117,21 +1465,23 @@ class TrezorClientDebugLink(TrezorClient):
         else:
             input_flow = None
 
-        self.reset_debug_features()
+        self.reset_debug_features(new_seedless_session=False)
 
         if exc_type is None:
             # If no other exception was raised, evaluate missed responses
             # (raises AssertionError on mismatch)
             self._verify_responses(expected_responses, actual_responses)
 
-        elif isinstance(input_flow, Generator):
+        elif isinstance(input_flow, t.Generator):
             # Propagate the exception through the input flow, so that we see in
             # traceback where it is stuck.
             input_flow.throw(exc_type, value, traceback)
 
     def set_expected_responses(
         self,
-        expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]],
+        expected: t.Sequence[
+            t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]]
+        ],
     ) -> None:
         """Set a sequence of expected responses to client calls.
 
@@ -1170,33 +1520,17 @@ class TrezorClientDebugLink(TrezorClient):
         ]
         self.actual_responses = []
 
-    def use_pin_sequence(self, pins: Iterable[str]) -> None:
+    def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
         """Respond to PIN prompts from device with the provided PINs.
         The sequence must be at least as long as the expected number of PIN prompts.
         """
         self.ui.pins = iter(pins)
 
-    def use_passphrase(self, passphrase: str) -> None:
-        """Respond to passphrase prompts from device with the provided passphrase."""
-        self.ui.passphrase = Mnemonic.normalize_string(passphrase)
-
     def use_mnemonic(self, mnemonic: str) -> None:
         """Use the provided mnemonic to respond to device.
         Only applies to T1, where device prompts the host for mnemonic words."""
         self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
 
-    def _raw_read(self) -> protobuf.MessageType:
-        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
-
-        resp = super()._raw_read()
-        resp = self._filter_message(resp)
-        if self.actual_responses is not None:
-            self.actual_responses.append(resp)
-        return resp
-
-    def _raw_write(self, msg: protobuf.MessageType) -> None:
-        return super()._raw_write(self._filter_message(msg))
-
     @staticmethod
     def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
         start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
@@ -1265,23 +1599,22 @@ class TrezorClientDebugLink(TrezorClient):
 
         # Start by canceling whatever is on screen. This will work to cancel T1 PIN
         # prompt, which is in TINY mode and does not respond to `Ping`.
-        cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel())
-        self.transport.begin_session()
-        try:
-            self.transport.write(*cancel_msg)
-
-            message = "SYNC" + secrets.token_hex(8)
-            ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message))
-            self.transport.write(*ping_msg)
-            resp = None
-            while resp != messages.Success(message=message):
-                msg_id, msg_bytes = self.transport.read()
-                try:
-                    resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes)
-                except Exception:
-                    pass
-        finally:
-            self.transport.end_session()
+        if self.protocol_version is ProtocolVersion.PROTOCOL_V1:
+            assert isinstance(self.protocol, ProtocolV1Channel)
+            self.transport.open()
+            try:
+                self.protocol.write(messages.Cancel())
+                resp = self.protocol.read()
+                message = "SYNC" + secrets.token_hex(8)
+                self.protocol.write(messages.Ping(message=message))
+                while resp != messages.Success(message=message):
+                    try:
+                        resp = self.protocol.read()
+                    except Exception:
+                        pass
+            finally:
+                pass
+                # TODO fix self.transport.end_session()
 
     def mnemonic_callback(self, _) -> str:
         word, pos = self.debug.read_recovery_word()
@@ -1294,8 +1627,8 @@ class TrezorClientDebugLink(TrezorClient):
 
 
 def load_device(
-    client: "TrezorClient",
-    mnemonic: Union[str, Iterable[str]],
+    session: "Session",
+    mnemonic: str | t.Iterable[str],
     pin: str | None,
     passphrase_protection: bool,
     label: str | None,
@@ -1308,12 +1641,12 @@ def load_device(
 
     mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic]
 
-    if client.features.initialized:
+    if session.features.initialized:
         raise RuntimeError(
             "Device is initialized already. Call device.wipe() and try again."
         )
 
-    client.call(
+    session.call(
         messages.LoadDevice(
             mnemonics=mnemonics,
             pin=pin,
@@ -1325,18 +1658,18 @@ def load_device(
         ),
         expect=messages.Success,
     )
-    client.init_device()
+    session.refresh_features()
 
 
 # keep the old name for compatibility
 load_device_by_mnemonic = load_device
 
 
-def prodtest_t1(client: "TrezorClient") -> None:
-    if client.features.bootloader_mode is not True:
+def prodtest_t1(session: "Session") -> None:
+    if session.features.bootloader_mode is not True:
         raise RuntimeError("Device must be in bootloader mode")
 
-    client.call(
+    session.call(
         messages.ProdTestT1(
             payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
         ),
@@ -1346,8 +1679,8 @@ def prodtest_t1(client: "TrezorClient") -> None:
 
 def record_screen(
     debug_client: "TrezorClientDebugLink",
-    directory: Union[str, None],
-    report_func: Union[Callable[[str], None], None] = None,
+    directory: str | None,
+    report_func: t.Callable[[str], None] | None = None,
 ) -> None:
     """Record screen changes into a specified directory.
 
@@ -1392,8 +1725,8 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
     return debug_client.features.fw_vendor == "EMULATOR"
 
 
-def optiga_set_sec_max(client: "TrezorClient") -> None:
-    client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)
+def optiga_set_sec_max(session: "Session") -> None:
+    session.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)
 
 
 class ScreenButtons:
@@ -1656,26 +1989,26 @@ class ButtonActions:
         else:
             return PASSPHRASE_SPECIAL
 
-    def passphrase(self, char: str) -> Tuple[Coords, int]:
+    def passphrase(self, char: str) -> t.Tuple[Coords, int]:
         choices = self._passphrase_choices(char)
         idx = next(i for i, letters in enumerate(choices) if char in letters)
         click_amount = choices[idx].index(char) + 1
         return self.buttons.pin_passphrase_index(idx), click_amount
 
-    def type_word(self, word: str, is_slip39: bool = False) -> Iterator[Coords]:
+    def type_word(self, word: str, is_slip39: bool = False) -> t.Iterator[Coords]:
         if is_slip39:
             yield from self._type_word_slip39(word)
         else:
             yield from self._type_word_bip39(word)
 
-    def _type_word_slip39(self, word: str) -> Iterator[Coords]:
+    def _type_word_slip39(self, word: str) -> t.Iterator[Coords]:
         for l in word:
             idx = next(
                 i for i, letters in enumerate(BUTTON_LETTERS_SLIP39) if l in letters
             )
             yield self.buttons.mnemonic_from_index(idx)
 
-    def _type_word_bip39(self, word: str) -> Iterator[Coords]:
+    def _type_word_bip39(self, word: str) -> t.Iterator[Coords]:
         coords_prev: Coords | None = None
         for letter in word:
             time.sleep(0.1)  # not being so quick to miss something
@@ -1688,7 +2021,7 @@ class ButtonActions:
             for _ in range(amount):
                 yield coords
 
-    def _letter_coords_and_amount(self, letter: str) -> Tuple[Coords, int]:
+    def _letter_coords_and_amount(self, letter: str) -> t.Tuple[Coords, int]:
         idx = next(
             i for i, letters in enumerate(BUTTON_LETTERS_BIP39) if letter in letters
         )
diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py
index c08d485ed0..a3b24c247d 100644
--- a/python/src/trezorlib/device.py
+++ b/python/src/trezorlib/device.py
@@ -28,16 +28,10 @@ from slip10 import SLIP10
 
 from . import messages
 from .exceptions import Cancelled, TrezorException
-from .tools import (
-    Address,
-    _deprecation_retval_helper,
-    _return_success,
-    parse_path,
-    session,
-)
+from .tools import Address, _deprecation_retval_helper, _return_success, parse_path
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
+    from .transport.session import Session
 
 
 RECOVERY_BACK = "\x08"  # backspace character, sent literally
@@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
 ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
 
 
-@session
 def apply_settings(
-    client: "TrezorClient",
+    session: "Session",
     label: Optional[str] = None,
     language: Optional[str] = None,
     use_passphrase: Optional[bool] = None,
@@ -79,13 +72,13 @@ def apply_settings(
         haptic_feedback=haptic_feedback,
     )
 
-    out = client.call(settings, expect=messages.Success)
-    client.refresh_features()
+    out = session.call(settings, expect=messages.Success)
+    session.refresh_features()
     return _return_success(out)
 
 
 def _send_language_data(
-    client: "TrezorClient",
+    session: "Session",
     request: "messages.TranslationDataRequest",
     language_data: bytes,
 ) -> None:
@@ -95,69 +88,63 @@ def _send_language_data(
         data_length = response.data_length
         data_offset = response.data_offset
         chunk = language_data[data_offset : data_offset + data_length]
-        response = client.call(messages.TranslationDataAck(data_chunk=chunk))
+        response = session.call(messages.TranslationDataAck(data_chunk=chunk))
 
 
-@session
 def change_language(
-    client: "TrezorClient",
+    session: "Session",
     language_data: bytes,
     show_display: bool | None = None,
 ) -> str | None:
     data_length = len(language_data)
     msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
 
-    response = client.call(msg)
+    response = session.call(msg)
     if data_length > 0:
         response = messages.TranslationDataRequest.ensure_isinstance(response)
-        _send_language_data(client, response, language_data)
+        _send_language_data(session, response, language_data)
     else:
         messages.Success.ensure_isinstance(response)
-    client.refresh_features()  # changing the language in features
+    session.refresh_features()  # changing the language in features
     return _return_success(messages.Success(message="Language changed."))
 
 
-@session
-def apply_flags(client: "TrezorClient", flags: int) -> str | None:
-    out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
-    client.refresh_features()
+def apply_flags(session: "Session", flags: int) -> str | None:
+    out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
+    session.refresh_features()
     return _return_success(out)
 
 
-@session
-def change_pin(client: "TrezorClient", remove: bool = False) -> str | None:
-    ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success)
-    client.refresh_features()
+def change_pin(session: "Session", remove: bool = False) -> str | None:
+    ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success)
+    session.refresh_features()
     return _return_success(ret)
 
 
-@session
-def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None:
-    ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
-    client.refresh_features()
+def change_wipe_code(session: "Session", remove: bool = False) -> str | None:
+    ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
+    session.refresh_features()
     return _return_success(ret)
 
 
-@session
 def sd_protect(
-    client: "TrezorClient", operation: messages.SdProtectOperationType
+    session: "Session", operation: messages.SdProtectOperationType
 ) -> str | None:
-    ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success)
-    client.refresh_features()
+    ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success)
+    session.refresh_features()
     return _return_success(ret)
 
 
-@session
-def wipe(client: "TrezorClient") -> str | None:
-    ret = client.call(messages.WipeDevice(), expect=messages.Success)
-    if not client.features.bootloader_mode:
-        client.init_device()
+def wipe(session: "Session") -> str | None:
+    ret = session.call(messages.WipeDevice(), expect=messages.Success)
+    session.invalidate()
+    # if not session.features.bootloader_mode:
+    #     session.refresh_features()
     return _return_success(ret)
 
 
-@session
 def recover(
-    client: "TrezorClient",
+    session: "Session",
     word_count: int = 24,
     passphrase_protection: bool = False,
     pin_protection: bool = True,
@@ -193,13 +180,13 @@ def recover(
     if type is None:
         type = messages.RecoveryType.NormalRecovery
 
-    if client.features.model == "1" and input_callback is None:
+    if session.features.model == "1" and input_callback is None:
         raise RuntimeError("Input callback required for Trezor One")
 
     if word_count not in (12, 18, 24):
         raise ValueError("Invalid word count. Use 12/18/24")
 
-    if client.features.initialized and type == messages.RecoveryType.NormalRecovery:
+    if session.features.initialized and type == messages.RecoveryType.NormalRecovery:
         raise RuntimeError(
             "Device already initialized. Call device.wipe() and try again."
         )
@@ -221,20 +208,20 @@ def recover(
         msg.label = label
         msg.u2f_counter = u2f_counter
 
-    res = client.call(msg)
+    res = session.call(msg)
 
     while isinstance(res, messages.WordRequest):
         try:
             assert input_callback is not None
             inp = input_callback(res.type)
-            res = client.call(messages.WordAck(word=inp))
+            res = session.call(messages.WordAck(word=inp))
         except Cancelled:
-            res = client.call(messages.Cancel())
+            res = session.call(messages.Cancel())
 
     # check that the result is a Success
     res = messages.Success.ensure_isinstance(res)
     # reinitialize the device
-    client.init_device()
+    session.refresh_features()
 
     return _deprecation_retval_helper(res)
 
@@ -280,7 +267,7 @@ def _seed_from_entropy(
 
 
 def reset(
-    client: "TrezorClient",
+    session: "Session",
     display_random: bool = False,
     strength: Optional[int] = None,
     passphrase_protection: bool = False,
@@ -313,7 +300,7 @@ def reset(
         )
 
     setup(
-        client,
+        session,
         strength=strength,
         passphrase_protection=passphrase_protection,
         pin_protection=pin_protection,
@@ -331,9 +318,8 @@ def _get_external_entropy() -> bytes:
     return secrets.token_bytes(32)
 
 
-@session
 def setup(
-    client: "TrezorClient",
+    session: "Session",
     *,
     strength: Optional[int] = None,
     passphrase_protection: bool = True,
@@ -388,19 +374,19 @@ def setup(
         check.
     """
 
-    if client.features.initialized:
+    if session.features.initialized:
         raise RuntimeError(
             "Device is initialized already. Call wipe_device() and try again."
         )
 
     if strength is None:
-        if client.features.model == "1":
+        if session.features.model == "1":
             strength = 256
         else:
             strength = 128
 
     if backup_type is None:
-        if client.version < SLIP39_EXTENDABLE_MIN_VERSION:
+        if session.version < SLIP39_EXTENDABLE_MIN_VERSION:
             # includes Trezor One 1.x.x
             backup_type = messages.BackupType.Bip39
         else:
@@ -411,7 +397,7 @@ def setup(
         paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")]
 
     if entropy_check_count is None:
-        if client.version < ENTROPY_CHECK_MIN_VERSION:
+        if session.version < ENTROPY_CHECK_MIN_VERSION:
             # includes Trezor One 1.x.x
             entropy_check_count = 0
         else:
@@ -431,18 +417,18 @@ def setup(
     )
     if entropy_check_count > 0:
         xpubs = _reset_with_entropycheck(
-            client, msg, entropy_check_count, paths, _get_entropy
+            session, msg, entropy_check_count, paths, _get_entropy
         )
     else:
-        _reset_no_entropycheck(client, msg, _get_entropy)
+        _reset_no_entropycheck(session, msg, _get_entropy)
         xpubs = []
 
-    client.init_device()
+    session.refresh_features()
     return xpubs
 
 
 def _reset_no_entropycheck(
-    client: "TrezorClient",
+    session: "Session",
     msg: messages.ResetDevice,
     get_entropy: Callable[[], bytes],
 ) -> None:
@@ -454,12 +440,12 @@ def _reset_no_entropycheck(
     << Success
     """
     assert msg.entropy_check is False
-    client.call(msg, expect=messages.EntropyRequest)
-    client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
+    session.call(msg, expect=messages.EntropyRequest)
+    session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success)
 
 
 def _reset_with_entropycheck(
-    client: "TrezorClient",
+    session: "Session",
     reset_msg: messages.ResetDevice,
     entropy_check_count: int,
     paths: Iterable[Address],
@@ -495,7 +481,7 @@ def _reset_with_entropycheck(
     def get_xpubs() -> list[tuple[Address, str]]:
         xpubs = []
         for path in paths:
-            resp = client.call(
+            resp = session.call(
                 messages.GetPublicKey(address_n=path), expect=messages.PublicKey
             )
             xpubs.append((path, resp.xpub))
@@ -524,13 +510,13 @@ def _reset_with_entropycheck(
                 raise TrezorException("Invalid XPUB in entropy check")
 
     xpubs = []
-    resp = client.call(reset_msg, expect=messages.EntropyRequest)
+    resp = session.call(reset_msg, expect=messages.EntropyRequest)
     entropy_commitment = resp.entropy_commitment
 
     while True:
         # provide external entropy for this round
         external_entropy = get_entropy()
-        client.call(
+        session.call(
             messages.EntropyAck(entropy=external_entropy),
             expect=messages.EntropyCheckReady,
         )
@@ -540,7 +526,7 @@ def _reset_with_entropycheck(
 
         if entropy_check_count <= 0:
             # last round, wait for a Success and exit the loop
-            client.call(
+            session.call(
                 messages.EntropyCheckContinue(finish=True),
                 expect=messages.Success,
             )
@@ -549,7 +535,7 @@ def _reset_with_entropycheck(
         entropy_check_count -= 1
 
         # Next round starts.
-        resp = client.call(
+        resp = session.call(
             messages.EntropyCheckContinue(finish=False),
             expect=messages.EntropyRequest,
         )
@@ -570,13 +556,12 @@ def _reset_with_entropycheck(
     return xpubs
 
 
-@session
 def backup(
-    client: "TrezorClient",
+    session: "Session",
     group_threshold: Optional[int] = None,
     groups: Iterable[tuple[int, int]] = (),
 ) -> str | None:
-    ret = client.call(
+    ret = session.call(
         messages.BackupDevice(
             group_threshold=group_threshold,
             groups=[
@@ -586,37 +571,36 @@ def backup(
         ),
         expect=messages.Success,
     )
-    client.refresh_features()
+    session.refresh_features()
     return _return_success(ret)
 
 
-def cancel_authorization(client: "TrezorClient") -> str | None:
-    ret = client.call(messages.CancelAuthorization(), expect=messages.Success)
+def cancel_authorization(session: "Session") -> str | None:
+    ret = session.call(messages.CancelAuthorization(), expect=messages.Success)
     return _return_success(ret)
 
 
-def unlock_path(client: "TrezorClient", n: "Address") -> bytes:
-    resp = client.call(
+def unlock_path(session: "Session", n: "Address") -> bytes:
+    resp = session.call(
         messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
     )
 
     # Cancel the UnlockPath workflow now that we have the authentication code.
     try:
-        client.call(messages.Cancel())
+        session.call(messages.Cancel())
     except Cancelled:
         return resp.mac
     else:
         raise TrezorException("Unexpected response in UnlockPath flow")
 
 
-@session
 def reboot_to_bootloader(
-    client: "TrezorClient",
+    session: "Session",
     boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
     firmware_header: Optional[bytes] = None,
     language_data: bytes = b"",
 ) -> str | None:
-    response = client.call(
+    response = session.call(
         messages.RebootToBootloader(
             boot_command=boot_command,
             firmware_header=firmware_header,
@@ -624,43 +608,38 @@ def reboot_to_bootloader(
         )
     )
     if isinstance(response, messages.TranslationDataRequest):
-        response = _send_language_data(client, response, language_data)
+        response = _send_language_data(session, response, language_data)
     return _return_success(messages.Success(message=""))
 
 
-@session
-def show_device_tutorial(client: "TrezorClient") -> str | None:
-    ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
+def show_device_tutorial(session: "Session") -> str | None:
+    ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success)
     return _return_success(ret)
 
 
-@session
-def unlock_bootloader(client: "TrezorClient") -> str | None:
-    ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
+def unlock_bootloader(session: "Session") -> str | None:
+    ret = session.call(messages.UnlockBootloader(), expect=messages.Success)
     return _return_success(ret)
 
 
-@session
-def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
+def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None:
     """Sets or clears the busy state of the device.
 
     In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
     Setting `expiry_ms=None` clears the busy state.
     """
-    ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
-    client.refresh_features()
+    ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
+    session.refresh_features()
     return _return_success(ret)
 
 
-def authenticate(
-    client: "TrezorClient", challenge: bytes
-) -> messages.AuthenticityProof:
-    return client.call(
+def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof:
+    return session.call(
         messages.AuthenticateDevice(challenge=challenge),
         expect=messages.AuthenticityProof,
     )
 
 
-def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None:
-    ret = client.call(messages.SetBrightness(value=value), expect=messages.Success)
+def set_brightness(session: "Session", value: Optional[int] = None) -> str | None:
+    ret = session.call(messages.SetBrightness(value=value), expect=messages.Success)
     return _return_success(ret)
diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py
index eb491f204c..990adf3855 100644
--- a/python/src/trezorlib/eos.py
+++ b/python/src/trezorlib/eos.py
@@ -18,11 +18,11 @@ from datetime import datetime
 from typing import TYPE_CHECKING, List, Tuple
 
 from . import exceptions, messages
-from .tools import b58decode, session
+from .tools import b58decode
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 
 def name_to_number(name: str) -> int:
@@ -319,17 +319,16 @@ def parse_transaction_json(
 
 
 def get_public_key(
-    client: "TrezorClient", n: "Address", show_display: bool = False
+    session: "Session", n: "Address", show_display: bool = False
 ) -> messages.EosPublicKey:
-    return client.call(
+    return session.call(
         messages.EosGetPublicKey(address_n=n, show_display=show_display),
         expect=messages.EosPublicKey,
     )
 
 
-@session
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     address: "Address",
     transaction: dict,
     chain_id: str,
@@ -345,11 +344,11 @@ def sign_tx(
         chunkify=chunkify,
     )
 
-    response = client.call(msg)
+    response = session.call(msg)
 
     try:
         while isinstance(response, messages.EosTxActionRequest):
-            response = client.call(actions.pop(0))
+            response = session.call(actions.pop(0))
     except IndexError:
         # pop from empty list
         raise exceptions.TrezorException(
diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py
index 96ce4d1066..f3f3e57e06 100644
--- a/python/src/trezorlib/ethereum.py
+++ b/python/src/trezorlib/ethereum.py
@@ -18,11 +18,11 @@ import re
 from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
 
 from . import definitions, exceptions, messages
-from .tools import prepare_message_bytes, session, unharden
+from .tools import prepare_message_bytes, unharden
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 
 def int_to_big_endian(value: int) -> bytes:
@@ -161,13 +161,13 @@ def network_from_address_n(
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     show_display: bool = False,
     encoded_network: Optional[bytes] = None,
     chunkify: bool = False,
 ) -> str:
-    resp = client.call(
+    resp = session.call(
         messages.EthereumGetAddress(
             address_n=n,
             show_display=show_display,
@@ -181,17 +181,16 @@ def get_address(
 
 
 def get_public_node(
-    client: "TrezorClient", n: "Address", show_display: bool = False
+    session: "Session", n: "Address", show_display: bool = False
 ) -> messages.EthereumPublicKey:
-    return client.call(
+    return session.call(
         messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
         expect=messages.EthereumPublicKey,
     )
 
 
-@session
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     nonce: int,
     gas_price: int,
@@ -227,13 +226,13 @@ def sign_tx(
     data, chunk = data[1024:], data[:1024]
     msg.data_initial_chunk = chunk
 
-    response = client.call(msg)
+    response = session.call(msg)
     assert isinstance(response, messages.EthereumTxRequest)
 
     while response.data_length is not None:
         data_length = response.data_length
         data, chunk = data[data_length:], data[:data_length]
-        response = client.call(messages.EthereumTxAck(data_chunk=chunk))
+        response = session.call(messages.EthereumTxAck(data_chunk=chunk))
         assert isinstance(response, messages.EthereumTxRequest)
 
     assert response.signature_v is not None
@@ -248,9 +247,8 @@ def sign_tx(
     return response.signature_v, response.signature_r, response.signature_s
 
 
-@session
 def sign_tx_eip1559(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     *,
     nonce: int,
@@ -283,13 +281,13 @@ def sign_tx_eip1559(
         chunkify=chunkify,
     )
 
-    response = client.call(msg)
+    response = session.call(msg)
     assert isinstance(response, messages.EthereumTxRequest)
 
     while response.data_length is not None:
         data_length = response.data_length
         data, chunk = data[data_length:], data[:data_length]
-        response = client.call(messages.EthereumTxAck(data_chunk=chunk))
+        response = session.call(messages.EthereumTxAck(data_chunk=chunk))
         assert isinstance(response, messages.EthereumTxRequest)
 
     assert response.signature_v is not None
@@ -299,13 +297,13 @@ def sign_tx_eip1559(
 
 
 def sign_message(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     message: AnyStr,
     encoded_network: Optional[bytes] = None,
     chunkify: bool = False,
 ) -> messages.EthereumMessageSignature:
-    return client.call(
+    return session.call(
         messages.EthereumSignMessage(
             address_n=n,
             message=prepare_message_bytes(message),
@@ -317,7 +315,7 @@ def sign_message(
 
 
 def sign_typed_data(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     data: Dict[str, Any],
     *,
@@ -333,7 +331,7 @@ def sign_typed_data(
         metamask_v4_compat=metamask_v4_compat,
         definitions=definitions,
     )
-    response = client.call(request)
+    response = session.call(request)
 
     # Sending all the types
     while isinstance(response, messages.EthereumTypedDataStructRequest):
@@ -349,7 +347,7 @@ def sign_typed_data(
             members.append(struct_member)
 
         request = messages.EthereumTypedDataStructAck(members=members)
-        response = client.call(request)
+        response = session.call(request)
 
     # Sending the whole message that should be signed
     while isinstance(response, messages.EthereumTypedDataValueRequest):
@@ -362,7 +360,7 @@ def sign_typed_data(
             member_typename = data["primaryType"]
             member_data = data["message"]
         else:
-            client.cancel()
+            session.cancel()
             raise exceptions.TrezorException("Root index can only be 0 or 1")
 
         # It can be asking for a nested structure (the member path being [X, Y, Z, ...])
@@ -385,20 +383,20 @@ def sign_typed_data(
             encoded_data = encode_data(member_data, member_typename)
 
         request = messages.EthereumTypedDataValueAck(value=encoded_data)
-        response = client.call(request)
+        response = session.call(request)
 
     return messages.EthereumTypedDataSignature.ensure_isinstance(response)
 
 
 def verify_message(
-    client: "TrezorClient",
+    session: "Session",
     address: str,
     signature: bytes,
     message: AnyStr,
     chunkify: bool = False,
 ) -> bool:
     try:
-        client.call(
+        session.call(
             messages.EthereumVerifyMessage(
                 address=address,
                 signature=signature,
@@ -413,13 +411,13 @@ def verify_message(
 
 
 def sign_typed_data_hash(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     domain_hash: bytes,
     message_hash: Optional[bytes],
     encoded_network: Optional[bytes] = None,
 ) -> messages.EthereumTypedDataSignature:
-    return client.call(
+    return session.call(
         messages.EthereumSignTypedHash(
             address_n=n,
             domain_separator_hash=domain_hash,
diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py
index a2618b72db..aaa3b084bf 100644
--- a/python/src/trezorlib/fido.py
+++ b/python/src/trezorlib/fido.py
@@ -22,37 +22,37 @@ from . import messages
 from .tools import _return_success
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
+    from .transport.session import Session
 
 
-def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]:
-    return client.call(
+def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]:
+    return session.call(
         messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
     ).credentials
 
 
-def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None:
-    ret = client.call(
+def add_credential(session: "Session", credential_id: bytes) -> str | None:
+    ret = session.call(
         messages.WebAuthnAddResidentCredential(credential_id=credential_id),
         expect=messages.Success,
     )
     return _return_success(ret)
 
 
-def remove_credential(client: "TrezorClient", index: int) -> str | None:
-    ret = client.call(
+def remove_credential(session: "Session", index: int) -> str | None:
+    ret = session.call(
         messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
     )
     return _return_success(ret)
 
 
-def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None:
-    ret = client.call(
+def set_counter(session: "Session", u2f_counter: int) -> str | None:
+    ret = session.call(
         messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
     )
     return _return_success(ret)
 
 
-def get_next_counter(client: "TrezorClient") -> int:
-    ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
+def get_next_counter(session: "Session") -> int:
+    ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
     return ret.u2f_counter
diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py
index 1c36ba9acc..ac766b42d0 100644
--- a/python/src/trezorlib/firmware/__init__.py
+++ b/python/src/trezorlib/firmware/__init__.py
@@ -22,7 +22,6 @@ from hashlib import blake2s
 from typing_extensions import Protocol, TypeGuard
 
 from .. import messages
-from ..tools import session
 from .core import VendorFirmware
 from .legacy import LegacyFirmware, LegacyV2Firmware
 from .models import Model
@@ -41,7 +40,7 @@ if True:
     from .vendor import *  # noqa: F401, F403
 
 if t.TYPE_CHECKING:
-    from ..client import TrezorClient
+    from ..transport.session import Session
 
     T = t.TypeVar("T", bound="FirmwareType")
 
@@ -77,20 +76,19 @@ def is_onev2(fw: FirmwareType) -> TypeGuard[LegacyFirmware]:
 # ====== Client functions ====== #
 
 
-@session
 def update(
-    client: TrezorClient,
+    session: Session,
     data: bytes,
     progress_update: t.Callable[[int], t.Any] = lambda _: None,
 ):
-    if client.features.bootloader_mode is False:
+    if session.features.bootloader_mode is False:
         raise RuntimeError("Device must be in bootloader mode")
 
-    resp = client.call(messages.FirmwareErase(length=len(data)))
+    resp = session.call(messages.FirmwareErase(length=len(data)))
 
     # TREZORv1 method
     if isinstance(resp, messages.Success):
-        resp = client.call(messages.FirmwareUpload(payload=data))
+        resp = session.call(messages.FirmwareUpload(payload=data))
         progress_update(len(data))
         if isinstance(resp, messages.Success):
             return
@@ -102,7 +100,7 @@ def update(
         length = resp.length
         payload = data[resp.offset : resp.offset + length]
         digest = blake2s(payload).digest()
-        resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest))
+        resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest))
         progress_update(length)
 
     if isinstance(resp, messages.Success):
@@ -111,7 +109,7 @@ def update(
         raise RuntimeError(f"Unexpected message {resp}")
 
 
-def get_hash(client: TrezorClient, challenge: bytes | None) -> bytes:
-    return client.call(
+def get_hash(session: Session, challenge: bytes | None) -> bytes:
+    return session.call(
         messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
     ).hash
diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py
index 532277078f..1d5b867e4a 100644
--- a/python/src/trezorlib/mapping.py
+++ b/python/src/trezorlib/mapping.py
@@ -85,6 +85,7 @@ class ProtobufMapping:
         mapping = cls()
 
         message_types = getattr(module, "MessageType")
+
         for entry in message_types:
             msg_class = getattr(module, entry.name, None)
             if msg_class is None:
diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py
index 578c1fa19f..eeaea26872 100644
--- a/python/src/trezorlib/misc.py
+++ b/python/src/trezorlib/misc.py
@@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional
 from . import messages
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 
-def get_entropy(client: "TrezorClient", size: int) -> bytes:
-    return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
+def get_entropy(session: "Session", size: int) -> bytes:
+    return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
 
 
 def sign_identity(
-    client: "TrezorClient",
+    session: "Session",
     identity: messages.IdentityType,
     challenge_hidden: bytes,
     challenge_visual: str,
     ecdsa_curve_name: Optional[str] = None,
 ) -> messages.SignedIdentity:
-    return client.call(
+    return session.call(
         messages.SignIdentity(
             identity=identity,
             challenge_hidden=challenge_hidden,
@@ -46,12 +46,12 @@ def sign_identity(
 
 
 def get_ecdh_session_key(
-    client: "TrezorClient",
+    session: "Session",
     identity: messages.IdentityType,
     peer_public_key: bytes,
     ecdsa_curve_name: Optional[str] = None,
 ) -> messages.ECDHSessionKey:
-    return client.call(
+    return session.call(
         messages.GetECDHSessionKey(
             identity=identity,
             peer_public_key=peer_public_key,
@@ -62,7 +62,7 @@ def get_ecdh_session_key(
 
 
 def encrypt_keyvalue(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     key: str,
     value: bytes,
@@ -70,7 +70,7 @@ def encrypt_keyvalue(
     ask_on_decrypt: bool = True,
     iv: bytes = b"",
 ) -> bytes:
-    return client.call(
+    return session.call(
         messages.CipherKeyValue(
             address_n=n,
             key=key,
@@ -85,7 +85,7 @@ def encrypt_keyvalue(
 
 
 def decrypt_keyvalue(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     key: str,
     value: bytes,
@@ -93,7 +93,7 @@ def decrypt_keyvalue(
     ask_on_decrypt: bool = True,
     iv: bytes = b"",
 ) -> bytes:
-    return client.call(
+    return session.call(
         messages.CipherKeyValue(
             address_n=n,
             key=key,
@@ -107,5 +107,5 @@ def decrypt_keyvalue(
     ).value
 
 
-def get_nonce(client: "TrezorClient") -> bytes:
-    return client.call(messages.GetNonce(), expect=messages.Nonce).nonce
+def get_nonce(session: "Session") -> bytes:
+    return session.call(messages.GetNonce(), expect=messages.Nonce).nonce
diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py
index b2e3214fb9..9e32346156 100644
--- a/python/src/trezorlib/monero.py
+++ b/python/src/trezorlib/monero.py
@@ -19,8 +19,8 @@ from typing import TYPE_CHECKING
 from . import messages
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 
 # MAINNET = 0
@@ -30,13 +30,13 @@ if TYPE_CHECKING:
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     show_display: bool = False,
     network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
     chunkify: bool = False,
 ) -> bytes:
-    return client.call(
+    return session.call(
         messages.MoneroGetAddress(
             address_n=n,
             show_display=show_display,
@@ -48,11 +48,11 @@ def get_address(
 
 
 def get_watch_key(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
 ) -> messages.MoneroWatchKey:
-    return client.call(
+    return session.call(
         messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
         expect=messages.MoneroWatchKey,
     )
diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py
index 744dc3205f..357de145ad 100644
--- a/python/src/trezorlib/nem.py
+++ b/python/src/trezorlib/nem.py
@@ -20,8 +20,8 @@ from typing import TYPE_CHECKING
 from . import exceptions, messages
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 TYPE_TRANSACTION_TRANSFER = 0x0101
 TYPE_IMPORTANCE_TRANSFER = 0x0801
@@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     n: "Address",
     network: int,
     show_display: bool = False,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         messages.NEMGetAddress(
             address_n=n, network=network, show_display=show_display, chunkify=chunkify
         ),
@@ -210,7 +210,7 @@ def get_address(
 
 
 def sign_tx(
-    client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
+    session: "Session", n: "Address", transaction: dict, chunkify: bool = False
 ) -> messages.NEMSignedTx:
     try:
         msg = create_sign_tx(transaction, chunkify=chunkify)
@@ -219,4 +219,4 @@ def sign_tx(
 
     assert msg.transaction is not None
     msg.transaction.address_n = n
-    return client.call(msg, expect=messages.NEMSignedTx)
+    return session.call(msg, expect=messages.NEMSignedTx)
diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py
index 00a027c6d9..e5e0f524cc 100644
--- a/python/src/trezorlib/ripple.py
+++ b/python/src/trezorlib/ripple.py
@@ -21,20 +21,20 @@ from .protobuf import dict_to_proto
 from .tools import dict_from_camelcase
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
 REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     address_n: "Address",
     show_display: bool = False,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         messages.RippleGetAddress(
             address_n=address_n, show_display=show_display, chunkify=chunkify
         ),
@@ -43,14 +43,14 @@ def get_address(
 
 
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     address_n: "Address",
     msg: messages.RippleSignTx,
     chunkify: bool = False,
 ) -> messages.RippleSignedTx:
     msg.address_n = address_n
     msg.chunkify = chunkify
-    return client.call(msg, expect=messages.RippleSignedTx)
+    return session.call(msg, expect=messages.RippleSignedTx)
 
 
 def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:
diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py
index 0054e0fd92..3d0ee75549 100644
--- a/python/src/trezorlib/solana.py
+++ b/python/src/trezorlib/solana.py
@@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional
 from . import messages
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
+    from .transport.session import Session
 
 
 def get_public_key(
-    client: "TrezorClient",
+    session: "Session",
     address_n: List[int],
     show_display: bool,
 ) -> bytes:
-    return client.call(
+    return session.call(
         messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
         expect=messages.SolanaPublicKey,
     ).public_key
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     address_n: List[int],
     show_display: bool,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         messages.SolanaGetAddress(
             address_n=address_n,
             show_display=show_display,
@@ -34,12 +34,12 @@ def get_address(
 
 
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     address_n: List[int],
     serialized_tx: bytes,
     additional_info: Optional[messages.SolanaTxAdditionalInfo],
 ) -> bytes:
-    return client.call(
+    return session.call(
         messages.SolanaSignTx(
             address_n=address_n,
             serialized_tx=serialized_tx,
diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py
index 5bd0a749e4..843a2e0c39 100644
--- a/python/src/trezorlib/stellar.py
+++ b/python/src/trezorlib/stellar.py
@@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union
 from . import exceptions, messages
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
     StellarMessageType = Union[
         messages.StellarAccountMergeOp,
@@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     address_n: "Address",
     show_display: bool = False,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         messages.StellarGetAddress(
             address_n=address_n, show_display=show_display, chunkify=chunkify
         ),
@@ -336,7 +336,7 @@ def get_address(
 
 
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     tx: messages.StellarSignTx,
     operations: List["StellarMessageType"],
     address_n: "Address",
@@ -352,10 +352,10 @@ def sign_tx(
     # 3. Receive a StellarTxOpRequest message
     # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
     # 5. The final message received will be StellarSignedTx which is returned from this method
-    resp = client.call(tx)
+    resp = session.call(tx)
     try:
         while isinstance(resp, messages.StellarTxOpRequest):
-            resp = client.call(operations.pop(0))
+            resp = session.call(operations.pop(0))
     except IndexError:
         # pop from empty list
         raise exceptions.TrezorException(
diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py
index 9319aa1eaa..06bcafe759 100644
--- a/python/src/trezorlib/tezos.py
+++ b/python/src/trezorlib/tezos.py
@@ -19,17 +19,17 @@ from typing import TYPE_CHECKING
 from . import messages
 
 if TYPE_CHECKING:
-    from .client import TrezorClient
     from .tools import Address
+    from .transport.session import Session
 
 
 def get_address(
-    client: "TrezorClient",
+    session: "Session",
     address_n: "Address",
     show_display: bool = False,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         messages.TezosGetAddress(
             address_n=address_n, show_display=show_display, chunkify=chunkify
         ),
@@ -38,12 +38,12 @@ def get_address(
 
 
 def get_public_key(
-    client: "TrezorClient",
+    session: "Session",
     address_n: "Address",
     show_display: bool = False,
     chunkify: bool = False,
 ) -> str:
-    return client.call(
+    return session.call(
         messages.TezosGetPublicKey(
             address_n=address_n, show_display=show_display, chunkify=chunkify
         ),
@@ -52,11 +52,11 @@ def get_public_key(
 
 
 def sign_tx(
-    client: "TrezorClient",
+    session: "Session",
     address_n: "Address",
     sign_tx_msg: messages.TezosSignTx,
     chunkify: bool = False,
 ) -> messages.TezosSignedTx:
     sign_tx_msg.address_n = address_n
     sign_tx_msg.chunkify = chunkify
-    return client.call(sign_tx_msg, expect=messages.TezosSignedTx)
+    return session.call(sign_tx_msg, expect=messages.TezosSignedTx)
diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py
index 6ba8c64dba..f753e68a33 100644
--- a/python/src/trezorlib/tools.py
+++ b/python/src/trezorlib/tools.py
@@ -45,7 +45,7 @@ if TYPE_CHECKING:
     # More details: https://www.python.org/dev/peps/pep-0612/
     from typing import TypeVar
 
-    from typing_extensions import Concatenate, ParamSpec
+    from typing_extensions import ParamSpec
 
     from . import client
     from .messages import Success
@@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None:
     return _deprecation_retval_helper(msg.message, stacklevel=1)
 
 
-def session(
-    f: "Callable[Concatenate[TrezorClient, P], R]",
-) -> "Callable[Concatenate[TrezorClient, P], R]":
-    # Decorator wraps a BaseClient method
-    # with session activation / deactivation
-    @functools.wraps(f)
-    def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R":
-        __tracebackhide__ = True  # for pytest # pylint: disable=W0612
-        client.open()
-        try:
-            return f(client, *args, **kwargs)
-        finally:
-            client.close()
-
-    return wrapped_f
-
-
 # de-camelcasifier
 # https://stackoverflow.com/a/1176023/222189
 
diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py
new file mode 100644
index 0000000000..f75a4c7c15
--- /dev/null
+++ b/python/src/trezorlib/transport/session.py
@@ -0,0 +1,152 @@
+from __future__ import annotations
+
+import logging
+import typing as t
+
+from .. import exceptions, messages, models
+from ..protobuf import MessageType
+from .thp.protocol_v1 import ProtocolV1Channel
+
+if t.TYPE_CHECKING:
+    from ..client import TrezorClient
+
+LOG = logging.getLogger(__name__)
+
+MT = t.TypeVar("MT", bound=MessageType)
+
+
+class Session:
+    def __init__(self, client: TrezorClient, id: bytes) -> None:
+        self.client = client
+        self._id = id
+
+    def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
+        self.client.check_firmware_version()
+        resp = self.call_raw(msg)
+
+        while True:
+            if isinstance(resp, messages.PinMatrixRequest):
+                if self.client.pin_callback is None:
+                    raise NotImplementedError("Missing pin_callback")
+                resp = self.client.pin_callback(self, resp)
+            elif isinstance(resp, messages.PassphraseRequest):
+                if self.client.passphrase_callback is None:
+                    raise NotImplementedError("Missing passphrase_callback")
+                resp = self.client.passphrase_callback(self, resp)
+            elif isinstance(resp, messages.ButtonRequest):
+                resp = (self.client.button_callback or default_button_callback)(
+                    self, resp
+                )
+            elif isinstance(resp, messages.Failure):
+                if resp.code == messages.FailureType.ActionCancelled:
+                    raise exceptions.Cancelled
+                raise exceptions.TrezorFailure(resp)
+            elif not isinstance(resp, expect):
+                raise exceptions.UnexpectedMessageError(expect, resp)
+            else:
+                return resp
+
+    def call_raw(self, msg: t.Any) -> t.Any:
+        self._write(msg)
+        return self._read()
+
+    def _write(self, msg: t.Any) -> None:
+        raise NotImplementedError
+
+    def _read(self) -> t.Any:
+        raise NotImplementedError
+
+    def refresh_features(self) -> None:
+        self.client.refresh_features()
+
+    def end(self) -> t.Any:
+        return self.call(messages.EndSession())
+
+    def cancel(self) -> None:
+        self._write(messages.Cancel())
+
+    def ping(self, message: str, button_protection: bool | None = None) -> str:
+        resp = self.call(
+            messages.Ping(message=message, button_protection=button_protection),
+            expect=messages.Success,
+        )
+        assert resp.message is not None
+        return resp.message
+
+    def invalidate(self) -> None:
+        self.client.invalidate()
+
+    @property
+    def features(self) -> messages.Features:
+        return self.client.features
+
+    @property
+    def model(self) -> models.TrezorModel:
+        return self.client.model
+
+    @property
+    def version(self) -> t.Tuple[int, int, int]:
+        return self.client.version
+
+    @property
+    def id(self) -> bytes:
+        return self._id
+
+    @id.setter
+    def id(self, value: bytes) -> None:
+        if not isinstance(value, bytes):
+            raise ValueError("id must be of type bytes")
+        self._id = value
+
+
+class SessionV1(Session):
+    derive_cardano: bool | None = False
+
+    @classmethod
+    def new(
+        cls,
+        client: TrezorClient,
+        passphrase: str | object = "",
+        derive_cardano: bool = False,
+        session_id: bytes | None = None,
+    ) -> SessionV1:
+        assert isinstance(client.protocol, ProtocolV1Channel)
+        session = SessionV1(client, id=session_id or b"")
+
+        session.passphrase = passphrase
+        session.derive_cardano = derive_cardano
+        session.init_session(session.derive_cardano)
+        return session
+
+    @classmethod
+    def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
+        assert isinstance(client.protocol, ProtocolV1Channel)
+        session = SessionV1(client, session_id)
+        session.init_session()
+        return session
+
+    def _write(self, msg: t.Any) -> None:
+        if t.TYPE_CHECKING:
+            assert isinstance(self.client.protocol, ProtocolV1Channel)
+        self.client.protocol.write(msg)
+
+    def _read(self) -> t.Any:
+        if t.TYPE_CHECKING:
+            assert isinstance(self.client.protocol, ProtocolV1Channel)
+        return self.client.protocol.read()
+
+    def init_session(self, derive_cardano: bool | None = None):
+        if self.id == b"":
+            session_id = None
+        else:
+            session_id = self.id
+        resp: messages.Features = self.call_raw(
+            messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
+        )
+        assert isinstance(resp, messages.Features)
+        if resp.session_id is not None:
+            self.id = resp.session_id
+
+
+def default_button_callback(session: Session, msg: t.Any) -> t.Any:
+    return session.call(messages.ButtonAck())