Basic THP functinality - not-polished prototype

M1nd3r/thp5
M1nd3r 3 months ago committed by M1nd3r
parent 0d7fe7d643
commit 87365a3148

@ -47,6 +47,12 @@ storage
import storage import storage
storage.cache storage.cache
import storage.cache import storage.cache
storage.cache_codec
import storage.cache_codec
storage.cache_common
import storage.cache_common
storage.cache_thp
import storage.cache_thp
storage.common storage.common
import storage.common import storage.common
storage.debug storage.debug
@ -195,6 +201,14 @@ trezor.wire.context
import trezor.wire.context import trezor.wire.context
trezor.wire.errors trezor.wire.errors
import trezor.wire.errors import trezor.wire.errors
trezor.wire.protocol
import trezor.wire.protocol
trezor.wire.protocol_common
import trezor.wire.protocol_common
trezor.wire.thp_session
import trezor.wire.thp_session
trezor.wire.thp_v1
import trezor.wire.thp_v1
trezor.workflow trezor.workflow
import trezor.workflow import trezor.workflow
apps apps

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import storage.cache as storage_cache import storage.cache as storage_cache
import storage.cache_thp as storage_thp_cache
import storage.device as storage_device import storage.device as storage_device
from trezor import TR, config, utils, wire, workflow from trezor import TR, config, utils, wire, workflow
from trezor.enums import HomescreenFormat, MessageType from trezor.enums import HomescreenFormat, MessageType
@ -175,10 +176,21 @@ def get_features() -> Features:
return f return f
async def handle_Initialize(msg: Initialize) -> Features: # handle_Initialize should not be used with THP to start a new session
session_id = storage_cache.start_session(msg.session_id) async def handle_Initialize(
msg: Initialize, message_session_id: bytearray | None = None
) -> Features:
if message_session_id is None and utils.USE_THP:
raise ValueError("With THP enabled, a session id must be provided in args")
if utils.USE_THP:
session_id = storage_thp_cache.start_existing_session(msg.session_id)
else:
session_id = storage_cache.start_session(msg.session_id)
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
# TODO this block should be changed in THP
derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO) derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
@ -190,7 +202,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
# seed is already derived, and host wants to change derive_cardano setting # seed is already derived, and host wants to change derive_cardano setting
# => create a new session # => create a new session
storage_cache.end_current_session() storage_cache.end_current_session()
session_id = storage_cache.start_session() session_id = storage_cache.start_session() # This should not be used in THP
have_seed = False have_seed = False
if not have_seed: if not have_seed:
@ -200,7 +212,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
) )
features = get_features() features = get_features()
features.session_id = session_id features.session_id = session_id # not important in THP
return features return features

