|
|
@ -10,7 +10,7 @@ from trezor import loop, protobuf, utils
|
|
|
|
from trezor.messages import ThpCreateNewSession
|
|
|
|
from trezor.messages import ThpCreateNewSession
|
|
|
|
from trezor.wire import message_handler
|
|
|
|
from trezor.wire import message_handler
|
|
|
|
|
|
|
|
|
|
|
|
from ..protocol_common import Context
|
|
|
|
from ..protocol_common import Context, MessageWithType
|
|
|
|
from . import ChannelState, SessionState, checksum
|
|
|
|
from . import ChannelState, SessionState, checksum
|
|
|
|
from . import thp_session as THP
|
|
|
|
from . import thp_session as THP
|
|
|
|
from .checksum import CHECKSUM_LENGTH
|
|
|
|
from .checksum import CHECKSUM_LENGTH
|
|
|
@ -39,7 +39,7 @@ REPORT_LENGTH = const(64)
|
|
|
|
MAX_PAYLOAD_LEN = const(60000)
|
|
|
|
MAX_PAYLOAD_LEN = const(60000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelContext(Context):
|
|
|
|
class Channel(Context):
|
|
|
|
def __init__(self, channel_cache: ChannelCache) -> None:
|
|
|
|
def __init__(self, channel_cache: ChannelCache) -> None:
|
|
|
|
iface = _decode_iface(channel_cache.iface)
|
|
|
|
iface = _decode_iface(channel_cache.iface)
|
|
|
|
super().__init__(iface, channel_cache.channel_id)
|
|
|
|
super().__init__(iface, channel_cache.channel_id)
|
|
|
@ -56,7 +56,7 @@ class ChannelContext(Context):
|
|
|
|
@classmethod
|
|
|
|
@classmethod
|
|
|
|
def create_new_channel(
|
|
|
|
def create_new_channel(
|
|
|
|
cls, iface: WireInterface, buffer: utils.BufferType
|
|
|
|
cls, iface: WireInterface, buffer: utils.BufferType
|
|
|
|
) -> "ChannelContext":
|
|
|
|
) -> "Channel":
|
|
|
|
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
|
|
|
|
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
|
|
|
|
r = cls(channel_cache)
|
|
|
|
r = cls(channel_cache)
|
|
|
|
r.set_buffer(buffer)
|
|
|
|
r.set_buffer(buffer)
|
|
|
@ -217,9 +217,11 @@ class ChannelContext(Context):
|
|
|
|
if session_state is SessionState.UNALLOCATED:
|
|
|
|
if session_state is SessionState.UNALLOCATED:
|
|
|
|
raise Exception("Unalloacted session")
|
|
|
|
raise Exception("Unalloacted session")
|
|
|
|
|
|
|
|
|
|
|
|
await self.sessions[session_id].receive_message(
|
|
|
|
self.sessions[session_id].incoming_message.publish(
|
|
|
|
message_type,
|
|
|
|
MessageWithType(
|
|
|
|
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
|
|
|
message_type,
|
|
|
|
|
|
|
|
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH2:
|
|
|
|
if state is ChannelState.TH2:
|
|
|
@ -275,6 +277,7 @@ class ChannelContext(Context):
|
|
|
|
session = SessionContext.create_new_session(self)
|
|
|
|
session = SessionContext.create_new_session(self)
|
|
|
|
print("help")
|
|
|
|
print("help")
|
|
|
|
self.sessions[session.session_id] = session
|
|
|
|
self.sessions[session.session_id] = session
|
|
|
|
|
|
|
|
loop.schedule(session.handle())
|
|
|
|
print("new session created. Session id:", session.session_id)
|
|
|
|
print("new session created. Session id:", session.session_id)
|
|
|
|
|
|
|
|
|
|
|
|
def _todo_clear_buffer(self):
|
|
|
|
def _todo_clear_buffer(self):
|
|
|
@ -300,11 +303,11 @@ class ChannelContext(Context):
|
|
|
|
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
|
|
|
|
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_cached_channels(buffer: utils.BufferType) -> dict[int, ChannelContext]: # TODO
|
|
|
|
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
|
|
|
|
channels: dict[int, ChannelContext] = {}
|
|
|
|
channels: dict[int, Channel] = {}
|
|
|
|
cached_channels = cache_thp.get_all_allocated_channels()
|
|
|
|
cached_channels = cache_thp.get_all_allocated_channels()
|
|
|
|
for c in cached_channels:
|
|
|
|
for c in cached_channels:
|
|
|
|
channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c)
|
|
|
|
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
|
|
|
|
for c in channels.values():
|
|
|
|
for c in channels.values():
|
|
|
|
c.set_buffer(buffer)
|
|
|
|
c.set_buffer(buffer)
|
|
|
|
return channels
|
|
|
|
return channels
|