mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-18 05:58:45 +00:00
fix(python): improve protocol_v2, remove its channel database
This commit is contained in:
parent
174c5597e4
commit
4f20c2883d
@ -12,11 +12,9 @@ from ... import exceptions, messages, protobuf
|
|||||||
from ...mapping import ProtobufMapping
|
from ...mapping import ProtobufMapping
|
||||||
from .. import Transport
|
from .. import Transport
|
||||||
from ..thp import checksum, thp_io
|
from ..thp import checksum, thp_io
|
||||||
from ..thp.channel_data import ChannelData
|
|
||||||
from ..thp.checksum import CHECKSUM_LENGTH
|
from ..thp.checksum import CHECKSUM_LENGTH
|
||||||
from ..thp.message_header import MessageHeader
|
from ..thp.message_header import MessageHeader
|
||||||
from . import control_byte
|
from . import control_byte
|
||||||
from .channel_database import ChannelDatabase, get_channel_db
|
|
||||||
from .protocol_and_channel import Channel
|
from .protocol_and_channel import Channel
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -30,11 +28,6 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
|||||||
|
|
||||||
class ProtocolV2Channel(Channel):
|
class ProtocolV2Channel(Channel):
|
||||||
channel_id: int
|
channel_id: int
|
||||||
channel_database: ChannelDatabase
|
|
||||||
key_request: bytes
|
|
||||||
key_response: bytes
|
|
||||||
nonce_request: int
|
|
||||||
nonce_response: int
|
|
||||||
sync_bit_send: int
|
sync_bit_send: int
|
||||||
sync_bit_receive: int
|
sync_bit_receive: int
|
||||||
handshake_hash: bytes
|
handshake_hash: bytes
|
||||||
@ -46,52 +39,23 @@ class ProtocolV2Channel(Channel):
|
|||||||
self,
|
self,
|
||||||
transport: Transport,
|
transport: Transport,
|
||||||
mapping: ProtobufMapping,
|
mapping: ProtobufMapping,
|
||||||
channel_data: ChannelData | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.channel_database: ChannelDatabase = get_channel_db()
|
|
||||||
super().__init__(transport, mapping)
|
super().__init__(transport, mapping)
|
||||||
if channel_data is not None:
|
|
||||||
self.channel_id = channel_data.channel_id
|
|
||||||
self.key_request = bytes.fromhex(channel_data.key_request)
|
|
||||||
self.key_response = bytes.fromhex(channel_data.key_response)
|
|
||||||
self.nonce_request = channel_data.nonce_request
|
|
||||||
self.nonce_response = channel_data.nonce_response
|
|
||||||
self.sync_bit_receive = channel_data.sync_bit_receive
|
|
||||||
self.sync_bit_send = channel_data.sync_bit_send
|
|
||||||
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
|
|
||||||
self._has_valid_channel = True
|
|
||||||
|
|
||||||
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel:
|
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel:
|
||||||
if not self._has_valid_channel:
|
if not self._has_valid_channel:
|
||||||
self._establish_new_channel(helper_debug)
|
self._establish_new_channel(helper_debug)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_channel_data(self) -> ChannelData:
|
|
||||||
return ChannelData(
|
|
||||||
protocol_version_major=2,
|
|
||||||
protocol_version_minor=2,
|
|
||||||
transport_path=self.transport.get_path(),
|
|
||||||
channel_id=self.channel_id,
|
|
||||||
key_request=self.noise.noise_protocol.cipher_state_encrypt.k,
|
|
||||||
key_response=self.noise.noise_protocol.cipher_state_decrypt.k,
|
|
||||||
nonce_request=self.nonce_request,
|
|
||||||
nonce_response=self.nonce_response,
|
|
||||||
sync_bit_receive=self.sync_bit_receive,
|
|
||||||
sync_bit_send=self.sync_bit_send,
|
|
||||||
handshake_hash=self.handshake_hash,
|
|
||||||
)
|
|
||||||
|
|
||||||
def read(self, session_id: int) -> t.Any:
|
def read(self, session_id: int) -> t.Any:
|
||||||
sid, msg_type, msg_data = self.read_and_decrypt()
|
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||||
if sid != session_id:
|
if sid != session_id:
|
||||||
raise Exception("Received messsage on a different session.")
|
raise Exception("Received messsage on a different session.")
|
||||||
self.channel_database.save_channel(self)
|
|
||||||
return self.mapping.decode(msg_type, msg_data)
|
return self.mapping.decode(msg_type, msg_data)
|
||||||
|
|
||||||
def write(self, session_id: int, msg: t.Any) -> None:
|
def write(self, session_id: int, msg: t.Any) -> None:
|
||||||
msg_type, msg_data = self.mapping.encode(msg)
|
msg_type, msg_data = self.mapping.encode(msg)
|
||||||
self._encrypt_and_write(session_id, msg_type, msg_data)
|
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||||
self.channel_database.save_channel(self)
|
|
||||||
|
|
||||||
def get_features(self) -> messages.Features:
|
def get_features(self) -> messages.Features:
|
||||||
if not self._has_valid_channel:
|
if not self._has_valid_channel:
|
||||||
@ -166,13 +130,13 @@ class ProtocolV2Channel(Channel):
|
|||||||
return (channel_id, device_properties)
|
return (channel_id, device_properties)
|
||||||
|
|
||||||
def _init_noise(self, randomness_static: bytes) -> None:
|
def _init_noise(self, randomness_static: bytes) -> None:
|
||||||
self.noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256")
|
self._noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256")
|
||||||
self.noise.set_as_initiator()
|
self._noise.set_as_initiator()
|
||||||
self.noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static)
|
self._noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static)
|
||||||
|
|
||||||
prologue = bytes(self.device_properties)
|
prologue = bytes(self.device_properties)
|
||||||
self.noise.set_prologue(prologue)
|
self._noise.set_prologue(prologue)
|
||||||
self.noise.start_handshake()
|
self._noise.start_handshake()
|
||||||
|
|
||||||
def _do_handshake(
|
def _do_handshake(
|
||||||
self,
|
self,
|
||||||
@ -191,14 +155,10 @@ class ProtocolV2Channel(Channel):
|
|||||||
)
|
)
|
||||||
self._read_ack()
|
self._read_ack()
|
||||||
self._read_handshake_completion_response()
|
self._read_handshake_completion_response()
|
||||||
self.key_request = self.noise.noise_protocol.cipher_state_encrypt.k
|
|
||||||
self.key_response = self.noise.noise_protocol.cipher_state_decrypt.k
|
|
||||||
self.nonce_request = 0
|
|
||||||
self.nonce_response = 1
|
|
||||||
|
|
||||||
def _send_handshake_init_request(self) -> None:
|
def _send_handshake_init_request(self) -> None:
|
||||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||||
host_ephemeral_pubkey = self.noise.write_message()
|
host_ephemeral_pubkey = self._noise.write_message()
|
||||||
|
|
||||||
thp_io.write_payload_to_wire_and_add_checksum(
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||||
@ -219,7 +179,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
"Received message is not a valid handshake init response message",
|
"Received message is not a valid handshake init response message",
|
||||||
err=True,
|
err=True,
|
||||||
)
|
)
|
||||||
self.noise.read_message(payload)
|
self._noise.read_message(payload)
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _send_handshake_completion_request(
|
def _send_handshake_completion_request(
|
||||||
@ -237,7 +197,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
host_pairing_credential=credential,
|
host_pairing_credential=credential,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
message2 = self.noise.write_message(payload=msg_data)
|
message2 = self._noise.write_message(payload=msg_data)
|
||||||
|
|
||||||
ha_completion_req_header = MessageHeader(
|
ha_completion_req_header = MessageHeader(
|
||||||
0x12,
|
0x12,
|
||||||
@ -249,7 +209,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
ha_completion_req_header,
|
ha_completion_req_header,
|
||||||
message2, # encrypted_host_static_pubkey + encrypted_payload,
|
message2, # encrypted_host_static_pubkey + encrypted_payload,
|
||||||
)
|
)
|
||||||
self.handshake_hash = self.noise.get_handshake_hash()
|
self.handshake_hash = self._noise.get_handshake_hash()
|
||||||
|
|
||||||
def _read_handshake_completion_response(self) -> None:
|
def _read_handshake_completion_response(self) -> None:
|
||||||
# Read handshake completion response, ignore payload as we do not care about the state
|
# Read handshake completion response, ignore payload as we do not care about the state
|
||||||
@ -259,9 +219,10 @@ class ProtocolV2Channel(Channel):
|
|||||||
"Received message is not a valid handshake completion response",
|
"Received message is not a valid handshake completion response",
|
||||||
err=True,
|
err=True,
|
||||||
)
|
)
|
||||||
trezor_state = self.noise.decrypt(bytes(data))
|
trezor_state = self._noise.decrypt(bytes(data))
|
||||||
# TODO handle trezor_state
|
# TODO handle trezor_state
|
||||||
print("trezor state:", trezor_state)
|
print("trezor state:", trezor_state)
|
||||||
|
assert trezor_state == b"\x00" or trezor_state == b"\x01"
|
||||||
self._send_ack_1()
|
self._send_ack_1()
|
||||||
|
|
||||||
def _do_pairing(self, helper_debug: DebugLink | None):
|
def _do_pairing(self, helper_debug: DebugLink | None):
|
||||||
@ -305,7 +266,6 @@ class ProtocolV2Channel(Channel):
|
|||||||
message_data: bytes,
|
message_data: bytes,
|
||||||
ctrl_byte: int | None = None,
|
ctrl_byte: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.key_request is not None
|
|
||||||
|
|
||||||
if ctrl_byte is None:
|
if ctrl_byte is None:
|
||||||
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
|
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
|
||||||
@ -315,7 +275,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
msg_type = message_type.to_bytes(2, "big")
|
msg_type = message_type.to_bytes(2, "big")
|
||||||
data = sid + msg_type + message_data
|
data = sid + msg_type + message_data
|
||||||
|
|
||||||
encrypted_message = self.noise.encrypt(data)
|
encrypted_message = self._noise.encrypt(data)
|
||||||
|
|
||||||
header = MessageHeader(
|
header = MessageHeader(
|
||||||
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||||
@ -354,7 +314,7 @@ class ProtocolV2Channel(Channel):
|
|||||||
else:
|
else:
|
||||||
self._send_ack_0()
|
self._send_ack_0()
|
||||||
|
|
||||||
message = self.noise.decrypt(bytes(raw_payload))
|
message = self._noise.decrypt(bytes(raw_payload))
|
||||||
session_id = message[0]
|
session_id = message[0]
|
||||||
message_type = message[1:3]
|
message_type = message[1:3]
|
||||||
message_data = message[3:]
|
message_data = message[3:]
|
||||||
|
@ -302,7 +302,7 @@ def test_credential_phase(client: Client) -> None:
|
|||||||
|
|
||||||
# Delete channel from the device by sending badly encrypted message
|
# Delete channel from the device by sending badly encrypted message
|
||||||
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
|
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
|
||||||
protocol.noise.noise_protocol.cipher_state_encrypt.n = 250
|
protocol._noise.noise_protocol.cipher_state_encrypt.n = 250
|
||||||
|
|
||||||
protocol._send_message(ButtonAck())
|
protocol._send_message(ButtonAck())
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
@ -351,7 +351,7 @@ def test_credential_phase(client: Client) -> None:
|
|||||||
|
|
||||||
# Delete channel from the device by sending badly encrypted message
|
# Delete channel from the device by sending badly encrypted message
|
||||||
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
|
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
|
||||||
protocol.noise.noise_protocol.cipher_state_encrypt.n = 100
|
protocol._noise.noise_protocol.cipher_state_encrypt.n = 100
|
||||||
|
|
||||||
protocol._send_message(ButtonAck())
|
protocol._send_message(ButtonAck())
|
||||||
with pytest.raises(Exception) as e:
|
with pytest.raises(Exception) as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user