@ -4,17 +4,13 @@ from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import utils from trezor import utils
from storage.cache_common import SESSIONLESS_FLAG, InvalidSessionError, SessionlessCache
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload from typing import Sequence, TypeVar, overload
T = TypeVar("T") T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSIONLESS_FLAG = const(128)
_SESSION_ID_LENGTH = const(32)
# Traditional cache keys # Traditional cache keys
APP_COMMON_SEED = const(0) APP_COMMON_SEED = const(0)
APP_COMMON_AUTHORIZATION_TYPE = const(1) APP_COMMON_AUTHORIZATION_TYPE = const(1)
@ -27,14 +23,13 @@ if not utils.BITCOIN_ONLY:
APP_MONERO_LIVE_REFRESH = const(7) APP_MONERO_LIVE_REFRESH = const(7)
# Keys that are valid across sessions # Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG) APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG) APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | _SESSIONLESS_FLAG) STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | SESSIONLESS_FLAG)
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | _SESSIONLESS_FLAG) APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | SESSIONLESS_FLAG)
APP_COMMON_BUSY_DEADLINE_MS = const(4 | _SESSIONLESS_FLAG) APP_COMMON_BUSY_DEADLINE_MS = const(4 | SESSIONLESS_FLAG)
APP_MISC_COSI_NONCE = const(5 | _SESSIONLESS_FLAG) APP_MISC_COSI_NONCE = const(5 | SESSIONLESS_FLAG)
APP_MISC_COSI_COMMITMENT = const(6 | _SESSIONLESS_FLAG) APP_MISC_COSI_COMMITMENT = const(6 | SESSIONLESS_FLAG)
# === Homescreen storage === # === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is # This does not logically belong to the "cache" functionality, but the cache module is
@ -52,103 +47,6 @@ homescreen_shown: object | None = None
autolock_last_touch: int | None = None autolock_last_touch: int | None = None
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None: ...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
)
super().__init__()
# XXX # XXX
# Allocation notes: # Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so # Instantiation of a DataCache subclass should make as little garbage as possible, so
@ -157,97 +55,46 @@ class SessionlessCache(DataCache):
# bytearrays, then later call `clear()` on all the existing objects, which resets them # bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice. # to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache() _SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_SESSIONLESS_CACHE.clear()
gc.collect() if utils.USE_THP:
from storage import cache_thp
_PROTOCOL_CACHE = cache_thp
else:
from storage import cache_codec
_active_session_idx: int | None = None _PROTOCOL_CACHE = cache_codec
_session_usage_counter = 0
_PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear()
def start_session(received_session_id: bytes | None = None) -> bytes: gc.collect()
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None: def clear_all() -> None:
global _active_session_idx global autolock_last_touch
autolock_last_touch = None
_SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear() def start_session(received_session_id: bytes | None = None) -> bytes:
_active_session_idx = None return _PROTOCOL_CACHE.start_session(received_session_id)
def set(key: int, value: bytes) -> None: def end_current_session() -> None:
if key & _SESSIONLESS_FLAG: _PROTOCOL_CACHE.end_current_session()
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
def set_int(key: int, value: int) -> None: def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG: if key & SESSIONLESS_FLAG:
length = _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG] return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG)
elif _active_session_idx is None: active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError raise InvalidSessionError
else: return active_session.delete(key)
length = _SESSIONS[_active_session_idx].fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -261,11 +108,12 @@ if TYPE_CHECKING:
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG: if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default) return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default)
if _active_session_idx is None: active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key, default) return active_session.get(key, default)
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
@ -277,29 +125,52 @@ def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
def get_int_all_sessions(key: int) -> builtins.set[int]: def get_int_all_sessions(key: int) -> builtins.set[int]:
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS if key & SESSIONLESS_FLAG:
values = builtins.set() values = builtins.set()
for session in sessions: encoded = _SESSIONLESS_CACHE.get(key)
encoded = session.get(key)
if encoded is not None: if encoded is not None:
values.add(int.from_bytes(encoded, "big")) values.add(int.from_bytes(encoded, "big"))
return values return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
def is_set(key: int) -> bool: def is_set(key: int) -> bool:
if key & _SESSIONLESS_FLAG: if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG) return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG)
if _active_session_idx is None: active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError raise InvalidSessionError
return _SESSIONS[_active_session_idx].is_set(key) return active_session.is_set(key)
def delete(key: int) -> None: def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG: if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG) _SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value)
if _active_session_idx is None: return
active_session = _PROTOCOL_CACHE.get_active_session()
if active_session is None:
raise InvalidSessionError raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key) active_session.set(key, value)
def set_int(key: int, value: int) -> None:
active_session = _PROTOCOL_CACHE.get_active_session()
if key & SESSIONLESS_FLAG:
length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG]
if active_session is None:
raise InvalidSessionError
else:
length = active_session.fields[key]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -336,15 +207,3 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
return wrapper return wrapper
return decorator return decorator
def clear_all() -> None:
global _active_session_idx
global autolock_last_touch
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()
autolock_last_touch = None

@ -0,0 +1,144 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
from storage.cache_common import DataCache, InvalidSessionError
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
_MAX_SESSIONS_COUNT = const(10)
_SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionCache] = []
def initialize() -> None:
global _SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
for session in _SESSIONS:
session.clear()
initialize()
_active_session_idx: int | None = None
_session_usage_counter = 0
def get_active_session() -> SessionCache | None:
if _active_session_idx is None:
return None
return _SESSIONS[_active_session_idx]
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS:
session.clear()

@ -0,0 +1,70 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import utils
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
SESSIONLESS_FLAG = const(128)
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
if TYPE_CHECKING:
@overload
def get(self, key: int) -> bytes | None:
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
8, # APP_COMMON_BUSY_DEADLINE_MS
32, # APP_MISC_COSI_NONCE
32, # APP_MISC_COSI_COMMITMENT
)
super().__init__()

@ -0,0 +1,256 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_common import DataCache, InvalidSessionError
from trezor import utils
if TYPE_CHECKING:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
# THP specific constants
_MAX_SESSIONS_COUNT = const(20)
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5)
_THP_SESSION_STATE_LENGTH = const(1)
_SESSION_ID_LENGTH = const(4)
BROADCAST_CHANNEL_ID = const(65535)
class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
)
else:
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
32, # APP_COMMON_NONCE
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
)
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
_SESSIONS: list[SessionThpCache] = []
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = []
def initialize() -> None:
global _SESSIONS
global _UNAUTHENTICATED_SESSIONS
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
for session in _SESSIONS:
session.clear()
for session in _UNAUTHENTICATED_SESSIONS:
session.clear()
initialize()
# THP vars
_next_unauthenicated_session_index: int = 0
_is_active_session_authenticated: bool
_active_session_idx: int | None = None
_session_usage_counter = 0
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
cid_counter: int = 4659
def get_active_session_id():
if get_active_session() is None:
return None
return get_active_session().session_id
def get_active_session() -> SessionThpCache | None:
if _active_session_idx is None:
return None
if _is_active_session_authenticated:
return _SESSIONS[_active_session_idx]
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
def get_next_channel_id() -> int:
global cid_counter
while True:
cid_counter += 1
if cid_counter >= BROADCAST_CHANNEL_ID:
cid_counter = 1
if _is_cid_unique():
break
return cid_counter
def _is_cid_unique() -> bool:
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
if cid_counter == _get_cid(session):
return False
return True
def _get_cid(session: SessionThpCache) -> int:
return int.from_bytes(session.session_id[2:], "big")
def create_new_unauthenticated_session(session_id: bytearray) -> SessionThpCache:
if len(session_id) != 4:
raise ValueError("session_id must be 4 bytes long.")
global _active_session_idx
global _is_active_session_authenticated
global _next_unauthenicated_session_index
i = _next_unauthenicated_session_index
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
_next_unauthenicated_session_index += 1
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
_next_unauthenicated_session_index = 0
# Set session as active if and only if there is no active session
if _active_session_idx is None:
_active_session_idx = i
_is_active_session_authenticated = False
return _UNAUTHENTICATED_SESSIONS[i]
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
return i
return None
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
global _session_usage_counter
unauth_session_idx = get_unauth_session_index(unauth_session)
if unauth_session_idx is None:
raise InvalidSessionError
# replace least recently used authenticated session by the new session
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
_UNAUTHENTICATED_SESSIONS[unauth_session_idx] = None
_session_usage_counter += 1
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
def get_least_recently_used_authetnicated_session_index() -> int:
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
return lru_session_idx
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
def start_session(session_id: bytes) -> bytes: # TODO incomplete
global _active_session_idx
global _is_active_session_authenticated
if session_id is not None:
if get_active_session_id() == session_id:
return session_id
for index in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = True
return session_id
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big")
new_session = create_new_unauthenticated_session(new_session_id)
index = get_unauth_session_index(new_session)
_active_session_idx = index
_is_active_session_authenticated = False
return new_session_id
def start_existing_session(session_id: bytearray) -> bytes:
if session_id is None:
raise ValueError("session_id cannot be None")
if get_active_session_id() == session_id:
return session_id
for index in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = True
return session_id
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
raise ValueError("There is no active session with provided session_id")
def end_current_session() -> None:
global _active_session_idx
active_session = get_active_session()
if active_session is None:
return
active_session.clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
session.clear()

