mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-22 12:32:02 +00:00
refactor(trezorlib): extract protocol_v1 and protocol_v2
[no changelog]
This commit is contained in:
parent
3cab8fbb79
commit
43ed952900
@ -22,6 +22,7 @@ import requests
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||
from .protocol import PROTOCOL_VERSION_1, PROTOCOL_VERSION_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
@ -34,8 +35,6 @@ TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
|
||||
TREZORD_VERSION_MODERN = (2, 0, 25)
|
||||
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
|
||||
|
||||
PROTOCOL_VERSION_1 = 1
|
||||
PROTOCOL_VERSION_2 = 2
|
||||
CONNECTION = requests.Session()
|
||||
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
||||
|
||||
|
@ -15,14 +15,16 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from ..mapping import ProtobufMapping
|
||||
from . import MessagePayload, Transport
|
||||
|
||||
PROTOCOL_VERSION_1 = 1
|
||||
PROTOCOL_VERSION_2 = 2
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
V2_FIRST_CHUNK = 0x01
|
||||
@ -154,9 +156,13 @@ class ProtocolBasedTransport(Transport):
|
||||
def deprecated_end_session(self) -> None:
|
||||
self.protocol.deprecated_end_session()
|
||||
|
||||
def get_protocol(self) -> Protocol:
|
||||
def get_protocol(self, version: Optional[int] = None) -> Protocol:
|
||||
if version is not None:
|
||||
return _get_protocol(version, self.handle)
|
||||
|
||||
from .. import mapping, messages
|
||||
from ..messages import FailureType
|
||||
from .protocol_v1 import ProtocolV1
|
||||
|
||||
request_type, request_data = mapping.DEFAULT_MAPPING.encode(
|
||||
messages.Initialize()
|
||||
@ -168,6 +174,8 @@ class ProtocolBasedTransport(Transport):
|
||||
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
self.handle.close()
|
||||
if isinstance(response, messages.Failure):
|
||||
from .protocol_v2 import ProtocolV2
|
||||
|
||||
if (
|
||||
response.code == FailureType.UnexpectedMessage
|
||||
and response.message == "Invalid protocol"
|
||||
@ -178,72 +186,15 @@ class ProtocolBasedTransport(Transport):
|
||||
return protocol
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
def _get_protocol(version: int, handle: Handle) -> Protocol:
|
||||
if version == PROTOCOL_VERSION_1:
|
||||
from .protocol_v1 import ProtocolV1
|
||||
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
return ProtocolV1(handle)
|
||||
|
||||
def initialize_connection(
|
||||
self,
|
||||
mapping: "ProtobufMapping",
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_caradano: Optional[bool] = None,
|
||||
):
|
||||
from .. import messages
|
||||
if version == PROTOCOL_VERSION_2:
|
||||
from .protocol_v2 import ProtocolV2
|
||||
|
||||
msg = messages.Initialize(
|
||||
session_id=session_id,
|
||||
derive_cardano=derive_caradano,
|
||||
)
|
||||
msg_type, msg_data = mapping.encode(msg)
|
||||
self.write(msg_type, msg_data)
|
||||
(resp_type, resp_data) = self.read()
|
||||
return mapping.decode(resp_type, resp_data)
|
||||
return ProtocolV2(handle)
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
||||
|
||||
|
||||
class ProtocolV2(Protocol):
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
super().__init__(handle)
|
||||
raise NotImplementedError
|
||||
|
72
python/src/trezorlib/transport/protocol_v1.py
Normal file
72
python/src/trezorlib/transport/protocol_v1.py
Normal file
@ -0,0 +1,72 @@
|
||||
import struct
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ..mapping import ProtobufMapping
|
||||
from ..transport import MessagePayload
|
||||
from ..transport.protocol import REPLEN, Protocol
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def initialize_connection(
|
||||
self,
|
||||
mapping: "ProtobufMapping",
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_caradano: Optional[bool] = None,
|
||||
):
|
||||
from .. import messages
|
||||
|
||||
msg = messages.Initialize(
|
||||
session_id=session_id,
|
||||
derive_cardano=derive_caradano,
|
||||
)
|
||||
msg_type, msg_data = mapping.encode(msg)
|
||||
self.write(msg_type, msg_data)
|
||||
(resp_type, resp_data) = self.read()
|
||||
return mapping.decode(resp_type, resp_data)
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
6
python/src/trezorlib/transport/protocol_v2.py
Normal file
6
python/src/trezorlib/transport/protocol_v2.py
Normal file
@ -0,0 +1,6 @@
|
||||
from ..transport.protocol import Handle, Protocol
|
||||
|
||||
|
||||
class ProtocolV2(Protocol):
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
super().__init__(handle)
|
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
from .protocol import Protocol, ProtocolBasedTransport, ProtocolV1
|
||||
from .protocol import PROTOCOL_VERSION_1, Protocol, ProtocolBasedTransport
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
@ -103,7 +103,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if protocol is None and not skip_protocol_detection:
|
||||
protocol = self.get_protocol()
|
||||
elif protocol is None:
|
||||
protocol = ProtocolV1(self.handle)
|
||||
protocol = self.get_protocol(version=PROTOCOL_VERSION_1)
|
||||
|
||||
super().__init__(protocol)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user