parent
0d7fe7d643
commit
87365a3148
@ -0,0 +1,144 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
from trezor import utils
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
_SESSION_ID_LENGTH = const(32)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
1, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
1, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
_SESSIONS: list[SessionCache] = []
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _SESSIONS
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionCache())
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def get_active_session() -> SessionCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
return _SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
global _active_session_idx
|
||||
global _session_usage_counter
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
# that wrong-length session ids should not be in cache.
|
||||
# Reduce to "session id not provided" case because that's what we do when
|
||||
# caller supplies an id that is not found.
|
||||
received_session_id = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
|
||||
# attempt to find specified session id
|
||||
if received_session_id:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].session_id == received_session_id:
|
||||
_active_session_idx = i
|
||||
_SESSIONS[i].last_usage = _session_usage_counter
|
||||
return received_session_id
|
||||
|
||||
# allocate least recently used session
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
|
||||
_active_session_idx = lru_session_idx
|
||||
selected_session = _SESSIONS[lru_session_idx]
|
||||
selected_session.clear()
|
||||
selected_session.last_usage = _session_usage_counter
|
||||
return selected_session.export_session_id()
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
|
||||
if _active_session_idx is None:
|
||||
return
|
||||
|
||||
_SESSIONS[_active_session_idx].clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS:
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
@ -0,0 +1,70 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
SESSIONLESS_FLAG = const(128)
|
||||
|
||||
|
||||
class InvalidSessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DataCache:
|
||||
fields: Sequence[int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields))
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def is_set(self, key: int) -> bool:
|
||||
utils.ensure(key < len(self.fields))
|
||||
return self.data[key][0] == 1
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.delete(i)
|
||||
|
||||
|
||||
class SessionlessCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
|
||||
8, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
|
||||
8, # APP_COMMON_BUSY_DEADLINE_MS
|
||||
32, # APP_MISC_COSI_NONCE
|
||||
32, # APP_MISC_COSI_COMMITMENT
|
||||
)
|
||||
super().__init__()
|
@ -0,0 +1,256 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
from trezor import utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# THP specific constants
|
||||
_MAX_SESSIONS_COUNT = const(20)
|
||||
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5)
|
||||
_THP_SESSION_STATE_LENGTH = const(1)
|
||||
_SESSION_ID_LENGTH = const(4)
|
||||
BROADCAST_CHANNEL_ID = const(65535)
|
||||
|
||||
|
||||
class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
)
|
||||
else:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
2, # APP_COMMON_AUTHORIZATION_TYPE
|
||||
128, # APP_COMMON_AUTHORIZATION_DATA
|
||||
32, # APP_COMMON_NONCE
|
||||
1, # APP_COMMON_DERIVE_CARDANO
|
||||
96, # APP_CARDANO_ICARUS_SECRET
|
||||
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
|
||||
1, # APP_MONERO_LIVE_REFRESH
|
||||
)
|
||||
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.last_usage = 0
|
||||
self.session_id[:] = b""
|
||||
|
||||
|
||||
_SESSIONS: list[SessionThpCache] = []
|
||||
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = []
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _SESSIONS
|
||||
global _UNAUTHENTICATED_SESSIONS
|
||||
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionThpCache())
|
||||
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
|
||||
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
for session in _UNAUTHENTICATED_SESSIONS:
|
||||
session.clear()
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
# THP vars
|
||||
_next_unauthenicated_session_index: int = 0
|
||||
_is_active_session_authenticated: bool
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
|
||||
cid_counter: int = 4659
|
||||
|
||||
|
||||
def get_active_session_id():
|
||||
if get_active_session() is None:
|
||||
return None
|
||||
return get_active_session().session_id
|
||||
|
||||
|
||||
def get_active_session() -> SessionThpCache | None:
|
||||
if _active_session_idx is None:
|
||||
return None
|
||||
if _is_active_session_authenticated:
|
||||
return _SESSIONS[_active_session_idx]
|
||||
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def get_next_channel_id() -> int:
|
||||
global cid_counter
|
||||
while True:
|
||||
cid_counter += 1
|
||||
if cid_counter >= BROADCAST_CHANNEL_ID:
|
||||
cid_counter = 1
|
||||
if _is_cid_unique():
|
||||
break
|
||||
return cid_counter
|
||||
|
||||
|
||||
def _is_cid_unique() -> bool:
|
||||
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||
if cid_counter == _get_cid(session):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_cid(session: SessionThpCache) -> int:
|
||||
return int.from_bytes(session.session_id[2:], "big")
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(session_id: bytearray) -> SessionThpCache:
|
||||
if len(session_id) != 4:
|
||||
raise ValueError("session_id must be 4 bytes long.")
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
global _next_unauthenicated_session_index
|
||||
|
||||
i = _next_unauthenicated_session_index
|
||||
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
|
||||
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
|
||||
_next_unauthenicated_session_index += 1
|
||||
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
|
||||
_next_unauthenicated_session_index = 0
|
||||
|
||||
# Set session as active if and only if there is no active session
|
||||
if _active_session_idx is None:
|
||||
_active_session_idx = i
|
||||
_is_active_session_authenticated = False
|
||||
return _UNAUTHENTICATED_SESSIONS[i]
|
||||
|
||||
|
||||
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
|
||||
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
|
||||
global _session_usage_counter
|
||||
|
||||
unauth_session_idx = get_unauth_session_index(unauth_session)
|
||||
if unauth_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
|
||||
# replace least recently used authenticated session by the new session
|
||||
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
|
||||
|
||||
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
|
||||
_UNAUTHENTICATED_SESSIONS[unauth_session_idx] = None
|
||||
|
||||
_session_usage_counter += 1
|
||||
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
|
||||
|
||||
|
||||
def get_least_recently_used_authetnicated_session_index() -> int:
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
return lru_session_idx
|
||||
|
||||
|
||||
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
|
||||
def start_session(session_id: bytes) -> bytes: # TODO incomplete
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
|
||||
if session_id is not None:
|
||||
if get_active_session_id() == session_id:
|
||||
return session_id
|
||||
for index in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = True
|
||||
return session_id
|
||||
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big")
|
||||
|
||||
new_session = create_new_unauthenticated_session(new_session_id)
|
||||
|
||||
index = get_unauth_session_index(new_session)
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
|
||||
return new_session_id
|
||||
|
||||
|
||||
def start_existing_session(session_id: bytearray) -> bytes:
|
||||
if session_id is None:
|
||||
raise ValueError("session_id cannot be None")
|
||||
if get_active_session_id() == session_id:
|
||||
return session_id
|
||||
for index in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = True
|
||||
return session_id
|
||||
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
raise ValueError("There is no active session with provided session_id")
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_idx
|
||||
active_session = get_active_session()
|
||||
if active_session is None:
|
||||
return
|
||||
active_session.clear()
|
||||
_active_session_idx = None
|
||||
|
||||
|
||||
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
||||
values = builtins.set()
|
||||
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
|
||||
encoded = session.get(key)
|
||||
if encoded is not None:
|
||||
values.add(int.from_bytes(encoded, "big"))
|
||||
return values
|
||||
|
||||
|
||||
def clear_all() -> None:
|
||||
global _active_session_idx
|
||||
_active_session_idx = None
|
||||
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||
session.clear()
|
@ -0,0 +1,19 @@
|
||||
from trezor import utils
|
||||
from trezor.wire import codec_v1, thp_v1
|
||||
from trezor.wire.protocol_common import Message
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
class WireProtocol:
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
if utils.USE_THP:
|
||||
return thp_v1.read_message(iface, buffer)
|
||||
return codec_v1.read_message(iface, buffer)
|
||||
|
||||
async def write_message(iface: WireInterface, message: Message) -> None:
|
||||
if utils.USE_THP:
|
||||
return thp_v1.write_to_wire(iface, message)
|
||||
return codec_v1.write_message(iface, message.type, message.data)
|
@ -0,0 +1,14 @@
|
||||
class Message:
|
||||
def __init__(
|
||||
self,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
session_id: bytearray | None = None,
|
||||
) -> None:
|
||||
self.type = message_type
|
||||
self.data = message_data
|
||||
self.session_id = session_id
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
@ -0,0 +1,166 @@
|
||||
import ustruct
|
||||
from storage import cache_thp as storage_thp_cache
|
||||
from storage.cache_thp import SessionThpCache, BROADCAST_CHANNEL_ID
|
||||
from trezor import io
|
||||
from trezor.wire.protocol_common import WireError
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
else:
|
||||
IntEnum = object
|
||||
|
||||
|
||||
class ThpError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowState(IntEnum):
|
||||
NOT_STARTED = 0
|
||||
PENDING = 1
|
||||
FINISHED = 2
|
||||
|
||||
|
||||
class Workflow:
|
||||
id: int
|
||||
workflow_state: WorkflowState
|
||||
|
||||
|
||||
class SessionState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
|
||||
PAIRED = 2
|
||||
UNPAIRED = 3
|
||||
PAIRING = 4
|
||||
APP_TRAFFIC = 5
|
||||
|
||||
|
||||
def get_workflow() -> Workflow:
|
||||
pass # TODO
|
||||
|
||||
|
||||
def print_all_test_sessions() -> None:
|
||||
for session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
|
||||
if session is None:
|
||||
print("none")
|
||||
else:
|
||||
print(hexlify(session.session_id).decode("utf-8"), session.state)
|
||||
|
||||
|
||||
#
|
||||
def create_autenticated_session(unauthenticated_session: SessionThpCache):
|
||||
storage_thp_cache.start_session() # TODO something like this but for THP
|
||||
raise
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(iface: WireInterface, cid: int):
|
||||
session_id = _get_id(iface, cid)
|
||||
new_session = storage_thp_cache.create_new_unauthenticated_session(session_id)
|
||||
set_session_state(new_session, SessionState.INITIALIZED)
|
||||
|
||||
|
||||
def get_active_session() -> SessionThpCache | None:
|
||||
return storage_thp_cache.get_active_session()
|
||||
|
||||
|
||||
def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None:
|
||||
session_id = _get_id(iface, cid)
|
||||
return get_session_from_id(session_id)
|
||||
|
||||
|
||||
def get_session_from_id(session_id) -> SessionThpCache | None:
|
||||
session = _get_authenticated_session_or_none(session_id)
|
||||
if session is None:
|
||||
session = _get_unauthenticated_session_or_none(session_id)
|
||||
return session
|
||||
|
||||
|
||||
def get_state(session: SessionThpCache) -> int:
|
||||
if session is None:
|
||||
return SessionState.UNALLOCATED
|
||||
return _decode_session_state(session.state)
|
||||
|
||||
|
||||
def get_cid(session: SessionThpCache) -> int:
|
||||
return storage_thp_cache._get_cid(session)
|
||||
|
||||
|
||||
def get_next_channel_id() -> int:
|
||||
return storage_thp_cache.get_next_channel_id()
|
||||
|
||||
|
||||
def sync_can_send_message(session: SessionThpCache) -> bool:
|
||||
return session.sync & 0x80 == 0x80
|
||||
|
||||
|
||||
def sync_get_receive_expected_bit(session: SessionThpCache) -> int:
|
||||
return (session.sync & 0x40) >> 6
|
||||
|
||||
|
||||
def sync_get_send_bit(session: SessionThpCache) -> int:
|
||||
return (session.sync & 0x20) >> 5
|
||||
|
||||
|
||||
def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None:
|
||||
session.sync &= 0x7F
|
||||
if can_send:
|
||||
session.sync |= 0x80
|
||||
|
||||
|
||||
def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
|
||||
if bit != 0 and bit != 1:
|
||||
raise ThpError("Unexpected receive sync bit")
|
||||
|
||||
# set second bit to "bit" value
|
||||
session.sync &= 0xBF
|
||||
session.sync |= 0x40
|
||||
|
||||
|
||||
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:
|
||||
_sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session))
|
||||
|
||||
|
||||
def is_active_session(session: SessionThpCache):
|
||||
if session is None:
|
||||
return False
|
||||
return session.session_id == storage_thp_cache.get_active_session_id()
|
||||
|
||||
|
||||
def set_session_state(session: SessionThpCache, new_state: SessionState):
|
||||
session.state = new_state.to_bytes(1, "big")
|
||||
|
||||
|
||||
def _get_id(iface: WireInterface, cid: int) -> bytearray:
|
||||
return ustruct.pack(">HH", iface.iface_num(), cid)
|
||||
|
||||
|
||||
def _get_authenticated_session_or_none(session_id) -> SessionThpCache:
|
||||
for authenticated_session in storage_thp_cache._SESSIONS:
|
||||
if authenticated_session.session_id == session_id:
|
||||
return authenticated_session
|
||||
return None
|
||||
|
||||
|
||||
def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache:
|
||||
for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
|
||||
if unauthenticated_session.session_id == session_id:
|
||||
return unauthenticated_session
|
||||
return None
|
||||
|
||||
|
||||
def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None:
|
||||
if bit != 0 and bit != 1:
|
||||
raise ThpError("Unexpected send sync bit")
|
||||
|
||||
# set third bit to "bit" value
|
||||
session.sync &= 0xDF
|
||||
session.sync |= 0x20
|
||||
|
||||
|
||||
def _decode_session_state(state: bytearray) -> int:
|
||||
return ustruct.unpack("B", state)[0]
|
||||
|
||||
|
||||
def _encode_session_state(state: SessionState) -> bytearray:
|
||||
return ustruct.pack("B", state)
|
@ -0,0 +1,370 @@
|
||||
import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import io, loop, utils
|
||||
from trezor.crypto import crc
|
||||
from trezor.wire.protocol_common import Message
|
||||
import trezor.wire.thp_session as THP
|
||||
from trezor.wire.thp_session import (
|
||||
ThpError,
|
||||
SessionState,
|
||||
BROADCAST_CHANNEL_ID,
|
||||
)
|
||||
from ubinascii import hexlify
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
_MAX_PAYLOAD_LEN = const(60000)
|
||||
_CHECKSUM_LENGTH = const(4)
|
||||
_CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x40
|
||||
_ERROR = 0x41
|
||||
_CONTINUATION_PACKET = 0x80
|
||||
_ACK_MESSAGE = 0x20
|
||||
_HANDSHAKE_INIT = 0x00
|
||||
_PLAINTEXT = 0x01
|
||||
ENCRYPTED_TRANSPORT = 0x02
|
||||
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
|
||||
b"\x0A\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
|
||||
)
|
||||
_UNALLOCATED_SESSION_ERROR = (
|
||||
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
|
||||
)
|
||||
|
||||
_REPORT_LENGTH = const(64)
|
||||
_REPORT_INIT_DATA_OFFSET = const(5)
|
||||
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||
|
||||
|
||||
class InitHeader:
|
||||
format_str = ">BHH"
|
||||
|
||||
def __init__(self, ctrl_byte, cid, length) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.length = length
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return ustruct.pack(
|
||||
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
|
||||
)
|
||||
|
||||
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
|
||||
ustruct.pack_into(
|
||||
InitHeader.format_str,
|
||||
buffer,
|
||||
buffer_offset,
|
||||
self.ctrl_byte,
|
||||
self.cid,
|
||||
self.length,
|
||||
)
|
||||
|
||||
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
|
||||
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
|
||||
|
||||
|
||||
class InterruptingInitPacket:
|
||||
def __init__(self, report) -> None:
|
||||
self.initReport = report
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
msg = await read_message_or_init_packet(iface, buffer)
|
||||
while type(msg) is not Message:
|
||||
if msg is InterruptingInitPacket:
|
||||
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
|
||||
else:
|
||||
raise ThpError("Unexpected output of read_message_or_init_packet")
|
||||
return msg
|
||||
|
||||
|
||||
async def read_message_or_init_packet(
|
||||
iface: WireInterface, buffer: utils.BufferType, firstReport=None
|
||||
) -> Message | InterruptingInitPacket:
|
||||
while True:
|
||||
# Wait for an initial report
|
||||
if firstReport is None:
|
||||
report = await _get_loop_wait_read(iface)
|
||||
else:
|
||||
report = firstReport
|
||||
|
||||
# Channel multiplexing
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", report)
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
_handle_broadcast(iface, ctrl_byte, report)
|
||||
continue
|
||||
|
||||
# We allow for only one message to be read simultaneously. We do not
|
||||
# support reading multiple messages with interleaven packets - with
|
||||
# the sole exception of cid_request which can be handled independently.
|
||||
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||
# continuation packet is not expected - ignore
|
||||
continue
|
||||
|
||||
payload_length = ustruct.unpack(">H", report[3:])[0]
|
||||
payload = _get_buffer_for_payload(payload_length, buffer)
|
||||
header = InitHeader(ctrl_byte, cid, payload_length)
|
||||
|
||||
# buffer the received data
|
||||
interruptingPacket = await _buffer_received_data(payload, header, iface, report)
|
||||
if interruptingPacket is not None:
|
||||
return interruptingPacket
|
||||
|
||||
# Check CRC
|
||||
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||
# checksum is not valid -> ignore message
|
||||
continue
|
||||
|
||||
session = THP.get_session(iface, cid)
|
||||
session_state = THP.get_state(session)
|
||||
|
||||
# Handle message on unallocated channel
|
||||
if session_state == SessionState.UNALLOCATED:
|
||||
message = _handle_unallocated(iface, cid)
|
||||
# unallocated should not return regular message, TODO, but it might change
|
||||
if message is not None:
|
||||
return message
|
||||
continue
|
||||
|
||||
# Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
# 1: Handle ACKs
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
_handle_received_ACK(session, sync_bit)
|
||||
continue
|
||||
|
||||
# 2: Handle message with unexpected synchronization bit
|
||||
if sync_bit != THP.sync_get_receive_expected_bit(session):
|
||||
message = _handle_unexpected_sync_bit(iface, cid, sync_bit)
|
||||
# unsynchronized messages should not return regular message, TODO,
|
||||
# but it might change with the cancelation message
|
||||
if message is not None:
|
||||
return message
|
||||
continue
|
||||
|
||||
# 3: Send ACK in response
|
||||
_sendAck(iface, cid, sync_bit)
|
||||
THP.sync_set_receive_expected_bit(session, 1 - sync_bit)
|
||||
|
||||
return _handle_allocated(ctrl_byte, session, payload)
|
||||
|
||||
|
||||
def _get_loop_wait_read(iface: WireInterface):
|
||||
return loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
if payload_length > _MAX_PAYLOAD_LEN:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(_REPORT_LENGTH)
|
||||
raise ("Message too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
async def _buffer_received_data(
|
||||
payload: utils.BufferType, header: InitHeader, iface, report
|
||||
) -> None | InterruptingInitPacket:
|
||||
# buffer the initial data
|
||||
nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET)
|
||||
while nread < header.length:
|
||||
# wait for continuation report
|
||||
report = await _get_loop_wait_read(iface)
|
||||
|
||||
# channel multiplexing
|
||||
cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report)
|
||||
|
||||
# handle broadcast - allows the reading process
|
||||
# to survive interruption by broadcast
|
||||
if cont_cid == BROADCAST_CHANNEL_ID:
|
||||
_handle_broadcast(iface, cont_ctrl_byte, report)
|
||||
continue
|
||||
|
||||
# handle unexpected initiation packet
|
||||
if not _is_ctrl_byte_continuation(cont_ctrl_byte):
|
||||
# TODO possibly add timeout - allow interruption only after a long time
|
||||
return InterruptingInitPacket(report)
|
||||
|
||||
# ignore continuation packets on different channels
|
||||
if cont_cid != header.cid:
|
||||
continue
|
||||
|
||||
# buffer the continuation data
|
||||
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
|
||||
|
||||
|
||||
async def write_message(
|
||||
iface: WireInterface, message: Message, is_retransmission: bool = False
|
||||
) -> None:
|
||||
session = THP.get_session_from_id(message.session_id)
|
||||
cid = THP.get_cid(session)
|
||||
payload = message.type.to_bytes(2, "big") + message.data
|
||||
payload_length = len(payload)
|
||||
|
||||
if THP.get_state(session) == SessionState.INITIALIZED:
|
||||
# write message in plaintext, TODO check if it is allowed
|
||||
ctrl_byte = _PLAINTEXT
|
||||
elif THP.get_state(session) == SessionState.APP_TRAFFIC:
|
||||
ctrl_byte = ENCRYPTED_TRANSPORT
|
||||
else:
|
||||
raise ThpError("Session in not implemented state" + str(THP.get_state(session)))
|
||||
|
||||
if not is_retransmission:
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(
|
||||
ctrl_byte, THP.sync_get_send_bit(session)
|
||||
)
|
||||
THP.sync_set_send_bit_to_opposite(session)
|
||||
else:
|
||||
# retransmission must have the same sync bit as the previously sent message
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(ctrl_byte, 1 - THP.sync_get_send_bit())
|
||||
|
||||
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
|
||||
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
|
||||
await write_to_wire(iface, header, payload + checksum)
|
||||
# TODO set timeout for retransmission
|
||||
|
||||
|
||||
async def write_to_wire(
|
||||
iface: WireInterface, header: InitHeader, payload: bytes
|
||||
) -> None:
|
||||
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
|
||||
payload_length = len(payload)
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(_REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
|
||||
await _write_report(write, iface, report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_length:
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
while nwritten < payload_length:
|
||||
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
|
||||
await _write_report(write, iface, report)
|
||||
|
||||
|
||||
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
|
||||
while True:
|
||||
await write
|
||||
n = iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
||||
|
||||
|
||||
def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None:
|
||||
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
||||
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
||||
length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:])
|
||||
|
||||
if not _is_checksum_valid(checksum, data=report[:-4]):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
channel_id = _get_new_channel_id()
|
||||
THP.create_new_unauthenticated_session(iface, channel_id)
|
||||
response_data = (
|
||||
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
|
||||
)
|
||||
|
||||
response_header = InitHeader(
|
||||
_CHANNEL_ALLOCATION_RES,
|
||||
BROADCAST_CHANNEL_ID,
|
||||
len(response_data) + _CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
|
||||
write_to_wire(iface, response_header, response_data + checksum)
|
||||
|
||||
|
||||
def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
|
||||
# Parameters session and ctrl_byte will be used to determine if the
|
||||
# communication should be encrypted or not
|
||||
|
||||
message_type = ustruct.unpack(">H", payload)[0]
|
||||
|
||||
# trim message type and checksum from payload
|
||||
message_data = payload[2:-_CHECKSUM_LENGTH]
|
||||
return Message(message_type, message_data, session.session_id)
|
||||
|
||||
|
||||
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
|
||||
# No ACKs expected
|
||||
if THP.sync_can_send_message(session):
|
||||
return
|
||||
|
||||
# ACK has incorrect sync bit
|
||||
if THP.sync_get_send_bit(session) != sync_bit:
|
||||
return
|
||||
|
||||
# ACK is expected and it has correct sync bit
|
||||
THP.sync_set_can_send_message(session, True)
|
||||
|
||||
|
||||
async def _handle_unallocated(iface, cid) -> Message | None:
|
||||
data = _UNALLOCATED_SESSION_ERROR
|
||||
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
|
||||
checksum = _compute_checksum_bytes(header.to_bytes() + data)
|
||||
write_to_wire(iface, header, data + checksum)
|
||||
|
||||
|
||||
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
||||
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
|
||||
checksum = _compute_checksum_bytes(header.to_bytes())
|
||||
write_to_wire(iface, header, checksum)
|
||||
|
||||
|
||||
def _handle_unexpected_sync_bit(
|
||||
iface: WireInterface, cid: int, sync_bit: int
|
||||
) -> Message | None:
|
||||
_sendAck(iface, cid, sync_bit)
|
||||
|
||||
# TODO handle cancelation messages and messages on allocated channels without synchronization
|
||||
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
|
||||
|
||||
|
||||
def _get_new_channel_id() -> int:
|
||||
return THP.get_next_channel_id()
|
||||
|
||||
|
||||
def _is_checksum_valid(checksum: bytearray, data: bytearray) -> bool:
|
||||
data_checksum = _compute_checksum_bytes(data)
|
||||
return checksum == data_checksum
|
||||
|
||||
|
||||
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
||||
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
|
||||
|
||||
|
||||
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
|
||||
return ctrl_byte & 0x20 == _ACK_MESSAGE
|
||||
|
||||
|
||||
def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if sync_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise ThpError("Unexpected synchronization bit")
|
||||
|
||||
|
||||
def _compute_checksum_bytes(data: bytearray):
|
||||
return crc.crc32(data).to_bytes(4, "big")
|
@ -0,0 +1,341 @@
|
||||
from common import *
|
||||
from ubinascii import hexlify, unhexlify
|
||||
import ustruct
|
||||
|
||||
from trezor import io, utils
|
||||
from trezor.loop import wait
|
||||
from trezor.utils import chunks
|
||||
from trezor.wire import thp_v1
|
||||
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID
|
||||
from trezor.wire.protocol_common import Message
|
||||
import trezor.wire.thp_session as THP
|
||||
|
||||
from micropython import const
|
||||
|
||||
|
||||
class MockHID:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
self.data = []
|
||||
|
||||
def iface_num(self):
|
||||
return self.num
|
||||
|
||||
def write(self, msg):
|
||||
self.data.append(bytearray(msg))
|
||||
return len(msg)
|
||||
|
||||
def wait_object(self, mode):
|
||||
return wait(mode | self.num)
|
||||
|
||||
|
||||
MESSAGE_TYPE = 0x4242
|
||||
MESSAGE_TYPE_BYTES = b"\x42\x42"
|
||||
_MESSAGE_TYPE_LEN = 2
|
||||
PLAINTEXT_0 = 0x01
|
||||
PLAINTEXT_1 = 0x11
|
||||
COMMON_CID = 4660
|
||||
CONT = 0x80
|
||||
|
||||
HEADER_INIT_LENGTH = 5
|
||||
HEADER_CONT_LENGTH = 3
|
||||
INIT_MESSAGE_DATA_LENGTH = (
|
||||
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
|
||||
)
|
||||
|
||||
|
||||
def make_header(ctrl_byte, cid, length):
|
||||
return ustruct.pack(">BHH", ctrl_byte, cid, length)
|
||||
|
||||
|
||||
def make_cont_header():
|
||||
return ustruct.pack(">BH", CONT, COMMON_CID)
|
||||
|
||||
|
||||
def makeSimpleMessage(header, message_type, message_data):
|
||||
return header + ustruct.pack(">H", message_type) + message_data
|
||||
|
||||
|
||||
def makeCidRequest(header, message_data):
|
||||
return header + message_data
|
||||
|
||||
|
||||
def printBytes(a):
|
||||
print(hexlify(a).decode("utf-8"))
|
||||
|
||||
|
||||
def getPlaintext() -> bytes:
|
||||
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
|
||||
return PLAINTEXT_1
|
||||
PLAINTEXT_0
|
||||
|
||||
|
||||
def getCid() -> int:
|
||||
return THP.get_cid(THP.get_active_session())
|
||||
|
||||
|
||||
# This test suite is an adaptation of test_trezor.wire.codec_v1
|
||||
class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
if not utils.USE_THP:
|
||||
import storage.cache_thp # noQA:F401
|
||||
|
||||
def test_simple(self):
|
||||
cid_req_header = make_header(
|
||||
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
|
||||
)
|
||||
cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3C\x6C"
|
||||
cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data)
|
||||
|
||||
message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18)
|
||||
cid_request_dummy_data_checksum = b"\x67\x8E\xAC\xE0"
|
||||
message = makeSimpleMessage(
|
||||
message_header,
|
||||
MESSAGE_TYPE,
|
||||
cid_request_dummy_data + cid_request_dummy_data_checksum,
|
||||
)
|
||||
|
||||
buffer = bytearray(64)
|
||||
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(cid_req_message)
|
||||
gen.send(message)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, cid_request_dummy_data)
|
||||
|
||||
buffer_without_zeroes = buffer[: len(message) - 5]
|
||||
message_without_header = message[5:]
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer_without_zeroes, message_without_header)
|
||||
|
||||
def test_read_one_packet(self):
|
||||
# zero length message - just a header
|
||||
PLAINTEXT = getPlaintext()
|
||||
header = make_header(
|
||||
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
||||
)
|
||||
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
||||
message = header + MESSAGE_TYPE_BYTES + checksum
|
||||
|
||||
buffer = bytearray(64)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(message)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, b"")
|
||||
|
||||
# message should have been read into the buffer
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58)
|
||||
|
||||
def test_read_many_packets(self):
|
||||
message = bytes(range(256))
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
||||
)
|
||||
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
||||
# message = MESSAGE_TYPE_BYTES + message + checksum
|
||||
|
||||
# first packet is init header + 59 bytes of data
|
||||
# other packets are cont header + 61 bytes of data
|
||||
cont_header = make_cont_header()
|
||||
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
||||
64 - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
buffer = bytearray(262)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in packets[:-1]:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
query = gen.send(packet)
|
||||
|
||||
# last packet will stop
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(packets[-1])
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# message should have been read into the buffer )
|
||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum)
|
||||
|
||||
def test_read_large_message(self):
|
||||
message = b"hello world"
|
||||
header = make_header(
|
||||
getPlaintext(),
|
||||
COMMON_CID,
|
||||
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
packet = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ message
|
||||
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
||||
)
|
||||
|
||||
# make sure we fit into one packet, to make this easier
|
||||
self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH)
|
||||
|
||||
buffer = bytearray(1)
|
||||
self.assertTrue(len(buffer) <= len(packet))
|
||||
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(packet)
|
||||
|
||||
# e.value is StopIteration. e.value.value is the return value of the call
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message)
|
||||
|
||||
# read should have allocated its own buffer and not touch ours
|
||||
self.assertEqual(buffer, b"\x00")
|
||||
|
||||
def test_write_one_packet(self):
|
||||
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID))
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
header = make_header(
|
||||
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
||||
)
|
||||
expected_message = (
|
||||
header
|
||||
+ MESSAGE_TYPE_BYTES
|
||||
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
||||
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH)
|
||||
)
|
||||
self.assertTrue(self.interface.data == [expected_message])
|
||||
|
||||
def test_write_multiple_packets(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = Message(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
|
||||
header = make_header(
|
||||
PLAINTEXT_1,
|
||||
COMMON_CID,
|
||||
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
||||
)
|
||||
cont_header = make_cont_header()
|
||||
checksum = thp_v1._compute_checksum_bytes(
|
||||
header + message.type.to_bytes(2, "big") + message.data
|
||||
)
|
||||
packets = [
|
||||
header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH]
|
||||
] + [
|
||||
cont_header + chunk
|
||||
for chunk in chunks(
|
||||
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
||||
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
|
||||
)
|
||||
]
|
||||
|
||||
for _ in packets:
|
||||
# we receive as many queries as there are packets
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
# the first sent None only started the generator. the len(packets)-th None
|
||||
# will finish writing and raise StopIteration
|
||||
with self.assertRaises(StopIteration):
|
||||
gen.send(None)
|
||||
|
||||
# packets must be identical up to the last one
|
||||
self.assertListEqual(packets[:-1], self.interface.data[:-1])
|
||||
# last packet must be identical up to message length. remaining bytes in
|
||||
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||
# previous packet
|
||||
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||
self.assertEqual(last_packet, self.interface.data[-1])
|
||||
|
||||
def test_roundtrip(self):
|
||||
message_payload = bytes(range(256))
|
||||
message = Message(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
|
||||
# exhaust the iterator:
|
||||
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||
for query in gen:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
|
||||
buffer = bytearray(1024)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
query = gen.send(None)
|
||||
for packet in self.interface.data[:-1]:
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
query = gen.send(packet)
|
||||
|
||||
with self.assertRaises(StopIteration) as e:
|
||||
gen.send(self.interface.data[-1])
|
||||
|
||||
result = e.value.value
|
||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||
self.assertEqual(result.data, message.data)
|
||||
|
||||
def test_read_huge_packet(self):
|
||||
PACKET_COUNT = 1180
|
||||
# message that takes up 1 180 USB packets
|
||||
message_size = (PACKET_COUNT - 1) * (
|
||||
thp_v1._REPORT_LENGTH
|
||||
- HEADER_CONT_LENGTH
|
||||
- _CHECKSUM_LENGTH
|
||||
- _MESSAGE_TYPE_LEN
|
||||
) + INIT_MESSAGE_DATA_LENGTH
|
||||
|
||||
# ensure that a message this big won't fit into memory
|
||||
# Note: this control is changed, because THP has only 2 byte length field
|
||||
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN)
|
||||
# self.assertRaises(MemoryError, bytearray, message_size)
|
||||
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
|
||||
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
|
||||
buffer = bytearray(65536)
|
||||
gen = thp_v1.read_message(self.interface, buffer)
|
||||
|
||||
query = gen.send(None)
|
||||
|
||||
# THP returns "Message too large" error after reading the message size,
|
||||
# it is different from codec_v1 as it does not allow big enough messages
|
||||
# to raise MemoryError in this test
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||
with self.assertRaises(thp_v1.ThpError) as e:
|
||||
query = gen.send(packet)
|
||||
|
||||
self.assertEqual(e.value.args[0], "Message too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue