mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-22 12:32:02 +00:00
wip
This commit is contained in:
parent
02312699d1
commit
06cc68cc46
@ -268,6 +268,7 @@ async def _handle_qr_code_tag(
|
||||
) -> protobuf.MessageType:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(message, ThpQrCodeTag)
|
||||
assert ctx.display_data.code_qr_code is not None
|
||||
expected_tag = sha256(ctx.display_data.code_qr_code).digest()
|
||||
if expected_tag != message.tag:
|
||||
print(
|
||||
|
@ -37,6 +37,7 @@ if __debug__:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Awaitable
|
||||
|
||||
from .pairing_context import PairingContext
|
||||
from .session_context import GenericSessionContext
|
||||
@ -113,7 +114,7 @@ class Channel:
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
async def receive_packet(self, packet: utils.BufferType):
|
||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -121,7 +122,7 @@ class Channel:
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
|
||||
await self._handle_received_packet(packet)
|
||||
self._handle_received_packet(packet)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
@ -133,7 +134,7 @@ class Channel:
|
||||
|
||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||
self._finish_message()
|
||||
await received_message_handler.handle_received_message(self, self.buffer)
|
||||
return received_message_handler.handle_received_message(self, self.buffer)
|
||||
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
||||
self.is_cont_packet_expected = True
|
||||
else:
|
||||
@ -141,14 +142,13 @@ class Channel:
|
||||
"Read more bytes than is the expected length of the message!"
|
||||
)
|
||||
|
||||
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
||||
def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
||||
ctrl_byte = packet[0]
|
||||
if control_byte.is_continuation(ctrl_byte):
|
||||
await self._handle_cont_packet(packet)
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
return self._handle_cont_packet(packet)
|
||||
return self._handle_init_packet(packet)
|
||||
|
||||
async def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -175,7 +175,6 @@ class Channel:
|
||||
packet_payload,
|
||||
payload_length,
|
||||
)
|
||||
await self._buffer_packet_data(self.buffer, packet, 0)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
@ -190,8 +189,9 @@ class Channel:
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
len(self.buffer),
|
||||
)
|
||||
return self._buffer_packet_data(self.buffer, packet, 0)
|
||||
|
||||
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -200,7 +200,7 @@ class Channel:
|
||||
)
|
||||
if not self.is_cont_packet_expected:
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
await self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
||||
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
||||
|
||||
def _decrypt_single_packet_payload(
|
||||
self, payload: utils.BufferType
|
||||
@ -297,7 +297,7 @@ class Channel:
|
||||
|
||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||
|
||||
async def _buffer_packet_data(
|
||||
def _buffer_packet_data(
|
||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||
):
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
@ -309,7 +309,7 @@ class Channel:
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -323,15 +323,15 @@ class Channel:
|
||||
noise_payload_len = memory_manager.encode_into_buffer(
|
||||
self.buffer, msg, session_id
|
||||
)
|
||||
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||
return self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||
|
||||
async def write_error(self, err_type: int):
|
||||
def write_error(self, err_type: int) -> Awaitable[None]:
|
||||
msg_data = err_type.to_bytes(1, "big")
|
||||
length = len(msg_data) + CHECKSUM_LENGTH
|
||||
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
||||
await write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
||||
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
||||
|
||||
async def write_and_encrypt(self, payload: bytes) -> None:
|
||||
def write_and_encrypt(self, payload: bytes) -> None:
|
||||
payload_length = len(payload)
|
||||
self._encrypt(self.buffer, payload_length)
|
||||
payload_length = payload_length + TAG_LENGTH
|
||||
@ -346,7 +346,7 @@ class Channel:
|
||||
)
|
||||
)
|
||||
|
||||
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||
self._prepare_write()
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||
|
@ -166,7 +166,7 @@ class PairingContext(Context):
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel_ctx.write(msg)
|
||||
return self.channel_ctx.write(msg)
|
||||
|
||||
async def call(
|
||||
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
||||
|
@ -47,6 +47,8 @@ from .writer import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable
|
||||
|
||||
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
|
||||
|
||||
from .channel import Channel
|
||||
@ -68,7 +70,9 @@ async def handle_received_message(
|
||||
import micropython
|
||||
|
||||
micropython.mem_info()
|
||||
print("Allocation count:", micropython.alloc_count())
|
||||
print(
|
||||
"Allocation count:", micropython.alloc_count() # type: ignore ["alloc_count" is not a known attribute of module "micropython"]
|
||||
)
|
||||
except AttributeError:
|
||||
print("To show allocation count, create the build with TREZOR_MEMPERF=1")
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
|
||||
@ -117,7 +121,7 @@ async def handle_received_message(
|
||||
)
|
||||
except ThpUnallocatedSessionError as e:
|
||||
error_message = Failure(code=FailureType.ThpUnallocatedSession)
|
||||
await ctx.write(error_message, e.session_id)
|
||||
ctx.write(error_message, e.session_id)
|
||||
except ThpDecryptionError:
|
||||
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
|
||||
ctx.clear()
|
||||
@ -128,7 +132,7 @@ async def handle_received_message(
|
||||
log.debug(__name__, "handle_received_message - end")
|
||||
|
||||
|
||||
async def _send_ack(ctx: Channel, ack_bit: int) -> None:
|
||||
def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
|
||||
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
||||
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
|
||||
if __debug__:
|
||||
@ -138,7 +142,7 @@ async def _send_ack(ctx: Channel, ack_bit: int) -> None:
|
||||
ctx.get_channel_id_int(),
|
||||
ack_bit,
|
||||
)
|
||||
await write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
|
||||
return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"")
|
||||
|
||||
|
||||
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
|
||||
@ -175,31 +179,27 @@ async def _handle_ack(ctx: Channel, ack_bit: int):
|
||||
# this await might not be executed
|
||||
|
||||
|
||||
async def _handle_message_to_app_or_channel(
|
||||
def _handle_message_to_app_or_channel(
|
||||
ctx: Channel,
|
||||
payload_length: int,
|
||||
message_length: int,
|
||||
ctrl_byte: int,
|
||||
) -> None:
|
||||
) -> Awaitable[None]:
|
||||
state = ctx.get_channel_state()
|
||||
if __debug__:
|
||||
log.debug(__name__, "state: %s", state_to_str(state))
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
||||
return
|
||||
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
await _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
|
||||
return
|
||||
return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte)
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
await _handle_state_TH2(ctx, message_length, ctrl_byte)
|
||||
return
|
||||
return _handle_state_TH2(ctx, message_length, ctrl_byte)
|
||||
|
||||
if is_channel_state_pairing(state):
|
||||
await _handle_pairing(ctx, message_length)
|
||||
return
|
||||
return _handle_pairing(ctx, message_length)
|
||||
|
||||
raise ThpError("Unimplemented channel state")
|
||||
|
||||
@ -244,7 +244,7 @@ async def _handle_state_TH1(
|
||||
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
|
||||
|
||||
# send handshake init response message
|
||||
await ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
|
||||
ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload)
|
||||
ctx.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
|
||||
@ -323,7 +323,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
||||
if paired:
|
||||
trezor_state = thp_messages.TREZOR_STATE_PAIRED
|
||||
# send hanshake completion response
|
||||
await ctx.write_handshake_message(
|
||||
ctx.write_handshake_message(
|
||||
HANDSHAKE_COMP_RES,
|
||||
ctx.handshake.get_handshake_completion_response(trezor_state),
|
||||
)
|
||||
|
@ -148,7 +148,7 @@ class GenericSessionContext(Context):
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel.write(msg, self.session_id)
|
||||
return self.channel.write(msg, self.session_id)
|
||||
|
||||
def get_session_state(self) -> SessionState: ...
|
||||
|
||||
|
@ -14,18 +14,18 @@ MESSAGE_TYPE_LENGTH = const(2)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Sequence
|
||||
from typing import Awaitable, Sequence
|
||||
|
||||
|
||||
async def write_payload_to_wire_and_add_checksum(
|
||||
def write_payload_to_wire_and_add_checksum(
|
||||
iface: WireInterface, header: PacketHeader, transport_payload: bytes
|
||||
):
|
||||
) -> Awaitable[None]:
|
||||
header_checksum: int = crc.crc32(header.to_bytes())
|
||||
checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes(
|
||||
CHECKSUM_LENGTH, "big"
|
||||
)
|
||||
data = (transport_payload, checksum)
|
||||
await write_payloads_to_wire(iface, header, data)
|
||||
return write_payloads_to_wire(iface, header, data)
|
||||
|
||||
|
||||
async def write_payloads_to_wire(
|
||||
@ -67,7 +67,16 @@ async def write_payloads_to_wire(
|
||||
raise Exception("Should not happen!!!")
|
||||
packet_number += 1
|
||||
packet_offset = CONT_HEADER_LENGTH
|
||||
await write_packet_to_wire(iface, packet)
|
||||
|
||||
# write packet to wire (in-lined)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||
)
|
||||
written_by_iface: int = 0
|
||||
while written_by_iface < len(packet):
|
||||
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
written_by_iface = iface.write(packet)
|
||||
|
||||
|
||||
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
||||
@ -77,6 +86,6 @@ async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
||||
log.debug(
|
||||
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||
)
|
||||
n = iface.write(packet)
|
||||
if n == len(packet):
|
||||
n_written = iface.write(packet)
|
||||
if n_written == len(packet):
|
||||
return
|
||||
|
@ -143,7 +143,9 @@ async def _handle_allocated(
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
await channel.receive_packet(packet)
|
||||
x = channel.receive_packet(packet)
|
||||
if x is not None:
|
||||
await x
|
||||
|
||||
|
||||
async def _handle_unallocated(iface, cid) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user