From e2a9f6654dcaa47e41a1a6dbc35afebbfa443cb6 Mon Sep 17 00:00:00 2001
From: M1nd3r <petrsedlacek.km@seznam.cz>
Date: Mon, 2 Dec 2024 15:41:37 +0100
Subject: [PATCH] feat(core): implement THP

---
 core/src/all_modules.py                       |  44 ++
 core/src/apps/base.py                         |  51 +-
 core/src/apps/cardano/seed.py                 |  19 +-
 core/src/apps/common/backup.py                |  27 +-
 core/src/apps/common/keychain.py              |   4 +
 core/src/apps/common/passphrase.py            | 170 +++--
 core/src/apps/common/seed.py                  | 147 ++--
 core/src/apps/debug/__init__.py               |  68 +-
 .../apps/management/reboot_to_bootloader.py   |   2 +-
 .../management/recovery_device/__init__.py    |   5 +-
 .../management/recovery_device/homescreen.py  |  39 +-
 .../apps/management/reset_device/__init__.py  |   8 +-
 core/src/apps/management/wipe_device.py       |  27 +-
 core/src/apps/thp/create_new_session.py       |  59 ++
 core/src/apps/thp/pairing.py                  | 403 +++++++++++
 core/src/apps/workflow_handlers.py            |   7 +
 core/src/storage/__init__.py                  |  22 +-
 core/src/storage/cache.py                     |  29 +-
 core/src/storage/cache_common.py              |   8 +
 core/src/storage/cache_thp.py                 | 363 ++++++++++
 core/src/trezor/enums/FailureType.py          |   2 +
 core/src/trezor/enums/ThpMessageType.py       |  22 +
 core/src/trezor/enums/ThpPairingMethod.py     |   8 +
 core/src/trezor/enums/__init__.py             |  28 +
 core/src/trezor/messages.py                   | 282 ++++++++
 core/src/trezor/utils.py                      |   4 +
 core/src/trezor/wire/__init__.py              | 127 ++--
 core/src/trezor/wire/context.py               |  13 +-
 core/src/trezor/wire/errors.py                |   6 +
 core/src/trezor/wire/message_handler.py       |  64 +-
 core/src/trezor/wire/protocol_common.py       |   5 +-
 core/src/trezor/wire/thp/__init__.py          | 184 +++++
 .../wire/thp/alternating_bit_protocol.py      | 102 +++
 core/src/trezor/wire/thp/channel.py           | 405 +++++++++++
 core/src/trezor/wire/thp/channel_manager.py   |  34 +
 core/src/trezor/wire/thp/checksum.py          |  22 +
 core/src/trezor/wire/thp/control_byte.py      |  50 ++
 core/src/trezor/wire/thp/cpace.py             |  36 +
 core/src/trezor/wire/thp/crypto.py            | 211 ++++++
 core/src/trezor/wire/thp/interface_manager.py |  28 +
 core/src/trezor/wire/thp/memory_manager.py    | 179 +++++
 core/src/trezor/wire/thp/pairing_context.py   | 262 +++++++
 .../wire/thp/received_message_handler.py      | 446 ++++++++++++
 core/src/trezor/wire/thp/session_context.py   | 169 +++++
 core/src/trezor/wire/thp/session_manager.py   |  48 ++
 core/src/trezor/wire/thp/thp_main.py          | 190 +++++
 core/src/trezor/wire/thp/transmission_loop.py |  54 ++
 core/src/trezor/wire/thp/writer.py            |  93 +++
 core/src/trezor/workflow.py                   |   9 +-
 core/tests/mock_wire_interface.py             |  50 ++
 core/tests/myTests.sh                         |  42 ++
 core/tests/test_apps.bitcoin.approver.py      |  25 +-
 core/tests/test_apps.bitcoin.authorization.py |  25 +-
 core/tests/test_apps.bitcoin.keychain.py      |  60 +-
 core/tests/test_apps.common.keychain.py       |  26 +-
 core/tests/test_apps.ethereum.keychain.py     |  35 +-
 core/tests/test_storage.cache.py              | 674 +++++++++++++-----
 core/tests/test_trezor.wire.codec.codec_v1.py |  52 +-
 core/tests/test_trezor.wire.thp.checksum.py   |  94 +++
 core/tests/test_trezor.wire.thp.crypto.py     | 156 ++++
 core/tests/test_trezor.wire.thp.py            | 378 ++++++++++
 core/tests/test_trezor.wire.thp.writer.py     | 151 ++++
 core/tests/test_trezor.wire.thp_deprecated.py | 338 +++++++++
 core/tests/thp_common.py                      |  44 ++
 core/tools/codegen/get_trezor_keys.py         |   2 +-
 65 files changed, 6198 insertions(+), 539 deletions(-)
 create mode 100644 core/src/apps/thp/create_new_session.py
 create mode 100644 core/src/apps/thp/pairing.py
 create mode 100644 core/src/storage/cache_thp.py
 create mode 100644 core/src/trezor/enums/ThpMessageType.py
 create mode 100644 core/src/trezor/enums/ThpPairingMethod.py
 create mode 100644 core/src/trezor/wire/thp/__init__.py
 create mode 100644 core/src/trezor/wire/thp/alternating_bit_protocol.py
 create mode 100644 core/src/trezor/wire/thp/channel.py
 create mode 100644 core/src/trezor/wire/thp/channel_manager.py
 create mode 100644 core/src/trezor/wire/thp/checksum.py
 create mode 100644 core/src/trezor/wire/thp/control_byte.py
 create mode 100644 core/src/trezor/wire/thp/cpace.py
 create mode 100644 core/src/trezor/wire/thp/crypto.py
 create mode 100644 core/src/trezor/wire/thp/interface_manager.py
 create mode 100644 core/src/trezor/wire/thp/memory_manager.py
 create mode 100644 core/src/trezor/wire/thp/pairing_context.py
 create mode 100644 core/src/trezor/wire/thp/received_message_handler.py
 create mode 100644 core/src/trezor/wire/thp/session_context.py
 create mode 100644 core/src/trezor/wire/thp/session_manager.py
 create mode 100644 core/src/trezor/wire/thp/thp_main.py
 create mode 100644 core/src/trezor/wire/thp/transmission_loop.py
 create mode 100644 core/src/trezor/wire/thp/writer.py
 create mode 100644 core/tests/mock_wire_interface.py
 create mode 100755 core/tests/myTests.sh
 create mode 100644 core/tests/test_trezor.wire.thp.checksum.py
 create mode 100644 core/tests/test_trezor.wire.thp.crypto.py
 create mode 100644 core/tests/test_trezor.wire.thp.py
 create mode 100644 core/tests/test_trezor.wire.thp.writer.py
 create mode 100644 core/tests/test_trezor.wire.thp_deprecated.py
 create mode 100644 core/tests/thp_common.py

diff --git a/core/src/all_modules.py b/core/src/all_modules.py
index 70651ea3a8..42914e34ba 100644
--- a/core/src/all_modules.py
+++ b/core/src/all_modules.py
@@ -51,6 +51,8 @@ storage.cache_codec
 import storage.cache_codec
 storage.cache_common
 import storage.cache_common
+storage.cache_thp
+import storage.cache_thp
 storage.common
 import storage.common
 storage.debug
@@ -419,10 +421,52 @@ apps.workflow_handlers
 import apps.workflow_handlers
 
 if utils.USE_THP:
+    trezor.enums.ThpMessageType
+    import trezor.enums.ThpMessageType
+    trezor.enums.ThpPairingMethod
+    import trezor.enums.ThpPairingMethod
+    trezor.wire.thp
+    import trezor.wire.thp
+    trezor.wire.thp.alternating_bit_protocol
+    import trezor.wire.thp.alternating_bit_protocol
+    trezor.wire.thp.channel
+    import trezor.wire.thp.channel
+    trezor.wire.thp.channel_manager
+    import trezor.wire.thp.channel_manager
+    trezor.wire.thp.checksum
+    import trezor.wire.thp.checksum
+    trezor.wire.thp.control_byte
+    import trezor.wire.thp.control_byte
+    trezor.wire.thp.cpace
+    import trezor.wire.thp.cpace
+    trezor.wire.thp.crypto
+    import trezor.wire.thp.crypto
+    trezor.wire.thp.interface_manager
+    import trezor.wire.thp.interface_manager
+    trezor.wire.thp.memory_manager
+    import trezor.wire.thp.memory_manager
+    trezor.wire.thp.pairing_context
+    import trezor.wire.thp.pairing_context
+    trezor.wire.thp.received_message_handler
+    import trezor.wire.thp.received_message_handler
+    trezor.wire.thp.session_context
+    import trezor.wire.thp.session_context
+    trezor.wire.thp.session_manager
+    import trezor.wire.thp.session_manager
+    trezor.wire.thp.thp_main
+    import trezor.wire.thp.thp_main
+    trezor.wire.thp.transmission_loop
+    import trezor.wire.thp.transmission_loop
+    trezor.wire.thp.writer
+    import trezor.wire.thp.writer
     apps.thp
     import apps.thp
+    apps.thp.create_new_session
+    import apps.thp.create_new_session
     apps.thp.credential_manager
     import apps.thp.credential_manager
+    apps.thp.pairing
+    import apps.thp.pairing
 
 if not utils.BITCOIN_ONLY:
     trezor.enums.BinanceOrderSide
diff --git a/core/src/apps/base.py b/core/src/apps/base.py
index 5552fc86ba..ca923b7337 100644
--- a/core/src/apps/base.py
+++ b/core/src/apps/base.py
@@ -204,33 +204,37 @@ def get_features() -> Features:
     return f
 
 
-async def handle_Initialize(msg: Initialize) -> Features:
-    import storage.cache_codec as cache_codec
+if not utils.USE_THP:
 
-    session_id = cache_codec.start_session(msg.session_id)
+    async def handle_Initialize(msg: Initialize) -> Features:
+        import storage.cache_codec as cache_codec
 
-    if not utils.BITCOIN_ONLY:
-        from storage.cache_common import APP_COMMON_DERIVE_CARDANO
+        session_id = cache_codec.start_session(msg.session_id)
 
-        derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
-        have_seed = context.cache_is_set(APP_COMMON_SEED)
-        if (
-            have_seed
-            and msg.derive_cardano is not None
-            and msg.derive_cardano != bool(derive_cardano)
-        ):
-            # seed is already derived, and host wants to change derive_cardano setting
-            # => create a new session
-            cache_codec.end_current_session()
-            session_id = cache_codec.start_session()
-            have_seed = False
+        if not utils.BITCOIN_ONLY:
+            from storage.cache_common import APP_COMMON_DERIVE_CARDANO
 
-        if not have_seed:
-            context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano))
+            derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
+            have_seed = context.cache_is_set(APP_COMMON_SEED)
+            if (
+                have_seed
+                and msg.derive_cardano is not None
+                and msg.derive_cardano != bool(derive_cardano)
+            ):
+                # seed is already derived, and host wants to change derive_cardano setting
+                # => create a new session
+                cache_codec.end_current_session()
+                session_id = cache_codec.start_session()
+                have_seed = False
 
-    features = get_features()
-    features.session_id = session_id
-    return features
+            if not have_seed:
+                context.cache_set_bool(
+                    APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
+                )
+
+        features = get_features()
+        features.session_id = session_id
+        return features
 
 
 async def handle_GetFeatures(msg: GetFeatures) -> Features:
@@ -464,8 +468,9 @@ def boot() -> None:
     MT = MessageType  # local_cache_global
 
     # Register workflow handlers
+    if not utils.USE_THP:
+        workflow_handlers.register(MT.Initialize, handle_Initialize)
     for msg_type, handler in [
-        (MT.Initialize, handle_Initialize),
         (MT.GetFeatures, handle_GetFeatures),
         (MT.Cancel, handle_Cancel),
         (MT.LockDevice, handle_LockDevice),
diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py
index 35f6b3f60c..e4e77825aa 100644
--- a/core/src/apps/cardano/seed.py
+++ b/core/src/apps/cardano/seed.py
@@ -6,7 +6,7 @@ from storage.cache_common import (
     APP_CARDANO_ICARUS_TREZOR_SECRET,
     APP_COMMON_DERIVE_CARDANO,
 )
-from trezor import wire
+from trezor import utils, wire
 from trezor.crypto import cardano
 from trezor.wire import context
 
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
     from trezor import messages
     from trezor.crypto import bip32
     from trezor.enums import CardanoDerivationType
+    from trezor.wire.protocol_common import Context
 
     from apps.common.keychain import Handler, MsgOut
     from apps.common.paths import Bip32Path
@@ -116,7 +117,7 @@ def is_minting_path(path: Bip32Path) -> bool:
     return path[: len(MINTING_ROOT)] == MINTING_ROOT
 
 
-def derive_and_store_secrets(passphrase: str) -> None:
+def derive_and_store_secrets(ctx: Context, passphrase: str) -> None:
     assert device.is_initialized()
     assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
 
@@ -144,8 +145,7 @@ def derive_and_store_secrets(passphrase: str) -> None:
 
 async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
     from trezor.enums import CardanoDerivationType
-
-    from apps.common.seed import derive_and_store_roots
+    from trezor.wire import context
 
     if not device.is_initialized():
         raise wire.NotInitialized("Device is not initialized")
@@ -164,10 +164,13 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
 
     # _get_secret
     secret = context.cache_get(cache_entry)
-    if secret is None:
-        await derive_and_store_roots()
-        secret = context.cache_get(cache_entry)
-        assert secret is not None
+    if not utils.USE_THP:
+        if secret is None:
+            from apps.common.seed import derive_and_store_roots_legacy
+
+            await derive_and_store_roots_legacy()
+            secret = context.cache_get(cache_entry)
+    assert secret is not None
 
     root = cardano.from_secret(secret)
     return Keychain(root)
diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py
index fc56f42f9b..8037aba698 100644
--- a/core/src/apps/common/backup.py
+++ b/core/src/apps/common/backup.py
@@ -1,7 +1,7 @@
 from typing import TYPE_CHECKING
 
 from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
-from trezor import wire
+from trezor import utils, wire
 from trezor.enums import MessageType
 from trezor.wire import context
 from trezor.wire.message_handler import filters, remove_filter
@@ -24,14 +24,23 @@ def deactivate_repeated_backup() -> None:
     remove_filter(_repeated_backup_filter)
 
 
-_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
-    MessageType.Initialize,
-    MessageType.GetFeatures,
-    MessageType.EndSession,
-    MessageType.BackupDevice,
-    MessageType.WipeDevice,
-    MessageType.Cancel,
-)
+if utils.USE_THP:
+    _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
+        MessageType.GetFeatures,
+        MessageType.EndSession,
+        MessageType.BackupDevice,
+        MessageType.WipeDevice,
+        MessageType.Cancel,
+    )
+else:
+    _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
+        MessageType.Initialize,
+        MessageType.GetFeatures,
+        MessageType.EndSession,
+        MessageType.BackupDevice,
+        MessageType.WipeDevice,
+        MessageType.Cancel,
+    )
 
 
 def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
diff --git a/core/src/apps/common/keychain.py b/core/src/apps/common/keychain.py
index 16913d1529..7959789b25 100644
--- a/core/src/apps/common/keychain.py
+++ b/core/src/apps/common/keychain.py
@@ -1,5 +1,6 @@
 from typing import TYPE_CHECKING
 
+from trezor import utils
 from trezor.crypto import bip32
 from trezor.wire import DataError
 
@@ -172,6 +173,9 @@ async def get_keychain(
 ) -> Keychain:
     from .seed import get_seed
 
+    if not utils.USE_THP:
+        pass
+        # try to ask for passphrase here
     seed = await get_seed()
     keychain = Keychain(seed, curve, schemas, slip21_namespaces)
     return keychain
diff --git a/core/src/apps/common/passphrase.py b/core/src/apps/common/passphrase.py
index ef8bb5b185..d150dd4736 100644
--- a/core/src/apps/common/passphrase.py
+++ b/core/src/apps/common/passphrase.py
@@ -1,84 +1,122 @@
 from micropython import const
+from typing import TYPE_CHECKING
 
 import storage.device as storage_device
+from trezor import utils
 from trezor.wire import DataError
 
 _MAX_PASSPHRASE_LEN = const(50)
 
+if TYPE_CHECKING:
+    from trezor.messages import ThpCreateNewSession
+
 
 def is_enabled() -> bool:
     return storage_device.is_passphrase_enabled()
 
 
-async def get() -> str:
-    from trezor import workflow
-
+async def get_passphrase(msg: ThpCreateNewSession) -> str:
     if not is_enabled():
         return ""
+
+    if msg.on_device or storage_device.get_passphrase_always_on_device():
+        passphrase = await _get_on_device()
     else:
-        workflow.close_others()  # request exclusive UI access
-        if storage_device.get_passphrase_always_on_device():
-            from trezor.ui.layouts import request_passphrase_on_device
+        passphrase = msg.passphrase or ""
+        if passphrase:
+            await _handle_displaying_passphrase_from_host(passphrase)
 
-            passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
-        else:
-            passphrase = await _request_on_host()
-        if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
-            raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
-
-        return passphrase
-
-
-async def _request_on_host() -> str:
-    from trezor import TR
-    from trezor.messages import PassphraseAck, PassphraseRequest
-    from trezor.ui.layouts import request_passphrase_on_host
-    from trezor.wire.context import call
-
-    request_passphrase_on_host()
-
-    request = PassphraseRequest()
-    ack = await call(request, PassphraseAck)
-    passphrase = ack.passphrase  # local_cache_attribute
-
-    if ack.on_device:
-        from trezor.ui.layouts import request_passphrase_on_device
-
-        if passphrase is not None:
-            raise DataError("Passphrase provided when it should not be")
-        return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
-
-    if passphrase is None:
-        raise DataError(
-            "Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
-        )
-
-    # non-empty passphrase
-    if passphrase:
-        from trezor.ui.layouts import confirm_action, confirm_blob
-
-        # We want to hide the passphrase, or show it, according to settings.
-        if storage_device.get_hide_passphrase_from_host():
-            await confirm_action(
-                "passphrase_host1_hidden",
-                TR.passphrase__wallet,
-                description=TR.passphrase__from_host_not_shown,
-                prompt_screen=True,
-                prompt_title=TR.passphrase__access_wallet,
-            )
-        else:
-            await confirm_action(
-                "passphrase_host1",
-                TR.passphrase__wallet,
-                description=TR.passphrase__next_screen_will_show_passphrase,
-                verb=TR.buttons__continue,
-            )
-
-            await confirm_blob(
-                "passphrase_host2",
-                TR.passphrase__title_confirm,
-                passphrase,
-                info=False,
-            )
+    if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
+        raise DataError(f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes")
 
     return passphrase
+
+
+async def _get_on_device() -> str:
+    from trezor import workflow
+    from trezor.ui.layouts import request_passphrase_on_device
+
+    workflow.close_others()  # request exclusive UI access
+    passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
+
+    return passphrase
+
+
+async def _handle_displaying_passphrase_from_host(passphrase: str) -> None:
+    from trezor import TR
+    from trezor.ui.layouts import confirm_action, confirm_blob
+
+    # We want to hide the passphrase, or show it, according to settings.
+    if storage_device.get_hide_passphrase_from_host():
+        await confirm_action(
+            "passphrase_host1_hidden",
+            TR.passphrase__wallet,
+            description=TR.passphrase__from_host_not_shown,
+            prompt_screen=True,
+            prompt_title=TR.passphrase__access_wallet,
+        )
+    else:
+        await confirm_action(
+            "passphrase_host1",
+            TR.passphrase__wallet,
+            description=TR.passphrase__next_screen_will_show_passphrase,
+            verb=TR.buttons__continue,
+        )
+
+        await confirm_blob(
+            "passphrase_host2",
+            TR.passphrase__title_confirm,
+            passphrase,
+        )
+
+
+if not utils.USE_THP:
+
+    async def get() -> str:
+        from trezor import workflow
+
+        if not is_enabled():
+            return ""
+        else:
+            workflow.close_others()  # request exclusive UI access
+            if storage_device.get_passphrase_always_on_device():
+                from trezor.ui.layouts import request_passphrase_on_device
+
+                passphrase = await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
+            else:
+                passphrase = await _request_on_host()
+            if len(passphrase.encode()) > _MAX_PASSPHRASE_LEN:
+                raise DataError(
+                    f"Maximum passphrase length is {_MAX_PASSPHRASE_LEN} bytes"
+                )
+
+            return passphrase
+
+    async def _request_on_host() -> str:
+        from trezor.messages import PassphraseAck, PassphraseRequest
+        from trezor.ui.layouts import request_passphrase_on_host
+        from trezor.wire.context import call
+
+        request_passphrase_on_host()
+
+        request = PassphraseRequest()
+        ack = await call(request, PassphraseAck)
+        passphrase = ack.passphrase  # local_cache_attribute
+
+        if ack.on_device:
+            from trezor.ui.layouts import request_passphrase_on_device
+
+            if passphrase is not None:
+                raise DataError("Passphrase provided when it should not be")
+            return await request_passphrase_on_device(_MAX_PASSPHRASE_LEN)
+
+        if passphrase is None:
+            raise DataError(
+                "Passphrase not provided and on_device is False. Use empty string to set an empty passphrase."
+            )
+
+        # non-empty passphrase
+        if passphrase:
+            await _handle_displaying_passphrase_from_host(passphrase)
+
+        return passphrase
diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py
index b09004ae69..4bb15184f8 100644
--- a/core/src/apps/common/seed.py
+++ b/core/src/apps/common/seed.py
@@ -5,14 +5,18 @@ from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPH
 from trezor import utils
 from trezor.crypto import hmac
 from trezor.wire import context
+from trezor.wire.context import get_context
+from trezor.wire.errors import DataError
 
 from apps.common import cache
 
 from . import mnemonic
-from .passphrase import get as get_passphrase
+from .passphrase import get_passphrase as get_passphrase
 
 if TYPE_CHECKING:
     from trezor.crypto import bip32
+    from trezor.messages import ThpCreateNewSession
+    from trezor.wire.protocol_common import Context
 
     from .paths import Bip32Path, Slip21Path
 
@@ -22,6 +26,9 @@ if not utils.BITCOIN_ONLY:
         APP_COMMON_DERIVE_CARDANO,
     )
 
