mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): add sending of unallocated session error message, debug log improvements
This commit is contained in:
parent
07c935f989
commit
b262a59d9b
@ -7,8 +7,10 @@ import usb
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
|
||||
from trezor import io, log, loop, protobuf, utils
|
||||
from trezor.messages import ThpCreateNewSession
|
||||
from trezor.enums import FailureType, MessageType
|
||||
from trezor.messages import Failure, ThpCreateNewSession
|
||||
from trezor.wire import message_handler
|
||||
from trezor.wire.errors import Error
|
||||
from trezor.wire.thp import thp_messages
|
||||
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
@ -19,6 +21,7 @@ from .thp_messages import (
|
||||
ACK_MESSAGE,
|
||||
CONTINUATION_PACKET,
|
||||
ENCRYPTED_TRANSPORT,
|
||||
ERROR,
|
||||
HANDSHAKE_INIT,
|
||||
InitHeader,
|
||||
)
|
||||
@ -72,27 +75,35 @@ class Channel(Context):
|
||||
|
||||
def get_channel_state(self) -> int:
|
||||
state = int.from_bytes(self.channel_cache.state, "big")
|
||||
print("channel.get_ch_state:", _state_to_str(state))
|
||||
if __debug__:
|
||||
log.debug(__name__, "get_channel_state: %s", _state_to_str(state))
|
||||
return state
|
||||
|
||||
def get_channel_id_int(self) -> int:
|
||||
return int.from_bytes(self.channel_id, "big")
|
||||
|
||||
def set_channel_state(self, state: ChannelState) -> None:
|
||||
print("channel.set_ch_state:", _state_to_str(state))
|
||||
if __debug__:
|
||||
log.debug(__name__, "set_channel_state: %s", _state_to_str(state))
|
||||
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||
|
||||
def set_buffer(self, buffer: utils.BufferType) -> None:
|
||||
self.buffer = buffer
|
||||
print("channel.set_buffer:", type(self.buffer))
|
||||
if __debug__:
|
||||
log.debug(__name__, "set_buffer: %s", type(self.buffer))
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
async def receive_packet(self, packet: utils.BufferType):
|
||||
print("channel.receive_packet")
|
||||
if __debug__:
|
||||
log.debug(__name__, "receive_packet")
|
||||
ctrl_byte = packet[0]
|
||||
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||
await self._handle_cont_packet(packet)
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
printBytes(self.buffer)
|
||||
if __debug__:
|
||||
log.debug(__name__, "self.buffer: %s", get_bytes_as_str(self.buffer))
|
||||
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
||||
self._finish_message()
|
||||
await self._handle_completed_message()
|
||||
@ -104,7 +115,8 @@ class Channel(Context):
|
||||
)
|
||||
|
||||
async def _handle_init_packet(self, packet: utils.BufferType):
|
||||
print("channel._handle_init_packet")
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_init_packet")
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||
self.expected_payload_length = payload_length
|
||||
packet_payload = packet[5:]
|
||||
@ -133,20 +145,27 @@ class Channel(Context):
|
||||
payload_length, self.buffer
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("channel._handle_init_packet - payload len", payload_length)
|
||||
print("channel._handle_init_packet - buffer len", len(self.buffer))
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
|
||||
|
||||
await self._buffer_packet_data(self.buffer, packet, 0)
|
||||
print("channel._handle_init_packet - end")
|
||||
if __debug__:
|
||||
log.debug(__name__, "channel._handle_init_packet - end")
|
||||
|
||||
async def _handle_cont_packet(self, packet: utils.BufferType):
|
||||
print("channel._handle_cont_packet")
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_cont_packet")
|
||||
if not self.is_cont_packet_expected:
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
|
||||
|
||||
async def _handle_completed_message(self) -> None:
|
||||
print("channel._handle_completed_message")
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_completed_message")
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
||||
message_length = payload_length + INIT_DATA_OFFSET
|
||||
|
||||
@ -154,7 +173,12 @@ class Channel(Context):
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
print("channel._handle_completed_message - sync bit of message:", sync_bit)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"handle_completed_message - sync bit of message: %d",
|
||||
sync_bit,
|
||||
)
|
||||
|
||||
# 1: Handle ACKs
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
@ -179,15 +203,19 @@ class Channel(Context):
|
||||
await self._handle_message_to_app_or_channel(
|
||||
payload_length, message_length, ctrl_byte, sync_bit
|
||||
)
|
||||
print("channel._handle_completed_message - end")
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_completed_message - end")
|
||||
|
||||
def _check_checksum(self, message_length: int):
|
||||
print("channel._check_checksum")
|
||||
if __debug__:
|
||||
log.debug(__name__, "check_checksum")
|
||||
if not checksum.is_valid(
|
||||
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
|
||||
data=self.buffer[: message_length - CHECKSUM_LENGTH],
|
||||
):
|
||||
self._todo_clear_buffer()
|
||||
if __debug__:
|
||||
log.debug(__name__, "Invalid checksum, ignoring message.")
|
||||
raise ThpError("Invalid checksum, ignoring message.")
|
||||
|
||||
async def _handle_message_to_app_or_channel(
|
||||
@ -195,7 +223,7 @@ class Channel(Context):
|
||||
) -> None:
|
||||
state = self.get_channel_state()
|
||||
if __debug__:
|
||||
log.debug(__name__, "state: " + _state_to_str(state))
|
||||
log.debug(__name__, "state: %s", _state_to_str(state))
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
await self._handle_state_TH1(payload_length, message_length, sync_bit)
|
||||
@ -206,7 +234,7 @@ class Channel(Context):
|
||||
raise ThpError("Message is not encrypted. Ignoring")
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
self._handle_state_ENCRYPTED_TRANSPORT(message_length)
|
||||
await self._handle_state_ENCRYPTED_TRANSPORT(message_length)
|
||||
return
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
@ -239,7 +267,8 @@ class Channel(Context):
|
||||
return
|
||||
|
||||
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
|
||||
print("channel._handle_state_TH2")
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_state_TH2")
|
||||
host_encrypted_static_pubkey = self.buffer[
|
||||
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
@ -249,10 +278,13 @@ class Channel(Context):
|
||||
+ TAG_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
]
|
||||
print(
|
||||
host_encrypted_static_pubkey,
|
||||
handshake_completion_request_noise_payload,
|
||||
) # TODO remove
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"host static pubkey: %s, noise payload: %s",
|
||||
get_bytes_as_str(host_encrypted_static_pubkey),
|
||||
get_bytes_as_str(handshake_completion_request_noise_payload),
|
||||
)
|
||||
|
||||
# send hanshake completion response
|
||||
loop.schedule(
|
||||
@ -262,8 +294,9 @@ class Channel(Context):
|
||||
)
|
||||
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
|
||||
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
|
||||
print("channel._handle_state_ENCRYPTED_TRANSPORT")
|
||||
async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
||||
self._decrypt_buffer(message_length)
|
||||
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
|
||||
if session_id == 0:
|
||||
@ -271,10 +304,16 @@ class Channel(Context):
|
||||
return
|
||||
|
||||
if session_id not in self.sessions:
|
||||
await self.write_error(
|
||||
FailureType.ThpUnallocatedSession, "Unallocated session"
|
||||
)
|
||||
raise ThpError("Unalloacted session")
|
||||
|
||||
session_state = self.sessions[session_id].get_session_state()
|
||||
if session_state is SessionState.UNALLOCATED:
|
||||
await self.write_error(
|
||||
FailureType.ThpUnallocatedSession, "Unallocated session"
|
||||
)
|
||||
raise ThpError("Unalloacted session")
|
||||
|
||||
self.sessions[session_id].incoming_message.publish(
|
||||
@ -296,11 +335,17 @@ class Channel(Context):
|
||||
|
||||
expected_type = protobuf.type_for_wire(message_type)
|
||||
message = message_handler.wrap_protobuf_load(buf, expected_type)
|
||||
print("channel._handle_channel_message:", message)
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_channel_message: %s", message)
|
||||
# TODO handle other messages than CreateNewSession
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(message, ThpCreateNewSession)
|
||||
print("channel._handle_channel_message - passphrase:", message.passphrase)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"handle_channel_message - passphrase: %s",
|
||||
message.passphrase,
|
||||
)
|
||||
# await thp_messages.handle_CreateNewSession(message)
|
||||
if message.passphrase is not None:
|
||||
new_session_id: int = self.create_new_session(message.passphrase)
|
||||
@ -311,7 +356,10 @@ class Channel(Context):
|
||||
message_size: int = thp_messages.get_new_session_message(
|
||||
bufferrone, new_session_id
|
||||
)
|
||||
print(message_size) # TODO adjust
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "handle_channel_message - message size: %d", message_size
|
||||
)
|
||||
loop.schedule(self.write_and_encrypt(bufferrone))
|
||||
# TODO not finished
|
||||
|
||||
@ -332,7 +380,8 @@ class Channel(Context):
|
||||
)
|
||||
|
||||
def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None:
|
||||
print("channel._encrypt")
|
||||
if __debug__:
|
||||
log.debug(__name__, "encrypt")
|
||||
min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
if len(buffer) < min_required_length or not isinstance(buffer, bytearray):
|
||||
new_buffer = bytearray(min_required_length)
|
||||
@ -359,18 +408,16 @@ class Channel(Context):
|
||||
|
||||
async def _send_ack(self, ack_bit: int) -> None:
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
||||
header = InitHeader(
|
||||
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
|
||||
)
|
||||
header = InitHeader(ctrl_byte, self.get_channel_id_int(), CHECKSUM_LENGTH)
|
||||
chksum = checksum.compute(header.to_bytes())
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
||||
int.from_bytes(self.channel_id, "big"),
|
||||
self.get_channel_id_int(),
|
||||
ack_bit,
|
||||
)
|
||||
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
|
||||
await self._write_payload_to_wire(header, chksum)
|
||||
|
||||
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
@ -382,10 +429,28 @@ class Channel(Context):
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
print("channel.write:" + msg.MESSAGE_NAME)
|
||||
if __debug__:
|
||||
log.debug(__name__, "channel.write: %s", msg.MESSAGE_NAME)
|
||||
noise_payload_len = self._encode_into_buffer(msg, session_id)
|
||||
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||
|
||||
async def write_error(self, err_type: FailureType, message: str) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_error")
|
||||
msg_size = self._encode_error_into_buffer(err_type, message)
|
||||
data_length = MESSAGE_TYPE_LENGTH + msg_size
|
||||
header: InitHeader = InitHeader(
|
||||
ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH
|
||||
)
|
||||
chksum = checksum.compute(
|
||||
header.to_bytes() + memoryview(self.buffer[:data_length])
|
||||
)
|
||||
|
||||
utils.memcpy(self.buffer, data_length, chksum, 0)
|
||||
await self._write_payload_to_wire(
|
||||
header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH])
|
||||
)
|
||||
|
||||
async def write_and_encrypt(self, payload: bytes) -> None:
|
||||
payload_length = len(payload)
|
||||
|
||||
@ -394,29 +459,31 @@ class Channel(Context):
|
||||
self._encrypt(self.buffer, payload_length)
|
||||
payload_length = payload_length + TAG_LENGTH
|
||||
|
||||
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length]))
|
||||
loop.schedule(
|
||||
self._write_encrypted_payload_loop(memoryview(self.buffer[:payload_length]))
|
||||
)
|
||||
|
||||
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
||||
print("channel._write_encrypted_payload_loop")
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_encrypted_payload_loop")
|
||||
payload_len = len(payload) + CHECKSUM_LENGTH
|
||||
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
|
||||
header = InitHeader(
|
||||
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
|
||||
)
|
||||
header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
|
||||
chksum = checksum.compute(header.to_bytes() + payload)
|
||||
payload = payload + chksum
|
||||
|
||||
# TODO add condition that disallows to write when can_send_message is false
|
||||
THP.sync_set_can_send_message(self.channel_cache, False)
|
||||
while True:
|
||||
print(
|
||||
"channel._write_encrypted_payload_loop - loop start, sync_bit:",
|
||||
(header.ctrl_byte & 0x10) >> 4,
|
||||
" sync_send_bit:",
|
||||
THP.sync_get_send_bit(self.channel_cache),
|
||||
)
|
||||
await self._write_payload_to_wire(header, payload, payload_len)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"write_encrypted_payload_loop - loop start, sync_bit: %d, sync_send_bit: %d",
|
||||
(header.ctrl_byte & 0x10) >> 4,
|
||||
THP.sync_get_send_bit(self.channel_cache),
|
||||
)
|
||||
await self._write_payload_to_wire(header, payload)
|
||||
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
||||
try:
|
||||
await self.waiting_for_ack_timeout
|
||||
@ -424,29 +491,36 @@ class Channel(Context):
|
||||
THP.sync_set_send_bit_to_opposite(self.channel_cache)
|
||||
break
|
||||
|
||||
async def _write_payload_to_wire(
|
||||
self, header: InitHeader, payload: bytes, payload_len: int
|
||||
):
|
||||
print("chanel._write_payload_to_wire")
|
||||
async def _write_payload_to_wire(self, header: InitHeader, payload: bytes):
|
||||
if __debug__:
|
||||
log.debug(__name__, "write_payload_to_wire")
|
||||
# prepare the report buffer with header data
|
||||
payload_len = len(payload)
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
|
||||
await self._write_report_to_wire(report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_len:
|
||||
header.pack_to_cont_buffer(report)
|
||||
while nwritten < payload_len:
|
||||
if nwritten >= payload_len - REPORT_LENGTH:
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_cont_buffer(report)
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await self._write_report_to_wire(report)
|
||||
|
||||
async def _write_report_to_wire(self, report: utils.BufferType) -> None:
|
||||
while True:
|
||||
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
||||
printBytes(report) # TODO remove
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "write_report_to_wire: %s", get_bytes_as_str(report)
|
||||
)
|
||||
n = self.iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
||||
@ -460,41 +534,52 @@ class Channel(Context):
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
||||
payload_size = offset + msg_size
|
||||
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
||||
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
|
||||
|
||||
if required_min_size > len(self.buffer) or not isinstance(
|
||||
self.buffer, bytearray
|
||||
):
|
||||
# message is too big or buffer is not bytearray, we need to allocate a new buffer
|
||||
if required_min_size > len(self.buffer):
|
||||
# message is too big, we need to allocate a new buffer
|
||||
self.buffer = bytearray(required_min_size)
|
||||
|
||||
buffer = self.buffer
|
||||
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
|
||||
msg_type_bytes = int.to_bytes(msg.MESSAGE_WIRE_TYPE, MESSAGE_TYPE_LENGTH, "big")
|
||||
|
||||
utils.memcpy(buffer, 0, session_id_bytes, 0)
|
||||
utils.memcpy(buffer, SESSION_ID_LENGTH, msg_type_bytes, 0)
|
||||
assert isinstance(buffer, bytearray)
|
||||
msg_size = protobuf.encode(buffer[offset:], msg)
|
||||
_encode_session_into_buffer(memoryview(buffer), session_id)
|
||||
_encode_message_type_into_buffer(
|
||||
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
|
||||
)
|
||||
_encode_message_into_buffer(
|
||||
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
||||
)
|
||||
|
||||
return payload_size
|
||||
|
||||
def _encode_error_into_buffer(self, err_code: FailureType, message: str) -> int:
|
||||
error_message: protobuf.MessageType = Failure(code=err_code, message=message)
|
||||
_encode_message_type_into_buffer(memoryview(self.buffer), MessageType.Failure)
|
||||
_encode_message_into_buffer(
|
||||
memoryview(self.buffer), error_message, MESSAGE_TYPE_LENGTH
|
||||
)
|
||||
return protobuf.encoded_length(error_message)
|
||||
|
||||
def create_new_session(
|
||||
self,
|
||||
passphrase="",
|
||||
) -> int:
|
||||
print("channel.create_new_session")
|
||||
if __debug__:
|
||||
log.debug(__name__, " create_new_session")
|
||||
from trezor.wire.thp.session_context import SessionContext
|
||||
|
||||
session = SessionContext.create_new_session(self)
|
||||
self.sessions[session.session_id] = session
|
||||
loop.schedule(session.handle())
|
||||
print(
|
||||
"channel.create_new_session - new session created. Session id:",
|
||||
session.session_id,
|
||||
)
|
||||
print(self.sessions)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"create_new_session - new session created. Session id: %d",
|
||||
session.session_id,
|
||||
)
|
||||
if __debug__:
|
||||
print(self.sessions)
|
||||
return session.session_id
|
||||
|
||||
def _todo_clear_buffer(self):
|
||||
@ -504,14 +589,21 @@ class Channel(Context):
|
||||
# TODO add debug logging to ACK handling
|
||||
def _handle_received_ACK(self, sync_bit: int) -> None:
|
||||
if self._ack_is_not_expected():
|
||||
print("channel._handle_received_ACK - ack is not expected")
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_received_ACK - ack is not expected")
|
||||
return
|
||||
if self._ack_has_incorrect_sync_bit(sync_bit):
|
||||
print("channel._handle_received_ACK - ack has incorrect sync bit")
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"handle_received_ACK - ack has incorrect sync bit",
|
||||
)
|
||||
return
|
||||
|
||||
if self.waiting_for_ack_timeout is not None:
|
||||
self.waiting_for_ack_timeout.close()
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_received_ACK - closed waiting for ack")
|
||||
|
||||
THP.sync_set_can_send_message(self.channel_cache, True)
|
||||
|
||||
@ -557,10 +649,13 @@ def _get_buffer_for_message(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
length = payload_length + INIT_DATA_OFFSET
|
||||
print("channel._get_buffer_for_message - length", length)
|
||||
print(
|
||||
"channel._get_buffer_for_message - existing buffer type", type(existing_buffer)
|
||||
)
|
||||
if __debug__:
|
||||
log.debug(__name__, "get_buffer_for_message - length: %d", length)
|
||||
log.debug(
|
||||
__name__,
|
||||
"get_buffer_for_message - existing buffer type: %s",
|
||||
type(existing_buffer),
|
||||
)
|
||||
if length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
|
||||
@ -605,6 +700,26 @@ def is_channel_state_pairing(state: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _encode_session_into_buffer(
|
||||
buffer: memoryview, session_id: int, buffer_offset: int = 0
|
||||
) -> None:
|
||||
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
|
||||
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
|
||||
|
||||
|
||||
def _encode_message_type_into_buffer(
|
||||
buffer: memoryview, message_type: int, offset: int = 0
|
||||
) -> None:
|
||||
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
|
||||
utils.memcpy(buffer, offset, msg_type_bytes, 0)
|
||||
|
||||
|
||||
def _encode_message_into_buffer(
|
||||
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
|
||||
) -> None:
|
||||
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
|
||||
|
||||
|
||||
def _state_to_str(state: int) -> str:
|
||||
name = {
|
||||
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
|
||||
@ -614,5 +729,22 @@ def _state_to_str(state: int) -> str:
|
||||
return "UNKNOWN_STATE"
|
||||
|
||||
|
||||
def printBytes(a):
|
||||
print(hexlify(a).decode("utf-8"))
|
||||
def get_bytes_as_str(a):
|
||||
return hexlify(a).decode("utf-8")
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
if isinstance(exc, Error):
|
||||
return Failure(code=exc.code, message=exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||
elif isinstance(exc, ThpError):
|
||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||
else:
|
||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||
# change the `if __debug__` to `if True` to get the full error message.
|
||||
if __debug__:
|
||||
message = str(exc)
|
||||
else:
|
||||
message = "Firmware error"
|
||||
return Failure(code=FailureType.FirmwareError, message=message)
|
||||
|
@ -12,7 +12,7 @@ CONTINUATION_PACKET = 0x80
|
||||
ENCRYPTED_TRANSPORT = 0x02
|
||||
HANDSHAKE_INIT = 0x00
|
||||
ACK_MESSAGE = 0x20
|
||||
_ERROR = 0x42
|
||||
ERROR = 0x42
|
||||
CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x41
|
||||
|
||||
@ -23,7 +23,7 @@ TREZOR_STATE_PAIRED = b"\x01"
|
||||
class InitHeader:
|
||||
format_str = ">BHH"
|
||||
|
||||
def __init__(self, ctrl_byte, cid: int, length: int) -> None:
|
||||
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.length = length
|
||||
@ -48,7 +48,7 @@ class InitHeader:
|
||||
|
||||
@classmethod
|
||||
def get_error_header(cls, cid, length):
|
||||
return cls(_ERROR, cid, length)
|
||||
return cls(ERROR, cid, length)
|
||||
|
||||
@classmethod
|
||||
def get_channel_allocation_response_header(cls, length):
|
||||
|
@ -45,7 +45,8 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
|
||||
while True:
|
||||
try:
|
||||
print("thp_v1.thp_main_loop")
|
||||
if __debug__:
|
||||
log.debug(__name__, "thp_main_loop")
|
||||
packet = await read
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user