Denote to-be-replaced functions as deprecated

M1nd3r/thp5
M1nd3r 1 month ago
parent 86ad10e8bd
commit 4a76216cbf

@ -41,16 +41,6 @@ _BUFFER_LOCK = None
_CHANNEL_CONTEXTS: dict[int, Channel] = {}
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
msg = await read_message_or_init_packet(iface, buffer)
while type(msg) is not MessageWithId:
if isinstance(msg, InterruptingInitPacket):
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
else:
raise ThpError("Unexpected output of read_message_or_init_packet:")
return msg
def set_buffer(buffer):
global _BUFFER
_BUFFER = buffer
@ -96,14 +86,28 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
async def read_message_or_init_packet(
async def deprecated_read_message(
iface: WireInterface, buffer: utils.BufferType
) -> MessageWithId:
msg = await deprecated_read_message_or_init_packet(iface, buffer)
while type(msg) is not MessageWithId:
if isinstance(msg, InterruptingInitPacket):
msg = await deprecated_read_message_or_init_packet(
iface, buffer, msg.initReport
)
else:
raise ThpError("Unexpected output of read_message_or_init_packet:")
return msg
async def deprecated_read_message_or_init_packet(
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
) -> MessageWithId | InterruptingInitPacket:
report = firstReport
while True:
# Wait for an initial report
if report is None:
report = await _get_loop_wait_read(iface)
report = await loop.wait(iface.iface_num() | io.POLL_READ)
if report is None:
raise ThpError("Reading failed unexpectedly, report is None.")
@ -129,7 +133,9 @@ async def read_message_or_init_packet(
header = InitHeader(ctrl_byte, cid, payload_length)
# buffer the received data
interruptingPacket = await _buffer_received_data(payload, header, iface, report)
interruptingPacket = await _deprecated_buffer_received_data(
payload, header, iface, report
)
if interruptingPacket is not None:
return interruptingPacket
@ -191,10 +197,6 @@ async def read_message_or_init_packet(
return await _handle_allocated(ctrl_byte, session, payload)
def _get_loop_wait_read(iface: WireInterface):
return loop.wait(iface.iface_num() | io.POLL_READ)
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
@ -213,14 +215,14 @@ def _get_buffer_for_payload(
return memoryview(existing_buffer)[:payload_length]
async def _buffer_received_data(
async def _deprecated_buffer_received_data(
payload: utils.BufferType, header: InitHeader, iface, report
) -> None | InterruptingInitPacket:
# buffer the initial data
nread = utils.memcpy(payload, 0, report, INIT_DATA_OFFSET)
while nread < header.length:
# wait for continuation report
report = await _get_loop_wait_read(iface)
report = await loop.wait(iface.iface_num() | io.POLL_READ)
# channel multiplexing
cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report)

@ -97,7 +97,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
)
buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer)
gen = thp_v1.deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req_message)
@ -126,7 +126,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
message = header + MESSAGE_TYPE_BYTES + chksum
buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer)
gen = thp_v1.deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -163,7 +163,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
)
]
buffer = bytearray(262)
gen = thp_v1.read_message(self.interface, buffer)
gen = thp_v1.deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -203,7 +203,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = thp_v1.read_message(self.interface, buffer)
gen = thp_v1.deprecated_read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(packet)
@ -230,7 +230,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = thp_v1.read_message(self.interface, buffer)
gen = thp_v1.deprecated_read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -327,7 +327,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
buffer = bytearray(65536)
gen = thp_v1.read_message(self.interface, buffer)
gen = thp_v1.deprecated_read_message(self.interface, buffer)
query = gen.send(None)

Loading…
Cancel
Save