mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
refactor(core): extract duplicated write_to_wire
This commit is contained in:
parent
360d3afa23
commit
31ad84133c
2
core/src/all_modules.py
generated
2
core/src/all_modules.py
generated
@ -223,6 +223,8 @@ trezor.wire.thp.thp_messages
|
||||
import trezor.wire.thp.thp_messages
|
||||
trezor.wire.thp.thp_session
|
||||
import trezor.wire.thp.thp_session
|
||||
trezor.wire.thp.writer
|
||||
import trezor.wire.thp.writer
|
||||
trezor.wire.thp_v1
|
||||
import trezor.wire.thp_v1
|
||||
trezor.workflow
|
||||
|
@ -117,6 +117,7 @@ def presize_module(modname: str, size: int) -> None:
|
||||
|
||||
|
||||
if __debug__:
|
||||
from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
def mem_dump(filename: str) -> None:
|
||||
from micropython import mem_info # pyright: ignore[reportMissingModuleSource]
|
||||
@ -133,6 +134,9 @@ if __debug__:
|
||||
else:
|
||||
mem_info(True)
|
||||
|
||||
def get_bytes_as_str(a):
|
||||
return hexlify(a).decode("utf-8")
|
||||
|
||||
|
||||
def ensure(cond: bool, msg: str | None = None) -> None:
|
||||
if not cond:
|
||||
|
@ -1,12 +1,11 @@
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
|
||||
from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
import usb
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
|
||||
from trezor import io, log, loop, protobuf, utils
|
||||
from trezor import log, loop, protobuf, utils
|
||||
from trezor.enums import FailureType, MessageType
|
||||
from trezor.messages import Failure, ThpCreateNewSession
|
||||
from trezor.wire import message_handler
|
||||
@ -25,6 +24,7 @@ from .thp_messages import (
|
||||
InitHeader,
|
||||
)
|
||||
from .thp_session import ThpError
|
||||
from .writer import write_payload_to_wire
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
|
||||
@ -102,7 +102,7 @@ class Channel(Context):
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
if __debug__:
|
||||
log.debug(__name__, "self.buffer: %s", get_bytes_as_str(self.buffer))
|
||||
log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer))
|
||||
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
||||
self._finish_message()
|
||||
await self._handle_completed_message()
|
||||
@ -281,8 +281,8 @@ class Channel(Context):
|
||||
log.debug(
|
||||
__name__,
|
||||
"host static pubkey: %s, noise payload: %s",
|
||||
get_bytes_as_str(host_encrypted_static_pubkey),
|
||||
get_bytes_as_str(handshake_completion_request_noise_payload),
|
||||
utils.get_bytes_as_str(host_encrypted_static_pubkey),
|
||||
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
|
||||
)
|
||||
|
||||
# send hanshake completion response
|
||||
@ -416,7 +416,7 @@ class Channel(Context):
|
||||
self.get_channel_id_int(),
|
||||
ack_bit,
|
||||
)
|
||||
await self._write_payload_to_wire(header, chksum)
|
||||
await write_payload_to_wire(self.iface, header, chksum)
|
||||
|
||||
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
@ -446,8 +446,8 @@ class Channel(Context):
|
||||
)
|
||||
|
||||
utils.memcpy(self.buffer, data_length, chksum, 0)
|
||||
await self._write_payload_to_wire(
|
||||
header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH])
|
||||
await write_payload_to_wire(
|
||||
self.iface, header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH])
|
||||
)
|
||||
|
||||
async def write_and_encrypt(self, payload: bytes) -> None:
|
||||
@ -482,7 +482,7 @@ class Channel(Context):
|
||||
(header.ctrl_byte & 0x10) >> 4,
|
||||
THP.sync_get_send_bit(self.channel_cache),
|
||||
)
|
||||
await self._write_payload_to_wire(header, payload)
|
||||
await write_payload_to_wire(self.iface, header, payload)
|
||||
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
||||
try:
|
||||
await self.waiting_for_ack_timeout
|
||||
@ -490,40 +490,6 @@ class Channel(Context):
|
||||
THP.sync_set_send_bit_to_opposite(self.channel_cache)
|
||||
break
|
||||
|
||||
async def _write_payload_to_wire(self, header: InitHeader, payload: bytes):
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_payload_to_wire")
|
||||
# prepare the report buffer with header data
|
||||
payload_len = len(payload)
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
|
||||
await self._write_report_to_wire(report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_len:
|
||||
header.pack_to_cont_buffer(report)
|
||||
while nwritten < payload_len:
|
||||
if nwritten >= payload_len - REPORT_LENGTH:
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_cont_buffer(report)
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await self._write_report_to_wire(report)
|
||||
|
||||
async def _write_report_to_wire(self, report: utils.BufferType) -> None:
|
||||
while True:
|
||||
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "write_report_to_wire: %s", get_bytes_as_str(report)
|
||||
)
|
||||
n = self.iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
||||
|
||||
async def _wait_for_ack(self) -> None:
|
||||
await loop.sleep(1000)
|
||||
|
||||
@ -726,7 +692,3 @@ def _state_to_str(state: int) -> str:
|
||||
if name is not None:
|
||||
return name
|
||||
return "UNKNOWN_STATE"
|
||||
|
||||
|
||||
def get_bytes_as_str(a):
|
||||
return hexlify(a).decode("utf-8")
|
||||
|
51
core/src/trezor/wire/thp/writer.py
Normal file
51
core/src/trezor/wire/thp/writer.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from trezor import io, log, loop, utils
|
||||
from trezor.wire.thp.channel import CONT_DATA_OFFSET, INIT_DATA_OFFSET, REPORT_LENGTH
|
||||
from trezor.wire.thp.thp_messages import InitHeader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
||||
|
||||
|
||||
async def write_payload_to_wire(
|
||||
iface: WireInterface, header: InitHeader, payload: bytes
|
||||
):
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_payload_to_wire")
|
||||
# prepare the report buffer with header data
|
||||
payload_len = len(payload)
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
|
||||
await _write_report_to_wire(iface, report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_len:
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
while nwritten < payload_len:
|
||||
if nwritten >= payload_len - REPORT_LENGTH:
|
||||
# Sanitation of last report
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await _write_report_to_wire(iface, report)
|
||||
|
||||
|
||||
async def _write_report_to_wire(iface: WireInterface, report: utils.BufferType) -> None:
|
||||
while True:
|
||||
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "write_report_to_wire: %s", utils.get_bytes_as_str(report)
|
||||
)
|
||||
n = iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
@ -7,17 +7,11 @@ from trezor import io, log, loop, utils
|
||||
|
||||
from .protocol_common import MessageWithId
|
||||
from .thp import ChannelState, checksum, thp_messages
|
||||
from .thp.channel import (
|
||||
CONT_DATA_OFFSET,
|
||||
INIT_DATA_OFFSET,
|
||||
MAX_PAYLOAD_LEN,
|
||||
REPORT_LENGTH,
|
||||
Channel,
|
||||
load_cached_channels,
|
||||
)
|
||||
from .thp.channel import MAX_PAYLOAD_LEN, REPORT_LENGTH, Channel, load_cached_channels
|
||||
from .thp.checksum import CHECKSUM_LENGTH
|
||||
from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader
|
||||
from .thp.thp_session import ThpError
|
||||
from .thp.writer import write_payload_to_wire
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
||||
@ -64,8 +58,10 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
if cid in _CHANNEL_CONTEXTS:
|
||||
channel = _CHANNEL_CONTEXTS[cid]
|
||||
if channel is None:
|
||||
# TODO send error message to wire
|
||||
raise ThpError("Invalid state of a channel")
|
||||
if channel.iface is not iface:
|
||||
# TODO send error message to wire
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
@ -80,63 +76,6 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
return _try_allocate_new_buffer(payload_length)
|
||||
return _reuse_existing_buffer(payload_length, existing_buffer)
|
||||
|
||||
|
||||
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
|
||||
def _reuse_existing_buffer(
|
||||
payload_length: int, existing_buffer: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
async def write_to_wire(
|
||||
iface: WireInterface, header: InitHeader, payload: bytes
|
||||
) -> None:
|
||||
loop_write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||
|
||||
payload_length = len(payload)
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
await _write_report(loop_write, iface, report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_length:
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
while nwritten < payload_length:
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await _write_report(loop_write, iface, report)
|
||||
|
||||
|
||||
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
|
||||
while True:
|
||||
await write
|
||||
n = iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
||||
|
||||
|
||||
async def _handle_broadcast(
|
||||
iface: WireInterface, ctrl_byte, packet
|
||||
) -> MessageWithId | None:
|
||||
@ -167,14 +106,39 @@ async def _handle_broadcast(
|
||||
if __debug__:
|
||||
log.debug(__name__, "New channel allocated with id %d", cid)
|
||||
|
||||
await write_to_wire(iface, response_header, response_data + chksum)
|
||||
await write_payload_to_wire(iface, response_header, response_data + chksum)
|
||||
|
||||
|
||||
async def _handle_unallocated(iface, cid) -> MessageWithId | None:
|
||||
data = thp_messages.get_error_unallocated_channel()
|
||||
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
|
||||
chksum = checksum.compute(header.to_bytes() + data)
|
||||
await write_to_wire(iface, header, data + chksum)
|
||||
await write_payload_to_wire(iface, header, data + chksum)
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
return _try_allocate_new_buffer(payload_length)
|
||||
return _reuse_existing_buffer(payload_length, existing_buffer)
|
||||
|
||||
|
||||
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
|
||||
def _reuse_existing_buffer(
|
||||
payload_length: int, existing_buffer: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
async def deprecated_read_message(
|
||||
|
Loading…
Reference in New Issue
Block a user