mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
fix(core): fix debug log, crashing and mock noise tags
This commit is contained in:
parent
8d346120f4
commit
504a3bfe98
2
core/src/all_modules.py
generated
2
core/src/all_modules.py
generated
@ -217,6 +217,8 @@ trezor.wire.thp.channel
|
||||
import trezor.wire.thp.channel
|
||||
trezor.wire.thp.checksum
|
||||
import trezor.wire.thp.checksum
|
||||
trezor.wire.thp.crypto
|
||||
import trezor.wire.thp.crypto
|
||||
trezor.wire.thp.session_context
|
||||
import trezor.wire.thp.session_context
|
||||
trezor.wire.thp.thp_messages
|
||||
|
@ -180,7 +180,7 @@ def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> Non
|
||||
|
||||
|
||||
def get_new_session(channel: ChannelCache):
|
||||
|
||||
print("---------------get new session")
|
||||
new_sid = get_next_session_id(channel)
|
||||
index = _get_next_session_index()
|
||||
|
||||
@ -194,6 +194,8 @@ def get_new_session(channel: ChannelCache):
|
||||
_SESSIONS[index].state[:] = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
|
||||
)
|
||||
for s in _SESSIONS:
|
||||
print(s)
|
||||
return _SESSIONS[index]
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from trezor.wire import message_handler
|
||||
from trezor.wire.thp import thp_messages
|
||||
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
from . import ChannelState, SessionState, checksum
|
||||
from . import ChannelState, SessionState, checksum, crypto
|
||||
from . import thp_session as THP
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .thp_messages import (
|
||||
@ -90,7 +90,8 @@ class Channel(Context):
|
||||
await self._handle_cont_packet(packet)
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
|
||||
print("receive packet", self.expected_payload_length, self.bytes_read)
|
||||
printBytes(self.buffer)
|
||||
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
||||
self._finish_message()
|
||||
await self._handle_completed_message()
|
||||
@ -103,7 +104,7 @@ class Channel(Context):
|
||||
# If the channel does not "own" the buffer lock, decrypt first packet
|
||||
# TODO do it only when needed!
|
||||
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
||||
packet_payload = self._decrypt(packet_payload)
|
||||
packet_payload = self._decrypt_single_packet_payload(packet_payload)
|
||||
|
||||
state = self.get_channel_state()
|
||||
|
||||
@ -254,7 +255,7 @@ class Channel(Context):
|
||||
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
|
||||
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
|
||||
self._decrypt_buffer()
|
||||
self._decrypt_buffer(message_length)
|
||||
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
|
||||
if session_id == 0:
|
||||
self._handle_channel_comms(message_length, message_type)
|
||||
@ -293,18 +294,43 @@ class Channel(Context):
|
||||
bufferrone = bytearray(2)
|
||||
message_size: int = thp_messages.get_new_session_message(bufferrone)
|
||||
print(message_size) # TODO adjust
|
||||
loop.schedule(self._write_encrypted_payload_loop(bufferrone))
|
||||
loop.schedule(self.write_and_encrypt(bufferrone))
|
||||
except Exception as e:
|
||||
print("Proč??")
|
||||
print(e)
|
||||
# TODO not finished
|
||||
|
||||
def _decrypt(self, payload) -> bytes:
|
||||
return payload # TODO add decryption process
|
||||
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
|
||||
payload_buffer = bytearray(payload)
|
||||
crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
|
||||
return payload_buffer
|
||||
|
||||
def _decrypt_buffer(self) -> None:
|
||||
pass
|
||||
# TODO decode buffer in place
|
||||
def _decrypt_buffer(self, message_length: int) -> None:
|
||||
if not isinstance(self.buffer, bytearray):
|
||||
self.buffer = bytearray(self.buffer)
|
||||
crypto.decrypt(
|
||||
b"\x00",
|
||||
b"\x00",
|
||||
self.buffer,
|
||||
INIT_DATA_OFFSET,
|
||||
message_length - INIT_DATA_OFFSET - CHECKSUM_LENGTH,
|
||||
)
|
||||
|
||||
def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None:
|
||||
print("\n Encrypting ")
|
||||
min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
if len(buffer) < min_required_length or not isinstance(buffer, bytearray):
|
||||
new_buffer = bytearray(min_required_length)
|
||||
utils.memcpy(new_buffer, 0, buffer, 0)
|
||||
buffer = new_buffer
|
||||
tag = crypto.encrypt(
|
||||
b"\x00",
|
||||
b"\x00",
|
||||
buffer,
|
||||
0,
|
||||
noise_payload_len,
|
||||
)
|
||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||
|
||||
async def _buffer_packet_data(
|
||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||
@ -327,7 +353,7 @@ class Channel(Context):
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
||||
self.channel_id,
|
||||
int.from_bytes(self.channel_id, "big"),
|
||||
ack_bit,
|
||||
)
|
||||
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
|
||||
@ -343,15 +369,18 @@ class Channel(Context):
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
print("write")
|
||||
|
||||
noise_payload_len = self._encode_into_buffer(msg, session_id)
|
||||
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||
|
||||
# trezor.crypto.noise.encode(key, payload=self.buffer)
|
||||
async def write_and_encrypt(self, payload: bytes) -> None:
|
||||
payload_length = len(payload)
|
||||
|
||||
# TODO payload_len should be output from trezor.crypto.noise.encode, I guess
|
||||
payload_len = noise_payload_len # + TAG_LENGTH # TODO
|
||||
if not isinstance(self.buffer, bytearray):
|
||||
self.buffer = bytearray(self.buffer)
|
||||
self._encrypt(self.buffer, payload_length)
|
||||
payload_length = payload_length + TAG_LENGTH
|
||||
|
||||
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
|
||||
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length]))
|
||||
|
||||
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
||||
print("write loop before while")
|
||||
@ -419,10 +448,13 @@ class Channel(Context):
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
||||
payload_size = offset + msg_size
|
||||
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
|
||||
|
||||
if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray):
|
||||
if required_min_size > len(self.buffer) or not isinstance(
|
||||
self.buffer, bytearray
|
||||
):
|
||||
# message is too big or buffer is not bytearray, we need to allocate a new buffer
|
||||
self.buffer = bytearray(payload_size)
|
||||
self.buffer = bytearray(required_min_size)
|
||||
|
||||
buffer = self.buffer
|
||||
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
|
||||
@ -445,6 +477,7 @@ class Channel(Context):
|
||||
self.sessions[session.session_id] = session
|
||||
loop.schedule(session.handle())
|
||||
print("new session created. Session id:", session.session_id)
|
||||
print(self.sessions)
|
||||
|
||||
def _todo_clear_buffer(self):
|
||||
# TODO Buffer clearing not implemented
|
||||
|
34
core/src/trezor/wire/thp/crypto.py
Normal file
34
core/src/trezor/wire/thp/crypto.py
Normal file
@ -0,0 +1,34 @@
|
||||
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
|
||||
|
||||
|
||||
# TODO implement
|
||||
|
||||
|
||||
def encrypt(
|
||||
key: bytes,
|
||||
nonce: bytes,
|
||||
buffer: bytearray,
|
||||
init_offset: int = 0,
|
||||
payload_length: int = 0,
|
||||
) -> bytes:
|
||||
"""
|
||||
Returns a 16-byte long encryption tag, the encryption itself is performed on the buffer provided.
|
||||
"""
|
||||
return DUMMY_TAG
|
||||
|
||||
|
||||
def decrypt(
|
||||
key: bytes,
|
||||
nonce: bytes,
|
||||
buffer: bytearray,
|
||||
init_offset: int = 0,
|
||||
payload_length: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Decryption in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def is_tag_valid(key: bytes, nonce: bytes, payload: bytes, noise_tag: bytes) -> bool:
|
||||
return True
|
@ -139,10 +139,18 @@ class SessionContext(Context):
|
||||
|
||||
|
||||
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
|
||||
print("start loading sessions from cache")
|
||||
sessions: dict[int, SessionContext] = {}
|
||||
cached_sessions = cache_thp.get_all_allocated_sessions()
|
||||
print(
|
||||
"loaded a total of ",
|
||||
len(cached_sessions),
|
||||
"sessions from cache",
|
||||
)
|
||||
for session in cached_sessions:
|
||||
if session.channel_id == channel.channel_id:
|
||||
sid = int.from_bytes(session.session_id, "big")
|
||||
sessions[sid] = SessionContext(channel, session)
|
||||
for i in sessions:
|
||||
print("session", i)
|
||||
return sessions
|
||||
|
@ -44,34 +44,39 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
while True:
|
||||
print("main loop")
|
||||
packet = await read
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
try:
|
||||
print("main loop")
|
||||
packet = await read
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
|
||||
if ctrl_byte == CODEC_V1:
|
||||
pass
|
||||
# TODO add handling of (unsupported) codec_v1 packets
|
||||
# possibly ignore continuation packets, i.e. if the
|
||||
# following bytes are not "##"", do not respond
|
||||
if ctrl_byte == CODEC_V1:
|
||||
pass
|
||||
# TODO add handling of (unsupported) codec_v1 packets
|
||||
# possibly ignore continuation packets, i.e. if the
|
||||
# following bytes are not "##"", do not respond
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
# TODO handle exceptions, try-catch?
|
||||
await _handle_broadcast(iface, ctrl_byte, packet)
|
||||
continue
|
||||
|
||||
if cid in _CHANNEL_CONTEXTS:
|
||||
channel = _CHANNEL_CONTEXTS[cid]
|
||||
if channel is None:
|
||||
raise ThpError("Invalid state of a channel")
|
||||
if channel.iface is not iface:
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
print("packet type in loop:", type(packet))
|
||||
await channel.receive_packet(packet)
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
# TODO handle exceptions, try-catch?
|
||||
await _handle_broadcast(iface, ctrl_byte, packet)
|
||||
continue
|
||||
|
||||
await _handle_unallocated(iface, cid)
|
||||
if cid in _CHANNEL_CONTEXTS:
|
||||
channel = _CHANNEL_CONTEXTS[cid]
|
||||
if channel is None:
|
||||
raise ThpError("Invalid state of a channel")
|
||||
if channel.iface is not iface:
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
print("packet type in loop:", type(packet))
|
||||
await channel.receive_packet(packet)
|
||||
continue
|
||||
await _handle_unallocated(iface, cid)
|
||||
|
||||
except ThpError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
|
||||
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user