mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-22 12:32:02 +00:00
refactor(trezorlib): decouple protocol from handler
[no changelog]
This commit is contained in:
parent
ef33422ab3
commit
121ed1f530
@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, List, Optional
|
|||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from ..models import TREZOR_ONE, TrezorModel
|
from ..models import TREZOR_ONE, TrezorModel
|
||||||
|
from ..transport.protocol import Handle
|
||||||
from . import UDEV_RULES_STR, TransportException
|
from . import UDEV_RULES_STR, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
@ -127,7 +128,10 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||||
|
|
||||||
super().__init__(protocol=ProtocolV1(self.handle))
|
super().__init__(protocol=ProtocolV1())
|
||||||
|
|
||||||
|
def get_handle(self) -> Handle:
|
||||||
|
return self.handle
|
||||||
|
|
||||||
def get_path(self) -> str:
|
def get_path(self) -> str:
|
||||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from typing import Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
from typing_extensions import Protocol as StructuralType
|
from typing_extensions import Protocol as StructuralType
|
||||||
|
|
||||||
@ -71,25 +71,18 @@ class Protocol:
|
|||||||
its messages.
|
its messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, handle: Handle) -> None:
|
def __init__(self) -> None:
|
||||||
self.handle = handle
|
|
||||||
self.session_counter = 0
|
self.session_counter = 0
|
||||||
|
|
||||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
def read(self, read_chunk: Callable[[], bytes]) -> MessagePayload:
|
||||||
def begin_session(self) -> None:
|
|
||||||
if self.session_counter == 0:
|
|
||||||
self.handle.open()
|
|
||||||
self.session_counter += 1
|
|
||||||
|
|
||||||
def end_session(self) -> None:
|
|
||||||
self.session_counter = max(self.session_counter - 1, 0)
|
|
||||||
if self.session_counter == 0:
|
|
||||||
self.handle.close()
|
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def write(
|
||||||
|
self,
|
||||||
|
message_type: int,
|
||||||
|
message_data: bytes,
|
||||||
|
write_chunk: Callable[[bytes], None],
|
||||||
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -102,18 +95,30 @@ class ProtocolBasedTransport(Transport):
|
|||||||
|
|
||||||
def __init__(self, protocol: Protocol) -> None:
|
def __init__(self, protocol: Protocol) -> None:
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
|
self.session_counter = 0
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
self.protocol.write(message_type, message_data)
|
self.protocol.write(
|
||||||
|
message_type,
|
||||||
|
message_data,
|
||||||
|
self.get_handle().write_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def read(self) -> MessagePayload:
|
||||||
return self.protocol.read()
|
return self.protocol.read(self.get_handle().read_chunk)
|
||||||
|
|
||||||
def begin_session(self) -> None:
|
def begin_session(self) -> None:
|
||||||
self.protocol.begin_session()
|
if self.session_counter == 0:
|
||||||
|
self.get_handle().open()
|
||||||
|
self.session_counter += 1
|
||||||
|
|
||||||
def end_session(self) -> None:
|
def end_session(self) -> None:
|
||||||
self.protocol.end_session()
|
self.session_counter = max(self.session_counter - 1, 0)
|
||||||
|
if self.session_counter == 0:
|
||||||
|
self.get_handle().close()
|
||||||
|
|
||||||
|
def get_handle(self) -> Handle:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ProtocolV1(Protocol):
|
class ProtocolV1(Protocol):
|
||||||
@ -123,7 +128,12 @@ class ProtocolV1(Protocol):
|
|||||||
|
|
||||||
HEADER_LEN = struct.calcsize(">HL")
|
HEADER_LEN = struct.calcsize(">HL")
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def write(
|
||||||
|
self,
|
||||||
|
message_type: int,
|
||||||
|
message_data: bytes,
|
||||||
|
write_chunk: Callable[[bytes], None],
|
||||||
|
) -> None:
|
||||||
header = struct.pack(">HL", message_type, len(message_data))
|
header = struct.pack(">HL", message_type, len(message_data))
|
||||||
buffer = bytearray(b"##" + header + message_data)
|
buffer = bytearray(b"##" + header + message_data)
|
||||||
|
|
||||||
@ -131,23 +141,23 @@ class ProtocolV1(Protocol):
|
|||||||
# Report ID, data padded to 63 bytes
|
# Report ID, data padded to 63 bytes
|
||||||
chunk = b"?" + buffer[: REPLEN - 1]
|
chunk = b"?" + buffer[: REPLEN - 1]
|
||||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||||
self.handle.write_chunk(chunk)
|
write_chunk(chunk)
|
||||||
buffer = buffer[63:]
|
buffer = buffer[63:]
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def read(self, read_chunk: Callable[[], bytes]) -> MessagePayload:
|
||||||
buffer = bytearray()
|
buffer = bytearray()
|
||||||
# Read header with first part of message data
|
# Read header with first part of message data
|
||||||
msg_type, datalen, first_chunk = self.read_first()
|
msg_type, datalen, first_chunk = self.read_first(read_chunk)
|
||||||
buffer.extend(first_chunk)
|
buffer.extend(first_chunk)
|
||||||
|
|
||||||
# Read the rest of the message
|
# Read the rest of the message
|
||||||
while len(buffer) < datalen:
|
while len(buffer) < datalen:
|
||||||
buffer.extend(self.read_next())
|
buffer.extend(self.read_next(read_chunk))
|
||||||
|
|
||||||
return msg_type, buffer[:datalen]
|
return msg_type, buffer[:datalen]
|
||||||
|
|
||||||
def read_first(self) -> Tuple[int, int, bytes]:
|
def read_first(self, read_chunk: Callable[[], bytes]) -> Tuple[int, int, bytes]:
|
||||||
chunk = self.handle.read_chunk()
|
chunk = read_chunk()
|
||||||
if chunk[:3] != b"?##":
|
if chunk[:3] != b"?##":
|
||||||
raise RuntimeError("Unexpected magic characters")
|
raise RuntimeError("Unexpected magic characters")
|
||||||
try:
|
try:
|
||||||
@ -158,8 +168,8 @@ class ProtocolV1(Protocol):
|
|||||||
data = chunk[3 + self.HEADER_LEN :]
|
data = chunk[3 + self.HEADER_LEN :]
|
||||||
return msg_type, datalen, data
|
return msg_type, datalen, data
|
||||||
|
|
||||||
def read_next(self) -> bytes:
|
def read_next(self, read_chunk: Callable[[], bytes]) -> bytes:
|
||||||
chunk = self.handle.read_chunk()
|
chunk = read_chunk()
|
||||||
if chunk[:1] != b"?":
|
if chunk[:1] != b"?":
|
||||||
raise RuntimeError("Unexpected magic characters")
|
raise RuntimeError("Unexpected magic characters")
|
||||||
return chunk[1:]
|
return chunk[1:]
|
||||||
|
@ -20,6 +20,7 @@ import time
|
|||||||
from typing import TYPE_CHECKING, Iterable, Optional
|
from typing import TYPE_CHECKING, Iterable, Optional
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
|
from ..transport.protocol import Handle
|
||||||
from . import TransportException
|
from . import TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
@ -49,7 +50,10 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
self.device = (host, port)
|
self.device = (host, port)
|
||||||
self.socket: Optional[socket.socket] = None
|
self.socket: Optional[socket.socket] = None
|
||||||
|
|
||||||
super().__init__(protocol=ProtocolV1(self))
|
super().__init__(protocol=ProtocolV1())
|
||||||
|
|
||||||
|
def get_handle(self) -> Handle:
|
||||||
|
return self
|
||||||
|
|
||||||
def get_path(self) -> str:
|
def get_path(self) -> str:
|
||||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||||
|
@ -23,7 +23,7 @@ from typing import Iterable, List, Optional
|
|||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from ..models import TREZORS, TrezorModel
|
from ..models import TREZORS, TrezorModel
|
||||||
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import Handle, ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -112,7 +112,10 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
super().__init__(protocol=ProtocolV1(handle))
|
super().__init__(protocol=ProtocolV1())
|
||||||
|
|
||||||
|
def get_handle(self) -> Handle:
|
||||||
|
return self.handle
|
||||||
|
|
||||||
def get_path(self) -> str:
|
def get_path(self) -> str:
|
||||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||||
|
Loading…
Reference in New Issue
Block a user