Basic THP functinality - not-polished prototype

M1nd3r/thp5
M1nd3r 2 months ago committed by M1nd3r
parent 0d7fe7d643
commit 87365a3148

@ -47,6 +47,12 @@ storage
import storage
storage.cache
import storage.cache
storage.cache_codec
import storage.cache_codec
storage.cache_common
import storage.cache_common
storage.cache_thp
import storage.cache_thp
storage.common
import storage.common
storage.debug
@ -195,6 +201,14 @@ trezor.wire.context
import trezor.wire.context
trezor.wire.errors
import trezor.wire.errors
trezor.wire.protocol
import trezor.wire.protocol
trezor.wire.protocol_common
import trezor.wire.protocol_common
trezor.wire.thp_session
import trezor.wire.thp_session
trezor.wire.thp_v1
import trezor.wire.thp_v1
trezor.workflow
import trezor.workflow
apps

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.cache_thp as storage_thp_cache
import storage.device as storage_device
from trezor import TR, config, utils, wire, workflow
from trezor.enums import HomescreenFormat, MessageType
@ -175,10 +176,21 @@ def get_features() -> Features:
return f
async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id)
# handle_Initialize should not be used with THP to start a new session
async def handle_Initialize(
msg: Initialize, message_session_id: bytearray | None = None
) -> Features:
if message_session_id is None and utils.USE_THP:
raise ValueError("With THP enabled, a session id must be provided in args")
if utils.USE_THP:
session_id = storage_thp_cache.start_existing_session(msg.session_id)
else:
session_id = storage_cache.start_session(msg.session_id)
if not utils.BITCOIN_ONLY:
# TODO this block should be changed in THP
derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
@ -190,7 +202,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
storage_cache.end_current_session()
session_id = storage_cache.start_session()
session_id = storage_cache.start_session() # This should not be used in THP
have_seed = False
if not have_seed:
@ -200,7 +212,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
)
features = get_features()
features.session_id = session_id
features.session_id = session_id # not important in THP
return features

@ -4,17 +4,13 @@ from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
from storage.cache_common import SESSIONLESS_FLAG, InvalidSessionError, SessionlessCache
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSIONLESS_FLAG = const(128)
_SESSION_ID_LENGTH = const(32)
# Traditional cache keys
APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1)
@ -27,14 +23,13 @@ if not utils.BITCOIN_ONLY:
APP_MONERO_LIVE_REFRESH = const(7)
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | _SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | _SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(4 | _SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(5 | _SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(6 | _SESSIONLESS_FLAG)
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(4 | SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(5 | SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(6 | SESSIONLESS_FLAG)
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
@ -52,103 +47,6 @@ homescreen_shown: object | None = None
autolock_last_touch: int | None = None
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: ...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
)
super().__init__()
# XXX
# Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so
@ -157,97 +55,46 @@ class SessionlessCache(DataCache):
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_SESSIONLESS_CACHE.clear()
gc.collect()
if utils.USE_THP:
from storage import cache_thp
_PROTOCOL_CACHE = cache_thp
else:
from storage import cache_codec
_active_session_idx: int | None = None
_session_usage_counter = 0
_PROTOCOL_CACHE = cache_codec
_PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear()
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
gc.collect()
def end_current_session() -> None:
global _active_session_idx
def clear_all() -> None:
global autolock_last_touch
autolock_last_touch = None
_SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def start_session(received_session_id: bytes | None = None) -> bytes:
return _PROTOCOL_CACHE.start_session(received_session_id)
def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
def end_current_session() -> None:
_PROTOCOL_CACHE.end_current_session()
def set_int(key: int, value: int) -> None:
if key & _SESSIONLESS_FLAG:
length = _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
elif _active_session_idx is None:
def delete(key: int) -> None:
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG)
active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError
else:
length = _SESSIONS[_active_session_idx].fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
return active_session.delete(key)
if TYPE_CHECKING:
@ -261,11 +108,12 @@ if TYPE_CHECKING:
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
if _active_session_idx is None:
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default)
active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key, default)
return active_session.get(key, default)
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
@ -277,29 +125,52 @@ def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
def get_int_all_sessions(key: int) -> builtins.set[int]:
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
values = builtins.set()
for session in sessions:
encoded = session.get(key)
if key & SESSIONLESS_FLAG:
values = builtins.set()
encoded = _SESSIONLESS_CACHE.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
def is_set(key: int) -> bool:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG)
active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].is_set(key)
return active_session.is_set(key)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
def set(key: int, value: bytes) -> None:
if key & SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value)
return
active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key)
active_session.set(key, value)
def set_int(key: int, value: int) -> None:
active_session = _PROTOCOL_CACHE.get_active_session()
if key & SESSIONLESS_FLAG:
length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG]
if active_session is None:
raise InvalidSessionError
else:
length = active_session.fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
if TYPE_CHECKING:
@ -336,15 +207,3 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
return wrapper
return decorator
def clear_all() -> None:
global _active_session_idx
global autolock_last_touch
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()
autolock_last_touch = None

