mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-01 12:22:34 +00:00
wip
This commit is contained in:
parent
02312699d1
commit
06cc68cc46
@ -268,6 +268,7 @@ async def _handle_qr_code_tag(
|
|||||||
) -> protobuf.MessageType:
|
) -> protobuf.MessageType:
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert isinstance(message, ThpQrCodeTag)
|
assert isinstance(message, ThpQrCodeTag)
|
||||||
|
assert ctx.display_data.code_qr_code is not None
|
||||||
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
|
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
|
||||||
if expected_tag != message.tag:
|
if expected_tag != message.tag:
|
||||||
print(
|
print(
|
||||||
|
@ -37,6 +37,7 @@ if __debug__:
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
from typing import Awaitable
|
||||||
|
|
||||||
from .pairing_context import PairingContext
|
from .pairing_context import PairingContext
|
||||||
from .session_context import GenericSessionContext
|
from .session_context import GenericSessionContext
|
||||||
@ -113,7 +114,7 @@ class Channel:
|
|||||||
|
|
||||||
# CALLED BY THP_MAIN_LOOP
|
# CALLED BY THP_MAIN_LOOP
|
||||||
|
|
||||||
async def receive_packet(self, packet: utils.BufferType):
|
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
@ -121,7 +122,7 @@ class Channel:
|
|||||||
utils.get_bytes_as_str(self.channel_id),
|
utils.get_bytes_as_str(self.channel_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._handle_received_packet(packet)
|
self._handle_received_packet(packet)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -133,7 +134,7 @@ class Channel:
|
|||||||
|
|
||||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||||
self._finish_message()
|
self._finish_message()
|
||||||
await received_message_handler.handle_received_message(self, self.buffer)
|
return received_message_handler.handle_received_message(self, self.buffer)
|
||||||
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
||||||
self.is_cont_packet_expected = True
|
self.is_cont_packet_expected = True
|
||||||
else:
|
else:
|
||||||
@ -141,14 +142,13 @@ class Channel:
|
|||||||
"Read more bytes than is the expected length of the message!"
|
"Read more bytes than is the expected length of the message!"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
||||||
ctrl_byte = packet[0]
|
ctrl_byte = packet[0]
|
||||||
if control_byte.is_continuation(ctrl_byte):
|
if control_byte.is_continuation(ctrl_byte):
|
||||||
await self._handle_cont_packet(packet)
|
return self._handle_cont_packet(packet)
|
||||||
else:
|
return self._handle_init_packet(packet)
|
||||||
await self._handle_init_packet(packet)
|
|
||||||
|
|
||||||
async def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
@ -175,7 +175,6 @@ class Channel:
|
|||||||
packet_payload,
|
packet_payload,
|
||||||
payload_length,
|
payload_length,
|
||||||
)
|
)
|
||||||
await self._buffer_packet_data(self.buffer, packet, 0)
|
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -190,8 +189,9 @@ class Channel:
|
|||||||
utils.get_bytes_as_str(self.channel_id),
|
utils.get_bytes_as_str(self.channel_id),
|
||||||
len(self.buffer),
|
len(self.buffer),
|
||||||
)
|
)
|
||||||
|
return self._buffer_packet_data(self.buffer, packet, 0)
|
||||||
|
|
||||||
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
@ -200,7 +200,7 @@ class Channel:
|
|||||||
)
|
)
|
||||||
if not self.is_cont_packet_expected:
|
if not self.is_cont_packet_expected:
|
||||||
raise ThpError("Continuation packet is not expected, ignoring")
|
raise ThpError("Continuation packet is not expected, ignoring")
|
||||||
await self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
||||||
|
|
||||||
def _decrypt_single_packet_payload(
|
def _decrypt_single_packet_payload(
|
||||||
self, payload: utils.BufferType
|
self, payload: utils.BufferType
|
||||||
@ -297,7 +297,7 @@ class Channel:
|
|||||||
|
|
||||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||||
|
|
||||||
async def _buffer_packet_data(
|
def _buffer_packet_data(
|
||||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||||
):
|
):
|
||||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||||
@ -309,7 +309,7 @@ class Channel:
|
|||||||
|
|
||||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||||
if __debug__ and utils.EMULATOR:
|
if __debug__ and utils.EMULATOR:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
@ -323,15 +323,15 @@ class Channel:
|
|||||||
noise_payload_len = memory_manager.encode_into_buffer(
|
noise_payload_len = memory_manager.encode_into_buffer(
|
||||||
self.buffer, msg, session_id
|
self.buffer, msg, session_id
|
||||||
)
|
)
|
||||||
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
return self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||||
|
|
||||||
async def write_error(self, err_type: int):
|
def write_error(self, err_type: int) -> Awaitable[None]:
|
||||||
msg_data = err_type.to_bytes(1, "big")
|
msg_data = err_type.to_bytes(1, "big")
|
||||||
length = len(msg_data) + CHECKSUM_LENGTH
|
length = len(msg_data) + CHECKSUM_LENGTH
|
||||||
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
||||||
await write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
||||||
|
|
||||||
async def write_and_encrypt(self, payload: bytes) -> None:
|
def write_and_encrypt(self, payload: bytes) -> None:
|
||||||
payload_length = len(payload)
|
payload_length = len(payload)
|
||||||
self._encrypt(self.buffer, payload_length)
|
self._encrypt(self.buffer, payload_length)
|
||||||
payload_length = payload_length + TAG_LENGTH
|
payload_length = payload_length + TAG_LENGTH
|
||||||
@ -346,7 +346,7 @@ class Channel:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||||
self._prepare_write()
|
self._prepare_write()
|
||||||
self.write_task_spawn = loop.spawn(
|
self.write_task_spawn = loop.spawn(
|
||||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||||
|
@ -166,7 +166,7 @@ class PairingContext(Context):
|
|||||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
return await self.channel_ctx.write(msg)
|
return self.channel_ctx.write(msg)
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
||||||
|
@ -47,6 +47,8 @@ from .writer import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from typing import Awaitable
|
||||||
|
|
||||||
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
|
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
|
||||||
|
|
||||||
from .channel import Channel
|
from .channel import Channel
|
||||||
@ -68,7 +70,9 @@ async def handle_received_message(
|
|||||||
import micropython
|
import micropython
|
||||||
|
|
||||||
micropython.mem_info()
|
micropython.mem_info()
|
||||||
print("Allocation count:", micropython.alloc_count())
|
print(
|
||||||
|
"Allocation count:", micropython.alloc_count() # type: ignore ["alloc_count" is not a known attribute of module "micropython"]
|
||||||
|
)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
print("To show allocation count, create the build with TREZOR_MEMPERF=1")
|
print("To show allocation count, create the build with TREZOR_MEMPERF=1")
|
||||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
|
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
|
||||||
@ -117,7 +121,7 @@ async def handle_received_message(
|
|||||||
)
|
)
|
||||||
except ThpUnallocatedSessionError as e:
|
except ThpUnallocatedSessionError as e:
|
||||||
error_message = Failure(code=FailureType.ThpUnallocatedSession)
|
error_message = Failure(code=FailureType.ThpUnallocatedSession)
|
||||||
await ctx.write(error_message, e.session_id)
|
ctx.write(error_message, e.session_id)
|
||||||
except ThpDecryptionError:
|
except ThpDecryptionError:
|
||||||
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
|
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
|
||||||
ctx.clear()
|
ctx.clear()
|
||||||
@ -128,7 +132,7 @@ async def handle_received_message(
|
|||||||
log.debug(__name__, "handle_received_message - end")
|
log.debug(__name__, "handle_received_message - end")
|
||||||
|
|
||||||
|
|
||||||
async def _send_ack(ctx: Channel, ack_bit: int) -> None:
|
def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
|
||||||
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
||||||
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
|
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
|
||||||
if __debug__:
|
if __debug__:
|
||||||
@ -138,7 +142,7 @@ async def _send_ack(ctx: Channel, ack_bit: int) -> None:
|
|||||||
ctx.get_channel_id_int(),
|
ctx.get_channel_id_int(),
|
||||||
ack_bit,
|
ack_bit,
|
||||||
)
|
)
|
||||||
await write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
|
return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
|
||||||
|
|
||||||
|
|
||||||
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
|
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
|
||||||
@ -175,31 +179,27 @@ async def _handle_ack(ctx: Channel, ack_bit: int):
|
|||||||
# this await might not be executed
|
# this await might not be executed
|
||||||
|
|
||||||
|
|
||||||
async def _handle_message_to_app_or_channel(
|
def _handle_message_to_app_or_channel(
|
||||||
ctx: Channel,
|
ctx: Channel,
|
||||||
payload_length: int,
|
payload_length: int,
|
||||||
message_length: int,
|
message_length: int,
|
||||||
ctrl_byte: int,
|
ctrl_byte: int,
|
||||||
) -> None:
|
) -> Awaitable[None]:
|
||||||
state = ctx.get_channel_state()
|
state = ctx.get_channel_state()
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(__name__, "state: %s", state_to_str(state))
|
log.debug(__name__, "state: %s", state_to_str(state))
|
||||||
|
|
||||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||||
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
||||||
return
|
|
||||||
|
|
||||||
if state is ChannelState.TH1:
|
if state is ChannelState.TH1:
|
||||||
await _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
|
return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
|
||||||
return
|
|
||||||
|
|
||||||
if state is ChannelState.TH2:
|
if state is ChannelState.TH2:
|
||||||
await _handle_state_TH2(ctx, message_length, ctrl_byte)
|
return _handle_state_TH2(ctx, message_length, ctrl_byte)
|
||||||
return
|
|
||||||
|
|
||||||
if is_channel_state_pairing(state):
|
if is_channel_state_pairing(state):
|
||||||
await _handle_pairing(ctx, message_length)
|
return _handle_pairing(ctx, message_length)
|
||||||
return
|
|
||||||
|
|
||||||
raise ThpError("Unimplemented channel state")
|
raise ThpError("Unimplemented channel state")
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ async def _handle_state_TH1(
|
|||||||
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
|
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
|
||||||
|
|
||||||
# send handshake init response message
|
# send handshake init response message
|
||||||
await ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
|
ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
|
||||||
ctx.set_channel_state(ChannelState.TH2)
|
ctx.set_channel_state(ChannelState.TH2)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -323,7 +323,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
if paired:
|
if paired:
|
||||||
trezor_state = thp_messages.TREZOR_STATE_PAIRED
|
trezor_state = thp_messages.TREZOR_STATE_PAIRED
|
||||||
# send hanshake completion response
|
# send hanshake completion response
|
||||||
await ctx.write_handshake_message(
|
ctx.write_handshake_message(
|
||||||
HANDSHAKE_COMP_RES,
|
HANDSHAKE_COMP_RES,
|
||||||
ctx.handshake.get_handshake_completion_response(trezor_state),
|
ctx.handshake.get_handshake_completion_response(trezor_state),
|
||||||
)
|
)
|
||||||
|
@ -148,7 +148,7 @@ class GenericSessionContext(Context):
|
|||||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
return await self.channel.write(msg, self.session_id)
|
return self.channel.write(msg, self.session_id)
|
||||||
|
|
||||||
def get_session_state(self) -> SessionState: ...
|
def get_session_state(self) -> SessionState: ...
|
||||||
|
|
||||||
|
@ -14,18 +14,18 @@ MESSAGE_TYPE_LENGTH = const(2)
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
from typing import Sequence
|
from typing import Awaitable, Sequence
|
||||||
|
|
||||||
|
|
||||||
async def write_payload_to_wire_and_add_checksum(
|
def write_payload_to_wire_and_add_checksum(
|
||||||
iface: WireInterface, header: PacketHeader, transport_payload: bytes
|
iface: WireInterface, header: PacketHeader, transport_payload: bytes
|
||||||
):
|
) -> Awaitable[None]:
|
||||||
header_checksum: int = crc.crc32(header.to_bytes())
|
header_checksum: int = crc.crc32(header.to_bytes())
|
||||||
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
|
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
|
||||||
CHECKSUM_LENGTH, "big"
|
CHECKSUM_LENGTH, "big"
|
||||||
)
|
)
|
||||||
data = (transport_payload, checksum)
|
data = (transport_payload, checksum)
|
||||||
await write_payloads_to_wire(iface, header, data)
|
return write_payloads_to_wire(iface, header, data)
|
||||||
|
|
||||||
|
|
||||||
async def write_payloads_to_wire(
|
async def write_payloads_to_wire(
|
||||||
@ -67,7 +67,16 @@ async def write_payloads_to_wire(
|
|||||||
raise Exception("Should not happen!!!")
|
raise Exception("Should not happen!!!")
|
||||||
packet_number += 1
|
packet_number += 1
|
||||||
packet_offset = CONT_HEADER_LENGTH
|
packet_offset = CONT_HEADER_LENGTH
|
||||||
await write_packet_to_wire(iface, packet)
|
|
||||||
|
# write packet to wire (in-lined)
|
||||||
|
if __debug__:
|
||||||
|
log.debug(
|
||||||
|
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||||
|
)
|
||||||
|
written_by_iface: int = 0
|
||||||
|
while written_by_iface < len(packet):
|
||||||
|
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||||
|
written_by_iface = iface.write(packet)
|
||||||
|
|
||||||
|
|
||||||
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
||||||
@ -77,6 +86,6 @@ async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
|||||||
log.debug(
|
log.debug(
|
||||||
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||||
)
|
)
|
||||||
n = iface.write(packet)
|
n_written = iface.write(packet)
|
||||||
if n == len(packet):
|
if n_written == len(packet):
|
||||||
return
|
return
|
||||||
|
@ -143,7 +143,9 @@ async def _handle_allocated(
|
|||||||
raise ThpError("Channel has different WireInterface")
|
raise ThpError("Channel has different WireInterface")
|
||||||
|
|
||||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||||
await channel.receive_packet(packet)
|
x = channel.receive_packet(packet)
|
||||||
|
if x is not None:
|
||||||
|
await x
|
||||||
|
|
||||||
|
|
||||||
async def _handle_unallocated(iface, cid) -> None:
|
async def _handle_unallocated(iface, cid) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user