Lower the number of style prebuild errors 3

M1nd3r/thp5
M1nd3r 2 months ago committed by M1nd3r
parent c205d80762
commit e7ff87f9bd

@ -177,20 +177,13 @@ def get_features() -> Features:
# handle_Initialize should not be used with THP to start a new session
async def handle_Initialize(
msg: Initialize, message_session_id: bytearray | None = None
) -> Features:
if message_session_id is None and utils.USE_THP:
async def handle_Initialize(msg: Initialize) -> Features:
if utils.USE_THP:
raise ValueError("With THP enabled, a session id must be provided in args")
if utils.USE_THP:
session_id = storage_thp_cache.start_existing_session(msg.session_id)
else:
session_id = storage_cache.start_session(msg.session_id)
session_id = storage_cache.start_session(msg.session_id)
if not utils.BITCOIN_ONLY:
# TODO this block should be changed in THP
derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
@ -212,7 +205,7 @@ async def handle_Initialize(
)
features = get_features()
features.session_id = session_id # not important in THP
features.session_id = session_id
return features

@ -56,7 +56,7 @@ class SessionThpCache(DataCache): # TODO implement, this is just copied Session
def clear(self) -> None:
super().clear()
self.state = 0 # Set state to UNALLOCATED
self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.last_usage = 0
self.session_id[:] = b""
@ -175,6 +175,7 @@ def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
_session_usage_counter += 1
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
return _SESSIONS[new_auth_session_index]
def get_least_recently_used_authetnicated_session_index() -> int:
@ -216,7 +217,7 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
return new_session_id
def start_existing_session(session_id: bytearray) -> bytes:
def start_existing_session(session_id: bytes) -> bytes:
if session_id is None:
raise ValueError("session_id cannot be None")
if get_active_session_id() == session_id:

@ -47,7 +47,6 @@ if TYPE_CHECKING:
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask]
HandlerWithSessionId = Callable[[Msg, bytes | None], HandlerTask]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
@ -57,7 +56,9 @@ EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
loop.schedule(
handle_session(iface, codec_v1.SESSION_ID.to_bytes(4, "big"), is_debug_session)
)
def wrap_protobuf_load(
@ -145,11 +146,7 @@ async def _handle_single_message(
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
if msg.type is MT.Initialize:
# Special case for handle_initialize to have access to the verified session_id
task = handler(req_msg, ctx.session_id)
else:
task = handler(req_msg)
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
@ -268,9 +265,7 @@ async def handle_session(
log.exception(__name__, exc)
def _find_handler_placeholder(
iface: WireInterface, msg_type: int
) -> Handler | HandlerWithSessionId | None:
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
"""Placeholder handler lookup before a proper one is registered."""
return None

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from trezor import log, loop, protobuf
from .protocol import WireProtocol
import trezor.wire.protocol
from .protocol_common import Message
if TYPE_CHECKING:
@ -69,11 +69,11 @@ class Context:
) -> None:
self.iface = iface
self.buffer = buffer
self.session_id: session_id
self.session_id = session_id
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return WireProtocol.read_message(self, self.iface, self.buffer)
return protocol.read_message(self, self.iface, self.buffer)
if TYPE_CHECKING:
@ -160,7 +160,7 @@ class Context:
msg_size = protobuf.encode(buffer, msg)
await WireProtocol.write_message(
await protocol.write_message(
self.iface,
Message(
message_type=msg.MESSAGE_WIRE_TYPE,

@ -7,17 +7,15 @@ if TYPE_CHECKING:
from trezorio import WireInterface
class WireProtocol:
async def read_message(
self, iface: WireInterface, buffer: utils.BufferType
) -> Message:
if utils.USE_THP:
return await thp_v1.read_message(iface, buffer)
return await codec_v1.read_message(iface, buffer)
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
if utils.USE_THP:
return await thp_v1.read_message(iface, buffer)
return await codec_v1.read_message(iface, buffer)
async def write_message(self, iface: WireInterface, message: Message) -> None:
if utils.USE_THP:
await thp_v1.write_to_wire(iface, message) # TODO incomplete
return
await codec_v1.write_message(iface, message.type, message.data)
async def write_message(iface: WireInterface, message: Message) -> None:
if utils.USE_THP:
await thp_v1.write_message(iface, message)
return
await codec_v1.write_message(iface, message.type, message.data)
return

@ -105,7 +105,7 @@ def is_active_session(session: SessionThpCache):
def set_session_state(session: SessionThpCache, new_state: SessionState):
session.state = new_state.to_bytes(1, "big")
session.state = bytearray(new_state.to_bytes(1, "big"))
def _get_id(iface: WireInterface, cid: int) -> bytes:

@ -86,9 +86,8 @@ async def read_message_or_init_packet(
report = firstReport
while True:
# Wait for an initial report
if firstReport is None:
if report is None:
report = await _get_loop_wait_read(iface)
if report is None:
raise ThpError("Reading failed unexpectedly, report is None.")
@ -96,7 +95,7 @@ async def read_message_or_init_packet(
ctrl_byte, cid = ustruct.unpack(">BH", report)
if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, report)
await _handle_broadcast(iface, ctrl_byte, report) # TODO await
report = None
continue
@ -258,7 +257,7 @@ async def write_message(
async def write_to_wire(
iface: WireInterface, header: InitHeader, payload: bytes
) -> None:
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
loop_write = loop.wait(iface.iface_num() | io.POLL_WRITE)
payload_length = len(payload)
@ -268,7 +267,7 @@ async def write_to_wire(
# write initial report
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
await _write_report(write, iface, report)
await _write_report(loop_write, iface, report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_length:
@ -276,7 +275,7 @@ async def write_to_wire(
while nwritten < payload_length:
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
await _write_report(write, iface, report)
await _write_report(loop_write, iface, report)
async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
@ -287,7 +286,7 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
return
async def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None:
async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None:
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:])

Loading…
Cancel
Save