1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-04 20:01:18 +00:00

fix(core): update packet length handling

[no changelog]
This commit is contained in:
M1nd3r 2025-01-24 12:22:29 +01:00
parent 2f7426bbe8
commit 9a2b8c3904
4 changed files with 16 additions and 26 deletions

View File

@ -33,7 +33,6 @@ from .writer import (
CONT_HEADER_LENGTH, CONT_HEADER_LENGTH,
INIT_HEADER_LENGTH, INIT_HEADER_LENGTH,
MESSAGE_TYPE_LENGTH, MESSAGE_TYPE_LENGTH,
PACKET_LENGTH,
write_payload_to_wire_and_add_checksum, write_payload_to_wire_and_add_checksum,
) )
@ -65,7 +64,7 @@ class Channel:
self.channel_id: bytes = channel_cache.channel_id self.channel_id: bytes = channel_cache.channel_id
# Shared variables # Shared variables
self.buffer: utils.BufferType = bytearray(PACKET_LENGTH) self.buffer: utils.BufferType = bytearray(self.iface.TX_PACKET_LEN)
self.fallback_decrypt: bool = False self.fallback_decrypt: bool = False
self.bytes_read: int = 0 self.bytes_read: int = 0
self.expected_payload_length: int = 0 self.expected_payload_length: int = 0

View File

@ -23,7 +23,6 @@ from .checksum import CHECKSUM_LENGTH
from .writer import ( from .writer import (
INIT_HEADER_LENGTH, INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN, MAX_PAYLOAD_LEN,
PACKET_LENGTH,
write_payload_to_wire_and_add_checksum, write_payload_to_wire_and_add_checksum,
) )
@ -39,7 +38,7 @@ async def thp_main_loop(iface: WireInterface) -> None:
_CHANNELS = channel_manager.load_cached_channels() _CHANNELS = channel_manager.load_cached_channels()
read = loop.wait(iface.iface_num() | io.POLL_READ) read = loop.wait(iface.iface_num() | io.POLL_READ)
packet = bytearray(PACKET_LENGTH) packet = bytearray(iface.RX_PACKET_LEN)
while True: while True:
try: try:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
@ -141,19 +140,14 @@ def _get_buffer_for_payload(
if payload_length > max_length: if payload_length > max_length:
raise ThpError("Message too large") raise ThpError("Message too large")
if payload_length > len(existing_buffer): if payload_length > len(existing_buffer):
return _try_allocate_new_buffer(payload_length) try:
new_buffer = bytearray(payload_length)
except MemoryError:
raise ThpError("Message too large")
return new_buffer
return _reuse_existing_buffer(payload_length, existing_buffer) return _reuse_existing_buffer(payload_length, existing_buffer)
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(PACKET_LENGTH) # TODO ???
raise ThpError("Message too large")
return payload
def _reuse_existing_buffer( def _reuse_existing_buffer(
payload_length: int, existing_buffer: utils.BufferType payload_length: int, existing_buffer: utils.BufferType
) -> utils.BufferType: ) -> utils.BufferType:

View File

@ -12,8 +12,6 @@ CHECKSUM_LENGTH = const(4)
MAX_PAYLOAD_LEN = const(60000) MAX_PAYLOAD_LEN = const(60000)
MESSAGE_TYPE_LENGTH = const(2) MESSAGE_TYPE_LENGTH = const(2)
PACKET_LENGTH = io.WebUSB.PACKET_LEN
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
from typing import Awaitable, Sequence from typing import Awaitable, Sequence
@ -39,7 +37,7 @@ async def write_payloads_to_wire(
current_data_idx = 0 current_data_idx = 0
current_data_offset = 0 current_data_offset = 0
packet = bytearray(PACKET_LENGTH) packet = bytearray(iface.TX_PACKET_LEN)
header.pack_to_init_buffer(packet) header.pack_to_init_buffer(packet)
packet_offset: int = INIT_HEADER_LENGTH packet_offset: int = INIT_HEADER_LENGTH
packet_number = 0 packet_number = 0
@ -47,8 +45,8 @@ async def write_payloads_to_wire(
while nwritten < total_length: while nwritten < total_length:
if packet_number == 1: if packet_number == 1:
header.pack_to_cont_buffer(packet) header.pack_to_cont_buffer(packet)
if packet_number >= 1 and nwritten >= total_length - PACKET_LENGTH: if packet_number >= 1 and nwritten >= total_length - iface.TX_PACKET_LEN:
packet[:] = bytearray(PACKET_LENGTH) packet[:] = bytearray(iface.TX_PACKET_LEN)
header.pack_to_cont_buffer(packet) header.pack_to_cont_buffer(packet)
while True: while True:
n = utils.memcpy( n = utils.memcpy(
@ -58,12 +56,12 @@ async def write_payloads_to_wire(
current_data_offset += n current_data_offset += n
nwritten += n nwritten += n
if packet_offset < PACKET_LENGTH: if packet_offset < iface.TX_PACKET_LEN:
current_data_idx += 1 current_data_idx += 1
current_data_offset = 0 current_data_offset = 0
if current_data_idx >= n_of_data: if current_data_idx >= n_of_data:
break break
elif packet_offset == PACKET_LENGTH: elif packet_offset == iface.TX_PACKET_LEN:
break break
else: else:
raise Exception("Should not happen!!!") raise Exception("Should not happen!!!")

View File

@ -15,7 +15,6 @@ if utils.USE_THP:
from trezor.wire.thp import alternating_bit_protocol as ABP from trezor.wire.thp import alternating_bit_protocol as ABP
from trezor.wire.thp import checksum, thp_main from trezor.wire.thp import checksum, thp_main
from trezor.wire.thp.checksum import CHECKSUM_LENGTH from trezor.wire.thp.checksum import CHECKSUM_LENGTH
from trezor.wire.thp.writer import PACKET_LENGTH
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
@ -32,7 +31,7 @@ CONT = 0x80
HEADER_INIT_LENGTH = 5 HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3 HEADER_CONT_LENGTH = 3
if utils.USE_THP: if utils.USE_THP:
INIT_MESSAGE_DATA_LENGTH = PACKET_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN INIT_MESSAGE_DATA_LENGTH = HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN # + PACKET_LENGTH
def make_header(ctrl_byte, cid, length): def make_header(ctrl_byte, cid, length):
@ -198,7 +197,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
) )
# make sure we fit into one packet, to make this easier # make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH) # self.assertTrue(len(packet) <= thp_main.PACKET_LENGTH)
buffer = bytearray(1) buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet)) self.assertTrue(len(buffer) <= len(packet))
@ -284,7 +283,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
cont_header + chunk cont_header + chunk
for chunk in chunks( for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum, message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
thp_main.PACKET_LENGTH - HEADER_CONT_LENGTH, HEADER_CONT_LENGTH, # + PACKET_LENGTH
) )
] ]
@ -310,7 +309,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
PACKET_COUNT = 1180 PACKET_COUNT = 1180
# message that takes up 1 180 USB packets # message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * ( message_size = (PACKET_COUNT - 1) * (
PACKET_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN # + PACKET_LENGTH
) + INIT_MESSAGE_DATA_LENGTH ) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory # ensure that a message this big won't fit into memory