mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-03 16:56:07 +00:00
python: unify protobuf-encoding code paths
Protobuf encoding now happens in TrezorClient, and transports get encoded blobs to (chunkify and) send. This is a better design because transports don't need to know about protobuf. It also lays groundwork for sending raw bytes feature (#116) This commit also removes all vestiges of ProtocolV2 which was never used and will probably need to be redesigned from the ground up anyway. The code is still ready for protocol flexibility.
This commit is contained in:
parent
22b167a961
commit
9a330f3475
@ -29,6 +29,8 @@ _At the moment, the project does **not** adhere to [Semantic Versioning](https:/
|
|||||||
- `get_default_client` respects `TREZOR_PATH` environment variable
|
- `get_default_client` respects `TREZOR_PATH` environment variable
|
||||||
- UI callback `get_passphrase` has an additional argument `available_on_device`,
|
- UI callback `get_passphrase` has an additional argument `available_on_device`,
|
||||||
indicating that the connected Trezor is capable of on-device entry
|
indicating that the connected Trezor is capable of on-device entry
|
||||||
|
- `Transport.write` and `read` method signatures changed to accept bytes instead of
|
||||||
|
protobuf messages
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
@ -21,7 +21,8 @@ import warnings
|
|||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
from . import MINIMUM_FIRMWARE_VERSION, exceptions, messages, tools
|
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
|
||||||
|
from .log import DUMP_BYTES
|
||||||
from .messages import Capability
|
from .messages import Capability
|
||||||
|
|
||||||
if sys.version_info.major < 3:
|
if sys.version_info.major < 3:
|
||||||
@ -114,11 +115,34 @@ class TrezorClient:
|
|||||||
|
|
||||||
def _raw_write(self, msg):
|
def _raw_write(self, msg):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
self.transport.write(msg)
|
LOG.debug(
|
||||||
|
"sending message: {}".format(msg.__class__.__name__),
|
||||||
|
extra={"protobuf": msg},
|
||||||
|
)
|
||||||
|
msg_type, msg_bytes = mapping.encode(msg)
|
||||||
|
LOG.log(
|
||||||
|
DUMP_BYTES,
|
||||||
|
"encoded as type {} ({} bytes): {}".format(
|
||||||
|
msg_type, len(msg_bytes), msg_bytes.hex()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.transport.write(msg_type, msg_bytes)
|
||||||
|
|
||||||
def _raw_read(self):
|
def _raw_read(self):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
return self.transport.read()
|
msg_type, msg_bytes = self.transport.read()
|
||||||
|
LOG.log(
|
||||||
|
DUMP_BYTES,
|
||||||
|
"received type {} ({} bytes): {}".format(
|
||||||
|
msg_type, len(msg_bytes), msg_bytes.hex()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
msg = mapping.decode(msg_type, msg_bytes)
|
||||||
|
LOG.debug(
|
||||||
|
"received message: {}".format(msg.__class__.__name__),
|
||||||
|
extra={"protobuf": msg},
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
|
||||||
def _callback_pin(self, msg):
|
def _callback_pin(self, msg):
|
||||||
try:
|
try:
|
||||||
|
@ -19,7 +19,7 @@ from copy import deepcopy
|
|||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
from . import messages as proto, protobuf
|
from . import mapping, messages as proto, protobuf
|
||||||
from .client import TrezorClient
|
from .client import TrezorClient
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
|
||||||
@ -45,11 +45,12 @@ class DebugLink:
|
|||||||
self.transport.end_session()
|
self.transport.end_session()
|
||||||
|
|
||||||
def _call(self, msg, nowait=False):
|
def _call(self, msg, nowait=False):
|
||||||
self.transport.write(msg)
|
msg_type, msg_bytes = mapping.encode(msg)
|
||||||
|
self.transport.write(msg_type, msg_bytes)
|
||||||
if nowait:
|
if nowait:
|
||||||
return None
|
return None
|
||||||
ret = self.transport.read()
|
ret_type, ret_bytes = self.transport.read()
|
||||||
return ret
|
return mapping.decode(ret_type, ret_bytes)
|
||||||
|
|
||||||
def state(self):
|
def state(self):
|
||||||
return self._call(proto.DebugLinkGetState())
|
return self._call(proto.DebugLinkGetState())
|
||||||
|
@ -22,8 +22,10 @@ from . import protobuf
|
|||||||
OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]]
|
OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]]
|
||||||
|
|
||||||
DUMP_BYTES = 5
|
DUMP_BYTES = 5
|
||||||
|
DUMP_PACKETS = 4
|
||||||
|
|
||||||
logging.addLevelName(DUMP_BYTES, "BYTES")
|
logging.addLevelName(DUMP_BYTES, "BYTES")
|
||||||
|
logging.addLevelName(DUMP_PACKETS, "PACKETS")
|
||||||
|
|
||||||
|
|
||||||
class PrettyProtobufFormatter(logging.Formatter):
|
class PrettyProtobufFormatter(logging.Formatter):
|
||||||
@ -54,6 +56,8 @@ def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] =
|
|||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
if verbosity > 1:
|
if verbosity > 1:
|
||||||
level = DUMP_BYTES
|
level = DUMP_BYTES
|
||||||
|
if verbosity > 2:
|
||||||
|
level = DUMP_PACKETS
|
||||||
|
|
||||||
logger = logging.getLogger("trezorlib")
|
logger = logging.getLogger("trezorlib")
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
|
@ -14,7 +14,10 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
from . import messages
|
import io
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from . import messages, protobuf
|
||||||
|
|
||||||
map_type_to_class = {}
|
map_type_to_class = {}
|
||||||
map_class_to_type = {}
|
map_class_to_type = {}
|
||||||
@ -59,4 +62,17 @@ def get_class(t):
|
|||||||
return map_type_to_class[t]
|
return map_type_to_class[t]
|
||||||
|
|
||||||
|
|
||||||
|
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||||
|
message_type = msg.MESSAGE_WIRE_TYPE
|
||||||
|
buf = io.BytesIO()
|
||||||
|
protobuf.dump_message(buf, msg)
|
||||||
|
return message_type, buf.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType:
|
||||||
|
cls = get_class(message_type)
|
||||||
|
buf = io.BytesIO(message_bytes)
|
||||||
|
return protobuf.load_message(buf, cls)
|
||||||
|
|
||||||
|
|
||||||
build_map()
|
build_map()
|
||||||
|
@ -15,10 +15,9 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, List, Type
|
from typing import Iterable, List, Tuple, Type
|
||||||
|
|
||||||
from ..exceptions import TrezorException
|
from ..exceptions import TrezorException
|
||||||
from ..protobuf import MessageType
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -35,6 +34,9 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
MessagePayload = Tuple[int, bytes]
|
||||||
|
|
||||||
|
|
||||||
class TransportException(TrezorException):
|
class TransportException(TrezorException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -71,10 +73,10 @@ class Transport:
|
|||||||
def end_session(self) -> None:
|
def end_session(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def read(self) -> MessageType:
|
def read(self) -> MessagePayload:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def write(self, message: MessageType) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -16,14 +16,12 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from io import BytesIO
|
|
||||||
from typing import Any, Dict, Iterable, Optional
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from .. import mapping, protobuf
|
from ..log import DUMP_PACKETS
|
||||||
from ..log import DUMP_BYTES
|
from . import MessagePayload, Transport, TransportException
|
||||||
from . import Transport, TransportException
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -66,10 +64,12 @@ class BridgeHandle:
|
|||||||
|
|
||||||
class BridgeHandleModern(BridgeHandle):
|
class BridgeHandleModern(BridgeHandle):
|
||||||
def write_buf(self, buf: bytes) -> None:
|
def write_buf(self, buf: bytes) -> None:
|
||||||
|
LOG.log(DUMP_PACKETS, "sending message: {}".format(buf.hex()))
|
||||||
self.transport._call("post", data=buf.hex())
|
self.transport._call("post", data=buf.hex())
|
||||||
|
|
||||||
def read_buf(self) -> bytes:
|
def read_buf(self) -> bytes:
|
||||||
data = self.transport._call("read")
|
data = self.transport._call("read")
|
||||||
|
LOG.log(DUMP_PACKETS, "received message: {}".format(data.text))
|
||||||
return bytes.fromhex(data.text)
|
return bytes.fromhex(data.text)
|
||||||
|
|
||||||
|
|
||||||
@ -87,7 +87,9 @@ class BridgeHandleLegacy(BridgeHandle):
|
|||||||
if self.request is None:
|
if self.request is None:
|
||||||
raise TransportException("Can't read without write on legacy Bridge")
|
raise TransportException("Can't read without write on legacy Bridge")
|
||||||
try:
|
try:
|
||||||
|
LOG.log(DUMP_PACKETS, "calling with message: {}".format(self.request))
|
||||||
data = self.transport._call("call", data=self.request)
|
data = self.transport._call("call", data=self.request)
|
||||||
|
LOG.log(DUMP_PACKETS, "received response: {}".format(data.text))
|
||||||
return bytes.fromhex(data.text)
|
return bytes.fromhex(data.text)
|
||||||
finally:
|
finally:
|
||||||
self.request = None
|
self.request = None
|
||||||
@ -152,29 +154,12 @@ class BridgeTransport(Transport):
|
|||||||
self._call("release")
|
self._call("release")
|
||||||
self.session = None
|
self.session = None
|
||||||
|
|
||||||
def write(self, msg: protobuf.MessageType) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
LOG.debug(
|
header = struct.pack(">HL", message_type, len(message_data))
|
||||||
"sending message: {}".format(msg.__class__.__name__),
|
self.handle.write_buf(header + message_data)
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
buffer = BytesIO()
|
|
||||||
protobuf.dump_message(buffer, msg)
|
|
||||||
ser = buffer.getvalue()
|
|
||||||
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
|
|
||||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
|
||||||
|
|
||||||
self.handle.write_buf(header + ser)
|
def read(self) -> MessagePayload:
|
||||||
|
|
||||||
def read(self) -> protobuf.MessageType:
|
|
||||||
data = self.handle.read_buf()
|
data = self.handle.read_buf()
|
||||||
headerlen = struct.calcsize(">HL")
|
headerlen = struct.calcsize(">HL")
|
||||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||||
ser = data[headerlen : headerlen + datalen]
|
return msg_type, data[headerlen : headerlen + datalen]
|
||||||
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
|
|
||||||
buffer = BytesIO(ser)
|
|
||||||
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
|
|
||||||
LOG.debug(
|
|
||||||
"received message: {}".format(msg.__class__.__name__),
|
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
return msg
|
|
||||||
|
@ -19,6 +19,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Iterable
|
from typing import Any, Dict, Iterable
|
||||||
|
|
||||||
|
from ..log import DUMP_PACKETS
|
||||||
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
@ -82,9 +83,10 @@ class HidHandle:
|
|||||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||||
|
|
||||||
if self.hid_version == 2:
|
if self.hid_version == 2:
|
||||||
self.handle.write(b"\0" + bytearray(chunk))
|
chunk = b"\x00" + chunk
|
||||||
else:
|
|
||||||
self.handle.write(chunk)
|
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
|
||||||
|
self.handle.write(chunk)
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self) -> bytes:
|
||||||
while True:
|
while True:
|
||||||
@ -93,6 +95,8 @@ class HidHandle:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||||
return bytes(chunk)
|
return bytes(chunk)
|
||||||
@ -119,8 +123,7 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||||
|
|
||||||
protocol = ProtocolV1(self.handle)
|
super().__init__(protocol=ProtocolV1(self.handle))
|
||||||
super().__init__(protocol=protocol)
|
|
||||||
|
|
||||||
def get_path(self) -> str:
|
def get_path(self) -> str:
|
||||||
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
|
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
|
||||||
@ -142,15 +145,11 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
return devices
|
return devices
|
||||||
|
|
||||||
def find_debug(self) -> "HidTransport":
|
def find_debug(self) -> "HidTransport":
|
||||||
if self.protocol.VERSION >= 2:
|
# For v1 protocol, find debug USB interface for the same serial number
|
||||||
# use the same device
|
for debug in HidTransport.enumerate(debug=True):
|
||||||
return self
|
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||||
else:
|
return debug
|
||||||
# For v1 protocol, find debug USB interface for the same serial number
|
raise TransportException("Debug HID device not found")
|
||||||
for debug in HidTransport.enumerate(debug=True):
|
|
||||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
|
||||||
return debug
|
|
||||||
raise TransportException("Debug HID device not found")
|
|
||||||
|
|
||||||
|
|
||||||
def is_wirelink(dev: HidDevice) -> bool:
|
def is_wirelink(dev: HidDevice) -> bool:
|
||||||
|
@ -15,16 +15,12 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import struct
|
import struct
|
||||||
from io import BytesIO
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from typing_extensions import Protocol as StructuralType
|
from typing_extensions import Protocol as StructuralType
|
||||||
|
|
||||||
from .. import mapping, protobuf
|
from . import MessagePayload, Transport
|
||||||
from ..log import DUMP_BYTES
|
|
||||||
from . import Transport
|
|
||||||
|
|
||||||
REPLEN = 64
|
REPLEN = 64
|
||||||
|
|
||||||
@ -72,7 +68,6 @@ class Protocol:
|
|||||||
- open and close physical connections,
|
- open and close physical connections,
|
||||||
- and send and receive binary chunks.
|
- and send and receive binary chunks.
|
||||||
|
|
||||||
We declare a protocol version (we have implementations of v1 and v2).
|
|
||||||
For now, the class also handles session counting and opening the underlying Handle.
|
For now, the class also handles session counting and opening the underlying Handle.
|
||||||
This will probably be removed in the future.
|
This will probably be removed in the future.
|
||||||
|
|
||||||
@ -80,8 +75,6 @@ class Protocol:
|
|||||||
its messages.
|
its messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
VERSION = None # type: int
|
|
||||||
|
|
||||||
def __init__(self, handle: Handle) -> None:
|
def __init__(self, handle: Handle) -> None:
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.session_counter = 0
|
self.session_counter = 0
|
||||||
@ -97,10 +90,10 @@ class Protocol:
|
|||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
self.handle.close()
|
self.handle.close()
|
||||||
|
|
||||||
def read(self) -> protobuf.MessageType:
|
def read(self) -> MessagePayload:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def write(self, message: protobuf.MessageType) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -114,10 +107,10 @@ class ProtocolBasedTransport(Transport):
|
|||||||
def __init__(self, protocol: Protocol) -> None:
|
def __init__(self, protocol: Protocol) -> None:
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
|
|
||||||
def write(self, message: protobuf.MessageType) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
self.protocol.write(message)
|
self.protocol.write(message_type, message_data)
|
||||||
|
|
||||||
def read(self) -> protobuf.MessageType:
|
def read(self) -> MessagePayload:
|
||||||
return self.protocol.read()
|
return self.protocol.read()
|
||||||
|
|
||||||
def begin_session(self) -> None:
|
def begin_session(self) -> None:
|
||||||
@ -132,19 +125,11 @@ class ProtocolV1(Protocol):
|
|||||||
Does not understand sessions.
|
Does not understand sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
VERSION = 1
|
HEADER_LEN = struct.calcsize(">HL")
|
||||||
|
|
||||||
def write(self, msg: protobuf.MessageType) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
LOG.debug(
|
header = struct.pack(">HL", message_type, len(message_data))
|
||||||
"sending message: {}".format(msg.__class__.__name__),
|
buffer = bytearray(b"##" + header + message_data)
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
data = BytesIO()
|
|
||||||
protobuf.dump_message(data, msg)
|
|
||||||
ser = data.getvalue()
|
|
||||||
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
|
|
||||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
|
||||||
buffer = bytearray(b"##" + header + ser)
|
|
||||||
|
|
||||||
while buffer:
|
while buffer:
|
||||||
# Report ID, data padded to 63 bytes
|
# Report ID, data padded to 63 bytes
|
||||||
@ -153,7 +138,7 @@ class ProtocolV1(Protocol):
|
|||||||
self.handle.write_chunk(chunk)
|
self.handle.write_chunk(chunk)
|
||||||
buffer = buffer[63:]
|
buffer = buffer[63:]
|
||||||
|
|
||||||
def read(self) -> protobuf.MessageType:
|
def read(self) -> MessagePayload:
|
||||||
buffer = bytearray()
|
buffer = bytearray()
|
||||||
# Read header with first part of message data
|
# Read header with first part of message data
|
||||||
msg_type, datalen, first_chunk = self.read_first()
|
msg_type, datalen, first_chunk = self.read_first()
|
||||||
@ -163,30 +148,18 @@ class ProtocolV1(Protocol):
|
|||||||
while len(buffer) < datalen:
|
while len(buffer) < datalen:
|
||||||
buffer.extend(self.read_next())
|
buffer.extend(self.read_next())
|
||||||
|
|
||||||
# Strip padding
|
return msg_type, buffer[:datalen]
|
||||||
ser = buffer[:datalen]
|
|
||||||
data = BytesIO(ser)
|
|
||||||
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
|
|
||||||
|
|
||||||
# Parse to protobuf
|
|
||||||
msg = protobuf.load_message(data, mapping.get_class(msg_type))
|
|
||||||
LOG.debug(
|
|
||||||
"received message: {}".format(msg.__class__.__name__),
|
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def read_first(self) -> Tuple[int, int, bytes]:
|
def read_first(self) -> Tuple[int, int, bytes]:
|
||||||
chunk = self.handle.read_chunk()
|
chunk = self.handle.read_chunk()
|
||||||
if chunk[:3] != b"?##":
|
if chunk[:3] != b"?##":
|
||||||
raise RuntimeError("Unexpected magic characters")
|
raise RuntimeError("Unexpected magic characters")
|
||||||
try:
|
try:
|
||||||
headerlen = struct.calcsize(">HL")
|
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise RuntimeError("Cannot parse header")
|
raise RuntimeError("Cannot parse header")
|
||||||
|
|
||||||
data = chunk[3 + headerlen :]
|
data = chunk[3 + self.HEADER_LEN :]
|
||||||
return msg_type, datalen, data
|
return msg_type, datalen, data
|
||||||
|
|
||||||
def read_next(self) -> bytes:
|
def read_next(self) -> bytes:
|
||||||
@ -194,160 +167,3 @@ class ProtocolV1(Protocol):
|
|||||||
if chunk[:1] != b"?":
|
if chunk[:1] != b"?":
|
||||||
raise RuntimeError("Unexpected magic characters")
|
raise RuntimeError("Unexpected magic characters")
|
||||||
return chunk[1:]
|
return chunk[1:]
|
||||||
|
|
||||||
|
|
||||||
class ProtocolV2(Protocol):
|
|
||||||
"""Protocol version 2. Currently (11/2018) not used.
|
|
||||||
Intended to mimic U2F/WebAuthN session handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
VERSION = 2
|
|
||||||
|
|
||||||
def __init__(self, handle: Handle) -> None:
|
|
||||||
self.session = None
|
|
||||||
super().__init__(handle)
|
|
||||||
|
|
||||||
def begin_session(self) -> None:
|
|
||||||
# ensure open connection
|
|
||||||
super().begin_session()
|
|
||||||
|
|
||||||
# initiate session
|
|
||||||
chunk = struct.pack(">B", V2_BEGIN_SESSION)
|
|
||||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
|
||||||
self.handle.write_chunk(chunk)
|
|
||||||
|
|
||||||
# get session identifier
|
|
||||||
resp = self.handle.read_chunk()
|
|
||||||
try:
|
|
||||||
headerlen = struct.calcsize(">BL")
|
|
||||||
magic, session = struct.unpack(">BL", resp[:headerlen])
|
|
||||||
except Exception:
|
|
||||||
raise RuntimeError("Cannot parse header")
|
|
||||||
if magic != V2_BEGIN_SESSION:
|
|
||||||
raise RuntimeError("Unexpected magic character")
|
|
||||||
self.session = session
|
|
||||||
|
|
||||||
LOG.debug("[session {}] session started".format(self.session))
|
|
||||||
|
|
||||||
def end_session(self) -> None:
|
|
||||||
if not self.session:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk = struct.pack(">BL", V2_END_SESSION, self.session)
|
|
||||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
|
||||||
self.handle.write_chunk(chunk)
|
|
||||||
resp = self.handle.read_chunk()
|
|
||||||
(magic,) = struct.unpack(">B", resp[:1])
|
|
||||||
if magic != V2_END_SESSION:
|
|
||||||
raise RuntimeError("Expected session close")
|
|
||||||
LOG.debug("[session {}] session ended".format(self.session))
|
|
||||||
finally:
|
|
||||||
self.session = None
|
|
||||||
# close connection if appropriate
|
|
||||||
super().end_session()
|
|
||||||
|
|
||||||
def write(self, msg: protobuf.MessageType) -> None:
|
|
||||||
if not self.session:
|
|
||||||
raise RuntimeError("Missing session for v2 protocol")
|
|
||||||
|
|
||||||
LOG.debug(
|
|
||||||
"[session {}] sending message: {}".format(
|
|
||||||
self.session, msg.__class__.__name__
|
|
||||||
),
|
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
# Serialize whole message
|
|
||||||
data = BytesIO()
|
|
||||||
protobuf.dump_message(data, msg)
|
|
||||||
data = data.getvalue()
|
|
||||||
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
|
|
||||||
data = dataheader + data
|
|
||||||
seq = -1
|
|
||||||
|
|
||||||
# Write it out
|
|
||||||
while data:
|
|
||||||
if seq < 0:
|
|
||||||
repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session)
|
|
||||||
else:
|
|
||||||
repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq)
|
|
||||||
datalen = REPLEN - len(repheader)
|
|
||||||
chunk = repheader + data[:datalen]
|
|
||||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
|
||||||
self.handle.write_chunk(chunk)
|
|
||||||
data = data[datalen:]
|
|
||||||
seq += 1
|
|
||||||
|
|
||||||
def read(self) -> protobuf.MessageType:
|
|
||||||
if not self.session:
|
|
||||||
raise RuntimeError("Missing session for v2 protocol")
|
|
||||||
|
|
||||||
buffer = bytearray()
|
|
||||||
|
|
||||||
# Read header with first part of message data
|
|
||||||
msg_type, datalen, chunk = self.read_first()
|
|
||||||
buffer.extend(chunk)
|
|
||||||
|
|
||||||
# Read the rest of the message
|
|
||||||
while len(buffer) < datalen:
|
|
||||||
next_chunk = self.read_next()
|
|
||||||
buffer.extend(next_chunk)
|
|
||||||
|
|
||||||
# Strip padding
|
|
||||||
buffer = BytesIO(buffer[:datalen])
|
|
||||||
|
|
||||||
# Parse to protobuf
|
|
||||||
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
|
|
||||||
LOG.debug(
|
|
||||||
"[session {}] received message: {}".format(
|
|
||||||
self.session, msg.__class__.__name__
|
|
||||||
),
|
|
||||||
extra={"protobuf": msg},
|
|
||||||
)
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def read_first(self) -> Tuple[int, int, bytes]:
|
|
||||||
chunk = self.handle.read_chunk()
|
|
||||||
try:
|
|
||||||
headerlen = struct.calcsize(">BLLL")
|
|
||||||
magic, session, msg_type, datalen = struct.unpack(
|
|
||||||
">BLLL", chunk[:headerlen]
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
raise RuntimeError("Cannot parse header")
|
|
||||||
if magic != V2_FIRST_CHUNK:
|
|
||||||
raise RuntimeError("Unexpected magic character")
|
|
||||||
if session != self.session:
|
|
||||||
raise RuntimeError("Session id mismatch")
|
|
||||||
return msg_type, datalen, chunk[headerlen:]
|
|
||||||
|
|
||||||
def read_next(self) -> bytes:
|
|
||||||
chunk = self.handle.read_chunk()
|
|
||||||
try:
|
|
||||||
headerlen = struct.calcsize(">BLL")
|
|
||||||
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
|
|
||||||
except Exception:
|
|
||||||
raise RuntimeError("Cannot parse header")
|
|
||||||
if magic != V2_NEXT_CHUNK:
|
|
||||||
raise RuntimeError("Unexpected magic characters")
|
|
||||||
if session != self.session:
|
|
||||||
raise RuntimeError("Session id mismatch")
|
|
||||||
return chunk[headerlen:]
|
|
||||||
|
|
||||||
|
|
||||||
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
|
|
||||||
"""Make a Protocol instance for the given handle.
|
|
||||||
|
|
||||||
Each transport can have a preference for using a particular protocol version.
|
|
||||||
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
|
|
||||||
which forces the library to use V1 anyways.
|
|
||||||
|
|
||||||
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
|
|
||||||
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
|
|
||||||
for it (i.e., USB transports for Trezor T).
|
|
||||||
"""
|
|
||||||
force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1))
|
|
||||||
if want_v2 and not force_v1:
|
|
||||||
return ProtocolV2(handle)
|
|
||||||
else:
|
|
||||||
return ProtocolV1(handle)
|
|
||||||
|
@ -14,15 +14,19 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional, cast
|
from typing import Iterable, Optional, cast
|
||||||
|
|
||||||
|
from ..log import DUMP_PACKETS
|
||||||
from . import TransportException
|
from . import TransportException
|
||||||
from .protocol import ProtocolBasedTransport, get_protocol
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
SOCKET_TIMEOUT = 10
|
SOCKET_TIMEOUT = 10
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UdpTransport(ProtocolBasedTransport):
|
class UdpTransport(ProtocolBasedTransport):
|
||||||
|
|
||||||
@ -42,8 +46,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
self.device = (host, port)
|
self.device = (host, port)
|
||||||
self.socket = None # type: Optional[socket.socket]
|
self.socket = None # type: Optional[socket.socket]
|
||||||
|
|
||||||
protocol = get_protocol(self, want_v2=False)
|
super().__init__(protocol=ProtocolV1(self))
|
||||||
super().__init__(protocol=protocol)
|
|
||||||
|
|
||||||
def get_path(self) -> str:
|
def get_path(self) -> str:
|
||||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||||
@ -126,6 +129,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
assert self.socket is not None
|
assert self.socket is not None
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected data length")
|
raise TransportException("Unexpected data length")
|
||||||
|
LOG.log(DUMP_PACKETS, "sending packet: {}".format(chunk.hex()))
|
||||||
self.socket.sendall(chunk)
|
self.socket.sendall(chunk)
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self) -> bytes:
|
||||||
@ -136,6 +140,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
break
|
break
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
continue
|
continue
|
||||||
|
LOG.log(DUMP_PACKETS, "received packet: {}".format(chunk.hex()))
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||||
return bytearray(chunk)
|
return bytearray(chunk)
|
||||||
|
@ -20,6 +20,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
from ..log import DUMP_PACKETS
|
||||||
from . import TREZORS, UDEV_RULES_STR, TransportException
|
from . import TREZORS, UDEV_RULES_STR, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
@ -65,6 +66,7 @@ class WebUsbHandle:
|
|||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||||
|
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
|
||||||
self.handle.interruptWrite(self.endpoint, chunk)
|
self.handle.interruptWrite(self.endpoint, chunk)
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self) -> bytes:
|
||||||
@ -76,6 +78,7 @@ class WebUsbHandle:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||||
return chunk
|
return chunk
|
||||||
@ -136,14 +139,8 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
return devices
|
return devices
|
||||||
|
|
||||||
def find_debug(self) -> "WebUsbTransport":
|
def find_debug(self) -> "WebUsbTransport":
|
||||||
if self.protocol.VERSION >= 2:
|
# For v1 protocol, find debug USB interface for the same serial number
|
||||||
# TODO test this
|
return WebUsbTransport(self.device, debug=True)
|
||||||
# XXX this is broken right now because sessions don't really work
|
|
||||||
# For v2 protocol, use the same WebUSB interface with a different session
|
|
||||||
return WebUsbTransport(self.device, self.handle)
|
|
||||||
else:
|
|
||||||
# For v1 protocol, find debug USB interface for the same serial number
|
|
||||||
return WebUsbTransport(self.device, debug=True)
|
|
||||||
|
|
||||||
|
|
||||||
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
||||||
|
@ -59,9 +59,9 @@ def test_cancel_message_via_initialize(client, message):
|
|||||||
resp = client.call_raw(message)
|
resp = client.call_raw(message)
|
||||||
assert isinstance(resp, m.ButtonRequest)
|
assert isinstance(resp, m.ButtonRequest)
|
||||||
|
|
||||||
client.transport.write(m.ButtonAck())
|
client._raw_write(m.ButtonAck())
|
||||||
client.transport.write(m.Initialize())
|
client._raw_write(m.Initialize())
|
||||||
|
|
||||||
resp = client.transport.read()
|
resp = client._raw_read()
|
||||||
|
|
||||||
assert isinstance(resp, m.Features)
|
assert isinstance(resp, m.Features)
|
||||||
|
@ -69,10 +69,10 @@ class TestMsgRecoverydeviceT2:
|
|||||||
|
|
||||||
# Enter mnemonic words
|
# Enter mnemonic words
|
||||||
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
|
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
|
||||||
client.transport.write(proto.ButtonAck())
|
client._raw_write(proto.ButtonAck())
|
||||||
for word in mnemonic:
|
for word in mnemonic:
|
||||||
client.debug.input(word)
|
client.debug.input(word)
|
||||||
ret = client.transport.read()
|
ret = client._raw_read()
|
||||||
|
|
||||||
# Confirm success
|
# Confirm success
|
||||||
assert isinstance(ret, proto.ButtonRequest)
|
assert isinstance(ret, proto.ButtonRequest)
|
||||||
@ -125,10 +125,10 @@ class TestMsgRecoverydeviceT2:
|
|||||||
|
|
||||||
# Enter mnemonic words
|
# Enter mnemonic words
|
||||||
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
|
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
|
||||||
client.transport.write(proto.ButtonAck())
|
client._raw_write(proto.ButtonAck())
|
||||||
for word in mnemonic:
|
for word in mnemonic:
|
||||||
client.debug.input(word)
|
client.debug.input(word)
|
||||||
ret = client.transport.read()
|
ret = client._raw_read()
|
||||||
|
|
||||||
# Confirm success
|
# Confirm success
|
||||||
assert isinstance(ret, proto.ButtonRequest)
|
assert isinstance(ret, proto.ButtonRequest)
|
||||||
|
Loading…
Reference in New Issue
Block a user