@ -0,0 +1,144 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
from storage.cache_common import DataCache, InvalidSessionError
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionCache] = []
def initialize() -> None:
global _SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
for session in _SESSIONS:
session.clear()
initialize()
_active_session_idx: int | None = None
_session_usage_counter = 0
def get_active_session() -> SessionCache | None:
if _active_session_idx is None:
return None
return _SESSIONS[_active_session_idx]
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS:
session.clear()

@ -0,0 +1,70 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
SESSIONLESS_FLAG = const(128)
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None:
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
)
super().__init__()

@ -0,0 +1,256 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache, InvalidSessionError
from trezor import utils
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
# THP specific constants
_MAX_SESSIONS_COUNT = const(20)
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5)
_THP_SESSION_STATE_LENGTH = const(1)
_SESSION_ID_LENGTH = const(4)
BROADCAST_CHANNEL_ID = const(65535)
class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
)
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionThpCache] = []
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = []
def initialize() -> None:
global _SESSIONS
global _UNAUTHENTICATED_SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
for session in _SESSIONS:
session.clear()
for session in _UNAUTHENTICATED_SESSIONS:
session.clear()
initialize()
# THP vars
_next_unauthenicated_session_index: int = 0
_is_active_session_authenticated: bool
_active_session_idx: int | None = None
_session_usage_counter = 0
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
cid_counter: int = 4659
def get_active_session_id():
if get_active_session() is None:
return None
return get_active_session().session_id
def get_active_session() -> SessionThpCache | None:
if _active_session_idx is None:
return None
if _is_active_session_authenticated:
return _SESSIONS[_active_session_idx]
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
def get_next_channel_id() -> int:
global cid_counter
while True:
cid_counter += 1
if cid_counter >= BROADCAST_CHANNEL_ID:
cid_counter = 1
if _is_cid_unique():
break
return cid_counter
def _is_cid_unique() -> bool:
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
if cid_counter == _get_cid(session):
return False
return True
def _get_cid(session: SessionThpCache) -> int:
return int.from_bytes(session.session_id[2:], "big")
def create_new_unauthenticated_session(session_id: bytearray) -> SessionThpCache:
if len(session_id) != 4:
raise ValueError("session_id must be 4 bytes long.")
global _active_session_idx
global _is_active_session_authenticated
global _next_unauthenicated_session_index
i = _next_unauthenicated_session_index
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
_next_unauthenicated_session_index += 1
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
_next_unauthenicated_session_index = 0
# Set session as active if and only if there is no active session
if _active_session_idx is None:
_active_session_idx = i
_is_active_session_authenticated = False
return _UNAUTHENTICATED_SESSIONS[i]
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
return i
return None
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
global _session_usage_counter
unauth_session_idx = get_unauth_session_index(unauth_session)
if unauth_session_idx is None:
raise InvalidSessionError
# replace least recently used authenticated session by the new session
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
_UNAUTHENTICATED_SESSIONS[unauth_session_idx] = None
_session_usage_counter += 1
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
def get_least_recently_used_authetnicated_session_index() -> int:
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
return lru_session_idx
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
def start_session(session_id: bytes) -> bytes: # TODO incomplete
global _active_session_idx
global _is_active_session_authenticated
if session_id is not None:
if get_active_session_id() == session_id:
return session_id
for index in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = True
return session_id
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big")
new_session = create_new_unauthenticated_session(new_session_id)
index = get_unauth_session_index(new_session)
_active_session_idx = index
_is_active_session_authenticated = False
return new_session_id
def start_existing_session(session_id: bytearray) -> bytes:
if session_id is None:
raise ValueError("session_id cannot be None")
if get_active_session_id() == session_id:
return session_id
for index in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = True
return session_id
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
raise ValueError("There is no active session with provided session_id")
def end_current_session() -> None:
global _active_session_idx
active_session = get_active_session()
if active_session is None:
return
active_session.clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
session.clear()

