mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
refactor(core): change structure of channels and sessions
This commit is contained in:
parent
f75ee29ffa
commit
42873b1c30
@ -165,7 +165,8 @@ async def handle_session(
|
||||
next_msg = None
|
||||
|
||||
# Set ctx.session_id to the value msg.session_id
|
||||
ctx.channel_id = msg.session_id
|
||||
if msg.session_id is not None:
|
||||
ctx.channel_id = msg.session_id
|
||||
|
||||
try:
|
||||
next_msg = await message_handler.handle_single_message(
|
||||
|
@ -65,7 +65,7 @@ class CodecContext(Context):
|
||||
self,
|
||||
iface: WireInterface,
|
||||
buffer: bytearray,
|
||||
channel_id: bytes | None = None,
|
||||
channel_id: bytes,
|
||||
) -> None:
|
||||
self.iface = iface
|
||||
self.buffer = buffer
|
||||
|
@ -1,5 +1,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import protobuf
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(
|
||||
@ -41,8 +46,8 @@ class WireError(Exception):
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, iface, channel_id) -> None:
|
||||
self.iface = iface
|
||||
self.channel_id = channel_id
|
||||
def __init__(self, iface: WireInterface, channel_id: bytes) -> None:
|
||||
self.iface: WireInterface = iface
|
||||
self.channel_id: bytes = channel_id
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||
|
@ -8,11 +8,18 @@ else:
|
||||
|
||||
class ChannelState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
TH1 = 1
|
||||
TH2 = 2
|
||||
TP1 = 3
|
||||
TP2 = 4
|
||||
TP3 = 5
|
||||
TP4 = 6
|
||||
TP5 = 7
|
||||
ENCRYPTED_TRANSPORT = 8
|
||||
UNAUTHENTICATED = 1
|
||||
TH1 = 2
|
||||
TH2 = 3
|
||||
TP1 = 4
|
||||
TP2 = 5
|
||||
TP3 = 6
|
||||
TP4 = 7
|
||||
TP5 = 8
|
||||
ENCRYPTED_TRANSPORT = 9
|
||||
|
||||
|
||||
class WireInterfaceType(IntEnum):
|
||||
MOCK = 0
|
||||
USB = 1
|
||||
BLE = 2
|
||||
|
@ -2,11 +2,14 @@ import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_thp import SessionThpCache
|
||||
import usb
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import ChannelCache
|
||||
from trezor import loop, protobuf, utils
|
||||
|
||||
from ..protocol_common import Context
|
||||
from . import thp_session
|
||||
|
||||
# from . import thp_session
|
||||
from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT
|
||||
|
||||
# from .thp_session import SessionState, ThpError
|
||||
@ -17,23 +20,30 @@ if TYPE_CHECKING:
|
||||
_INIT_DATA_OFFSET = const(5)
|
||||
_CONT_DATA_OFFSET = const(3)
|
||||
|
||||
_WIRE_INTERFACE_USB = b"\x00"
|
||||
|
||||
|
||||
class ChannelContext(Context):
|
||||
def __init__(
|
||||
self, iface: WireInterface, channel_id: int, session_data: SessionThpCache
|
||||
) -> None:
|
||||
super().__init__(iface, channel_id)
|
||||
self.session_data = session_data
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
iface = _decode_iface(channel_cache.iface)
|
||||
super().__init__(iface, channel_cache.channel_id)
|
||||
self.channel_cache = channel_cache
|
||||
self.buffer: utils.BufferType
|
||||
self.waiting_for_ack_timeout: loop.Task | None
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.expected_payload_length: int = 0
|
||||
self.bytes_read = 0
|
||||
|
||||
# ACCESS TO SESSION_DATA
|
||||
@classmethod
|
||||
def create_new_channel(cls, iface: WireInterface) -> "ChannelContext":
|
||||
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
|
||||
return cls(channel_cache)
|
||||
|
||||
def get_management_session_state(self):
|
||||
return thp_session.get_state(self.session_data)
|
||||
# ACCESS TO CHANNEL_DATA
|
||||
|
||||
def get_management_session_state(self): # TODO redo for channel state
|
||||
# return thp_session.get_state(self.session_data)
|
||||
pass
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
@ -96,6 +106,31 @@ class ChannelContext(Context):
|
||||
# create a new session with this passphrase
|
||||
|
||||
|
||||
def load_cached_channels() -> dict[int, ChannelContext]: # TODO
|
||||
channels: dict[int, ChannelContext] = {}
|
||||
cached_channels = cache_thp.get_all_allocated_channels()
|
||||
for c in cached_channels:
|
||||
channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c)
|
||||
return channels
|
||||
|
||||
|
||||
def _decode_iface(cached_iface: bytes) -> WireInterface:
|
||||
if cached_iface == _WIRE_INTERFACE_USB:
|
||||
iface = usb.iface_wire
|
||||
if iface is None:
|
||||
raise RuntimeError("There is no valid USB WireInterface")
|
||||
return iface
|
||||
# TODO implement bluetooth interface
|
||||
raise Exception("Unknown WireInterface")
|
||||
|
||||
|
||||
def _encode_iface(iface: WireInterface) -> bytes:
|
||||
if iface is usb.iface_wire:
|
||||
return _WIRE_INTERFACE_USB
|
||||
# TODO implement bluetooth interface
|
||||
raise Exception("Unknown WireInterface")
|
||||
|
||||
|
||||
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||
|
||||
|
@ -62,8 +62,8 @@ def get_cid(session: SessionThpCache) -> int:
|
||||
return storage_thp_cache._get_cid(session)
|
||||
|
||||
|
||||
def get_next_channel_id() -> int:
|
||||
return storage_thp_cache.get_next_channel_id()
|
||||
def get_next_channel_id() -> int: # deprecated TODO remove
|
||||
return int.from_bytes(storage_thp_cache.get_next_channel_id(), "big")
|
||||
|
||||
|
||||
def sync_can_send_message(session: SessionThpCache) -> bool:
|
||||
|
@ -8,7 +8,7 @@ from trezor import io, log, loop, utils
|
||||
from .protocol_common import MessageWithId
|
||||
from .thp import ack_handler, checksum, thp_messages
|
||||
from .thp import thp_session as THP
|
||||
from .thp.channel_context import ChannelContext
|
||||
from .thp.channel_context import ChannelContext, load_cached_channels
|
||||
from .thp.checksum import CHECKSUM_LENGTH
|
||||
from .thp.thp_messages import (
|
||||
CONTINUATION_PACKET,
|
||||
@ -35,6 +35,8 @@ _REPORT_CONT_DATA_OFFSET = const(3)
|
||||
_BUFFER: bytearray
|
||||
_BUFFER_LOCK = None
|
||||
|
||||
_CHANNEL_CONTEXTS: dict[int, ChannelContext] = {}
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
||||
msg = await read_message_or_init_packet(iface, buffer)
|
||||
@ -52,9 +54,8 @@ def set_buffer(buffer):
|
||||
|
||||
|
||||
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
|
||||
CHANNELS: dict[int, ChannelContext] = {}
|
||||
# TODO load cached channels/sessions
|
||||
global _CHANNEL_CONTEXTS
|
||||
_CHANNEL_CONTEXTS = load_cached_channels()
|
||||
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
@ -63,18 +64,23 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
# TODO handle exceptions, try-catch?
|
||||
await _handle_broadcast(iface, ctrl_byte, packet)
|
||||
continue
|
||||
|
||||
if cid in CHANNELS:
|
||||
channel = CHANNELS[cid]
|
||||
if cid in _CHANNEL_CONTEXTS:
|
||||
channel = _CHANNEL_CONTEXTS[cid]
|
||||
if channel is None:
|
||||
raise ThpError("Invalid state of a channel")
|
||||
# TODO if the channelContext interface is not None and is different from
|
||||
# the one used in the transmission of the packet, raise an exception
|
||||
# TODO add current wire interface to channelContext if its iface is None
|
||||
if channel.get_management_session_state != SessionState.UNALLOCATED:
|
||||
await channel.receive_packet(packet)
|
||||
continue
|
||||
|
||||
await _handle_unallocated(iface, cid)
|
||||
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
|
||||
|
||||
|
||||
async def read_message_or_init_packet(
|
||||
@ -316,26 +322,29 @@ async def _handle_broadcast(
|
||||
) -> MessageWithId | None:
|
||||
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
||||
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
||||
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received valid message on broadcast channel ")
|
||||
|
||||
length, nonce = ustruct.unpack(">H8s", packet[3:])
|
||||
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
|
||||
|
||||
payload = _get_buffer_for_payload(length, packet[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
|
||||
|
||||
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
channel_id = _get_new_channel_id()
|
||||
THP.create_new_unauthenticated_session(iface, channel_id)
|
||||
deprecated_channel_id = _get_new_channel_id() # TODO remove
|
||||
THP.create_new_unauthenticated_session(iface, deprecated_channel_id) # TODO remove
|
||||
new_context: ChannelContext = ChannelContext.create_new_channel(iface)
|
||||
cid = int.from_bytes(new_context.channel_id, "big")
|
||||
_CHANNEL_CONTEXTS[cid] = new_context
|
||||
|
||||
response_data = thp_messages.get_channel_allocation_response(nonce, channel_id)
|
||||
response_data = thp_messages.get_channel_allocation_response(nonce, cid)
|
||||
response_header = InitHeader.get_channel_allocation_response_header(
|
||||
len(response_data) + CHECKSUM_LENGTH,
|
||||
)
|
||||
chksum = checksum.compute(response_header.to_bytes() + response_data)
|
||||
if __debug__:
|
||||
log.debug(__name__, "New channel allocated with id %d", channel_id)
|
||||
log.debug(__name__, "New channel allocated with id %d", cid)
|
||||
|
||||
await write_to_wire(iface, response_header, response_data + chksum)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user