mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-23 21:02:23 +00:00
Basic THP functinality - not-polished prototype
This commit is contained in:
parent
89fdaed31e
commit
94278d5c01
14
core/src/all_modules.py
generated
14
core/src/all_modules.py
generated
@ -47,6 +47,12 @@ storage
|
||||
import storage
|
||||
storage.cache
|
||||
import storage.cache
|
||||
storage.cache_codec
|
||||
import storage.cache_codec
|
||||
storage.cache_common
|
||||
import storage.cache_common
|
||||
storage.cache_thp
|
||||
import storage.cache_thp
|
||||
storage.common
|
||||
import storage.common
|
||||
storage.debug
|
||||
@ -195,6 +201,14 @@ trezor.wire.context
|
||||
import trezor.wire.context
|
||||
trezor.wire.errors
|
||||
import trezor.wire.errors
|
||||
trezor.wire.protocol
|
||||
import trezor.wire.protocol
|
||||
trezor.wire.protocol_common
|
||||
import trezor.wire.protocol_common
|
||||
trezor.wire.thp_session
|
||||
import trezor.wire.thp_session
|
||||
trezor.wire.thp_v1
|
||||
import trezor.wire.thp_v1
|
||||
trezor.workflow
|
||||
import trezor.workflow
|
||||
apps
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import storage.cache as storage_cache
|
||||
import storage.cache_thp as storage_thp_cache
|
||||
import storage.device as storage_device
|
||||
from trezor import TR, config, utils, wire, workflow
|
||||
from trezor.enums import HomescreenFormat, MessageType
|
||||
@ -174,10 +175,21 @@ def get_features() -> Features:
|
||||
return f
|
||||
|
||||
|
||||
async def handle_Initialize(msg: Initialize) -> Features:
|
||||
# handle_Initialize should not be used with THP to start a new session
|
||||
async def handle_Initialize(
|
||||
msg: Initialize, message_session_id: bytearray | None = None
|
||||
) -> Features:
|
||||
if message_session_id is None and utils.USE_THP:
|
||||
raise ValueError("With THP enabled, a session id must be provided in args")
|
||||
|
||||
if utils.USE_THP:
|
||||
session_id = storage_thp_cache.start_existing_session(msg.session_id)
|
||||
else:
|
||||
session_id = storage_cache.start_session(msg.session_id)
|
||||
|
||||
if not utils.BITCOIN_ONLY:
|
||||
# TODO this block should be changed in THP
|
||||
|
||||
derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO)
|
||||
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
|
||||
|
||||
@ -189,7 +201,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
|
||||
# seed is already derived, and host wants to change derive_cardano setting
|
||||
# => create a new session
|
||||
storage_cache.end_current_session()
|
||||
session_id = storage_cache.start_session()
|
||||
session_id = storage_cache.start_session() # This should not be used in THP
|
||||
have_seed = False
|
||||
|
||||
if not have_seed:
|
||||
@ -199,7 +211,7 @@ async def handle_Initialize(msg: Initialize) -> Features:
|
||||
)
|
||||
|
||||
features = get_features()
|
||||
features.session_id = session_id
|
||||
features.session_id = session_id # not important in THP
|
||||
return features
|
||||
|
||||
|
||||
|
@ -4,17 +4,13 @@ from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
from storage.cache_common import SESSIONLESS_FLAG, InvalidSessionError, SessionlessCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
_SESSIONLESS_FLAG = const(128)
|
||||
_SESSION_ID_LENGTH = const(32)
|
||||
|
||||
# Traditional cache keys
|
||||
APP_COMMON_SEED = const(0)
|
||||
APP_COMMON_AUTHORIZATION_TYPE = const(1)
|
||||
@ -27,14 +23,13 @@ if not utils.BITCOIN_ONLY:
|
||||
APP_MONERO_LIVE_REFRESH = const(7)
|
||||
|
||||
# Keys that are valid across sessions
|
||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | _SESSIONLESS_FLAG)
|
||||
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | _SESSIONLESS_FLAG)
|
||||
APP_COMMON_BUSY_DEADLINE_MS = const(4 | _SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_NONCE = const(5 | _SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_COMMITMENT = const(6 | _SESSIONLESS_FLAG)
|
||||
|
||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = const(0 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = const(1 | SESSIONLESS_FLAG)
|
||||
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = const(2 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = const(3 | SESSIONLESS_FLAG)
|
||||
APP_COMMON_BUSY_DEADLINE_MS = const(4 | SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_NONCE = const(5 | SESSIONLESS_FLAG)
|
||||
APP_MISC_COSI_COMMITMENT = const(6 | SESSIONLESS_FLAG)
|
||||
|
||||
# === Homescreen storage ===
|
||||
# This does not logically belong to the "cache" functionality, but the cache module is
|
||||
@ -52,103 +47,6 @@ homescreen_shown: object | None = None
|
||||
autolock_last_touch: int | None = None
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None: ...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
1, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
1, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
|
||||
# XXX
|
||||
# Allocation notes:
|
||||
# Instantiation of a DataCache subclass should make as little garbage as possible, so
|
||||
@ -157,97 +55,46 @@ class SessionlessCache(DataCache):
|
||||
# bytearrays, then later call `clear()` on all the existing objects, which resets them
|
||||
# to zero length. This is producing some trash - `b[:]` allocates a slice.
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
_SESSIONLESS_CACHE = SessionlessCache()
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
if utils.USE_THP:
|
||||
from storage import cache_thp
|
||||
|
||||
_PROTOCOL_CACHE = cache_thp
|
||||
else:
|
||||
from storage import cache_codec
|
||||
|
||||
_PROTOCOL_CACHE = cache_codec
|
||||
|
||||
_PROTOCOL_CACHE.initialize()
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
def clear_all() -> None:
|
||||
global autolock_last_touch
|
||||
autolock_last_touch = None
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
_PROTOCOL_CACHE.clear_all()
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
return _PROTOCOL_CACHE.start_session(received_session_id)
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
_PROTOCOL_CACHE.end_current_session()
|
||||
|
||||
|
||||
def set(key: int, value: bytes) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
|
||||
return
|
||||
if _active_session_idx is None:
|
||||
def delete(key: int) -> None:
|
||||
if key & SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG)
|
||||
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||
if active_session is None:
|
||||
raise InvalidSessionError
|
||||
_SESSIONS[_active_session_idx].set(key, value)
|
||||
|
||||
|
||||
def set_int(key: int, value: int) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
length = _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG]
|
||||
elif _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
else:
|
||||
length = _SESSIONS[_active_session_idx].fields[key]
|
||||
|
||||
encoded = value.to_bytes(length, "big")
|
||||
|
||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||
# doesn't raise OverflowError.
|
||||
assert int.from_bytes(encoded, "big") == value
|
||||
|
||||
set(key, encoded)
|
||||
return active_session.delete(key)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -261,11 +108,12 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
|
||||
if _active_session_idx is None:
|
||||
if key & SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default)
|
||||
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||
if active_session is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].get(key, default)
|
||||
return active_session.get(key, default)
|
||||
|
||||
|
||||
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||
@ -277,29 +125,52 @@ def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
sessions = [_SESSIONLESS_CACHE] if key & _SESSIONLESS_FLAG else _SESSIONS
|
||||
if key & SESSIONLESS_FLAG:
|
||||
values = builtins.set()
|
||||
for session in sessions:
|
||||
encoded = session.get(key)
|
||||
encoded = _SESSIONLESS_CACHE.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
return _PROTOCOL_CACHE.get_int_all_sessions(key)
|
||||
|
||||
|
||||
def is_set(key: int) -> bool:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
|
||||
if _active_session_idx is None:
|
||||
if key & SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG)
|
||||
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||
if active_session is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].is_set(key)
|
||||
return active_session.is_set(key)
|
||||
|
||||
|
||||
def delete(key: int) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
|
||||
if _active_session_idx is None:
|
||||
def set(key: int, value: bytes) -> None:
|
||||
if key & SESSIONLESS_FLAG:
|
||||
_SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value)
|
||||
return
|
||||
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||
if active_session is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].delete(key)
|
||||
active_session.set(key, value)
|
||||
|
||||
|
||||
def set_int(key: int, value: int) -> None:
|
||||
active_session = _PROTOCOL_CACHE.get_active_session()
|
||||
|
||||
if key & SESSIONLESS_FLAG:
|
||||
length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG]
|
||||
|
||||
if active_session is None:
|
||||
raise InvalidSessionError
|
||||
else:
|
||||
length = active_session.fields[key]
|
||||
|
||||
encoded = value.to_bytes(length, "big")
|
||||
|
||||
# Ensure that the value fits within the length. Micropython's int.to_bytes()
|
||||
# doesn't raise OverflowError.
|
||||
assert int.from_bytes(encoded, "big") == value
|
||||
|
||||
set(key, encoded)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -336,15 +207,3 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
global autolock_last_touch
|
||||
|
||||
_active_session_idx = None
|
||||
_SESSIONLESS_CACHE.clear()
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
autolock_last_touch = None
|
||||
|
144
core/src/storage/cache_codec.py
Normal file
144
core/src/storage/cache_codec.py
Normal file
@ -0,0 +1,144 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
from trezor import utils
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
_SESSION_ID_LENGTH = const(32)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
1, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
1, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _SESSIONS
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def get_active_session() -> SessionCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
return _SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS:
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
70
core/src/storage/cache_common.py
Normal file
70
core/src/storage/cache_common.py
Normal file
@ -0,0 +1,70 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
SESSIONLESS_FLAG = const(128)
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
)
|
||||
super().__init__()
|
256
core/src/storage/cache_thp.py
Normal file
256
core/src/storage/cache_thp.py
Normal file
@ -0,0 +1,256 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
from trezor import utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# THP specific constants
|
||||
_MAX_SESSIONS_COUNT = const(20)
|
||||
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5)
|
||||
_THP_SESSION_STATE_LENGTH = const(1)
|
||||
_SESSION_ID_LENGTH = const(4)
|
||||
BROADCAST_CHANNEL_ID = const(65535)
|
||||
|
||||
|
||||
class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
1, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
1, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
_SESSIONS: list[SessionThpCache] = []
|
||||
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = []
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _SESSIONS
|
||||
global _UNAUTHENTICATED_SESSIONS
|
||||
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionThpCache())
|
||||
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
for session in _UNAUTHENTICATED_SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
# THP vars
|
||||
_next_unauthenicated_session_index: int = 0
|
||||
_is_active_session_authenticated: bool
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
|
||||
cid_counter: int = 4659
|
||||
|
||||
|
||||
def get_active_session_id():
|
||||
if get_active_session() is None:
|
||||
return None
|
||||
return get_active_session().session_id
|
||||
|
||||
|
||||
def get_active_session() -> SessionThpCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
if _is_active_session_authenticated:
|
||||
return _SESSIONS[_active_session_idx]
|
||||
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def get_next_channel_id() -> int:
|
||||
global cid_counter
|
||||
while True:
|
||||
cid_counter += 1
|
||||
if cid_counter >= BROADCAST_CHANNEL_ID:
|
||||
cid_counter = 1
|
||||
if _is_cid_unique():
|
||||
break
|
||||
return cid_counter
|
||||
|
||||
|
||||
def _is_cid_unique() -> bool:
|
||||
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||
if cid_counter == _get_cid(session):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_cid(session: SessionThpCache) -> int:
|
||||
return int.from_bytes(session.session_id[2:], "big")
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(session_id: bytearray) -> SessionThpCache:
|
||||
if len(session_id) != 4:
|
||||
raise ValueError("session_id must be 4 bytes long.")
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
global _next_unauthenicated_session_index
|
||||
|
||||
i = _next_unauthenicated_session_index
|
||||
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
|
||||
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
|
||||
_next_unauthenicated_session_index += 1
|
||||
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
|
||||
_next_unauthenicated_session_index = 0
|
||||
|
||||
# Set session as active if and only if there is no active session
|
||||
if _active_session_idx is None:
|
||||
_active_session_idx = i
|
||||
_is_active_session_authenticated = False
|
||||
return _UNAUTHENTICATED_SESSIONS[i]
|
||||
|
||||
|
||||
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
|
||||
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
|
||||
global _session_usage_counter
|
||||
|
||||
unauth_session_idx = get_unauth_session_index(unauth_session)
|
||||
if unauth_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
|
||||
# replace least recently used authenticated session by the new session
|
||||
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
|
||||
|
||||
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
|
||||
_UNAUTHENTICATED_SESSIONS[unauth_session_idx] = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
|
||||
|
||||
|
||||
def get_least_recently_used_authetnicated_session_index() -> int:
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
return lru_session_idx
|
||||
|
||||
|
||||
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
|
||||
def start_session(session_id: bytes) -> bytes: # TODO incomplete
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
|
||||
if session_id is not None:
|
||||
if get_active_session_id() == session_id:
|
||||
return session_id
|
||||
for index in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = True
|
||||
return session_id
|
||||
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big")
|
||||
|
||||
new_session = create_new_unauthenticated_session(new_session_id)
|
||||
|
||||
index = get_unauth_session_index(new_session)
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
|
||||
return new_session_id
|
||||
|
||||
|
||||
def start_existing_session(session_id: bytearray) -> bytes:
|
||||
if session_id is None:
|
||||
raise ValueError("session_id cannot be None")
|
||||
if get_active_session_id() == session_id:
|
||||
return session_id
|
||||
for index in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = True
|
||||
return session_id
|
||||
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
raise ValueError("There is no active session with provided session_id")
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
active_session = get_active_session()
|
||||
if active_session is None:
|
||||
return
|
||||
active_session.clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||
session.clear()
|
@ -32,6 +32,8 @@ MODEL_IS_T2B1: bool = INTERNAL_MODEL == "T2B1"
|
||||
|
||||
DISABLE_ANIMATION = 0
|
||||
|
||||
USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/...
|
||||
|
||||
if __debug__:
|
||||
if EMULATOR:
|
||||
import uos
|
||||
|
@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
|
||||
|
||||
- Request / response.
|
||||
- Protobuf-encoded, see `protobuf.py`.
|
||||
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
|
||||
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py` or `trezor/wire/thp_v1.py`.
|
||||
- Transferred over USB interface, or UDP in case of Unix emulation.
|
||||
|
||||
This module:
|
||||
@ -23,15 +23,17 @@ reads the message's header. When the message type is known the first handler is
|
||||
|
||||
"""
|
||||
|
||||
from apps import workflow_handlers
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache import InvalidSessionError
|
||||
from storage.cache_codec import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire import codec_v1, context
|
||||
from trezor.wire import codec_v1, context, protocol_common
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||
import trezor.enums.MessageType as MT
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
# other packages.
|
||||
@ -88,8 +90,8 @@ if __debug__:
|
||||
|
||||
|
||||
async def _handle_single_message(
|
||||
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool
|
||||
) -> codec_v1.Message | None:
|
||||
ctx: context.Context, msg: protocol_common.Message, use_workflow: bool
|
||||
) -> protocol_common.Message | None:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
@ -113,7 +115,7 @@ async def _handle_single_message(
|
||||
__name__,
|
||||
"%s:%x receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
ctx.sid,
|
||||
ctx.session_id,
|
||||
msg_type,
|
||||
)
|
||||
|
||||
@ -143,6 +145,10 @@ async def _handle_single_message(
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
if msg.type is MT.Initialize:
|
||||
# Special case for handle_initialize to have access to the verified session_id
|
||||
task = handler(req_msg, ctx.session_id)
|
||||
else:
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
@ -201,7 +207,7 @@ async def handle_session(
|
||||
ctx_buffer = WIRE_BUFFER
|
||||
|
||||
ctx = context.Context(iface, session_id, ctx_buffer)
|
||||
next_msg: codec_v1.Message | None = None
|
||||
next_msg: protocol_common.Message | None = None
|
||||
|
||||
if __debug__ and is_debug_session:
|
||||
import apps.debug
|
||||
@ -218,7 +224,7 @@ async def handle_session(
|
||||
# wait for a new one coming from the wire.
|
||||
try:
|
||||
msg = await ctx.read_from_wire()
|
||||
except codec_v1.CodecError as exc:
|
||||
except protocol_common.WireError as exc:
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
await ctx.write(failure(exc))
|
||||
@ -229,6 +235,9 @@ async def handle_session(
|
||||
msg = next_msg
|
||||
next_msg = None
|
||||
|
||||
# Set ctx.session_id to the value msg.session_id
|
||||
ctx.session_id = msg.session_id
|
||||
|
||||
try:
|
||||
next_msg = await _handle_single_message(
|
||||
ctx, msg, use_workflow=not is_debug_session
|
||||
|
@ -3,6 +3,7 @@ from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import io, loop, utils
|
||||
from trezor.wire.protocol_common import Message, WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
@ -18,16 +19,10 @@ _REP_CONT_DATA = const(1) # offset of data in the continuation report
|
||||
SESSION_ID = const(0)
|
||||
|
||||
|
||||
class CodecError(Exception):
|
||||
class CodecError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, mtype: int, mdata: bytes) -> None:
|
||||
self.type = mtype
|
||||
self.data = mdata
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
|
@ -17,7 +17,8 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import log, loop, protobuf
|
||||
|
||||
from . import codec_v1
|
||||
from .protocol import WireProtocol
|
||||
from .protocol_common import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
@ -48,7 +49,7 @@ class UnexpectedMessage(Exception):
|
||||
should be aborted and a new one started as if `msg` was the first message.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: codec_v1.Message) -> None:
|
||||
def __init__(self, msg: Message) -> None:
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
|
||||
@ -60,14 +61,14 @@ class Context:
|
||||
(i.e., wire, debug, single BT connection, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
||||
def __init__(self, iface: WireInterface, buffer: bytearray) -> None:
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.buffer = buffer
|
||||
self.session_id: bytearray | None = None
|
||||
|
||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||
def read_from_wire(self) -> Awaitable[Message]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
return WireProtocol.read_message(self.iface, self.buffer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -97,7 +98,7 @@ class Context:
|
||||
__name__,
|
||||
"%s:%x expect: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
self.session_id,
|
||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||
)
|
||||
|
||||
@ -109,6 +110,9 @@ class Context:
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessage(msg)
|
||||
|
||||
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
|
||||
# (and maybe update ctx.session_id - depends on expected behaviour)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
@ -117,7 +121,7 @@ class Context:
|
||||
__name__,
|
||||
"%s:%x read: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
self.session_id,
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
@ -133,7 +137,7 @@ class Context:
|
||||
__name__,
|
||||
"%s:%x write: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
self.session_id,
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
@ -151,10 +155,13 @@ class Context:
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
|
||||
await codec_v1.write_message(
|
||||
await WireProtocol.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer)[:msg_size],
|
||||
Message(
|
||||
message_type=msg.MESSAGE_WIRE_TYPE,
|
||||
message_data=memoryview(buffer)[:msg_size],
|
||||
session_id=self.session_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
19
core/src/trezor/wire/protocol.py
Normal file
19
core/src/trezor/wire/protocol.py
Normal file
@ -0,0 +1,19 @@
|
||||
from trezor import utils
|
||||
from trezor.wire import codec_v1, thp_v1
|
||||
from trezor.wire.protocol_common import Message
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
class WireProtocol:
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
if utils.USE_THP:
|
||||
return thp_v1.read_message(iface, buffer)
|
||||
return codec_v1.read_message(iface, buffer)
|
||||
|
||||
async def write_message(iface: WireInterface, message: Message) -> None:
|
||||
if utils.USE_THP:
|
||||
return thp_v1.write_to_wire(iface, message)
|
||||
return codec_v1.write_message(iface, message.type, message.data)
|
14
core/src/trezor/wire/protocol_common.py
Normal file
14
core/src/trezor/wire/protocol_common.py
Normal file
@ -0,0 +1,14 @@
|
||||
class Message:
|
||||
def __init__(
|
||||
self,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
session_id: bytearray | None = None,
|
||||
) -> None:
|
||||
self.type = message_type
|
||||
self.data = message_data
|
||||
self.session_id = session_id
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
166
core/src/trezor/wire/thp_session.py
Normal file
166
core/src/trezor/wire/thp_session.py
Normal file
@ -0,0 +1,166 @@
|
||||
import ustruct
|
||||
from storage import cache_thp as storage_thp_cache
|
||||
from storage.cache_thp import SessionThpCache, BROADCAST_CHANNEL_ID
|
||||
from trezor import io
|
||||
from trezor.wire.protocol_common import WireError
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
else:
|
||||
IntEnum = object
|
||||
|
||||
|
||||
class ThpError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowState(IntEnum):
|
||||
NOT_STARTED = 0
|
||||
PENDING = 1
|
||||
FINISHED = 2
|
||||
|
||||
|
||||
class Workflow:
|
||||
id: int
|
||||
workflow_state: WorkflowState
|
||||
|
||||
|
||||
class SessionState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
|
||||
PAIRED = 2
|
||||
UNPAIRED = 3
|
||||
PAIRING = 4
|
||||
APP_TRAFFIC = 5
|
||||
|
||||
|
||||
def get_workflow() -> Workflow:
|
||||
pass # TODO
|
||||
|
||||
|
||||
def print_all_test_sessions() -> None:
|
||||
for session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
|
||||
if session is None:
|
||||
print("none")
|
||||
else:
|
||||
print(hexlify(session.session_id).decode("utf-8"), session.state)
|
||||
|
||||
|
||||
#
|
||||
def create_autenticated_session(unauthenticated_session: SessionThpCache):
|
||||
storage_thp_cache.start_session() # TODO something like this but for THP
|
||||
raise
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(iface: WireInterface, cid: int):
|
||||
session_id = _get_id(iface, cid)
|
||||
new_session = storage_thp_cache.create_new_unauthenticated_session(session_id)
|
||||
set_session_state(new_session, SessionState.INITIALIZED)
|
||||
|
||||
|
||||
def get_active_session() -> SessionThpCache | None:
|
||||
return storage_thp_cache.get_active_session()
|
||||
|
||||
|
||||
def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None:
|
||||
session_id = _get_id(iface, cid)
|
||||
return get_session_from_id(session_id)
|
||||
|
||||
|
||||
def get_session_from_id(session_id) -> SessionThpCache | None:
|
||||
session = _get_authenticated_session_or_none(session_id)
|
||||
if session is None:
|
||||
session = _get_unauthenticated_session_or_none(session_id)
|
||||
return session
|
||||
|
||||
|
||||
def get_state(session: SessionThpCache) -> int:
|
||||
if session is None:
|
||||
return SessionState.UNALLOCATED
|
||||
return _decode_session_state(session.state)
|
||||
|
||||
|
||||
def get_cid(session: SessionThpCache) -> int:
|
||||
return storage_thp_cache._get_cid(session)
|
||||
|
||||
|
||||
def get_next_channel_id() -> int:
|
||||
return storage_thp_cache.get_next_channel_id()
|
||||
|
||||
|
||||
def sync_can_send_message(session: SessionThpCache) -> bool:
|
||||
return session.sync & 0x80 == 0x80
|
||||
|
||||
|
||||
def sync_get_receive_expected_bit(session: SessionThpCache) -> int:
|
||||
return (session.sync & 0x40) >> 6
|
||||
|
||||
|
||||
def sync_get_send_bit(session: SessionThpCache) -> int:
|
||||
return (session.sync & 0x20) >> 5
|
||||
|
||||
|
||||
def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None:
|
||||
session.sync &= 0x7F
|
||||
if can_send:
|
||||
session.sync |= 0x80
|
||||
|
||||
|
||||
def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
|
||||
if bit != 0 and bit != 1:
|
||||
raise ThpError("Unexpected receive sync bit")
|
||||
|
||||
# set second bit to "bit" value
|
||||
session.sync &= 0xBF
|
||||
session.sync |= 0x40
|
||||
|
||||
|
||||
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:
|
||||
_sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session))
|
||||
|
||||
|
||||
def is_active_session(session: SessionThpCache):
|
||||
if session is None:
|
||||
return False
|
||||
return session.session_id == storage_thp_cache.get_active_session_id()
|
||||
|
||||
|
||||
def set_session_state(session: SessionThpCache, new_state: SessionState):
|
||||
session.state = new_state.to_bytes(1, "big")
|
||||
|
||||
|
||||
def _get_id(iface: WireInterface, cid: int) -> bytearray:
|
||||
return ustruct.pack(">HH", iface.iface_num(), cid)
|
||||
|
||||
|
||||
def _get_authenticated_session_or_none(session_id) -> SessionThpCache:
|
||||
for authenticated_session in storage_thp_cache._SESSIONS:
|
||||
if authenticated_session.session_id == session_id:
|
||||
return authenticated_session
|
||||
return None
|
||||
|
||||
|
||||
def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache:
|
||||
for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
|
||||
if unauthenticated_session.session_id == session_id:
|
||||
return unauthenticated_session
|
||||
return None
|
||||
|
||||
|
||||
def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None:
|
||||
if bit != 0 and bit != 1:
|
||||
raise ThpError("Unexpected send sync bit")
|
||||
|
||||
# set third bit to "bit" value
|
||||
session.sync &= 0xDF
|
||||
session.sync |= 0x20
|
||||
|
||||
|
||||
def _decode_session_state(state: bytearray) -> int:
|
||||
return ustruct.unpack("B", state)[0]
|
||||
|
||||
|
||||
def _encode_session_state(state: SessionState) -> bytearray:
|
||||
return ustruct.pack("B", state)
|
370
core/src/trezor/wire/thp_v1.py
Normal file
370
core/src/trezor/wire/thp_v1.py
Normal file
@ -0,0 +1,370 @@
|
||||
import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import io, loop, utils
|
||||
from trezor.crypto import crc
|
||||
from trezor.wire.protocol_common import Message
|
||||
import trezor.wire.thp_session as THP
|
||||
from trezor.wire.thp_session import (
|
||||
ThpError,
|
||||
SessionState,
|
||||
BROADCAST_CHANNEL_ID,
|
||||
)
|
||||
from ubinascii import hexlify
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
_MAX_PAYLOAD_LEN = const(60000)
|
||||
_CHECKSUM_LENGTH = const(4)
|
||||
_CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x40
|
||||
_ERROR = 0x41
|
||||
_CONTINUATION_PACKET = 0x80
|
||||
_ACK_MESSAGE = 0x20
|
||||
_HANDSHAKE_INIT = 0x00
|
||||
_PLAINTEXT = 0x01
|
||||
ENCRYPTED_TRANSPORT = 0x02
|
||||
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
|
||||
b"\x0A\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
|
||||
)
|
||||
_UNALLOCATED_SESSION_ERROR = (
|
||||
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
|
||||
)
|
||||
|
||||
_REPORT_LENGTH = const(64)
|
||||
_REPORT_INIT_DATA_OFFSET = const(5)
|
||||
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||
|
||||
|
||||
class InitHeader:
|
||||
format_str = ">BHH"
|
||||
|
||||
def __init__(self, ctrl_byte, cid, length) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.length = length
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return ustruct.pack(
|
||||
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
|
||||
)
|
||||
|
||||
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
|
||||
ustruct.pack_into(
|
||||
InitHeader.format_str,
|
||||
buffer,
|
||||
buffer_offset,
|
||||
self.ctrl_byte,
|
||||
self.cid,
|
||||
self.length,
|
||||
)
|
||||
|
||||
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
|
||||
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
|
||||
|
||||
|
||||
class InterruptingInitPacket:
|
||||
def __init__(self, report) -> None:
|
||||
self.initReport = report
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
msg = await read_message_or_init_packet(iface, buffer)
|
||||
while type(msg) is not Message:
|
||||
if msg is InterruptingInitPacket:
|
||||
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
|
||||
else:
|
||||
raise ThpError("Unexpected output of read_message_or_init_packet")
|
||||
return msg
|
||||
|
||||
|
||||
async def read_message_or_init_packet(
|
||||
iface: WireInterface, buffer: utils.BufferType, firstReport=None
|
||||
) -> Message | InterruptingInitPacket:
|
||||
while True:
|
||||
# Wait for an initial report
|
||||
if firstReport is None:
|
||||
report = await _get_loop_wait_read(iface)
|
||||
else:
|
||||
report = firstReport
|
||||
|
||||
# Channel multiplexing
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", report)
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
_handle_broadcast(iface, ctrl_byte, report)
|
||||
continue
|
||||
|
||||
# We allow for only one message to be read simultaneously. We do not
|
||||
# support reading multiple messages with interleaven packets - with
|
||||
# the sole exception of cid_request which can be handled independently.
|
||||
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||
# continuation packet is not expected - ignore
|
||||
continue
|
||||
|
||||
payload_length = ustruct.unpack(">H", report[3:])[0]
|
||||
payload = _get_buffer_for_payload(payload_length, buffer)
|
||||
header = InitHeader(ctrl_byte, cid, payload_length)
|
||||
|
||||
# buffer the received data
|
||||
interruptingPacket = await _buffer_received_data(payload, header, iface, report)
|
||||
if interruptingPacket is not None:
|
||||
return interruptingPacket
|
||||
|
||||
# Check CRC
|
||||
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||
# checksum is not valid -> ignore message
|
||||
continue
|
||||
|
||||
session = THP.get_session(iface, cid)
|
||||
session_state = THP.get_state(session)
|
||||
|
||||
# Handle message on unallocated channel
|
||||
if session_state == SessionState.UNALLOCATED:
|
||||
message = _handle_unallocated(iface, cid)
|
||||
# unallocated should not return regular message, TODO, but it might change
|
||||
if message is not None:
|
||||
return message
|
||||
continue
|
||||
|
||||
# Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
# 1: Handle ACKs
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
_handle_received_ACK(session, sync_bit)
|
||||
continue
|
||||
|
||||
# 2: Handle message with unexpected synchronization bit
|
||||
if sync_bit != THP.sync_get_receive_expected_bit(session):
|
||||
message = _handle_unexpected_sync_bit(iface, cid, sync_bit)
|
||||
# unsynchronized messages should not return regular message, TODO,
|
||||
# but it might change with the cancelation message
|
||||
if message is not None:
|
||||
return message
|
||||
continue
|
||||
|
||||
# 3: Send ACK in response
|
||||
_sendAck(iface, cid, sync_bit)
|
||||
THP.sync_set_receive_expected_bit(session, 1 - sync_bit)
|
||||
|
||||
return _handle_allocated(ctrl_byte, session, payload)
|
||||
|
||||
|
||||
def _get_loop_wait_read(iface: WireInterface):
|
||||
return loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
if payload_length > _MAX_PAYLOAD_LEN:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(_REPORT_LENGTH)
|
||||
raise ("Message too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
async def _buffer_received_data(
|
||||
payload: utils.BufferType, header: InitHeader, iface, report
|
||||
) -> None | InterruptingInitPacket:
|
||||
# buffer the initial data
|
||||
nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET)
|
||||
while nread < header.length:
|
||||
# wait for continuation report
|
||||
report = await _get_loop_wait_read(iface)
|
||||
|
||||
# channel multiplexing
|
||||
cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report)
|
||||
|
||||
# handle broadcast - allows the reading process
|
||||
# to survive interruption by broadcast
|
||||
if cont_cid == BROADCAST_CHANNEL_ID:
|
||||
_handle_broadcast(iface, cont_ctrl_byte, report)
|
||||
continue
|
||||
|
||||
# handle unexpected initiation packet
|
||||
if not _is_ctrl_byte_continuation(cont_ctrl_byte):
|
||||
# TODO possibly add timeout - allow interruption only after a long time
|
||||
return InterruptingInitPacket(report)
|
||||
|
||||
# ignore continuation packets on different channels
|
||||
if cont_cid != header.cid:
|
||||
continue
|
||||
|
||||
# buffer the continuation data
|
||||
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
|
||||
|
||||
|
||||
async def write_message(
|
||||
iface: WireInterface, message: Message, is_retransmission: bool = False
|
||||
) -> None:
|
||||
session = THP.get_session_from_id(message.session_id)
|
||||
cid = THP.get_cid(session)
|
||||
payload = message.type.to_bytes(2, "big") + message.data
|
||||
payload_length = len(payload)
|
||||
|
||||
if THP.get_state(session) == SessionState.INITIALIZED:
|
||||
# write message in plaintext, TODO check if it is allowed
|
||||
ctrl_byte = _PLAINTEXT
|
||||
elif THP.get_state(session) == SessionState.APP_TRAFFIC:
|
||||
ctrl_byte = ENCRYPTED_TRANSPORT
|
||||
else:
|
||||
raise ThpError("Session in not implemented state" + str(THP.get_state(session)))
|
||||
|
||||
if not is_retransmission:
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(
|
||||
ctrl_byte, THP.sync_get_send_bit(session)
|
||||
)
|
||||
THP.sync_set_send_bit_to_opposite(session)
|
||||
else:
|
||||
# retransmission must have the same sync bit as the previously sent message
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(ctrl_byte, 1 - THP.sync_get_send_bit())
|
||||
|
||||
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
|
||||
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
|
||||
await write_to_wire(iface, header, payload + checksum)
|
||||
# TODO set timeout for retransmission
|
||||
|
||||
|
||||
async def write_to_wire(
|
||||
iface: WireInterface, header: InitHeader, payload: bytes
|
||||
) -> None:
|
||||
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
|
||||
payload_length = len(payload)
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(_REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
|
||||
await _write_report(write, iface, report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_length:
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
while nwritten < payload_length:
|
||||
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
|
||||
await _write_report(write, iface, report)
|
||||
|
||||
|
||||
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
|
||||
while True:
|
||||
await write
|
||||
n = iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
||||
|
||||
|
||||
def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None:
|
||||
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
||||
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
||||
length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:])
|
||||
|
||||
if not _is_checksum_valid(checksum, data=report[:-4]):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
channel_id = _get_new_channel_id()
|
||||
THP.create_new_unauthenticated_session(iface, channel_id)
|
||||
response_data = (
|
||||
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
|
||||
)
|
||||
|
||||
response_header = InitHeader(
|
||||
_CHANNEL_ALLOCATION_RES,
|
||||
BROADCAST_CHANNEL_ID,
|
||||
len(response_data) + _CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
|
||||
write_to_wire(iface, response_header, response_data + checksum)
|
||||
|
||||
|
||||
def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
|
||||
# Parameters session and ctrl_byte will be used to determine if the
|
||||
# communication should be encrypted or not
|
||||
|
||||
message_type = ustruct.unpack(">H", payload)[0]
|
||||
|
||||
# trim message type and checksum from payload
|
||||
message_data = payload[2:-_CHECKSUM_LENGTH]
|
||||
return Message(message_type, message_data, session.session_id)
|
||||
|
||||
|
||||
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
|
||||
# No ACKs expected
|
||||
if THP.sync_can_send_message(session):
|
||||
return
|
||||
|
||||
# ACK has incorrect sync bit
|
||||
if THP.sync_get_send_bit(session) != sync_bit:
|
||||
return
|
||||
|
||||
# ACK is expected and it has correct sync bit
|
||||
THP.sync_set_can_send_message(session, True)
|
||||
|
||||
|
||||
async def _handle_unallocated(iface, cid) -> Message | None:
|
||||
data = _UNALLOCATED_SESSION_ERROR
|
||||
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
|
||||
checksum = _compute_checksum_bytes(header.to_bytes() + data)
|
||||
write_to_wire(iface, header, data + checksum)
|
||||
|
||||
|
||||
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
||||
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
|
||||
checksum = _compute_checksum_bytes(header.to_bytes())
|
||||
write_to_wire(iface, header, checksum)
|
||||
|
||||
|
||||
def _handle_unexpected_sync_bit(
|
||||
iface: WireInterface, cid: int, sync_bit: int
|
||||
) -> Message | None:
|
||||
_sendAck(iface, cid, sync_bit)
|
||||
|
||||
# TODO handle cancelation messages and messages on allocated channels without synchronization
|
||||
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
|
||||
|
||||
|
||||
def _get_new_channel_id() -> int:
|
||||
return THP.get_next_channel_id()
|
||||
|
||||
|
||||
def _is_checksum_valid(checksum: bytearray, data: bytearray) -> bool:
|
||||
data_checksum = _compute_checksum_bytes(data)
|
||||
return checksum == data_checksum
|
||||
|
||||
|
||||
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
||||
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
|
||||
|
||||
|
||||
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
|
||||
return ctrl_byte & 0x20 == _ACK_MESSAGE
|
||||
|
||||
|
||||
def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if sync_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise ThpError("Unexpected synchronization bit")
|
||||
|
||||
|
||||
def _compute_checksum_bytes(data: bytearray):
|
||||
return crc.crc32(data).to_bytes(4, "big")
|
@ -1,17 +1,26 @@
|
||||
from common import * # isort:skip
|
||||
|
||||
from mock_storage import mock_storage
|
||||
from storage import cache
|
||||
from trezor.messages import EndSession, Initialize
|
||||
|
||||
from storage import cache, cache_codec, cache_thp
|
||||
from storage.cache_common import InvalidSessionError
|
||||
from trezor import utils
|
||||
from trezor.messages import Initialize
|
||||
from trezor.messages import EndSession
|
||||
|
||||
from apps.base import handle_EndSession, handle_Initialize
|
||||
|
||||
KEY = 0
|
||||
|
||||
if utils.USE_THP:
|
||||
_PROTOCOL_CACHE = cache_thp
|
||||
else:
|
||||
_PROTOCOL_CACHE = cache_codec
|
||||
|
||||
|
||||
# Function moved from cache.py, as it was not used there
|
||||
def is_session_started() -> bool:
|
||||
return cache._active_session_idx is not None
|
||||
return _PROTOCOL_CACHE.get_active_session() is not None
|
||||
|
||||
|
||||
class TestStorageCache(unittest.TestCase):
|
||||
@ -25,9 +34,9 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertNotEqual(session_id_a, session_id_b)
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
with self.assertRaises(InvalidSessionError):
|
||||
cache.set(KEY, "something")
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
with self.assertRaises(InvalidSessionError):
|
||||
cache.get(KEY)
|
||||
|
||||
def test_end_session(self):
|
||||
@ -36,7 +45,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
cache.set(KEY, b"A")
|
||||
cache.end_current_session()
|
||||
self.assertFalse(is_session_started())
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
self.assertRaises(InvalidSessionError, cache.get, KEY)
|
||||
|
||||
# ending an ended session should be a no-op
|
||||
cache.end_current_session()
|
||||
@ -63,7 +72,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
session_id = cache.start_session()
|
||||
self.assertEqual(cache.start_session(session_id), session_id)
|
||||
cache.set(KEY, b"A")
|
||||
for i in range(cache._MAX_SESSIONS_COUNT):
|
||||
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
|
||||
cache.start_session()
|
||||
self.assertNotEqual(cache.start_session(session_id), session_id)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
@ -83,7 +92,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
with self.assertRaises(InvalidSessionError):
|
||||
cache.get(KEY)
|
||||
|
||||
def test_get_set_int(self):
|
||||
@ -101,7 +110,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get_int(KEY), 1234)
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
with self.assertRaises(InvalidSessionError):
|
||||
cache.get_int(KEY)
|
||||
|
||||
def test_delete(self):
|
||||
@ -186,6 +195,9 @@ class TestStorageCache(unittest.TestCase):
|
||||
|
||||
@mock_storage
|
||||
def test_Initialize(self):
|
||||
if utils.USE_THP: # INITIALIZE SHOULD NOT BE IN THP!!! TODO
|
||||
return
|
||||
|
||||
def call_Initialize(**kwargs):
|
||||
msg = Initialize(**kwargs)
|
||||
return await_result(handle_Initialize(msg))
|
||||
@ -210,7 +222,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
|
||||
# supplying a different session ID starts a new cache
|
||||
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
|
||||
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
# but resuming a session loads the previous one
|
||||
@ -218,13 +230,18 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
|
||||
def test_EndSession(self):
|
||||
<<<<<<< HEAD
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
cache.start_session()
|
||||
=======
|
||||
self.assertRaises(InvalidSessionError, cache.get, KEY)
|
||||
session_id = cache.start_session()
|
||||
>>>>>>> 8681ba167 (Basic THP functinality - not-polished prototype)
|
||||
self.assertTrue(is_session_started())
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
await_result(handle_EndSession(EndSession()))
|
||||
self.assertFalse(is_session_started())
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
self.assertRaises(InvalidSessionError, cache.get, KEY)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
341
core/tests/test_trezor.wire.thp_v1.py
Normal file
341
core/tests/test_trezor.wire.thp_v1.py
Normal file
@ -0,0 +1,341 @@
|
||||
from common import *
|
||||
from ubinascii import hexlify, unhexlify
|
||||
import ustruct
|
||||
|
||||
from trezor import io, utils
|
||||
from trezor.loop import wait
|
||||
from trezor.utils import chunks
|
||||
from trezor.wire import thp_v1
|
||||
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID
|
||||
from trezor.wire.protocol_common import Message
|
||||
import trezor.wire.thp_session as THP
|
||||
|
||||
from micropython import const
|
||||
|
||||
|
||||
class MockHID:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.data = []
|
||||
|
||||
def iface_num(self):
|
||||
return self.num
|
||||
|
||||
def write(self, msg):
|
||||
self.data.append(bytearray(msg))
|
||||
return len(msg)
|
||||
|
||||
def wait_object(self, mode):
|
||||
return wait(mode | self.num)
|
||||
|
||||
|
||||
MESSAGE_TYPE = 0x4242
|
||||
MESSAGE_TYPE_BYTES = b"\x42\x42"
|
||||
_MESSAGE_TYPE_LEN = 2
|
||||
PLAINTEXT_0 = 0x01
|
||||
PLAINTEXT_1 = 0x11
|
||||
COMMON_CID = 4660
|
||||
CONT = 0x80
|
||||
|
||||
HEADER_INIT_LENGTH = 5
|
||||
HEADER_CONT_LENGTH = 3
|
||||
INIT_MESSAGE_DATA_LENGTH = (
|
||||
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
|
||||
)
|
||||
|
||||
|
||||
def make_header(ctrl_byte, cid, length):
|
||||
return ustruct.pack(">BHH", ctrl_byte, cid, length)
|
||||
|
||||
|
||||
def make_cont_header():
|
||||
return ustruct.pack(">BH", CONT, COMMON_CID)
|
||||
|
||||
|
||||
def makeSimpleMessage(header, message_type, message_data):
|
||||
return header + ustruct.pack(">H", message_type) + message_data
|
||||
|
||||
|
||||
def makeCidRequest(header, message_data):
|
||||
return header + message_data
|
||||
|
||||
|
||||
def printBytes(a):
|
||||
print(hexlify(a).decode("utf-8"))
|
||||
|
||||
|
||||
def getPlaintext() -> bytes:
|
||||
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
|
||||
return PLAINTEXT_1
|
||||
PLAINTEXT_0
|
||||
|
||||
|
||||
def getCid() -> int:
|
||||
return THP.get_cid(THP.get_active_session())
|
||||
|
||||
|
||||
# This test suite is an adaptation of test_trezor.wire.codec_v1
|
||||
class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
if not utils.USE_THP:
|
||||
import storage.cache_thp # noQA:F401
|
||||
|
||||
def test_simple(self):
|
||||
cid_req_header = make_header(
|
||||
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
|
||||
)
|
||||
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3C\x6C"
|
||||
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
|
||||
|
||||
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
|
||||
cid_request_dummy_data_checksum = b"\x67\x8E\xAC\xE0"
|
||||
message = makeSimpleMessage(
|
||||
message_header,
|
||||
MESSAGE_TYPE,
|
||||
cid_request_dummy_data + cid_request_dummy_data_checksum,
|
||||
)
|
||||
|
||||
buffer = bytearray(64)
|
||||
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(cid_req_message)
|
||||
gen.send(message)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, cid_request_dummy_data)
|
||||
|
||||
buffer_without_zeroes = buffer[: len(message) - 5]
|
||||
message_without_header = message[5:]
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer_without_zeroes, message_without_header)
|
||||
|
||||
def test_read_one_packet(self):
|
||||
# zero length message - just a header
|
||||
PLAINTEXT = getPlaintext()
|
||||
header = make_header(
|
||||
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
||||
)
|
||||
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
||||
message = header + MESSAGE_TYPE_BYTES + checksum
|
||||
|
||||
buffer = bytearray(64)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(message)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, b"")
|
||||
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58)
|
||||
|
||||
def test_read_many_packets(self):
|
||||
message = bytes(range(256))
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
||||
)
|
||||
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
||||
# message = MESSAGE_TYPE_BYTES + message + checksum
|
||||
|
||||
# first packet is init header + 59 bytes of data
|
||||
# other packets are cont header + 61 bytes of data
|
||||
cont_header = make_cont_header()
|
||||
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
||||
64 - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
buffer = bytearray(262)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in packets[:-1]:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
query = gen.send(packet)
|
||||
|
||||
# last packet will stop
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(packets[-1])
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# message should have been read into the buffer )
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum)
|
||||
|
||||
def test_read_large_message(self):
|
||||
message = b"hello world"
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
packet = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ message
|
||||
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
||||
)
|
||||
|
||||
# make sure we fit into one packet, to make this easier
|
||||
self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH)
|
||||
|
||||
buffer = bytearray(1)
|
||||
self.assertTrue(len(buffer) <= len(packet))
|
||||
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(packet)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# read should have allocated its own buffer and not touch ours
|
||||
self.assertEqual(buffer, b"\x00")
|
||||
|
||||
def test_write_one_packet(self):
|
||||
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID))
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
header = make_header(
|
||||
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
||||
)
|
||||
expected_message = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
||||
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH)
|
||||
)
|
||||
self.assertTrue(self.interface.data == [expected_message])
|
||||
|
||||
def test_write_multiple_packets(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = Message(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
|
||||
header = make_header(
|
||||
PLAINTEXT_1,
|
||||
COMMON_CID,
|
||||
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
||||
)
|
||||
cont_header = make_cont_header()
|
||||
checksum = thp_v1._compute_checksum_bytes(
|
||||
header + message.type.to_bytes(2, "big") + message.data
|
||||
)
|
||||
packets = [
|
||||
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
|
||||
] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
||||
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
|
||||
for _ in packets:
|
||||
# we receive as many queries as there are packets
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
# the first sent None only started the generator. the len(packets)-th None
|
||||
# will finish writing and raise StopIteration
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
# packets must be identical up to the last one
|
||||
self.assertListEqual(packets[:-1], self.interface.data[:-1])
|
||||
# last packet must be identical up to message length. remaining bytes in
|
||||
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||
# previous packet
|
||||
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||
self.assertEqual(last_packet, self.interface.data[-1])
|
||||
|
||||
def test_roundtrip(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = Message(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
|
||||
# exhaust the iterator:
|
||||
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||
for query in gen:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
buffer = bytearray(1024)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in self.interface.data[:-1]:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
query = gen.send(packet)
|
||||
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(self.interface.data[-1])
|
||||
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message.data)
|
||||
|
||||
def test_read_huge_packet(self):
|
||||
PACKET_COUNT = 1180
|
||||
# message that takes up 1 180 USB packets
|
||||
message_size = (PACKET_COUNT - 1) * (
|
||||
thp_v1._REPORT_LENGTH
|
||||
- HEADER_CONT_LENGTH
|
||||
- _CHECKSUM_LENGTH
|
||||
- _MESSAGE_TYPE_LEN
|
||||
) + INIT_MESSAGE_DATA_LENGTH
|
||||
|
||||
# ensure that a message this big won't fit into memory
|
||||
# Note: this control is changed, because THP has only 2 byte length field
|
||||
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN)
|
||||
# self.assertRaises(MemoryError, bytearray, message_size)
|
||||
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
|
||||
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
|
||||
buffer = bytearray(65536)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
|
||||
# THP returns "Message too large" error after reading the message size,
|
||||
# it is different from codec_v1 as it does not allow big enough messages
|
||||
# to raise MemoryError in this test
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(thp_v1.ThpError) as e:
|
||||
query = gen.send(packet)
|
||||
|
||||
self.assertEqual(e.value.args[0], "Message too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -252,16 +252,16 @@ def run_class(c, test_result):
|
||||
raise RuntimeError(f"{name} should not return a result.")
|
||||
finally:
|
||||
tear_down()
|
||||
print(" ok")
|
||||
print("\033[32mok\033[0m")
|
||||
except SkipTest as e:
|
||||
print(" skipped:", e.args[0])
|
||||
test_result.skippedNum += 1
|
||||
except AssertionError as e:
|
||||
print(" failed")
|
||||
print("\033[31mfailed\033[0m")
|
||||
sys.print_exception(e)
|
||||
test_result.failuresNum += 1
|
||||
except BaseException as e:
|
||||
print(" errored:", e)
|
||||
print("\033[31merrored:\033[0m", e)
|
||||
sys.print_exception(e)
|
||||
test_result.errorsNum += 1
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user