diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index f3e4468df8..995c694dcd 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -22,6 +22,7 @@ import requests from ..log import DUMP_PACKETS from . import DeviceIsBusy, MessagePayload, Transport, TransportException +from .protocol import PROTOCOL_VERSION_1, PROTOCOL_VERSION_2 if TYPE_CHECKING: from ..models import TrezorModel @@ -34,8 +35,6 @@ TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_VERSION_MODERN = (2, 0, 25) TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value -PROTOCOL_VERSION_1 = 1 -PROTOCOL_VERSION_2 = 2 CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index 3010e11043..dcad8378e0 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -15,14 +15,16 @@ # If not, see . import logging -import struct -from typing import TYPE_CHECKING, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar from typing_extensions import Protocol as StructuralType from ..mapping import ProtobufMapping from . import MessagePayload, Transport +PROTOCOL_VERSION_1 = 1 +PROTOCOL_VERSION_2 = 2 + REPLEN = 64 V2_FIRST_CHUNK = 0x01 @@ -154,9 +156,13 @@ class ProtocolBasedTransport(Transport): def deprecated_end_session(self) -> None: self.protocol.deprecated_end_session() - def get_protocol(self) -> Protocol: + def get_protocol(self, version: Optional[int] = None) -> Protocol: + if version is not None: + return _get_protocol(version, self.handle) + from .. import mapping, messages from ..messages import FailureType + from .protocol_v1 import ProtocolV1 request_type, request_data = mapping.DEFAULT_MAPPING.encode( messages.Initialize() @@ -168,6 +174,8 @@ class ProtocolBasedTransport(Transport): response = mapping.DEFAULT_MAPPING.decode(response_type, response_data) self.handle.close() if isinstance(response, messages.Failure): + from .protocol_v2 import ProtocolV2 + if ( response.code == FailureType.UnexpectedMessage and response.message == "Invalid protocol" @@ -178,72 +186,15 @@ class ProtocolBasedTransport(Transport): return protocol -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ +def _get_protocol(version: int, handle: Handle) -> Protocol: + if version == PROTOCOL_VERSION_1: + from .protocol_v1 import ProtocolV1 - HEADER_LEN = struct.calcsize(">HL") + return ProtocolV1(handle) - def initialize_connection( - self, - mapping: "ProtobufMapping", - session_id: Optional[bytes] = None, - derive_caradano: Optional[bool] = None, - ): - from .. import messages + if version == PROTOCOL_VERSION_2: + from .protocol_v2 import ProtocolV2 - msg = messages.Initialize( - session_id=session_id, - derive_cardano=derive_caradano, - ) - msg_type, msg_data = mapping.encode(msg) - self.write(msg_type, msg_data) - (resp_type, resp_data) = self.read() - return mapping.decode(resp_type, resp_data) + return ProtocolV2(handle) - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[63:] - - def read(self) -> MessagePayload: - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next()) - - return msg_type, buffer[:datalen] - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - if chunk[:3] != b"?##": - raise RuntimeError("Unexpected magic characters") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError("Cannot parse header") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - if chunk[:1] != b"?": - raise RuntimeError("Unexpected magic characters") - return chunk[1:] - - -class ProtocolV2(Protocol): - def __init__(self, handle: Handle) -> None: - super().__init__(handle) + raise NotImplementedError diff --git a/python/src/trezorlib/transport/protocol_v1.py b/python/src/trezorlib/transport/protocol_v1.py new file mode 100644 index 0000000000..c5d18c0042 --- /dev/null +++ b/python/src/trezorlib/transport/protocol_v1.py @@ -0,0 +1,72 @@ +import struct +from typing import Optional, Tuple + +from ..mapping import ProtobufMapping +from ..transport import MessagePayload +from ..transport.protocol import REPLEN, Protocol + + +class ProtocolV1(Protocol): + """Protocol version 1. Currently (11/2018) in use on all Trezors. + Does not understand sessions. + """ + + HEADER_LEN = struct.calcsize(">HL") + + def initialize_connection( + self, + mapping: "ProtobufMapping", + session_id: Optional[bytes] = None, + derive_caradano: Optional[bool] = None, + ): + from .. import messages + + msg = messages.Initialize( + session_id=session_id, + derive_cardano=derive_caradano, + ) + msg_type, msg_data = mapping.encode(msg) + self.write(msg_type, msg_data) + (resp_type, resp_data) = self.read() + return mapping.decode(resp_type, resp_data) + + def write(self, message_type: int, message_data: bytes) -> None: + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) + + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: REPLEN - 1] + chunk = chunk.ljust(REPLEN, b"\x00") + self.handle.write_chunk(chunk) + buffer = buffer[63:] + + def read(self) -> MessagePayload: + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_first(self) -> Tuple[int, int, bytes]: + chunk = self.handle.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError("Unexpected magic characters") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.handle.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError("Unexpected magic characters") + return chunk[1:] diff --git a/python/src/trezorlib/transport/protocol_v2.py b/python/src/trezorlib/transport/protocol_v2.py new file mode 100644 index 0000000000..455d4e1a01 --- /dev/null +++ b/python/src/trezorlib/transport/protocol_v2.py @@ -0,0 +1,6 @@ +from ..transport.protocol import Handle, Protocol + + +class ProtocolV2(Protocol): + def __init__(self, handle: Handle) -> None: + super().__init__(handle) diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index cb35402990..c3a00514ca 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple from ..log import DUMP_PACKETS from . import TransportException -from .protocol import Protocol, ProtocolBasedTransport, ProtocolV1 +from .protocol import PROTOCOL_VERSION_1, Protocol, ProtocolBasedTransport if TYPE_CHECKING: from ..models import TrezorModel @@ -103,7 +103,7 @@ class UdpTransport(ProtocolBasedTransport): if protocol is None and not skip_protocol_detection: protocol = self.get_protocol() elif protocol is None: - protocol = ProtocolV1(self.handle) + protocol = self.get_protocol(version=PROTOCOL_VERSION_1) super().__init__(protocol)