@ -33,6 +33,8 @@ MODEL_IS_T2B1: bool = INTERNAL_MODEL == "T2B1"
DISABLE_ANIMATION = 0 DISABLE_ANIMATION = 0
USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/...
if __debug__: if __debug__:
if EMULATOR: if EMULATOR:
import uos import uos

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response. - Request / response.
- Protobuf-encoded, see `protobuf.py`. - Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`. - Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_v1.py`.
- Transferred over USB interface, or UDP in case of Unix emulation. - Transferred over USB interface, or UDP in case of Unix emulation.
This module: This module:
@ -23,15 +23,17 @@ reads the message's header. When the message type is known the first handler is
""" """
from apps import workflow_handlers
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage.cache import InvalidSessionError from storage.cache_codec import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType from trezor.enums import FailureType
from trezor.messages import Failure from trezor.messages import Failure
from trezor.wire import codec_v1, context from trezor.wire import codec_v1, context, protocol_common
from trezor.wire.errors import ActionCancelled, DataError, Error from trezor.wire.errors import ActionCancelled, DataError, Error
import trezor.enums.MessageType as MT
# Import all errors into namespace, so that `wire.Error` is available from # Import all errors into namespace, so that `wire.Error` is available from
# other packages. # other packages.
@ -88,8 +90,8 @@ if __debug__:
async def _handle_single_message( async def _handle_single_message(
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool ctx: context.Context, msg: protocol_common.Message, use_workflow: bool
) -> codec_v1.Message | None: ) -> protocol_common.Message | None:
"""Handle a message that was loaded from USB by the caller. """Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case Find the appropriate handler, run it and write its result on the wire. In case
@ -113,7 +115,7 @@ async def _handle_single_message(
__name__, __name__,
"%s:%x receive: <%s>", "%s:%x receive: <%s>",
ctx.iface.iface_num(), ctx.iface.iface_num(),
ctx.sid, ctx.session_id,
msg_type, msg_type,
) )
@ -143,7 +145,11 @@ async def _handle_single_message(
req_msg = wrap_protobuf_load(msg.data, req_type) req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task. # Create the handler task.
task = handler(req_msg) if msg.type is MT.Initialize:
# Special case for handle_initialize to have access to the verified session_id
task = handler(req_msg, ctx.session_id)
else:
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire # Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a # communication inside, but it should eventually return a
@ -201,7 +207,7 @@ async def handle_session(
ctx_buffer = WIRE_BUFFER ctx_buffer = WIRE_BUFFER
ctx = context.Context(iface, session_id, ctx_buffer) ctx = context.Context(iface, session_id, ctx_buffer)
next_msg: codec_v1.Message | None = None next_msg: protocol_common.Message | None = None
if __debug__ and is_debug_session: if __debug__ and is_debug_session:
import apps.debug import apps.debug
@ -218,7 +224,7 @@ async def handle_session(
# wait for a new one coming from the wire. # wait for a new one coming from the wire.
try: try:
msg = await ctx.read_from_wire() msg = await ctx.read_from_wire()
except codec_v1.CodecError as exc: except protocol_common.WireError as exc:
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
await ctx.write(failure(exc)) await ctx.write(failure(exc))
@ -229,6 +235,9 @@ async def handle_session(
msg = next_msg msg = next_msg
next_msg = None next_msg = None
# Set ctx.session_id to the value msg.session_id
ctx.session_id = msg.session_id
try: try:
next_msg = await _handle_single_message( next_msg = await _handle_single_message(
ctx, msg, use_workflow=not is_debug_session ctx, msg, use_workflow=not is_debug_session

@ -3,6 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import io, loop, utils from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
@ -18,16 +19,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report
SESSION_ID = const(0) SESSION_ID = const(0)
class CodecError(Exception): class CodecError(WireError):
pass pass
class Message:
def __init__(self, mtype: int, mdata: bytes) -> None:
self.type = mtype
self.data = mdata
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ) read = loop.wait(iface.iface_num() | io.POLL_READ)

@ -17,7 +17,8 @@ from typing import TYPE_CHECKING
from trezor import log, loop, protobuf from trezor import log, loop, protobuf
from . import codec_v1 from .protocol import WireProtocol
from .protocol_common import Message
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
@ -48,7 +49,7 @@ class UnexpectedMessage(Exception):
should be aborted and a new one started as if `msg` was the first message. should be aborted and a new one started as if `msg` was the first message.
""" """
def __init__(self, msg: codec_v1.Message) -> None: def __init__(self, msg: Message) -> None:
super().__init__() super().__init__()
self.msg = msg self.msg = msg
@ -60,14 +61,14 @@ class Context:
(i.e., wire, debug, single BT connection, etc.) (i.e., wire, debug, single BT connection, etc.)
""" """
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None: def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
self.iface = iface self.iface = iface
self.sid = sid
self.buffer = buffer self.buffer = buffer
self.session_id: bytearray | None = None
def read_from_wire(self) -> Awaitable[codec_v1.Message]: def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it.""" """Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer) return WireProtocol.read_message(self.iface, self.buffer)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -97,7 +98,7 @@ class Context:
__name__, __name__,
"%s:%x expect: %s", "%s:%x expect: %s",
self.iface.iface_num(), self.iface.iface_num(),
self.sid, self.session_id,
expected_type.MESSAGE_NAME if expected_type else expected_types, expected_type.MESSAGE_NAME if expected_type else expected_types,
) )
@ -109,6 +110,9 @@ class Context:
if msg.type not in expected_types: if msg.type not in expected_types:
raise UnexpectedMessage(msg) raise UnexpectedMessage(msg)
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
# (and maybe update ctx.session_id - depends on expected behaviour)
if expected_type is None: if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type) expected_type = protobuf.type_for_wire(msg.type)
@ -117,7 +121,7 @@ class Context:
__name__, __name__,
"%s:%x read: %s", "%s:%x read: %s",
self.iface.iface_num(), self.iface.iface_num(),
self.sid, self.session_id,
expected_type.MESSAGE_NAME, expected_type.MESSAGE_NAME,
) )
@ -133,7 +137,7 @@ class Context:
__name__, __name__,
"%s:%x write: %s", "%s:%x write: %s",
self.iface.iface_num(), self.iface.iface_num(),
self.sid, self.session_id,
msg.MESSAGE_NAME, msg.MESSAGE_NAME,
) )
@ -151,10 +155,13 @@ class Context:
msg_size = protobuf.encode(buffer, msg) msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message( await WireProtocol.write_message(
self.iface, self.iface,
msg.MESSAGE_WIRE_TYPE, Message(
memoryview(buffer)[:msg_size], message_type=msg.MESSAGE_WIRE_TYPE,
message_data=memoryview(buffer)[:msg_size],
session_id=self.session_id,
),
) )

@ -0,0 +1,19 @@
from trezor import utils
from trezor.wire import codec_v1, thp_v1
from trezor.wire.protocol_common import Message
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezorio import WireInterface
class WireProtocol:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
if utils.USE_THP:
return thp_v1.read_message(iface, buffer)
return codec_v1.read_message(iface, buffer)
async def write_message(iface: WireInterface, message: Message) -> None:
if utils.USE_THP:
return thp_v1.write_to_wire(iface, message)
return codec_v1.write_message(iface, message.type, message.data)

@ -0,0 +1,14 @@
class Message:
def __init__(
self,
message_type: int,
message_data: bytes,
session_id: bytearray | None = None,
) -> None:
self.type = message_type
self.data = message_data
self.session_id = session_id
class WireError(Exception):
pass

@ -0,0 +1,166 @@
import ustruct
from storage import cache_thp as storage_thp_cache
from storage.cache_thp import SessionThpCache, BROADCAST_CHANNEL_ID
from trezor import io
from trezor.wire.protocol_common import WireError
from typing import TYPE_CHECKING
from ubinascii import hexlify
if TYPE_CHECKING:
from enum import IntEnum
else:
IntEnum = object
class ThpError(WireError):
pass
class WorkflowState(IntEnum):
NOT_STARTED = 0
PENDING = 1
FINISHED = 2
class Workflow:
id: int
workflow_state: WorkflowState
class SessionState(IntEnum):
UNALLOCATED = 0
INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
PAIRED = 2
UNPAIRED = 3
PAIRING = 4
APP_TRAFFIC = 5
def get_workflow() -> Workflow:
pass # TODO
def print_all_test_sessions() -> None:
for session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
if session is None:
print("none")
else:
print(hexlify(session.session_id).decode("utf-8"), session.state)
#
def create_autenticated_session(unauthenticated_session: SessionThpCache):
storage_thp_cache.start_session() # TODO something like this but for THP
raise
def create_new_unauthenticated_session(iface: WireInterface, cid: int):
session_id = _get_id(iface, cid)
new_session = storage_thp_cache.create_new_unauthenticated_session(session_id)
set_session_state(new_session, SessionState.INITIALIZED)
def get_active_session() -> SessionThpCache | None:
return storage_thp_cache.get_active_session()
def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None:
session_id = _get_id(iface, cid)
return get_session_from_id(session_id)
def get_session_from_id(session_id) -> SessionThpCache | None:
session = _get_authenticated_session_or_none(session_id)
if session is None:
session = _get_unauthenticated_session_or_none(session_id)
return session
def get_state(session: SessionThpCache) -> int:
if session is None:
return SessionState.UNALLOCATED
return _decode_session_state(session.state)
def get_cid(session: SessionThpCache) -> int:
return storage_thp_cache._get_cid(session)
def get_next_channel_id() -> int:
return storage_thp_cache.get_next_channel_id()
def sync_can_send_message(session: SessionThpCache) -> bool:
return session.sync & 0x80 == 0x80
def sync_get_receive_expected_bit(session: SessionThpCache) -> int:
return (session.sync & 0x40) >> 6
def sync_get_send_bit(session: SessionThpCache) -> int:
return (session.sync & 0x20) >> 5
def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None:
session.sync &= 0x7F
if can_send:
session.sync |= 0x80
def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
if bit != 0 and bit != 1:
raise ThpError("Unexpected receive sync bit")
# set second bit to "bit" value
session.sync &= 0xBF
session.sync |= 0x40
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:
_sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session))
def is_active_session(session: SessionThpCache):
if session is None:
return False
return session.session_id == storage_thp_cache.get_active_session_id()
def set_session_state(session: SessionThpCache, new_state: SessionState):
session.state = new_state.to_bytes(1, "big")
def _get_id(iface: WireInterface, cid: int) -> bytearray:
return ustruct.pack(">HH", iface.iface_num(), cid)
def _get_authenticated_session_or_none(session_id) -> SessionThpCache:
for authenticated_session in storage_thp_cache._SESSIONS:
if authenticated_session.session_id == session_id:
return authenticated_session
return None
def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache:
for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
if unauthenticated_session.session_id == session_id:
return unauthenticated_session
return None
def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None:
if bit != 0 and bit != 1:
raise ThpError("Unexpected send sync bit")
# set third bit to "bit" value
session.sync &= 0xDF
session.sync |= 0x20
def _decode_session_state(state: bytearray) -> int:
return ustruct.unpack("B", state)[0]
def _encode_session_state(state: SessionState) -> bytearray:
return ustruct.pack("B", state)

@ -0,0 +1,370 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import SessionThpCache
from trezor import io, loop, utils
from trezor.crypto import crc
from trezor.wire.protocol_common import Message
import trezor.wire.thp_session as THP
from trezor.wire.thp_session import (
ThpError,
SessionState,
BROADCAST_CHANNEL_ID,
)
from ubinascii import hexlify
if TYPE_CHECKING:
from trezorio import WireInterface
_MAX_PAYLOAD_LEN = const(60000)
_CHECKSUM_LENGTH = const(4)
_CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x40
_ERROR = 0x41
_CONTINUATION_PACKET = 0x80
_ACK_MESSAGE = 0x20
_HANDSHAKE_INIT = 0x00
_PLAINTEXT = 0x01
ENCRYPTED_TRANSPORT = 0x02
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
b"\x0A\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
)
_UNALLOCATED_SESSION_ERROR = (
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
)
_REPORT_LENGTH = const(64)
_REPORT_INIT_DATA_OFFSET = const(5)
_REPORT_CONT_DATA_OFFSET = const(3)
class InitHeader:
format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
)
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
InitHeader.format_str,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
class InterruptingInitPacket:
def __init__(self, report) -> None:
self.initReport = report
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
msg = await read_message_or_init_packet(iface, buffer)
while type(msg) is not Message:
if msg is InterruptingInitPacket:
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
else:
raise ThpError("Unexpected output of read_message_or_init_packet")
return msg
async def read_message_or_init_packet(
iface: WireInterface, buffer: utils.BufferType, firstReport=None
) -> Message | InterruptingInitPacket:
while True:
# Wait for an initial report
if firstReport is None:
report = await _get_loop_wait_read(iface)
else:
report = firstReport
# Channel multiplexing
ctrl_byte, cid = ustruct.unpack(">BH", report)
if cid == BROADCAST_CHANNEL_ID:
_handle_broadcast(iface, ctrl_byte, report)
continue
# We allow for only one message to be read simultaneously. We do not
# support reading multiple messages with interleaven packets - with
# the sole exception of cid_request which can be handled independently.
if _is_ctrl_byte_continuation(ctrl_byte):
# continuation packet is not expected - ignore
continue
payload_length = ustruct.unpack(">H", report[3:])[0]
payload = _get_buffer_for_payload(payload_length, buffer)
header = InitHeader(ctrl_byte, cid, payload_length)
# buffer the received data
interruptingPacket = await _buffer_received_data(payload, header, iface, report)
if interruptingPacket is not None:
return interruptingPacket
# Check CRC
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
# checksum is not valid -> ignore message
continue
session = THP.get_session(iface, cid)
session_state = THP.get_state(session)
# Handle message on unallocated channel
if session_state == SessionState.UNALLOCATED:
message = _handle_unallocated(iface, cid)
# unallocated should not return regular message, TODO, but it might change
if message is not None:
return message
continue
# Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
_handle_received_ACK(session, sync_bit)
continue
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(session):
message = _handle_unexpected_sync_bit(iface, cid, sync_bit)
# unsynchronized messages should not return regular message, TODO,
# but it might change with the cancelation message
if message is not None:
return message
continue
# 3: Send ACK in response
_sendAck(iface, cid, sync_bit)
THP.sync_set_receive_expected_bit(session, 1 - sync_bit)
return _handle_allocated(ctrl_byte, session, payload)
def _get_loop_wait_read(iface: WireInterface):
return loop.wait(iface.iface_num() | io.POLL_READ)
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
if payload_length > _MAX_PAYLOAD_LEN:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(_REPORT_LENGTH)
raise ("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:payload_length]
async def _buffer_received_data(
payload: utils.BufferType, header: InitHeader, iface, report
) -> None | InterruptingInitPacket:
# buffer the initial data
nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET)
while nread < header.length:
# wait for continuation report
report = await _get_loop_wait_read(iface)
# channel multiplexing
cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report)
# handle broadcast - allows the reading process
# to survive interruption by broadcast
if cont_cid == BROADCAST_CHANNEL_ID:
_handle_broadcast(iface, cont_ctrl_byte, report)
continue
# handle unexpected initiation packet
if not _is_ctrl_byte_continuation(cont_ctrl_byte):
# TODO possibly add timeout - allow interruption only after a long time
return InterruptingInitPacket(report)
# ignore continuation packets on different channels
if cont_cid != header.cid:
continue
# buffer the continuation data
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
async def write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False
) -> None:
session = THP.get_session_from_id(message.session_id)
cid = THP.get_cid(session)
payload = message.type.to_bytes(2, "big") + message.data
payload_length = len(payload)
if THP.get_state(session) == SessionState.INITIALIZED:
# write message in plaintext, TODO check if it is allowed
ctrl_byte = _PLAINTEXT
elif THP.get_state(session) == SessionState.APP_TRAFFIC:
ctrl_byte = ENCRYPTED_TRANSPORT
else:
raise ThpError("Session in not implemented state" + str(THP.get_state(session)))
if not is_retransmission:
ctrl_byte = _add_sync_bit_to_ctrl_byte(
ctrl_byte, THP.sync_get_send_bit(session)
)
THP.sync_set_send_bit_to_opposite(session)
else:
# retransmission must have the same sync bit as the previously sent message
ctrl_byte = _add_sync_bit_to_ctrl_byte(ctrl_byte, 1 - THP.sync_get_send_bit())
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
await write_to_wire(iface, header, payload + checksum)
# TODO set timeout for retransmission
async def write_to_wire(
iface: WireInterface, header: InitHeader, payload: bytes
) -> None:
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
payload_length = len(payload)
# prepare the report buffer with header data
report = bytearray(_REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
await _write_report(write, iface, report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_length:
header.pack_to_cont_buffer(report)
while nwritten < payload_length:
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
await _write_report(write, iface, report)
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
while True:
await write
n = iface.write(report)
if n == len(report):
return
def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None:
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:])
if not _is_checksum_valid(checksum, data=report[:-4]):
raise ThpError("Checksum is not valid")
channel_id = _get_new_channel_id()
THP.create_new_unauthenticated_session(iface, channel_id)
response_data = (
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
)
response_header = InitHeader(
_CHANNEL_ALLOCATION_RES,
BROADCAST_CHANNEL_ID,
len(response_data) + _CHECKSUM_LENGTH,
)
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
write_to_wire(iface, response_header, response_data + checksum)
def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
# Parameters session and ctrl_byte will be used to determine if the
# communication should be encrypted or not
message_type = ustruct.unpack(">H", payload)[0]
# trim message type and checksum from payload
message_data = payload[2:-_CHECKSUM_LENGTH]
return Message(message_type, message_data, session.session_id)
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
# No ACKs expected
if THP.sync_can_send_message(session):
return
# ACK has incorrect sync bit
if THP.sync_get_send_bit(session) != sync_bit:
return
# ACK is expected and it has correct sync bit
THP.sync_set_can_send_message(session, True)
async def _handle_unallocated(iface, cid) -> Message | None:
data = _UNALLOCATED_SESSION_ERROR
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + data)
write_to_wire(iface, header, data + checksum)
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes())
write_to_wire(iface, header, checksum)
def _handle_unexpected_sync_bit(
iface: WireInterface, cid: int, sync_bit: int
) -> Message | None:
_sendAck(iface, cid, sync_bit)
# TODO handle cancelation messages and messages on allocated channels without synchronization
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
def _get_new_channel_id() -> int:
return THP.get_next_channel_id()
def _is_checksum_valid(checksum: bytearray, data: bytearray) -> bool:
data_checksum = _compute_checksum_bytes(data)
return checksum == data_checksum
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
return ctrl_byte & 0x20 == _ACK_MESSAGE
def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
def _compute_checksum_bytes(data: bytearray):
return crc.crc32(data).to_bytes(4, "big")