+if not utils.USE_THP:
+    from .passphrase import get as get_passphrase_legacy
+
 
 class Slip21Node:
     """
@@ -54,51 +61,111 @@ class Slip21Node:
         return Slip21Node(data=self.data)
 
 
-if not utils.BITCOIN_ONLY:
-    # === Cardano variant ===
-    # We want to derive both the normal seed and the Cardano seed together, AND
-    # expose a method for Cardano to do the same
+if utils.USE_THP:
 
-    async def derive_and_store_roots() -> None:
-        from trezor import wire
-
-        if not storage_device.is_initialized():
-            raise wire.NotInitialized("Device is not initialized")
-
-        need_seed = not context.cache_is_set(APP_COMMON_SEED)
-        need_cardano_secret = context.cache_get_bool(
-            APP_COMMON_DERIVE_CARDANO
-        ) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET)
-
-        if not need_seed and not need_cardano_secret:
-            return
-
-        passphrase = await get_passphrase()
-
-        if need_seed:
-            common_seed = mnemonic.get_seed(passphrase)
-            context.cache_set(APP_COMMON_SEED, common_seed)
-
-        if need_cardano_secret:
-            from apps.cardano.seed import derive_and_store_secrets
-
-            derive_and_store_secrets(passphrase)
-
-    @cache.stored_async(APP_COMMON_SEED)
-    async def get_seed() -> bytes:
-        await derive_and_store_roots()
+    async def get_seed() -> bytes:  # type: ignore [Function declaration "get_seed" is obscured by a declaration of the same name]
         common_seed = context.cache_get(APP_COMMON_SEED)
         assert common_seed is not None
         return common_seed
 
-else:
-    # === Bitcoin-only variant ===
-    # We use the simple version of `get_seed` that never needs to derive anything else.
+    if utils.BITCOIN_ONLY:
+        # === Bitcoin_only variant ===
+        # We want to derive the normal seed ONLY
 
-    @cache.stored_async(APP_COMMON_SEED)
-    async def get_seed() -> bytes:
-        passphrase = await get_passphrase()
-        return mnemonic.get_seed(passphrase)
+        async def derive_and_store_roots(
+            ctx: Context, msg: ThpCreateNewSession
+        ) -> None:
+
+            if msg.passphrase is not None and msg.on_device:
+                raise DataError("Passphrase provided when it shouldn't be!")
+
+            if ctx.cache.is_set(APP_COMMON_SEED):
+                raise Exception("Seed is already set!")
+
+            from trezor import wire
+
+            if not storage_device.is_initialized():
+                raise wire.NotInitialized("Device is not initialized")
+
+            passphrase = await get_passphrase(msg)
+            common_seed = mnemonic.get_seed(passphrase)
+            ctx.cache.set(APP_COMMON_SEED, common_seed)
+
+    else:
+        # === Cardano variant ===
+        # We want to derive both the normal seed and the Cardano seed together
+        async def derive_and_store_roots(
+            ctx: Context, msg: ThpCreateNewSession
+        ) -> None:
+
+            if msg.passphrase is not None and msg.on_device:
+                raise DataError("Passphrase provided when it shouldn't be!")
+
+            from trezor import wire
+
+            if not storage_device.is_initialized():
+                raise wire.NotInitialized("Device is not initialized")
+
+            if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET):
+                raise Exception("Cardano icarus secret is already set!")
+
+            passphrase = await get_passphrase(msg)
+            common_seed = mnemonic.get_seed(passphrase)
+            ctx.cache.set(APP_COMMON_SEED, common_seed)
+
+            if msg.derive_cardano:
+                from apps.cardano.seed import derive_and_store_secrets
+
+                ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True)
+                derive_and_store_secrets(ctx, passphrase)
+
+else:
+    if utils.BITCOIN_ONLY:
+        # === Bitcoin-only variant ===
+        # We use the simple version of `get_seed` that never needs to derive anything else.
+
+        @cache.stored_async(APP_COMMON_SEED)
+        async def get_seed() -> bytes:
+            passphrase = await get_passphrase_legacy()
+            return mnemonic.get_seed(passphrase=passphrase)
+
+    else:
+        # === Cardano variant ===
+        # We want to derive both the normal seed and the Cardano seed together, AND
+        # expose a method for Cardano to do the same
+
+        @cache.stored_async(APP_COMMON_SEED)
+        async def get_seed() -> bytes:
+            await derive_and_store_roots_legacy()
+            common_seed = context.cache_get(APP_COMMON_SEED)
+            assert common_seed is not None
+            return common_seed
+
+        async def derive_and_store_roots_legacy() -> None:
+            from trezor import wire
+
+            if not storage_device.is_initialized():
+                raise wire.NotInitialized("Device is not initialized")
+
+            ctx = get_context()
+            need_seed = not ctx.cache.is_set(APP_COMMON_SEED)
+            need_cardano_secret = ctx.cache.get_bool(
+                APP_COMMON_DERIVE_CARDANO
+            ) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET)
+
+            if not need_seed and not need_cardano_secret:
+                return
+
+            passphrase = await get_passphrase_legacy()
+
+            if need_seed:
+                common_seed = mnemonic.get_seed(passphrase)
+                ctx.cache.set(APP_COMMON_SEED, common_seed)
+
+            if need_cardano_secret:
+                from apps.cardano.seed import derive_and_store_secrets
+
+                derive_and_store_secrets(ctx, passphrase)
 
 
 @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)
diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py
index 6486a27e1c..71c5745eaa 100644
--- a/core/src/apps/debug/__init__.py
+++ b/core/src/apps/debug/__init__.py
@@ -1,3 +1,5 @@
+from trezor.wire import message_handler
+
 if not __debug__:
     from trezor.utils import halt
 
@@ -29,13 +31,14 @@ if __debug__:
             DebugLinkState,
         )
         from trezor.ui import Layout
-        from trezor.wire import WireInterface, context
+        from trezor.wire import WireInterface
+        from trezor.wire.protocol_common import Context
 
         Handler = Callable[[Any], Awaitable[Any]]
 
     layout_change_box = loop.mailbox()
 
-    DEBUG_CONTEXT: context.Context | None = None
+    DEBUG_CONTEXT: Context | None = None
 
     REFRESH_INDEX = 0
 
@@ -70,9 +73,7 @@ if __debug__:
                     "layout deadlock detected (did you send a ButtonAck?)"
                 )
 
-    async def return_layout_change(
-        ctx: wire.protocol_common.Context, detect_deadlock: bool = False
-    ) -> None:
+    async def return_layout_change(ctx: Context, detect_deadlock: bool = False) -> None:
         # set up the wait
         storage.layout_watcher = True
 
@@ -212,12 +213,12 @@ if __debug__:
 
         x = msg.x  # local_cache_attribute
         y = msg.y  # local_cache_attribute
-
         await wait_until_layout_is_running()
         assert isinstance(ui.CURRENT_LAYOUT, ui.Layout)
         layout_change_box.clear()
 
         try:
+
             # click on specific coordinates, with possible hold
             if x is not None and y is not None:
                 await _layout_click(x, y, msg.hold_ms or 0)
@@ -229,7 +230,11 @@ if __debug__:
             elif msg.button is not None:
                 await _layout_event(msg.button)
             elif msg.input is not None:
-                ui.CURRENT_LAYOUT._emit_message(msg.input)
+                try:
+                    ui.CURRENT_LAYOUT._emit_message(msg.input)
+                except Exception as e:
+                    print(type(e))
+
             else:
                 raise RuntimeError("Invalid DebugLinkDecision message")
 
@@ -244,7 +249,11 @@ if __debug__:
         # If no exception was raised, the layout did not shut down. That means that it
         # just updated itself. The update is already live for the caller to retrieve.
 
-    def _state() -> DebugLinkState:
+    def _state(
+        thp_pairing_code_entry_code: int | None = None,
+        thp_pairing_code_qr_code: bytes | None = None,
+        thp_pairing_code_nfc_unidirectional: bytes | None = None,
+    ) -> DebugLinkState:
         from trezor.messages import DebugLinkState
 
         from apps.common import mnemonic, passphrase
@@ -263,13 +272,45 @@ if __debug__:
             passphrase_protection=passphrase.is_enabled(),
             reset_entropy=storage.reset_internal_entropy,
             tokens=tokens,
+            thp_pairing_code_entry_code=thp_pairing_code_entry_code,
+            thp_pairing_code_qr_code=thp_pairing_code_qr_code,
+            thp_pairing_code_nfc_unidirectional=thp_pairing_code_nfc_unidirectional,
         )
 
     async def dispatch_DebugLinkGetState(
         msg: DebugLinkGetState,
     ) -> DebugLinkState | None:
+
+        thp_pairing_code_entry_code: int | None = None
+        thp_pairing_code_qr_code: bytes | None = None
+        thp_pairing_code_nfc_unidirectional: bytes | None = None
+        if utils.USE_THP and msg.thp_channel_id is not None:
+            channel_id = int.from_bytes(msg.thp_channel_id, "big")
+
+            from trezor.wire.thp.channel import Channel
+            from trezor.wire.thp.pairing_context import PairingContext
+            from trezor.wire.thp.thp_main import _CHANNELS
+
+            channel: Channel | None = None
+            ctx: PairingContext | None = None
+            try:
+                channel = _CHANNELS[channel_id]
+                ctx = channel.connection_context
+            except KeyError:
+                pass
+            if ctx is not None and isinstance(ctx, PairingContext):
+                thp_pairing_code_entry_code = ctx.display_data.code_code_entry
+                thp_pairing_code_qr_code = ctx.display_data.code_qr_code
+                thp_pairing_code_nfc_unidirectional = (
+                    ctx.display_data.code_nfc_unidirectional
+                )
+
         if msg.wait_layout == DebugWaitType.IMMEDIATE:
-            return _state()
+            return _state(
+                thp_pairing_code_entry_code,
+                thp_pairing_code_qr_code,
+                thp_pairing_code_nfc_unidirectional,
+            )
 
         assert DEBUG_CONTEXT is not None
         if msg.wait_layout == DebugWaitType.NEXT_LAYOUT:
@@ -280,7 +321,11 @@ if __debug__:
         if not layout_is_ready():
             return await return_layout_change(DEBUG_CONTEXT, detect_deadlock=True)
         else:
-            return _state()
+            return _state(
+                thp_pairing_code_entry_code,
+                thp_pairing_code_qr_code,
+                thp_pairing_code_nfc_unidirectional,
+            )
 
     async def dispatch_DebugLinkRecordScreen(msg: DebugLinkRecordScreen) -> Success:
         if msg.target_directory:
@@ -390,7 +435,6 @@ if __debug__:
                     ctx.iface.iface_num(),
                     msg_type,
                 )
-
                 if msg.type not in WORKFLOW_HANDLERS:
                     await ctx.write(wire.message_handler.unexpected_message())
                     continue
@@ -403,7 +447,7 @@ if __debug__:
                     await ctx.write(Success())
                     continue
 
-                req_msg = wire.message_handler.wrap_protobuf_load(msg.data, req_type)
+                req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
                 try:
                     res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg)
                 except Exception as exc:
diff --git a/core/src/apps/management/reboot_to_bootloader.py b/core/src/apps/management/reboot_to_bootloader.py
index 85596c0268..2213d2c17a 100644
--- a/core/src/apps/management/reboot_to_bootloader.py
+++ b/core/src/apps/management/reboot_to_bootloader.py
@@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
         boot_args = None
 
     ctx = get_context()
-    await ctx.write(Success(message="Rebooting"))
+    await ctx.write_force(Success(message="Rebooting"))
     # make sure the outgoing USB buffer is flushed
     await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
     # reboot to the bootloader, pass the firmware header hash if any
diff --git a/core/src/apps/management/recovery_device/__init__.py b/core/src/apps/management/recovery_device/__init__.py
index 10ca5f6377..f722a63c7f 100644
--- a/core/src/apps/management/recovery_device/__init__.py
+++ b/core/src/apps/management/recovery_device/__init__.py
@@ -24,6 +24,7 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
     from trezor import TR, config, wire, workflow
     from trezor.enums import BackupType, ButtonRequestType
     from trezor.ui.layouts import confirm_action, confirm_reset_device
+    from trezor.wire.context import try_get_ctx_ids
 
     from apps.common import mnemonic
     from apps.common.request_pin import (
@@ -69,8 +70,8 @@ async def recovery_device(msg: RecoveryDevice) -> Success:
     if recovery_type == RecoveryType.NormalRecovery:
         await confirm_reset_device(recovery=True)
 
-        # wipe storage to make sure the device is in a clear state
-        storage.reset()
+        # wipe storage to make sure the device is in a clear state (except protocol cache)
+        storage.reset(excluded=try_get_ctx_ids())
 
         # set up pin if requested
         if msg.pin_protection:
diff --git a/core/src/apps/management/recovery_device/homescreen.py b/core/src/apps/management/recovery_device/homescreen.py
index 7ad56a4742..f90acaacee 100644
--- a/core/src/apps/management/recovery_device/homescreen.py
+++ b/core/src/apps/management/recovery_device/homescreen.py
@@ -3,8 +3,9 @@ from typing import TYPE_CHECKING
 import storage.device as storage_device
 import storage.recovery as storage_recovery
 import storage.recovery_shares as storage_recovery_shares
-from trezor import TR, wire
+from trezor import TR, utils, wire
 from trezor.messages import Success
+from trezor.wire import message_handler
 
 from apps.common import backup_types
 
@@ -38,18 +39,26 @@ async def recovery_process() -> Success:
 
     recovery_type = storage_recovery.get_type()
 
-    wire.message_handler.AVOID_RESTARTING_FOR = (
-        MessageType.Initialize,
-        MessageType.GetFeatures,
-        MessageType.EndSession,
-    )
+    if utils.USE_THP:
+        message_handler.AVOID_RESTARTING_FOR = (
+            MessageType.GetFeatures,
+            MessageType.EndSession,
+        )
+    else:
+        message_handler.AVOID_RESTARTING_FOR = (
+            MessageType.Initialize,
+            MessageType.GetFeatures,
+            MessageType.EndSession,
+        )
     try:
         return await _continue_recovery_process()
     except recover.RecoveryAborted:
         storage_recovery.end_progress()
         backup.deactivate_repeated_backup()
         if recovery_type == RecoveryType.NormalRecovery:
-            storage.wipe()
+            from trezor.wire.context import try_get_ctx_ids
+
+            storage.wipe(excluded=try_get_ctx_ids())
         raise wire.ActionCancelled
 
 
@@ -59,11 +68,17 @@ async def _continue_repeated_backup() -> None:
     from apps.common import backup
     from apps.management.backup_device import perform_backup
 
-    wire.message_handler.AVOID_RESTARTING_FOR = (
-        MessageType.Initialize,
-        MessageType.GetFeatures,
-        MessageType.EndSession,
-    )
+    if utils.USE_THP:
+        message_handler.AVOID_RESTARTING_FOR = (
+            MessageType.GetFeatures,
+            MessageType.EndSession,
+        )
+    else:
+        message_handler.AVOID_RESTARTING_FOR = (
+            MessageType.Initialize,
+            MessageType.GetFeatures,
+            MessageType.EndSession,
+        )
 
     try:
         await perform_backup(is_repeated_backup=True)
diff --git a/core/src/apps/management/reset_device/__init__.py b/core/src/apps/management/reset_device/__init__.py
index 4b3d8bf2ef..4840d31a21 100644
--- a/core/src/apps/management/reset_device/__init__.py
+++ b/core/src/apps/management/reset_device/__init__.py
@@ -38,7 +38,7 @@ async def reset_device(msg: ResetDevice) -> Success:
         prompt_backup,
         show_wallet_created_success,
     )
-    from trezor.wire.context import call
+    from trezor.wire.context import call, try_get_ctx_ids
 
     from apps.common.request_pin import request_pin_confirm
 
@@ -60,8 +60,8 @@ async def reset_device(msg: ResetDevice) -> Success:
     # Rendering empty loader so users do not feel a freezing screen
     render_empty_loader(config.StorageMessage.PROCESSING_MSG)
 
-    # wipe storage to make sure the device is in a clear state
-    storage.reset()
+    # wipe storage to make sure the device is in a clear state (except protocol cache)
+    storage.reset(excluded=try_get_ctx_ids())
 
     # Check backup type, perform type-specific handling
     if backup_types.is_slip39_backup_type(backup_type):
@@ -139,7 +139,7 @@ async def reset_device(msg: ResetDevice) -> Success:
     if perform_backup:
         await layout.show_backup_success()
 
-    return Success(message="Initialized")
+    return Success(message="Initialized")  # TODO: Why "Initialized?"
 
 
 async def _entropy_check(secret: bytes) -> bool:
diff --git a/core/src/apps/management/wipe_device.py b/core/src/apps/management/wipe_device.py
index b6e60057a6..1abdc3f3e6 100644
--- a/core/src/apps/management/wipe_device.py
+++ b/core/src/apps/management/wipe_device.py
@@ -1,12 +1,19 @@
 from typing import TYPE_CHECKING
 
+from trezor.wire.context import get_context, try_get_ctx_ids
+
 if TYPE_CHECKING:
-    from trezor.messages import Success, WipeDevice
+    from typing import NoReturn
+
+    from trezor.messages import WipeDevice
+
+if __debug__:
+    from trezor import log
 
 
-async def wipe_device(msg: WipeDevice) -> Success:
+async def wipe_device(msg: WipeDevice) -> NoReturn:
     import storage
-    from trezor import TR, config, translations
+    from trezor import TR, config, loop, translations
     from trezor.enums import ButtonRequestType
     from trezor.messages import Success
     from trezor.pin import render_empty_loader
@@ -26,16 +33,22 @@ async def wipe_device(msg: WipeDevice) -> Success:
         br_code=ButtonRequestType.WipeDevice,
     )
 
+    if __debug__:
+        log.debug(__name__, "Device wipe - start")
+
     # start an empty progress screen so that the screen is not blank while waiting
     render_empty_loader(config.StorageMessage.PROCESSING_MSG)
-
     # wipe storage
-    storage.wipe()
+    storage.wipe(excluded=try_get_ctx_ids())
     # erase translations
     translations.deinit()
     translations.erase()
 
+    await get_context().write_force(Success(message="Device wiped"))
+    storage.wipe_cache()
+
     # reload settings
     reload_settings_from_storage()
-
-    return Success(message="Device wiped")
+    loop.clear()
+    if __debug__:
+        log.debug(__name__, "Device wipe - finished")
diff --git a/core/src/apps/thp/create_new_session.py b/core/src/apps/thp/create_new_session.py
new file mode 100644
index 0000000000..156b852d46
--- /dev/null
+++ b/core/src/apps/thp/create_new_session.py
@@ -0,0 +1,59 @@
+from trezor import log, loop
+from trezor.enums import FailureType
+from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
+from trezor.wire.context import get_context
+from trezor.wire.errors import ActionCancelled, DataError
+from trezor.wire.thp import SessionState
+
+
+async def create_new_session(message: ThpCreateNewSession) -> ThpNewSession | Failure:
+    """
+    Creates a new `ThpSession` based on the provided parameters and returns a
+    `ThpNewSession` message containing the new session ID.
+
+    Returns an appropriate `Failure` message if session creation fails.
+    """
+    from trezor.wire import NotInitialized
+    from trezor.wire.thp.session_context import GenericSessionContext
+    from trezor.wire.thp.session_manager import create_new_session
+
+    from apps.common.seed import derive_and_store_roots
+
+    ctx = get_context()
+
+    # Assert that context `ctx` is `GenericSessionContext`
+    assert isinstance(ctx, GenericSessionContext)
+
+    channel = ctx.channel
+
+    # Do not use `ctx` beyond this point, as it is techically
+    # allowed to change in between await statements
+
+    new_session = create_new_session(channel)
+    try:
+        await derive_and_store_roots(new_session, message)
+    except DataError as e:
+        return Failure(code=FailureType.DataError, message=e.message)
+    except ActionCancelled as e:
+        return Failure(code=FailureType.ActionCancelled, message=e.message)
+    except NotInitialized as e:
+        return Failure(code=FailureType.NotInitialized, message=e.message)
+    # TODO handle other errors (`Exception`` when "Cardano icarus secret is already set!"
+    # and `RuntimeError` when accessing storage for mnemonic.get_secret - it actually
+    # happens for locked devices)
+
+    new_session.set_session_state(SessionState.ALLOCATED)
+    channel.sessions[new_session.session_id] = new_session
+    loop.schedule(new_session.handle())
+    new_session_id: int = new_session.session_id
+
+    if __debug__:
+        log.debug(
+            __name__,
+            "create_new_session - new session created. Passphrase: %s, Session id: %d\n%s",
+            message.passphrase if message.passphrase is not None else "",
+            new_session.session_id,
+            str(channel.sessions),
+        )
+
+    return ThpNewSession(new_session_id=new_session_id)
diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py
new file mode 100644
index 0000000000..0f47e6b967
--- /dev/null
+++ b/core/src/apps/thp/pairing.py
@@ -0,0 +1,403 @@
+from typing import TYPE_CHECKING
+from ubinascii import hexlify
+
+from trezor import loop, protobuf
+from trezor.crypto.hashlib import sha256
+from trezor.enums import ThpMessageType, ThpPairingMethod
+from trezor.messages import (
+    Cancel,
+    ThpCodeEntryChallenge,
+    ThpCodeEntryCommitment,
+    ThpCodeEntryCpaceHost,
+    ThpCodeEntryCpaceTrezor,
+    ThpCodeEntrySecret,
+    ThpCodeEntryTag,
+    ThpCredentialMetadata,
+    ThpCredentialRequest,
+    ThpCredentialResponse,
+    ThpEndRequest,
+    ThpEndResponse,
+    ThpNfcUnidirectionalSecret,
+    ThpNfcUnidirectionalTag,
+    ThpPairingPreparationsFinished,
+    ThpQrCodeSecret,
+    ThpQrCodeTag,
+    ThpStartPairingRequest,
+)
+from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
+from trezor.wire.thp import ChannelState, ThpError, crypto
+from trezor.wire.thp.pairing_context import PairingContext
+
+from .credential_manager import issue_credential
+
+if __debug__:
+    from trezor import log
+
+if TYPE_CHECKING:
+    from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
+
+    P = ParamSpec("P")
+    FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
+
+#
+# Helpers - decorators
+
+
+def check_state_and_log(
+    *allowed_states: ChannelState,
+) -> Callable[[FuncWithContext], FuncWithContext]:
+    def decorator(f: FuncWithContext) -> FuncWithContext:
+        def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
+            _check_state(context, *allowed_states)
+            if __debug__:
+                try:
+                    log.debug(__name__, "started %s", f.__name__)
+                except AttributeError:
+                    log.debug(
+                        __name__,
+                        "started a function that cannot be named, because it raises AttributeError, eg. closure",
+                    )
+            return f(context, *args, **kwargs)
+
+        return inner
+
+    return decorator
+
+
+def check_method_is_allowed(
+    pairing_method: ThpPairingMethod,
+) -> Callable[[FuncWithContext], FuncWithContext]:
+    def decorator(f: FuncWithContext) -> FuncWithContext:
+        def inner(context: PairingContext, *args: P.args, **kwargs: P.kwargs) -> object:
+            _check_method_is_allowed(context, pairing_method)
+            return f(context, *args, **kwargs)
+
+        return inner
+
+    return decorator
+
+
+#
+# Pairing handlers
+
+
+@check_state_and_log(ChannelState.TP1)
+async def handle_pairing_request(
+    ctx: PairingContext, message: protobuf.MessageType
+) -> ThpEndResponse:
+
+    if not ThpStartPairingRequest.is_type_of(message):
+        raise UnexpectedMessage("Unexpected message")
+
+    ctx.host_name = message.host_name or ""
+
+    skip_pairing = _is_method_included(ctx, ThpPairingMethod.NoMethod)
+    if skip_pairing:
+        return await _end_pairing(ctx)
+
+    await _prepare_pairing(ctx)
+    await ctx.write(ThpPairingPreparationsFinished())
+    ctx.channel_ctx.set_channel_state(ChannelState.TP3)
+    response = await show_display_data(
+        ctx, _get_possible_pairing_methods_and_cancel(ctx)
+    )
+
+    if Cancel.is_type_of(response):
+        ctx.channel_ctx.clear()
+        raise SilentError("Action was cancelled by the Host")
+    # TODO disable NFC (if enabled)
+    response = await _handle_different_pairing_methods(ctx, response)
+
+    while ThpCredentialRequest.is_type_of(response):
+        response = await _handle_credential_request(ctx, response)
+
+    return await _handle_end_request(ctx, response)
+
+
+async def _prepare_pairing(ctx: PairingContext) -> None:
+
+    if _is_method_included(ctx, ThpPairingMethod.CodeEntry):
+        await _handle_code_entry_is_included(ctx)
+
+    if _is_method_included(ctx, ThpPairingMethod.QrCode):
+        _handle_qr_code_is_included(ctx)
+
+    if _is_method_included(ctx, ThpPairingMethod.NFC_Unidirectional):
+        _handle_nfc_unidirectional_is_included(ctx)
+
+
+async def show_display_data(
+    ctx: PairingContext, expected_types: Container[int] = ()
+) -> type[protobuf.MessageType]:
+    from trezorui_api import CANCELLED
+
+    read_task = ctx.read(expected_types)
+    cancel_task = ctx.display_data.get_display_layout()
+    race = loop.race(read_task, cancel_task.get_result())
+    result: type[protobuf.MessageType] = await race
+
+    if result is CANCELLED:
+        raise ActionCancelled
+
+    return result
+
+
+@check_state_and_log(ChannelState.TP1)
+async def _handle_code_entry_is_included(ctx: PairingContext) -> None:
+    commitment = sha256(ctx.secret).digest()
+
+    challenge_message = await ctx.call(  # noqa: F841
+        ThpCodeEntryCommitment(commitment=commitment), ThpCodeEntryChallenge
+    )
+    ctx.channel_ctx.set_channel_state(ChannelState.TP2)
+
+    if not ThpCodeEntryChallenge.is_type_of(challenge_message):
+        raise UnexpectedMessage("Unexpected message")
+
+    if challenge_message.challenge is None:
+        raise Exception("Invalid message")
+    sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
+    sha_ctx.update(ctx.secret)
+    sha_ctx.update(challenge_message.challenge)
+    sha_ctx.update(bytes("PairingMethod_CodeEntry", "utf-8"))
+    code_code_entry_hash = sha_ctx.digest()
+    ctx.display_data.code_code_entry = (
+        int.from_bytes(code_code_entry_hash, "big") % 1000000
+    )
+
+
+@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
+def _handle_qr_code_is_included(ctx: PairingContext) -> None:
+    sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
+    sha_ctx.update(ctx.secret)
+    sha_ctx.update(bytes("PairingMethod_QrCode", "utf-8"))
+    ctx.display_data.code_qr_code = sha_ctx.digest()[:16]
+
+
+@check_state_and_log(ChannelState.TP1, ChannelState.TP2)
+def _handle_nfc_unidirectional_is_included(ctx: PairingContext) -> None:
+    sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
+    sha_ctx.update(ctx.secret)
+    sha_ctx.update(bytes("PairingMethod_NfcUnidirectional", "utf-8"))
+    ctx.display_data.code_nfc_unidirectional = sha_ctx.digest()[:16]
+
+
+@check_state_and_log(ChannelState.TP3)
+async def _handle_different_pairing_methods(
+    ctx: PairingContext, response: protobuf.MessageType
+) -> protobuf.MessageType:
+    if ThpCodeEntryCpaceHost.is_type_of(response):
+        return await _handle_code_entry_cpace(ctx, response)
+    if ThpQrCodeTag.is_type_of(response):
+        return await _handle_qr_code_tag(ctx, response)
+    if ThpNfcUnidirectionalTag.is_type_of(response):
+        return await _handle_nfc_unidirectional_tag(ctx, response)
+    raise UnexpectedMessage("Unexpected message")
+
+
+@check_state_and_log(ChannelState.TP3)
+@check_method_is_allowed(ThpPairingMethod.CodeEntry)
+async def _handle_code_entry_cpace(
+    ctx: PairingContext, message: protobuf.MessageType
+) -> protobuf.MessageType:
+    from trezor.wire.thp.cpace import Cpace
+
+    # TODO check that ThpCodeEntryCpaceHost message is valid
+
+    if TYPE_CHECKING:
+        assert isinstance(message, ThpCodeEntryCpaceHost)
+    if message.cpace_host_public_key is None:
+        raise ThpError("Message ThpCodeEntryCpaceHost has no public key")
+
+    ctx.cpace = Cpace(
+        message.cpace_host_public_key,
+        ctx.channel_ctx.get_handshake_hash(),
+    )
+    assert ctx.display_data.code_code_entry is not None
+    ctx.cpace.generate_keys_and_secret(
+        ctx.display_data.code_code_entry.to_bytes(6, "big")
+    )
+
+    ctx.channel_ctx.set_channel_state(ChannelState.TP4)
+    response = await ctx.call(
+        ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key),
+        ThpCodeEntryTag,
+    )
+    return await _handle_code_entry_tag(ctx, response)
+
+
+@check_state_and_log(ChannelState.TP4)
+@check_method_is_allowed(ThpPairingMethod.CodeEntry)
+async def _handle_code_entry_tag(
+    ctx: PairingContext, message: protobuf.MessageType
+) -> protobuf.MessageType:
+
+    if TYPE_CHECKING:
+        assert isinstance(message, ThpCodeEntryTag)
+
+    expected_tag = sha256(ctx.cpace.shared_secret).digest()
+    if expected_tag != message.tag:
+        print(
+            "expected code entry tag:", hexlify(expected_tag).decode()
+        )  # TODO remove after testing
+        print(
+            "expected code entry shared secret:",
+            hexlify(ctx.cpace.shared_secret).decode(),
+        )  # TODO remove after testing
+        raise ThpError("Unexpected Code Entry Tag")
+
+    return await _handle_secret_reveal(
+        ctx,
+        msg=ThpCodeEntrySecret(secret=ctx.secret),
+    )
+
+
+@check_state_and_log(ChannelState.TP3)
+@check_method_is_allowed(ThpPairingMethod.QrCode)
+async def _handle_qr_code_tag(
+    ctx: PairingContext, message: protobuf.MessageType
+) -> protobuf.MessageType:
+    if TYPE_CHECKING:
+        assert isinstance(message, ThpQrCodeTag)
+    assert ctx.display_data.code_qr_code is not None
+    expected_tag = sha256(ctx.display_data.code_qr_code).digest()
+    if expected_tag != message.tag:
+        print(
+            "expected qr code tag:", hexlify(expected_tag).decode()
+        )  # TODO remove after testing
+        print(
+            "expected code qr code tag:",
+            hexlify(ctx.display_data.code_qr_code).decode(),
+        )  # TODO remove after testing
+        print(
+            "expected secret:", hexlify(ctx.secret).decode()
+        )  # TODO remove after testing
+        raise ThpError("Unexpected QR Code Tag")
+
+    return await _handle_secret_reveal(
+        ctx,
+        msg=ThpQrCodeSecret(secret=ctx.secret),
+    )
+
+
+@check_state_and_log(ChannelState.TP3)
+@check_method_is_allowed(ThpPairingMethod.NFC_Unidirectional)
+async def _handle_nfc_unidirectional_tag(
+    ctx: PairingContext, message: protobuf.MessageType
+) -> protobuf.MessageType:
+    if TYPE_CHECKING:
+        assert isinstance(message, ThpNfcUnidirectionalTag)
+
+    expected_tag = sha256(ctx.display_data.code_nfc_unidirectional).digest()
+    if expected_tag != message.tag:
+        print(
+            "expected nfc tag:", hexlify(expected_tag).decode()
+        )  # TODO remove after testing
+        raise ThpError("Unexpected NFC Unidirectional Tag")
+
+    return await _handle_secret_reveal(
+        ctx,
+        msg=ThpNfcUnidirectionalSecret(secret=ctx.secret),
+    )
+
+
+@check_state_and_log(ChannelState.TP3, ChannelState.TP4)
+async def _handle_secret_reveal(
+    ctx: PairingContext,
+    msg: protobuf.MessageType,
+) -> protobuf.MessageType:
+    ctx.channel_ctx.set_channel_state(ChannelState.TC1)
+    return await ctx.call_any(
+        msg,
+        ThpMessageType.ThpCredentialRequest,
+        ThpMessageType.ThpEndRequest,
+    )
+
+
+@check_state_and_log(ChannelState.TC1)
+async def _handle_credential_request(
+    ctx: PairingContext, message: protobuf.MessageType
+) -> protobuf.MessageType:
+    ctx.secret
+
+    if not ThpCredentialRequest.is_type_of(message):
+        raise UnexpectedMessage("Unexpected message")
+    if message.host_static_pubkey is None:
+        raise Exception("Invalid message")  # TODO change failure type
+
+    trezor_static_pubkey = crypto.get_trezor_static_pubkey()
+    credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name)
+    credential = issue_credential(message.host_static_pubkey, credential_metadata)
+
+    return await ctx.call_any(
+        ThpCredentialResponse(
+            trezor_static_pubkey=trezor_static_pubkey, credential=credential
+        ),
+        ThpMessageType.ThpCredentialRequest,
+        ThpMessageType.ThpEndRequest,
+    )
+
+
+@check_state_and_log(ChannelState.TC1)
+async def _handle_end_request(
+    ctx: PairingContext, message: protobuf.MessageType
+) -> ThpEndResponse:
+    if not ThpEndRequest.is_type_of(message):
+        raise UnexpectedMessage("Unexpected message")
+    return await _end_pairing(ctx)
+
+
+async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
+    ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
+    return ThpEndResponse()
+
+
+#
+# Helpers - checkers
+
+
+def _check_state(ctx: PairingContext, *allowed_states: ChannelState) -> None:
+    if ctx.channel_ctx.get_channel_state() not in allowed_states:
+        raise UnexpectedMessage("Unexpected message")
+
+
+def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
+    if not _is_method_included(ctx, method):
+        raise ThpError("Unexpected pairing method")
+
+
+def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
+    return method in ctx.channel_ctx.selected_pairing_methods
+
+
+#
+# Helpers - getters
+
+
+def _get_possible_pairing_methods_and_cancel(ctx: PairingContext) -> Tuple[int, ...]:
+    r = _get_possible_pairing_methods(ctx)
+    mtype = Cancel.MESSAGE_WIRE_TYPE
+    return r + ((mtype,) if mtype is not None else ())
+
+
+def _get_possible_pairing_methods(ctx: PairingContext) -> Tuple[int, ...]:
+    r = tuple(
+        _get_message_type_for_method(method)
+        for method in ctx.channel_ctx.selected_pairing_methods
+    )
+    if __debug__:
+        from trezor.messages import DebugLinkGetState
+
+        mtype = DebugLinkGetState.MESSAGE_WIRE_TYPE
+        return r + ((mtype,) if mtype is not None else ())
+    return r
+
+
+def _get_message_type_for_method(method: int) -> int:
+    if method is ThpPairingMethod.CodeEntry:
+        return ThpMessageType.ThpCodeEntryCpaceHost
+    if method is ThpPairingMethod.NFC_Unidirectional:
+        return ThpMessageType.ThpNfcUnidirectionalTag
+    if method is ThpPairingMethod.QrCode:
+        return ThpMessageType.ThpQrCodeTag
+    raise ValueError("Unexpected pairing method - no message type available")
diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py
index b65c853c93..3013516382 100644
--- a/core/src/apps/workflow_handlers.py
+++ b/core/src/apps/workflow_handlers.py
@@ -35,6 +35,13 @@ def _find_message_handler_module(msg_type: int) -> str:
     if __debug__ and msg_type == MessageType.BenchmarkRun:
         return "apps.benchmark.run"
 
+    if utils.USE_THP:
+        from trezor.enums import ThpMessageType
+
+        # thp management
+        if msg_type == ThpMessageType.ThpCreateNewSession:
+            return "apps.thp.create_new_session"
+
     # management
     if msg_type == MessageType.ResetDevice:
         return "apps.management.reset_device"
diff --git a/core/src/storage/__init__.py b/core/src/storage/__init__.py
index 3a012874f3..2fe2c845d9 100644
--- a/core/src/storage/__init__.py
+++ b/core/src/storage/__init__.py
@@ -1,11 +1,27 @@
 # make sure to import cache unconditionally at top level so that it is imported (and retained) together with the storage module
+from typing import TYPE_CHECKING
+
 from storage import cache, common, device
 
+if TYPE_CHECKING:
+    from typing import Tuple
 
-def wipe() -> None:
+    pass
+
+
+def wipe(excluded: Tuple[bytes, bytes] | None) -> None:
+    """
+    TODO REPHRASE SO THAT IT IS TRUE! Wipes the storage. Using `exclude_protocol=False` destroys the THP communication channel.
+    If the device should communicate after wipe, use `exclude_protocol=True` and clear cache manually later using
+    `wipe_cache()`.
+    """
     from trezor import config
 
     config.wipe()
+    cache.clear_all(excluded)
+
+
+def wipe_cache() -> None:
     cache.clear_all()
 
 
@@ -21,12 +37,12 @@ def init_unlocked() -> None:
         common.set_bool(common.APP_DEVICE, device.INITIALIZED, True, public=True)
 
 
-def reset() -> None:
+def reset(excluded: Tuple[bytes, bytes] | None) -> None:
     """
     Wipes storage but keeps the device id unchanged.
     """
     device_id = device.get_device_id()
-    wipe()
+    wipe(excluded)
     common.set(common.APP_DEVICE, device.DEVICE_ID, device_id.encode(), public=True)
 
 
diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py
index 72d8a1e418..6db224a782 100644
--- a/core/src/storage/cache.py
+++ b/core/src/storage/cache.py
@@ -1,26 +1,47 @@
 import builtins
 import gc
+from typing import TYPE_CHECKING
 
-from storage import cache_codec
 from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
+from trezor import utils
+
+if TYPE_CHECKING:
+    from typing import Tuple
+
+    pass
 
 # Cache initialization
 _SESSIONLESS_CACHE = SessionlessCache()
-_PROTOCOL_CACHE = cache_codec
+
+
+if utils.USE_THP:
+    from storage import cache_thp
+
+    _PROTOCOL_CACHE = cache_thp
+else:
+    from storage import cache_codec
+
+    _PROTOCOL_CACHE = cache_codec
+
 _PROTOCOL_CACHE.initialize()
 _SESSIONLESS_CACHE.clear()
 
 gc.collect()
 
 
-def clear_all() -> None:
+def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None:
     """
     Clears all data from both the protocol cache and the sessionless cache.
     """
     global autolock_last_touch
     autolock_last_touch = None
     _SESSIONLESS_CACHE.clear()
-    _PROTOCOL_CACHE.clear_all()
+
+    if utils.USE_THP and excluded is not None:
+        # If we want to keep THP connection alive, we do not clear communication keys
+        cache_thp.clear_all_except_one_session_keys(excluded)
+    else:
+        _PROTOCOL_CACHE.clear_all()
 
 
 def get_int_all_sessions(key: int) -> builtins.set[int]:
diff --git a/core/src/storage/cache_common.py b/core/src/storage/cache_common.py
index 90cead81db..40eee905cc 100644
--- a/core/src/storage/cache_common.py
+++ b/core/src/storage/cache_common.py
@@ -14,6 +14,14 @@ if not utils.BITCOIN_ONLY:
     APP_CARDANO_ICARUS_TREZOR_SECRET = const(6)
     APP_MONERO_LIVE_REFRESH = const(7)
 
+# Cache keys for THP channel
+if utils.USE_THP:
+    CHANNEL_HANDSHAKE_HASH = const(0)
+    CHANNEL_KEY_RECEIVE = const(1)
+    CHANNEL_KEY_SEND = const(2)
+    CHANNEL_NONCE_RECEIVE = const(3)
+    CHANNEL_NONCE_SEND = const(4)
+
 # Keys that are valid across sessions
 SESSIONLESS_FLAG = const(128)
 APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py
new file mode 100644
index 0000000000..6ed41b8415
--- /dev/null
+++ b/core/src/storage/cache_thp.py
@@ -0,0 +1,363 @@
+import builtins
+from micropython import const
+from typing import TYPE_CHECKING
+
+from storage.cache_common import DataCache
+
+if TYPE_CHECKING:
+    from typing import Tuple
+
+    pass
+
+
+# THP specific constants
+_MAX_CHANNELS_COUNT = const(10)
+_MAX_SESSIONS_COUNT = const(20)
+
+
+_CHANNEL_STATE_LENGTH = const(1)
+_WIRE_INTERFACE_LENGTH = const(1)
+_SESSION_STATE_LENGTH = const(1)
+_CHANNEL_ID_LENGTH = const(2)
+SESSION_ID_LENGTH = const(1)
+BROADCAST_CHANNEL_ID = const(0xFFFF)
+KEY_LENGTH = const(32)
+TAG_LENGTH = const(16)
+_UNALLOCATED_STATE = const(0)
+_MANAGEMENT_STATE = const(2)
+MANAGEMENT_SESSION_ID = const(0)
+
+
+class ThpDataCache(DataCache):
+    def __init__(self) -> None:
+        self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
+        self.last_usage = 0
+        super().__init__()
+
+    def clear(self) -> None:
+        self.channel_id[:] = b""
+        self.last_usage = 0
+        super().clear()
+
+
+class ChannelCache(ThpDataCache):
+    def __init__(self) -> None:
+        self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
+        self.state = bytearray(_CHANNEL_STATE_LENGTH)
+        self.iface = bytearray(1)  # TODO add decoding
+        self.sync = 0x80  # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
+        self.session_id_counter = 0x00
+        self.fields = (
+            32,  # CHANNEL_HANDSHAKE_HASH
+            32,  # CHANNEL_KEY_RECEIVE
+            32,  # CHANNEL_KEY_SEND
+            8,  # CHANNEL_NONCE_RECEIVE
+            8,  # CHANNEL_NONCE_SEND
+        )
+        super().__init__()
+
+    def clear(self) -> None:
+        self.state[:] = bytearray(
+            int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
+        )  # Set state to UNALLOCATED
+        self.host_ephemeral_pubkey[:] = bytearray(KEY_LENGTH)
+        self.state[:] = bytearray(_CHANNEL_STATE_LENGTH)
+        self.iface[:] = bytearray(1)
+        super().clear()
+
+
+class SessionThpCache(ThpDataCache):
+    def __init__(self) -> None:
+        from trezor import utils
+
+        self.session_id = bytearray(SESSION_ID_LENGTH)
+        self.state = bytearray(_SESSION_STATE_LENGTH)
+        if utils.BITCOIN_ONLY:
+            self.fields = (
+                64,  # APP_COMMON_SEED
+                2,  # APP_COMMON_AUTHORIZATION_TYPE
+                128,  # APP_COMMON_AUTHORIZATION_DATA
+                32,  # APP_COMMON_NONCE
+            )
+        else:
+            self.fields = (
+                64,  # APP_COMMON_SEED
+                2,  # APP_COMMON_AUTHORIZATION_TYPE
+                128,  # APP_COMMON_AUTHORIZATION_DATA
+                32,  # APP_COMMON_NONCE
+                0,  # APP_COMMON_DERIVE_CARDANO
+                96,  # APP_CARDANO_ICARUS_SECRET
+                96,  # APP_CARDANO_ICARUS_TREZOR_SECRET
+                0,  # APP_MONERO_LIVE_REFRESH
+            )
+        super().__init__()
+
+    def clear(self) -> None:
+        self.state[:] = bytearray(int.to_bytes(0, 1, "big"))  # Set state to UNALLOCATED
+        self.session_id[:] = b""
+        super().clear()
+
+
+_CHANNELS: list[ChannelCache] = []
+_SESSIONS: list[SessionThpCache] = []
+cid_counter: int = 0
+
+# Last-used counter
+_usage_counter = 0
+
+
+def initialize() -> None:
+    global _CHANNELS
+    global _SESSIONS
+    global cid_counter
+
+    for _ in range(_MAX_CHANNELS_COUNT):
+        _CHANNELS.append(ChannelCache())
+    for _ in range(_MAX_SESSIONS_COUNT):
+        _SESSIONS.append(SessionThpCache())
+
+    for channel in _CHANNELS:
+        channel.clear()
+    for session in _SESSIONS:
+        session.clear()
+
+    from trezorcrypto import random
+
+    cid_counter = random.uniform(0xFFFE)
+
+
+def get_new_channel(iface: bytes) -> ChannelCache:
+    if len(iface) != _WIRE_INTERFACE_LENGTH:
+        raise Exception("Invalid WireInterface (encoded) length")
+
+    new_cid = get_next_channel_id()
+    index = _get_next_channel_index()
+
+    # clear sessions from replaced channel
+    if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
+        old_cid = _CHANNELS[index].channel_id
+        clear_sessions_with_channel_id(old_cid)
+
+    _CHANNELS[index] = ChannelCache()
+    _CHANNELS[index].channel_id[:] = new_cid
+    _CHANNELS[index].last_usage = _get_usage_counter_and_increment()
+    _CHANNELS[index].state[:] = bytearray(
+        _UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
+    )
+    _CHANNELS[index].iface[:] = bytearray(iface)
+    return _CHANNELS[index]
+
+
+def update_channel_last_used(channel_id: bytes) -> None:
+    for channel in _CHANNELS:
+        if channel.channel_id == channel_id:
+            channel.last_usage = _get_usage_counter_and_increment()
+            return
+
+
+def update_session_last_used(channel_id: bytes, session_id: bytes) -> None:
+    for session in _SESSIONS:
+        if session.channel_id == channel_id and session.session_id == session_id:
+            session.last_usage = _get_usage_counter_and_increment()
+            update_channel_last_used(channel_id)
+            return
+
+
+def get_all_allocated_channels() -> list[ChannelCache]:
+    _list: list[ChannelCache] = []
+    for channel in _CHANNELS:
+        if _get_channel_state(channel) != _UNALLOCATED_STATE:
+            _list.append(channel)
+    return _list
+
+
+def get_allocated_session(
+    channel_id: bytes, session_id: bytes
+) -> SessionThpCache | None:
+    """
+    Finds and returns the first allocated session matching the given `channel_id` and `session_id`,
+    or `None` if no match is found.
+
+    Raises `Exception` if either channel_id or session_id has an invalid length.
+    """
+    if len(channel_id) != _CHANNEL_ID_LENGTH or len(session_id) != SESSION_ID_LENGTH:
+        raise Exception("At least one of arguments has invalid length")
+
+    for session in _SESSIONS:
+        if _get_session_state(session) == _UNALLOCATED_STATE:
+            continue
+        if session.channel_id != channel_id:
+            continue
+        if session.session_id != session_id:
+            continue
+        return session
+    return None
+
+
+def is_management_session(session_cache: SessionThpCache) -> bool:
+    return _get_session_state(session_cache) == _MANAGEMENT_STATE
+
+
+def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
+    if len(key) != KEY_LENGTH:
+        raise Exception("Invalid key length")
+    channel.host_ephemeral_pubkey = key
+
+
+def get_new_session(channel: ChannelCache) -> SessionThpCache:
+    new_sid = get_next_session_id(channel)
+    index = _get_next_session_index()
+
+    _SESSIONS[index] = SessionThpCache()
+    _SESSIONS[index].channel_id[:] = channel.channel_id
+    _SESSIONS[index].session_id[:] = new_sid
+    _SESSIONS[index].last_usage = _get_usage_counter_and_increment()
+    channel.last_usage = (
+        _get_usage_counter_and_increment()
+    )  # increment also use of the channel so it does not get replaced
+    _SESSIONS[index].state[:] = bytearray(
+        _UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
+    )
+    return _SESSIONS[index]
+
+
+def _get_usage_counter_and_increment() -> int:
+    global _usage_counter
+    _usage_counter += 1
+    return _usage_counter
+
+
+def _get_next_channel_index() -> int:
+    idx = _get_unallocated_channel_index()
+    if idx is not None:
+        return idx
+    return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
+
+
+def _get_next_session_index() -> int:
+    idx = _get_unallocated_session_index()
+    if idx is not None:
+        return idx
+    return get_least_recently_used_item(_SESSIONS, max_count=_MAX_SESSIONS_COUNT)
+
+
+def _get_unallocated_channel_index() -> int | None:
+    for i in range(_MAX_CHANNELS_COUNT):
+        if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
+            return i
+    return None
+
+
+def _get_unallocated_session_index() -> int | None:
+    for i in range(_MAX_SESSIONS_COUNT):
+        if (_SESSIONS[i]) is _UNALLOCATED_STATE:
+            return i
+    return None
+
+
+def _get_channel_state(channel: ChannelCache) -> int:
+    return int.from_bytes(channel.state, "big")
+
+
+def _get_session_state(session: SessionThpCache) -> int:
+    return int.from_bytes(session.state, "big")
+
+
+def get_next_channel_id() -> bytes:
+    global cid_counter
+    while True:
+        cid_counter += 1
+        if cid_counter >= BROADCAST_CHANNEL_ID:
+            cid_counter = 1
+        if _is_cid_unique():
+            break
+    return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
+
+
+def get_next_session_id(channel: ChannelCache) -> bytes:
+    while True:
+        if channel.session_id_counter >= 255:
+            channel.session_id_counter = 1
+        else:
+            channel.session_id_counter += 1
+        if _is_session_id_unique(channel):
+            break
+    new_sid = channel.session_id_counter
+    return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
+
+
+def _is_session_id_unique(channel: ChannelCache) -> bool:
+    for session in _SESSIONS:
+        if session.channel_id == channel.channel_id:
+            if session.session_id == channel.session_id_counter:
+                return False
+    return True
+
+
+def _is_cid_unique() -> bool:
+    global cid_counter
+    cid_counter_bytes = cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
+    for channel in _CHANNELS:
+        if channel.channel_id == cid_counter_bytes:
+            return False
+    return True
+
+
+def get_least_recently_used_item(
+    list: list[ChannelCache] | list[SessionThpCache], max_count: int
+) -> int:
+    global _usage_counter
+    lru_counter = _usage_counter + 1
+    lru_item_index = 0
+    for i in range(max_count):
+        if list[i].last_usage < lru_counter:
+            lru_counter = list[i].last_usage
+            lru_item_index = i
+    return lru_item_index
+
+
+def get_int_all_sessions(key: int) -> builtins.set[int]:
+    values = builtins.set()
+    for session in _SESSIONS:
+        encoded = session.get(key)
+        if encoded is not None:
+            values.add(int.from_bytes(encoded, "big"))
+    return values
+
+
+def clear_sessions_with_channel_id(channel_id: bytes) -> None:
+    for session in _SESSIONS:
+        if session.channel_id == channel_id:
+            session.clear()
+
+
+def clear_session(session: SessionThpCache) -> None:
+    for s in _SESSIONS:
+        if s.channel_id == session.channel_id and s.session_id == session.session_id:
+            session.clear()
+
+
+def clear_all() -> None:
+    for session in _SESSIONS:
+        session.clear()
+    for channel in _CHANNELS:
+        channel.clear()
+
+
+def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None:
+    cid, sid = excluded
+
+    for channel in _CHANNELS:
+        if channel.channel_id != cid:
+            channel.clear()
+
+    for session in _SESSIONS:
+        if session.channel_id != cid and session.session_id != sid:
+            session.clear()
+        else:
+            s_last_usage = session.last_usage
+            session.clear()
+            session.last_usage = s_last_usage
+            session.state = bytearray(_MANAGEMENT_STATE.to_bytes(1, "big"))
+            session.session_id[:] = bytearray(sid)
+            session.channel_id[:] = bytearray(cid)
diff --git a/core/src/trezor/enums/FailureType.py b/core/src/trezor/enums/FailureType.py
index fbb2001e54..883844307a 100644
--- a/core/src/trezor/enums/FailureType.py
+++ b/core/src/trezor/enums/FailureType.py
@@ -16,4 +16,6 @@ NotInitialized = 11
 PinMismatch = 12
 WipeCodeMismatch = 13
 InvalidSession = 14
+ThpUnallocatedSession = 15
+InvalidProtocol = 16
 FirmwareError = 99
diff --git a/core/src/trezor/enums/ThpMessageType.py b/core/src/trezor/enums/ThpMessageType.py
new file mode 100644
index 0000000000..45a34120e5
--- /dev/null
+++ b/core/src/trezor/enums/ThpMessageType.py
@@ -0,0 +1,22 @@
+# Automatically generated by pb2py
+# fmt: off
+# isort:skip_file
+
+ThpCreateNewSession = 1000
+ThpNewSession = 1001
+ThpStartPairingRequest = 1008
+ThpPairingPreparationsFinished = 1009
+ThpCredentialRequest = 1010
+ThpCredentialResponse = 1011
+ThpEndRequest = 1012
+ThpEndResponse = 1013
+ThpCodeEntryCommitment = 1016
+ThpCodeEntryChallenge = 1017
+ThpCodeEntryCpaceHost = 1018
+ThpCodeEntryCpaceTrezor = 1019
+ThpCodeEntryTag = 1020
+ThpCodeEntrySecret = 1021
+ThpQrCodeTag = 1024
+ThpQrCodeSecret = 1025
+ThpNfcUnidirectionalTag = 1032
+ThpNfcUnidirectionalSecret = 1033
diff --git a/core/src/trezor/enums/ThpPairingMethod.py b/core/src/trezor/enums/ThpPairingMethod.py
new file mode 100644
index 0000000000..b356cdf470
--- /dev/null
+++ b/core/src/trezor/enums/ThpPairingMethod.py
@@ -0,0 +1,8 @@
+# Automatically generated by pb2py
+# fmt: off
+# isort:skip_file
+
+NoMethod = 1
+CodeEntry = 2
+QrCode = 3
+NFC_Unidirectional = 4
diff --git a/core/src/trezor/enums/__init__.py b/core/src/trezor/enums/__init__.py
index d16c3c4a66..d8421393b8 100644
--- a/core/src/trezor/enums/__init__.py
+++ b/core/src/trezor/enums/__init__.py
@@ -39,6 +39,8 @@ if TYPE_CHECKING:
         PinMismatch = 12
         WipeCodeMismatch = 13
         InvalidSession = 14
+        ThpUnallocatedSession = 15
+        InvalidProtocol = 16
         FirmwareError = 99
 
     class ButtonRequestType(IntEnum):
@@ -347,6 +349,32 @@ if TYPE_CHECKING:
         Nay = 1
         Pass = 2
 
+    class ThpMessageType(IntEnum):
+        ThpCreateNewSession = 1000
+        ThpNewSession = 1001
+        ThpStartPairingRequest = 1008
+        ThpPairingPreparationsFinished = 1009
+        ThpCredentialRequest = 1010
+        ThpCredentialResponse = 1011
+        ThpEndRequest = 1012
+        ThpEndResponse = 1013
+        ThpCodeEntryCommitment = 1016
+        ThpCodeEntryChallenge = 1017
+        ThpCodeEntryCpaceHost = 1018
+        ThpCodeEntryCpaceTrezor = 1019
+        ThpCodeEntryTag = 1020
+        ThpCodeEntrySecret = 1021
+        ThpQrCodeTag = 1024
+        ThpQrCodeSecret = 1025
+        ThpNfcUnidirectionalTag = 1032
+        ThpNfcUnidirectionalSecret = 1033
+
+    class ThpPairingMethod(IntEnum):
+        NoMethod = 1
+        CodeEntry = 2
+        QrCode = 3
+        NFC_Unidirectional = 4
+
     class MessageType(IntEnum):
         Initialize = 0
         Ping = 1
diff --git a/core/src/trezor/messages.py b/core/src/trezor/messages.py
index e529707f4b..0c64aae3c0 100644
--- a/core/src/trezor/messages.py
+++ b/core/src/trezor/messages.py
@@ -68,6 +68,8 @@ if TYPE_CHECKING:
     from trezor.enums import StellarSignerType  # noqa: F401
     from trezor.enums import TezosBallotType  # noqa: F401
     from trezor.enums import TezosContractType  # noqa: F401
+    from trezor.enums import ThpMessageType  # noqa: F401
+    from trezor.enums import ThpPairingMethod  # noqa: F401
     from trezor.enums import WordRequestType  # noqa: F401
 
     class BenchmarkListNames(protobuf.MessageType):
@@ -2898,11 +2900,13 @@ if TYPE_CHECKING:
 
     class DebugLinkGetState(protobuf.MessageType):
         wait_layout: "DebugWaitType"
+        thp_channel_id: "bytes | None"
 
         def __init__(
             self,
             *,
             wait_layout: "DebugWaitType | None" = None,
+            thp_channel_id: "bytes | None" = None,
         ) -> None:
             pass
 
@@ -2924,6 +2928,9 @@ if TYPE_CHECKING:
         reset_word_pos: "int | None"
         mnemonic_type: "BackupType | None"
         tokens: "list[str]"
+        thp_pairing_code_entry_code: "int | None"
+        thp_pairing_code_qr_code: "bytes | None"
+        thp_pairing_code_nfc_unidirectional: "bytes | None"
 
         def __init__(
             self,
@@ -2941,6 +2948,9 @@ if TYPE_CHECKING:
             recovery_word_pos: "int | None" = None,
             reset_word_pos: "int | None" = None,
             mnemonic_type: "BackupType | None" = None,
+            thp_pairing_code_entry_code: "int | None" = None,
+            thp_pairing_code_qr_code: "bytes | None" = None,
+            thp_pairing_code_nfc_unidirectional: "bytes | None" = None,
         ) -> None:
             pass
 
@@ -6162,6 +6172,278 @@ if TYPE_CHECKING:
         def is_type_of(cls, msg: Any) -> TypeGuard["TezosManagerTransfer"]:
             return isinstance(msg, cls)
 
+    class ThpDeviceProperties(protobuf.MessageType):
+        internal_model: "str | None"
+        model_variant: "int | None"
+        bootloader_mode: "bool | None"
+        protocol_version: "int | None"
+        pairing_methods: "list[ThpPairingMethod]"
+
+        def __init__(
+            self,
+            *,
+            pairing_methods: "list[ThpPairingMethod] | None" = None,
+            internal_model: "str | None" = None,
+            model_variant: "int | None" = None,
+            bootloader_mode: "bool | None" = None,
+            protocol_version: "int | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpDeviceProperties"]:
+            return isinstance(msg, cls)
+
+    class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType):
+        host_pairing_credential: "bytes | None"
+        pairing_methods: "list[ThpPairingMethod]"
+
+        def __init__(
+            self,
+            *,
+            pairing_methods: "list[ThpPairingMethod] | None" = None,
+            host_pairing_credential: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpHandshakeCompletionReqNoisePayload"]:
+            return isinstance(msg, cls)
+
+    class ThpCreateNewSession(protobuf.MessageType):
+        passphrase: "str | None"
+        on_device: "bool | None"
+        derive_cardano: "bool | None"
+
+        def __init__(
+            self,
+            *,
+            passphrase: "str | None" = None,
+            on_device: "bool | None" = None,
+            derive_cardano: "bool | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCreateNewSession"]:
+            return isinstance(msg, cls)
+
+    class ThpNewSession(protobuf.MessageType):
+        new_session_id: "int | None"
+
+        def __init__(
+            self,
+            *,
+            new_session_id: "int | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpNewSession"]:
+            return isinstance(msg, cls)
+
+    class ThpStartPairingRequest(protobuf.MessageType):
+        host_name: "str | None"
+
+        def __init__(
+            self,
+            *,
+            host_name: "str | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpStartPairingRequest"]:
+            return isinstance(msg, cls)
+
+    class ThpPairingPreparationsFinished(protobuf.MessageType):
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpPairingPreparationsFinished"]:
+            return isinstance(msg, cls)
+
+    class ThpCodeEntryCommitment(protobuf.MessageType):
+        commitment: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            commitment: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCommitment"]:
+            return isinstance(msg, cls)
+
+    class ThpCodeEntryChallenge(protobuf.MessageType):
+        challenge: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            challenge: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryChallenge"]:
+            return isinstance(msg, cls)
+
+    class ThpCodeEntryCpaceHost(protobuf.MessageType):
+        cpace_host_public_key: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            cpace_host_public_key: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceHost"]:
+            return isinstance(msg, cls)
+
+    class ThpCodeEntryCpaceTrezor(protobuf.MessageType):
+        cpace_trezor_public_key: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            cpace_trezor_public_key: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryCpaceTrezor"]:
+            return isinstance(msg, cls)
+
+    class ThpCodeEntryTag(protobuf.MessageType):
+        tag: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            tag: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntryTag"]:
+            return isinstance(msg, cls)
+
+    class ThpCodeEntrySecret(protobuf.MessageType):
+        secret: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            secret: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCodeEntrySecret"]:
+            return isinstance(msg, cls)
+
+    class ThpQrCodeTag(protobuf.MessageType):
+        tag: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            tag: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeTag"]:
+            return isinstance(msg, cls)
+
+    class ThpQrCodeSecret(protobuf.MessageType):
+        secret: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            secret: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpQrCodeSecret"]:
+            return isinstance(msg, cls)
+
+    class ThpNfcUnidirectionalTag(protobuf.MessageType):
+        tag: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            tag: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalTag"]:
+            return isinstance(msg, cls)
+
+    class ThpNfcUnidirectionalSecret(protobuf.MessageType):
+        secret: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            secret: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpNfcUnidirectionalSecret"]:
+            return isinstance(msg, cls)
+
+    class ThpCredentialRequest(protobuf.MessageType):
+        host_static_pubkey: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            host_static_pubkey: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialRequest"]:
+            return isinstance(msg, cls)
+
+    class ThpCredentialResponse(protobuf.MessageType):
+        trezor_static_pubkey: "bytes | None"
+        credential: "bytes | None"
+
+        def __init__(
+            self,
+            *,
+            trezor_static_pubkey: "bytes | None" = None,
+            credential: "bytes | None" = None,
+        ) -> None:
+            pass
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpCredentialResponse"]:
+            return isinstance(msg, cls)
+
+    class ThpEndRequest(protobuf.MessageType):
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndRequest"]:
+            return isinstance(msg, cls)
+
+    class ThpEndResponse(protobuf.MessageType):
+
+        @classmethod
+        def is_type_of(cls, msg: Any) -> TypeGuard["ThpEndResponse"]:
+            return isinstance(msg, cls)
+
     class ThpCredentialMetadata(protobuf.MessageType):
         host_name: "str | None"
 
diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py
index 3fcb47af35..30d379fd7f 100644
--- a/core/src/trezor/utils.py
+++ b/core/src/trezor/utils.py
@@ -34,6 +34,10 @@ from trezorutils import (  # noqa: F401
 )
 from typing import TYPE_CHECKING
 
+DISABLE_ENCRYPTION: bool = False
+
+ALLOW_DEBUG_MESSAGES: bool = True
+
 if __debug__:
     if EMULATOR:
         import uos
diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py
index 2662a5610a..1bc847f273 100644
--- a/core/src/trezor/wire/__init__.py
+++ b/core/src/trezor/wire/__init__.py
@@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
 
 - Request / response.
 - Protobuf-encoded, see `protobuf.py`.
-- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`.
+- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py` or `trezor/wire/thp/thp_main.py`.
 - Transferred over USB interface, or UDP in case of Unix emulation.
 
 This module:
@@ -29,7 +29,12 @@ from typing import TYPE_CHECKING
 from trezor import log, loop, protobuf, utils
 
 from . import message_handler, protocol_common
-from .codec.codec_context import CodecContext
+
+if utils.USE_THP:
+    from .thp import thp_main
+else:
+    from .codec.codec_context import CodecContext
+
 from .context import UnexpectedMessageException
 from .message_handler import failure
 
@@ -40,6 +45,8 @@ from .errors import *  # isort:skip # noqa: F401,F403
 _PROTOBUF_BUFFER_SIZE = const(8192)
 
 WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
+if utils.USE_THP:
+    WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
 
 if TYPE_CHECKING:
     from trezorio import WireInterface
@@ -57,57 +64,89 @@ def setup(iface: WireInterface) -> None:
     loop.schedule(handle_session(iface))
 
 
-async def handle_session(iface: WireInterface) -> None:
-    ctx = CodecContext(iface, WIRE_BUFFER)
-    next_msg: protocol_common.Message | None = None
+if utils.USE_THP:
 
-    # Take a mark of modules that are imported at this point, so we can
-    # roll back and un-import any others.
-    modules = utils.unimport_begin()
-    while True:
-        try:
-            if next_msg is None:
-                # If the previous run did not keep an unprocessed message for us,
-                # wait for a new one coming from the wire.
-                try:
-                    msg = await ctx.read_from_wire()
-                except protocol_common.WireError as exc:
-                    if __debug__:
-                        log.exception(__name__, exc)
-                    await ctx.write(failure(exc))
-                    continue
+    async def handle_session(iface: WireInterface) -> None:
 
-            else:
-                # Process the message from previous run.
-                msg = next_msg
-                next_msg = None
+        thp_main.set_read_buffer(WIRE_BUFFER)
+        thp_main.set_write_buffer(WIRE_BUFFER_2)
 
-            do_not_restart = False
+        # Take a mark of modules that are imported at this point, so we can
+        # roll back and un-import any others.
+        modules = utils.unimport_begin()
+
+        while True:
             try:
-                do_not_restart = await message_handler.handle_single_message(ctx, msg)
-            except UnexpectedMessageException as unexpected:
-                # The workflow was interrupted by an unexpected message. We need to
-                # process it as if it was a new message...
-                next_msg = unexpected.msg
-                # ...and we must not restart because that would lose the message.
-                do_not_restart = True
-                continue
+                await thp_main.thp_main_loop(iface)
             except Exception as exc:
-                # Log and ignore. The session handler can only exit explicitly in the
-                # following finally block.
+                # Log and try again.
                 if __debug__:
                     log.exception(__name__, exc)
             finally:
                 # Unload modules imported by the workflow. Should not raise.
+                if __debug__:
+                    log.debug(__name__, "utils.unimport_end(modules) and loop.clear()")
                 utils.unimport_end(modules)
+                loop.clear()
+                return  # pylint: disable=lost-exception
 
-                if not do_not_restart:
-                    # Let the session be restarted from `main`.
-                    loop.clear()
-                    return  # pylint: disable=lost-exception
+else:
 
-        except Exception as exc:
-            # Log and try again. The session handler can only exit explicitly via
-            # loop.clear() above.
-            if __debug__:
-                log.exception(__name__, exc)
+    async def handle_session(iface: WireInterface) -> None:
+        ctx = CodecContext(iface, WIRE_BUFFER)
+        next_msg: protocol_common.Message | None = None
+
+        # Take a mark of modules that are imported at this point, so we can
+        # roll back and un-import any others.
+        modules = utils.unimport_begin()
+        while True:
+            try:
+                if next_msg is None:
+                    # If the previous run did not keep an unprocessed message for us,
+                    # wait for a new one coming from the wire.
+                    try:
+                        msg = await ctx.read_from_wire()
+                    except protocol_common.WireError as exc:
+                        if __debug__:
+                            log.exception(__name__, exc)
+                        await ctx.write(failure(exc))
+                        continue
+
+                else:
+                    # Process the message from previous run.
+                    msg = next_msg
+                    next_msg = None
+
+                do_not_restart = False
+                try:
+                    do_not_restart = await message_handler.handle_single_message(
+                        ctx, msg
+                    )
+                except UnexpectedMessageException as unexpected:
+                    # The workflow was interrupted by an unexpected message. We need to
+                    # process it as if it was a new message...
+                    next_msg = unexpected.msg
+                    # ...and we must not restart because that would lose the message.
+                    do_not_restart = True
+                    continue
+                except Exception as exc:
+                    # Log and ignore. The session handler can only exit explicitly in the
+                    # following finally block.
+                    if __debug__:
+                        log.exception(__name__, exc)
+                finally:
+                    # Unload modules imported by the workflow. Should not raise.
+                    utils.unimport_end(modules)
+
+                    if not do_not_restart:
+                        # Let the session be restarted from `main`.
+                        if __debug__:
+                            log.debug(__name__, "loop.clear()")
+                        loop.clear()
+                        return  # pylint: disable=lost-exception
+
+            except Exception as exc:
+                # Log and try again. The session handler can only exit explicitly via
+                # loop.clear() above.
+                if __debug__:
+                    log.exception(__name__, exc)
diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py
index 56df34fbc5..00bfeb77d4 100644
--- a/core/src/trezor/wire/context.py
+++ b/core/src/trezor/wire/context.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
 
 from storage import cache
 from storage.cache_common import SESSIONLESS_FLAG
-from trezor import loop, protobuf
+from trezor import loop, protobuf, utils
 
 from .protocol_common import Context, Message
 
@@ -138,6 +138,17 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
             send_exc = None
 
 
+def try_get_ctx_ids() -> tuple[bytes, bytes] | None:
+    ids = None
+    if utils.USE_THP:
+        from trezor.wire.thp.session_context import GenericSessionContext
+
+        ctx = get_context()
+        if isinstance(ctx, GenericSessionContext):
+            ids = (ctx.channel_id, ctx.session_id.to_bytes(1, "big"))
+    return ids
+
+
 # ACCESS TO CACHE
 
 if TYPE_CHECKING:
diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py
index 376820b583..e8b2d3feb4 100644
--- a/core/src/trezor/wire/errors.py
+++ b/core/src/trezor/wire/errors.py
@@ -8,6 +8,12 @@ class Error(Exception):
         self.message = message
 
 
+class SilentError(Exception):
+    def __init__(self, message: str) -> None:
+        super().__init__()
+        self.message = message
+
+
 class UnexpectedMessage(Error):
     def __init__(self, message: str) -> None:
         super().__init__(FailureType.UnexpectedMessage, message)
diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py
index 21c901dc90..c0f201de22 100644
--- a/core/src/trezor/wire/message_handler.py
+++ b/core/src/trezor/wire/message_handler.py
@@ -25,7 +25,12 @@ def wrap_protobuf_load(
     expected_type: type[LoadedMessageType],
 ) -> LoadedMessageType:
     try:
-        if __debug__ and utils.EMULATOR and utils.USE_THP:
+        if (
+            __debug__
+            and utils.EMULATOR
+            and utils.USE_THP
+            and utils.ALLOW_DEBUG_MESSAGES
+        ):
             log.debug(
                 __name__,
                 "Buffer to be parsed to a LoadedMessage: %s",
@@ -38,7 +43,7 @@ def wrap_protobuf_load(
             )
         return msg
     except Exception as e:
-        if __debug__:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
             log.exception(__name__, e)
         if e.args:
             raise DataError("Failed to decode message: " + " ".join(e.args))
@@ -46,6 +51,25 @@ def wrap_protobuf_load(
             raise DataError("Failed to decode message")
 
 
+if utils.USE_THP:
+    from trezor.enums import ThpMessageType
+
+    def get_msg_name(msg_type: int) -> str | None:
+        for name in dir(ThpMessageType):
+            if not name.startswith("__"):  # Skip built-in attributes
+                value = getattr(ThpMessageType, name)
+                if isinstance(value, int):
+                    if value == msg_type:
+                        return name
+        return None
+
+    def get_msg_type(msg_name: str) -> int | None:
+        value = getattr(ThpMessageType, msg_name)
+        if isinstance(value, int):
+            return value
+        return None
+
+
 async def handle_single_message(ctx: Context, msg: Message) -> bool:
     """Handle a message that was loaded from a WireInterface by the caller.
 
