1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00

refactor(trezorlib): extract protocol_v1 and protocol_v2

[no changelog]
This commit is contained in:
M1nd3r 2024-08-12 19:54:47 +02:00
parent 3cab8fbb79
commit 43ed952900
5 changed files with 100 additions and 72 deletions

View File

@ -22,6 +22,7 @@ import requests
from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
from .protocol import PROTOCOL_VERSION_1, PROTOCOL_VERSION_2
if TYPE_CHECKING:
from ..models import TrezorModel
@ -34,8 +35,6 @@ TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
TREZORD_VERSION_MODERN = (2, 0, 25)
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
PROTOCOL_VERSION_1 = 1
PROTOCOL_VERSION_2 = 2
CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)

View File

@ -15,14 +15,16 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Optional, TypeVar
from typing_extensions import Protocol as StructuralType
from ..mapping import ProtobufMapping
from . import MessagePayload, Transport
PROTOCOL_VERSION_1 = 1
PROTOCOL_VERSION_2 = 2
REPLEN = 64
V2_FIRST_CHUNK = 0x01
@ -154,9 +156,13 @@ class ProtocolBasedTransport(Transport):
def deprecated_end_session(self) -> None:
self.protocol.deprecated_end_session()
def get_protocol(self) -> Protocol:
def get_protocol(self, version: Optional[int] = None) -> Protocol:
if version is not None:
return _get_protocol(version, self.handle)
from .. import mapping, messages
from ..messages import FailureType
from .protocol_v1 import ProtocolV1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(
messages.Initialize()
@ -168,6 +174,8 @@ class ProtocolBasedTransport(Transport):
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
self.handle.close()
if isinstance(response, messages.Failure):
from .protocol_v2 import ProtocolV2
if (
response.code == FailureType.UnexpectedMessage
and response.message == "Invalid protocol"
@ -178,72 +186,15 @@ class ProtocolBasedTransport(Transport):
return protocol
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
def _get_protocol(version: int, handle: Handle) -> Protocol:
if version == PROTOCOL_VERSION_1:
from .protocol_v1 import ProtocolV1
HEADER_LEN = struct.calcsize(">HL")
return ProtocolV1(handle)
def initialize_connection(
self,
mapping: "ProtobufMapping",
session_id: Optional[bytes] = None,
derive_caradano: Optional[bool] = None,
):
from .. import messages
if version == PROTOCOL_VERSION_2:
from .protocol_v2 import ProtocolV2
msg = messages.Initialize(
session_id=session_id,
derive_cardano=derive_caradano,
)
msg_type, msg_data = mapping.encode(msg)
self.write(msg_type, msg_data)
(resp_type, resp_data) = self.read()
return mapping.decode(resp_type, resp_data)
return ProtocolV2(handle)
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]
class ProtocolV2(Protocol):
def __init__(self, handle: Handle) -> None:
super().__init__(handle)
raise NotImplementedError

View File

@ -0,0 +1,72 @@
import struct
from typing import Optional, Tuple
from ..mapping import ProtobufMapping
from ..transport import MessagePayload
from ..transport.protocol import REPLEN, Protocol
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
HEADER_LEN = struct.calcsize(">HL")
def initialize_connection(
self,
mapping: "ProtobufMapping",
session_id: Optional[bytes] = None,
derive_caradano: Optional[bool] = None,
):
from .. import messages
msg = messages.Initialize(
session_id=session_id,
derive_cardano=derive_caradano,
)
msg_type, msg_data = mapping.encode(msg)
self.write(msg_type, msg_data)
(resp_type, resp_data) = self.read()
return mapping.decode(resp_type, resp_data)
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -0,0 +1,6 @@
from ..transport.protocol import Handle, Protocol
class ProtocolV2(Protocol):
def __init__(self, handle: Handle) -> None:
super().__init__(handle)

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from ..log import DUMP_PACKETS
from . import TransportException
from .protocol import Protocol, ProtocolBasedTransport, ProtocolV1
from .protocol import PROTOCOL_VERSION_1, Protocol, ProtocolBasedTransport
if TYPE_CHECKING:
from ..models import TrezorModel
@ -103,7 +103,7 @@ class UdpTransport(ProtocolBasedTransport):
if protocol is None and not skip_protocol_detection:
protocol = self.get_protocol()
elif protocol is None:
protocol = ProtocolV1(self.handle)
protocol = self.get_protocol(version=PROTOCOL_VERSION_1)
super().__init__(protocol)