Extract duplicated write to wire

M1nd3r/thp2
M1nd3r 2 months ago
parent 90f908c542
commit 5c57329ca5

@ -223,6 +223,8 @@ trezor.wire.thp.thp_messages
import trezor.wire.thp.thp_messages import trezor.wire.thp.thp_messages
trezor.wire.thp.thp_session trezor.wire.thp.thp_session
import trezor.wire.thp.thp_session import trezor.wire.thp.thp_session
trezor.wire.thp.writer
import trezor.wire.thp.writer
trezor.wire.thp_v1 trezor.wire.thp_v1
import trezor.wire.thp_v1 import trezor.wire.thp_v1
trezor.workflow trezor.workflow

@ -116,6 +116,7 @@ def presize_module(modname: str, size: int) -> None:
if __debug__: if __debug__:
from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource]
def mem_dump(filename: str) -> None: def mem_dump(filename: str) -> None:
from micropython import mem_info # pyright: ignore[reportMissingModuleSource] from micropython import mem_info # pyright: ignore[reportMissingModuleSource]
@ -132,6 +133,9 @@ if __debug__:
else: else:
mem_info(True) mem_info(True)
def get_bytes_as_str(a):
return hexlify(a).decode("utf-8")
def ensure(cond: bool, msg: str | None = None) -> None: def ensure(cond: bool, msg: str | None = None) -> None:
if not cond: if not cond:

@ -1,12 +1,11 @@
import ustruct # pyright: ignore[reportMissingModuleSource] import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource] from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource]
import usb import usb
from storage import cache_thp from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import io, log, loop, protobuf, utils from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType, MessageType from trezor.enums import FailureType, MessageType
from trezor.messages import Failure, ThpCreateNewSession from trezor.messages import Failure, ThpCreateNewSession
from trezor.wire import message_handler from trezor.wire import message_handler
@ -25,6 +24,7 @@ from .thp_messages import (
InitHeader, InitHeader,
) )
from .thp_session import ThpError from .thp_session import ThpError
from .writer import write_payload_to_wire
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface # pyright:ignore[reportMissingImports] from trezorio import WireInterface # pyright:ignore[reportMissingImports]
@ -102,7 +102,7 @@ class Channel(Context):
else: else:
await self._handle_init_packet(packet) await self._handle_init_packet(packet)
if __debug__: if __debug__:
log.debug(__name__, "self.buffer: %s", get_bytes_as_str(self.buffer)) log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer))
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
self._finish_message() self._finish_message()
await self._handle_completed_message() await self._handle_completed_message()
@ -281,8 +281,8 @@ class Channel(Context):
log.debug( log.debug(
__name__, __name__,
"host static pubkey: %s, noise payload: %s", "host static pubkey: %s, noise payload: %s",
get_bytes_as_str(host_encrypted_static_pubkey), utils.get_bytes_as_str(host_encrypted_static_pubkey),
get_bytes_as_str(handshake_completion_request_noise_payload), utils.get_bytes_as_str(handshake_completion_request_noise_payload),
) )
# send hanshake completion response # send hanshake completion response
@ -416,7 +416,7 @@ class Channel(Context):
self.get_channel_id_int(), self.get_channel_id_int(),
ack_bit, ack_bit,
) )
await self._write_payload_to_wire(header, chksum) await write_payload_to_wire(self.iface, header, chksum)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0: if sync_bit == 0:
@ -446,8 +446,8 @@ class Channel(Context):
) )
utils.memcpy(self.buffer, data_length, chksum, 0) utils.memcpy(self.buffer, data_length, chksum, 0)
await self._write_payload_to_wire( await write_payload_to_wire(
header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH]) self.iface, header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH])
) )
async def write_and_encrypt(self, payload: bytes) -> None: async def write_and_encrypt(self, payload: bytes) -> None:
@ -482,7 +482,7 @@ class Channel(Context):
(header.ctrl_byte & 0x10) >> 4, (header.ctrl_byte & 0x10) >> 4,
THP.sync_get_send_bit(self.channel_cache), THP.sync_get_send_bit(self.channel_cache),
) )
await self._write_payload_to_wire(header, payload) await write_payload_to_wire(self.iface, header, payload)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try: try:
await self.waiting_for_ack_timeout await self.waiting_for_ack_timeout
@ -490,40 +490,6 @@ class Channel(Context):
THP.sync_set_send_bit_to_opposite(self.channel_cache) THP.sync_set_send_bit_to_opposite(self.channel_cache)
break break
async def _write_payload_to_wire(self, header: InitHeader, payload: bytes):
if __debug__:
log.debug(__name__, "write_payload_to_wire")
# prepare the report buffer with header data
payload_len = len(payload)
report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await self._write_report_to_wire(report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_len:
header.pack_to_cont_buffer(report)
while nwritten < payload_len:
if nwritten >= payload_len - REPORT_LENGTH:
report = bytearray(REPORT_LENGTH)
header.pack_to_cont_buffer(report)
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await self._write_report_to_wire(report)
async def _write_report_to_wire(self, report: utils.BufferType) -> None:
while True:
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
if __debug__:
log.debug(
__name__, "write_report_to_wire: %s", get_bytes_as_str(report)
)
n = self.iface.write(report)
if n == len(report):
return
async def _wait_for_ack(self) -> None: async def _wait_for_ack(self) -> None:
await loop.sleep(1000) await loop.sleep(1000)
@ -726,7 +692,3 @@ def _state_to_str(state: int) -> str:
if name is not None: if name is not None:
return name return name
return "UNKNOWN_STATE" return "UNKNOWN_STATE"
def get_bytes_as_str(a):
return hexlify(a).decode("utf-8")

@ -0,0 +1,51 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from trezor import io, log, loop, utils
from trezor.wire.thp.channel import CONT_DATA_OFFSET, INIT_DATA_OFFSET, REPORT_LENGTH
from trezor.wire.thp.thp_messages import InitHeader
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
async def write_payload_to_wire(
iface: WireInterface, header: InitHeader, payload: bytes
):
if __debug__:
log.debug(__name__, "write_payload_to_wire")
# prepare the report buffer with header data
payload_len = len(payload)
# prepare the report buffer with header data
report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await _write_report_to_wire(iface, report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_len:
header.pack_to_cont_buffer(report)
while nwritten < payload_len:
if nwritten >= payload_len - REPORT_LENGTH:
# Sanitation of last report
report = bytearray(REPORT_LENGTH)
header.pack_to_cont_buffer(report)
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await _write_report_to_wire(iface, report)
async def _write_report_to_wire(iface: WireInterface, report: utils.BufferType) -> None:
while True:
await loop.wait(iface.iface_num() | io.POLL_WRITE)
if __debug__:
log.debug(
__name__, "write_report_to_wire: %s", utils.get_bytes_as_str(report)
)
n = iface.write(report)
if n == len(report):
return

@ -7,17 +7,11 @@ from trezor import io, log, loop, utils
from .protocol_common import MessageWithId from .protocol_common import MessageWithId
from .thp import ChannelState, checksum, thp_messages from .thp import ChannelState, checksum, thp_messages
from .thp.channel import ( from .thp.channel import MAX_PAYLOAD_LEN, REPORT_LENGTH, Channel, load_cached_channels
CONT_DATA_OFFSET,
INIT_DATA_OFFSET,
MAX_PAYLOAD_LEN,
REPORT_LENGTH,
Channel,
load_cached_channels,
)
from .thp.checksum import CHECKSUM_LENGTH from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader
from .thp.thp_session import ThpError from .thp.thp_session import ThpError
from .thp.writer import write_payload_to_wire
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports] from trezorio import WireInterface # pyright: ignore[reportMissingImports]
@ -64,8 +58,10 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
if cid in _CHANNEL_CONTEXTS: if cid in _CHANNEL_CONTEXTS:
channel = _CHANNEL_CONTEXTS[cid] channel = _CHANNEL_CONTEXTS[cid]
if channel is None: if channel is None:
# TODO send error message to wire
raise ThpError("Invalid state of a channel") raise ThpError("Invalid state of a channel")
if channel.iface is not iface: if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface") raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED: if channel.get_channel_state() != ChannelState.UNALLOCATED:
@ -80,63 +76,6 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that) # TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
return _try_allocate_new_buffer(payload_length)
return _reuse_existing_buffer(payload_length, existing_buffer)
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
def _reuse_existing_buffer(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
return memoryview(existing_buffer)[:payload_length]
async def write_to_wire(
iface: WireInterface, header: InitHeader, payload: bytes
) -> None:
loop_write = loop.wait(iface.iface_num() | io.POLL_WRITE)
payload_length = len(payload)
# prepare the report buffer with header data
report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await _write_report(loop_write, iface, report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_length:
header.pack_to_cont_buffer(report)
while nwritten < payload_length:
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await _write_report(loop_write, iface, report)
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
while True:
await write
n = iface.write(report)
if n == len(report):
return
async def _handle_broadcast( async def _handle_broadcast(
iface: WireInterface, ctrl_byte, packet iface: WireInterface, ctrl_byte, packet
) -> MessageWithId | None: ) -> MessageWithId | None:
@ -167,14 +106,39 @@ async def _handle_broadcast(
if __debug__: if __debug__:
log.debug(__name__, "New channel allocated with id %d", cid) log.debug(__name__, "New channel allocated with id %d", cid)
await write_to_wire(iface, response_header, response_data + chksum) await write_payload_to_wire(iface, response_header, response_data + chksum)
async def _handle_unallocated(iface, cid) -> MessageWithId | None: async def _handle_unallocated(iface, cid) -> MessageWithId | None:
data = thp_messages.get_error_unallocated_channel() data = thp_messages.get_error_unallocated_channel()
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes() + data) chksum = checksum.compute(header.to_bytes() + data)
await write_to_wire(iface, header, data + chksum) await write_payload_to_wire(iface, header, data + chksum)
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
return _try_allocate_new_buffer(payload_length)
return _reuse_existing_buffer(payload_length, existing_buffer)
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
def _reuse_existing_buffer(
payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType:
return memoryview(existing_buffer)[:payload_length]
async def deprecated_read_message( async def deprecated_read_message(

Loading…
Cancel
Save