@@ -60,17 +84,27 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
     the type of message is supposed to be optimized and not disrupt the running state,
     this function will return `True`.
     """
-    if __debug__:
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
         try:
             msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
         except Exception:
             msg_type = f"{msg.type} - unknown message type"
-        log.debug(
-            __name__,
-            "%d receive: <%s>",
-            ctx.iface.iface_num(),
-            msg_type,
-        )
+        if utils.USE_THP:
+            cid = int.from_bytes(ctx.channel_id, "big")
+            log.debug(
+                __name__,
+                "%d:%d receive: <%s>",
+                ctx.iface.iface_num(),
+                cid,
+                msg_type,
+            )
+        else:
+            log.debug(
+                __name__,
+                "%d receive: <%s>",
+                ctx.iface.iface_num(),
+                msg_type,
+            )
 
     res_msg: protobuf.MessageType | None = None
 
@@ -91,7 +125,15 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
     try:
         # Find a protobuf.MessageType subclass that describes this
         # message.  Raises if the type is not found.
-        req_type = protobuf.type_for_wire(msg.type)
+
+        if utils.USE_THP:
+            name = get_msg_name(msg.type)
+            if name is None:
+                req_type = protobuf.type_for_wire(msg.type)
+            else:
+                req_type = protobuf.type_for_name(name)
+        else:
+            req_type = protobuf.type_for_wire(msg.type)
 
         # Try to decode the message according to schema from
         # `req_type`. Raises if the message is malformed.
@@ -132,7 +174,7 @@ async def handle_single_message(ctx: Context, msg: Message) -> bool:
         # - the message was not valid protobuf
         # - workflow raised some kind of an exception while running
         # - something canceled the workflow from the outside
-        if __debug__:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
             if isinstance(exc, ActionCancelled):
                 log.debug(__name__, "cancelled: %s", exc.message)
             elif isinstance(exc, loop.TaskClosed):
diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py
index ed4105517b..0e54afe8c3 100644
--- a/core/src/trezor/wire/protocol_common.py
+++ b/core/src/trezor/wire/protocol_common.py
@@ -4,7 +4,7 @@ from trezor import protobuf
 
 if TYPE_CHECKING:
     from trezorio import WireInterface
-    from typing import Container, TypeVar, overload
+    from typing import Awaitable, Container, TypeVar, overload
 
     from storage.cache_common import DataCache
 
@@ -72,6 +72,9 @@ class Context:
         """Write a message to the wire."""
         ...
 
+    def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
+        return self.write(msg)
+
     async def call(
         self,
         msg: protobuf.MessageType,
diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py
new file mode 100644
index 0000000000..ce61cf815a
--- /dev/null
+++ b/core/src/trezor/wire/thp/__init__.py
@@ -0,0 +1,184 @@
+import ustruct
+from micropython import const
+from typing import TYPE_CHECKING
+
+from storage.cache_thp import BROADCAST_CHANNEL_ID
+from trezor import protobuf, utils
+from trezor.enums import ThpPairingMethod
+from trezor.messages import ThpDeviceProperties
+
+from ..protocol_common import WireError
+
+if TYPE_CHECKING:
+    from enum import IntEnum
+
+    from trezor.wire import WireInterface
+    from typing_extensions import Self
+else:
+    IntEnum = object
+
+CODEC_V1 = const(0x3F)
+
+HANDSHAKE_INIT_REQ = const(0x00)
+HANDSHAKE_INIT_RES = const(0x01)
+HANDSHAKE_COMP_REQ = const(0x02)
+HANDSHAKE_COMP_RES = const(0x03)
+ENCRYPTED = const(0x04)
+
+ACK_MESSAGE = const(0x20)
+CHANNEL_ALLOCATION_REQ = const(0x40)
+_CHANNEL_ALLOCATION_RES = const(0x41)
+_ERROR = const(0x42)
+CONTINUATION_PACKET = const(0x80)
+
+
+class ThpError(WireError):
+    pass
+
+
+class ThpDecryptionError(ThpError):
+    pass
+
+
+class ThpInvalidDataError(ThpError):
+    pass
+
+
+class ThpUnallocatedSessionError(ThpError):
+
+    def __init__(self, session_id: int) -> None:
+        self.session_id = session_id
+
+
+class ThpErrorType(IntEnum):
+    TRANSPORT_BUSY = 1
+    UNALLOCATED_CHANNEL = 2
+    DECRYPTION_FAILED = 3
+    INVALID_DATA = 4
+
+
+class ChannelState(IntEnum):
+    UNALLOCATED = 0
+    TH1 = 1
+    TH2 = 2
+    TP1 = 3
+    TP2 = 4
+    TP3 = 5
+    TP4 = 6
+    TC1 = 7
+    ENCRYPTED_TRANSPORT = 8
+
+
+class SessionState(IntEnum):
+    UNALLOCATED = 0
+    ALLOCATED = 1
+    MANAGEMENT = 2
+
+
+class PacketHeader:
+    format_str_init = ">BHH"
+    format_str_cont = ">BH"
+
+    def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
+        self.ctrl_byte = ctrl_byte
+        self.cid = cid
+        self.length = length
+
+    def to_bytes(self) -> bytes:
+        return ustruct.pack(self.format_str_init, self.ctrl_byte, self.cid, self.length)
+
+    def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
+        """
+        Packs header information in the form of **intial** packet
+        into the provided buffer.
+        """
+        ustruct.pack_into(
+            self.format_str_init,
+            buffer,
+            buffer_offset,
+            self.ctrl_byte,
+            self.cid,
+            self.length,
+        )
+
+    def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
+        """
+        Packs header information in the form of **continuation** packet header
+        into the provided buffer.
+        """
+        ustruct.pack_into(
+            self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
+        )
+
+    @classmethod
+    def get_error_header(cls, cid: int, length: int) -> Self:
+        """
+        Returns header for protocol-level error messages.
+        """
+        return cls(_ERROR, cid, length)
+
+    @classmethod
+    def get_channel_allocation_response_header(cls, length: int) -> Self:
+        """
+        Returns header for allocation response handshake message.
+        """
+        return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
+
+
+_DEFAULT_ENABLED_PAIRING_METHODS = [
+    ThpPairingMethod.CodeEntry,
+    ThpPairingMethod.QrCode,
+    ThpPairingMethod.NFC_Unidirectional,
+]
+
+
+def get_enabled_pairing_methods(
+    iface: WireInterface | None = None,
+) -> list[ThpPairingMethod]:
+    """
+    Returns pairing methods that are currently allowed by the device
+    with respect to the wire interface the host communicates on.
+    """
+    import usb
+
+    methods = _DEFAULT_ENABLED_PAIRING_METHODS.copy()
+    if iface is not None and iface is usb.iface_wire:
+        methods.append(ThpPairingMethod.NoMethod)
+    return methods
+
+
+def _get_device_properties(iface: WireInterface) -> ThpDeviceProperties:
+    # TODO define model variants
+    return ThpDeviceProperties(
+        pairing_methods=get_enabled_pairing_methods(iface),
+        internal_model=utils.INTERNAL_MODEL,
+        model_variant=0,
+        bootloader_mode=False,
+        protocol_version=2,
+    )
+
+
+def get_encoded_device_properties(iface: WireInterface) -> bytes:
+    props = _get_device_properties(iface)
+    length = protobuf.encoded_length(props)
+    encoded_properties = bytearray(length)
+    protobuf.encode(encoded_properties, props)
+    return encoded_properties
+
+
+def get_channel_allocation_response(
+    nonce: bytes, new_cid: bytes, iface: WireInterface
+) -> bytes:
+    props_msg = get_encoded_device_properties(iface)
+    return nonce + new_cid + props_msg
+
+
+if __debug__:
+
+    def state_to_str(state: int) -> str:
+        name = {
+            v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
+        }.get(state)
+        if name is not None:
+            return name
+        return "UNKNOWN_STATE"
diff --git a/core/src/trezor/wire/thp/alternating_bit_protocol.py b/core/src/trezor/wire/thp/alternating_bit_protocol.py
new file mode 100644
index 0000000000..d8ba60c5b2
--- /dev/null
+++ b/core/src/trezor/wire/thp/alternating_bit_protocol.py
@@ -0,0 +1,102 @@
+from storage.cache_thp import ChannelCache
+from trezor import log, utils
+from trezor.wire.thp import ThpError
+
+
+def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
+    """
+    Checks if:
+    - an ACK message is expected
+    - the received ACK message acknowledges correct sequence number (bit)
+    """
+    if not _is_ack_expected(cache):
+        return False
+
+    if not _has_ack_correct_sync_bit(cache, ack_bit):
+        return False
+
+    return True
+
+
+def _is_ack_expected(cache: ChannelCache) -> bool:
+    is_expected: bool = not is_sending_allowed(cache)
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_expected:
+        log.debug(__name__, "Received unexpected ACK message")
+    return is_expected
+
+
+def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
+    is_correct: bool = get_send_seq_bit(cache) == sync_bit
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_correct:
+        log.debug(__name__, "Received ACK message with wrong ack bit")
+    return is_correct
+
+
+def is_sending_allowed(cache: ChannelCache) -> bool:
+    """
+    Checks whether sending a message in the provided channel is allowed.
+
+    Note: Sending a message in a channel before receipt of ACK message for the previously
+    sent message (in the channel) is prohibited, as it can lead to desynchronization.
+    """
+    return bool(cache.sync >> 7)
+
+
+def get_send_seq_bit(cache: ChannelCache) -> int:
+    """
+    Returns the sequential number (bit) of the next message to be sent
+    in the provided channel.
+    """
+    return (cache.sync & 0x20) >> 5
+
+
+def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
+    """
+    Returns the (expected) sequential number (bit) of the next message
+    to be received in the provided channel.
+    """
+    return (cache.sync & 0x40) >> 6
+
+
+def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
+    """
+    Set the flag whether sending a message in this channel is allowed or not.
+    """
+    cache.sync &= 0x7F
+    if sending_allowed:
+        cache.sync |= 0x80
+
+
+def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
+    """
+    Set the expected sequential number (bit) of the next message to be received
+    in the provided channel
+    """
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
+    if seq_bit not in (0, 1):
+        raise ThpError("Unexpected receive sync bit")
+
+    # set second bit to "seq_bit" value
+    cache.sync &= 0xBF
+    if seq_bit:
+        cache.sync |= 0x40
+
+
+def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
+    if seq_bit not in (0, 1):
+        raise ThpError("Unexpected send seq bit")
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
+    # set third bit to "seq_bit" value
+    cache.sync &= 0xDF
+    if seq_bit:
+        cache.sync |= 0x20
+
+
+def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
+    """
+    Set the sequential bit of the "next message to be send" to the opposite value,
+    i.e. 1 -> 0 and 0 -> 1
+    """
+    _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))
diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py
new file mode 100644
index 0000000000..8e6e65647f
--- /dev/null
+++ b/core/src/trezor/wire/thp/channel.py
@@ -0,0 +1,405 @@
+import ustruct
+from typing import TYPE_CHECKING
+
+from storage.cache_common import (
+    CHANNEL_HANDSHAKE_HASH,
+    CHANNEL_KEY_RECEIVE,
+    CHANNEL_KEY_SEND,
+    CHANNEL_NONCE_RECEIVE,
+    CHANNEL_NONCE_SEND,
+)
+from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id
+from trezor import log, loop, protobuf, utils, workflow
+
+from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
+from . import alternating_bit_protocol as ABP
+from . import (
+    control_byte,
+    crypto,
+    interface_manager,
+    memory_manager,
+    received_message_handler,
+)
+from .checksum import CHECKSUM_LENGTH
+from .transmission_loop import TransmissionLoop
+from .writer import (
+    CONT_HEADER_LENGTH,
+    INIT_HEADER_LENGTH,
+    write_payload_to_wire_and_add_checksum,
+)
+
+if __debug__:
+    from ubinascii import hexlify
+
+    from . import state_to_str
+
+if TYPE_CHECKING:
+    from trezorio import WireInterface
+    from typing import Awaitable
+
+    from .pairing_context import PairingContext
+    from .session_context import GenericSessionContext
+
+
+class Channel:
+    """
+    THP protocol encrypted communication channel.
+    """
+
+    def __init__(self, channel_cache: ChannelCache) -> None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(__name__, "channel initialization")
+        self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
+        self.channel_cache: ChannelCache = channel_cache
+        self.is_cont_packet_expected: bool = False
+        self.expected_payload_length: int = 0
+        self.bytes_read: int = 0
+        self.buffer: utils.BufferType
+        self.channel_id: bytes = channel_cache.channel_id
+        self.selected_pairing_methods = []
+        self.sessions: dict[int, GenericSessionContext] = {}
+        self.write_task_spawn: loop.spawn | None = None
+        self.connection_context: PairingContext | None = None
+        self.transmission_loop: TransmissionLoop | None = None
+        self.handshake: crypto.Handshake | None = None
+
+    def clear(self) -> None:
+        clear_sessions_with_channel_id(self.channel_id)
+        self.channel_cache.clear()
+
+    # ACCESS TO CHANNEL_DATA
+    def get_channel_id_int(self) -> int:
+        return int.from_bytes(self.channel_id, "big")
+
+    def get_channel_state(self) -> int:
+        state = int.from_bytes(self.channel_cache.state, "big")
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) get_channel_state: %s",
+                utils.get_bytes_as_str(self.channel_id),
+                state_to_str(state),
+            )
+        return state
+
+    def get_handshake_hash(self) -> bytes:
+        h = self.channel_cache.get(CHANNEL_HANDSHAKE_HASH)
+        assert h is not None
+        return h
+
+    def set_channel_state(self, state: ChannelState) -> None:
+        self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) set_channel_state: %s",
+                utils.get_bytes_as_str(self.channel_id),
+                state_to_str(state),
+            )
+
+    def set_buffer(self, buffer: utils.BufferType) -> None:
+        self.buffer = buffer
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) set_buffer: %s",
+                utils.get_bytes_as_str(self.channel_id),
+                type(self.buffer),
+            )
+
+    # CALLED BY THP_MAIN_LOOP
+
+    def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) receive_packet",
+                utils.get_bytes_as_str(self.channel_id),
+            )
+
+        self._handle_received_packet(packet)
+
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) self.buffer: %s",
+                utils.get_bytes_as_str(self.channel_id),
+                utils.get_bytes_as_str(self.buffer),
+            )
+
+        if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
+            self._finish_message()
+            return received_message_handler.handle_received_message(self, self.buffer)
+        elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
+            self.is_cont_packet_expected = True
+        else:
+            raise ThpError(
+                "Read more bytes than is the expected length of the message!"
+            )
+        return None
+
+    def _handle_received_packet(self, packet: utils.BufferType) -> None:
+        ctrl_byte = packet[0]
+        if control_byte.is_continuation(ctrl_byte):
+            return self._handle_cont_packet(packet)
+        return self._handle_init_packet(packet)
+
+    def _handle_init_packet(self, packet: utils.BufferType) -> None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) handle_init_packet",
+                utils.get_bytes_as_str(self.channel_id),
+            )
+        # ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
+        _, _, payload_length = ustruct.unpack(">BHH", packet)
+        self.expected_payload_length = payload_length
+        packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
+
+        # If the channel does not "own" the buffer lock, decrypt first packet
+        # TODO do it only when needed!
+        # TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation.
+        # On the other hand, without the single packet decryption, the "advanced" buffer selection cannot be implemented
+        # in "memory_manager.select_buffer", because the session id is unknown (encrypted).
+
+        # if control_byte.is_encrypted_transport(ctrl_byte):
+        #   packet_payload = self._decrypt_single_packet_payload(packet_payload)
+
+        self.buffer = memory_manager.select_buffer(
+            self.get_channel_state(),
+            self.buffer,
+            packet_payload,
+            payload_length,
+        )
+
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) handle_init_packet - payload len: %d",
+                utils.get_bytes_as_str(self.channel_id),
+                payload_length,
+            )
+            log.debug(
+                __name__,
+                "(cid: %s) handle_init_packet - buffer len: %d",
+                utils.get_bytes_as_str(self.channel_id),
+                len(self.buffer),
+            )
+        return self._buffer_packet_data(self.buffer, packet, 0)
+
+    def _handle_cont_packet(self, packet: utils.BufferType) -> None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) handle_cont_packet",
+                utils.get_bytes_as_str(self.channel_id),
+            )
+        if not self.is_cont_packet_expected:
+            raise ThpError("Continuation packet is not expected, ignoring")
+        return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
+
+    def _decrypt_single_packet_payload(
+        self, payload: utils.BufferType
+    ) -> utils.BufferType:
+        # crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
+        return payload
+
+    def decrypt_buffer(
+        self, message_length: int, offset: int = INIT_HEADER_LENGTH
+    ) -> None:
+        noise_buffer = memoryview(self.buffer)[
+            offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
+        ]
+        tag = self.buffer[
+            message_length
+            - CHECKSUM_LENGTH
+            - TAG_LENGTH : message_length
+            - CHECKSUM_LENGTH
+        ]
+        if utils.DISABLE_ENCRYPTION:
+            is_tag_valid = tag == crypto.DUMMY_TAG
+        else:
+            key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
+            nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
+
+            assert key_receive is not None
+            assert nonce_receive is not None
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.debug(
+                    __name__,
+                    "(cid: %s) Buffer before decryption: %s",
+                    utils.get_bytes_as_str(self.channel_id),
+                    hexlify(noise_buffer),
+                )
+            is_tag_valid = crypto.dec(
+                noise_buffer, tag, key_receive, nonce_receive, b""
+            )
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.debug(
+                    __name__,
+                    "(cid: %s) Buffer after decryption: %s",
+                    utils.get_bytes_as_str(self.channel_id),
+                    hexlify(noise_buffer),
+                )
+
+            self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
+
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid: %s) Is decrypted tag valid? %s",
+                utils.get_bytes_as_str(self.channel_id),
+                str(is_tag_valid),
+            )
+            log.debug(
+                __name__,
+                "(cid: %s) Received tag: %s",
+                utils.get_bytes_as_str(self.channel_id),
+                (hexlify(tag).decode()),
+            )
+            log.debug(
+                __name__,
+                "(cid: %s) New nonce_receive: %i",
+                utils.get_bytes_as_str(self.channel_id),
+                nonce_receive + 1,
+            )
+
+        if not is_tag_valid:
+            raise ThpDecryptionError()
+
+    def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
+            )
+        assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
+
+        noise_buffer = memoryview(buffer)[0:noise_payload_len]
+
+        if utils.DISABLE_ENCRYPTION:
+            tag = crypto.DUMMY_TAG
+        else:
+            key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
+            nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
+
+            assert key_send is not None
+            assert nonce_send is not None
+
+            tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
+
+            self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
+
+        buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
+
+    def _buffer_packet_data(
+        self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
+    ) -> None:
+        self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
+
+    def _finish_message(self) -> None:
+        self.bytes_read = 0
+        self.expected_payload_length = 0
+        self.is_cont_packet_expected = False
+
+    # CALLED BY WORKFLOW / SESSION CONTEXT
+
+    async def write(
+        self,
+        msg: protobuf.MessageType,
+        session_id: int = 0,
+        force: bool = False,
+    ) -> None:
+        if __debug__ and utils.EMULATOR:
+            log.debug(
+                __name__,
+                "(cid: %s) write message: %s\n%s",
+                utils.get_bytes_as_str(self.channel_id),
+                msg.MESSAGE_NAME,
+                utils.dump_protobuf(msg),
+            )
+
+        self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
+        noise_payload_len = memory_manager.encode_into_buffer(
+            self.buffer, msg, session_id
+        )
+        task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
+        if task is not None:
+            await task
+
+    def write_error(self, err_type: int) -> Awaitable[None]:
+        msg_data = err_type.to_bytes(1, "big")
+        length = len(msg_data) + CHECKSUM_LENGTH
+        header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
+        return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
+
+    def write_and_encrypt(
+        self, payload: bytes, force: bool = False
+    ) -> Awaitable[None] | None:
+        payload_length = len(payload)
+        self._encrypt(self.buffer, payload_length)
+        payload_length = payload_length + TAG_LENGTH
+
+        if self.write_task_spawn is not None:
+            self.write_task_spawn.close()  # UPS TODO might break something
+            print("\nCLOSED\n")
+        self._prepare_write()
+        if force:
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.debug(
+                    __name__, "Writing FORCE message (without async or retransmission)."
+                )
+            return self._write_encrypted_payload_loop(
+                ENCRYPTED, memoryview(self.buffer[:payload_length])
+            )
+        self.write_task_spawn = loop.spawn(
+            self._write_encrypted_payload_loop(
+                ENCRYPTED, memoryview(self.buffer[:payload_length])
+            )
+        )
+        return None
+
+    def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
+        self._prepare_write()
+        self.write_task_spawn = loop.spawn(
+            self._write_encrypted_payload_loop(ctrl_byte, payload)
+        )
+
+    def _prepare_write(self) -> None:
+        # TODO add condition that disallows to write when can_send_message is false
+        ABP.set_sending_allowed(self.channel_cache, False)
+
+    async def _write_encrypted_payload_loop(
+        self, ctrl_byte: int, payload: bytes
+    ) -> None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__,
+                "(cid %s) write_encrypted_payload_loop",
+                utils.get_bytes_as_str(self.channel_id),
+            )
+        payload_len = len(payload) + CHECKSUM_LENGTH
+        sync_bit = ABP.get_send_seq_bit(self.channel_cache)
+        ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
+        header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
+        self.transmission_loop = TransmissionLoop(self, header, payload)
+        await self.transmission_loop.start()
+
+        ABP.set_send_seq_bit_to_opposite(self.channel_cache)
+
+        # Let the main loop be restarted and clear loop, if there is no other
+        # workflow and the state is ENCRYPTED_TRANSPORT
+        if self._can_clear_loop():
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.debug(
+                    __name__,
+                    "(cid: %s) clearing loop from channel",
+                    utils.get_bytes_as_str(self.channel_id),
+                )
+            loop.clear()
+
+    def _can_clear_loop(self) -> bool:
+        return (
+            not workflow.tasks
+        ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
diff --git a/core/src/trezor/wire/thp/channel_manager.py b/core/src/trezor/wire/thp/channel_manager.py
new file mode 100644
index 0000000000..a48f6d7fdb
--- /dev/null
+++ b/core/src/trezor/wire/thp/channel_manager.py
@@ -0,0 +1,34 @@
+from typing import TYPE_CHECKING
+
+from storage import cache_thp
+from trezor import utils
+
+from . import ChannelState, interface_manager
+from .channel import Channel
+
+if TYPE_CHECKING:
+    from trezorio import WireInterface
+
+
+def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
+    """
+    Creates a new channel for the interface `iface` with the buffer `buffer`.
+    """
+    channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
+    r = Channel(channel_cache)
+    r.set_buffer(buffer)
+    r.set_channel_state(ChannelState.TH1)
+    return r
+
+
+def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
+    """
+    Returns all allocated channels from cache.
+    """
+    channels: dict[int, Channel] = {}
+    cached_channels = cache_thp.get_all_allocated_channels()
+    for c in cached_channels:
+        channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
+    for c in channels.values():
+        c.set_buffer(buffer)
+    return channels
diff --git a/core/src/trezor/wire/thp/checksum.py b/core/src/trezor/wire/thp/checksum.py
new file mode 100644
index 0000000000..9c28f2e78d
--- /dev/null
+++ b/core/src/trezor/wire/thp/checksum.py
@@ -0,0 +1,22 @@
+from micropython import const
+
+from trezor import utils
+from trezor.crypto import crc
+
+CHECKSUM_LENGTH = const(4)
+
+
+def compute(data: bytes | utils.BufferType) -> bytes:
+    """
+    Returns a CRC-32 checksum of the provided `data`.
+    """
+    return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
+
+
+def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
+    """
+    Checks whether the CRC-32 checksum of the `data` is the same
+    as the checksum provided in `checksum`.
+    """
+    data_checksum = compute(data)
+    return checksum == data_checksum
diff --git a/core/src/trezor/wire/thp/control_byte.py b/core/src/trezor/wire/thp/control_byte.py
new file mode 100644
index 0000000000..5d4d69b040
--- /dev/null
+++ b/core/src/trezor/wire/thp/control_byte.py
@@ -0,0 +1,50 @@
+from micropython import const
+
+from . import (
+    ACK_MESSAGE,
+    CONTINUATION_PACKET,
+    ENCRYPTED,
+    HANDSHAKE_COMP_REQ,
+    HANDSHAKE_INIT_REQ,
+    ThpError,
+)
+
+_CONTINUATION_PACKET_MASK = const(0x80)
+_ACK_MASK = const(0xF7)
+_DATA_MASK = const(0xE7)
+
+
+def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
+    if seq_bit == 0:
+        return ctrl_byte & 0xEF
+    if seq_bit == 1:
+        return ctrl_byte | 0x10
+    raise ThpError("Unexpected sequence bit")
+
+
+def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
+    if ack_bit == 0:
+        return ctrl_byte & 0xF7
+    if ack_bit == 1:
+        return ctrl_byte | 0x08
+    raise ThpError("Unexpected acknowledgement bit")
+
+
+def is_ack(ctrl_byte: int) -> bool:
+    return ctrl_byte & _ACK_MASK == ACK_MESSAGE
+
+
+def is_continuation(ctrl_byte: int) -> bool:
+    return ctrl_byte & _CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
+
+
+def is_encrypted_transport(ctrl_byte: int) -> bool:
+    return ctrl_byte & _DATA_MASK == ENCRYPTED
+
+
+def is_handshake_init_req(ctrl_byte: int) -> bool:
+    return ctrl_byte & _DATA_MASK == HANDSHAKE_INIT_REQ
+
+
+def is_handshake_comp_req(ctrl_byte: int) -> bool:
+    return ctrl_byte & _DATA_MASK == HANDSHAKE_COMP_REQ
diff --git a/core/src/trezor/wire/thp/cpace.py b/core/src/trezor/wire/thp/cpace.py
new file mode 100644
index 0000000000..302dd3e5e3
--- /dev/null
+++ b/core/src/trezor/wire/thp/cpace.py
@@ -0,0 +1,36 @@
+from trezor.crypto import elligator2, random
+from trezor.crypto.curve import curve25519
+from trezor.crypto.hashlib import sha512
+
+_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
+_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
+
+
+class Cpace:
+    """
+    CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
+    """
+
+    def __init__(self, cpace_host_public_key: bytes, handshake_hash: bytes) -> None:
+        self.handshake_hash: bytes = handshake_hash
+        self.host_public_key: bytes = cpace_host_public_key
+        self.shared_secret: bytes
+        self.trezor_private_key: bytes
+        self.trezor_public_key: bytes
+
+    def generate_keys_and_secret(self, code_code_entry: bytes) -> None:
+        """
+        Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
+        """
+        sha_ctx = sha512(_PREFIX)
+        sha_ctx.update(code_code_entry)
+        sha_ctx.update(_PADDING)
+        sha_ctx.update(self.handshake_hash)
+        sha_ctx.update(b"\x00")
+        pregenerator = sha_ctx.digest()[:32]
+        generator = elligator2.map_to_curve25519(pregenerator)
+        self.trezor_private_key = random.bytes(32)
+        self.trezor_public_key = curve25519.multiply(self.trezor_private_key, generator)
+        self.shared_secret = curve25519.multiply(
+            self.trezor_private_key, self.host_public_key
+        )
diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py
new file mode 100644
index 0000000000..aa7d9c146e
--- /dev/null
+++ b/core/src/trezor/wire/thp/crypto.py
@@ -0,0 +1,211 @@
+from micropython import const
+from trezorcrypto import aesgcm, bip32, curve25519, hmac
+
+from storage import device
+from trezor import log, utils
+from trezor.crypto.hashlib import sha256
+from trezor.wire.thp import ThpDecryptionError
+
+# The HARDENED flag is taken from apps.common.paths
+# It is not imported to save on resources
+HARDENED = const(0x8000_0000)
+PUBKEY_LENGTH = const(32)
+if utils.DISABLE_ENCRYPTION:
+    DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
+
+if __debug__:
+    from ubinascii import hexlify
+
+
+def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes:
+    """
+    Encrypts the provided `buffer` with AES-GCM (in place).
+    Returns a 16-byte long encryption tag.
+    """
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce)
+    iv = _get_iv_from_nonce(nonce)
+    aes_ctx = aesgcm(key, iv)
+    aes_ctx.auth(auth_data)
+    aes_ctx.encrypt_in_place(buffer)
+    return aes_ctx.finish()
+
+
+def dec(
+    buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes
+) -> bool:
+    """
+    Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as
+    the tag computed in decryption, otherwise it returns `False`.
+    """
+    iv = _get_iv_from_nonce(nonce)
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce)
+    aes_ctx = aesgcm(key, iv)
+    aes_ctx.auth(auth_data)
+    aes_ctx.decrypt_in_place(buffer)
+    computed_tag = aes_ctx.finish()
+    return computed_tag == tag
+
+
+class BusyDecoder:
+    def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None:
+        iv = _get_iv_from_nonce(nonce)
+        self.aes_ctx = aesgcm(key, iv)
+        self.aes_ctx.auth(auth_data)
+
+    def decrypt_part(self, part: utils.BufferType) -> None:
+        self.aes_ctx.decrypt_in_place(part)
+
+    def finish_and_check_tag(self, tag: bytes) -> bool:
+        computed_tag = self.aes_ctx.finish()
+        return computed_tag == tag
+
+
+PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
+IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
+
+class Handshake:
+    """
+    `Handshake` holds (temporary) values and keys that are used during the creation of an encrypted channel.
+    The following values should be saved for future use before disposing of this object:
+    - `h` - handshake hash, can be used to bind other values to the channel
+    - `key_receive` - key for decrypting incoming communication
+    - `key_send` - key for encrypting outgoing communication
+    """
+
+    def __init__(self) -> None:
+        self.trezor_ephemeral_privkey: bytes
+        self.ck: bytes
+        self.k: bytes
+        self.h: bytes
+        self.key_receive: bytes
+        self.key_send: bytes
+
+    def handle_th1_crypto(
+        self,
+        device_properties: bytes,
+        host_ephemeral_pubkey: bytes,
+    ) -> tuple[bytes, bytes, bytes]:
+
+        trezor_static_privkey, trezor_static_pubkey = _derive_static_key_pair()
+        self.trezor_ephemeral_privkey = curve25519.generate_secret()
+        trezor_ephemeral_pubkey = curve25519.publickey(self.trezor_ephemeral_privkey)
+        self.h = _hash_of_two(PROTOCOL_NAME, device_properties)
+        self.h = _hash_of_two(self.h, host_ephemeral_pubkey)
+        self.h = _hash_of_two(self.h, trezor_ephemeral_pubkey)
+        point = curve25519.multiply(
+            self.trezor_ephemeral_privkey, host_ephemeral_pubkey
+        )
+        self.ck, self.k = _hkdf(PROTOCOL_NAME, point)
+        mask = _hash_of_two(trezor_static_pubkey, trezor_ephemeral_pubkey)
+        trezor_masked_static_pubkey = curve25519.multiply(mask, trezor_static_pubkey)
+        aes_ctx = aesgcm(self.k, IV_1)
+        encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey)
+        if __debug__:
+            log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0)
+        aes_ctx.auth(self.h)
+        tag_to_encrypted_key = aes_ctx.finish()
+        encrypted_trezor_static_pubkey = (
+            encrypted_trezor_static_pubkey + tag_to_encrypted_key
+        )
+        self.h = _hash_of_two(self.h, encrypted_trezor_static_pubkey)
+        point = curve25519.multiply(trezor_static_privkey, host_ephemeral_pubkey)
+        self.ck, self.k = _hkdf(self.ck, curve25519.multiply(mask, point))
+        aes_ctx = aesgcm(self.k, IV_1)
+        aes_ctx.auth(self.h)
+        tag = aes_ctx.finish()
+        self.h = _hash_of_two(self.h, tag)
+        return (trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag)
+
+    def handle_th2_crypto(
+        self,
+        encrypted_host_static_pubkey: utils.BufferType,
+        encrypted_payload: utils.BufferType,
+    ) -> None:
+
+        aes_ctx = aesgcm(self.k, IV_2)
+
+        # The new value of hash `h` MUST be computed before the `encrypted_host_static_pubkey` is decrypted.
+        # However, decryption of `encrypted_host_static_pubkey` MUST use the previous value of `h` for
+        # authentication of the gcm tag.
+        aes_ctx.auth(self.h)  # Authenticate with the previous value of `h`
+        self.h = _hash_of_two(self.h, encrypted_host_static_pubkey)  # Compute new value
+        aes_ctx.decrypt_in_place(
+            memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
+        )
+        if __debug__:
+            log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1)
+        host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH]
+        tag = aes_ctx.finish()
+        if tag != encrypted_host_static_pubkey[-16:]:
+            raise ThpDecryptionError()
+
+        self.ck, self.k = _hkdf(
+            self.ck,
+            curve25519.multiply(self.trezor_ephemeral_privkey, host_static_pubkey),
+        )
+        aes_ctx = aesgcm(self.k, IV_1)
+        aes_ctx.auth(self.h)
+        aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16])
+        if __debug__:
+            log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0)
+        tag = aes_ctx.finish()
+        if tag != encrypted_payload[-16:]:
+            raise ThpDecryptionError()
+
+        self.h = _hash_of_two(self.h, memoryview(encrypted_payload)[:-16])
+        self.key_receive, self.key_send = _hkdf(self.ck, b"")
+        if __debug__:
+            log.debug(
+                __name__,
+                "(key_receive: %s, key_send: %s)",
+                hexlify(self.key_receive),
+                hexlify(self.key_send),
+            )
+
+    def get_handshake_completion_response(self, trezor_state: bytes) -> bytes:
+        aes_ctx = aesgcm(self.key_send, IV_1)
+        encrypted_trezor_state = aes_ctx.encrypt(trezor_state)
+        tag = aes_ctx.finish()
+        return encrypted_trezor_state + tag
+
+
+def _derive_static_key_pair() -> tuple[bytes, bytes]:
+    node_int = HARDENED | int.from_bytes(b"\x00THP", "big")
+    node = bip32.from_seed(device.get_device_secret(), "curve25519")
+    node.derive(node_int)
+
+    trezor_static_privkey = node.private_key()
+    trezor_static_pubkey = node.public_key()[1:33]
+    # Note: the first byte (\x01) of the public key is removed, as it
+    # only indicates the type of the elliptic curve used
+
+    return trezor_static_privkey, trezor_static_pubkey
+
+
+def get_trezor_static_pubkey() -> bytes:
+    _, pubkey = _derive_static_key_pair()
+    return pubkey
+
+
+def _hkdf(chaining_key: bytes, input: bytes) -> tuple[bytes, bytes]:
+    temp_key = hmac(hmac.SHA256, chaining_key, input).digest()
+    output_1 = hmac(hmac.SHA256, temp_key, b"\x01").digest()
+    ctx_output_2 = hmac(hmac.SHA256, temp_key, output_1)
+    ctx_output_2.update(b"\x02")
+    output_2 = ctx_output_2.digest()
+    return (output_1, output_2)
+
+
+def _hash_of_two(part_1: bytes, part_2: bytes) -> bytes:
+    ctx = sha256(part_1)
+    ctx.update(part_2)
+    return ctx.digest()
+
+
+def _get_iv_from_nonce(nonce: int) -> bytes:
+    utils.ensure(nonce <= 0xFFFFFFFFFFFFFFFF, "Nonce overflow, terminate the channel")
+    return bytes(4) + nonce.to_bytes(8, "big")
diff --git a/core/src/trezor/wire/thp/interface_manager.py b/core/src/trezor/wire/thp/interface_manager.py
new file mode 100644
index 0000000000..a1fecfe7d6
--- /dev/null
+++ b/core/src/trezor/wire/thp/interface_manager.py
@@ -0,0 +1,28 @@
+from typing import TYPE_CHECKING
+
+import usb
+
+_WIRE_INTERFACE_USB = b"\x01"
+# TODO _WIRE_INTERFACE_BLE = b"\x02"
+
+if TYPE_CHECKING:
+    from trezorio import WireInterface
+
+
+def decode_iface(cached_iface: bytes) -> WireInterface:
+    """Decode the cached wire interface."""
+    if cached_iface == _WIRE_INTERFACE_USB:
+        iface = usb.iface_wire
+        if iface is None:
+            raise RuntimeError("There is no valid USB WireInterface")
+        return iface
+    # TODO implement bluetooth interface
+    raise Exception("Unknown WireInterface")
+
+
+def encode_iface(iface: WireInterface) -> bytes:
+    """Encode wire interface into bytes."""
+    if iface is usb.iface_wire:
+        return _WIRE_INTERFACE_USB
+    # TODO implement bluetooth interface
+    raise Exception("Unknown WireInterface")
diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py
new file mode 100644
index 0000000000..0a117c16f7
--- /dev/null
+++ b/core/src/trezor/wire/thp/memory_manager.py
@@ -0,0 +1,179 @@
+from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
+from trezor import log, protobuf, utils
+from trezor.wire.message_handler import get_msg_type
+
+from . import ChannelState, ThpError
+from .checksum import CHECKSUM_LENGTH
+from .writer import (
+    INIT_HEADER_LENGTH,
+    MAX_PAYLOAD_LEN,
+    MESSAGE_TYPE_LENGTH,
+    PACKET_LENGTH,
+)
+
+
+def select_buffer(
+    channel_state: int,
+    channel_buffer: utils.BufferType,
+    packet_payload: utils.BufferType,
+    payload_length: int,
+) -> utils.BufferType:
+
+    if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
+        session_id = packet_payload[0]
+        if session_id == 0:
+            pass
+            # TODO use small buffer
+        else:
+            pass
+            # TODO use big buffer but only if the channel owns the buffer lock.
+            # Otherwise send BUSY message and return
+    else:
+        pass
+        # TODO use small buffer
+    try:
+        # TODO for now, we create a new big buffer every time. It should be changed
+        buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
+        return buffer
+    except Exception as e:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.exception(__name__, e)
+    raise Exception("Failed to create a buffer for channel")  # TODO handle better
+
+
+def get_write_buffer(
+    buffer: utils.BufferType, msg: protobuf.MessageType
+) -> utils.BufferType:
+    msg_size = protobuf.encoded_length(msg)
+    payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
+    required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
+
+    if required_min_size > len(buffer):
+        return _get_buffer_for_write(required_min_size, buffer)
+    return buffer
+
+
+def encode_into_buffer(
+    buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
+) -> int:
+    # cannot write message without wire type
+    msg_type = msg.MESSAGE_WIRE_TYPE
+    if msg_type is None:
+        msg_type = get_msg_type(msg.MESSAGE_NAME)
+    assert msg_type is not None
+
+    msg_size = protobuf.encoded_length(msg)
+    payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
+
+    _encode_session_into_buffer(memoryview(buffer), session_id)
+    _encode_message_type_into_buffer(memoryview(buffer), msg_type, SESSION_ID_LENGTH)
+    _encode_message_into_buffer(
+        memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
+    )
+
+    return payload_size
+
+
+def _encode_session_into_buffer(
+    buffer: memoryview, session_id: int, buffer_offset: int = 0
+) -> None:
+    session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
+    utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
+
+
+def _encode_message_type_into_buffer(
+    buffer: memoryview, message_type: int, offset: int = 0
+) -> None:
+    msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
+    utils.memcpy(buffer, offset, msg_type_bytes, 0)
+
+
+def _encode_message_into_buffer(
+    buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
+) -> None:
+    protobuf.encode(memoryview(buffer[buffer_offset:]), message)
+
+
+def _get_buffer_for_read(
+    payload_length: int,
+    existing_buffer: utils.BufferType,
+    max_length: int = MAX_PAYLOAD_LEN,
+) -> utils.BufferType:
+    length = payload_length + INIT_HEADER_LENGTH
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(
+            __name__,
+            "get_buffer_for_read - length: %d, %s %s",
+            length,
+            "existing buffer type:",
+            type(existing_buffer),
+        )
+    if length > max_length:
+        raise ThpError("Message too large")
+
+    if length > len(existing_buffer):
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(__name__, "Allocating a new buffer")
+
+        from .thp_main import get_raw_read_buffer
+
+        if length > len(get_raw_read_buffer()):
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.debug(
+                    __name__,
+                    "Required length is %d, where raw buffer has capacity only %d",
+                    length,
+                    len(get_raw_read_buffer()),
+                )
+            raise ThpError("Message is too large")
+
+        try:
+            payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
+        except MemoryError:
+            payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
+            raise ThpError("Message is too large")
+        return payload
+
+    # reuse a part of the supplied buffer
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "Reusing already allocated buffer")
+    return memoryview(existing_buffer)[:length]
+
+
+def _get_buffer_for_write(
+    payload_length: int,
+    existing_buffer: utils.BufferType,
+    max_length: int = MAX_PAYLOAD_LEN,
+) -> utils.BufferType:
+    length = payload_length + INIT_HEADER_LENGTH
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(
+            __name__,
+            "get_buffer_for_write - length: %d, %s %s",
+            length,
+            "existing buffer type:",
+            type(existing_buffer),
+        )
+    if length > max_length:
+        raise ThpError("Message too large")
+
+    if length > len(existing_buffer):
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(__name__, "Creating a new write buffer from raw write buffer")
+
+        from .thp_main import get_raw_write_buffer
+
+        if length > len(get_raw_write_buffer()):
+            raise ThpError("Message is too large")
+
+        try:
+            payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
+        except MemoryError:
+            payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
+            raise ThpError("Message is too large")
+        return payload
+
+    # reuse a part of the supplied buffer
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "Reusing already allocated buffer")
+    return memoryview(existing_buffer)[:length]
diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py
new file mode 100644
index 0000000000..02f8db3a8a
--- /dev/null
+++ b/core/src/trezor/wire/thp/pairing_context.py
@@ -0,0 +1,262 @@
+from typing import TYPE_CHECKING
+from ubinascii import hexlify
+
+import trezorui_api
+from trezor import loop, protobuf, workflow
+from trezor.crypto import random
+from trezor.wire import context, message_handler, protocol_common
+from trezor.wire.context import UnexpectedMessageException
+from trezor.wire.errors import ActionCancelled, SilentError
+from trezor.wire.protocol_common import Context, Message
+
+if TYPE_CHECKING:
+    from typing import Container
+
+    from trezor import ui
+
+    from .channel import Channel
+    from .cpace import Cpace
+
+    pass
+
+if __debug__:
+    from trezor import log
+
+
+class PairingDisplayData:
+
+    def __init__(self) -> None:
+        self.code_code_entry: int | None = None
+        self.code_qr_code: bytes | None = None
+        self.code_nfc_unidirectional: bytes | None = None
+
+    def get_display_layout(self) -> ui.Layout:
+        from trezor import ui
+
+        # TODO have different layouts when there is only QR code or only Code Entry
+        qr_str = ""
+        code_str = ""
+        if self.code_qr_code is not None:
+            qr_str = self._get_code_qr_code_str()
+        if self.code_code_entry is not None:
+            code_str = self._get_code_code_entry_str()
+
+        return ui.Layout(
+            trezorui_api.show_address_details(  # noqa
+                qr_title="Scan QR code to pair",
+                address=qr_str,
+                case_sensitive=True,
+                details_title="",
+                account="Code to rewrite:\n" + code_str,
+                path="",
+                xpubs=[],
+            )
+        )
+
+    def _get_code_code_entry_str(self) -> str:
+        if self.code_code_entry is not None:
+            code_str = f"{self.code_code_entry:06}"
+            if __debug__:
+                log.debug(__name__, "code_code_entry: %s", code_str)
+
+            return code_str[:3] + " " + code_str[3:]
+        raise Exception("Code entry string is not available")
+
+    def _get_code_qr_code_str(self) -> str:
+        if self.code_qr_code is not None:
+            code_str = (hexlify(self.code_qr_code)).decode("utf-8")
+            if __debug__:
+                log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
+            return code_str
+        raise Exception("QR code string is not available")
+
+
+class PairingContext(Context):
+
+    def __init__(self, channel_ctx: Channel) -> None:
+        super().__init__(channel_ctx.iface, channel_ctx.channel_id)
+        self.channel_ctx: Channel = channel_ctx
+        self.incoming_message = loop.mailbox()
+        self.secret: bytes = random.bytes(16)
+
+        self.display_data: PairingDisplayData = PairingDisplayData()
+        self.cpace: Cpace
+        self.host_name: str
+
+    async def handle(self, is_debug_session: bool = False) -> None:
+        # if __debug__:
+        #     log.debug(__name__, "handle - start")
+        #     if is_debug_session:
+        #         import apps.debug
+
+        #         apps.debug.DEBUG_CONTEXT = self
+
+        next_message: Message | None = None
+
+        while True:
+            try:
+                if next_message is None:
+                    # If the previous run did not keep an unprocessed message for us,
+                    # wait for a new one.
+                    try:
+                        message: Message = await self.incoming_message
+                    except protocol_common.WireError as e:
+                        if __debug__:
+                            log.exception(__name__, e)
+                        await self.write(message_handler.failure(e))
+                        continue
+                else:
+                    # Process the message from previous run.
+                    message = next_message
+                    next_message = None
+
+                try:
+                    next_message = await handle_pairing_request_message(self, message)
+                except Exception as exc:
+                    # Log and ignore. The session handler can only exit explicitly in the
+                    # following finally block.
+                    if __debug__:
+                        log.exception(__name__, exc)
+                finally:
+                    # Unload modules imported by the workflow.  Should not raise.
+                    # This is not done for the debug session because the snapshot taken
+                    # in a debug session would clear modules which are in use by the
+                    # workflow running on wire.
+                    # TODO utils.unimport_end(modules)
+
+                    if next_message is None:
+
+                        # Shut down the loop if there is no next message waiting.
+                        return  # pylint: disable=lost-exception
+
+            except Exception as exc:
+                # Log and try again. The session handler can only exit explicitly via
+                # loop.clear() above. # TODO not updated comments
+                if __debug__:
+                    log.exception(__name__, exc)
+
+    async def read(
+        self,
+        expected_types: Container[int],
+        expected_type: type[protobuf.MessageType] | None = None,
+    ) -> protobuf.MessageType:
+        if __debug__:
+            exp_type: str = str(expected_type)
+            if expected_type is not None:
+                exp_type = expected_type.MESSAGE_NAME
+            log.debug(
+                __name__,
+                "Read - with expected types %s and expected type %s",
+                str(expected_types),
+                exp_type,
+            )
+
+        message: Message = await self.incoming_message
+
+        if message.type not in expected_types:
+            raise UnexpectedMessageException(message)
+
+        if expected_type is None:
+            name = message_handler.get_msg_name(message.type)
+            if name is None:
+                expected_type = protobuf.type_for_wire(message.type)
+            else:
+                expected_type = protobuf.type_for_name(name)
+
+        return message_handler.wrap_protobuf_load(message.data, expected_type)
+
+    async def write(self, msg: protobuf.MessageType) -> None:
+        return await self.channel_ctx.write(msg)
+
+    async def call(
+        self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
+    ) -> protobuf.MessageType:
+        expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME)
+        if expected_wire_type is None:
+            expected_wire_type = expected_type.MESSAGE_WIRE_TYPE
+
+        assert expected_wire_type is not None
+
+        await self.write(msg)
+        del msg
+
+        return await self.read((expected_wire_type,), expected_type)
+
+    async def call_any(
+        self, msg: protobuf.MessageType, *expected_types: int
+    ) -> protobuf.MessageType:
+        await self.write(msg)
+        del msg
+        return await self.read(expected_types)
+
+
+async def handle_pairing_request_message(
+    pairing_ctx: PairingContext,
+    msg: protocol_common.Message,
+) -> protocol_common.Message | None:
+
+    res_msg: protobuf.MessageType | None = None
+
+    from apps.thp.pairing import handle_pairing_request
+
+    if msg.type in workflow.ALLOW_WHILE_LOCKED:
+        workflow.autolock_interrupts_workflow = False
+
+    # Here we make sure we always respond with a Failure response
+    # in case of any errors.
+    try:
+        # Find a protobuf.MessageType subclass that describes this
+        # message.  Raises if the type is not found.
+        name = message_handler.get_msg_name(msg.type)
+        if name is None:
+            req_type = protobuf.type_for_wire(msg.type)
+        else:
+            req_type = protobuf.type_for_name(name)
+
+        # Try to decode the message according to schema from
+        # `req_type`. Raises if the message is malformed.
+        req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
+
+        # Create the handler task.
+        task = handle_pairing_request(pairing_ctx, req_msg)
+
+        # Run the workflow task.  Workflow can do more on-the-wire
+        # communication inside, but it should eventually return a
+        # response message, or raise an exception (a rather common
+        # thing to do).  Exceptions are handled in the code below.
+        res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
+
+    except UnexpectedMessageException as exc:
+        # Workflow was trying to read a message from the wire, and
+        # something unexpected came in.  See Context.read() for
+        # example, which expects some particular message and raises
+        # UnexpectedMessage if another one comes in.
+        # In order not to lose the message, we return it to the caller.
+        # TODO:
+        # We might handle only the few common cases here, like
+        # Initialize and Cancel.
+        return exc.msg
+    except SilentError as exc:
+        if __debug__:
+            log.error(__name__, "SilentError: %s", exc.message)
+    except BaseException as exc:
+        # Either:
+        # - the message had a type that has a registered handler, but does not have
+        #   a protobuf class
+        # - the message was not valid protobuf
+        # - workflow raised some kind of an exception while running
+        # - something canceled the workflow from the outside
+        if __debug__:
+            if isinstance(exc, ActionCancelled):
+                log.debug(__name__, "cancelled: %s", exc.message)
+            elif isinstance(exc, loop.TaskClosed):
+                log.debug(__name__, "cancelled: loop task was closed")
+            else:
+                log.exception(__name__, exc)
+        res_msg = message_handler.failure(exc)
+
+    if res_msg is not None:
+        # perform the write outside the big try-except block, so that usb write
+        # problem bubbles up
+        await pairing_ctx.write(res_msg)
+    return None
diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py
new file mode 100644
index 0000000000..3f9cd8f693
--- /dev/null
+++ b/core/src/trezor/wire/thp/received_message_handler.py
@@ -0,0 +1,446 @@
+import ustruct
+from typing import TYPE_CHECKING
+
+from storage.cache_common import (
+    CHANNEL_HANDSHAKE_HASH,
+    CHANNEL_KEY_RECEIVE,
+    CHANNEL_KEY_SEND,
+    CHANNEL_NONCE_RECEIVE,
+    CHANNEL_NONCE_SEND,
+)
+from storage.cache_thp import (
+    KEY_LENGTH,
+    MANAGEMENT_SESSION_ID,
+    SESSION_ID_LENGTH,
+    TAG_LENGTH,
+    update_channel_last_used,
+    update_session_last_used,
+)
+from trezor import log, loop, protobuf, utils
+from trezor.enums import FailureType
+from trezor.messages import Failure
+
+from .. import message_handler
+from ..errors import DataError
+from ..protocol_common import Message
+from . import (
+    ACK_MESSAGE,
+    HANDSHAKE_COMP_RES,
+    HANDSHAKE_INIT_RES,
+    ChannelState,
+    PacketHeader,
+    SessionState,
+    ThpDecryptionError,
+    ThpError,
+    ThpErrorType,
+    ThpInvalidDataError,
+    ThpUnallocatedSessionError,
+)
+from . import alternating_bit_protocol as ABP
+from . import (
+    checksum,
+    control_byte,
+    get_enabled_pairing_methods,
+    get_encoded_device_properties,
+    session_manager,
+)
+from .checksum import CHECKSUM_LENGTH
+from .crypto import PUBKEY_LENGTH, Handshake
+from .writer import (
+    INIT_HEADER_LENGTH,
+    MESSAGE_TYPE_LENGTH,
+    write_payload_to_wire_and_add_checksum,
+)
+
+if TYPE_CHECKING:
+    from typing import Awaitable
+
+    from trezor.messages import ThpHandshakeCompletionReqNoisePayload
+
+    from .channel import Channel
+
+if __debug__:
+    from ubinascii import hexlify
+
+    from . import state_to_str
+
+
+_TREZOR_STATE_UNPAIRED = b"\x00"
+_TREZOR_STATE_PAIRED = b"\x01"
+
+
+async def handle_received_message(
+    ctx: Channel, message_buffer: utils.BufferType
+) -> None:
+    """Handle a message received from the channel."""
+
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "handle_received_message")
+        if utils.ALLOW_DEBUG_MESSAGES:  # TODO remove after performance tests are done
+            try:
+                import micropython
+
+                print("micropython.mem_info() from received_message_handler.py")
+                micropython.mem_info()
+                print("Allocation count:", micropython.alloc_count())  # type: ignore ["alloc_count" is not a known attribute of module "micropython"]
+            except AttributeError:
+                print(
+                    "To show allocation count, create the build with TREZOR_MEMPERF=1"
+                )
+    ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
+    message_length = payload_length + INIT_HEADER_LENGTH
+
+    _check_checksum(message_length, message_buffer)
+
+    # Synchronization process
+    seq_bit = (ctrl_byte & 0x10) >> 4
+    ack_bit = (ctrl_byte & 0x08) >> 3
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(
+            __name__,
+            "handle_completed_message - seq bit of message: %d, ack bit of message: %d",
+            seq_bit,
+            ack_bit,
+        )
+    # 0: Update "last-time used"
+    update_channel_last_used(ctx.channel_id)
+
+    # 1: Handle ACKs
+    if control_byte.is_ack(ctrl_byte):
+        await _handle_ack(ctx, ack_bit)
+        return
+
+    if _should_have_ctrl_byte_encrypted_transport(
+        ctx
+    ) and not control_byte.is_encrypted_transport(ctrl_byte):
+        raise ThpError("Message is not encrypted. Ignoring")
+
+    # 2: Handle message with unexpected sequential bit
+    if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(__name__, "Received message with an unexpected sequential bit")
+        await _send_ack(ctx, ack_bit=seq_bit)
+        raise ThpError("Received message with an unexpected sequential bit")
+
+    # 3: Send ACK in response
+    await _send_ack(ctx, ack_bit=seq_bit)
+
+    ABP.set_expected_receive_seq_bit(ctx.channel_cache, 1 - seq_bit)
+
+    try:
+        await _handle_message_to_app_or_channel(
+            ctx, payload_length, message_length, ctrl_byte
+        )
+    except ThpUnallocatedSessionError as e:
+        error_message = Failure(code=FailureType.ThpUnallocatedSession)
+        await ctx.write(error_message, e.session_id)
+    except ThpDecryptionError:
+        await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
+        ctx.clear()
+    except ThpInvalidDataError:
+        await ctx.write_error(ThpErrorType.INVALID_DATA)
+        ctx.clear()
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "handle_received_message - end")
+
+
+def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
+    ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
+    header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(
+            __name__,
+            "Writing ACK message to a channel with id: %d, ack_bit: %d",
+            ctx.get_channel_id_int(),
+            ack_bit,
+        )
+    return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
+
+
+def _check_checksum(message_length: int, message_buffer: utils.BufferType) -> None:
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "check_checksum")
+    if not checksum.is_valid(
+        checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
+        data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
+    ):
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(__name__, "Invalid checksum, ignoring message.")
+        raise ThpError("Invalid checksum, ignoring message.")
+
+
+async def _handle_ack(ctx: Channel, ack_bit: int) -> None:
+    if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
+        return
+    # ACK is expected and it has correct sync bit
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "Received ACK message with correct ack bit")
+    if ctx.transmission_loop is not None:
+        ctx.transmission_loop.stop_immediately()
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(__name__, "Stopped transmission loop")
+
+    ABP.set_sending_allowed(ctx.channel_cache, True)
+
+    if ctx.write_task_spawn is not None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
+        await ctx.write_task_spawn
+        # Note that no the write_task_spawn could result in loop.clear(),
+        # which will result in termination of this function - any code after
+        # this await might not be executed
+
+
+def _handle_message_to_app_or_channel(
+    ctx: Channel,
+    payload_length: int,
+    message_length: int,
+    ctrl_byte: int,
+) -> Awaitable[None]:
+    state = ctx.get_channel_state()
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "state: %s", state_to_str(state))
+
+    if state is ChannelState.ENCRYPTED_TRANSPORT:
+        return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
+
+    if state is ChannelState.TH1:
+        return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
+
+    if state is ChannelState.TH2:
+        return _handle_state_TH2(ctx, message_length, ctrl_byte)
+
+    if _is_channel_state_pairing(state):
+        return _handle_pairing(ctx, message_length)
+
+    raise ThpError("Unimplemented channel state")
+
+
+async def _handle_state_TH1(
+    ctx: Channel,
+    payload_length: int,
+    message_length: int,
+    ctrl_byte: int,
+) -> None:
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "handle_state_TH1")
+    if not control_byte.is_handshake_init_req(ctrl_byte):
+        raise ThpError("Message received is not a handshake init request!")
+    if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
+        raise ThpError("Message received is not a valid handshake init request!")
+
+    ctx.handshake = Handshake()
+
+    host_ephemeral_pubkey = bytearray(
+        ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
+    )
+    trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
+        ctx.handshake.handle_th1_crypto(
+            get_encoded_device_properties(ctx.iface), host_ephemeral_pubkey
+        )
+    )
+
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(
+            __name__,
+            "trezor ephemeral pubkey: %s",
+            hexlify(trezor_ephemeral_pubkey).decode(),
+        )
+        log.debug(
+            __name__,
+            "encrypted trezor masked static pubkey: %s",
+            hexlify(encrypted_trezor_static_pubkey).decode(),
+        )
+        log.debug(__name__, "tag: %s", hexlify(tag))
+
+    payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
+
+    # send handshake init response message
+    ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
+    ctx.set_channel_state(ChannelState.TH2)
+    return
+
+
+async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
+    from apps.thp.credential_manager import validate_credential
+
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "handle_state_TH2")
+    if not control_byte.is_handshake_comp_req(ctrl_byte):
+        raise ThpError("Message received is not a handshake completion request!")
+    if ctx.handshake is None:
+        raise Exception("Handshake object is not prepared. Retry handshake.")
+
+    host_encrypted_static_pubkey = memoryview(ctx.buffer)[
+        INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
+    ]
+    handshake_completion_request_noise_payload = memoryview(ctx.buffer)[
+        INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
+    ]
+
+    ctx.handshake.handle_th2_crypto(
+        host_encrypted_static_pubkey, handshake_completion_request_noise_payload
+    )
+
+    ctx.channel_cache.set(CHANNEL_KEY_RECEIVE, ctx.handshake.key_receive)
+    ctx.channel_cache.set(CHANNEL_KEY_SEND, ctx.handshake.key_send)
+    ctx.channel_cache.set(CHANNEL_HANDSHAKE_HASH, ctx.handshake.h)
+    ctx.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
+    ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
+
+    noise_payload = _decode_message(
+        ctx.buffer[
+            INIT_HEADER_LENGTH
+            + KEY_LENGTH
+            + TAG_LENGTH : message_length
+            - CHECKSUM_LENGTH
+            - TAG_LENGTH
+        ],
+        0,
+        "ThpHandshakeCompletionReqNoisePayload",
+    )
+    if TYPE_CHECKING:
+        assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
+    enabled_methods = get_enabled_pairing_methods(ctx.iface)
+    for method in noise_payload.pairing_methods:
+        if method not in enabled_methods:
+            raise ThpInvalidDataError()
+        if method not in ctx.selected_pairing_methods:
+            ctx.selected_pairing_methods.append(method)
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(
+            __name__,
+            "host static pubkey: %s, noise payload: %s",
+            utils.get_bytes_as_str(host_encrypted_static_pubkey),
+            utils.get_bytes_as_str(handshake_completion_request_noise_payload),
+        )
+
+    # key is decoded in handshake._handle_th2_crypto
+    host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
+
+    paired: bool = False
+
+    if noise_payload.host_pairing_credential is not None:
+        try:  # TODO change try-except for something better
+            paired = validate_credential(
+                noise_payload.host_pairing_credential,
+                host_static_pubkey,
+            )
+        except DataError as e:
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.exception(__name__, e)
+            pass
+
+    trezor_state = _TREZOR_STATE_UNPAIRED
+    if paired:
+        trezor_state = _TREZOR_STATE_PAIRED
+    # send hanshake completion response
+    ctx.write_handshake_message(
+        HANDSHAKE_COMP_RES,
+        ctx.handshake.get_handshake_completion_response(trezor_state),
+    )
+
+    ctx.handshake = None
+
+    if paired:
+        ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
+    else:
+        ctx.set_channel_state(ChannelState.TP1)
+
+
+async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
+    if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+        log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
+
+    ctx.decrypt_buffer(message_length)
+    session_id, message_type = ustruct.unpack(
+        ">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:]
+    )
+    if session_id not in ctx.sessions:
+        if session_id == MANAGEMENT_SESSION_ID:
+            s = session_manager.create_new_management_session(ctx)
+        else:
+            s = session_manager.get_session_from_cache(ctx, session_id)
+        if s is None:
+            raise ThpUnallocatedSessionError(session_id)
+        ctx.sessions[session_id] = s
+        loop.schedule(s.handle())
+
+    elif ctx.sessions[session_id].get_session_state() is SessionState.UNALLOCATED:
+        raise ThpUnallocatedSessionError(session_id)
+
+    s = ctx.sessions[session_id]
+    update_session_last_used(s.channel_id, (s.session_id).to_bytes(1, "big"))
+
+    s.incoming_message.put(
+        Message(
+            message_type,
+            ctx.buffer[
+                INIT_HEADER_LENGTH
+                + MESSAGE_TYPE_LENGTH
+                + SESSION_ID_LENGTH : message_length
+                - CHECKSUM_LENGTH
+                - TAG_LENGTH
+            ],
+        )
+    )
+
+
+async def _handle_pairing(ctx: Channel, message_length: int) -> None:
+    from .pairing_context import PairingContext
+
+    if ctx.connection_context is None:
+        ctx.connection_context = PairingContext(ctx)
+        loop.schedule(ctx.connection_context.handle())
+
+    ctx.decrypt_buffer(message_length)
+    message_type = ustruct.unpack(
+        ">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
+    )[0]
+
+    ctx.connection_context.incoming_message.put(
+        Message(
+            message_type,
+            ctx.buffer[
+                INIT_HEADER_LENGTH
+                + MESSAGE_TYPE_LENGTH
+                + SESSION_ID_LENGTH : message_length
+                - CHECKSUM_LENGTH
+                - TAG_LENGTH
+            ],
+        )
+    )
+
+
+def _should_have_ctrl_byte_encrypted_transport(ctx: Channel) -> bool:
+    if ctx.get_channel_state() in [
+        ChannelState.UNALLOCATED,
+        ChannelState.TH1,
+        ChannelState.TH2,
+    ]:
+        return False
+    return True
+
+
+def _decode_message(
+    buffer: bytes, msg_type: int, message_name: str | None = None
+) -> protobuf.MessageType:
+    if __debug__:
+        log.debug(__name__, "decode message")
+    if message_name is not None:
+        expected_type = protobuf.type_for_name(message_name)
+    else:
+        expected_type = protobuf.type_for_wire(msg_type)
+    return message_handler.wrap_protobuf_load(buffer, expected_type)
+
+
+def _is_channel_state_pairing(state: int) -> bool:
+    if state in (
+        ChannelState.TP1,
+        ChannelState.TP2,
+        ChannelState.TP3,
+        ChannelState.TP4,
+        ChannelState.TC1,
+    ):
+        return True
+    return False
diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py
new file mode 100644
index 0000000000..688fa46b37
--- /dev/null
+++ b/core/src/trezor/wire/thp/session_context.py
@@ -0,0 +1,169 @@
+from typing import TYPE_CHECKING
+
+from storage import cache_thp
+from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
+from trezor import log, loop, protobuf, utils
+from trezor.wire import message_handler, protocol_common
+from trezor.wire.context import UnexpectedMessageException
+from trezor.wire.message_handler import failure
+
+from ..protocol_common import Context, Message
+from . import SessionState
+
+if TYPE_CHECKING:
+    from typing import Awaitable, Container
+
+    from storage.cache_common import DataCache
+
+    from .channel import Channel
+
+    pass
+
+_EXIT_LOOP = True
+_REPEAT_LOOP = False
+
+if __debug__:
+    from trezor.utils import get_bytes_as_str
+
+
+class GenericSessionContext(Context):
+
+    def __init__(self, channel: Channel, session_id: int) -> None:
+        super().__init__(channel.iface, channel.channel_id)
+        self.channel: Channel = channel
+        self.session_id: int = session_id
+        self.incoming_message = loop.mailbox()
+
+    async def handle(self) -> None:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            self._handle_debug()
+
+        next_message: Message | None = None
+
+        while True:
+            message = next_message
+            next_message = None
+            try:
+                if await self._handle_message(message):
+                    loop.schedule(self.handle())
+                    return
+            except UnexpectedMessageException as unexpected:
+                # The workflow was interrupted by an unexpected message. We need to
+                # process it as if it was a new message...
+                next_message = unexpected.msg
+                continue
+            except Exception as exc:
+                # Log and try again.
+                if __debug__:
+                    log.exception(__name__, exc)
+
+    def _handle_debug(self) -> None:
+        log.debug(
+            __name__,
+            "handle - start (channel_id (bytes): %s, session_id: %d)",
+            get_bytes_as_str(self.channel_id),
+            self.session_id,
+        )
+
+    async def _handle_message(
+        self,
+        next_message: Message | None,
+    ) -> bool:
+
+        try:
+            if next_message is not None:
+                # Process the message from previous run.
+                message = next_message
+                next_message = None
+            else:
+                # Wait for a new message from wire
+                message = await self.incoming_message
+
+        except protocol_common.WireError as e:
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.exception(__name__, e)
+            await self.write(failure(e))
+            return _REPEAT_LOOP
+
+        await message_handler.handle_single_message(self, message)
+        return _EXIT_LOOP
+
+    async def read(
+        self,
+        expected_types: Container[int],
+        expected_type: type[protobuf.MessageType] | None = None,
+    ) -> protobuf.MessageType:
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            exp_type: str = str(expected_type)
+            if expected_type is not None:
+                exp_type = expected_type.MESSAGE_NAME
+            log.debug(
+                __name__,
+                "Read - with expected types %s and expected type %s",
+                str(expected_types),
+                exp_type,
+            )
+        message: Message = await self.incoming_message
+        if message.type not in expected_types:
+            if __debug__:
+                log.debug(
+                    __name__,
+                    "EXPECTED TYPES: %s\nRECEIVED TYPE: %s",
+                    str(expected_types),
+                    str(message.type),
+                )
+            raise UnexpectedMessageException(message)
+
+        if expected_type is None:
+            expected_type = protobuf.type_for_wire(message.type)
+
+        return message_handler.wrap_protobuf_load(message.data, expected_type)
+
+    async def write(self, msg: protobuf.MessageType) -> None:
+        return await self.channel.write(msg, self.session_id)
+
+    def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
+        return self.channel.write(msg, self.session_id, force=True)
+
+    def get_session_state(self) -> SessionState: ...
+
+
+class ManagementSessionContext(GenericSessionContext):
+
+    def __init__(
+        self, channel_ctx: Channel, session_id: int = MANAGEMENT_SESSION_ID
+    ) -> None:
+        super().__init__(channel_ctx, session_id)
+
+    def get_session_state(self) -> SessionState:
+        return SessionState.MANAGEMENT
+
+
+class SessionContext(GenericSessionContext):
+
+    def __init__(self, channel_ctx: Channel, session_cache: SessionThpCache) -> None:
+        if channel_ctx.channel_id != session_cache.channel_id:
+            raise Exception(
+                "The session has different channel id than the provided channel context!"
+            )
+        session_id = int.from_bytes(session_cache.session_id, "big")
+        super().__init__(channel_ctx, session_id)
+        self.session_cache = session_cache
+
+    # ACCESS TO SESSION DATA
+
+    def get_session_state(self) -> SessionState:
+        state = int.from_bytes(self.session_cache.state, "big")
+        return SessionState(state)
+
+    def set_session_state(self, state: SessionState) -> None:
+        self.session_cache.state = bytearray(state.to_bytes(1, "big"))
+
+    def release(self) -> None:
+        if self.session_cache is not None:
+            cache_thp.clear_session(self.session_cache)
+
+    # ACCESS TO CACHE
+    @property
+    def cache(self) -> DataCache:
+        return self.session_cache
diff --git a/core/src/trezor/wire/thp/session_manager.py b/core/src/trezor/wire/thp/session_manager.py
new file mode 100644
index 0000000000..3377ce437f
--- /dev/null
+++ b/core/src/trezor/wire/thp/session_manager.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from storage import cache_thp
+
+from .session_context import (
+    GenericSessionContext,
+    ManagementSessionContext,
+    SessionContext,
+)
+
+if TYPE_CHECKING:
+    from .channel import Channel
+
+
+def create_new_session(channel_ctx: Channel) -> SessionContext:
+    """
+    Creates new `SessionContext` backed by cache.
+    """
+    session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
+    return SessionContext(channel_ctx, session_cache)
+
+
+def create_new_management_session(
+    channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID
+) -> ManagementSessionContext:
+    """
+    Creates new `ManagementSessionContext` that is not backed by cache entry.
+
+    Seed cannot be derived with this type of session.
+    """
+    return ManagementSessionContext(channel_ctx, session_id)
+
+
+def get_session_from_cache(
+    channel_ctx: Channel, session_id: int
+) -> GenericSessionContext | None:
+    """
+    Returns a `SessionContext` (or `ManagementSessionContext`) reconstructed from a cache or `None` if backing cache is not found.
+    """
+    session_id_bytes = session_id.to_bytes(1, "big")
+    session_cache = cache_thp.get_allocated_session(
+        channel_ctx.channel_id, session_id_bytes
+    )
+    if session_cache is None:
+        return None
+    elif cache_thp.is_management_session(session_cache):
+        return ManagementSessionContext(channel_ctx, session_id)
+    return SessionContext(channel_ctx, session_cache)
diff --git a/core/src/trezor/wire/thp/thp_main.py b/core/src/trezor/wire/thp/thp_main.py
new file mode 100644
index 0000000000..d83bce301e
--- /dev/null
+++ b/core/src/trezor/wire/thp/thp_main.py
@@ -0,0 +1,190 @@
+import ustruct
+from micropython import const
+from typing import TYPE_CHECKING
+
+from storage.cache_thp import BROADCAST_CHANNEL_ID
+from trezor import io, log, loop, utils
+
+from . import (
+    CHANNEL_ALLOCATION_REQ,
+    CODEC_V1,
+    ChannelState,
+    PacketHeader,
+    ThpError,
+    ThpErrorType,
+    channel_manager,
+    checksum,
+    get_channel_allocation_response,
+    writer,
+)
+from .channel import Channel
+from .checksum import CHECKSUM_LENGTH
+from .writer import (
+    INIT_HEADER_LENGTH,
+    MAX_PAYLOAD_LEN,
+    PACKET_LENGTH,
+    write_payload_to_wire_and_add_checksum,
+)
+
+if TYPE_CHECKING:
+    from trezorio import WireInterface
+
+_CID_REQ_PAYLOAD_LENGTH = const(12)
+_READ_BUFFER: bytearray
+_WRITE_BUFFER: bytearray
+_CHANNELS: dict[int, Channel] = {}
+
+
+def set_read_buffer(buffer: bytearray) -> None:
+    global _READ_BUFFER
+    _READ_BUFFER = buffer
+
+
+def set_write_buffer(buffer: bytearray) -> None:
+    global _WRITE_BUFFER
+    _WRITE_BUFFER = buffer
+
+
+def get_raw_read_buffer() -> bytearray:
+    global _READ_BUFFER
+    return _READ_BUFFER
+
+
+def get_raw_write_buffer() -> bytearray:
+    global _WRITE_BUFFER
+    return _WRITE_BUFFER
+
+
+async def thp_main_loop(iface: WireInterface) -> None:
+    global _CHANNELS
+    global _READ_BUFFER
+    _CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
+
+    read = loop.wait(iface.iface_num() | io.POLL_READ)
+    packet = bytearray(PACKET_LENGTH)
+    while True:
+        try:
+            if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+                log.debug(__name__, "thp_main_loop")
+            packet_len = await read
+            assert packet_len == len(packet)
+            iface.read(packet, 0)
+
+            ctrl_byte, cid = ustruct.unpack(">BH", packet)
+
+            if ctrl_byte == CODEC_V1:
+                await _handle_codec_v1(iface, packet)
+                continue
+
+            if cid == BROADCAST_CHANNEL_ID:
+                await _handle_broadcast(iface, ctrl_byte, packet)
+                continue
+
+            if cid in _CHANNELS:
+                await _handle_allocated(iface, cid, packet)
+            else:
+                await _handle_unallocated(iface, cid)
+
+        except ThpError as e:
+            if __debug__:
+                log.exception(__name__, e)
+
+
+async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
+    # If the received packet is not an initial codec_v1 packet, do not send error message
+    if not packet[1:3] == b"##":
+        return
+    if __debug__:
+        log.debug(__name__, "Received codec_v1 message, returning error")
+    error_message = _get_codec_v1_error_message()
+    await writer.write_packet_to_wire(iface, error_message)
+
+
+async def _handle_broadcast(
+    iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
+) -> None:
+    global _READ_BUFFER
+    if ctrl_byte != CHANNEL_ALLOCATION_REQ:
+        raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
+    if __debug__:
+        log.debug(__name__, "Received valid message on the broadcast channel")
+
+    length, nonce = ustruct.unpack(">H8s", packet[3:])
+    payload = _get_buffer_for_payload(length, packet[5:], _CID_REQ_PAYLOAD_LENGTH)
+    if not checksum.is_valid(
+        payload[-4:],
+        packet[: _CID_REQ_PAYLOAD_LENGTH + INIT_HEADER_LENGTH - CHECKSUM_LENGTH],
+    ):
+        raise ThpError("Checksum is not valid")
+
+    new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
+    cid = int.from_bytes(new_channel.channel_id, "big")
+    _CHANNELS[cid] = new_channel
+
+    response_data = get_channel_allocation_response(
+        nonce, new_channel.channel_id, iface
+    )
+    response_header = PacketHeader.get_channel_allocation_response_header(
+        len(response_data) + CHECKSUM_LENGTH,
+    )
+    if __debug__:
+        log.debug(__name__, "New channel allocated with id %d", cid)
+
+    await write_payload_to_wire_and_add_checksum(iface, response_header, response_data)
+
+
+async def _handle_allocated(
+    iface: WireInterface, cid: int, packet: utils.BufferType
+) -> None:
+    channel = _CHANNELS[cid]
+    if channel is None:
+        await _handle_unallocated(iface, cid)
+        raise ThpError("Invalid state of a channel")
+    if channel.iface is not iface:
+        # TODO send error message to wire
+        raise ThpError("Channel has different WireInterface")
+
+    if channel.get_channel_state() != ChannelState.UNALLOCATED:
+        x = channel.receive_packet(packet)
+        if x is not None:
+            await x
+
+
+async def _handle_unallocated(iface: WireInterface, cid: int) -> None:
+    data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big")
+    header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
+    await write_payload_to_wire_and_add_checksum(iface, header, data)
+
+
+def _get_buffer_for_payload(
+    payload_length: int,
+    existing_buffer: utils.BufferType,
+    max_length: int = MAX_PAYLOAD_LEN,
+) -> utils.BufferType:
+    if payload_length > max_length:
+        raise ThpError("Message too large")
+    if payload_length > len(existing_buffer):
+        return _try_allocate_new_buffer(payload_length)
+    return _reuse_existing_buffer(payload_length, existing_buffer)
+
+
+def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
+    try:
+        payload: utils.BufferType = bytearray(payload_length)
+    except MemoryError:
+        payload = bytearray(PACKET_LENGTH)
+        raise ThpError("Message too large")
+    return payload
+
+
+def _reuse_existing_buffer(
+    payload_length: int, existing_buffer: utils.BufferType
+) -> utils.BufferType:
+    return memoryview(existing_buffer)[:payload_length]
+
+
+def _get_codec_v1_error_message() -> bytes:
+    # Codec_v1 magic constant "?##" + Failure message type + msg_size
+    # + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B
+    ERROR_MSG = b"\x3f\x23\x23\x00\x03\x00\x00\x00\x14\x08\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+    return ERROR_MSG
diff --git a/core/src/trezor/wire/thp/transmission_loop.py b/core/src/trezor/wire/thp/transmission_loop.py
new file mode 100644
index 0000000000..cd3e3ba2f8
--- /dev/null
+++ b/core/src/trezor/wire/thp/transmission_loop.py
@@ -0,0 +1,54 @@
+from micropython import const
+from typing import TYPE_CHECKING
+
+from trezor import loop
+
+from .writer import write_payload_to_wire_and_add_checksum
+
+if TYPE_CHECKING:
+    from . import PacketHeader
+    from .channel import Channel
+
+MAX_RETRANSMISSION_COUNT = const(50)
+MIN_RETRANSMISSION_COUNT = const(2)
+
+
+class TransmissionLoop:
+
+    def __init__(
+        self, channel: Channel, header: PacketHeader, transport_payload: bytes
+    ) -> None:
+        self.channel: Channel = channel
+        self.header: PacketHeader = header
+        self.transport_payload: bytes = transport_payload
+        self.wait_task: loop.spawn | None = None
+        self.min_retransmisson_count_achieved: bool = False
+
+    async def start(
+        self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT
+    ) -> None:
+        self.min_retransmisson_count_achieved = False
+        for i in range(max_retransmission_count):
+            if i >= MIN_RETRANSMISSION_COUNT:
+                self.min_retransmisson_count_achieved = True
+            await write_payload_to_wire_and_add_checksum(
+                self.channel.iface, self.header, self.transport_payload
+            )
+            self.wait_task = loop.spawn(self._wait(i))
+            try:
+                await self.wait_task
+            except loop.TaskClosed:
+                self.wait_task = None
+                break
+
+    def stop_immediately(self) -> None:
+        if self.wait_task is not None:
+            self.wait_task.close()
+        self.wait_task = None
+
+    async def _wait(self, counter: int = 0) -> None:
+        timeout_ms = round(10200 - 1010000 / (counter + 100))
+        await loop.sleep(timeout_ms)
+
+    def __del__(self) -> None:
+        self.stop_immediately()
diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py
new file mode 100644
index 0000000000..f6963bdf6f
--- /dev/null
+++ b/core/src/trezor/wire/thp/writer.py
@@ -0,0 +1,93 @@
+from micropython import const
+from trezorcrypto import crc
+from typing import TYPE_CHECKING
+
+from trezor import io, log, loop, utils
+
+from . import PacketHeader
+
+INIT_HEADER_LENGTH = const(5)
+CONT_HEADER_LENGTH = const(3)
+CHECKSUM_LENGTH = const(4)
+MAX_PAYLOAD_LEN = const(60000)
+MESSAGE_TYPE_LENGTH = const(2)
+
+PACKET_LENGTH = io.WebUSB.PACKET_LEN
+
+if TYPE_CHECKING:
+    from trezorio import WireInterface
+    from typing import Awaitable, Sequence
+
+
+def write_payload_to_wire_and_add_checksum(
+    iface: WireInterface, header: PacketHeader, transport_payload: bytes
+) -> Awaitable[None]:
+    header_checksum: int = crc.crc32(header.to_bytes())
+    checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
+        CHECKSUM_LENGTH, "big"
+    )
+    data = (transport_payload, checksum)
+    return write_payloads_to_wire(iface, header, data)
+
+
+async def write_payloads_to_wire(
+    iface: WireInterface, header: PacketHeader, data: Sequence[bytes]
+) -> None:
+    n_of_data = len(data)
+    total_length = sum(len(item) for item in data)
+
+    current_data_idx = 0
+    current_data_offset = 0
+
+    packet = bytearray(PACKET_LENGTH)
+    header.pack_to_init_buffer(packet)
+    packet_offset: int = INIT_HEADER_LENGTH
+    packet_number = 0
+    nwritten = 0
+    while nwritten < total_length:
+        if packet_number == 1:
+            header.pack_to_cont_buffer(packet)
+        if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH:
+            packet[:] = bytearray(PACKET_LENGTH)
+            header.pack_to_cont_buffer(packet)
+        while True:
+            n = utils.memcpy(
+                packet, packet_offset, data[current_data_idx], current_data_offset
+            )
+            packet_offset += n
+            current_data_offset += n
+            nwritten += n
+
+            if packet_offset < PACKET_LENGTH:
+                current_data_idx += 1
+                current_data_offset = 0
+                if current_data_idx >= n_of_data:
+                    break
+            elif packet_offset == PACKET_LENGTH:
+                break
+            else:
+                raise Exception("Should not happen!!!")
+        packet_number += 1
+        packet_offset = CONT_HEADER_LENGTH
+
+        # write packet to wire (in-lined)
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
+            )
+        written_by_iface: int = 0
+        while written_by_iface < len(packet):
+            await loop.wait(iface.iface_num() | io.POLL_WRITE)
+            written_by_iface = iface.write(packet)
+
+
+async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
+    while True:
+        await loop.wait(iface.iface_num() | io.POLL_WRITE)
+        if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
+            log.debug(
+                __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
+            )
+        n_written = iface.write(packet)
+        if n_written == len(packet):
+            return
diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py
index 67b88f8e68..9fc72c3e98 100644
--- a/core/src/trezor/workflow.py
+++ b/core/src/trezor/workflow.py
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
 
 import storage.cache as storage_cache
 from trezor import log, loop
-from trezor.enums import MessageType
+from trezor.enums import MessageType, ThpMessageType
 
 if TYPE_CHECKING:
     from typing import Callable
@@ -17,9 +17,14 @@ if __debug__:
 
     from trezor import utils
 
+if utils.USE_THP:
+    protocol_specific = ThpMessageType.ThpCreateNewSession
+else:
+    protocol_specific = MessageType.Initialize
+
 
 ALLOW_WHILE_LOCKED = (
-    MessageType.Initialize,
+    protocol_specific,
     MessageType.EndSession,
     MessageType.GetFeatures,
     MessageType.Cancel,
diff --git a/core/tests/mock_wire_interface.py b/core/tests/mock_wire_interface.py
new file mode 100644
index 0000000000..13cd033375
--- /dev/null
+++ b/core/tests/mock_wire_interface.py
@@ -0,0 +1,50 @@
+from trezor.loop import wait
+
+
+class MockHID:
+
+    TX_PACKET_LEN = 64
+    RX_PACKET_LEN = 64
+
+    def __init__(self, num):
+        self.num = num
+        self.data = []
+        self.packet = None
+
+    def pad_packet(self, data):
+        if len(data) > self.RX_PACKET_LEN:
+            raise Exception("Too long packet")
+        padding_length = self.RX_PACKET_LEN - len(data)
+        return data + b"\x00" * padding_length
+
+    def iface_num(self):
+        return self.num
+
+    def write(self, msg):
+        self.data.append(bytearray(msg))
+        return len(msg)
+
+    def mock_read(self, packet, gen):
+        self.packet = self.pad_packet(packet)
+        return gen.send(self.RX_PACKET_LEN)
+
+    def read(self, buffer, offset=0):
+        if self.packet is None:
+            raise Exception("No packet to read")
+
+        if offset > len(buffer):
+            raise Exception("Offset out of bounds")
+
+        buffer_space = len(buffer) - offset
+
+        if len(self.packet) > buffer_space:
+            raise Exception("Buffer too small")
+        else:
+            end = offset + len(self.packet)
+            buffer[offset:end] = self.packet
+            read = len(self.packet)
+            self.packet = None
+            return read
+
+    def wait_object(self, mode):
+        return wait(mode | self.num)
diff --git a/core/tests/myTests.sh b/core/tests/myTests.sh
new file mode 100755
index 0000000000..1c29c1fd01
--- /dev/null
+++ b/core/tests/myTests.sh
@@ -0,0 +1,42 @@
+#!/usr/bin/env bash
+
+declare -a results
+declare -i passed=0 failed=0 exit_code=0
+declare COLOR_GREEN='\e[32m' COLOR_RED='\e[91m' COLOR_RESET='\e[39m'
+MICROPYTHON="${MICROPYTHON:-../build/unix/trezor-emu-core -X heapsize=2M}"
+print_summary() {
+    echo
+    echo 'Summary:'
+    echo '-------------------'
+    printf '%b\n' "${results[@]}"
+    if [ $exit_code == 0 ]; then
+        echo -e "${COLOR_GREEN}PASSED:${COLOR_RESET} $passed/$num_of_tests tests OK!"
+    else
+        echo -e "${COLOR_RED}FAILED:${COLOR_RESET} $failed/$num_of_tests tests failed!"
+    fi
+}
+
+trap 'print_summary; echo -e "${COLOR_RED}Interrupted by user!${COLOR_RESET}"; exit 1' SIGINT
+
+cd $(dirname $0)
+
+[ -z "$*" ] && tests=(test_trezor.wire.t*.py ) || tests=($*)
+
+declare -i num_of_tests=${#tests[@]}
+
+for test_case in ${tests[@]}; do
+    echo ${MICROPYTHON}
+    echo ${test_case}
+    echo
+    if $MICROPYTHON $test_case; then
+        results+=("${COLOR_GREEN}OK:${COLOR_RESET} $test_case")
+        ((passed++))
+    else
+        results+=("${COLOR_RED}FAIL:${COLOR_RESET} $test_case")
+        ((failed++))
+        exit_code=1
+    fi
+done
+
+print_summary
+exit $exit_code
diff --git a/core/tests/test_apps.bitcoin.approver.py b/core/tests/test_apps.bitcoin.approver.py
index 8086fd8e2d..17a870df7f 100644
--- a/core/tests/test_apps.bitcoin.approver.py
+++ b/core/tests/test_apps.bitcoin.approver.py
@@ -1,4 +1,4 @@
-from common import H_, await_result, unittest  # isort:skip
+from common import *  # isort:skip
 
 import storage.cache_codec
 from trezor import wire
@@ -12,19 +12,33 @@ from trezor.messages import (
     TxOutput,
 )
 from trezor.wire import context
-from trezor.wire.codec.codec_context import CodecContext
 
 from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
 from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
 from apps.bitcoin.sign_tx.bitcoin import Bitcoin
 from apps.bitcoin.sign_tx.tx_info import TxInfo
 from apps.common import coins
+from trezor.wire.codec.codec_context import CodecContext
+
+if utils.USE_THP:
+    import thp_common
+else:
+    import storage.cache_codec
+    from trezor.wire.codec.codec_context import CodecContext
 
 
 class TestApprover(unittest.TestCase):
+    if utils.USE_THP:
 
-    def setUpClass(self):
-        context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+        def setUpClass(self):
+            if __debug__:
+                thp_common.suppres_debug_log()
+            thp_common.prepare_context()
+
+    else:
+
+        def setUpClass(self):
+            context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
 
     def tearDownClass(self):
         context.CURRENT_CONTEXT = None
@@ -54,7 +68,8 @@ class TestApprover(unittest.TestCase):
             coin_name=self.coin.coin_name,
             script_type=InputScriptType.SPENDTAPROOT,
         )
-        storage.cache_codec.start_session()
+        if not utils.USE_THP:
+            storage.cache_codec.start_session()
 
     def make_coinjoin_request(self, inputs):
         return CoinJoinRequest(
diff --git a/core/tests/test_apps.bitcoin.authorization.py b/core/tests/test_apps.bitcoin.authorization.py
index 03d32651c7..4faf202989 100644
--- a/core/tests/test_apps.bitcoin.authorization.py
+++ b/core/tests/test_apps.bitcoin.authorization.py
@@ -1,23 +1,37 @@
-from common import H_, unittest  # isort:skip
+from common import *  # isort:skip
 
 import storage.cache_codec
 from trezor.enums import InputScriptType
 from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
 from trezor.wire import context
-from trezor.wire.codec.codec_context import CodecContext
 
 from apps.bitcoin.authorization import CoinJoinAuthorization
 from apps.common import coins
 
 _ROUND_ID_LEN = 32
 
+if utils.USE_THP:
+    import thp_common
+else:
+    import storage.cache_codec
+    from trezor.wire.codec.codec_context import CodecContext
+
 
 class TestAuthorization(unittest.TestCase):
 
     coin = coins.by_name("Bitcoin")
 
-    def setUpClass(self):
-        context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+    if utils.USE_THP:
+
+        def setUpClass(self):
+            if __debug__:
+                thp_common.suppres_debug_log()
+            thp_common.prepare_context()
+
+    else:
+
+        def setUpClass(self):
+            context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
 
     def tearDownClass(self):
         context.CURRENT_CONTEXT = None
@@ -34,7 +48,8 @@ class TestAuthorization(unittest.TestCase):
         )
 
         self.authorization = CoinJoinAuthorization(self.msg_auth)
-        storage.cache_codec.start_session()
+        if not utils.USE_THP:
+            storage.cache_codec.start_session()
 
     def test_ownership_proof_account_depth_mismatch(self):
         # Account depth mismatch.
diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py
index 232d2bf01d..25239dad8c 100644
--- a/core/tests/test_apps.bitcoin.keychain.py
+++ b/core/tests/test_apps.bitcoin.keychain.py
@@ -1,7 +1,7 @@
 # flake8: noqa: F403,F405
 from common import *  # isort:skip
 
-from storage import cache_codec, cache_common
+from storage import cache_common
 from trezor import wire
 from trezor.crypto import bip39
 from trezor.wire import context
@@ -9,20 +9,38 @@ from trezor.wire.codec.codec_context import CodecContext
 
 from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
 
+if utils.USE_THP:
+    import thp_common
+else:
+    from storage import cache_codec
+
 
 class TestBitcoinKeychain(unittest.TestCase):
 
-    def setUpClass(self):
-        context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+    if utils.USE_THP:
+
+        def setUpClass(self):
+            if __debug__:
+                thp_common.suppres_debug_log()
+            thp_common.prepare_context()
+
+        def setUp(self):
+            seed = bip39.seed(" ".join(["all"] * 12), "")
+            context.cache_set(cache_common.APP_COMMON_SEED, seed)
+
+    else:
+
+        def setUpClass(self):
+            context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+
+        def setUp(self):
+            cache_codec.start_session()
+            seed = bip39.seed(" ".join(["all"] * 12), "")
+            cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
 
     def tearDownClass(self):
         context.CURRENT_CONTEXT = None
 
-    def setUp(self):
-        cache_codec.start_session()
-        seed = bip39.seed(" ".join(["all"] * 12), "")
-        cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
-
     def test_bitcoin(self):
         coin = _get_coin_by_name("Bitcoin")
         keychain = await_result(_get_keychain_for_coin(coin))
@@ -98,18 +116,30 @@ class TestBitcoinKeychain(unittest.TestCase):
 
 @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
 class TestAltcoinKeychains(unittest.TestCase):
+    if utils.USE_THP:
 
-    def setUpClass(self):
-        context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+        def setUpClass(self):
+            if __debug__:
+                thp_common.suppres_debug_log()
+            thp_common.prepare_context()
+
+        def setUp(self):
+            seed = bip39.seed(" ".join(["all"] * 12), "")
+            context.cache_set(cache_common.APP_COMMON_SEED, seed)
+
+    else:
+
+        def setUpClass(self):
+            context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+
+        def setUp(self):
+            cache_codec.start_session()
+            seed = bip39.seed(" ".join(["all"] * 12), "")
+            cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
 
     def tearDownClass(self):
         context.CURRENT_CONTEXT = None
 
-    def setUp(self):
-        cache_codec.start_session()
-        seed = bip39.seed(" ".join(["all"] * 12), "")
-        cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
-
     def test_bcash(self):
         coin = _get_coin_by_name("Bcash")
         keychain = await_result(_get_keychain_for_coin(coin))
diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py
index f54f64d74f..8d0839f374 100644
--- a/core/tests/test_apps.common.keychain.py
+++ b/core/tests/test_apps.common.keychain.py
@@ -2,7 +2,7 @@
 from common import *  # isort:skip
 
 from mock_storage import mock_storage
-from storage import cache, cache_codec, cache_common
+from storage import cache, cache_common
 from trezor import wire
 from trezor.crypto import bip39
 from trezor.enums import SafetyCheckLevel
@@ -13,18 +13,32 @@ from apps.common import safety_checks
 from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
 from apps.common.paths import PATTERN_SEP5, PathSchema
 
+if utils.USE_THP:
+    import thp_common
+if not utils.USE_THP:
+    from storage import cache_codec
+
 
 class TestKeychain(unittest.TestCase):
 
-    def setUpClass(self):
-        context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+    if utils.USE_THP:
+
+        def setUpClass(self):
+            if __debug__:
+                thp_common.suppres_debug_log()
+            thp_common.prepare_context()
+
+    else:
+
+        def setUpClass(self):
+            context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+
+        def setUp(self):
+            cache_codec.start_session()
 
     def tearDownClass(self):
         context.CURRENT_CONTEXT = None
 
-    def setUp(self):
-        cache_codec.start_session()
-
     def tearDown(self):
         cache.clear_all()
 
diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py
index 3215aba267..6355da641c 100644
--- a/core/tests/test_apps.ethereum.keychain.py
+++ b/core/tests/test_apps.ethereum.keychain.py
@@ -3,7 +3,7 @@ from common import *  # isort:skip
 
 import unittest
 
-from storage import cache_codec, cache_common
+from storage import cache_common
 from trezor import wire
 from trezor.crypto import bip39
 from trezor.wire import context
@@ -12,6 +12,12 @@ from trezor.wire.codec.codec_context import CodecContext
 from apps.common.keychain import get_keychain
 from apps.common.paths import HARDENED
 
+if utils.USE_THP:
+    import thp_common
+else:
+    from storage import cache_codec
+
+
 if not utils.BITCOIN_ONLY:
     from ethereum_common import encode_network, make_network
     from trezor.messages import (
@@ -74,17 +80,30 @@ class TestEthereumKeychain(unittest.TestCase):
                 addr,
             )
 
-    def setUpClass(self):
-        context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+    if utils.USE_THP:
+
+        def setUpClass(self):
+            if __debug__:
+                thp_common.suppres_debug_log()
+            thp_common.prepare_context()
+
+        def setUp(self):
+            seed = bip39.seed(" ".join(["all"] * 12), "")
+            context.cache_set(cache_common.APP_COMMON_SEED, seed)
+
+    else:
+
+        def setUpClass(self):
+            context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+
+        def setUp(self):
+            cache_codec.start_session()
+            seed = bip39.seed(" ".join(["all"] * 12), "")
+            cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
 
     def tearDownClass(self):
         context.CURRENT_CONTEXT = None
 
-    def setUp(self):
-        cache_codec.start_session()
-        seed = bip39.seed(" ".join(["all"] * 12), "")
-        cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
-
     def from_address_n(self, address_n):
         slip44 = _slip44_from_address_n(address_n)
         network = make_network(slip44=slip44)
diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py
index cc93015e05..0c92bf254a 100644
--- a/core/tests/test_storage.cache.py
+++ b/core/tests/test_storage.cache.py
@@ -1,241 +1,519 @@
 # flake8: noqa: F403,F405
 from common import *  # isort:skip
 
-from mock_storage import mock_storage
-from storage import cache, cache_codec, cache_common
-from trezor.messages import EndSession, Initialize
-from trezor.wire import context
-from trezor.wire.codec.codec_context import CodecContext
-
-from apps.base import handle_EndSession, handle_Initialize
-from apps.common.cache import stored, stored_async
 
 KEY = 0
 
+if utils.USE_THP:
+    import thp_common
+    from mock_wire_interface import MockHID
+    from storage import cache, cache_thp
+    from trezor.wire.thp import ChannelState
+    from trezor.wire.thp.session_context import SessionContext
 
-# Function moved from cache.py, as it was not used there
-def is_session_started() -> bool:
-    return cache_codec._active_session_idx is not None
+    _PROTOCOL_CACHE = cache_thp
+
+else:
+    from storage import cache, cache_codec
+    from trezor.messages import EndSession, Initialize
+    from apps.base import handle_EndSession
+    from mock_storage import mock_storage
+
+    _PROTOCOL_CACHE = cache_codec
+
+    def is_session_started() -> bool:
+        return cache_codec.get_active_session() is not None
+
+    def get_active_session():
+        return cache_codec.get_active_session()
 
 
 class TestStorageCache(unittest.TestCase):
 
-    def setUpClass(self):
-        context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+    if utils.USE_THP:
 
-    def tearDownClass(self):
-        context.CURRENT_CONTEXT = None
+        def setUpClass(self):
+            if __debug__:
+                thp_common.suppres_debug_log()
+            super().__init__()
 
-    def setUp(self):
-        cache.clear_all()
+        def setUp(self):
+            self.interface = MockHID(0xDEADBEEF)
+            cache.clear_all()
 
-    def test_start_session(self):
-        session_id_a = cache_codec.start_session()
-        self.assertIsNotNone(session_id_a)
-        session_id_b = cache_codec.start_session()
-        self.assertNotEqual(session_id_a, session_id_b)
+        def test_new_channel_and_session(self):
+            channel = thp_common.get_new_channel(self.interface)
 
-        cache.clear_all()
-        with self.assertRaises(cache_common.InvalidSessionError):
-            context.cache_set(KEY, "something")
-        with self.assertRaises(cache_common.InvalidSessionError):
-            context.cache_get(KEY)
+            # Assert that channel is created without any sessions
+            self.assertEqual(len(channel.sessions), 0)
 
-    def test_end_session(self):
-        session_id = cache_codec.start_session()
-        self.assertTrue(is_session_started())
-        context.cache_set(KEY, b"A")
-        cache_codec.end_current_session()
-        self.assertFalse(is_session_started())
-        self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
+            cid_1 = channel.channel_id
+            session_cache_1 = cache_thp.get_new_session(channel.channel_cache)
+            session_1 = SessionContext(channel, session_cache_1)
+            self.assertEqual(session_1.channel_id, cid_1)
 
-        # ending an ended session should be a no-op
-        cache_codec.end_current_session()
-        self.assertFalse(is_session_started())
+            session_cache_2 = cache_thp.get_new_session(channel.channel_cache)
+            session_2 = SessionContext(channel, session_cache_2)
+            self.assertEqual(session_2.channel_id, cid_1)
+            self.assertEqual(session_1.channel_id, session_2.channel_id)
+            self.assertNotEqual(session_1.session_id, session_2.session_id)
 
-        session_id_a = cache_codec.start_session(session_id)
-        # original session no longer exists
-        self.assertNotEqual(session_id_a, session_id)
-        # original session data no longer exists
-        self.assertIsNone(context.cache_get(KEY))
+            channel_2 = thp_common.get_new_channel(self.interface)
+            cid_2 = channel_2.channel_id
+            self.assertNotEqual(cid_1, cid_2)
 
-        # create a new session
-        session_id_b = cache_codec.start_session()
-        # switch back to original session
-        session_id = cache_codec.start_session(session_id_a)
-        self.assertEqual(session_id, session_id_a)
-        # end original session
-        cache_codec.end_current_session()
-        # switch back to B
-        session_id = cache_codec.start_session(session_id_b)
-        self.assertEqual(session_id, session_id_b)
+            session_cache_3 = cache_thp.get_new_session(channel_2.channel_cache)
+            session_3 = SessionContext(channel_2, session_cache_3)
+            self.assertEqual(session_3.channel_id, cid_2)
 
-    def test_session_queue(self):
-        session_id = cache_codec.start_session()
-        self.assertEqual(cache_codec.start_session(session_id), session_id)
-        context.cache_set(KEY, b"A")
-        for _ in range(cache_codec._MAX_SESSIONS_COUNT):
+            # Sessions 1 and 3 should have different channel_id, but the same session_id
+            self.assertNotEqual(session_1.channel_id, session_3.channel_id)
+            self.assertEqual(session_1.session_id, session_3.session_id)
+
+            self.assertEqual(cache_thp._SESSIONS[0], session_cache_1)
+            self.assertNotEqual(cache_thp._SESSIONS[0], session_cache_2)
+            self.assertEqual(cache_thp._SESSIONS[0].channel_id, session_1.channel_id)
+
+            # Check that session data IS in cache for created sessions ONLY
+            for i in range(3):
+                self.assertNotEqual(cache_thp._SESSIONS[i].channel_id, b"")
+                self.assertNotEqual(cache_thp._SESSIONS[i].session_id, b"")
+                self.assertNotEqual(cache_thp._SESSIONS[i].last_usage, 0)
+            for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
+                self.assertEqual(cache_thp._SESSIONS[i].channel_id, b"")
+                self.assertEqual(cache_thp._SESSIONS[i].session_id, b"")
+                self.assertEqual(cache_thp._SESSIONS[i].last_usage, 0)
+
+            # Check that session data IS NOT in cache after cache.clear_all()
+            cache.clear_all()
+            for session in cache_thp._SESSIONS:
+                self.assertEqual(session.channel_id, b"")
+                self.assertEqual(session.session_id, b"")
+                self.assertEqual(session.last_usage, 0)
+                self.assertEqual(session.state, b"\x00")
+
+        def test_channel_capacity_in_cache(self):
+            self.assertTrue(cache_thp._MAX_CHANNELS_COUNT >= 3)
+            channels = []
+            for i in range(cache_thp._MAX_CHANNELS_COUNT):
+                channels.append(thp_common.get_new_channel(self.interface))
+            channel_ids = [channel.channel_cache.channel_id for channel in channels]
+
+            # Assert that each channel_id is unique and that cache and list of channels
+            # have the same "channels" on the same indexes
+            for i in range(len(channel_ids)):
+                self.assertEqual(cache_thp._CHANNELS[i].channel_id, channel_ids[i])
+                for j in range(i + 1, len(channel_ids)):
+                    self.assertNotEqual(channel_ids[i], channel_ids[j])
+
+            # Create a new channel that is over the capacity
+            new_channel = thp_common.get_new_channel(self.interface)
+            for c in channels:
+                self.assertNotEqual(c.channel_id, new_channel.channel_id)
+
+            # Test that the oldest (least used) channel was replaced (_CHANNELS[0])
+            self.assertNotEqual(cache_thp._CHANNELS[0].channel_id, channel_ids[0])
+            self.assertEqual(cache_thp._CHANNELS[0].channel_id, new_channel.channel_id)
+
+            # Update the "last used" value of the second channel in cache (_CHANNELS[1]) and
+            # assert that it is not replaced when creating a new channel
+            cache_thp.update_channel_last_used(channel_ids[1])
+            new_new_channel = thp_common.get_new_channel(self.interface)
+            self.assertEqual(cache_thp._CHANNELS[1].channel_id, channel_ids[1])
+
+            # Assert that it was in fact the _CHANNEL[2] that was replaced
+            self.assertNotEqual(cache_thp._CHANNELS[2].channel_id, channel_ids[2])
+            self.assertEqual(
+                cache_thp._CHANNELS[2].channel_id, new_new_channel.channel_id
+            )
+
+        def test_session_capacity_in_cache(self):
+            self.assertTrue(cache_thp._MAX_SESSIONS_COUNT >= 4)
+            channel_cache_A = thp_common.get_new_channel(self.interface).channel_cache
+            channel_cache_B = thp_common.get_new_channel(self.interface).channel_cache
+
+            sesions_A = []
+            cid = []
+            sid = []
+            for i in range(3):
+                sesions_A.append(cache_thp.get_new_session(channel_cache_A))
+                cid.append(sesions_A[i].channel_id)
+                sid.append(sesions_A[i].session_id)
+
+            sessions_B = []
+            for i in range(cache_thp._MAX_SESSIONS_COUNT - 3):
+                sessions_B.append(cache_thp.get_new_session(channel_cache_B))
+
+            for i in range(3):
+                self.assertEqual(sesions_A[i], cache_thp._SESSIONS[i])
+                self.assertEqual(cid[i], cache_thp._SESSIONS[i].channel_id)
+                self.assertEqual(sid[i], cache_thp._SESSIONS[i].session_id)
+            for i in range(3, cache_thp._MAX_SESSIONS_COUNT):
+                self.assertEqual(sessions_B[i - 3], cache_thp._SESSIONS[i])
+
+            # Assert that new session replaces the oldest (least used) one (_SESSOIONS[0])
+            new_session = cache_thp.get_new_session(channel_cache_B)
+            self.assertEqual(new_session, cache_thp._SESSIONS[0])
+            self.assertNotEqual(new_session.channel_id, cid[0])
+            self.assertNotEqual(new_session.session_id, sid[0])
+
+            # Assert that updating "last used" for session on channel A increases also
+            # the "last usage" of channel A.
+            self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
+            cache_thp.update_session_last_used(
+                channel_cache_A.channel_id, sesions_A[1].session_id
+            )
+            self.assertTrue(channel_cache_A.last_usage > channel_cache_B.last_usage)
+
+            new_new_session = cache_thp.get_new_session(channel_cache_B)
+
+            # Assert that creating a new session on channel B shifts the "last usage" again
+            # and that _SESSIONS[1] was not replaced, but that _SESSIONS[2] was replaced
+            self.assertTrue(channel_cache_A.last_usage < channel_cache_B.last_usage)
+            self.assertEqual(sesions_A[1], cache_thp._SESSIONS[1])
+            self.assertNotEqual(sesions_A[2], cache_thp._SESSIONS[2])
+            self.assertEqual(new_new_session, cache_thp._SESSIONS[2])
+
+        def test_clear(self):
+            channel_A = thp_common.get_new_channel(self.interface)
+            channel_B = thp_common.get_new_channel(self.interface)
+            cid_A = channel_A.channel_id
+            cid_B = channel_B.channel_id
+            sessions = []
+
+            for i in range(3):
+                sessions.append(cache_thp.get_new_session(channel_A.channel_cache))
+                sessions.append(cache_thp.get_new_session(channel_B.channel_cache))
+
+                self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, cid_A)
+                self.assertNotEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
+
+                self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
+                self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
+
+            # Assert that clearing of channel A works
+            self.assertNotEqual(channel_A.channel_cache.channel_id, b"")
+            self.assertNotEqual(channel_A.channel_cache.last_usage, 0)
+            self.assertEqual(channel_A.get_channel_state(), ChannelState.TH1)
+
+            channel_A.clear()
+
+            self.assertEqual(channel_A.channel_cache.channel_id, b"")
+            self.assertEqual(channel_A.channel_cache.last_usage, 0)
+            self.assertEqual(channel_A.get_channel_state(), ChannelState.UNALLOCATED)
+
+            # Assert that clearing channel A also cleared all its sessions
+            for i in range(3):
+                self.assertEqual(cache_thp._SESSIONS[2 * i].last_usage, 0)
+                self.assertEqual(cache_thp._SESSIONS[2 * i].channel_id, b"")
+
+                self.assertNotEqual(cache_thp._SESSIONS[2 * i + 1].last_usage, 0)
+                self.assertEqual(cache_thp._SESSIONS[2 * i + 1].channel_id, cid_B)
+
+            cache.clear_all()
+            for session in cache_thp._SESSIONS:
+                self.assertEqual(session.last_usage, 0)
+                self.assertEqual(session.channel_id, b"")
+            for channel in cache_thp._CHANNELS:
+                self.assertEqual(channel.channel_id, b"")
+                self.assertEqual(channel.last_usage, 0)
+                self.assertEqual(
+                    cache_thp._get_channel_state(channel), ChannelState.UNALLOCATED
+                )
+
+        def test_get_set(self):
+            channel = thp_common.get_new_channel(self.interface)
+
+            session_1 = cache_thp.get_new_session(channel.channel_cache)
+            session_1.set(KEY, b"hello")
+            self.assertEqual(session_1.get(KEY), b"hello")
+
+            session_2 = cache_thp.get_new_session(channel.channel_cache)
+            session_2.set(KEY, b"world")
+            self.assertEqual(session_2.get(KEY), b"world")
+
+            self.assertEqual(session_1.get(KEY), b"hello")
+
+            cache.clear_all()
+            self.assertIsNone(session_1.get(KEY))
+            self.assertIsNone(session_2.get(KEY))
+
+        def test_get_set_int(self):
+            channel = thp_common.get_new_channel(self.interface)
+
+            session_1 = cache_thp.get_new_session(channel.channel_cache)
+            session_1.set_int(KEY, 1234)
+
+            self.assertEqual(session_1.get_int(KEY), 1234)
+
+            session_2 = cache_thp.get_new_session(channel.channel_cache)
+            session_2.set_int(KEY, 5678)
+            self.assertEqual(session_2.get_int(KEY), 5678)
+
+            self.assertEqual(session_1.get_int(KEY), 1234)
+
+            cache.clear_all()
+            self.assertIsNone(session_1.get_int(KEY))
+            self.assertIsNone(session_2.get_int(KEY))
+
+        def test_get_set_bool(self):
+            channel = thp_common.get_new_channel(self.interface)
+
+            session_1 = cache_thp.get_new_session(channel.channel_cache)
+            with self.assertRaises(AssertionError):
+                session_1.set_bool(KEY, True)
+
+            # Change length of first session field to 0 so that the length check passes
+            session_1.fields = (0,) + session_1.fields[1:]
+
+            # with self.assertRaises(AssertionError) as e:
+            session_1.set_bool(KEY, True)
+            self.assertEqual(session_1.get_bool(KEY), True)
+
+            session_2 = cache_thp.get_new_session(channel.channel_cache)
+            session_2.fields = session_2.fields = (0,) + session_2.fields[1:]
+            session_2.set_bool(KEY, False)
+            self.assertEqual(session_2.get_bool(KEY), False)
+
+            self.assertEqual(session_1.get_bool(KEY), True)
+
+            cache.clear_all()
+
+            # Default value is False
+            self.assertFalse(session_1.get_bool(KEY))
+            self.assertFalse(session_2.get_bool(KEY))
+
+        def test_delete(self):
+            channel = thp_common.get_new_channel(self.interface)
+            session_1 = cache_thp.get_new_session(channel.channel_cache)
+
+            self.assertIsNone(session_1.get(KEY))
+            session_1.set(KEY, b"hello")
+            self.assertEqual(session_1.get(KEY), b"hello")
+            session_1.delete(KEY)
+            self.assertIsNone(session_1.get(KEY))
+
+            session_1.set(KEY, b"hello")
+            session_2 = cache_thp.get_new_session(channel.channel_cache)
+
+            self.assertIsNone(session_2.get(KEY))
+            session_2.set(KEY, b"hello")
+            self.assertEqual(session_2.get(KEY), b"hello")
+            session_2.delete(KEY)
+            self.assertIsNone(session_2.get(KEY))
+
+            self.assertEqual(session_1.get(KEY), b"hello")
+
+    else:
+
+        def setUpClass(self):
+            from trezor.wire.codec.codec_context import CodecContext
+            from trezor.wire import context
+
+            context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
+
+        def tearDownClass(self):
+            from trezor.wire import context
+
+            context.CURRENT_CONTEXT = None
+
+        def setUp(self):
+            cache.clear_all()
+
+        def test_start_session(self):
+            session_id_a = cache_codec.start_session()
+            self.assertIsNotNone(session_id_a)
+            session_id_b = cache_codec.start_session()
+            self.assertNotEqual(session_id_a, session_id_b)
+
+            cache.clear_all()
+            self.assertIsNone(get_active_session())
+            for session in cache_codec._SESSIONS:
+                self.assertEqual(session.session_id, b"")
+                self.assertEqual(session.last_usage, 0)
+
+        def test_end_session(self):
+            session_id = cache_codec.start_session()
+            self.assertTrue(is_session_started())
+            get_active_session().set(KEY, b"A")
+            cache_codec.end_current_session()
+            self.assertFalse(is_session_started())
+            self.assertIsNone(get_active_session())
+
+            # ending an ended session should be a no-op
+            cache_codec.end_current_session()
+            self.assertFalse(is_session_started())
+
+            session_id_a = cache_codec.start_session(session_id)
+            # original session no longer exists
+            self.assertNotEqual(session_id_a, session_id)
+            # original session data no longer exists
+            self.assertIsNone(get_active_session().get(KEY))
+
+            # create a new session
+            session_id_b = cache_codec.start_session()
+            # switch back to original session
+            session_id = cache_codec.start_session(session_id_a)
+            self.assertEqual(session_id, session_id_a)
+            # end original session
+            cache_codec.end_current_session()
+            # switch back to B
+            session_id = cache_codec.start_session(session_id_b)
+            self.assertEqual(session_id, session_id_b)
+
+        def test_session_queue(self):
+            session_id = cache_codec.start_session()
+            self.assertEqual(cache_codec.start_session(session_id), session_id)
+            get_active_session().set(KEY, b"A")
+            for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
+                cache_codec.start_session()
+            self.assertNotEqual(cache_codec.start_session(session_id), session_id)
+            self.assertIsNone(get_active_session().get(KEY))
+
+        def test_get_set(self):
+            session_id1 = cache_codec.start_session()
+            cache_codec.get_active_session().set(KEY, b"hello")
+            self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
+
+            session_id2 = cache_codec.start_session()
+            cache_codec.get_active_session().set(KEY, b"world")
+            self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
+
+            cache_codec.start_session(session_id2)
+            self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
+            cache_codec.start_session(session_id1)
+            self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
+
+            cache.clear_all()
+            self.assertIsNone(cache_codec.get_active_session())
+
+        def test_get_set_int(self):
+            session_id1 = cache_codec.start_session()
+            get_active_session().set_int(KEY, 1234)
+            self.assertEqual(get_active_session().get_int(KEY), 1234)
+
+            session_id2 = cache_codec.start_session()
+            get_active_session().set_int(KEY, 5678)
+            self.assertEqual(get_active_session().get_int(KEY), 5678)
+
+            cache_codec.start_session(session_id2)
+            self.assertEqual(get_active_session().get_int(KEY), 5678)
+            cache_codec.start_session(session_id1)
+            self.assertEqual(get_active_session().get_int(KEY), 1234)
+
+            cache.clear_all()
+            self.assertIsNone(get_active_session())
+
+        def test_delete(self):
+            session_id1 = cache_codec.start_session()
+            self.assertIsNone(get_active_session().get(KEY))
+            get_active_session().set(KEY, b"hello")
+            self.assertEqual(get_active_session().get(KEY), b"hello")
+            get_active_session().delete(KEY)
+            self.assertIsNone(get_active_session().get(KEY))
+
+            get_active_session().set(KEY, b"hello")
             cache_codec.start_session()
-        self.assertNotEqual(cache_codec.start_session(session_id), session_id)
-        self.assertIsNone(context.cache_get(KEY))
+            self.assertIsNone(get_active_session().get(KEY))
+            get_active_session().set(KEY, b"hello")
+            self.assertEqual(get_active_session().get(KEY), b"hello")
+            get_active_session().delete(KEY)
+            self.assertIsNone(get_active_session().get(KEY))
 
-    def test_get_set(self):
-        session_id1 = cache_codec.start_session()
-        context.cache_set(KEY, b"hello")
-        self.assertEqual(context.cache_get(KEY), b"hello")
+            cache_codec.start_session(session_id1)
+            self.assertEqual(get_active_session().get(KEY), b"hello")
 
-        session_id2 = cache_codec.start_session()
-        context.cache_set(KEY, b"world")
-        self.assertEqual(context.cache_get(KEY), b"world")
+        def test_decorators(self):
+            run_count = 0
+            cache_codec.start_session()
+            from apps.common.cache import stored
 
-        cache_codec.start_session(session_id2)
-        self.assertEqual(context.cache_get(KEY), b"world")
-        cache_codec.start_session(session_id1)
-        self.assertEqual(context.cache_get(KEY), b"hello")
+            @stored(KEY)
+            def func():
+                nonlocal run_count
+                run_count += 1
+                return b"foo"
 
-        cache.clear_all()
-        with self.assertRaises(cache_common.InvalidSessionError):
-            context.cache_get(KEY)
+            # cache is empty
+            self.assertIsNone(get_active_session().get(KEY))
+            self.assertEqual(run_count, 0)
+            self.assertEqual(func(), b"foo")
+            # function was run
+            self.assertEqual(run_count, 1)
+            self.assertEqual(get_active_session().get(KEY), b"foo")
+            # function does not run again but returns cached value
+            self.assertEqual(func(), b"foo")
+            self.assertEqual(run_count, 1)
 
-    def test_get_set_int(self):
-        session_id1 = cache_codec.start_session()
-        context.cache_set_int(KEY, 1234)
-        self.assertEqual(context.cache_get_int(KEY), 1234)
+        def test_empty_value(self):
+            cache_codec.start_session()
 
-        session_id2 = cache_codec.start_session()
-        context.cache_set_int(KEY, 5678)
-        self.assertEqual(context.cache_get_int(KEY), 5678)
+            self.assertIsNone(get_active_session().get(KEY))
+            get_active_session().set(KEY, b"")
+            self.assertEqual(get_active_session().get(KEY), b"")
 
-        cache_codec.start_session(session_id2)
-        self.assertEqual(context.cache_get_int(KEY), 5678)
-        cache_codec.start_session(session_id1)
-        self.assertEqual(context.cache_get_int(KEY), 1234)
+            get_active_session().delete(KEY)
+            run_count = 0
 
-        cache.clear_all()
-        with self.assertRaises(cache_common.InvalidSessionError):
-            context.cache_get_int(KEY)
+            from apps.common.cache import stored
 
-    def test_delete(self):
-        session_id1 = cache_codec.start_session()
-        self.assertIsNone(context.cache_get(KEY))
-        context.cache_set(KEY, b"hello")
-        self.assertEqual(context.cache_get(KEY), b"hello")
-        context.cache_delete(KEY)
-        self.assertIsNone(context.cache_get(KEY))
+            @stored(KEY)
+            def func():
+                nonlocal run_count
+                run_count += 1
+                return b""
 
-        context.cache_set(KEY, b"hello")
-        cache_codec.start_session()
-        self.assertIsNone(context.cache_get(KEY))
-        context.cache_set(KEY, b"hello")
-        self.assertEqual(context.cache_get(KEY), b"hello")
-        context.cache_delete(KEY)
-        self.assertIsNone(context.cache_get(KEY))
+            self.assertEqual(func(), b"")
+            # function gets called once
+            self.assertEqual(run_count, 1)
+            self.assertEqual(func(), b"")
+            # function is not called for a second time
+            self.assertEqual(run_count, 1)
 
-        cache_codec.start_session(session_id1)
-        self.assertEqual(context.cache_get(KEY), b"hello")
+        @mock_storage
+        def test_Initialize(self):
+            from apps.base import handle_Initialize
 
-    def test_decorators(self):
-        run_count = 0
-        cache_codec.start_session()
+            def call_Initialize(**kwargs):
+                msg = Initialize(**kwargs)
+                return await_result(handle_Initialize(msg))
 
-        @stored(KEY)
-        def func():
-            nonlocal run_count
-            run_count += 1
-            return b"foo"
+            # calling Initialize without an ID allocates a new one
+            session_id = cache_codec.start_session()
+            features = call_Initialize()
+            self.assertNotEqual(session_id, features.session_id)
 
-        # cache is empty
-        self.assertIsNone(context.cache_get(KEY))
-        self.assertEqual(run_count, 0)
-        self.assertEqual(func(), b"foo")
-        # function was run
-        self.assertEqual(run_count, 1)
-        self.assertEqual(context.cache_get(KEY), b"foo")
-        # function does not run again but returns cached value
-        self.assertEqual(func(), b"foo")
-        self.assertEqual(run_count, 1)
+            # calling Initialize with the current ID does not allocate a new one
+            features = call_Initialize(session_id=session_id)
+            self.assertEqual(session_id, features.session_id)
 
-        @stored_async(KEY)
-        async def async_func():
-            nonlocal run_count
-            run_count += 1
-            return b"bar"
+            # store "hello"
+            get_active_session().set(KEY, b"hello")
+            # check that it is cleared
+            features = call_Initialize()
+            session_id = features.session_id
+            self.assertIsNone(get_active_session().get(KEY))
+            # store "hello" again
+            get_active_session().set(KEY, b"hello")
+            self.assertEqual(get_active_session().get(KEY), b"hello")
 
-        # cache is still full
-        self.assertEqual(await_result(async_func()), b"foo")
-        self.assertEqual(run_count, 1)
+            # supplying a different session ID starts a new session
+            call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
+            self.assertIsNone(get_active_session().get(KEY))
 
-        cache_codec.start_session()
-        self.assertEqual(await_result(async_func()), b"bar")
-        self.assertEqual(run_count, 2)
-        # awaitable is also run only once
-        self.assertEqual(await_result(async_func()), b"bar")
-        self.assertEqual(run_count, 2)
+            # but resuming a session loads the previous one
+            call_Initialize(session_id=session_id)
+            self.assertEqual(get_active_session().get(KEY), b"hello")
 
-    def test_empty_value(self):
-        cache_codec.start_session()
+        def test_EndSession(self):
 
-        self.assertIsNone(context.cache_get(KEY))
-        context.cache_set(KEY, b"")
-        self.assertEqual(context.cache_get(KEY), b"")
-
-        context.cache_delete(KEY)
-        run_count = 0
-
-        @stored(KEY)
-        def func():
-            nonlocal run_count
-            run_count += 1
-            return b""
-
-        self.assertEqual(func(), b"")
-        # function gets called once
-        self.assertEqual(run_count, 1)
-        self.assertEqual(func(), b"")
-        # function is not called for a second time
-        self.assertEqual(run_count, 1)
-
-    @mock_storage
-    def test_Initialize(self):
-        def call_Initialize(**kwargs):
-            msg = Initialize(**kwargs)
-            return await_result(handle_Initialize(msg))
-
-        # calling Initialize without an ID allocates a new one
-        session_id = cache_codec.start_session()
-        features = call_Initialize()
-        self.assertNotEqual(session_id, features.session_id)
-
-        # calling Initialize with the current ID does not allocate a new one
-        features = call_Initialize(session_id=session_id)
-        self.assertEqual(session_id, features.session_id)
-
-        # store "hello"
-        context.cache_set(KEY, b"hello")
-        # check that it is cleared
-        features = call_Initialize()
-        session_id = features.session_id
-        self.assertIsNone(context.cache_get(KEY))
-        # store "hello" again
-        context.cache_set(KEY, b"hello")
-        self.assertEqual(context.cache_get(KEY), b"hello")
-
-        # supplying a different session ID starts a new cache
-        call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH)
-        self.assertIsNone(context.cache_get(KEY))
-
-        # but resuming a session loads the previous one
-        call_Initialize(session_id=session_id)
-        self.assertEqual(context.cache_get(KEY), b"hello")
-
-    def test_EndSession(self):
-        self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
-        cache_codec.start_session()
-        self.assertTrue(is_session_started())
-        self.assertIsNone(context.cache_get(KEY))
-        await_result(handle_EndSession(EndSession()))
-        self.assertFalse(is_session_started())
-        self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
+            self.assertIsNone(get_active_session())
+            cache_codec.start_session()
+            self.assertTrue(is_session_started())
+            self.assertIsNone(get_active_session().get(KEY))
+            await_result(handle_EndSession(EndSession()))
+            self.assertFalse(is_session_started())
+            self.assertIsNone(cache_codec.get_active_session())
 
 
 if __name__ == "__main__":
diff --git a/core/tests/test_trezor.wire.codec.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py
index 852f5f5b8b..9e6d478590 100644
--- a/core/tests/test_trezor.wire.codec.codec_v1.py
+++ b/core/tests/test_trezor.wire.codec.codec_v1.py
@@ -3,61 +3,11 @@ from common import *  # isort:skip
 
 import ustruct
 
+from mock_wire_interface import MockHID
 from trezor import io
-from trezor.loop import wait
 from trezor.utils import chunks
 from trezor.wire.codec import codec_v1
 
-
-class MockHID:
-
-    TX_PACKET_LEN = 64
-    RX_PACKET_LEN = 64
-
-    def __init__(self, num):
-        self.num = num
-        self.data = []
-        self.packet = None
-
-    def pad_packet(self, data):
-        if len(data) > self.RX_PACKET_LEN:
-            raise Exception("Too long packet")
-        padding_length = self.RX_PACKET_LEN - len(data)
-        return data + b"\x00" * padding_length
-
-    def iface_num(self):
-        return self.num
-
-    def write(self, msg):
-        self.data.append(bytearray(msg))
-        return len(msg)
-
-    def mock_read(self, packet, gen):
-        self.packet = self.pad_packet(packet)
-        return gen.send(self.RX_PACKET_LEN)
-
-    def read(self, buffer, offset=0):
-        if self.packet is None:
-            raise Exception("No packet to read")
-
-        if offset > len(buffer):
-            raise Exception("Offset out of bounds")
-
-        buffer_space = len(buffer) - offset
-
-        if len(self.packet) > buffer_space:
-            raise Exception("Buffer too small")
-        else:
-            end = offset + len(self.packet)
-            buffer[offset:end] = self.packet
-            read = len(self.packet)
-            self.packet = None
-            return read
-
-    def wait_object(self, mode):
-        return wait(mode | self.num)
-
-
 MESSAGE_TYPE = 0x4242
 
 HEADER_PAYLOAD_LENGTH = MockHID.RX_PACKET_LEN - 3 - ustruct.calcsize(">HL")
diff --git a/core/tests/test_trezor.wire.thp.checksum.py b/core/tests/test_trezor.wire.thp.checksum.py
new file mode 100644
index 0000000000..41c9325001
--- /dev/null
+++ b/core/tests/test_trezor.wire.thp.checksum.py
@@ -0,0 +1,94 @@
+from common import *  # isort:skip
+
+if utils.USE_THP:
+    from trezor.wire.thp import checksum
+
+
+@unittest.skipUnless(utils.USE_THP, "only needed for THP")
+class TestTrezorHostProtocolChecksum(unittest.TestCase):
+    vectors_correct = [
+        (
+            b"",
+            b"\x00\x00\x00\x00",
+        ),
+        (
+            b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+            b"\x19\x0A\x55\xAD",
+        ),
+        (
+            bytes("a", "ascii"),
+            b"\xE8\xB7\xBE\x43",
+        ),
+        (
+            bytes("abc", "ascii"),
+            b"\x35\x24\x41\xC2",
+        ),
+        (
+            bytes("123456789", "ascii"),
+            b"\xCB\xF4\x39\x26",
+        ),
+        (
+            bytes(
+                "12345678901234567890123456789012345678901234567890123456789012345678901234567890",
+                "ascii",
+            ),
+            b"\x7C\xA9\x4A\x72",
+        ),
+        (
+            b"\x76\x61\x72\x69\x6F\x75\x73\x20\x43\x52\x43\x20\x61\x6C\x67\x6F\x72\x69\x74\x68\x6D\x73\x20\x69\x6E\x70\x75\x74\x20\x64\x61\x74\x61",
+            b"\x9B\xD3\x66\xAE",
+        ),
+        (
+            b"\x67\x3a\x5f\x0e\x39\xc0\x3c\x79\x58\x22\x74\x76\x64\x9e\x36\xe9\x0b\x04\x8c\xd2\xc0\x4d\x76\x63\x1a\xa2\x17\x85\xe8\x50\xa7\x14\x18\xfb\x86\xed\xa3\x59\x2d\x62\x62\x49\x64\x62\x26\x12\xdb\x95\x3d\xd6\xb5\xca\x4b\x22\x0d\xc5\x78\xb2\x12\x97\x8e\x54\x4e\x06\xb7\x9c\x90\xf5\xa0\x21\xa6\xc7\xd8\x39\xfd\xea\x3a\xf1\x7b\xa2\xe8\x71\x41\xd6\xcb\x1e\x5b\x0e\x29\xf7\x0c\xc7\x57\x8b\x53\x20\x1d\x2b\x41\x1c\x25\xf9\x07\xbb\xb4\x37\x79\x6a\x13\x1f\x6c\x43\x71\xc1\x1e\x70\xe6\x74\xd3\x9c\xbf\x32\x15\xee\xf2\xa7\x86\xbe\x59\x99\xc4\x10\x09\x8a\x6a\xaa\xd4\xd1\xd0\x71\xd2\x06\x1a\xdd\x2a\xa0\x08\xeb\x08\x6c\xfb\xd2\x2d\xfb\xaa\x72\x56\xeb\xd1\x92\x92\xe5\x0e\x95\x67\xf8\x38\xc3\xab\x59\x37\xe6\xfd\x42\xb0\xd0\x31\xd0\xcb\x8a\x66\xce\x2d\x53\x72\x1e\x72\xd3\x84\x25\xb0\xb8\x93\xd2\x61\x5b\x32\xd5\xe7\xe4\x0e\x31\x11\xaf\xdc\xb4\xb8\xee\xa4\x55\x16\x5f\x78\x86\x8b\x50\x4d\xc5\x6d\x6e\xfc\xe1\x6b\x06\x5b\x37\x84\x2a\x67\x95\x28\x00\xa4\xd1\x32\x9f\xbf\xe1\x64\xf8\x17\x47\xe1\xad\x8b\x72\xd2\xd9\x45\x5b\x73\x43\x3c\xe6\x21\xf7\x53\xa3\x73\xf9\x2a\xb0\xe9\x75\x5e\xa6\xbe\x9a\xad\xfc\xed\xb5\x46\x5b\x9f\xa9\x5a\x4f\xcb\xb6\x60\x96\x31\x91\x42\xca\xaf\xee\xa5\x0c\xe0\xab\x3e\x83\xb8\xac\x88\x10\x2c\x63\xd3\xc9\xd2\xf2\x44\xef\xea\x3d\x19\x24\x3c\x5b\xe7\x0c\x52\xfd\xfe\x47\x41\x14\xd5\x4c\x67\x8d\xdb\xe5\xd9\xfa\x67\x9c\x06\x31\x01\x92\xba\x96\xc4\x0d\xef\xf7\xc1\xe9\x23\x28\x0f\xae\x27\x9b\xff\x28\x0b\x3e\x85\x0c\xae\x02\xda\x27\xb6\x04\x51\x04\x43\x04\x99\x8c\xa3\x97\x1d\x84\xec\x55\x59\xfb\xf3\x84\xe5\xf8\x40\xf8\x5f\x81\x65\x92\x4c\x92\x7a\x07\x51\x8d\x6f\xff\x8d\x15\x36\x5c\x57\x7a\x5b\x3a\x63\x1c\x87\x65\xee\x54\xd5\x96\x50\x73\x1a\x9c\xff\x59\xe5\xea\x6f\x89\xd2\xbb\xa9\x6a\x12\x21\xf5\x08\x8e\x8a\xc0\xd8\xf5\x14\xe9\x9d\x7e\x99\x13\x88\x29\xa8\xb4\x22\x2a\x41\x7c\xc5\x10\xdf\x11\x5e\xf8\x8d\x0e\xd9\x98\xd5\xaf\xa8\xf9\x55\x1e\xe3\x29\xcd\x2c\x51\x7b\x8a\x8d\x52\xaa\x8b\x87\xae\x8e\xb2\xfa\x31\x27\x60\x90\xcb\x01\x6f\x7a\x79\x38\x04\x05\x7c\x11\x79\x10\x40\x33\x70\x75\xfd\x0b\x88\xa5\xcd\x35\xd8\xa6\x3b\xb0\x45\x82\x64\xd1\xb5\xdc\x06\xc9\x89\xf4\x16\x3e\xc7\xb3\xf1\x9d\xd3\xc5\xe3\xaf\xe8\x25\x86\x7a\x4a\xfd\x10\x5d\x20\xe5\x76\x5a\x22\x5f\x8f\xbc\xaa\x97\xee\xf2\xc2\x4c\x0e\xdc\x7b\xc4\xee\x53\xa3\xe0\xfa\xcd\x1e\x4e\x54\x1d\x5e\xe1\x51\x17\x1f\x1a\x75\x7f\xed\x12\xd7\xf7\xe3\x18\x56\x24\xcf\xc6\x96\x30\x77\x0d\x73\x98\x9c\x09\x69\xa3\xbc\x96\x5e\xaf\xde\x76\xa4\x66\x04\x6b\x36\x2a\xac\x6d\x37\xf8\x1e\xe1\x2a\x3e\x42\x2d\x1d\xe6\x46\xdd\x28\xb9\x08\x44\xa1\x9e\xb2\x22\x7a\x45\x8a\x37\x39\x74\xb4\xae\xc8\x3b\x40\xf7\xec\xbf\xfd\xe5\xde\xb2\x83\x5e\xa4\x46\x19\xa6\x9d\xb0\xe8\x76\x80\xbd\xc1\x80\x7a\xd9\xeb\xe7\x90\x5b\x81\x25\x21\xd9\x5b\x4a\x80\x48\x92\x71\x77\x04\xb2\xac\x05\xc9\xdf\x5e\x44\x5a\xae\x6e\xb3\xd8\x30\x5e\xdc\x77\x2f\x79\xc2\x8e\x8b\x28\x24\x06\x1b\x6f\x8d\x88\x53\x80\x55\x0c\x3a\x7b\x85\xb8\x96\x85\xe9\xf0\x57\x63\xfe\x32\x80\xff\x57\xc9\x3c\xdb\xf6\xcd\x67\x14\x47\x6c\x43\x3d\x6d\x48\x3f\x9c\x00\x60\x0e\xf5\x94\xe4\x52\x97\x86\xcd\xac\xbc\xe4\xe3\xe7\xee\xa2\x91\x6e\x92\xbb\xd1\x55\x0c\x5c\x0d\x63\xdb\x6b\xb8\x6e\x45\x48\x0f\xdf\x44\x48\xd2\xf5\xf7\x4d\x7b\xd4\x4d\xd3\xcd\xcd\x5b\x40\x60\xb1\xb2\x8e\xc9\x9a\x65\xc5\x06\x24\xcf\xe9\xcc\x5e\x2c\x49\x47\x38\x45\x5d\xc5\xc0\x0d\x8a\x07\x1c\xb3\xbb\xb1\x69\xf5\x6d\x0e\x9c\x96\x14\x93\x58\x0c\xc9\x48\x74\xfc\x35\xda\x7d\x4e\x32\x73\xa3\x77\x4a\x9e\xc5\xd1\x08\xfe\xa6\xa0\xf1\x66\x72\xea\xc7\xae\x21\x81\x0e\x8a\xba\x99\x06\x97\xfc\xc6\x2b\x69\x53\xc6\x67\xec\x5d\xa1\xfc\xa1\x3b\xdd\x2a\xd6\x8f\x31\xa7\x8d\xec\xfe\x0a\x3b\x6b\x39\x70\x70\x09\x72\x12\xbc\x84\x67\xca\xd2\x4a\x17\x33\x94\x45\x25\xc7\xfd\x1e\xa2\x4a\x9e\x27\x9d\xfb\x87\xea\xe4\xfd\xb0\x11\x06\x9d\x72\xb9\x1d\xea\x9b\x81\x2e\x6a\x36\x76\x62\xfa\xbe\x96\x67\x7d\x35\xdd\x5e\x5c\x4f\x41\x0d\xce\xdb\x13\xb0\x46\x89\x92\x45\x02\x39\x0f\xe6\xd1\x20\x96\x1c\x34\x00\x8c\xc9\xdf\xe3\xf0\xb6\x92\x3a\xda\x5c\x96\xd9\x0b\x7d\x57\xf5\x78\x11\xc0\xcf\xbf\xb0\x92\x3d\xe5\x6a\x67\x34\xce\xd9\x16\x08\xa0\x09\x42\x0b\x07\x13\x7c\x73\x0c\xc6\x50\x17\x42\xcf\xd9\x85\xd9\x23\x3c\xb1\x40\x40\x0f\x94\x20\xed\x2d\xbf\x10\x44\x6e\x64\x65\xe5\x1d\x5f\xec\x24\xd8\x4b\xe8\xc2\xfb\x06\x11\x24\x3f\xdf\x54\x2d\xe8\x4d\xc2\x1c\x27\x11\xb8\xb3\xd4",
+            b"\x6B\xA4\xEC\x92",
+        ),
+    ]
+    vectors_incorrect = [
+        (
+            b"",
+            b"\x00\x00\x00\x00\x00",
+        ),
+        (
+            b"",
+            b"",
+        ),
+        (
+            b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+            b"\x19\x0A\x55\xAE",
+        ),
+        (
+            bytes("A", "ascii"),
+            b"\xE8\xB7\xBE\x43",
+        ),
+        (
+            bytes("abc ", "ascii"),
+            b"\x35\x24\x41\xC2",
+        ),
+        (
+            bytes("1234567890", "ascii"),
+            b"\xCB\xF4\x39\x26",
+        ),
+        (
+            bytes(
+                "1234567890123456789012345678901234567890123456789012345678901234567890123456789",
+                "ascii",
+            ),
+            b"\x7C\xA9\x4A\x72",
+        ),
+    ]
+
+    def test_computation(self):
+        for data, chksum in self.vectors_correct:
+            self.assertEqual(checksum.compute(data), chksum)
+
+    def test_validation_correct(self):
+        for data, chksum in self.vectors_correct:
+            self.assertTrue(checksum.is_valid(chksum, data))
+
+    def test_validation_incorrect(self):
+        for data, chksum in self.vectors_incorrect:
+            self.assertFalse(checksum.is_valid(chksum, data))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/core/tests/test_trezor.wire.thp.crypto.py b/core/tests/test_trezor.wire.thp.crypto.py
new file mode 100644
index 0000000000..d26785ce65
--- /dev/null
+++ b/core/tests/test_trezor.wire.thp.crypto.py
@@ -0,0 +1,156 @@
+from common import *  # isort:skip
+from trezorcrypto import aesgcm, curve25519
+
+import storage
+
+if utils.USE_THP:
+    import thp_common
+    from trezor.wire.thp import crypto
+    from trezor.wire.thp.crypto import IV_1, IV_2, Handshake
+
+    def get_dummy_device_secret():
+        return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08"
+
+
+@unittest.skipUnless(utils.USE_THP, "only needed for THP")
+class TestTrezorHostProtocolCrypto(unittest.TestCase):
+    if utils.USE_THP:
+        handshake = Handshake()
+        key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07"
+        # 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag
+        vectors_enc = [
+            (
+                key_1,
+                0,
+                b"\x55\x64",
+                b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
+                b"e2c9dd152fbee5821ea7",
+                b"10625812de81b14a46b9f1e5100a6d0c",
+            ),
+            (
+                key_1,
+                1,
+                b"\x55\x64",
+                b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
+                b"79811619ddb07c2b99f8",
+                b"71c6b872cdc499a7e9a3c7441f053214",
+            ),
+            (
+                key_1,
+                369,
+                b"\x55\x64",
+                b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+                b"03bd030390f2dfe815a61c2b157a064f",
+                b"c1200f8a7ae9a6d32cef0fff878d55c2",
+            ),
+            (
+                key_1,
+                369,
+                b"\x55\x64\x73\x82\x91",
+                b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+                b"03bd030390f2dfe815a61c2b157a064f",
+                b"693ac160cd93a20f7fc255f049d808d0",
+            ),
+        ]
+        # 0:chaining key, 1:input, 2:output_1, 3:output:2
+        vectors_hkdf = [
+            (
+                crypto.PROTOCOL_NAME,
+                b"\x01\x02",
+                b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514",
+                b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba",
+            ),
+            (
+                b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14",
+                b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa",
+                b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48",
+                b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76",
+            ),
+        ]
+        vectors_iv = [
+            (0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
+            (1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),
+            (7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"),
+            (1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"),
+            (4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"),
+            (0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"),
+        ]
+
+    def __init__(self):
+        if __debug__:
+            thp_common.suppres_debug_log()
+        super().__init__()
+
+    def setUp(self):
+        utils.DISABLE_ENCRYPTION = False
+
+    def test_encryption(self):
+        for v in self.vectors_enc:
+            buffer = bytearray(v[3])
+            tag = crypto.enc(buffer, v[0], v[1], v[2])
+            self.assertEqual(hexlify(buffer), v[4])
+            self.assertEqual(hexlify(tag), v[5])
+            self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2]))
+            self.assertEqual(buffer, v[3])
+
+    def test_hkdf(self):
+        for v in self.vectors_hkdf:
+            ck, k = crypto._hkdf(v[0], v[1])
+            self.assertEqual(hexlify(ck), v[2])
+            self.assertEqual(hexlify(k), v[3])
+
+    def test_iv_from_nonce(self):
+        for v in self.vectors_iv:
+            x = v[0]
+            y = x.to_bytes(8, "big")
+            iv = crypto._get_iv_from_nonce(v[0])
+            self.assertEqual(iv, v[1])
+        with self.assertRaises(AssertionError) as e:
+            iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1)
+        self.assertEqual(e.value.value, "Nonce overflow, terminate the channel")
+
+    def test_incorrect_vectors(self):
+        pass
+
+    def test_th1_crypto(self):
+        storage.device.get_device_secret = get_dummy_device_secret
+        handshake = self.handshake
+
+        host_ephemeral_privkey = curve25519.generate_secret()
+        host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey)
+        handshake.handle_th1_crypto(b"", host_ephemeral_pubkey)
+
+    def test_th2_crypto(self):
+        handshake = self.handshake
+
+        host_static_privkey = curve25519.generate_secret()
+        host_static_pubkey = curve25519.publickey(host_static_privkey)
+        aes_ctx = aesgcm(handshake.k, IV_2)
+        aes_ctx.auth(handshake.h)
+        encrypted_host_static_pubkey = bytearray(
+            aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish()
+        )
+
+        # Code to encrypt Host's noise encrypted payload correctly:
+        protomsg = bytearray(b"\x10\x02\x10\x03")
+        temp_k = handshake.k
+        temp_h = handshake.h
+
+        temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey)
+        _, temp_k = crypto._hkdf(
+            handshake.ck,
+            curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey),
+        )
+        aes_ctx = aesgcm(temp_k, IV_1)
+        aes_ctx.encrypt_in_place(protomsg)
+        aes_ctx.auth(temp_h)
+        tag = aes_ctx.finish()
+        encrypted_payload = bytearray(protomsg + tag)
+        # end of encrypted payload generation
+
+        handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload)
+        self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03")
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/core/tests/test_trezor.wire.thp.py b/core/tests/test_trezor.wire.thp.py
new file mode 100644
index 0000000000..576ddab4db
--- /dev/null
+++ b/core/tests/test_trezor.wire.thp.py
@@ -0,0 +1,378 @@
+from common import *  # isort:skip
+from mock_wire_interface import MockHID
+from trezor import config, io, protobuf
+from trezor.crypto.curve import curve25519
+from trezor.enums import ThpMessageType
+from trezor.wire.errors import UnexpectedMessage
+from trezor.wire.protocol_common import Message
+
+if utils.USE_THP:
+    from typing import TYPE_CHECKING
+
+    import thp_common
+    from storage import cache_thp
+    from storage.cache_common import (
+        CHANNEL_HANDSHAKE_HASH,
+        CHANNEL_KEY_RECEIVE,
+        CHANNEL_KEY_SEND,
+        CHANNEL_NONCE_RECEIVE,
+        CHANNEL_NONCE_SEND,
+    )
+    from trezor.crypto import elligator2
+    from trezor.enums import ThpPairingMethod
+    from trezor.messages import (
+        ThpCodeEntryChallenge,
+        ThpCodeEntryCpaceHost,
+        ThpCodeEntryTag,
+        ThpCredentialRequest,
+        ThpEndRequest,
+        ThpStartPairingRequest,
+    )
+    from trezor.wire.thp import thp_main
+    from trezor.wire.thp import ChannelState, checksum, interface_manager
+    from trezor.wire.thp.crypto import Handshake
+    from trezor.wire.thp.pairing_context import PairingContext
+
+    from apps.thp import pairing
+
+    if TYPE_CHECKING:
+        from trezor.wire import WireInterface
+
+    def get_dummy_key() -> bytes:
+        return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x01\x02\x03\x04\x05\x06\x07\x08\x09\x20\x01\x02\x03\x04\x05\x06\x07\x08\x09\x30\x31"
+
+    def dummy_encode_iface(iface: WireInterface):
+        return thp_common._MOCK_INTERFACE_HID
+
+    def send_channel_allocation_request(
+        interface: WireInterface, nonce: bytes | None = None
+    ) -> bytes:
+        if nonce is None or len(nonce) != 8:
+            nonce = b"\x00\x11\x22\x33\x44\x55\x66\x77"
+        header = b"\x40\xff\xff\x00\x0c"
+        chksum = checksum.compute(header + nonce)
+        cid_req = header + nonce + chksum
+        gen = thp_main.thp_main_loop(interface)
+        expected_channel_index = cache_thp._get_next_channel_index()
+        gen.send(None)
+        gen.send(cid_req)
+        gen.send(None)
+        model = bytes(utils.INTERNAL_MODEL, "big")
+        response_data = (
+            b"\x0a\x04" + model + "\x10\x00\x18\x00\x20\x02\x28\x02\x28\x03\x28\x04"
+        )
+        response_without_crc = (
+            b"\x41\xff\xff\x00\x20"
+            + nonce
+            + cache_thp._CHANNELS[expected_channel_index].channel_id
+            + response_data
+        )
+        chkcsum = checksum.compute(response_without_crc)
+        expected_response = response_without_crc + chkcsum + b"\x00" * 27
+        return expected_response
+
+    def get_channel_id_from_response(channel_allocation_response: bytes) -> int:
+        return int.from_bytes(channel_allocation_response[13:15], "big")
+
+    def get_ack(channel_id: bytes) -> bytes:
+        if len(channel_id) != 2:
+            raise Exception("Channel id should by two bytes long")
+        return (
+            b"\x20"
+            + channel_id
+            + b"\x00\x04"
+            + checksum.compute(b"\x20" + channel_id + b"\x00\x04")
+            + b"\x00" * 55
+        )
+
+
+@unittest.skipUnless(utils.USE_THP, "only needed for THP")
+class TestTrezorHostProtocol(unittest.TestCase):
+
+    def __init__(self):
+        if __debug__:
+            thp_common.suppres_debug_log()
+        interface_manager.encode_iface = dummy_encode_iface
+        super().__init__()
+
+    def setUp(self):
+        self.interface = MockHID(0xDEADBEEF)
+        buffer = bytearray(64)
+        buffer2 = bytearray(256)
+        thp_main.set_read_buffer(buffer)
+        thp_main.set_write_buffer(buffer2)
+        interface_manager.decode_iface = thp_common.dummy_decode_iface
+
+    def test_codec_message(self):
+        self.assertEqual(len(self.interface.data), 0)
+        gen = thp_main.thp_main_loop(self.interface)
+        gen.send(None)
+
+        # There should be a failiure response to received init packet (starts with "?##")
+        test_codec_message = b"?## Some data"
+        gen.send(test_codec_message)
+        gen.send(None)
+        self.assertEqual(len(self.interface.data), 1)
+
+        expected_response = b"?##\x00\x03\x00\x00\x00\x14\x08\x10"
+        self.assertEqual(
+            self.interface.data[-1][: len(expected_response)], expected_response
+        )
+
+        # There should be no response for continuation packet (starts with "?" only)
+        test_codec_message_2 = b"? Cont packet"
+        gen.send(test_codec_message_2)
+        with self.assertRaises(TypeError) as e:
+            gen.send(None)
+        self.assertEqual(e.value.value, "object with buffer protocol required")
+        self.assertEqual(len(self.interface.data), 1)
+
+    def test_message_on_unallocated_channel(self):
+        gen = thp_main.thp_main_loop(self.interface)
+        query = gen.send(None)
+        self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
+        message_to_channel_789a = (
+            b"\x04\x78\x9a\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
+        )
+        gen.send(message_to_channel_789a)
+        gen.send(None)
+        unallocated_chanel_error_on_channel_789a = "42789a0005027b743563000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
+        self.assertEqual(
+            utils.get_bytes_as_str(self.interface.data[-1]),
+            unallocated_chanel_error_on_channel_789a,
+        )
+
+    def test_channel_allocation(self):
+        self.assertEqual(len(thp_main._CHANNELS), 0)
+        for c in cache_thp._CHANNELS:
+            self.assertEqual(int.from_bytes(c.state, "big"), ChannelState.UNALLOCATED)
+
+        expected_channel_index = cache_thp._get_next_channel_index()
+        expected_response = send_channel_allocation_request(self.interface)
+        self.assertEqual(self.interface.data[-1], expected_response)
+
+        cid = cache_thp._CHANNELS[expected_channel_index].channel_id
+        self.assertTrue(int.from_bytes(cid, "big") in thp_main._CHANNELS)
+        self.assertEqual(len(thp_main._CHANNELS), 1)
+
+        # test channel's default state is TH1:
+        cid = get_channel_id_from_response(self.interface.data[-1])
+        self.assertEqual(thp_main._CHANNELS[cid].get_channel_state(), ChannelState.TH1)
+
+    def test_invalid_encrypted_tag(self):
+        gen = thp_main.thp_main_loop(self.interface)
+        gen.send(None)
+        # prepare 2 new channels
+        expected_response_1 = send_channel_allocation_request(self.interface)
+        expected_response_2 = send_channel_allocation_request(self.interface)
+        self.assertEqual(self.interface.data[-2], expected_response_1)
+        self.assertEqual(self.interface.data[-1], expected_response_2)
+
+        # test invalid encryption tag
+        config.init()
+        config.wipe()
+        cid_1 = get_channel_id_from_response(expected_response_1)
+        channel = thp_main._CHANNELS[cid_1]
+        channel.iface = self.interface
+        channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
+        header = b"\x04" + channel.channel_id + b"\x00\x14"
+
+        tag = b"\x00" * 16
+        chksum = checksum.compute(header + tag)
+        message_with_invalid_tag = header + tag + chksum
+
+        channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
+        channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
+
+        cid_1_bytes = int.to_bytes(cid_1, 2, "big")
+        expected_ack_on_received_message = get_ack(cid_1_bytes)
+
+        gen.send(message_with_invalid_tag)
+        gen.send(None)
+
+        self.assertEqual(
+            self.interface.data[-1],
+            expected_ack_on_received_message,
+        )
+        error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
+        chksum_err = checksum.compute(error_without_crc)
+        gen.send(None)
+
+        decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
+
+        self.assertEqual(
+            self.interface.data[-1],
+            decryption_failed_error,
+        )
+
+    def test_channel_errors(self):
+        gen = thp_main.thp_main_loop(self.interface)
+        gen.send(None)
+        # prepare 2 new channels
+        expected_response_1 = send_channel_allocation_request(self.interface)
+        expected_response_2 = send_channel_allocation_request(self.interface)
+        self.assertEqual(self.interface.data[-2], expected_response_1)
+        self.assertEqual(self.interface.data[-1], expected_response_2)
+
+        # test invalid encryption tag
+        config.init()
+        config.wipe()
+        cid_1 = get_channel_id_from_response(expected_response_1)
+        channel = thp_main._CHANNELS[cid_1]
+        channel.iface = self.interface
+        channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
+        header = b"\x04" + channel.channel_id + b"\x00\x14"
+
+        tag = b"\x00" * 16
+        chksum = checksum.compute(header + tag)
+        message_with_invalid_tag = header + tag + chksum
+
+        channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
+        channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
+
+        cid_1_bytes = int.to_bytes(cid_1, 2, "big")
+        expected_ack_on_received_message = get_ack(cid_1_bytes)
+
+        gen.send(message_with_invalid_tag)
+        gen.send(None)
+
+        self.assertEqual(
+            self.interface.data[-1],
+            expected_ack_on_received_message,
+        )
+        error_without_crc = b"\x42" + cid_1_bytes + b"\x00\x05\x03"
+        chksum_err = checksum.compute(error_without_crc)
+        gen.send(None)
+
+        decryption_failed_error = error_without_crc + chksum_err + b"\x00" * 54
+
+        self.assertEqual(
+            self.interface.data[-1],
+            decryption_failed_error,
+        )
+
+        # test invalid tag in handshake phase
+        cid_2 = get_channel_id_from_response(expected_response_1)
+        cid_2_bytes = cid_2.to_bytes(2, "big")
+        channel = thp_main._CHANNELS[cid_2]
+        channel.iface = self.interface
+
+        channel.set_channel_state(ChannelState.TH2)
+
+        message_with_invalid_tag = b"\x0a\x12\x36\x00\x14\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x91\x65\x4c\xf9"
+
+        channel.channel_cache.set(CHANNEL_KEY_RECEIVE, get_dummy_key())
+        channel.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, 0)
+
+        # gen.send(message_with_invalid_tag)
+        # gen.send(None)
+        # gen.send(None)
+        # for i in self.interface.data:
+        #    print(utils.get_bytes_as_str(i))
+
+    def test_skip_pairing(self):
+        config.init()
+        config.wipe()
+        channel = next(iter(thp_main._CHANNELS.values()))
+        channel.selected_pairing_methods = [
+            ThpPairingMethod.NoMethod,
+            ThpPairingMethod.CodeEntry,
+            ThpPairingMethod.NFC_Unidirectional,
+            ThpPairingMethod.QrCode,
+        ]
+        pairing_ctx = PairingContext(channel)
+        request_message = ThpStartPairingRequest()
+        channel.set_channel_state(ChannelState.TP1)
+        gen = pairing.handle_pairing_request(pairing_ctx, request_message)
+
+        with self.assertRaises(StopIteration):
+            gen.send(None)
+        self.assertEqual(channel.get_channel_state(), ChannelState.ENCRYPTED_TRANSPORT)
+
+        # Teardown: set back initial channel state value
+        channel.set_channel_state(ChannelState.TH1)
+
+    def TODO_test_pairing(self):
+        config.init()
+        config.wipe()
+        cid = get_channel_id_from_response(
+            send_channel_allocation_request(self.interface)
+        )
+        channel = thp_main._CHANNELS[cid]
+        channel.selected_pairing_methods = [
+            ThpPairingMethod.CodeEntry,
+            ThpPairingMethod.NFC_Unidirectional,
+            ThpPairingMethod.QrCode,
+        ]
+        pairing_ctx = PairingContext(channel)
+        request_message = ThpStartPairingRequest()
+        with self.assertRaises(UnexpectedMessage) as e:
+            pairing.handle_pairing_request(pairing_ctx, request_message)
+        print(e.value.message)
+        channel.set_channel_state(ChannelState.TP1)
+        gen = pairing.handle_pairing_request(pairing_ctx, request_message)
+
+        channel.channel_cache.set(CHANNEL_KEY_SEND, get_dummy_key())
+        channel.channel_cache.set_int(CHANNEL_NONCE_SEND, 0)
+        channel.channel_cache.set(CHANNEL_HANDSHAKE_HASH, b"")
+
+        gen.send(None)
+
+        async def _dummy(ctx: PairingContext, expected_types):
+            return await ctx.read([1018, 1024])
+
+        pairing.show_display_data = _dummy
+
+        msg_code_entry = ThpCodeEntryChallenge(challenge=b"\x12\x34")
+        buffer: bytearray = bytearray(protobuf.encoded_length(msg_code_entry))
+        protobuf.encode(buffer, msg_code_entry)
+        code_entry_challenge = Message(ThpMessageType.ThpCodeEntryChallenge, buffer)
+        gen.send(code_entry_challenge)
+
+        # tag_qrc = b"\x55\xdf\x6c\xba\x0b\xe9\x5e\xd1\x4b\x78\x61\xec\xfa\x07\x9b\x5d\x37\x60\xd8\x79\x9c\xd7\x89\xb4\x22\xc1\x6f\x39\xde\x8f\x3b\xc3"
+        # tag_nfc = b"\x8f\xf0\xfa\x37\x0a\x5b\xdb\x29\x32\x21\xd8\x2f\x95\xdd\xb6\xb8\xee\xfd\x28\x6f\x56\x9f\xa9\x0b\x64\x8c\xfc\x62\x46\x5a\xdd\xd0"
+
+        pregenerator_host = b"\xf6\x94\xc3\x6f\xb3\xbd\xfb\xba\x2f\xfd\x0c\xd0\x71\xed\x54\x76\x73\x64\x37\xfa\x25\x85\x12\x8d\xcf\xb5\x6c\x02\xaf\x9d\xe8\xbe"
+        generator_host = elligator2.map_to_curve25519(pregenerator_host)
+        cpace_host_private_key = b"\x02\x80\x70\x3c\x06\x45\x19\x75\x87\x0c\x82\xe1\x64\x11\xc0\x18\x13\xb2\x29\x04\xb3\xf0\xe4\x1e\x6b\xfd\x77\x63\x11\x73\x07\xa9"
+        cpace_host_public_key: bytes = curve25519.multiply(
+            cpace_host_private_key, generator_host
+        )
+        msg = ThpCodeEntryCpaceHost(cpace_host_public_key=cpace_host_public_key)
+
+        # msg = ThpQrCodeTag(tag=tag_qrc)
+        # msg = ThpNfcUnidirectionalTag(tag=tag_nfc)
+        buffer: bytearray = bytearray(protobuf.encoded_length(msg))
+
+        protobuf.encode(buffer, msg)
+        user_message = Message(ThpMessageType.ThpCodeEntryCpaceHost, buffer)
+        gen.send(user_message)
+
+        tag_ent = b"\xd0\x15\xd6\x72\x7c\xa6\x9b\x2a\x07\xfa\x30\xee\x03\xf0\x2d\x04\xdc\x96\x06\x77\x0c\xbd\xb4\xaa\x77\xc7\x68\x6f\xae\xa9\xdd\x81"
+        msg = ThpCodeEntryTag(tag=tag_ent)
+
+        buffer: bytearray = bytearray(protobuf.encoded_length(msg))
+
+        protobuf.encode(buffer, msg)
+        user_message = Message(ThpMessageType.ThpCodeEntryTag, buffer)
+        gen.send(user_message)
+
+        host_static_pubkey = b"\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77\x00\x11\x22\x33\x44\x55\x66\x77"
+        msg = ThpCredentialRequest(host_static_pubkey=host_static_pubkey)
+        buffer: bytearray = bytearray(protobuf.encoded_length(msg))
+        protobuf.encode(buffer, msg)
+        credential_request = Message(ThpMessageType.ThpCredentialRequest, buffer)
+        gen.send(credential_request)
+
+        msg = ThpEndRequest()
+
+        buffer: bytearray = bytearray(protobuf.encoded_length(msg))
+        protobuf.encode(buffer, msg)
+        end_request = Message(1012, buffer)
+        with self.assertRaises(StopIteration) as e:
+            gen.send(end_request)
+        print("response message:", e.value.value.MESSAGE_NAME)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/core/tests/test_trezor.wire.thp.writer.py b/core/tests/test_trezor.wire.thp.writer.py
new file mode 100644
index 0000000000..0e50f5c4b5
--- /dev/null
+++ b/core/tests/test_trezor.wire.thp.writer.py
@@ -0,0 +1,151 @@
+from common import *  # isort:skip
+
+from typing import Any, Awaitable
+
+
+if utils.USE_THP:
+    import thp_common
+    from mock_wire_interface import MockHID
+    from trezor.wire.thp import writer
+    from trezor.wire.thp import ENCRYPTED, PacketHeader
+
+
+@unittest.skipUnless(utils.USE_THP, "only needed for THP")
+class TestTrezorHostProtocolWriter(unittest.TestCase):
+    short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
+    longer_payload_expected = [
+        b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
+        b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
+        b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
+        b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
+        b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+    ]
+    eight_longer_payloads_expected = [
+        b"0412340800000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
+        b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
+        b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
+        b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
+        b"801234f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e",
+        b"8012342f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b",
+        b"8012346c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8",
+        b"801234a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5",
+        b"801234e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122",
+        b"801234232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f",
+        b"801234606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c",
+        b"8012349d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9",
+        b"801234dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f10111213141516",
+        b"8012341718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f50515253",
+        b"8012345455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f90",
+        b"8012349192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccd",
+        b"801234cecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a",
+        b"8012340b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f4041424344454647",
+        b"80123448494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f8081828384",
+        b"80123485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1",
+        b"801234c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe",
+        b"801234ff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b",
+        b"8012343c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778",
+        b"801234797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5",
+        b"801234b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2",
+        b"801234f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f",
+        b"801234303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c",
+        b"8012346d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9",
+        b"801234aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6",
+        b"801234e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20212223",
+        b"8012342425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f60",
+        b"8012346162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d",
+        b"8012349e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9da",
+        b"801234dbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000000000000000000000000000000000000000000000000",
+    ]
+    empty_payload_with_checksum_expected = b"0412340004edbd479c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"
+    longer_payload_with_checksum_expected = [
+        b"0412340100000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a",
+        b"8012343b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f7071727374757677",
+        b"80123478797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4",
+        b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1",
+        b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+    ]
+
+    def await_until_result(self, task: Awaitable) -> Any:
+        with self.assertRaises(StopIteration):
+            while True:
+                task.send(None)
+
+    def __init__(self):
+        if __debug__:
+            thp_common.suppres_debug_log()
+        super().__init__()
+
+    def setUp(self):
+        self.interface = MockHID(0xDEADBEEF)
+
+    def test_write_empty_packet(self):
+        self.await_until_result(writer.write_packet_to_wire(self.interface, b""))
+
+        print(self.interface.data[0])
+        self.assertEqual(len(self.interface.data), 1)
+        self.assertEqual(self.interface.data[0], b"")
+
+    def test_write_empty_payload(self):
+        header = PacketHeader(ENCRYPTED, 4660, 4)
+        await_result(writer.write_payloads_to_wire(self.interface, header, (b"",)))
+        self.assertEqual(len(self.interface.data), 0)
+
+    def test_write_short_payload(self):
+        header = PacketHeader(ENCRYPTED, 4660, 5)
+        data = b"\x07"
+        self.await_until_result(
+            writer.write_payloads_to_wire(self.interface, header, (data,))
+        )
+        self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected)
+
+    def test_write_longer_payload(self):
+        data = bytearray(range(256))
+        header = PacketHeader(ENCRYPTED, 4660, 256)
+        self.await_until_result(
+            writer.write_payloads_to_wire(self.interface, header, (data,))
+        )
+
+        for i in range(len(self.longer_payload_expected)):
+            self.assertEqual(
+                hexlify(self.interface.data[i]), self.longer_payload_expected[i]
+            )
+
+    def test_write_eight_longer_payloads(self):
+        data = bytearray(range(256))
+        header = PacketHeader(ENCRYPTED, 4660, 2048)
+        self.await_until_result(
+            writer.write_payloads_to_wire(
+                self.interface, header, (data, data, data, data, data, data, data, data)
+            )
+        )
+        for i in range(len(self.eight_longer_payloads_expected)):
+            self.assertEqual(
+                hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i]
+            )
+
+    def test_write_empty_payload_with_checksum(self):
+        header = PacketHeader(ENCRYPTED, 4660, 4)
+        self.await_until_result(
+            writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"")
+        )
+
+        self.assertEqual(
+            hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected
+        )
+
+    def test_write_longer_payload_with_checksum(self):
+        data = bytearray(range(256))
+        header = PacketHeader(ENCRYPTED, 4660, 256)
+        self.await_until_result(
+            writer.write_payload_to_wire_and_add_checksum(self.interface, header, data)
+        )
+
+        for i in range(len(self.longer_payload_with_checksum_expected)):
+            self.assertEqual(
+                hexlify(self.interface.data[i]),
+                self.longer_payload_with_checksum_expected[i],
+            )
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/core/tests/test_trezor.wire.thp_deprecated.py b/core/tests/test_trezor.wire.thp_deprecated.py
new file mode 100644
index 0000000000..12ef40bb7e
--- /dev/null
+++ b/core/tests/test_trezor.wire.thp_deprecated.py
@@ -0,0 +1,338 @@
+from common import *  # isort:skip
+import ustruct
+from typing import TYPE_CHECKING
+
+from mock_wire_interface import MockHID
+from storage.cache_thp import BROADCAST_CHANNEL_ID
+from trezor import io
+from trezor.utils import chunks
+from trezor.wire.protocol_common import Message
+
+if utils.USE_THP:
+    import thp_common
+    import trezor.wire.thp
+    from trezor.wire.thp import thp_main
+    from trezor.wire.thp import alternating_bit_protocol as ABP
+    from trezor.wire.thp import checksum
+    from trezor.wire.thp.checksum import CHECKSUM_LENGTH
+    from trezor.wire.thp.writer import PACKET_LENGTH
+
+if TYPE_CHECKING:
+    from trezorio import WireInterface
+
+
+MESSAGE_TYPE = 0x4242
+MESSAGE_TYPE_BYTES = b"\x42\x42"
+_MESSAGE_TYPE_LEN = 2
+PLAINTEXT_0 = 0x01
+PLAINTEXT_1 = 0x11
+COMMON_CID = 4660
+CONT = 0x80
+
+HEADER_INIT_LENGTH = 5
+HEADER_CONT_LENGTH = 3
+if utils.USE_THP:
+    INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
+
+
+def make_header(ctrl_byte, cid, length):
+    return ustruct.pack(">BHH", ctrl_byte, cid, length)
+
+
+def make_cont_header():
+    return ustruct.pack(">BH", CONT, COMMON_CID)
+
+
+def makeSimpleMessage(header, message_type, message_data):
+    return header + ustruct.pack(">H", message_type) + message_data
+
+
+def makeCidRequest(header, message_data):
+    return header + message_data
+
+
+def getPlaintext() -> bytes:
+    if ABP.get_expected_receive_seq_bit(THP.get_active_session()) == 1:
+        return PLAINTEXT_1
+    return PLAINTEXT_0
+
+
+async def deprecated_read_message(
+    iface: WireInterface, buffer: utils.BufferType
+) -> Message:
+    return Message(-1, b"\x00")
+
+
+async def deprecated_write_message(
+    iface: WireInterface, message: Message, is_retransmission: bool = False
+) -> None:
+    pass
+
+
+# This test suite is an adaptation of test_trezor.wire.codec_v1
+@unittest.skipUnless(utils.USE_THP, "only needed for THP")
+class TestWireTrezorHostProtocolV1(unittest.TestCase):
+
+    def __init__(self):
+        if __debug__:
+            thp_common.suppres_debug_log()
+        super().__init__()
+
+    def setUp(self):
+        self.interface = MockHID(0xDEADBEEF)
+
+    def _simple(self):
+        cid_req_header = make_header(
+            ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
+        )
+        cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
+        cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
+
+        message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
+        cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0"
+        message = makeSimpleMessage(
+            message_header,
+            MESSAGE_TYPE,
+            cid_request_dummy_data + cid_request_dummy_data_checksum,
+        )
+
+        buffer = bytearray(64)
+        gen = deprecated_read_message(self.interface, buffer)
+        query = gen.send(None)
+        self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
+        gen.send(cid_req_message)
+        gen.send(None)
+        gen.send(message)
+        with self.assertRaises(StopIteration) as e:
+            gen.send(None)
+
+        # e.value is StopIteration. e.value.value is the return value of the call
+        result = e.value.value
+        self.assertEqual(result.type, MESSAGE_TYPE)
+        self.assertEqual(result.data, cid_request_dummy_data)
+
+        buffer_without_zeroes = buffer[: len(message) - 5]
+        message_without_header = message[5:]
+        # message should have been read into the buffer
+        self.assertEqual(buffer_without_zeroes, message_without_header)
+
+    def _read_one_packet(self):
+        # zero length message - just a header
+        PLAINTEXT = getPlaintext()
+        header = make_header(
+            PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
+        )
+        chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
+        message = header + MESSAGE_TYPE_BYTES + chksum
+
+        buffer = bytearray(64)
+        gen = deprecated_read_message(self.interface, buffer)
+
+        query = gen.send(None)
+        self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
+        gen.send(message)
+        with self.assertRaises(StopIteration) as e:
+            gen.send(None)
+
+        # e.value is StopIteration. e.value.value is the return value of the call
+        result = e.value.value
+        self.assertEqual(result.type, MESSAGE_TYPE)
+        self.assertEqual(result.data, b"")
+
+        # message should have been read into the buffer
+        self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
+
+    def _read_many_packets(self):
+        message = bytes(range(256))
+        header = make_header(
+            getPlaintext(),
+            COMMON_CID,
+            len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
+        )
+        chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
+        # message = MESSAGE_TYPE_BYTES + message + checksum
+
+        # first packet is init header + 59 bytes of data
+        # other packets are cont header + 61 bytes of data
+        cont_header = make_cont_header()
+        packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
+            cont_header + chunk
+            for chunk in chunks(
+                message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
+                64 - HEADER_CONT_LENGTH,
+            )
+        ]
+        buffer = bytearray(262)
+        gen = deprecated_read_message(self.interface, buffer)
+        query = gen.send(None)
+        for packet in packets:
+            self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
+            query = gen.send(packet)
+
+        # last packet will stop
+        with self.assertRaises(StopIteration) as e:
+            gen.send(None)
+
+        # e.value is StopIteration. e.value.value is the return value of the call
+        result = e.value.value
+
+        self.assertEqual(result.type, MESSAGE_TYPE)
+        self.assertEqual(result.data, message)
+
+        # message should have been read into the buffer )
+        self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
+
+    def _read_large_message(self):
+        message = b"hello world"
+        header = make_header(
+            getPlaintext(),
+            COMMON_CID,
+            _MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
+        )
+
+        packet = (
+            header
+            + MESSAGE_TYPE_BYTES
+            + message
+            + checksum.compute(header + MESSAGE_TYPE_BYTES + message)
+        )
+
+        # make sure we fit into one packet, to make this easier
+        self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH)
+
+        buffer = bytearray(1)
+        self.assertTrue(len(buffer) <= len(packet))
+
+        gen = deprecated_read_message(self.interface, buffer)
+        query = gen.send(None)
+        self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
+        gen.send(packet)
+        with self.assertRaises(StopIteration) as e:
+            gen.send(None)
+
+        # e.value is StopIteration. e.value.value is the return value of the call
+        result = e.value.value
+        self.assertEqual(result.type, MESSAGE_TYPE)
+        self.assertEqual(result.data, message)
+
+        # read should have allocated its own buffer and not touch ours
+        self.assertEqual(buffer, b"\x00")
+
+    def _roundtrip(self):
+        message_payload = bytes(range(256))
+        message = Message(
+            MESSAGE_TYPE, message_payload, 1
+        )  # TODO use different session id
+        gen = deprecated_write_message(self.interface, message)
+        # exhaust the iterator:
+        # (XXX we can only do this because the iterator is only accepting None and returns None)
+        for query in gen:
+            self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
+
+        buffer = bytearray(1024)
+        gen = deprecated_read_message(self.interface, buffer)
+        query = gen.send(None)
+        for packet in self.interface.data:
+            self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
+            print(utils.get_bytes_as_str(packet))
+            query = gen.send(packet)
+
+        with self.assertRaises(StopIteration) as e:
+            gen.send(None)
+
+        result = e.value.value
+        self.assertEqual(result.type, MESSAGE_TYPE)
+        self.assertEqual(result.data, message.data)
+
+    def _write_one_packet(self):
+        message = Message(MESSAGE_TYPE, b"")
+        gen = deprecated_write_message(self.interface, message)
+
+        query = gen.send(None)
+        self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
+        with self.assertRaises(StopIteration):
+            gen.send(None)
+
+        header = make_header(
+            getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
+        )
+        expected_message = (
+            header
+            + MESSAGE_TYPE_BYTES
+            + checksum.compute(header + MESSAGE_TYPE_BYTES)
+            + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
+        )
+        self.assertTrue(self.interface.data == [expected_message])
+
+    def _write_multiple_packets(self):
+        message_payload = bytes(range(256))
+        message = Message(MESSAGE_TYPE, message_payload)
+        gen = deprecated_write_message(self.interface, message)
+
+        header = make_header(
+            PLAINTEXT_1,
+            COMMON_CID,
+            len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
+        )
+        cont_header = make_cont_header()
+        chksum = checksum.compute(
+            header + message.type.to_bytes(2, "big") + message.data
+        )
+        packets = [
+            header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
+        ] + [
+            cont_header + chunk
+            for chunk in chunks(
+                message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
+                thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH,
+            )
+        ]
+
+        for _ in packets:
+            # we receive as many queries as there are packets
+            query = gen.send(None)
+            self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
+
+        # the first sent None only started the generator. the len(packets)-th None
+        # will finish writing and raise StopIteration
+        with self.assertRaises(StopIteration):
+            gen.send(None)
+
+        # packets must be identical up to the last one
+        self.assertListEqual(packets[:-1], self.interface.data[:-1])
+        # last packet must be identical up to message length. remaining bytes in
+        # the 64-byte packets are garbage -- in particular, it's the bytes of the
+        # previous packet
+        last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
+        self.assertEqual(last_packet, self.interface.data[-1])
+
+    def _read_huge_packet(self):
+        PACKET_COUNT = 1180
+        # message that takes up 1 180 USB packets
+        message_size = (PACKET_COUNT - 1) * (
+            PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
+        ) + INIT_MESSAGE_DATA_LENGTH
+
+        # ensure that a message this big won't fit into memory
+        # Note: this control is changed, because THP has only 2 byte length field
+        self.assertTrue(message_size > thp_main.MAX_PAYLOAD_LEN)
+        # self.assertRaises(MemoryError, bytearray, message_size)
+        header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
+        packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
+        buffer = bytearray(65536)
+        gen = deprecated_read_message(self.interface, buffer)
+
+        query = gen.send(None)
+
+        # THP returns "Message too large" error after reading the message size,
+        # it is different from codec_v1 as it does not allow big enough messages
+        # to raise MemoryError in this test
+        self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
+        with self.assertRaises(trezor.wire.thp.ThpError) as e:
+            query = gen.send(packet)
+
+        self.assertEqual(e.value.args[0], "Message too large")
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/core/tests/thp_common.py b/core/tests/thp_common.py
new file mode 100644
index 0000000000..298a513002
--- /dev/null
+++ b/core/tests/thp_common.py
@@ -0,0 +1,44 @@
+from trezor import utils
+from trezor.wire.thp import ChannelState
+
+if utils.USE_THP:
+    import unittest
+    from typing import TYPE_CHECKING
+
+    from mock_wire_interface import MockHID
+    from storage import cache_thp
+    from trezor.wire import context
+    from trezor.wire.thp import interface_manager
+    from trezor.wire.thp.channel import Channel
+    from trezor.wire.thp.session_context import SessionContext
+
+    _MOCK_INTERFACE_HID = b"\x00"
+
+    if TYPE_CHECKING:
+        from trezor.wire import WireInterface
+
+    def dummy_decode_iface(cached_iface: bytes):
+        return MockHID(0xDEADBEEF)
+
+    def get_new_channel(channel_iface: WireInterface | None = None) -> Channel:
+        interface_manager.decode_iface = dummy_decode_iface
+        channel_cache = cache_thp.get_new_channel(_MOCK_INTERFACE_HID)
+        channel = Channel(channel_cache)
+        channel.set_channel_state(ChannelState.TH1)
+        if channel_iface is not None:
+            channel.iface = channel_iface
+        return channel
+
+    def prepare_context() -> None:
+        channel = get_new_channel()
+        session_cache = cache_thp.get_new_session(channel.channel_cache)
+        session_ctx = SessionContext(channel, session_cache)
+        context.CURRENT_CONTEXT = session_ctx
+
+
+if __debug__:
+    # Disable log.debug
+    def suppres_debug_log() -> None:
+        from trezor import log
+
+        log.debug = lambda name, msg, *args: None
diff --git a/core/tools/codegen/get_trezor_keys.py b/core/tools/codegen/get_trezor_keys.py
index 31c40fef1f..b511abd807 100755
--- a/core/tools/codegen/get_trezor_keys.py
+++ b/core/tools/codegen/get_trezor_keys.py
@@ -2,7 +2,7 @@
 
 import binascii
 from trezorlib.client import TrezorClient
-from trezorlib.transport_hid import HidTransport
+from trezorlib.transport.hid import HidTransport
 
 devices = HidTransport.enumerate()
 if len(devices) > 0: