mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-23 11:35:42 +00:00
fix(python): improve protocol_v2, remove its channel database
This commit is contained in:
parent
b5b0a354a9
commit
50c1299fb3
@ -12,11 +12,9 @@ from ... import exceptions, messages, protobuf
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
from ..thp import checksum, thp_io
|
||||
from ..thp.channel_data import ChannelData
|
||||
from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.message_header import MessageHeader
|
||||
from . import control_byte
|
||||
from .channel_database import ChannelDatabase, get_channel_db
|
||||
from .protocol_and_channel import Channel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -30,11 +28,6 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
||||
|
||||
class ProtocolV2Channel(Channel):
|
||||
channel_id: int
|
||||
channel_database: ChannelDatabase
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
handshake_hash: bytes
|
||||
@ -46,52 +39,23 @@ class ProtocolV2Channel(Channel):
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_data: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.channel_database: ChannelDatabase = get_channel_db()
|
||||
super().__init__(transport, mapping)
|
||||
if channel_data is not None:
|
||||
self.channel_id = channel_data.channel_id
|
||||
self.key_request = bytes.fromhex(channel_data.key_request)
|
||||
self.key_response = bytes.fromhex(channel_data.key_response)
|
||||
self.nonce_request = channel_data.nonce_request
|
||||
self.nonce_response = channel_data.nonce_response
|
||||
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||
self.sync_bit_send = channel_data.sync_bit_send
|
||||
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
|
||||
self._has_valid_channel = True
|
||||
|
||||
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel(helper_debug)
|
||||
return self
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
return ChannelData(
|
||||
protocol_version_major=2,
|
||||
protocol_version_minor=2,
|
||||
transport_path=self.transport.get_path(),
|
||||
channel_id=self.channel_id,
|
||||
key_request=self.noise.noise_protocol.cipher_state_encrypt.k,
|
||||
key_response=self.noise.noise_protocol.cipher_state_decrypt.k,
|
||||
nonce_request=self.nonce_request,
|
||||
nonce_response=self.nonce_response,
|
||||
sync_bit_receive=self.sync_bit_receive,
|
||||
sync_bit_send=self.sync_bit_send,
|
||||
handshake_hash=self.handshake_hash,
|
||||
)
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||
if sid != session_id:
|
||||
raise Exception("Received messsage on a different session.")
|
||||
self.channel_database.save_channel(self)
|
||||
return self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
def write(self, session_id: int, msg: t.Any) -> None:
|
||||
msg_type, msg_data = self.mapping.encode(msg)
|
||||
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||
self.channel_database.save_channel(self)
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if not self._has_valid_channel:
|
||||
@ -166,13 +130,13 @@ class ProtocolV2Channel(Channel):
|
||||
return (channel_id, device_properties)
|
||||
|
||||
def _init_noise(self, randomness_static: bytes) -> None:
|
||||
self.noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256")
|
||||
self.noise.set_as_initiator()
|
||||
self.noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static)
|
||||
self._noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256")
|
||||
self._noise.set_as_initiator()
|
||||
self._noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static)
|
||||
|
||||
prologue = bytes(self.device_properties)
|
||||
self.noise.set_prologue(prologue)
|
||||
self.noise.start_handshake()
|
||||
self._noise.set_prologue(prologue)
|
||||
self._noise.start_handshake()
|
||||
|
||||
def _do_handshake(
|
||||
self,
|
||||
@ -191,14 +155,10 @@ class ProtocolV2Channel(Channel):
|
||||
)
|
||||
self._read_ack()
|
||||
self._read_handshake_completion_response()
|
||||
self.key_request = self.noise.noise_protocol.cipher_state_encrypt.k
|
||||
self.key_response = self.noise.noise_protocol.cipher_state_decrypt.k
|
||||
self.nonce_request = 0
|
||||
self.nonce_response = 1
|
||||
|
||||
def _send_handshake_init_request(self) -> None:
|
||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||
host_ephemeral_pubkey = self.noise.write_message()
|
||||
host_ephemeral_pubkey = self._noise.write_message()
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||
@ -219,7 +179,7 @@ class ProtocolV2Channel(Channel):
|
||||
"Received message is not a valid handshake init response message",
|
||||
err=True,
|
||||
)
|
||||
self.noise.read_message(payload)
|
||||
self._noise.read_message(payload)
|
||||
return payload
|
||||
|
||||
def _send_handshake_completion_request(
|
||||
@ -237,7 +197,7 @@ class ProtocolV2Channel(Channel):
|
||||
host_pairing_credential=credential,
|
||||
)
|
||||
)
|
||||
message2 = self.noise.write_message(payload=msg_data)
|
||||
message2 = self._noise.write_message(payload=msg_data)
|
||||
|
||||
ha_completion_req_header = MessageHeader(
|
||||
0x12,
|
||||
@ -249,7 +209,7 @@ class ProtocolV2Channel(Channel):
|
||||
ha_completion_req_header,
|
||||
message2, # encrypted_host_static_pubkey + encrypted_payload,
|
||||
)
|
||||
self.handshake_hash = self.noise.get_handshake_hash()
|
||||
self.handshake_hash = self._noise.get_handshake_hash()
|
||||
|
||||
def _read_handshake_completion_response(self) -> None:
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
@ -259,9 +219,10 @@ class ProtocolV2Channel(Channel):
|
||||
"Received message is not a valid handshake completion response",
|
||||
err=True,
|
||||
)
|
||||
trezor_state = self.noise.decrypt(bytes(data))
|
||||
trezor_state = self._noise.decrypt(bytes(data))
|
||||
# TODO handle trezor_state
|
||||
print("trezor state:", trezor_state)
|
||||
assert trezor_state == b"\x00" or trezor_state == b"\x01"
|
||||
self._send_ack_1()
|
||||
|
||||
def _do_pairing(self, helper_debug: DebugLink | None):
|
||||
@ -305,7 +266,6 @@ class ProtocolV2Channel(Channel):
|
||||
message_data: bytes,
|
||||
ctrl_byte: int | None = None,
|
||||
) -> None:
|
||||
assert self.key_request is not None
|
||||
|
||||
if ctrl_byte is None:
|
||||
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
|
||||
@ -315,7 +275,7 @@ class ProtocolV2Channel(Channel):
|
||||
msg_type = message_type.to_bytes(2, "big")
|
||||
data = sid + msg_type + message_data
|
||||
|
||||
encrypted_message = self.noise.encrypt(data)
|
||||
encrypted_message = self._noise.encrypt(data)
|
||||
|
||||
header = MessageHeader(
|
||||
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
@ -354,7 +314,7 @@ class ProtocolV2Channel(Channel):
|
||||
else:
|
||||
self._send_ack_0()
|
||||
|
||||
message = self.noise.decrypt(bytes(raw_payload))
|
||||
message = self._noise.decrypt(bytes(raw_payload))
|
||||
session_id = message[0]
|
||||
message_type = message[1:3]
|
||||
message_data = message[3:]
|
||||
|
@ -302,7 +302,7 @@ def test_credential_phase(client: Client) -> None:
|
||||
|
||||
# Delete channel from the device by sending badly encrypted message
|
||||
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
|
||||
protocol.noise.noise_protocol.cipher_state_encrypt.n = 250
|
||||
protocol._noise.noise_protocol.cipher_state_encrypt.n = 250
|
||||
|
||||
protocol._send_message(ButtonAck())
|
||||
with pytest.raises(Exception) as e:
|
||||
@ -351,7 +351,7 @@ def test_credential_phase(client: Client) -> None:
|
||||
|
||||
# Delete channel from the device by sending badly encrypted message
|
||||
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
|
||||
protocol.noise.noise_protocol.cipher_state_encrypt.n = 100
|
||||
protocol._noise.noise_protocol.cipher_state_encrypt.n = 100
|
||||
|
||||
protocol._send_message(ButtonAck())
|
||||
with pytest.raises(Exception) as e:
|
||||
|
Loading…
Reference in New Issue
Block a user