Structural adjustments

M1nd3r/thp5
M1nd3r 2 months ago
parent 13fa46518d
commit 543843e05d

@ -11,50 +11,53 @@ if TYPE_CHECKING:
T = TypeVar("T")
# THP specific constants
_MAX_UNAUTHENTICATED_CHANNELS_COUNT = const(5)
_MAX_CHANNELS_COUNT = 10
_MAX_SESSIONS_COUNT = const(20)
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove
_THP_CHANNEL_STATE_LENGTH = const(1)
_THP_SESSION_STATE_LENGTH = const(1)
_CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(4)
_SESSION_ID_LENGTH = const(4)
BROADCAST_CHANNEL_ID = const(65535)
_UNALLOCATED_STATE = const(0)
class UnauthenticatedChannelCache(DataCache):
class ConnectionCache(DataCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.fields = ()
self.last_usage = 0
super().__init__()
def clear(self) -> None:
self.channel_id[:] = b""
self.last_usage = 0
super().clear()
class ChannelCache(UnauthenticatedChannelCache):
class ChannelCache(ConnectionCache):
def __init__(self) -> None:
self.enc_key = 0 # TODO change
self.dec_key = 1 # TODO change
self.state = bytearray(_THP_CHANNEL_STATE_LENGTH)
self.last_usage = 0
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.state = bytearray(_CHANNEL_STATE_LENGTH)
self.iface = bytearray(1) # TODO add decoding
super().__init__()
def clear(self) -> None:
self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.last_usage = 0
self.state[:] = bytearray(
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
) # Set state to UNALLOCATED
# TODO clear all sessions that are under this channel
super().clear()
class SessionThpCache(DataCache):
class SessionThpCache(ConnectionCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
self.state = bytearray(_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
@ -78,27 +81,21 @@ class SessionThpCache(DataCache):
super().__init__()
def clear(self) -> None:
self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.last_usage = 0
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.session_id[:] = b""
self.channel_id[:] = b""
super().clear()
_UNAUTHENTICATED_CHANNELS: list[UnauthenticatedChannelCache] = []
_CHANNELS: list[ChannelCache] = []
_SESSIONS: list[SessionThpCache] = []
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace
def initialize() -> None:
global _UNAUTHENTICATED_CHANNELS
global _CHANNELS
global _SESSIONS
global _UNAUTHENTICATED_SESSIONS
for _ in range(_MAX_UNAUTHENTICATED_CHANNELS_COUNT):
_UNAUTHENTICATED_CHANNELS.append(UnauthenticatedChannelCache())
for _ in range(_MAX_CHANNELS_COUNT):
_CHANNELS.append(ChannelCache())
for _ in range(_MAX_SESSIONS_COUNT):
@ -107,8 +104,6 @@ def initialize() -> None:
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
for unauth_channel in _UNAUTHENTICATED_CHANNELS:
unauth_channel.clear()
for channel in _CHANNELS:
channel.clear()
for session in _SESSIONS:
@ -122,16 +117,73 @@ initialize()
# THP vars
_next_unauthenicated_session_index: int = 0
_next_unauthenicated_session_index: int = 0 # TODO remove
# First unauthenticated channel will have index 0
_is_active_session_authenticated: bool
_active_session_idx: int | None = None
_session_usage_counter = 0
_usage_counter = 0
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
cid_counter: int = 4659 # TODO change to random value on start
def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache:
if len(iface) != _WIRE_INTERFACE_LENGTH:
raise Exception("Invalid WireInterface (encoded) length")
new_cid = get_next_channel_id()
index = _get_next_unauthenticated_channel_index()
_CHANNELS[index] = ChannelCache()
_CHANNELS[index].channel_id[:] = new_cid
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
_CHANNELS[index].state = bytearray(
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
)
_CHANNELS[index].iface = bytearray(iface)
return _CHANNELS[index]
def get_all_allocated_channels() -> list[ChannelCache]:
_list: list[ChannelCache] = []
for channel in _CHANNELS:
if _get_channel_state(channel) != _UNALLOCATED_STATE:
_list.append(channel)
return _list
def _get_usage_counter() -> int:
global _usage_counter
return _usage_counter
def _get_usage_counter_and_increment() -> int:
global _usage_counter
_usage_counter += 1
return _usage_counter
def _get_next_unauthenticated_channel_index() -> int:
idx = _get_unallocated_channel_index()
if idx is not None:
return idx
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
def _get_unallocated_channel_index() -> int | None:
for i in range(_MAX_CHANNELS_COUNT):
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_channel_state(channel: ChannelCache) -> int:
if channel is None:
return _UNALLOCATED_STATE
return int.from_bytes(channel.state, "big")
def get_active_session_id() -> bytearray | None:
active_session = get_active_session()
@ -148,7 +200,10 @@ def get_active_session() -> SessionThpCache | None:
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
def get_next_channel_id() -> int:
_session_usage_counter = 0
def get_next_channel_id() -> bytes:
global cid_counter
while True:
cid_counter += 1
@ -156,7 +211,7 @@ def get_next_channel_id() -> int:
cid_counter = 1
if _is_cid_unique():
break
return cid_counter
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
def _is_cid_unique() -> bool:
@ -199,8 +254,6 @@ def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
global _session_usage_counter
unauth_session_idx = get_unauth_session_index(unauth_session)
if unauth_session_idx is None:
raise InvalidSessionError
@ -211,19 +264,24 @@ def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
_UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear()
_session_usage_counter += 1
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
_SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment()
return _SESSIONS[new_auth_session_index]
def get_least_recently_used_authetnicated_session_index() -> int:
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
return lru_session_idx
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
def get_least_recently_used_item(
list: list[ChannelCache] | list[SessionThpCache], max_count: int
):
lru_counter = _get_usage_counter()
lru_item_index = 0
for i in range(max_count):
if list[i].last_usage < lru_counter:
lru_counter = list[i].last_usage
lru_item_index = i
return lru_item_index
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
@ -244,7 +302,7 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big")
new_session_id = b"\x00\x00" + get_next_channel_id()
new_session = create_new_unauthenticated_session(new_session_id)

@ -165,7 +165,8 @@ async def handle_session(
next_msg = None
# Set ctx.session_id to the value msg.session_id
ctx.channel_id = msg.session_id
if msg.session_id is not None:
ctx.channel_id = msg.session_id
try:
next_msg = await message_handler.handle_single_message(

@ -65,7 +65,7 @@ class CodecContext(Context):
self,
iface: WireInterface,
buffer: bytearray,
channel_id: bytes | None = None,
channel_id: bytes,
) -> None:
self.iface = iface
self.buffer = buffer

@ -1,5 +1,10 @@
from typing import TYPE_CHECKING
from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface
class Message:
def __init__(
@ -41,8 +46,8 @@ class WireError(Exception):
class Context:
def __init__(self, iface, channel_id) -> None:
self.iface = iface
self.channel_id = channel_id
def __init__(self, iface: WireInterface, channel_id: bytes) -> None:
self.iface: WireInterface = iface
self.channel_id: bytes = channel_id
async def write(self, msg: protobuf.MessageType) -> None: ...

@ -8,11 +8,18 @@ else:
class ChannelState(IntEnum):
UNALLOCATED = 0
TH1 = 1
TH2 = 2
TP1 = 3
TP2 = 4
TP3 = 5
TP4 = 6
TP5 = 7
ENCRYPTED_TRANSPORT = 8
UNAUTHENTICATED = 1
TH1 = 2
TH2 = 3
TP1 = 4
TP2 = 5
TP3 = 6
TP4 = 7
TP5 = 8
ENCRYPTED_TRANSPORT = 9
class WireInterfaceType(IntEnum):
MOCK = 0
USB = 1
BLE = 2

@ -2,11 +2,14 @@ import ustruct
from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import SessionThpCache
import usb
from storage import cache_thp
from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils
from ..protocol_common import Context
from . import thp_session
# from . import thp_session
from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT
# from .thp_session import SessionState, ThpError
@ -17,23 +20,30 @@ if TYPE_CHECKING:
_INIT_DATA_OFFSET = const(5)
_CONT_DATA_OFFSET = const(3)
_WIRE_INTERFACE_USB = b"\x00"
class ChannelContext(Context):
def __init__(
self, iface: WireInterface, channel_id: int, session_data: SessionThpCache
) -> None:
super().__init__(iface, channel_id)
self.session_data = session_data
def __init__(self, channel_cache: ChannelCache) -> None:
iface = _decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache.channel_id)
self.channel_cache = channel_cache
self.buffer: utils.BufferType
self.waiting_for_ack_timeout: loop.Task | None
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read = 0
# ACCESS TO SESSION_DATA
@classmethod
def create_new_channel(cls, iface: WireInterface) -> "ChannelContext":
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
return cls(channel_cache)
def get_management_session_state(self):
return thp_session.get_state(self.session_data)
# ACCESS TO CHANNEL_DATA
def get_management_session_state(self): # TODO redo for channel state
# return thp_session.get_state(self.session_data)
pass
# CALLED BY THP_MAIN_LOOP
@ -96,6 +106,31 @@ class ChannelContext(Context):
# create a new session with this passphrase
def load_cached_channels() -> dict[int, ChannelContext]: # TODO
channels: dict[int, ChannelContext] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c)
return channels
def _decode_iface(cached_iface: bytes) -> WireInterface:
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def _encode_iface(iface: WireInterface) -> bytes:
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET

@ -62,8 +62,8 @@ def get_cid(session: SessionThpCache) -> int:
return storage_thp_cache._get_cid(session)
def get_next_channel_id() -> int:
return storage_thp_cache.get_next_channel_id()
def get_next_channel_id() -> int: # deprecated TODO remove
return int.from_bytes(storage_thp_cache.get_next_channel_id(), "big")
def sync_can_send_message(session: SessionThpCache) -> bool:

@ -8,7 +8,7 @@ from trezor import io, log, loop, utils
from .protocol_common import MessageWithId
from .thp import ack_handler, checksum, thp_messages
from .thp import thp_session as THP
from .thp.channel_context import ChannelContext
from .thp.channel_context import ChannelContext, load_cached_channels
from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import (
CONTINUATION_PACKET,
@ -35,6 +35,8 @@ _REPORT_CONT_DATA_OFFSET = const(3)
_BUFFER: bytearray
_BUFFER_LOCK = None
_CHANNEL_CONTEXTS: dict[int, ChannelContext] = {}
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
msg = await read_message_or_init_packet(iface, buffer)
@ -52,9 +54,8 @@ def set_buffer(buffer):
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
CHANNELS: dict[int, ChannelContext] = {}
# TODO load cached channels/sessions
global _CHANNEL_CONTEXTS
_CHANNEL_CONTEXTS = load_cached_channels()
read = loop.wait(iface.iface_num() | io.POLL_READ)
@ -63,18 +64,23 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
ctrl_byte, cid = ustruct.unpack(">BH", packet)
if cid == BROADCAST_CHANNEL_ID:
# TODO handle exceptions, try-catch?
await _handle_broadcast(iface, ctrl_byte, packet)
continue
if cid in CHANNELS:
channel = CHANNELS[cid]
if cid in _CHANNEL_CONTEXTS:
channel = _CHANNEL_CONTEXTS[cid]
if channel is None:
raise ThpError("Invalid state of a channel")
# TODO if the channelContext interface is not None and is different from
# the one used in the transmission of the packet, raise an exception
# TODO add current wire interface to channelContext if its iface is None
if channel.get_management_session_state != SessionState.UNALLOCATED:
await channel.receive_packet(packet)
continue
await _handle_unallocated(iface, cid)
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
async def read_message_or_init_packet(
@ -316,26 +322,29 @@ async def _handle_broadcast(
) -> MessageWithId | None:
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
if __debug__:
log.debug(__name__, "Received valid message on broadcast channel ")
length, nonce = ustruct.unpack(">H8s", packet[3:])
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
payload = _get_buffer_for_payload(length, packet[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid")
channel_id = _get_new_channel_id()
THP.create_new_unauthenticated_session(iface, channel_id)
deprecated_channel_id = _get_new_channel_id() # TODO remove
THP.create_new_unauthenticated_session(iface, deprecated_channel_id) # TODO remove
new_context: ChannelContext = ChannelContext.create_new_channel(iface)
cid = int.from_bytes(new_context.channel_id, "big")
_CHANNEL_CONTEXTS[cid] = new_context
response_data = thp_messages.get_channel_allocation_response(nonce, channel_id)
response_data = thp_messages.get_channel_allocation_response(nonce, cid)
response_header = InitHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
)
chksum = checksum.compute(response_header.to_bytes() + response_data)
if __debug__:
log.debug(__name__, "New channel allocated with id %d", channel_id)
log.debug(__name__, "New channel allocated with id %d", cid)
await write_to_wire(iface, response_header, response_data + chksum)

Loading…
Cancel
Save