@ -1,17 +1,26 @@
from common import * # isort:skip from common import * # isort:skip
from mock_storage import mock_storage from mock_storage import mock_storage
from storage import cache
from trezor.messages import EndSession, Initialize from storage import cache, cache_codec, cache_thp
from storage.cache_common import InvalidSessionError
from trezor import utils
from trezor.messages import Initialize
from trezor.messages import EndSession
from apps.base import handle_EndSession, handle_Initialize from apps.base import handle_EndSession, handle_Initialize
KEY = 0 KEY = 0
if utils.USE_THP:
_PROTOCOL_CACHE = cache_thp
else:
_PROTOCOL_CACHE = cache_codec
# Function moved from cache.py, as it was not used there # Function moved from cache.py, as it was not used there
def is_session_started() -> bool: def is_session_started() -> bool:
return cache._active_session_idx is not None return _PROTOCOL_CACHE.get_active_session() is not None
class TestStorageCache(unittest.TestCase): class TestStorageCache(unittest.TestCase):
@ -25,9 +34,9 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b) self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all() cache.clear_all()
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(InvalidSessionError):
cache.set(KEY, "something") cache.set(KEY, "something")
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(InvalidSessionError):
cache.get(KEY) cache.get(KEY)
def test_end_session(self): def test_end_session(self):
@ -36,7 +45,7 @@ class TestStorageCache(unittest.TestCase):
cache.set(KEY, b"A") cache.set(KEY, b"A")
cache.end_current_session() cache.end_current_session()
self.assertFalse(is_session_started()) self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY) self.assertRaises(InvalidSessionError, cache.get, KEY)
# ending an ended session should be a no-op # ending an ended session should be a no-op
cache.end_current_session() cache.end_current_session()
@ -63,7 +72,7 @@ class TestStorageCache(unittest.TestCase):
session_id = cache.start_session() session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id) self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, b"A") cache.set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT): for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache.start_session() cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id) self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY)) self.assertIsNone(cache.get(KEY))
@ -83,7 +92,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(cache.get(KEY), b"hello")
cache.clear_all() cache.clear_all()
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(InvalidSessionError):
cache.get(KEY) cache.get(KEY)
def test_get_set_int(self): def test_get_set_int(self):
@ -101,7 +110,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get_int(KEY), 1234) self.assertEqual(cache.get_int(KEY), 1234)
cache.clear_all() cache.clear_all()
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(InvalidSessionError):
cache.get_int(KEY) cache.get_int(KEY)
def test_delete(self): def test_delete(self):
@ -186,6 +195,9 @@ class TestStorageCache(unittest.TestCase):
@mock_storage @mock_storage
def test_Initialize(self): def test_Initialize(self):
if utils.USE_THP: # INITIALIZE SHOULD NOT BE IN THP!!! TODO
return
def call_Initialize(**kwargs): def call_Initialize(**kwargs):
msg = Initialize(**kwargs) msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg)) return await_result(handle_Initialize(msg))
@ -210,7 +222,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache # supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY)) self.assertIsNone(cache.get(KEY))
# but resuming a session loads the previous one # but resuming a session loads the previous one
@ -218,13 +230,18 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(cache.get(KEY), b"hello")
def test_EndSession(self): def test_EndSession(self):
<<<<<<< HEAD
self.assertRaises(cache.InvalidSessionError, cache.get, KEY) self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
cache.start_session() cache.start_session()
=======
self.assertRaises(InvalidSessionError, cache.get, KEY)
session_id = cache.start_session()
>>>>>>> 8681ba167 (Basic THP functinality - not-polished prototype)
self.assertTrue(is_session_started()) self.assertTrue(is_session_started())
self.assertIsNone(cache.get(KEY)) self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(EndSession())) await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started()) self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY) self.assertRaises(InvalidSessionError, cache.get, KEY)
if __name__ == "__main__": if __name__ == "__main__":

