mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-24 05:12:02 +00:00
Basic THP functinality - not-polished prototype
This commit is contained in:
parent
89fdaed31e
commit
94278d5c01
14
core/src/all_modules.py
generated
14
core/src/all_modules.py
generated
@ -47,6 +47,12 @@ storage
|
|||||||
import storage
|
import storage
|
||||||
storage.cache
|
storage.cache
|
||||||
import storage.cache
|
import storage.cache
|
||||||
|
storage.cache_codec
|
||||||
|
import storage.cache_codec
|
||||||
|
storage.cache_common
|
||||||
|
import storage.cache_common
|
||||||
|
storage.cache_thp
|
||||||
|
import storage.cache_thp
|
||||||
storage.common
|
storage.common
|
||||||
import storage.common
|
import storage.common
|
||||||
storage.debug
|
storage.debug
|
||||||
@ -195,6 +201,14 @@ trezor.wire.context
|
|||||||
import trezor.wire.context
|
import trezor.wire.context
|
||||||
trezor.wire.errors
|
trezor.wire.errors
|
||||||
import trezor.wire.errors
|
import trezor.wire.errors
|
||||||
|
trezor.wire.protocol
|
||||||
|
import trezor.wire.protocol
|
||||||
|
trezor.wire.protocol_common
|
||||||
|
import trezor.wire.protocol_common
|
||||||
|
trezor.wire.thp_session
|
||||||
|
import trezor.wire.thp_session
|
||||||
|
trezor.wire.thp_v1
|
||||||
|
import trezor.wire.thp_v1
|
||||||
trezor.workflow
|
trezor.workflow
|
||||||
import trezor.workflow
|
import trezor.workflow
|
||||||
apps
|
apps
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import storage.cache as storage_cache
|
import storage.cache as storage_cache
|
||||||
|
import storage.cache_thp as storage_thp_cache
|
||||||
import storage.device as storage_device
|
import storage.device as storage_device
|
||||||
from trezor import TR, config, utils, wire, workflow
|
from trezor import TR, config, utils, wire, workflow
|
||||||
from trezor.enums import HomescreenFormat, MessageType
|
from trezor.enums import HomescreenFormat, MessageType
|
||||||
@ -174,10 +175,21 @@ def get_features() -> Features:
|
|||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
async def handle_Initialize(msg: Initialize) -> Features:
|
# handle_Initialize should not be used with THP to start a new session
|
||||||
|
async def handle_Initialize(
|
||||||
|
msg: Initialize, message_session_id: bytearray | None = None
|
||||||
|
) -> Features:
|
||||||
|
if message_session_id is None and utils.USE_THP:
|
||||||
|
raise ValueError("With THP enabled, a session id must be provided in args")
|
||||||
|
|
||||||
|
if utils.USE_THP:
|
||||||
|
session_id = storage_thp_cache.start_existing_session(msg.session_id)
|
||||||
|
else:
|
||||||
session_id = storage_cache.start_session(msg.session_id)
|
session_id = storage_cache.start_session(msg.session_id)
|
||||||
|
|
||||||
if not utils.BITCOIN_ONLY:
|
if not utils.BITCOIN_ONLY:
|
||||||
|
# TODO this block should be changed in THP
|
||||||
|
|
||||||
derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO)
|
derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO)
|
||||||
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
|
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
|
||||||
|
|
||||||
@ -189,7 +201,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
|
|||||||
# seed is already derived, and host wants to change derive_cardano setting
|
# seed is already derived, and host wants to change derive_cardano setting
|
||||||
# => create a new session
|
# => create a new session
|
||||||
storage_cache.end_current_session()
|
storage_cache.end_current_session()
|
||||||
session_id = storage_cache.start_session()
|
session_id = storage_cache.start_session() # This should not be used in THP
|
||||||
have_seed = False
|
have_seed = False
|
||||||
|
|
||||||
if not have_seed:
|
if not have_seed:
|
||||||
@ -199,7 +211,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
|
|||||||
)
|
)
|
||||||
|
|
||||||
features = get_features()
|
features = get_features()
|
||||||
features.session_id = session_id
|
features.session_id = session_id # not important in THP
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,17 +4,13 @@ from micropython import const
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import utils
|
from trezor import utils
|
||||||
|
from storage.cache_common import SESSIONLESS_FLAG, InvalidSessionError, SessionlessCache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Sequence, TypeVar, overload
|
from typing import Sequence, TypeVar, overload
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
_MAX_SESSIONS_COUNT = const(10)
|
|
||||||
_SESSIONLESS_FLAG = const(128)
|
|
||||||
_SESSION_ID_LENGTH = const(32)
|
|
||||||
|
|
||||||
# Traditional cache keys
|
# Traditional cache keys
|
||||||
APP_COMMON_SEED = const(0)
|
APP_COMMON_SEED = const(0)
|
||||||
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
||||||
@ -27,14 +23,13 @@ if not utils.BITCOIN_ONLY:
|
|||||||
APP_MONERO_LIVE_REFRESH = const(7)
|
APP_MONERO_LIVE_REFRESH = const(7)
|
||||||
|
|
||||||
# Keys that are valid across sessions
|
# Keys that are valid across sessions
|
||||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
|
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
|
||||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
|
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
|
||||||
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | _SESSIONLESS_FLAG)
|
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | SESSIONLESS_FLAG)
|
||||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | _SESSIONLESS_FLAG)
|
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | SESSIONLESS_FLAG)
|
||||||
APP_COMMON_BUSY_DEADLINE_MS = const(4 | _SESSIONLESS_FLAG)
|
APP_COMMON_BUSY_DEADLINE_MS = const(4 | SESSIONLESS_FLAG)
|
||||||
APP_MISC_COSI_NONCE = const(5 | _SESSIONLESS_FLAG)
|
APP_MISC_COSI_NONCE = const(5 | SESSIONLESS_FLAG)
|
||||||
APP_MISC_COSI_COMMITMENT = const(6 | _SESSIONLESS_FLAG)
|
APP_MISC_COSI_COMMITMENT = const(6 | SESSIONLESS_FLAG)
|
||||||
|
|
||||||
|
|
||||||
# === Homescreen storage ===
|
# === Homescreen storage ===
|
||||||
# This does not logically belong to the "cache" functionality, but the cache module is
|
# This does not logically belong to the "cache" functionality, but the cache module is
|
||||||
@ -52,103 +47,6 @@ homescreen_shown: object | None = None
|
|||||||
autolock_last_touch: int | None = None
|
autolock_last_touch: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class InvalidSessionError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DataCache:
|
|
||||||
fields: Sequence[int]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.data = [bytearray(f + 1) for f in self.fields]
|
|
||||||
|
|
||||||
def set(self, key: int, value: bytes) -> None:
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
utils.ensure(len(value) <= self.fields[key])
|
|
||||||
self.data[key][0] = 1
|
|
||||||
self.data[key][1:] = value
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: int) -> bytes | None: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
|
||||||
...
|
|
||||||
|
|
||||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
if self.data[key][0] != 1:
|
|
||||||
return default
|
|
||||||
return bytes(self.data[key][1:])
|
|
||||||
|
|
||||||
def is_set(self, key: int) -> bool:
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
return self.data[key][0] == 1
|
|
||||||
|
|
||||||
def delete(self, key: int) -> None:
|
|
||||||
utils.ensure(key < len(self.fields))
|
|
||||||
self.data[key][:] = b"\x00"
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
for i in range(len(self.fields)):
|
|
||||||
self.delete(i)
|
|
||||||
|
|
||||||
|
|
||||||
class SessionCache(DataCache):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
|
||||||
if utils.BITCOIN_ONLY:
|
|
||||||
self.fields = (
|
|
||||||
64, # APP_COMMON_SEED
|
|
||||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
|
||||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
|
||||||
32, # APP_COMMON_NONCE
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.fields = (
|
|
||||||
64, # APP_COMMON_SEED
|
|
||||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
|
||||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
|
||||||
32, # APP_COMMON_NONCE
|
|
||||||
1, # APP_COMMON_DERIVE_CARDANO
|
|
||||||
96, # APP_CARDANO_ICARUS_SECRET
|
|
||||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
|
||||||
1, # APP_MONERO_LIVE_REFRESH
|
|
||||||
)
|
|
||||||
self.last_usage = 0
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def export_session_id(self) -> bytes:
|
|
||||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
|
||||||
|
|
||||||
# generate a new session id if we don't have it yet
|
|
||||||
if not self.session_id:
|
|
||||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
|
||||||
# export it as immutable bytes
|
|
||||||
return bytes(self.session_id)
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
super().clear()
|
|
||||||
self.last_usage = 0
|
|
||||||
self.session_id[:] = b""
|
|
||||||
|
|
||||||
|
|
||||||
class SessionlessCache(DataCache):
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.fields = (
|
|
||||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
|
||||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
|
||||||
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
|
|
||||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
|
||||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
|
||||||
32, # APP_MISC_COSI_NONCE
|
|
||||||
32, # APP_MISC_COSI_COMMITMENT
|
|
||||||
)
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
|
|
||||||
# XXX
|
# XXX
|
||||||
# Allocation notes:
|
# Allocation notes:
|
||||||
# Instantiation of a DataCache subclass should make as little garbage as possible, so
|
# Instantiation of a DataCache subclass should make as little garbage as possible, so
|
||||||
@ -157,97 +55,46 @@ class SessionlessCache(DataCache):
|
|||||||
# bytearrays, then later call `clear()` on all the existing objects, which resets them
|
# bytearrays, then later call `clear()` on all the existing objects, which resets them
|
||||||
# to zero length. This is producing some trash - `b[:]` allocates a slice.
|
# to zero length. This is producing some trash - `b[:]` allocates a slice.
|
||||||
|
|
||||||
_SESSIONS: list[SessionCache] = []
|
|
||||||
for _ in range(_MAX_SESSIONS_COUNT):
|
|
||||||
_SESSIONS.append(SessionCache())
|
|
||||||
|
|
||||||
_SESSIONLESS_CACHE = SessionlessCache()
|
_SESSIONLESS_CACHE = SessionlessCache()
|
||||||
|
|
||||||
for session in _SESSIONS:
|
|
||||||
session.clear()
|
if utils.USE_THP:
|
||||||
|
from storage import cache_thp
|
||||||
|
|
||||||
|
_PROTOCOL_CACHE = cache_thp
|
||||||
|
else:
|
||||||
|
from storage import cache_codec
|
||||||
|
|
||||||
|
_PROTOCOL_CACHE = cache_codec
|
||||||
|
|
||||||
|
_PROTOCOL_CACHE.initialize()
|
||||||
_SESSIONLESS_CACHE.clear()
|
_SESSIONLESS_CACHE.clear()
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
_active_session_idx: int | None = None
|
def clear_all() -> None:
|
||||||
_session_usage_counter = 0
|
global autolock_last_touch
|
||||||
|
autolock_last_touch = None
|
||||||
|
_SESSIONLESS_CACHE.clear()
|
||||||
|
_PROTOCOL_CACHE.clear_all()
|
||||||
|
|
||||||
|
|
||||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||||
global _active_session_idx
|
return _PROTOCOL_CACHE.start_session(received_session_id)
|
||||||
global _session_usage_counter
|
|
||||||
|
|
||||||
if (
|
|
||||||
received_session_id is not None
|
|
||||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
|
||||||
):
|
|
||||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
|
||||||
# session. More generally, short-circuit the session id search, because we know
|
|
||||||
# that wrong-length session ids should not be in cache.
|
|
||||||
# Reduce to "session id not provided" case because that's what we do when
|
|
||||||
# caller supplies an id that is not found.
|
|
||||||
received_session_id = None
|
|
||||||
|
|
||||||
_session_usage_counter += 1
|
|
||||||
|
|
||||||
# attempt to find specified session id
|
|
||||||
if received_session_id:
|
|
||||||
for i in range(_MAX_SESSIONS_COUNT):
|
|
||||||
if _SESSIONS[i].session_id == received_session_id:
|
|
||||||
_active_session_idx = i
|
|
||||||
_SESSIONS[i].last_usage = _session_usage_counter
|
|
||||||
return received_session_id
|
|
||||||
|
|
||||||
# allocate least recently used session
|
|
||||||
lru_counter = _session_usage_counter
|
|
||||||
lru_session_idx = 0
|
|
||||||
for i in range(_MAX_SESSIONS_COUNT):
|
|
||||||
if _SESSIONS[i].last_usage < lru_counter:
|
|
||||||
lru_counter = _SESSIONS[i].last_usage
|
|
||||||
lru_session_idx = i
|
|
||||||
|
|
||||||
_active_session_idx = lru_session_idx
|
|
||||||
selected_session = _SESSIONS[lru_session_idx]
|
|
||||||
selected_session.clear()
|
|
||||||
selected_session.last_usage = _session_usage_counter
|
|
||||||
return selected_session.export_session_id()
|
|
||||||
|
|
||||||
|
|
||||||
def end_current_session() -> None:
|
def end_current_session() -> None:
|
||||||
global _active_session_idx
|
_PROTOCOL_CACHE.end_current_session()
|
||||||
|
|
||||||
if _active_session_idx is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
_SESSIONS[_active_session_idx].clear()
|
|
||||||
_active_session_idx = None
|
|
||||||
|
|
||||||
|
|
||||||
def set(key: int, value: bytes) -> None:
|
def delete(key: int) -> None:
|
||||||
if key & _SESSIONLESS_FLAG:
|
if key & SESSIONLESS_FLAG:
|
||||||
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
|
return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG)
|
||||||
return
|
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||||
if _active_session_idx is None:
|
if active_session is None:
|
||||||
raise InvalidSessionError
|
raise InvalidSessionError
|
||||||
_SESSIONS[_active_session_idx].set(key, value)
|
return active_session.delete(key)
|
||||||
|
|
||||||
|
|
||||||
def set_int(key: int, value: int) -> None:
|
|
||||||
if key & _SESSIONLESS_FLAG:
|
|
||||||
length = _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
|
|
||||||
elif _active_session_idx is None:
|
|
||||||
raise InvalidSessionError
|
|
||||||
else:
|
|
||||||
length = _SESSIONS[_active_session_idx].fields[key]
|
|
||||||
|
|
||||||
encoded = value.to_bytes(length, "big")
|
|
||||||
|
|
||||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
|
||||||
# doesn't raise OverflowError.
|
|
||||||
assert int.from_bytes(encoded, "big") == value
|
|
||||||
|
|
||||||
set(key, encoded)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -261,11 +108,12 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||||
if key & _SESSIONLESS_FLAG:
|
if key & SESSIONLESS_FLAG:
|
||||||
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
|
return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default)
|
||||||
if _active_session_idx is None:
|
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||||
|
if active_session is None:
|
||||||
raise InvalidSessionError
|
raise InvalidSessionError
|
||||||
return _SESSIONS[_active_session_idx].get(key, default)
|
return active_session.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||||
@ -277,29 +125,52 @@ def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
|||||||
|
|
||||||
|
|
||||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||||
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
|
if key & SESSIONLESS_FLAG:
|
||||||
values = builtins.set()
|
values = builtins.set()
|
||||||
for session in sessions:
|
encoded = _SESSIONLESS_CACHE.get(key)
|
||||||
encoded = session.get(key)
|
|
||||||
if encoded is not None:
|
if encoded is not None:
|
||||||
values.add(int.from_bytes(encoded, "big"))
|
values.add(int.from_bytes(encoded, "big"))
|
||||||
return values
|
return values
|
||||||
|
return _PROTOCOL_CACHE.get_int_all_sessions(key)
|
||||||
|
|
||||||
|
|
||||||
def is_set(key: int) -> bool:
|
def is_set(key: int) -> bool:
|
||||||
if key & _SESSIONLESS_FLAG:
|
if key & SESSIONLESS_FLAG:
|
||||||
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
|
return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG)
|
||||||
if _active_session_idx is None:
|
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||||
|
if active_session is None:
|
||||||
raise InvalidSessionError
|
raise InvalidSessionError
|
||||||
return _SESSIONS[_active_session_idx].is_set(key)
|
return active_session.is_set(key)
|
||||||
|
|
||||||
|
|
||||||
def delete(key: int) -> None:
|
def set(key: int, value: bytes) -> None:
|
||||||
if key & _SESSIONLESS_FLAG:
|
if key & SESSIONLESS_FLAG:
|
||||||
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
|
_SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value)
|
||||||
if _active_session_idx is None:
|
return
|
||||||
|
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||||
|
if active_session is None:
|
||||||
raise InvalidSessionError
|
raise InvalidSessionError
|
||||||
return _SESSIONS[_active_session_idx].delete(key)
|
active_session.set(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def set_int(key: int, value: int) -> None:
|
||||||
|
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||||
|
|
||||||
|
if key & SESSIONLESS_FLAG:
|
||||||
|
length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG]
|
||||||
|
|
||||||
|
if active_session is None:
|
||||||
|
raise InvalidSessionError
|
||||||
|
else:
|
||||||
|
length = active_session.fields[key]
|
||||||
|
|
||||||
|
encoded = value.to_bytes(length, "big")
|
||||||
|
|
||||||
|
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||||
|
# doesn't raise OverflowError.
|
||||||
|
assert int.from_bytes(encoded, "big") == value
|
||||||
|
|
||||||
|
set(key, encoded)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -336,15 +207,3 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def clear_all() -> None:
|
|
||||||
global _active_session_idx
|
|
||||||
global autolock_last_touch
|
|
||||||
|
|
||||||
_active_session_idx = None
|
|
||||||
_SESSIONLESS_CACHE.clear()
|
|
||||||
for session in _SESSIONS:
|
|
||||||
session.clear()
|
|
||||||
|
|
||||||
autolock_last_touch = None
|
|
||||||
|
144
core/src/storage/cache_codec.py
Normal file
144
core/src/storage/cache_codec.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import builtins
|
||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from trezor import utils
|
||||||
|
from storage.cache_common import DataCache, InvalidSessionError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Sequence, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
_MAX_SESSIONS_COUNT = const(10)
|
||||||
|
_SESSION_ID_LENGTH = const(32)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCache(DataCache):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||||
|
if utils.BITCOIN_ONLY:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED
|
||||||
|
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||||
|
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||||
|
32, # APP_COMMON_NONCE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED
|
||||||
|
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||||
|
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||||
|
32, # APP_COMMON_NONCE
|
||||||
|
1, # APP_COMMON_DERIVE_CARDANO
|
||||||
|
96, # APP_CARDANO_ICARUS_SECRET
|
||||||
|
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||||
|
1, # APP_MONERO_LIVE_REFRESH
|
||||||
|
)
|
||||||
|
self.last_usage = 0
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def export_session_id(self) -> bytes:
|
||||||
|
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||||
|
|
||||||
|
# generate a new session id if we don't have it yet
|
||||||
|
if not self.session_id:
|
||||||
|
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||||
|
# export it as immutable bytes
|
||||||
|
return bytes(self.session_id)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
super().clear()
|
||||||
|
self.last_usage = 0
|
||||||
|
self.session_id[:] = b""
|
||||||
|
|
||||||
|
|
||||||
|
_SESSIONS: list[SessionCache] = []
|
||||||
|
|
||||||
|
|
||||||
|
def initialize() -> None:
|
||||||
|
global _SESSIONS
|
||||||
|
for _ in range(_MAX_SESSIONS_COUNT):
|
||||||
|
_SESSIONS.append(SessionCache())
|
||||||
|
|
||||||
|
for session in _SESSIONS:
|
||||||
|
session.clear()
|
||||||
|
|
||||||
|
|
||||||
|
initialize()
|
||||||
|
|
||||||
|
|
||||||
|
_active_session_idx: int | None = None
|
||||||
|
_session_usage_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_session() -> SessionCache | None:
|
||||||
|
if _active_session_idx is None:
|
||||||
|
return None
|
||||||
|
return _SESSIONS[_active_session_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||||
|
global _active_session_idx
|
||||||
|
global _session_usage_counter
|
||||||
|
|
||||||
|
if (
|
||||||
|
received_session_id is not None
|
||||||
|
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||||
|
):
|
||||||
|
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||||
|
# session. More generally, short-circuit the session id search, because we know
|
||||||
|
# that wrong-length session ids should not be in cache.
|
||||||
|
# Reduce to "session id not provided" case because that's what we do when
|
||||||
|
# caller supplies an id that is not found.
|
||||||
|
received_session_id = None
|
||||||
|
|
||||||
|
_session_usage_counter += 1
|
||||||
|
|
||||||
|
# attempt to find specified session id
|
||||||
|
if received_session_id:
|
||||||
|
for i in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[i].session_id == received_session_id:
|
||||||
|
_active_session_idx = i
|
||||||
|
_SESSIONS[i].last_usage = _session_usage_counter
|
||||||
|
return received_session_id
|
||||||
|
|
||||||
|
# allocate least recently used session
|
||||||
|
lru_counter = _session_usage_counter
|
||||||
|
lru_session_idx = 0
|
||||||
|
for i in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[i].last_usage < lru_counter:
|
||||||
|
lru_counter = _SESSIONS[i].last_usage
|
||||||
|
lru_session_idx = i
|
||||||
|
|
||||||
|
_active_session_idx = lru_session_idx
|
||||||
|
selected_session = _SESSIONS[lru_session_idx]
|
||||||
|
selected_session.clear()
|
||||||
|
selected_session.last_usage = _session_usage_counter
|
||||||
|
return selected_session.export_session_id()
|
||||||
|
|
||||||
|
|
||||||
|
def end_current_session() -> None:
|
||||||
|
global _active_session_idx
|
||||||
|
|
||||||
|
if _active_session_idx is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
_SESSIONS[_active_session_idx].clear()
|
||||||
|
_active_session_idx = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||||
|
values = builtins.set()
|
||||||
|
for session in _SESSIONS:
|
||||||
|
encoded = session.get(key)
|
||||||
|
if encoded is not None:
|
||||||
|
values.add(int.from_bytes(encoded, "big"))
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all() -> None:
|
||||||
|
global _active_session_idx
|
||||||
|
_active_session_idx = None
|
||||||
|
for session in _SESSIONS:
|
||||||
|
session.clear()
|
70
core/src/storage/cache_common.py
Normal file
70
core/src/storage/cache_common.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from trezor import utils
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Sequence, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
SESSIONLESS_FLAG = const(128)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSessionError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataCache:
|
||||||
|
fields: Sequence[int]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.data = [bytearray(f + 1) for f in self.fields]
|
||||||
|
|
||||||
|
def set(self, key: int, value: bytes) -> None:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
utils.ensure(len(value) <= self.fields[key])
|
||||||
|
self.data[key][0] = 1
|
||||||
|
self.data[key][1:] = value
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: int) -> bytes | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||||
|
...
|
||||||
|
|
||||||
|
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
if self.data[key][0] != 1:
|
||||||
|
return default
|
||||||
|
return bytes(self.data[key][1:])
|
||||||
|
|
||||||
|
def is_set(self, key: int) -> bool:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
return self.data[key][0] == 1
|
||||||
|
|
||||||
|
def delete(self, key: int) -> None:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
self.data[key][:] = b"\x00"
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
for i in range(len(self.fields)):
|
||||||
|
self.delete(i)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionlessCache(DataCache):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||||
|
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||||
|
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
|
||||||
|
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||||
|
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||||
|
32, # APP_MISC_COSI_NONCE
|
||||||
|
32, # APP_MISC_COSI_COMMITMENT
|
||||||
|
)
|
||||||
|
super().__init__()
|
256
core/src/storage/cache_thp.py
Normal file
256
core/src/storage/cache_thp.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
import builtins
|
||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from storage.cache_common import DataCache, InvalidSessionError
|
||||||
|
from trezor import utils
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Sequence, TypeVar, overload
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# THP specific constants
|
||||||
|
_MAX_SESSIONS_COUNT = const(20)
|
||||||
|
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5)
|
||||||
|
_THP_SESSION_STATE_LENGTH = const(1)
|
||||||
|
_SESSION_ID_LENGTH = const(4)
|
||||||
|
BROADCAST_CHANNEL_ID = const(65535)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||||
|
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
|
||||||
|
if utils.BITCOIN_ONLY:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED
|
||||||
|
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||||
|
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||||
|
32, # APP_COMMON_NONCE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED
|
||||||
|
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||||
|
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||||
|
32, # APP_COMMON_NONCE
|
||||||
|
1, # APP_COMMON_DERIVE_CARDANO
|
||||||
|
96, # APP_CARDANO_ICARUS_SECRET
|
||||||
|
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||||
|
1, # APP_MONERO_LIVE_REFRESH
|
||||||
|
)
|
||||||
|
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
|
||||||
|
self.last_usage = 0
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def export_session_id(self) -> bytes:
|
||||||
|
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||||
|
|
||||||
|
# generate a new session id if we don't have it yet
|
||||||
|
if not self.session_id:
|
||||||
|
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||||
|
# export it as immutable bytes
|
||||||
|
return bytes(self.session_id)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
super().clear()
|
||||||
|
self.last_usage = 0
|
||||||
|
self.session_id[:] = b""
|
||||||
|
|
||||||
|
|
||||||
|
_SESSIONS: list[SessionThpCache] = []
|
||||||
|
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = []
|
||||||
|
|
||||||
|
|
||||||
|
def initialize() -> None:
|
||||||
|
global _SESSIONS
|
||||||
|
global _UNAUTHENTICATED_SESSIONS
|
||||||
|
|
||||||
|
for _ in range(_MAX_SESSIONS_COUNT):
|
||||||
|
_SESSIONS.append(SessionThpCache())
|
||||||
|
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||||
|
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
|
||||||
|
|
||||||
|
for session in _SESSIONS:
|
||||||
|
session.clear()
|
||||||
|
for session in _UNAUTHENTICATED_SESSIONS:
|
||||||
|
session.clear()
|
||||||
|
|
||||||
|
|
||||||
|
initialize()
|
||||||
|
|
||||||
|
|
||||||
|
# THP vars
|
||||||
|
_next_unauthenicated_session_index: int = 0
|
||||||
|
_is_active_session_authenticated: bool
|
||||||
|
_active_session_idx: int | None = None
|
||||||
|
_session_usage_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
|
||||||
|
cid_counter: int = 4659
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_session_id():
|
||||||
|
if get_active_session() is None:
|
||||||
|
return None
|
||||||
|
return get_active_session().session_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_session() -> SessionThpCache | None:
|
||||||
|
if _active_session_idx is None:
|
||||||
|
return None
|
||||||
|
if _is_active_session_authenticated:
|
||||||
|
return _SESSIONS[_active_session_idx]
|
||||||
|
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_channel_id() -> int:
|
||||||
|
global cid_counter
|
||||||
|
while True:
|
||||||
|
cid_counter += 1
|
||||||
|
if cid_counter >= BROADCAST_CHANNEL_ID:
|
||||||
|
cid_counter = 1
|
||||||
|
if _is_cid_unique():
|
||||||
|
break
|
||||||
|
return cid_counter
|
||||||
|
|
||||||
|
|
||||||
|
def _is_cid_unique() -> bool:
|
||||||
|
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||||
|
if cid_counter == _get_cid(session):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cid(session: SessionThpCache) -> int:
|
||||||
|
return int.from_bytes(session.session_id[2:], "big")
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_unauthenticated_session(session_id: bytearray) -> SessionThpCache:
|
||||||
|
if len(session_id) != 4:
|
||||||
|
raise ValueError("session_id must be 4 bytes long.")
|
||||||
|
global _active_session_idx
|
||||||
|
global _is_active_session_authenticated
|
||||||
|
global _next_unauthenicated_session_index
|
||||||
|
|
||||||
|
i = _next_unauthenicated_session_index
|
||||||
|
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
|
||||||
|
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
|
||||||
|
_next_unauthenicated_session_index += 1
|
||||||
|
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
|
||||||
|
_next_unauthenicated_session_index = 0
|
||||||
|
|
||||||
|
# Set session as active if and only if there is no active session
|
||||||
|
if _active_session_idx is None:
|
||||||
|
_active_session_idx = i
|
||||||
|
_is_active_session_authenticated = False
|
||||||
|
return _UNAUTHENTICATED_SESSIONS[i]
|
||||||
|
|
||||||
|
|
||||||
|
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
|
||||||
|
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||||
|
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
|
||||||
|
return i
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
|
||||||
|
global _session_usage_counter
|
||||||
|
|
||||||
|
unauth_session_idx = get_unauth_session_index(unauth_session)
|
||||||
|
if unauth_session_idx is None:
|
||||||
|
raise InvalidSessionError
|
||||||
|
|
||||||
|
# replace least recently used authenticated session by the new session
|
||||||
|
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
|
||||||
|
|
||||||
|
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
|
||||||
|
_UNAUTHENTICATED_SESSIONS[unauth_session_idx] = None
|
||||||
|
|
||||||
|
_session_usage_counter += 1
|
||||||
|
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
|
||||||
|
|
||||||
|
|
||||||
|
def get_least_recently_used_authetnicated_session_index() -> int:
|
||||||
|
lru_counter = _session_usage_counter
|
||||||
|
lru_session_idx = 0
|
||||||
|
for i in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[i].last_usage < lru_counter:
|
||||||
|
lru_counter = _SESSIONS[i].last_usage
|
||||||
|
lru_session_idx = i
|
||||||
|
return lru_session_idx
|
||||||
|
|
||||||
|
|
||||||
|
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
|
||||||
|
def start_session(session_id: bytes) -> bytes: # TODO incomplete
|
||||||
|
global _active_session_idx
|
||||||
|
global _is_active_session_authenticated
|
||||||
|
|
||||||
|
if session_id is not None:
|
||||||
|
if get_active_session_id() == session_id:
|
||||||
|
return session_id
|
||||||
|
for index in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[index].session_id == session_id:
|
||||||
|
_active_session_idx = index
|
||||||
|
_is_active_session_authenticated = True
|
||||||
|
return session_id
|
||||||
|
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||||
|
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||||
|
_active_session_idx = index
|
||||||
|
_is_active_session_authenticated = False
|
||||||
|
return session_id
|
||||||
|
new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big")
|
||||||
|
|
||||||
|
new_session = create_new_unauthenticated_session(new_session_id)
|
||||||
|
|
||||||
|
index = get_unauth_session_index(new_session)
|
||||||
|
_active_session_idx = index
|
||||||
|
_is_active_session_authenticated = False
|
||||||
|
|
||||||
|
return new_session_id
|
||||||
|
|
||||||
|
|
||||||
|
def start_existing_session(session_id: bytearray) -> bytes:
|
||||||
|
if session_id is None:
|
||||||
|
raise ValueError("session_id cannot be None")
|
||||||
|
if get_active_session_id() == session_id:
|
||||||
|
return session_id
|
||||||
|
for index in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[index].session_id == session_id:
|
||||||
|
_active_session_idx = index
|
||||||
|
_is_active_session_authenticated = True
|
||||||
|
return session_id
|
||||||
|
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||||
|
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||||
|
_active_session_idx = index
|
||||||
|
_is_active_session_authenticated = False
|
||||||
|
return session_id
|
||||||
|
raise ValueError("There is no active session with provided session_id")
|
||||||
|
|
||||||
|
|
||||||
|
def end_current_session() -> None:
|
||||||
|
global _active_session_idx
|
||||||
|
active_session = get_active_session()
|
||||||
|
if active_session is None:
|
||||||
|
return
|
||||||
|
active_session.clear()
|
||||||
|
_active_session_idx = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||||
|
values = builtins.set()
|
||||||
|
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
|
||||||
|
encoded = session.get(key)
|
||||||
|
if encoded is not None:
|
||||||
|
values.add(int.from_bytes(encoded, "big"))
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def clear_all() -> None:
|
||||||
|
global _active_session_idx
|
||||||
|
_active_session_idx = None
|
||||||
|
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||||
|
session.clear()
|
@ -32,6 +32,8 @@ MODEL_IS_T2B1: bool = INTERNAL_MODEL == "T2B1"
|
|||||||
|
|
||||||
DISABLE_ANIMATION = 0
|
DISABLE_ANIMATION = 0
|
||||||
|
|
||||||
|
USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/...
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
if EMULATOR:
|
if EMULATOR:
|
||||||
import uos
|
import uos
|
||||||
|
@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
|
|||||||
|
|
||||||
- Request / response.
|
- Request / response.
|
||||||
- Protobuf-encoded, see `protobuf.py`.
|
- Protobuf-encoded, see `protobuf.py`.
|
||||||
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
|
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_v1.py`.
|
||||||
- Transferred over USB interface, or UDP in case of Unix emulation.
|
- Transferred over USB interface, or UDP in case of Unix emulation.
|
||||||
|
|
||||||
This module:
|
This module:
|
||||||
@ -23,15 +23,17 @@ reads the message's header. When the message type is known the first handler is
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from apps import workflow_handlers
|
||||||
from micropython import const
|
from micropython import const
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from storage.cache import InvalidSessionError
|
from storage.cache_codec import InvalidSessionError
|
||||||
from trezor import log, loop, protobuf, utils, workflow
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
from trezor.enums import FailureType
|
from trezor.enums import FailureType
|
||||||
from trezor.messages import Failure
|
from trezor.messages import Failure
|
||||||
from trezor.wire import codec_v1, context
|
from trezor.wire import codec_v1, context, protocol_common
|
||||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||||
|
import trezor.enums.MessageType as MT
|
||||||
|
|
||||||
# Import all errors into namespace, so that `wire.Error` is available from
|
# Import all errors into namespace, so that `wire.Error` is available from
|
||||||
# other packages.
|
# other packages.
|
||||||
@ -88,8 +90,8 @@ if __debug__:
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_single_message(
|
async def _handle_single_message(
|
||||||
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool
|
ctx: context.Context, msg: protocol_common.Message, use_workflow: bool
|
||||||
) -> codec_v1.Message | None:
|
) -> protocol_common.Message | None:
|
||||||
"""Handle a message that was loaded from USB by the caller.
|
"""Handle a message that was loaded from USB by the caller.
|
||||||
|
|
||||||
Find the appropriate handler, run it and write its result on the wire. In case
|
Find the appropriate handler, run it and write its result on the wire. In case
|
||||||
@ -113,7 +115,7 @@ async def _handle_single_message(
|
|||||||
__name__,
|
__name__,
|
||||||
"%s:%x receive: <%s>",
|
"%s:%x receive: <%s>",
|
||||||
ctx.iface.iface_num(),
|
ctx.iface.iface_num(),
|
||||||
ctx.sid,
|
ctx.session_id,
|
||||||
msg_type,
|
msg_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -143,6 +145,10 @@ async def _handle_single_message(
|
|||||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||||
|
|
||||||
# Create the handler task.
|
# Create the handler task.
|
||||||
|
if msg.type is MT.Initialize:
|
||||||
|
# Special case for handle_initialize to have access to the verified session_id
|
||||||
|
task = handler(req_msg, ctx.session_id)
|
||||||
|
else:
|
||||||
task = handler(req_msg)
|
task = handler(req_msg)
|
||||||
|
|
||||||
# Run the workflow task. Workflow can do more on-the-wire
|
# Run the workflow task. Workflow can do more on-the-wire
|
||||||
@ -201,7 +207,7 @@ async def handle_session(
|
|||||||
ctx_buffer = WIRE_BUFFER
|
ctx_buffer = WIRE_BUFFER
|
||||||
|
|
||||||
ctx = context.Context(iface, session_id, ctx_buffer)
|
ctx = context.Context(iface, session_id, ctx_buffer)
|
||||||
next_msg: codec_v1.Message | None = None
|
next_msg: protocol_common.Message | None = None
|
||||||
|
|
||||||
if __debug__ and is_debug_session:
|
if __debug__ and is_debug_session:
|
||||||
import apps.debug
|
import apps.debug
|
||||||
@ -218,7 +224,7 @@ async def handle_session(
|
|||||||
# wait for a new one coming from the wire.
|
# wait for a new one coming from the wire.
|
||||||
try:
|
try:
|
||||||
msg = await ctx.read_from_wire()
|
msg = await ctx.read_from_wire()
|
||||||
except codec_v1.CodecError as exc:
|
except protocol_common.WireError as exc:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
await ctx.write(failure(exc))
|
await ctx.write(failure(exc))
|
||||||
@ -229,6 +235,9 @@ async def handle_session(
|
|||||||
msg = next_msg
|
msg = next_msg
|
||||||
next_msg = None
|
next_msg = None
|
||||||
|
|
||||||
|
# Set ctx.session_id to the value msg.session_id
|
||||||
|
ctx.session_id = msg.session_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
next_msg = await _handle_single_message(
|
next_msg = await _handle_single_message(
|
||||||
ctx, msg, use_workflow=not is_debug_session
|
ctx, msg, use_workflow=not is_debug_session
|
||||||
|
@ -3,6 +3,7 @@ from micropython import const
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import io, loop, utils
|
from trezor import io, loop, utils
|
||||||
|
from trezor.wire.protocol_common import Message, WireError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
@ -18,16 +19,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report
|
|||||||
SESSION_ID = const(0)
|
SESSION_ID = const(0)
|
||||||
|
|
||||||
|
|
||||||
class CodecError(Exception):
|
class CodecError(WireError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
|
||||||
def __init__(self, mtype: int, mdata: bytes) -> None:
|
|
||||||
self.type = mtype
|
|
||||||
self.data = mdata
|
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
|
@ -17,7 +17,8 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from trezor import log, loop, protobuf
|
from trezor import log, loop, protobuf
|
||||||
|
|
||||||
from . import codec_v1
|
from .protocol import WireProtocol
|
||||||
|
from .protocol_common import Message
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
@ -48,7 +49,7 @@ class UnexpectedMessage(Exception):
|
|||||||
should be aborted and a new one started as if `msg` was the first message.
|
should be aborted and a new one started as if `msg` was the first message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, msg: codec_v1.Message) -> None:
|
def __init__(self, msg: Message) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
||||||
@ -60,14 +61,14 @@ class Context:
|
|||||||
(i.e., wire, debug, single BT connection, etc.)
|
(i.e., wire, debug, single BT connection, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
|
||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.sid = sid
|
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
|
self.session_id: bytearray | None = None
|
||||||
|
|
||||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
def read_from_wire(self) -> Awaitable[Message]:
|
||||||
"""Read a whole message from the wire without parsing it."""
|
"""Read a whole message from the wire without parsing it."""
|
||||||
return codec_v1.read_message(self.iface, self.buffer)
|
return WireProtocol.read_message(self.iface, self.buffer)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
@ -97,7 +98,7 @@ class Context:
|
|||||||
__name__,
|
__name__,
|
||||||
"%s:%x expect: %s",
|
"%s:%x expect: %s",
|
||||||
self.iface.iface_num(),
|
self.iface.iface_num(),
|
||||||
self.sid,
|
self.session_id,
|
||||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -109,6 +110,9 @@ class Context:
|
|||||||
if msg.type not in expected_types:
|
if msg.type not in expected_types:
|
||||||
raise UnexpectedMessage(msg)
|
raise UnexpectedMessage(msg)
|
||||||
|
|
||||||
|
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
|
||||||
|
# (and maybe update ctx.session_id - depends on expected behaviour)
|
||||||
|
|
||||||
if expected_type is None:
|
if expected_type is None:
|
||||||
expected_type = protobuf.type_for_wire(msg.type)
|
expected_type = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
@ -117,7 +121,7 @@ class Context:
|
|||||||
__name__,
|
__name__,
|
||||||
"%s:%x read: %s",
|
"%s:%x read: %s",
|
||||||
self.iface.iface_num(),
|
self.iface.iface_num(),
|
||||||
self.sid,
|
self.session_id,
|
||||||
expected_type.MESSAGE_NAME,
|
expected_type.MESSAGE_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -133,7 +137,7 @@ class Context:
|
|||||||
__name__,
|
__name__,
|
||||||
"%s:%x write: %s",
|
"%s:%x write: %s",
|
||||||
self.iface.iface_num(),
|
self.iface.iface_num(),
|
||||||
self.sid,
|
self.session_id,
|
||||||
msg.MESSAGE_NAME,
|
msg.MESSAGE_NAME,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -151,10 +155,13 @@ class Context:
|
|||||||
|
|
||||||
msg_size = protobuf.encode(buffer, msg)
|
msg_size = protobuf.encode(buffer, msg)
|
||||||
|
|
||||||
await codec_v1.write_message(
|
await WireProtocol.write_message(
|
||||||
self.iface,
|
self.iface,
|
||||||
msg.MESSAGE_WIRE_TYPE,
|
Message(
|
||||||
memoryview(buffer)[:msg_size],
|
message_type=msg.MESSAGE_WIRE_TYPE,
|
||||||
|
message_data=memoryview(buffer)[:msg_size],
|
||||||
|
session_id=self.session_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
19
core/src/trezor/wire/protocol.py
Normal file
19
core/src/trezor/wire/protocol.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from trezor import utils
|
||||||
|
from trezor.wire import codec_v1, thp_v1
|
||||||
|
from trezor.wire.protocol_common import Message
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezorio import WireInterface
|
||||||
|
|
||||||
|
|
||||||
|
class WireProtocol:
|
||||||
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||||
|
if utils.USE_THP:
|
||||||
|
return thp_v1.read_message(iface, buffer)
|
||||||
|
return codec_v1.read_message(iface, buffer)
|
||||||
|
|
||||||
|
async def write_message(iface: WireInterface, message: Message) -> None:
|
||||||
|
if utils.USE_THP:
|
||||||
|
return thp_v1.write_to_wire(iface, message)
|
||||||
|
return codec_v1.write_message(iface, message.type, message.data)
|
14
core/src/trezor/wire/protocol_common.py
Normal file
14
core/src/trezor/wire/protocol_common.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
class Message:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_type: int,
|
||||||
|
message_data: bytes,
|
||||||
|
session_id: bytearray | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.type = message_type
|
||||||
|
self.data = message_data
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
|
||||||
|
class WireError(Exception):
|
||||||
|
pass
|
166
core/src/trezor/wire/thp_session.py
Normal file
166
core/src/trezor/wire/thp_session.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import ustruct
|
||||||
|
from storage import cache_thp as storage_thp_cache
|
||||||
|
from storage.cache_thp import SessionThpCache, BROADCAST_CHANNEL_ID
|
||||||
|
from trezor import io
|
||||||
|
from trezor.wire.protocol_common import WireError
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from ubinascii import hexlify
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from enum import IntEnum
|
||||||
|
else:
|
||||||
|
IntEnum = object
|
||||||
|
|
||||||
|
|
||||||
|
class ThpError(WireError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowState(IntEnum):
|
||||||
|
NOT_STARTED = 0
|
||||||
|
PENDING = 1
|
||||||
|
FINISHED = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Workflow:
|
||||||
|
id: int
|
||||||
|
workflow_state: WorkflowState
|
||||||
|
|
||||||
|
|
||||||
|
class SessionState(IntEnum):
|
||||||
|
UNALLOCATED = 0
|
||||||
|
INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
|
||||||
|
PAIRED = 2
|
||||||
|
UNPAIRED = 3
|
||||||
|
PAIRING = 4
|
||||||
|
APP_TRAFFIC = 5
|
||||||
|
|
||||||
|
|
||||||
|
def get_workflow() -> Workflow:
|
||||||
|
pass # TODO
|
||||||
|
|
||||||
|
|
||||||
|
def print_all_test_sessions() -> None:
|
||||||
|
for session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
|
||||||
|
if session is None:
|
||||||
|
print("none")
|
||||||
|
else:
|
||||||
|
print(hexlify(session.session_id).decode("utf-8"), session.state)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
def create_autenticated_session(unauthenticated_session: SessionThpCache):
|
||||||
|
storage_thp_cache.start_session() # TODO something like this but for THP
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_unauthenticated_session(iface: WireInterface, cid: int):
|
||||||
|
session_id = _get_id(iface, cid)
|
||||||
|
new_session = storage_thp_cache.create_new_unauthenticated_session(session_id)
|
||||||
|
set_session_state(new_session, SessionState.INITIALIZED)
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_session() -> SessionThpCache | None:
|
||||||
|
return storage_thp_cache.get_active_session()
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None:
|
||||||
|
session_id = _get_id(iface, cid)
|
||||||
|
return get_session_from_id(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_from_id(session_id) -> SessionThpCache | None:
|
||||||
|
session = _get_authenticated_session_or_none(session_id)
|
||||||
|
if session is None:
|
||||||
|
session = _get_unauthenticated_session_or_none(session_id)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_state(session: SessionThpCache) -> int:
|
||||||
|
if session is None:
|
||||||
|
return SessionState.UNALLOCATED
|
||||||
|
return _decode_session_state(session.state)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cid(session: SessionThpCache) -> int:
|
||||||
|
return storage_thp_cache._get_cid(session)
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_channel_id() -> int:
|
||||||
|
return storage_thp_cache.get_next_channel_id()
|
||||||
|
|
||||||
|
|
||||||
|
def sync_can_send_message(session: SessionThpCache) -> bool:
|
||||||
|
return session.sync & 0x80 == 0x80
|
||||||
|
|
||||||
|
|
||||||
|
def sync_get_receive_expected_bit(session: SessionThpCache) -> int:
|
||||||
|
return (session.sync & 0x40) >> 6
|
||||||
|
|
||||||
|
|
||||||
|
def sync_get_send_bit(session: SessionThpCache) -> int:
|
||||||
|
return (session.sync & 0x20) >> 5
|
||||||
|
|
||||||
|
|
||||||
|
def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None:
|
||||||
|
session.sync &= 0x7F
|
||||||
|
if can_send:
|
||||||
|
session.sync |= 0x80
|
||||||
|
|
||||||
|
|
||||||
|
def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
|
||||||
|
if bit != 0 and bit != 1:
|
||||||
|
raise ThpError("Unexpected receive sync bit")
|
||||||
|
|
||||||
|
# set second bit to "bit" value
|
||||||
|
session.sync &= 0xBF
|
||||||
|
session.sync |= 0x40
|
||||||
|
|
||||||
|
|
||||||
|
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:
|
||||||
|
_sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session))
|
||||||
|
|
||||||
|
|
||||||
|
def is_active_session(session: SessionThpCache):
|
||||||
|
if session is None:
|
||||||
|
return False
|
||||||
|
return session.session_id == storage_thp_cache.get_active_session_id()
|
||||||
|
|
||||||
|
|
||||||
|
def set_session_state(session: SessionThpCache, new_state: SessionState):
|
||||||
|
session.state = new_state.to_bytes(1, "big")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_id(iface: WireInterface, cid: int) -> bytearray:
|
||||||
|
return ustruct.pack(">HH", iface.iface_num(), cid)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_authenticated_session_or_none(session_id) -> SessionThpCache:
|
||||||
|
for authenticated_session in storage_thp_cache._SESSIONS:
|
||||||
|
if authenticated_session.session_id == session_id:
|
||||||
|
return authenticated_session
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache:
|
||||||
|
for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
|
||||||
|
if unauthenticated_session.session_id == session_id:
|
||||||
|
return unauthenticated_session
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None:
|
||||||
|
if bit != 0 and bit != 1:
|
||||||
|
raise ThpError("Unexpected send sync bit")
|
||||||
|
|
||||||
|
# set third bit to "bit" value
|
||||||
|
session.sync &= 0xDF
|
||||||
|
session.sync |= 0x20
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_session_state(state: bytearray) -> int:
|
||||||
|
return ustruct.unpack("B", state)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_session_state(state: SessionState) -> bytearray:
|
||||||
|
return ustruct.pack("B", state)
|
370
core/src/trezor/wire/thp_v1.py
Normal file
370
core/src/trezor/wire/thp_v1.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
import ustruct
|
||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from storage.cache_thp import SessionThpCache
|
||||||
|
from trezor import io, loop, utils
|
||||||
|
from trezor.crypto import crc
|
||||||
|
from trezor.wire.protocol_common import Message
|
||||||
|
import trezor.wire.thp_session as THP
|
||||||
|
from trezor.wire.thp_session import (
|
||||||
|
ThpError,
|
||||||
|
SessionState,
|
||||||
|
BROADCAST_CHANNEL_ID,
|
||||||
|
)
|
||||||
|
from ubinascii import hexlify
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezorio import WireInterface
|
||||||
|
|
||||||
|
_MAX_PAYLOAD_LEN = const(60000)
|
||||||
|
_CHECKSUM_LENGTH = const(4)
|
||||||
|
_CHANNEL_ALLOCATION_REQ = 0x40
|
||||||
|
_CHANNEL_ALLOCATION_RES = 0x40
|
||||||
|
_ERROR = 0x41
|
||||||
|
_CONTINUATION_PACKET = 0x80
|
||||||
|
_ACK_MESSAGE = 0x20
|
||||||
|
_HANDSHAKE_INIT = 0x00
|
||||||
|
_PLAINTEXT = 0x01
|
||||||
|
ENCRYPTED_TRANSPORT = 0x02
|
||||||
|
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
|
||||||
|
b"\x0A\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
|
||||||
|
)
|
||||||
|
_UNALLOCATED_SESSION_ERROR = (
|
||||||
|
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
|
||||||
|
)
|
||||||
|
|
||||||
|
_REPORT_LENGTH = const(64)
|
||||||
|
_REPORT_INIT_DATA_OFFSET = const(5)
|
||||||
|
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||||
|
|
||||||
|
|
||||||
|
class InitHeader:
|
||||||
|
format_str = ">BHH"
|
||||||
|
|
||||||
|
def __init__(self, ctrl_byte, cid, length) -> None:
|
||||||
|
self.ctrl_byte = ctrl_byte
|
||||||
|
self.cid = cid
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
return ustruct.pack(
|
||||||
|
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
|
||||||
|
ustruct.pack_into(
|
||||||
|
InitHeader.format_str,
|
||||||
|
buffer,
|
||||||
|
buffer_offset,
|
||||||
|
self.ctrl_byte,
|
||||||
|
self.cid,
|
||||||
|
self.length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
|
||||||
|
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptingInitPacket:
|
||||||
|
def __init__(self, report) -> None:
|
||||||
|
self.initReport = report
|
||||||
|
|
||||||
|
|
||||||
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||||
|
msg = await read_message_or_init_packet(iface, buffer)
|
||||||
|
while type(msg) is not Message:
|
||||||
|
if msg is InterruptingInitPacket:
|
||||||
|
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
|
||||||
|
else:
|
||||||
|
raise ThpError("Unexpected output of read_message_or_init_packet")
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
async def read_message_or_init_packet(
|
||||||
|
iface: WireInterface, buffer: utils.BufferType, firstReport=None
|
||||||
|
) -> Message | InterruptingInitPacket:
|
||||||
|
while True:
|
||||||
|
# Wait for an initial report
|
||||||
|
if firstReport is None:
|
||||||
|
report = await _get_loop_wait_read(iface)
|
||||||
|
else:
|
||||||
|
report = firstReport
|
||||||
|
|
||||||
|
# Channel multiplexing
|
||||||
|
ctrl_byte, cid = ustruct.unpack(">BH", report)
|
||||||
|
|
||||||
|
if cid == BROADCAST_CHANNEL_ID:
|
||||||
|
_handle_broadcast(iface, ctrl_byte, report)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We allow for only one message to be read simultaneously. We do not
|
||||||
|
# support reading multiple messages with interleaven packets - with
|
||||||
|
# the sole exception of cid_request which can be handled independently.
|
||||||
|
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||||
|
# continuation packet is not expected - ignore
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload_length = ustruct.unpack(">H", report[3:])[0]
|
||||||
|
payload = _get_buffer_for_payload(payload_length, buffer)
|
||||||
|
header = InitHeader(ctrl_byte, cid, payload_length)
|
||||||
|
|
||||||
|
# buffer the received data
|
||||||
|
interruptingPacket = await _buffer_received_data(payload, header, iface, report)
|
||||||
|
if interruptingPacket is not None:
|
||||||
|
return interruptingPacket
|
||||||
|
|
||||||
|
# Check CRC
|
||||||
|
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||||
|
# checksum is not valid -> ignore message
|
||||||
|
continue
|
||||||
|
|
||||||
|
session = THP.get_session(iface, cid)
|
||||||
|
session_state = THP.get_state(session)
|
||||||
|
|
||||||
|
# Handle message on unallocated channel
|
||||||
|
if session_state == SessionState.UNALLOCATED:
|
||||||
|
message = _handle_unallocated(iface, cid)
|
||||||
|
# unallocated should not return regular message, TODO, but it might change
|
||||||
|
if message is not None:
|
||||||
|
return message
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here
|
||||||
|
|
||||||
|
# Synchronization process
|
||||||
|
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||||
|
# 1: Handle ACKs
|
||||||
|
if _is_ctrl_byte_ack(ctrl_byte):
|
||||||
|
_handle_received_ACK(session, sync_bit)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2: Handle message with unexpected synchronization bit
|
||||||
|
if sync_bit != THP.sync_get_receive_expected_bit(session):
|
||||||
|
message = _handle_unexpected_sync_bit(iface, cid, sync_bit)
|
||||||
|
# unsynchronized messages should not return regular message, TODO,
|
||||||
|
# but it might change with the cancelation message
|
||||||
|
if message is not None:
|
||||||
|
return message
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3: Send ACK in response
|
||||||
|
_sendAck(iface, cid, sync_bit)
|
||||||
|
THP.sync_set_receive_expected_bit(session, 1 - sync_bit)
|
||||||
|
|
||||||
|
return _handle_allocated(ctrl_byte, session, payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_loop_wait_read(iface: WireInterface):
|
||||||
|
return loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_buffer_for_payload(
|
||||||
|
payload_length: int, existing_buffer: utils.BufferType
|
||||||
|
) -> utils.BufferType:
|
||||||
|
if payload_length > _MAX_PAYLOAD_LEN:
|
||||||
|
raise ThpError("Message too large")
|
||||||
|
if payload_length > len(existing_buffer):
|
||||||
|
# allocate a new buffer to fit the message
|
||||||
|
try:
|
||||||
|
payload: utils.BufferType = bytearray(payload_length)
|
||||||
|
except MemoryError:
|
||||||
|
payload = bytearray(_REPORT_LENGTH)
|
||||||
|
raise ("Message too large")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
# reuse a part of the supplied buffer
|
||||||
|
return memoryview(existing_buffer)[:payload_length]
|
||||||
|
|
||||||
|
|
||||||
|
async def _buffer_received_data(
|
||||||
|
payload: utils.BufferType, header: InitHeader, iface, report
|
||||||
|
) -> None | InterruptingInitPacket:
|
||||||
|
# buffer the initial data
|
||||||
|
nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET)
|
||||||
|
while nread < header.length:
|
||||||
|
# wait for continuation report
|
||||||
|
report = await _get_loop_wait_read(iface)
|
||||||
|
|
||||||
|
# channel multiplexing
|
||||||
|
cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report)
|
||||||
|
|
||||||
|
# handle broadcast - allows the reading process
|
||||||
|
# to survive interruption by broadcast
|
||||||
|
if cont_cid == BROADCAST_CHANNEL_ID:
|
||||||
|
_handle_broadcast(iface, cont_ctrl_byte, report)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# handle unexpected initiation packet
|
||||||
|
if not _is_ctrl_byte_continuation(cont_ctrl_byte):
|
||||||
|
# TODO possibly add timeout - allow interruption only after a long time
|
||||||
|
return InterruptingInitPacket(report)
|
||||||
|
|
||||||
|
# ignore continuation packets on different channels
|
||||||
|
if cont_cid != header.cid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# buffer the continuation data
|
||||||
|
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_message(
|
||||||
|
iface: WireInterface, message: Message, is_retransmission: bool = False
|
||||||
|
) -> None:
|
||||||
|
session = THP.get_session_from_id(message.session_id)
|
||||||
|
cid = THP.get_cid(session)
|
||||||
|
payload = message.type.to_bytes(2, "big") + message.data
|
||||||
|
payload_length = len(payload)
|
||||||
|
|
||||||
|
if THP.get_state(session) == SessionState.INITIALIZED:
|
||||||
|
# write message in plaintext, TODO check if it is allowed
|
||||||
|
ctrl_byte = _PLAINTEXT
|
||||||
|
elif THP.get_state(session) == SessionState.APP_TRAFFIC:
|
||||||
|
ctrl_byte = ENCRYPTED_TRANSPORT
|
||||||
|
else:
|
||||||
|
raise ThpError("Session in not implemented state" + str(THP.get_state(session)))
|
||||||
|
|
||||||
|
if not is_retransmission:
|
||||||
|
ctrl_byte = _add_sync_bit_to_ctrl_byte(
|
||||||
|
ctrl_byte, THP.sync_get_send_bit(session)
|
||||||
|
)
|
||||||
|
THP.sync_set_send_bit_to_opposite(session)
|
||||||
|
else:
|
||||||
|
# retransmission must have the same sync bit as the previously sent message
|
||||||
|
ctrl_byte = _add_sync_bit_to_ctrl_byte(ctrl_byte, 1 - THP.sync_get_send_bit())
|
||||||
|
|
||||||
|
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
|
||||||
|
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
|
||||||
|
await write_to_wire(iface, header, payload + checksum)
|
||||||
|
# TODO set timeout for retransmission
|
||||||
|
|
||||||
|
|
||||||
|
async def write_to_wire(
|
||||||
|
iface: WireInterface, header: InitHeader, payload: bytes
|
||||||
|
) -> None:
|
||||||
|
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||||
|
|
||||||
|
payload_length = len(payload)
|
||||||
|
|
||||||
|
# prepare the report buffer with header data
|
||||||
|
report = bytearray(_REPORT_LENGTH)
|
||||||
|
header.pack_to_buffer(report)
|
||||||
|
|
||||||
|
# write initial report
|
||||||
|
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
|
||||||
|
await _write_report(write, iface, report)
|
||||||
|
|
||||||
|
# if we have more data to write, use continuation reports for it
|
||||||
|
if nwritten < payload_length:
|
||||||
|
header.pack_to_cont_buffer(report)
|
||||||
|
|
||||||
|
while nwritten < payload_length:
|
||||||
|
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
|
||||||
|
await _write_report(write, iface, report)
|
||||||
|
|
||||||
|
|
||||||
|
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
|
||||||
|
while True:
|
||||||
|
await write
|
||||||
|
n = iface.write(report)
|
||||||
|
if n == len(report):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None:
|
||||||
|
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
||||||
|
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
||||||
|
length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:])
|
||||||
|
|
||||||
|
if not _is_checksum_valid(checksum, data=report[:-4]):
|
||||||
|
raise ThpError("Checksum is not valid")
|
||||||
|
|
||||||
|
channel_id = _get_new_channel_id()
|
||||||
|
THP.create_new_unauthenticated_session(iface, channel_id)
|
||||||
|
response_data = (
|
||||||
|
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
|
||||||
|
)
|
||||||
|
|
||||||
|
response_header = InitHeader(
|
||||||
|
_CHANNEL_ALLOCATION_RES,
|
||||||
|
BROADCAST_CHANNEL_ID,
|
||||||
|
len(response_data) + _CHECKSUM_LENGTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
|
||||||
|
write_to_wire(iface, response_header, response_data + checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
|
||||||
|
# Parameters session and ctrl_byte will be used to determine if the
|
||||||
|
# communication should be encrypted or not
|
||||||
|
|
||||||
|
message_type = ustruct.unpack(">H", payload)[0]
|
||||||
|
|
||||||
|
# trim message type and checksum from payload
|
||||||
|
message_data = payload[2:-_CHECKSUM_LENGTH]
|
||||||
|
return Message(message_type, message_data, session.session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
|
||||||
|
# No ACKs expected
|
||||||
|
if THP.sync_can_send_message(session):
|
||||||
|
return
|
||||||
|
|
||||||
|
# ACK has incorrect sync bit
|
||||||
|
if THP.sync_get_send_bit(session) != sync_bit:
|
||||||
|
return
|
||||||
|
|
||||||
|
# ACK is expected and it has correct sync bit
|
||||||
|
THP.sync_set_can_send_message(session, True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_unallocated(iface, cid) -> Message | None:
|
||||||
|
data = _UNALLOCATED_SESSION_ERROR
|
||||||
|
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
|
||||||
|
checksum = _compute_checksum_bytes(header.to_bytes() + data)
|
||||||
|
write_to_wire(iface, header, data + checksum)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
|
||||||
|
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
||||||
|
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
|
||||||
|
checksum = _compute_checksum_bytes(header.to_bytes())
|
||||||
|
write_to_wire(iface, header, checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_unexpected_sync_bit(
|
||||||
|
iface: WireInterface, cid: int, sync_bit: int
|
||||||
|
) -> Message | None:
|
||||||
|
_sendAck(iface, cid, sync_bit)
|
||||||
|
|
||||||
|
# TODO handle cancelation messages and messages on allocated channels without synchronization
|
||||||
|
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_new_channel_id() -> int:
|
||||||
|
return THP.get_next_channel_id()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_checksum_valid(checksum: bytearray, data: bytearray) -> bool:
|
||||||
|
data_checksum = _compute_checksum_bytes(data)
|
||||||
|
return checksum == data_checksum
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
||||||
|
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
|
||||||
|
return ctrl_byte & 0x20 == _ACK_MESSAGE
|
||||||
|
|
||||||
|
|
||||||
|
def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
|
||||||
|
if sync_bit == 0:
|
||||||
|
return ctrl_byte & 0xEF
|
||||||
|
if sync_bit == 1:
|
||||||
|
return ctrl_byte | 0x10
|
||||||
|
raise ThpError("Unexpected synchronization bit")
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_checksum_bytes(data: bytearray):
|
||||||
|
return crc.crc32(data).to_bytes(4, "big")
|
@ -1,17 +1,26 @@
|
|||||||
from common import * # isort:skip
|
from common import * # isort:skip
|
||||||
|
|
||||||
from mock_storage import mock_storage
|
from mock_storage import mock_storage
|
||||||
from storage import cache
|
|
||||||
from trezor.messages import EndSession, Initialize
|
from storage import cache, cache_codec, cache_thp
|
||||||
|
from storage.cache_common import InvalidSessionError
|
||||||
|
from trezor import utils
|
||||||
|
from trezor.messages import Initialize
|
||||||
|
from trezor.messages import EndSession
|
||||||
|
|
||||||
from apps.base import handle_EndSession, handle_Initialize
|
from apps.base import handle_EndSession, handle_Initialize
|
||||||
|
|
||||||
KEY = 0
|
KEY = 0
|
||||||
|
|
||||||
|
if utils.USE_THP:
|
||||||
|
_PROTOCOL_CACHE = cache_thp
|
||||||
|
else:
|
||||||
|
_PROTOCOL_CACHE = cache_codec
|
||||||
|
|
||||||
|
|
||||||
# Function moved from cache.py, as it was not used there
|
# Function moved from cache.py, as it was not used there
|
||||||
def is_session_started() -> bool:
|
def is_session_started() -> bool:
|
||||||
return cache._active_session_idx is not None
|
return _PROTOCOL_CACHE.get_active_session() is not None
|
||||||
|
|
||||||
|
|
||||||
class TestStorageCache(unittest.TestCase):
|
class TestStorageCache(unittest.TestCase):
|
||||||
@ -25,9 +34,9 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertNotEqual(session_id_a, session_id_b)
|
self.assertNotEqual(session_id_a, session_id_b)
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(InvalidSessionError):
|
||||||
cache.set(KEY, "something")
|
cache.set(KEY, "something")
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(InvalidSessionError):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
def test_end_session(self):
|
def test_end_session(self):
|
||||||
@ -36,7 +45,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
cache.set(KEY, b"A")
|
cache.set(KEY, b"A")
|
||||||
cache.end_current_session()
|
cache.end_current_session()
|
||||||
self.assertFalse(is_session_started())
|
self.assertFalse(is_session_started())
|
||||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
self.assertRaises(InvalidSessionError, cache.get, KEY)
|
||||||
|
|
||||||
# ending an ended session should be a no-op
|
# ending an ended session should be a no-op
|
||||||
cache.end_current_session()
|
cache.end_current_session()
|
||||||
@ -63,7 +72,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
session_id = cache.start_session()
|
session_id = cache.start_session()
|
||||||
self.assertEqual(cache.start_session(session_id), session_id)
|
self.assertEqual(cache.start_session(session_id), session_id)
|
||||||
cache.set(KEY, b"A")
|
cache.set(KEY, b"A")
|
||||||
for i in range(cache._MAX_SESSIONS_COUNT):
|
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
|
||||||
cache.start_session()
|
cache.start_session()
|
||||||
self.assertNotEqual(cache.start_session(session_id), session_id)
|
self.assertNotEqual(cache.start_session(session_id), session_id)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(cache.get(KEY))
|
||||||
@ -83,7 +92,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(cache.get(KEY), b"hello")
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(InvalidSessionError):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
def test_get_set_int(self):
|
def test_get_set_int(self):
|
||||||
@ -101,7 +110,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(cache.get_int(KEY), 1234)
|
self.assertEqual(cache.get_int(KEY), 1234)
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(InvalidSessionError):
|
||||||
cache.get_int(KEY)
|
cache.get_int(KEY)
|
||||||
|
|
||||||
def test_delete(self):
|
def test_delete(self):
|
||||||
@ -186,6 +195,9 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
|
|
||||||
@mock_storage
|
@mock_storage
|
||||||
def test_Initialize(self):
|
def test_Initialize(self):
|
||||||
|
if utils.USE_THP: # INITIALIZE SHOULD NOT BE IN THP!!! TODO
|
||||||
|
return
|
||||||
|
|
||||||
def call_Initialize(**kwargs):
|
def call_Initialize(**kwargs):
|
||||||
msg = Initialize(**kwargs)
|
msg = Initialize(**kwargs)
|
||||||
return await_result(handle_Initialize(msg))
|
return await_result(handle_Initialize(msg))
|
||||||
@ -210,7 +222,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(cache.get(KEY), b"hello")
|
||||||
|
|
||||||
# supplying a different session ID starts a new cache
|
# supplying a different session ID starts a new cache
|
||||||
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
|
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(cache.get(KEY))
|
||||||
|
|
||||||
# but resuming a session loads the previous one
|
# but resuming a session loads the previous one
|
||||||
@ -218,13 +230,18 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(cache.get(KEY), b"hello")
|
||||||
|
|
||||||
def test_EndSession(self):
|
def test_EndSession(self):
|
||||||
|
<<<<<<< HEAD
|
||||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||||
cache.start_session()
|
cache.start_session()
|
||||||
|
=======
|
||||||
|
self.assertRaises(InvalidSessionError, cache.get, KEY)
|
||||||
|
session_id = cache.start_session()
|
||||||
|
>>>>>>> 8681ba167 (Basic THP functinality - not-polished prototype)
|
||||||
self.assertTrue(is_session_started())
|
self.assertTrue(is_session_started())
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(cache.get(KEY))
|
||||||
await_result(handle_EndSession(EndSession()))
|
await_result(handle_EndSession(EndSession()))
|
||||||
self.assertFalse(is_session_started())
|
self.assertFalse(is_session_started())
|
||||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
self.assertRaises(InvalidSessionError, cache.get, KEY)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
341
core/tests/test_trezor.wire.thp_v1.py
Normal file
341
core/tests/test_trezor.wire.thp_v1.py
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
from common import *
|
||||||
|
from ubinascii import hexlify, unhexlify
|
||||||
|
import ustruct
|
||||||
|
|
||||||
|
from trezor import io, utils
|
||||||
|
from trezor.loop import wait
|
||||||
|
from trezor.utils import chunks
|
||||||
|
from trezor.wire import thp_v1
|
||||||
|
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID
|
||||||
|
from trezor.wire.protocol_common import Message
|
||||||
|
import trezor.wire.thp_session as THP
|
||||||
|
|
||||||
|
from micropython import const
|
||||||
|
|
||||||
|
|
||||||
|
class MockHID:
|
||||||
|
def __init__(self, num):
|
||||||
|
self.num = num
|
||||||
|
self.data = []
|
||||||
|
|
||||||
|
def iface_num(self):
|
||||||
|
return self.num
|
||||||
|
|
||||||
|
def write(self, msg):
|
||||||
|
self.data.append(bytearray(msg))
|
||||||
|
return len(msg)
|
||||||
|
|
||||||
|
def wait_object(self, mode):
|
||||||
|
return wait(mode | self.num)
|
||||||
|
|
||||||
|
|
||||||
|
MESSAGE_TYPE = 0x4242
|
||||||
|
MESSAGE_TYPE_BYTES = b"\x42\x42"
|
||||||
|
_MESSAGE_TYPE_LEN = 2
|
||||||
|
PLAINTEXT_0 = 0x01
|
||||||
|
PLAINTEXT_1 = 0x11
|
||||||
|
COMMON_CID = 4660
|
||||||
|
CONT = 0x80
|
||||||
|
|
||||||
|
HEADER_INIT_LENGTH = 5
|
||||||
|
HEADER_CONT_LENGTH = 3
|
||||||
|
INIT_MESSAGE_DATA_LENGTH = (
|
||||||
|
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_header(ctrl_byte, cid, length):
|
||||||
|
return ustruct.pack(">BHH", ctrl_byte, cid, length)
|
||||||
|
|
||||||
|
|
||||||
|
def make_cont_header():
|
||||||
|
return ustruct.pack(">BH", CONT, COMMON_CID)
|
||||||
|
|
||||||
|
|
||||||
|
def makeSimpleMessage(header, message_type, message_data):
|
||||||
|
return header + ustruct.pack(">H", message_type) + message_data
|
||||||
|
|
||||||
|
|
||||||
|
def makeCidRequest(header, message_data):
|
||||||
|
return header + message_data
|
||||||
|
|
||||||
|
|
||||||
|
def printBytes(a):
|
||||||
|
print(hexlify(a).decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def getPlaintext() -> bytes:
|
||||||
|
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
|
||||||
|
return PLAINTEXT_1
|
||||||
|
PLAINTEXT_0
|
||||||
|
|
||||||
|
|
||||||
|
def getCid() -> int:
|
||||||
|
return THP.get_cid(THP.get_active_session())
|
||||||
|
|
||||||
|
|
||||||
|
# This test suite is an adaptation of test_trezor.wire.codec_v1
|
||||||
|
class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.interface = MockHID(0xDEADBEEF)
|
||||||
|
if not utils.USE_THP:
|
||||||
|
import storage.cache_thp # noQA:F401
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
cid_req_header = make_header(
|
||||||
|
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
|
||||||
|
)
|
||||||
|
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3C\x6C"
|
||||||
|
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
|
||||||
|
|
||||||
|
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
|
||||||
|
cid_request_dummy_data_checksum = b"\x67\x8E\xAC\xE0"
|
||||||
|
message = makeSimpleMessage(
|
||||||
|
message_header,
|
||||||
|
MESSAGE_TYPE,
|
||||||
|
cid_request_dummy_data + cid_request_dummy_data_checksum,
|
||||||
|
)
|
||||||
|
|
||||||
|
buffer = bytearray(64)
|
||||||
|
|
||||||
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(cid_req_message)
|
||||||
|
gen.send(message)
|
||||||
|
|
||||||
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
|
result = e.value.value
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data, cid_request_dummy_data)
|
||||||
|
|
||||||
|
buffer_without_zeroes = buffer[: len(message) - 5]
|
||||||
|
message_without_header = message[5:]
|
||||||
|
# message should have been read into the buffer
|
||||||
|
self.assertEqual(buffer_without_zeroes, message_without_header)
|
||||||
|
|
||||||
|
def test_read_one_packet(self):
|
||||||
|
# zero length message - just a header
|
||||||
|
PLAINTEXT = getPlaintext()
|
||||||
|
header = make_header(
|
||||||
|
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
|
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
||||||
|
message = header + MESSAGE_TYPE_BYTES + checksum
|
||||||
|
|
||||||
|
buffer = bytearray(64)
|
||||||
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
|
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(message)
|
||||||
|
|
||||||
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
|
result = e.value.value
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data, b"")
|
||||||
|
|
||||||
|
# message should have been read into the buffer
|
||||||
|
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58)
|
||||||
|
|
||||||
|
def test_read_many_packets(self):
|
||||||
|
message = bytes(range(256))
|
||||||
|
header = make_header(
|
||||||
|
getPlaintext(),
|
||||||
|
COMMON_CID,
|
||||||
|
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
||||||
|
)
|
||||||
|
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
||||||
|
# message = MESSAGE_TYPE_BYTES + message + checksum
|
||||||
|
|
||||||
|
# first packet is init header + 59 bytes of data
|
||||||
|
# other packets are cont header + 61 bytes of data
|
||||||
|
cont_header = make_cont_header()
|
||||||
|
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
|
||||||
|
cont_header + chunk
|
||||||
|
for chunk in chunks(
|
||||||
|
message[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
||||||
|
64 - HEADER_CONT_LENGTH,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
buffer = bytearray(262)
|
||||||
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
|
query = gen.send(None)
|
||||||
|
for packet in packets[:-1]:
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
query = gen.send(packet)
|
||||||
|
|
||||||
|
# last packet will stop
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(packets[-1])
|
||||||
|
|
||||||
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
|
result = e.value.value
|
||||||
|
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data, message)
|
||||||
|
|
||||||
|
# message should have been read into the buffer )
|
||||||
|
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum)
|
||||||
|
|
||||||
|
def test_read_large_message(self):
|
||||||
|
message = b"hello world"
|
||||||
|
header = make_header(
|
||||||
|
getPlaintext(),
|
||||||
|
COMMON_CID,
|
||||||
|
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH,
|
||||||
|
)
|
||||||
|
|
||||||
|
packet = (
|
||||||
|
header
|
||||||
|
+ MESSAGE_TYPE_BYTES
|
||||||
|
+ message
|
||||||
|
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure we fit into one packet, to make this easier
|
||||||
|
self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH)
|
||||||
|
|
||||||
|
buffer = bytearray(1)
|
||||||
|
self.assertTrue(len(buffer) <= len(packet))
|
||||||
|
|
||||||
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(packet)
|
||||||
|
|
||||||
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
|
result = e.value.value
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data, message)
|
||||||
|
|
||||||
|
# read should have allocated its own buffer and not touch ours
|
||||||
|
self.assertEqual(buffer, b"\x00")
|
||||||
|
|
||||||
|
def test_write_one_packet(self):
|
||||||
|
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID))
|
||||||
|
gen = thp_v1.write_message(self.interface, message)
|
||||||
|
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
with self.assertRaises(StopIteration):
|
||||||
|
gen.send(None)
|
||||||
|
|
||||||
|
header = make_header(
|
||||||
|
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
|
expected_message = (
|
||||||
|
header
|
||||||
|
+ MESSAGE_TYPE_BYTES
|
||||||
|
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
||||||
|
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH)
|
||||||
|
)
|
||||||
|
self.assertTrue(self.interface.data == [expected_message])
|
||||||
|
|
||||||
|
def test_write_multiple_packets(self):
|
||||||
|
message_payload = bytes(range(256))
|
||||||
|
message = Message(
|
||||||
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||||
|
)
|
||||||
|
gen = thp_v1.write_message(self.interface, message)
|
||||||
|
|
||||||
|
header = make_header(
|
||||||
|
PLAINTEXT_1,
|
||||||
|
COMMON_CID,
|
||||||
|
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
||||||
|
)
|
||||||
|
cont_header = make_cont_header()
|
||||||
|
checksum = thp_v1._compute_checksum_bytes(
|
||||||
|
header + message.type.to_bytes(2, "big") + message.data
|
||||||
|
)
|
||||||
|
packets = [
|
||||||
|
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
|
||||||
|
] + [
|
||||||
|
cont_header + chunk
|
||||||
|
for chunk in chunks(
|
||||||
|
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
||||||
|
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in packets:
|
||||||
|
# we receive as many queries as there are packets
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
|
||||||
|
# the first sent None only started the generator. the len(packets)-th None
|
||||||
|
# will finish writing and raise StopIteration
|
||||||
|
with self.assertRaises(StopIteration):
|
||||||
|
gen.send(None)
|
||||||
|
|
||||||
|
# packets must be identical up to the last one
|
||||||
|
self.assertListEqual(packets[:-1], self.interface.data[:-1])
|
||||||
|
# last packet must be identical up to message length. remaining bytes in
|
||||||
|
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||||
|
# previous packet
|
||||||
|
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||||
|
self.assertEqual(last_packet, self.interface.data[-1])
|
||||||
|
|
||||||
|
def test_roundtrip(self):
|
||||||
|
message_payload = bytes(range(256))
|
||||||
|
message = Message(
|
||||||
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||||
|
)
|
||||||
|
gen = thp_v1.write_message(self.interface, message)
|
||||||
|
|
||||||
|
# exhaust the iterator:
|
||||||
|
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||||
|
for query in gen:
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
|
||||||
|
buffer = bytearray(1024)
|
||||||
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
|
query = gen.send(None)
|
||||||
|
for packet in self.interface.data[:-1]:
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
query = gen.send(packet)
|
||||||
|
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(self.interface.data[-1])
|
||||||
|
|
||||||
|
result = e.value.value
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data, message.data)
|
||||||
|
|
||||||
|
def test_read_huge_packet(self):
|
||||||
|
PACKET_COUNT = 1180
|
||||||
|
# message that takes up 1 180 USB packets
|
||||||
|
message_size = (PACKET_COUNT - 1) * (
|
||||||
|
thp_v1._REPORT_LENGTH
|
||||||
|
- HEADER_CONT_LENGTH
|
||||||
|
- _CHECKSUM_LENGTH
|
||||||
|
- _MESSAGE_TYPE_LEN
|
||||||
|
) + INIT_MESSAGE_DATA_LENGTH
|
||||||
|
|
||||||
|
# ensure that a message this big won't fit into memory
|
||||||
|
# Note: this control is changed, because THP has only 2 byte length field
|
||||||
|
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN)
|
||||||
|
# self.assertRaises(MemoryError, bytearray, message_size)
|
||||||
|
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
|
||||||
|
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
|
||||||
|
buffer = bytearray(65536)
|
||||||
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
|
|
||||||
|
query = gen.send(None)
|
||||||
|
|
||||||
|
# THP returns "Message too large" error after reading the message size,
|
||||||
|
# it is different from codec_v1 as it does not allow big enough messages
|
||||||
|
# to raise MemoryError in this test
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
with self.assertRaises(thp_v1.ThpError) as e:
|
||||||
|
query = gen.send(packet)
|
||||||
|
|
||||||
|
self.assertEqual(e.value.args[0], "Message too large")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -252,16 +252,16 @@ def run_class(c, test_result):
|
|||||||
raise RuntimeError(f"{name} should not return a result.")
|
raise RuntimeError(f"{name} should not return a result.")
|
||||||
finally:
|
finally:
|
||||||
tear_down()
|
tear_down()
|
||||||
print(" ok")
|
print("\033[32mok\033[0m")
|
||||||
except SkipTest as e:
|
except SkipTest as e:
|
||||||
print(" skipped:", e.args[0])
|
print(" skipped:", e.args[0])
|
||||||
test_result.skippedNum += 1
|
test_result.skippedNum += 1
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print(" failed")
|
print("\033[31mfailed\033[0m")
|
||||||
sys.print_exception(e)
|
sys.print_exception(e)
|
||||||
test_result.failuresNum += 1
|
test_result.failuresNum += 1
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
print(" errored:", e)
|
print("\033[31merrored:\033[0m", e)
|
||||||
sys.print_exception(e)
|
sys.print_exception(e)
|
||||||
test_result.errorsNum += 1
|
test_result.errorsNum += 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user