|
|
|
@ -6,7 +6,7 @@ from ubinascii import hexlify
|
|
|
|
|
import usb
|
|
|
|
|
from storage import cache_thp
|
|
|
|
|
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
|
|
|
|
|
from trezor import io, loop, protobuf, utils
|
|
|
|
|
from trezor import io, log, loop, protobuf, utils
|
|
|
|
|
from trezor.messages import ThpCreateNewSession
|
|
|
|
|
from trezor.wire import message_handler
|
|
|
|
|
from trezor.wire.thp import thp_messages
|
|
|
|
@ -41,6 +41,8 @@ MESSAGE_TYPE_LENGTH = const(2)
|
|
|
|
|
REPORT_LENGTH = const(64)
|
|
|
|
|
MAX_PAYLOAD_LEN = const(60000)
|
|
|
|
|
|
|
|
|
|
_ACK_MESSAGE = 0x20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Channel(Context):
|
|
|
|
|
def __init__(self, channel_cache: ChannelCache) -> None:
|
|
|
|
@ -139,10 +141,13 @@ class Channel(Context):
|
|
|
|
|
|
|
|
|
|
async def _handle_completed_message(self):
|
|
|
|
|
print("handling completed message")
|
|
|
|
|
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
|
|
|
|
|
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
|
|
|
|
msg_len = payload_length + INIT_DATA_OFFSET
|
|
|
|
|
|
|
|
|
|
print("checksum check")
|
|
|
|
|
printBytes(self.buffer)
|
|
|
|
|
# printBytes(self.buffer)
|
|
|
|
|
|
|
|
|
|
if not checksum.is_valid(
|
|
|
|
|
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
|
|
|
|
|
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
|
|
|
|
@ -150,16 +155,40 @@ class Channel(Context):
|
|
|
|
|
# checksum is not valid -> ignore message
|
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Synchronization process
|
|
|
|
|
sync_bit = (ctrl_byte & 0x10) >> 4
|
|
|
|
|
print("sync bit:", sync_bit)
|
|
|
|
|
|
|
|
|
|
# 1: Handle ACKs
|
|
|
|
|
if _is_ctrl_byte_ack(ctrl_byte):
|
|
|
|
|
self._handle_received_ACK(sync_bit)
|
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 2: Handle message with unexpected synchronization bit
|
|
|
|
|
if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache):
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__, "Received message with an unexpected synchronization bit"
|
|
|
|
|
)
|
|
|
|
|
raise ThpError("Received message with an unexpected synchronization bit")
|
|
|
|
|
|
|
|
|
|
# 3: Send ACK in response
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__,
|
|
|
|
|
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
|
|
|
|
self.channel_id,
|
|
|
|
|
sync_bit,
|
|
|
|
|
)
|
|
|
|
|
await self._sendAck(sync_bit)
|
|
|
|
|
print("___set receive bit to", 1 - sync_bit)
|
|
|
|
|
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
|
|
|
|
|
|
|
|
|
|
state = self.get_channel_state()
|
|
|
|
|
_print_state(state)
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, _state_to_str(state))
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH1:
|
|
|
|
|
if not _is_ctrl_byte_handshake_init:
|
|
|
|
@ -270,6 +299,21 @@ class Channel(Context):
|
|
|
|
|
self.expected_payload_length = 0
|
|
|
|
|
self.is_cont_packet_expected = False
|
|
|
|
|
|
|
|
|
|
async def _sendAck(self, ack_bit: int) -> None:
|
|
|
|
|
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
|
|
|
|
header = InitHeader(
|
|
|
|
|
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
|
|
|
|
|
)
|
|
|
|
|
chksum = checksum.compute(header.to_bytes())
|
|
|
|
|
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH)
|
|
|
|
|
|
|
|
|
|
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
|
|
|
|
|
if sync_bit == 0:
|
|
|
|
|
return ctrl_byte & 0xEF
|
|
|
|
|
if sync_bit == 1:
|
|
|
|
|
return ctrl_byte | 0x10
|
|
|
|
|
raise ThpError("Unexpected synchronization bit")
|
|
|
|
|
|
|
|
|
|
# CALLED BY WORKFLOW / SESSION CONTEXT
|
|
|
|
|
|
|
|
|
|
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
|
|
|
@ -287,17 +331,26 @@ class Channel(Context):
|
|
|
|
|
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
|
|
|
|
print("write loop before while")
|
|
|
|
|
payload_len = len(payload)
|
|
|
|
|
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
|
|
|
|
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
|
|
|
|
|
header = InitHeader(
|
|
|
|
|
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
|
|
|
|
|
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
|
|
|
|
|
)
|
|
|
|
|
# TODO add condition that disallows to write when can_send_message is false
|
|
|
|
|
THP.sync_set_can_send_message(self.channel_cache, False)
|
|
|
|
|
while True:
|
|
|
|
|
print("write encrypted payload loop - start")
|
|
|
|
|
print(
|
|
|
|
|
"write encrypted payload loop - start, sync_bit:",
|
|
|
|
|
header.ctrl_byte & 0x10,
|
|
|
|
|
" send_sync_bit:",
|
|
|
|
|
THP.sync_get_send_bit(self.channel_cache),
|
|
|
|
|
)
|
|
|
|
|
await self._write_encrypted_payload(header, payload, payload_len)
|
|
|
|
|
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
|
|
|
|
try:
|
|
|
|
|
await self.waiting_for_ack_timeout
|
|
|
|
|
except loop.TaskClosed:
|
|
|
|
|
THP.sync_set_send_bit_to_opposite(self.channel_cache)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
async def _write_encrypted_payload(
|
|
|
|
@ -459,30 +512,29 @@ def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
|
|
|
|
|
return ctrl_byte & 0xEF == ACK_MESSAGE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _print_state(cs: int) -> None:
|
|
|
|
|
if cs == ChannelState.ENCRYPTED_TRANSPORT:
|
|
|
|
|
print("state: encrypted transport")
|
|
|
|
|
elif cs == ChannelState.TH1:
|
|
|
|
|
print("state: th1")
|
|
|
|
|
elif cs == ChannelState.TH2:
|
|
|
|
|
print("state: th2")
|
|
|
|
|
elif cs == ChannelState.TP1:
|
|
|
|
|
print("state: tp1")
|
|
|
|
|
elif cs == ChannelState.TP2:
|
|
|
|
|
print("state: tp2")
|
|
|
|
|
elif cs == ChannelState.TP3:
|
|
|
|
|
print("state: tp3")
|
|
|
|
|
elif cs == ChannelState.TP4:
|
|
|
|
|
print("state: tp4")
|
|
|
|
|
elif cs == ChannelState.TP5:
|
|
|
|
|
print("state: tp5")
|
|
|
|
|
elif cs == ChannelState.UNALLOCATED:
|
|
|
|
|
print("state: unallocated")
|
|
|
|
|
elif cs == ChannelState.UNAUTHENTICATED:
|
|
|
|
|
print("state: unauthenticated")
|
|
|
|
|
def _state_to_str(state: int) -> str:
|
|
|
|
|
if state == ChannelState.ENCRYPTED_TRANSPORT:
|
|
|
|
|
return "state: encrypted transport"
|
|
|
|
|
elif state == ChannelState.TH1:
|
|
|
|
|
return "state: th1"
|
|
|
|
|
elif state == ChannelState.TH2:
|
|
|
|
|
return "state: th2"
|
|
|
|
|
elif state == ChannelState.TP1:
|
|
|
|
|
return "state: tp1"
|
|
|
|
|
elif state == ChannelState.TP2:
|
|
|
|
|
return "state: tp2"
|
|
|
|
|
elif state == ChannelState.TP3:
|
|
|
|
|
return "state: tp3"
|
|
|
|
|
elif state == ChannelState.TP4:
|
|
|
|
|
return "state: tp4"
|
|
|
|
|
elif state == ChannelState.TP5:
|
|
|
|
|
return "state: tp5"
|
|
|
|
|
elif state == ChannelState.UNALLOCATED:
|
|
|
|
|
return "state: unallocated"
|
|
|
|
|
elif state == ChannelState.UNAUTHENTICATED:
|
|
|
|
|
return "state: unauthenticated"
|
|
|
|
|
else:
|
|
|
|
|
print(cs)
|
|
|
|
|
print("state: <not implemented printout>")
|
|
|
|
|
return "state: <not implemented>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def printBytes(a):
|
|
|
|
|