1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

fix(core): fix debug log, crashing and mock noise tags

This commit is contained in:
M1nd3r 2024-04-03 13:56:26 +02:00
parent 8d346120f4
commit 504a3bfe98
6 changed files with 127 additions and 43 deletions

View File

@ -217,6 +217,8 @@ trezor.wire.thp.channel
import trezor.wire.thp.channel
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.crypto
import trezor.wire.thp.crypto
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.thp_messages

View File

@ -180,7 +180,7 @@ def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> Non
def get_new_session(channel: ChannelCache):
print("---------------get new session")
new_sid = get_next_session_id(channel)
index = _get_next_session_index()
@ -194,6 +194,8 @@ def get_new_session(channel: ChannelCache):
_SESSIONS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
)
for s in _SESSIONS:
print(s)
return _SESSIONS[index]

View File

@ -12,7 +12,7 @@ from trezor.wire import message_handler
from trezor.wire.thp import thp_messages
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum
from . import ChannelState, SessionState, checksum, crypto
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
from .thp_messages import (
@ -90,7 +90,8 @@ class Channel(Context):
await self._handle_cont_packet(packet)
else:
await self._handle_init_packet(packet)
print("receive packet", self.expected_payload_length, self.bytes_read)
printBytes(self.buffer)
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
self._finish_message()
await self._handle_completed_message()
@ -103,7 +104,7 @@ class Channel(Context):
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
packet_payload = self._decrypt(packet_payload)
packet_payload = self._decrypt_single_packet_payload(packet_payload)
state = self.get_channel_state()
@ -254,7 +255,7 @@ class Channel(Context):
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
self._decrypt_buffer()
self._decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
self._handle_channel_comms(message_length, message_type)
@ -293,18 +294,43 @@ class Channel(Context):
bufferrone = bytearray(2)
message_size: int = thp_messages.get_new_session_message(bufferrone)
print(message_size) # TODO adjust
loop.schedule(self._write_encrypted_payload_loop(bufferrone))
loop.schedule(self.write_and_encrypt(bufferrone))
except Exception as e:
print("Proč??")
print(e)
# TODO not finished
def _decrypt(self, payload) -> bytes:
return payload # TODO add decryption process
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
payload_buffer = bytearray(payload)
crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload_buffer
def _decrypt_buffer(self) -> None:
pass
# TODO decode buffer in place
def _decrypt_buffer(self, message_length: int) -> None:
if not isinstance(self.buffer, bytearray):
self.buffer = bytearray(self.buffer)
crypto.decrypt(
b"\x00",
b"\x00",
self.buffer,
INIT_DATA_OFFSET,
message_length - INIT_DATA_OFFSET - CHECKSUM_LENGTH,
)
def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None:
print("\n Encrypting ")
min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
if len(buffer) < min_required_length or not isinstance(buffer, bytearray):
new_buffer = bytearray(min_required_length)
utils.memcpy(new_buffer, 0, buffer, 0)
buffer = new_buffer
tag = crypto.encrypt(
b"\x00",
b"\x00",
buffer,
0,
noise_payload_len,
)
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
async def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
@ -327,7 +353,7 @@ class Channel(Context):
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.channel_id,
int.from_bytes(self.channel_id, "big"),
ack_bit,
)
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
@ -343,15 +369,18 @@ class Channel(Context):
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
print("write")
noise_payload_len = self._encode_into_buffer(msg, session_id)
await self.write_and_encrypt(self.buffer[:noise_payload_len])
# trezor.crypto.noise.encode(key, payload=self.buffer)
async def write_and_encrypt(self, payload: bytes) -> None:
payload_length = len(payload)
# TODO payload_len should be output from trezor.crypto.noise.encode, I guess
payload_len = noise_payload_len # + TAG_LENGTH # TODO
if not isinstance(self.buffer, bytearray):
self.buffer = bytearray(self.buffer)
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length]))
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
print("write loop before while")
@ -419,10 +448,13 @@ class Channel(Context):
msg_size = protobuf.encoded_length(msg)
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
payload_size = offset + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray):
if required_min_size > len(self.buffer) or not isinstance(
self.buffer, bytearray
):
# message is too big or buffer is not bytearray, we need to allocate a new buffer
self.buffer = bytearray(payload_size)
self.buffer = bytearray(required_min_size)
buffer = self.buffer
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
@ -445,6 +477,7 @@ class Channel(Context):
self.sessions[session.session_id] = session
loop.schedule(session.handle())
print("new session created. Session id:", session.session_id)
print(self.sessions)
def _todo_clear_buffer(self):
# TODO Buffer clearing not implemented

View File

@ -0,0 +1,34 @@
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
# TODO implement
def encrypt(
key: bytes,
nonce: bytes,
buffer: bytearray,
init_offset: int = 0,
payload_length: int = 0,
) -> bytes:
"""
Returns a 16-byte long encryption tag, the encryption itself is performed on the buffer provided.
"""
return DUMMY_TAG
def decrypt(
key: bytes,
nonce: bytes,
buffer: bytearray,
init_offset: int = 0,
payload_length: int = 0,
) -> None:
"""
Decryption in place.
"""
pass
def is_tag_valid(key: bytes, nonce: bytes, payload: bytes, noise_tag: bytes) -> bool:
return True

View File

@ -139,10 +139,18 @@ class SessionContext(Context):
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
print("start loading sessions from cache")
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
print(
"loaded a total of ",
len(cached_sessions),
"sessions from cache",
)
for session in cached_sessions:
if session.channel_id == channel.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel, session)
for i in sessions:
print("session", i)
return sessions

View File

@ -44,34 +44,39 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
read = loop.wait(iface.iface_num() | io.POLL_READ)
while True:
print("main loop")
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
try:
print("main loop")
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
if ctrl_byte == CODEC_V1:
pass
# TODO add handling of (unsupported) codec_v1 packets
# possibly ignore continuation packets, i.e. if the
# following bytes are not "##"", do not respond
if ctrl_byte == CODEC_V1:
pass
# TODO add handling of (unsupported) codec_v1 packets
# possibly ignore continuation packets, i.e. if the
# following bytes are not "##"", do not respond
if cid == BROADCAST_CHANNEL_ID:
# TODO handle exceptions, try-catch?
await _handle_broadcast(iface, ctrl_byte, packet)
continue
if cid in _CHANNEL_CONTEXTS:
channel = _CHANNEL_CONTEXTS[cid]
if channel is None:
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
print("packet type in loop:", type(packet))
await channel.receive_packet(packet)
if cid == BROADCAST_CHANNEL_ID:
# TODO handle exceptions, try-catch?
await _handle_broadcast(iface, ctrl_byte, packet)
continue
await _handle_unallocated(iface, cid)
if cid in _CHANNEL_CONTEXTS:
channel = _CHANNEL_CONTEXTS[cid]
if channel is None:
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
print("packet type in loop:", type(packet))
await channel.receive_packet(packet)
continue
await _handle_unallocated(iface, cid)
except ThpError as e:
if __debug__:
log.exception(__name__, e)
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)