|
|
|
@ -6,8 +6,12 @@ import usb
|
|
|
|
|
from storage import cache_thp
|
|
|
|
|
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
|
|
|
|
|
from trezor import log, loop, protobuf, utils
|
|
|
|
|
from trezor.enums import FailureType, MessageType
|
|
|
|
|
from trezor.messages import Failure, ThpCreateNewSession
|
|
|
|
|
from trezor.enums import FailureType, MessageType # , ThpPairingMethod
|
|
|
|
|
from trezor.messages import (
|
|
|
|
|
Failure,
|
|
|
|
|
ThpCreateNewSession,
|
|
|
|
|
ThpNewSession,
|
|
|
|
|
)
|
|
|
|
|
from trezor.wire import message_handler
|
|
|
|
|
from trezor.wire.thp import ack_handler, thp_messages
|
|
|
|
|
|
|
|
|
@ -56,9 +60,14 @@ class Channel(Context):
|
|
|
|
|
self.waiting_for_ack_timeout: loop.spawn | None = None
|
|
|
|
|
self.is_cont_packet_expected: bool = False
|
|
|
|
|
self.expected_payload_length: int = 0
|
|
|
|
|
self.bytes_read = 0
|
|
|
|
|
self.bytes_read: int = 0
|
|
|
|
|
self.selected_pairing_methods = (
|
|
|
|
|
[]
|
|
|
|
|
) # TODO better # ThpPairingMethod.PairingMethod_NoMethod
|
|
|
|
|
from trezor.wire.thp.session_context import load_cached_sessions
|
|
|
|
|
|
|
|
|
|
self.connection_context = None
|
|
|
|
|
|
|
|
|
|
self.sessions = load_cached_sessions(self)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
@ -233,14 +242,14 @@ class Channel(Context):
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, "state: %s", _state_to_str(state))
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH1:
|
|
|
|
|
await self._handle_state_TH1(payload_length, message_length, sync_bit)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
|
|
|
|
await self._handle_state_ENCRYPTED_TRANSPORT(message_length)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH1:
|
|
|
|
|
await self._handle_state_TH1(payload_length, message_length, sync_bit)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH2:
|
|
|
|
|
await self._handle_state_TH2(message_length, sync_bit)
|
|
|
|
|
return
|
|
|
|
@ -322,12 +331,38 @@ class Channel(Context):
|
|
|
|
|
MessageWithType(
|
|
|
|
|
message_type,
|
|
|
|
|
self.buffer[
|
|
|
|
|
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
|
|
|
|
INIT_DATA_OFFSET
|
|
|
|
|
+ MESSAGE_TYPE_LENGTH
|
|
|
|
|
+ SESSION_ID_LENGTH : message_length
|
|
|
|
|
- CHECKSUM_LENGTH
|
|
|
|
|
- TAG_LENGTH
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def _handle_pairing(self, message_length: int) -> None:
|
|
|
|
|
from .pairing_context import PairingContext
|
|
|
|
|
|
|
|
|
|
if self.connection_context is None:
|
|
|
|
|
self.connection_context = PairingContext(self)
|
|
|
|
|
|
|
|
|
|
self._decrypt_buffer(message_length)
|
|
|
|
|
|
|
|
|
|
message_type = ustruct.unpack(">H", self.buffer[INIT_DATA_OFFSET:])[0]
|
|
|
|
|
|
|
|
|
|
self.connection_context.incoming_message.publish(
|
|
|
|
|
MessageWithType(
|
|
|
|
|
message_type,
|
|
|
|
|
self.buffer[
|
|
|
|
|
INIT_DATA_OFFSET
|
|
|
|
|
+ MESSAGE_TYPE_LENGTH : message_length
|
|
|
|
|
- CHECKSUM_LENGTH
|
|
|
|
|
- TAG_LENGTH
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
# 1. Check that message is expected with respect to the current state
|
|
|
|
|
# 2. Handle the message
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _should_be_encrypted(self) -> bool:
|
|
|
|
@ -361,7 +396,9 @@ class Channel(Context):
|
|
|
|
|
else:
|
|
|
|
|
new_session_id: int = self.create_new_session()
|
|
|
|
|
# TODO reuse existing buffer and compute size dynamically
|
|
|
|
|
bufferrone = bytearray(2)
|
|
|
|
|
bufferrone = bytearray(5)
|
|
|
|
|
|
|
|
|
|
msg = ThpNewSession(new_session_id=new_session_id)
|
|
|
|
|
message_size: int = thp_messages.get_new_session_message(
|
|
|
|
|
bufferrone, new_session_id
|
|
|
|
|
)
|
|
|
|
@ -369,7 +406,17 @@ class Channel(Context):
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__, "handle_channel_message - message size: %d", message_size
|
|
|
|
|
)
|
|
|
|
|
await self.write_and_encrypt(bufferrone)
|
|
|
|
|
|
|
|
|
|
_encode_session_into_buffer(memoryview(bufferrone), 0)
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
assert msg.MESSAGE_WIRE_TYPE is not None
|
|
|
|
|
_encode_message_type_into_buffer(
|
|
|
|
|
memoryview(bufferrone), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
|
|
|
|
|
)
|
|
|
|
|
_encode_message_into_buffer(
|
|
|
|
|
memoryview(bufferrone), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
|
|
|
|
)
|
|
|
|
|
await self.write(ThpNewSession(new_session_id=new_session_id))
|
|
|
|
|
# TODO not finished
|
|
|
|
|
|
|
|
|
|
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
|
|
|
|
@ -643,7 +690,7 @@ def is_channel_state_pairing(state: int) -> bool:
|
|
|
|
|
ChannelState.TP2,
|
|
|
|
|
ChannelState.TP3,
|
|
|
|
|
ChannelState.TP4,
|
|
|
|
|
ChannelState.TP5,
|
|
|
|
|
ChannelState.TC1,
|
|
|
|
|
):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|