1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

fix(core): partial fix for bug in session creation response

note: fails GL checks
This commit is contained in:
M1nd3r 2024-04-08 19:11:47 +02:00
parent 0c156c94a0
commit 30fd1fe5c3

View File

@ -6,8 +6,12 @@ import usb
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType, MessageType
from trezor.messages import Failure, ThpCreateNewSession
from trezor.enums import FailureType, MessageType # , ThpPairingMethod
from trezor.messages import (
Failure,
ThpCreateNewSession,
ThpNewSession,
)
from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages
@ -56,9 +60,14 @@ class Channel(Context):
self.waiting_for_ack_timeout: loop.spawn | None = None
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read = 0
self.bytes_read: int = 0
self.selected_pairing_methods = (
[]
) # TODO better # ThpPairingMethod.PairingMethod_NoMethod
from trezor.wire.thp.session_context import load_cached_sessions
self.connection_context = None
self.sessions = load_cached_sessions(self)
@classmethod
@ -233,14 +242,14 @@ class Channel(Context):
if __debug__:
log.debug(__name__, "state: %s", _state_to_str(state))
if state is ChannelState.TH1:
await self._handle_state_TH1(payload_length, message_length, sync_bit)
return
if state is ChannelState.ENCRYPTED_TRANSPORT:
await self._handle_state_ENCRYPTED_TRANSPORT(message_length)
return
if state is ChannelState.TH1:
await self._handle_state_TH1(payload_length, message_length, sync_bit)
return
if state is ChannelState.TH2:
await self._handle_state_TH2(message_length, sync_bit)
return
@ -322,12 +331,38 @@ class Channel(Context):
MessageWithType(
message_type,
self.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(self, message_length: int) -> None:
from .pairing_context import PairingContext
if self.connection_context is None:
self.connection_context = PairingContext(self)
self._decrypt_buffer(message_length)
message_type = ustruct.unpack(">H", self.buffer[INIT_DATA_OFFSET:])[0]
self.connection_context.incoming_message.publish(
MessageWithType(
message_type,
self.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
# 1. Check that message is expected with respect to the current state
# 2. Handle the message
pass
def _should_be_encrypted(self) -> bool:
@ -361,7 +396,9 @@ class Channel(Context):
else:
new_session_id: int = self.create_new_session()
# TODO reuse existing buffer and compute size dynamically
bufferrone = bytearray(2)
bufferrone = bytearray(5)
msg = ThpNewSession(new_session_id=new_session_id)
message_size: int = thp_messages.get_new_session_message(
bufferrone, new_session_id
)
@ -369,7 +406,17 @@ class Channel(Context):
log.debug(
__name__, "handle_channel_message - message size: %d", message_size
)
await self.write_and_encrypt(bufferrone)
_encode_session_into_buffer(memoryview(bufferrone), 0)
if TYPE_CHECKING:
assert msg.MESSAGE_WIRE_TYPE is not None
_encode_message_type_into_buffer(
memoryview(bufferrone), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(bufferrone), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
await self.write(ThpNewSession(new_session_id=new_session_id))
# TODO not finished
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
@ -643,7 +690,7 @@ def is_channel_state_pairing(state: int) -> bool:
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TP5,
ChannelState.TC1,
):
return True
return False