1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): expect channel messages with noise tag

This commit is contained in:
M1nd3r 2024-04-03 17:58:38 +02:00
parent 5306a8b55e
commit c0e8342272

View File

@ -44,7 +44,8 @@ MAX_PAYLOAD_LEN = const(60000)
class Channel(Context):
def __init__(self, channel_cache: ChannelCache) -> None:
print("channel.__init__")
if __debug__:
log.debug(__name__, "channel initialization")
iface = _decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache.channel_id)
self.channel_cache = channel_cache
@ -279,7 +280,9 @@ class Channel(Context):
self.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH],
self.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
],
)
)
@ -287,13 +290,16 @@ class Channel(Context):
pass
def _handle_channel_message(self, message_length: int, message_type: int) -> None:
buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH]
buf = self.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
print("channel._handle_channel_message:", message)
# TODO handle other messages than CreateNewSession
assert isinstance(message, ThpCreateNewSession)
if TYPE_CHECKING:
assert isinstance(message, ThpCreateNewSession)
print("channel._handle_channel_message - passphrase:", message.passphrase)
# await thp_messages.handle_CreateNewSession(message)
if message.passphrase is not None:
@ -376,7 +382,7 @@ class Channel(Context):
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
print("channel.write")
print("channel.write:" + msg.MESSAGE_NAME)
noise_payload_len = self._encode_into_buffer(msg, session_id)
await self.write_and_encrypt(self.buffer[:noise_payload_len])
@ -600,8 +606,12 @@ def is_channel_state_pairing(state: int) -> bool:
def _state_to_str(state: int) -> str:
names = {v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")}
return names.get(state)
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"
def printBytes(a):