@ -0,0 +1,341 @@
from common import *
from ubinascii import hexlify, unhexlify
import ustruct
from trezor import io, utils
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import thp_v1
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import Message
import trezor.wire.thp_session as THP
from micropython import const
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
MESSAGE_TYPE_BYTES = b"\x42\x42"
_MESSAGE_TYPE_LEN = 2
PLAINTEXT_0 = 0x01
PLAINTEXT_1 = 0x11
COMMON_CID = 4660
CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
INIT_MESSAGE_DATA_LENGTH = (
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
)
def make_header(ctrl_byte, cid, length):
return ustruct.pack(">BHH", ctrl_byte, cid, length)
def make_cont_header():
return ustruct.pack(">BH", CONT, COMMON_CID)
def makeSimpleMessage(header, message_type, message_data):
return header + ustruct.pack(">H", message_type) + message_data
def makeCidRequest(header, message_data):
return header + message_data
def printBytes(a):
print(hexlify(a).decode("utf-8"))
def getPlaintext() -> bytes:
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
PLAINTEXT_0
def getCid() -> int:
return THP.get_cid(THP.get_active_session())
# This test suite is an adaptation of test_trezor.wire.codec_v1
class TestWireTrezorHostProtocolV1(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
if not utils.USE_THP:
import storage.cache_thp # noQA:F401
def test_simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3C\x6C"
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
cid_request_dummy_data_checksum = b"\x67\x8E\xAC\xE0"
message = makeSimpleMessage(
message_header,
MESSAGE_TYPE,
cid_request_dummy_data + cid_request_dummy_data_checksum,
)
buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(cid_req_message)
gen.send(message)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, cid_request_dummy_data)
buffer_without_zeroes = buffer[: len(message) - 5]
message_without_header = message[5:]
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def test_read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
)
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + checksum
buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(message)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58)
def test_read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
)
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data
# other packets are cont header + 61 bytes of data
cont_header = make_cont_header()
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk
for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + checksum,
64 - HEADER_CONT_LENGTH,
)
]
buffer = bytearray(262)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(packets[-1])
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum)
def test_read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH,
)
packet = (
header
+ MESSAGE_TYPE_BYTES
+ message
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
)
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(packet)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_write_one_packet(self):
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID))
gen = thp_v1.write_message(self.interface, message)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
)
expected_message = (
header
+ MESSAGE_TYPE_BYTES
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH)
)
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
header = make_header(
PLAINTEXT_1,
COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
)
cont_header = make_cont_header()
checksum = thp_v1._compute_checksum_bytes(
header + message.type.to_bytes(2, "big") + message.data
)
packets = [
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
] + [
cont_header + chunk
for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum,
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
message_payload = bytes(range(256))
message = Message(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(self.interface.data[-1])
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH
- HEADER_CONT_LENGTH
- _CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
buffer = bytearray(65536)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
# THP returns "Message too large" error after reading the message size,
# it is different from codec_v1 as it does not allow big enough messages
# to raise MemoryError in this test
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(thp_v1.ThpError) as e:
query = gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save