@ -33,6 +33,8 @@ MODEL_IS_T2B1: bool = INTERNAL_MODEL == "T2B1"
DISABLE_ANIMATION = 0
USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/...
if __debug__:
if EMULATOR:
import uos

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response.
- Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_v1.py`.
- Transferred over USB interface, or UDP in case of Unix emulation.
This module:
@ -23,15 +23,17 @@ reads the message's header. When the message type is known the first handler is
"""
from apps import workflow_handlers
from micropython import const
from typing import TYPE_CHECKING
from storage.cache import InvalidSessionError
from storage.cache_codec import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire import codec_v1, context
from trezor.wire import codec_v1, context, protocol_common
from trezor.wire.errors import ActionCancelled, DataError, Error
import trezor.enums.MessageType as MT
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
@ -88,8 +90,8 @@ if __debug__:
async def _handle_single_message(
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool
) -> codec_v1.Message | None:
ctx: context.Context, msg: protocol_common.Message, use_workflow: bool
) -> protocol_common.Message | None:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
@ -113,7 +115,7 @@ async def _handle_single_message(
__name__,
"%s:%x receive: <%s>",
ctx.iface.iface_num(),
ctx.sid,
ctx.session_id,
msg_type,
)
@ -143,7 +145,11 @@ async def _handle_single_message(
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(req_msg)
if msg.type is MT.Initialize:
# Special case for handle_initialize to have access to the verified session_id
task = handler(req_msg, ctx.session_id)
else:
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
@ -201,7 +207,7 @@ async def handle_session(
ctx_buffer = WIRE_BUFFER
ctx = context.Context(iface, session_id, ctx_buffer)
next_msg: codec_v1.Message | None = None
next_msg: protocol_common.Message | None = None
if __debug__ and is_debug_session:
import apps.debug
@ -218,7 +224,7 @@ async def handle_session(
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except codec_v1.CodecError as exc:
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
@ -229,6 +235,9 @@ async def handle_session(
msg = next_msg
next_msg = None
# Set ctx.session_id to the value msg.session_id
ctx.session_id = msg.session_id
try:
next_msg = await _handle_single_message(
ctx, msg, use_workflow=not is_debug_session

@ -3,6 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING
from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
@ -18,16 +19,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0)
class CodecError(Exception):
class CodecError(WireError):
pass
class Message:
def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype
self.data = mdata
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)

@ -17,7 +17,8 @@ from typing import TYPE_CHECKING
from trezor import log, loop, protobuf
from . import codec_v1
from .protocol import WireProtocol
from .protocol_common import Message
if TYPE_CHECKING:
from trezorio import WireInterface
@ -48,7 +49,7 @@ class UnexpectedMessage(Exception):
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: codec_v1.Message) -> None:
def __init__(self, msg: Message) -> None:
super().__init__()
self.msg = msg
@ -60,14 +61,14 @@ class Context:
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
self.iface = iface
self.sid = sid
self.buffer = buffer
self.session_id: bytearray | None = None
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
return WireProtocol.read_message(self.iface, self.buffer)
if TYPE_CHECKING:
@ -97,7 +98,7 @@ class Context:
__name__,
"%s:%x expect: %s",
self.iface.iface_num(),
self.sid,
self.session_id,
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
@ -109,6 +110,9 @@ class Context:
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
# (and maybe update ctx.session_id - depends on expected behaviour)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
@ -117,7 +121,7 @@ class Context:
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
self.session_id,
expected_type.MESSAGE_NAME,
)
@ -133,7 +137,7 @@ class Context:
__name__,
"%s:%x write: %s",
self.iface.iface_num(),
self.sid,
self.session_id,
msg.MESSAGE_NAME,
)
@ -151,10 +155,13 @@ class Context:
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
await WireProtocol.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
Message(
message_type=msg.MESSAGE_WIRE_TYPE,
message_data=memoryview(buffer)[:msg_size],
session_id=self.session_id,
),
)

@ -0,0 +1,19 @@
from trezor import utils
from trezor.wire import codec_v1, thp_v1
from trezor.wire.protocol_common import Message
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezorio import WireInterface
class WireProtocol:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
if utils.USE_THP:
return thp_v1.read_message(iface, buffer)
return codec_v1.read_message(iface, buffer)
async def write_message(iface: WireInterface, message: Message) -> None:
if utils.USE_THP:
return thp_v1.write_to_wire(iface, message)
return codec_v1.write_message(iface, message.type, message.data)

@ -0,0 +1,14 @@
class Message:
def __init__(
self,
message_type: int,
message_data: bytes,
session_id: bytearray | None = None,
) -> None:
self.type = message_type
self.data = message_data
self.session_id = session_id
class WireError(Exception):
pass

@ -0,0 +1,166 @@
import ustruct
from storage import cache_thp as storage_thp_cache
from storage.cache_thp import SessionThpCache, BROADCAST_CHANNEL_ID
from trezor import io
from trezor.wire.protocol_common import WireError
from typing import TYPE_CHECKING
from ubinascii import hexlify
if TYPE_CHECKING:
from enum import IntEnum
else:
IntEnum = object
class ThpError(WireError):
pass
class WorkflowState(IntEnum):
NOT_STARTED = 0
PENDING = 1
FINISHED = 2
class Workflow:
id: int
workflow_state: WorkflowState
class SessionState(IntEnum):
UNALLOCATED = 0
INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
PAIRED = 2
UNPAIRED = 3
PAIRING = 4
APP_TRAFFIC = 5
def get_workflow() -> Workflow:
pass # TODO
def print_all_test_sessions() -> None:
for session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
if session is None:
print("none")
else:
print(hexlify(session.session_id).decode("utf-8"), session.state)
#
def create_autenticated_session(unauthenticated_session: SessionThpCache):
storage_thp_cache.start_session() # TODO something like this but for THP
raise
def create_new_unauthenticated_session(iface: WireInterface, cid: int):
session_id = _get_id(iface, cid)
new_session = storage_thp_cache.create_new_unauthenticated_session(session_id)
set_session_state(new_session, SessionState.INITIALIZED)
def get_active_session() -> SessionThpCache | None:
return storage_thp_cache.get_active_session()
def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None:
session_id = _get_id(iface, cid)
return get_session_from_id(session_id)
def get_session_from_id(session_id) -> SessionThpCache | None:
session = _get_authenticated_session_or_none(session_id)
if session is None:
session = _get_unauthenticated_session_or_none(session_id)
return session
def get_state(session: SessionThpCache) -> int:
if session is None:
return SessionState.UNALLOCATED
return _decode_session_state(session.state)
def get_cid(session: SessionThpCache) -> int:
return storage_thp_cache._get_cid(session)
def get_next_channel_id() -> int:
return storage_thp_cache.get_next_channel_id()
def sync_can_send_message(session: SessionThpCache) -> bool:
return session.sync & 0x80 == 0x80
def sync_get_receive_expected_bit(session: SessionThpCache) -> int:
return (session.sync & 0x40) >> 6
def sync_get_send_bit(session: SessionThpCache) -> int:
return (session.sync & 0x20) >> 5
def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None:
session.sync &= 0x7F
if can_send:
session.sync |= 0x80
def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
if bit != 0 and bit != 1:
raise ThpError("Unexpected receive sync bit")
# set second bit to "bit" value
session.sync &= 0xBF
session.sync |= 0x40
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:
_sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session))
def is_active_session(session: SessionThpCache):
if session is None:
return False
return session.session_id == storage_thp_cache.get_active_session_id()
def set_session_state(session: SessionThpCache, new_state: SessionState):
session.state = new_state.to_bytes(1, "big")
def _get_id(iface: WireInterface, cid: int) -> bytearray:
return ustruct.pack(">HH", iface.iface_num(), cid)
def _get_authenticated_session_or_none(session_id) -> SessionThpCache:
for authenticated_session in storage_thp_cache._SESSIONS:
if authenticated_session.session_id == session_id:
return authenticated_session
return None
def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache:
for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
if unauthenticated_session.session_id == session_id:
return unauthenticated_session
return None
def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None:
if bit != 0 and bit != 1:
raise ThpError("Unexpected send sync bit")
# set third bit to "bit" value
session.sync &= 0xDF
session.sync |= 0x20
def _decode_session_state(state: bytearray) -> int:
return ustruct.unpack("B", state)[0]
def _encode_session_state(state: SessionState) -> bytearray:
return ustruct.pack("B", state)

@ -0,0 +1,370 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import SessionThpCache
from trezor import io, loop, utils
from trezor.crypto import crc
from trezor.wire.protocol_common import Message
import trezor.wire.thp_session as THP
from trezor.wire.thp_session import (
ThpError,
SessionState,
BROADCAST_CHANNEL_ID,
)
from ubinascii import hexlify
if TYPE_CHECKING:
from trezorio import WireInterface
_MAX_PAYLOAD_LEN = const(60000)
_CHECKSUM_LENGTH = const(4)
_CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x40
_ERROR = 0x41
_CONTINUATION_PACKET = 0x80
_ACK_MESSAGE = 0x20
_HANDSHAKE_INIT = 0x00
_PLAINTEXT = 0x01
ENCRYPTED_TRANSPORT = 0x02
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
b"\x0A\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
)
_UNALLOCATED_SESSION_ERROR = (
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
)
_REPORT_LENGTH = const(64)
_REPORT_INIT_DATA_OFFSET = const(5)
_REPORT_CONT_DATA_OFFSET = const(3)
class InitHeader:
format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
)
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
InitHeader.format_str,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
class InterruptingInitPacket:
def __init__(self, report) -> None:
self.initReport = report
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
msg = await read_message_or_init_packet(iface, buffer)
while type(msg) is not Message:
if msg is InterruptingInitPacket:
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
else:
raise ThpError("Unexpected output of read_message_or_init_packet")
return msg
async def read_message_or_init_packet(
iface: WireInterface, buffer: utils.BufferType, firstReport=None
) -> Message | InterruptingInitPacket:
while True:
# Wait for an initial report
if firstReport is None:
report = await _get_loop_wait_read(iface)
else:
report = firstReport
# Channel multiplexing
ctrl_byte, cid = ustruct.unpack(">BH", report)
if cid == BROADCAST_CHANNEL_ID:
_handle_broadcast(iface, ctrl_byte, report)
continue
# We allow for only one message to be read simultaneously. We do not
# support reading multiple messages with interleaven packets - with
# the sole exception of cid_request which can be handled independently.
if _is_ctrl_byte_continuation(ctrl_byte):
# continuation packet is not expected - ignore
continue
payload_length = ustruct.unpack(">H", report[3:])[0]
payload = _get_buffer_for_payload(payload_length, buffer)
header = InitHeader(ctrl_byte, cid, payload_length)
# buffer the received data
interruptingPacket = await _buffer_received_data(payload, header, iface, report)
if interruptingPacket is not None:
return interruptingPacket
# Check CRC
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
# checksum is not valid -> ignore message
continue
session = THP.get_session(iface, cid)
session_state = THP.get_state(session)
# Handle message on unallocated channel
if session_state == SessionState.UNALLOCATED:
message = _handle_unallocated(iface, cid)
# unallocated should not return regular message, TODO, but it might change
if message is not None:
return message
continue
# Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
_handle_received_ACK(session, sync_bit)
continue
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(session):
message = _handle_unexpected_sync_bit(iface, cid, sync_bit)
# unsynchronized messages should not return regular message, TODO,
# but it might change with the cancelation message
if message is not None:
return message
continue
# 3: Send ACK in response
_sendAck(iface, cid, sync_bit)
THP.sync_set_receive_expected_bit(session, 1 - sync_bit)
return _handle_allocated(ctrl_byte, session, payload)
def _get_loop_wait_read(iface: WireInterface):
return loop.wait(iface.iface_num() | io.POLL_READ)
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
if payload_length > _MAX_PAYLOAD_LEN:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(_REPORT_LENGTH)
raise ("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:payload_length]
async def _buffer_received_data(
payload: utils.BufferType, header: InitHeader, iface, report
) -> None | InterruptingInitPacket:
# buffer the initial data
nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET)
while nread < header.length:
# wait for continuation report
report = await _get_loop_wait_read(iface)
# channel multiplexing
cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report)
# handle broadcast - allows the reading process
# to survive interruption by broadcast
if cont_cid == BROADCAST_CHANNEL_ID:
_handle_broadcast(iface, cont_ctrl_byte, report)
continue
# handle unexpected initiation packet
if not _is_ctrl_byte_continuation(cont_ctrl_byte):
# TODO possibly add timeout - allow interruption only after a long time
return InterruptingInitPacket(report)
# ignore continuation packets on different channels
if cont_cid != header.cid:
continue
# buffer the continuation data
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
async def write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False
) -> None:
session = THP.get_session_from_id(message.session_id)
cid = THP.get_cid(session)
payload = message.type.to_bytes(2, "big") + message.data
payload_length = len(payload)
if THP.get_state(session) == SessionState.INITIALIZED:
# write message in plaintext, TODO check if it is allowed
ctrl_byte = _PLAINTEXT
elif THP.get_state(session) == SessionState.APP_TRAFFIC:
ctrl_byte = ENCRYPTED_TRANSPORT
else:
raise ThpError("Session in not implemented state" + str(THP.get_state(session)))
if not is_retransmission:
ctrl_byte = _add_sync_bit_to_ctrl_byte(
ctrl_byte, THP.sync_get_send_bit(session)
)
THP.sync_set_send_bit_to_opposite(session)
else:
# retransmission must have the same sync bit as the previously sent message
ctrl_byte = _add_sync_bit_to_ctrl_byte(ctrl_byte, 1 - THP.sync_get_send_bit())
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
await write_to_wire(iface, header, payload + checksum)
# TODO set timeout for retransmission
async def write_to_wire(
iface: WireInterface, header: InitHeader, payload: bytes
) -> None:
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
payload_length = len(payload)
# prepare the report buffer with header data
report = bytearray(_REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
await _write_report(write, iface, report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_length:
header.pack_to_cont_buffer(report)
while nwritten < payload_length:
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
await _write_report(write, iface, report)
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
while True:
await write
n = iface.write(report)
if n == len(report):
return
def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None:
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:])
if not _is_checksum_valid(checksum, data=report[:-4]):
raise ThpError("Checksum is not valid")
channel_id = _get_new_channel_id()
THP.create_new_unauthenticated_session(iface, channel_id)
response_data = (
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
)
response_header = InitHeader(
_CHANNEL_ALLOCATION_RES,
BROADCAST_CHANNEL_ID,
len(response_data) + _CHECKSUM_LENGTH,
)
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
write_to_wire(iface, response_header, response_data + checksum)
def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
# Parameters session and ctrl_byte will be used to determine if the
# communication should be encrypted or not
message_type = ustruct.unpack(">H", payload)[0]
# trim message type and checksum from payload
message_data = payload[2:-_CHECKSUM_LENGTH]
return Message(message_type, message_data, session.session_id)
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
# No ACKs expected
if THP.sync_can_send_message(session):
return
# ACK has incorrect sync bit
if THP.sync_get_send_bit(session) != sync_bit:
return
# ACK is expected and it has correct sync bit
THP.sync_set_can_send_message(session, True)
async def _handle_unallocated(iface, cid) -> Message | None:
data = _UNALLOCATED_SESSION_ERROR
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + data)
write_to_wire(iface, header, data + checksum)
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes())
write_to_wire(iface, header, checksum)
def _handle_unexpected_sync_bit(
iface: WireInterface, cid: int, sync_bit: int
) -> Message | None:
_sendAck(iface, cid, sync_bit)
# TODO handle cancelation messages and messages on allocated channels without synchronization
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
def _get_new_channel_id() -> int:
return THP.get_next_channel_id()
def _is_checksum_valid(checksum: bytearray, data: bytearray) -> bool:
data_checksum = _compute_checksum_bytes(data)
return checksum == data_checksum
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
return ctrl_byte & 0x20 == _ACK_MESSAGE
def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
def _compute_checksum_bytes(data: bytearray):
return crc.crc32(data).to_bytes(4, "big")

@ -1,17 +1,26 @@
from common import * # isort:skip
from mock_storage import mock_storage
from storage import cache
from trezor.messages import EndSession, Initialize
from storage import cache, cache_codec, cache_thp
from storage.cache_common import InvalidSessionError
from trezor import utils
from trezor.messages import Initialize
from trezor.messages import EndSession
from apps.base import handle_EndSession, handle_Initialize
KEY = 0
if utils.USE_THP:
_PROTOCOL_CACHE = cache_thp
else:
_PROTOCOL_CACHE = cache_codec
# Function moved from cache.py, as it was not used there
def is_session_started() -> bool:
return cache._active_session_idx is not None
return _PROTOCOL_CACHE.get_active_session() is not None
class TestStorageCache(unittest.TestCase):
@ -25,9 +34,9 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.set(KEY, "something")
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.get(KEY)
def test_end_session(self):
@ -36,7 +45,7 @@ class TestStorageCache(unittest.TestCase):
cache.set(KEY, b"A")
cache.end_current_session()
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
self.assertRaises(InvalidSessionError, cache.get, KEY)
# ending an ended session should be a no-op
cache.end_current_session()
@ -63,7 +72,7 @@ class TestStorageCache(unittest.TestCase):
session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT):
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY))
@ -83,7 +92,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello")
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.get(KEY)
def test_get_set_int(self):
@ -101,7 +110,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get_int(KEY), 1234)
cache.clear_all()
with self.assertRaises(cache.InvalidSessionError):
with self.assertRaises(InvalidSessionError):
cache.get_int(KEY)
def test_delete(self):
@ -186,6 +195,9 @@ class TestStorageCache(unittest.TestCase):
@mock_storage
def test_Initialize(self):
if utils.USE_THP: # INITIALIZE SHOULD NOT BE IN THP!!! TODO
return
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
@ -210,7 +222,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY))
# but resuming a session loads the previous one
@ -218,13 +230,18 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello")
def test_EndSession(self):
<<<<<<< HEAD
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
cache.start_session()
=======
self.assertRaises(InvalidSessionError, cache.get, KEY)
session_id = cache.start_session()
>>>>>>> 8681ba167 (Basic THP functinality - not-polished prototype)
self.assertTrue(is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
self.assertRaises(InvalidSessionError, cache.get, KEY)
if __name__ == "__main__":

@ -0,0 +1,341 @@
from common import *
from ubinascii import hexlify, unhexlify
import ustruct
from trezor import io, utils
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import thp_v1
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import Message
import trezor.wire.thp_session as THP
from micropython import const
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
MESSAGE_TYPE_BYTES = b"\x42\x42"
_MESSAGE_TYPE_LEN = 2
PLAINTEXT_0 = 0x01
PLAINTEXT_1 = 0x11
COMMON_CID = 4660
CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
INIT_MESSAGE_DATA_LENGTH = (
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
)
def make_header(ctrl_byte, cid, length):
return ustruct.pack(">BHH", ctrl_byte, cid, length)
def make_cont_header():
return ustruct.pack(">BH", CONT, COMMON_CID)
def makeSimpleMessage(header, message_type, message_data):
return header + ustruct.pack(">H", message_type) + message_data
def makeCidRequest(header, message_data):
return header + message_data
def printBytes(a):
print(hexlify(a).decode("utf-8"))
def getPlaintext() -> bytes:
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
PLAINTEXT_0
def getCid() -> int:
return THP.get_cid(THP.get_active_session())
# This test suite is an adaptation of test_trezor.wire.codec_v1
class TestWireTrezorHostProtocolV1(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
if not utils.USE_THP:
import storage.cache_thp # noQA:F401
def test_simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3C\x6C"
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
cid_request_dummy_data_checksum = b"\x67\x8E\xAC\xE0"
message = makeSimpleMessage(
message_header,
MESSAGE_TYPE,
cid_request_dummy_data + cid_request_dummy_data_checksum,
)
buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(cid_req_message)
gen.send(message)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, cid_request_dummy_data)
buffer_without_zeroes = buffer[: len(message) - 5]
message_without_header = message[5:]
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def test_read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
)
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + checksum
buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(message)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58)
def test_read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
)
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data
# other packets are cont header + 61 bytes of data
cont_header = make_cont_header()
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk
for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + checksum,
64 - HEADER_CONT_LENGTH,
)
]
buffer = bytearray(262)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(packets[-1])
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum)
def test_read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH,
)
packet = (
header
+ MESSAGE_TYPE_BYTES
+ message
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
)
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(packet)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_write_one_packet(self):
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID))
gen = thp_v1.write_message(self.interface, message)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
)
expected_message = (
header
+ MESSAGE_TYPE_BYTES
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH)
)
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
header = make_header(
PLAINTEXT_1,
COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
)
cont_header = make_cont_header()
checksum = thp_v1._compute_checksum_bytes(
header + message.type.to_bytes(2, "big") + message.data
)
packets = [
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
] + [
cont_header + chunk
for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum,
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(self.interface.data[-1])
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH
- HEADER_CONT_LENGTH
- _CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
buffer = bytearray(65536)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
# THP returns "Message too large" error after reading the message size,
# it is different from codec_v1 as it does not allow big enough messages
# to raise MemoryError in this test
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(thp_v1.ThpError) as e:
query = gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save