mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-03 00:36:06 +00:00
python: unify protobuf-encoding code paths
Protobuf encoding now happens in TrezorClient, and transports get encoded blobs to (chunkify and) send. This is a better design because transports don't need to know about protobuf. It also lays groundwork for sending raw bytes feature (#116) This commit also removes all vestiges of ProtocolV2 which was never used and will probably need to be redesigned from the ground up anyway. The code is still ready for protocol flexibility.
This commit is contained in:
parent
22b167a961
commit
9a330f3475
@ -29,6 +29,8 @@ _At the moment, the project does **not** adhere to [Semantic Versioning](https:/
|
||||
- `get_default_client` respects `TREZOR_PATH` environment variable
|
||||
- UI callback `get_passphrase` has an additional argument `available_on_device`,
|
||||
indicating that the connected Trezor is capable of on-device entry
|
||||
- `Transport.write` and `read` method signatures changed to accept bytes instead of
|
||||
protobuf messages
|
||||
|
||||
### Fixed
|
||||
|
||||
|
@ -21,7 +21,8 @@ import warnings
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import MINIMUM_FIRMWARE_VERSION, exceptions, messages, tools
|
||||
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability
|
||||
|
||||
if sys.version_info.major < 3:
|
||||
@ -114,11 +115,34 @@ class TrezorClient:
|
||||
|
||||
def _raw_write(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
self.transport.write(msg)
|
||||
LOG.debug(
|
||||
"sending message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
"encoded as type {} ({} bytes): {}".format(
|
||||
msg_type, len(msg_bytes), msg_bytes.hex()
|
||||
),
|
||||
)
|
||||
self.transport.write(msg_type, msg_bytes)
|
||||
|
||||
def _raw_read(self):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
return self.transport.read()
|
||||
msg_type, msg_bytes = self.transport.read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
"received type {} ({} bytes): {}".format(
|
||||
msg_type, len(msg_bytes), msg_bytes.hex()
|
||||
),
|
||||
)
|
||||
msg = mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
"received message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def _callback_pin(self, msg):
|
||||
try:
|
||||
|
@ -19,7 +19,7 @@ from copy import deepcopy
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import messages as proto, protobuf
|
||||
from . import mapping, messages as proto, protobuf
|
||||
from .client import TrezorClient
|
||||
from .tools import expect
|
||||
|
||||
@ -45,11 +45,12 @@ class DebugLink:
|
||||
self.transport.end_session()
|
||||
|
||||
def _call(self, msg, nowait=False):
|
||||
self.transport.write(msg)
|
||||
msg_type, msg_bytes = mapping.encode(msg)
|
||||
self.transport.write(msg_type, msg_bytes)
|
||||
if nowait:
|
||||
return None
|
||||
ret = self.transport.read()
|
||||
return ret
|
||||
ret_type, ret_bytes = self.transport.read()
|
||||
return mapping.decode(ret_type, ret_bytes)
|
||||
|
||||
def state(self):
|
||||
return self._call(proto.DebugLinkGetState())
|
||||
|
@ -22,8 +22,10 @@ from . import protobuf
|
||||
OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]]
|
||||
|
||||
DUMP_BYTES = 5
|
||||
DUMP_PACKETS = 4
|
||||
|
||||
logging.addLevelName(DUMP_BYTES, "BYTES")
|
||||
logging.addLevelName(DUMP_PACKETS, "PACKETS")
|
||||
|
||||
|
||||
class PrettyProtobufFormatter(logging.Formatter):
|
||||
@ -54,6 +56,8 @@ def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] =
|
||||
level = logging.DEBUG
|
||||
if verbosity > 1:
|
||||
level = DUMP_BYTES
|
||||
if verbosity > 2:
|
||||
level = DUMP_PACKETS
|
||||
|
||||
logger = logging.getLogger("trezorlib")
|
||||
logger.setLevel(level)
|
||||
|
@ -14,7 +14,10 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from . import messages
|
||||
import io
|
||||
from typing import Tuple
|
||||
|
||||
from . import messages, protobuf
|
||||
|
||||
map_type_to_class = {}
|
||||
map_class_to_type = {}
|
||||
@ -59,4 +62,17 @@ def get_class(t):
|
||||
return map_type_to_class[t]
|
||||
|
||||
|
||||
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||
message_type = msg.MESSAGE_WIRE_TYPE
|
||||
buf = io.BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
return message_type, buf.getvalue()
|
||||
|
||||
|
||||
def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType:
|
||||
cls = get_class(message_type)
|
||||
buf = io.BytesIO(message_bytes)
|
||||
return protobuf.load_message(buf, cls)
|
||||
|
||||
|
||||
build_map()
|
||||
|
@ -15,10 +15,9 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, List, Type
|
||||
from typing import Iterable, List, Tuple, Type
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
from ..protobuf import MessageType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -35,6 +34,9 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
||||
""".strip()
|
||||
|
||||
|
||||
MessagePayload = Tuple[int, bytes]
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
pass
|
||||
|
||||
@ -71,10 +73,10 @@ class Transport:
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self) -> MessageType:
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message: MessageType) -> None:
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -16,14 +16,12 @@
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .. import mapping, protobuf
|
||||
from ..log import DUMP_BYTES
|
||||
from . import Transport, TransportException
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import MessagePayload, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -66,10 +64,12 @@ class BridgeHandle:
|
||||
|
||||
class BridgeHandleModern(BridgeHandle):
|
||||
def write_buf(self, buf: bytes) -> None:
|
||||
LOG.log(DUMP_PACKETS, "sending message: {}".format(buf.hex()))
|
||||
self.transport._call("post", data=buf.hex())
|
||||
|
||||
def read_buf(self) -> bytes:
|
||||
data = self.transport._call("read")
|
||||
LOG.log(DUMP_PACKETS, "received message: {}".format(data.text))
|
||||
return bytes.fromhex(data.text)
|
||||
|
||||
|
||||
@ -87,7 +87,9 @@ class BridgeHandleLegacy(BridgeHandle):
|
||||
if self.request is None:
|
||||
raise TransportException("Can't read without write on legacy Bridge")
|
||||
try:
|
||||
LOG.log(DUMP_PACKETS, "calling with message: {}".format(self.request))
|
||||
data = self.transport._call("call", data=self.request)
|
||||
LOG.log(DUMP_PACKETS, "received response: {}".format(data.text))
|
||||
return bytes.fromhex(data.text)
|
||||
finally:
|
||||
self.request = None
|
||||
@ -152,29 +154,12 @@ class BridgeTransport(Transport):
|
||||
self._call("release")
|
||||
self.session = None
|
||||
|
||||
def write(self, msg: protobuf.MessageType) -> None:
|
||||
LOG.debug(
|
||||
"sending message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
buffer = BytesIO()
|
||||
protobuf.dump_message(buffer, msg)
|
||||
ser = buffer.getvalue()
|
||||
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
|
||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
self.handle.write_buf(header + message_data)
|
||||
|
||||
self.handle.write_buf(header + ser)
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
def read(self) -> MessagePayload:
|
||||
data = self.handle.read_buf()
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
ser = data[headerlen : headerlen + datalen]
|
||||
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
|
||||
buffer = BytesIO(ser)
|
||||
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"received message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
return msg_type, data[headerlen : headerlen + datalen]
|
||||
|
@ -19,6 +19,7 @@ import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
@ -82,9 +83,10 @@ class HidHandle:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
|
||||
if self.hid_version == 2:
|
||||
self.handle.write(b"\0" + bytearray(chunk))
|
||||
else:
|
||||
self.handle.write(chunk)
|
||||
chunk = b"\x00" + chunk
|
||||
|
||||
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
|
||||
self.handle.write(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
while True:
|
||||
@ -93,6 +95,8 @@ class HidHandle:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
|
||||
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return bytes(chunk)
|
||||
@ -119,8 +123,7 @@ class HidTransport(ProtocolBasedTransport):
|
||||
self.device = device
|
||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||
|
||||
protocol = ProtocolV1(self.handle)
|
||||
super().__init__(protocol=protocol)
|
||||
super().__init__(protocol=ProtocolV1(self.handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
|
||||
@ -142,15 +145,11 @@ class HidTransport(ProtocolBasedTransport):
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
if self.protocol.VERSION >= 2:
|
||||
# use the same device
|
||||
return self
|
||||
else:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
|
@ -15,16 +15,12 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from .. import mapping, protobuf
|
||||
from ..log import DUMP_BYTES
|
||||
from . import Transport
|
||||
from . import MessagePayload, Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
@ -72,7 +68,6 @@ class Protocol:
|
||||
- open and close physical connections,
|
||||
- and send and receive binary chunks.
|
||||
|
||||
We declare a protocol version (we have implementations of v1 and v2).
|
||||
For now, the class also handles session counting and opening the underlying Handle.
|
||||
This will probably be removed in the future.
|
||||
|
||||
@ -80,8 +75,6 @@ class Protocol:
|
||||
its messages.
|
||||
"""
|
||||
|
||||
VERSION = None # type: int
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.handle = handle
|
||||
self.session_counter = 0
|
||||
@ -97,10 +90,10 @@ class Protocol:
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message: protobuf.MessageType) -> None:
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -114,10 +107,10 @@ class ProtocolBasedTransport(Transport):
|
||||
def __init__(self, protocol: Protocol) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def write(self, message: protobuf.MessageType) -> None:
|
||||
self.protocol.write(message)
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
self.protocol.write(message_type, message_data)
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
def read(self) -> MessagePayload:
|
||||
return self.protocol.read()
|
||||
|
||||
def begin_session(self) -> None:
|
||||
@ -132,19 +125,11 @@ class ProtocolV1(Protocol):
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
VERSION = 1
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def write(self, msg: protobuf.MessageType) -> None:
|
||||
LOG.debug(
|
||||
"sending message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
data = BytesIO()
|
||||
protobuf.dump_message(data, msg)
|
||||
ser = data.getvalue()
|
||||
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
|
||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
||||
buffer = bytearray(b"##" + header + ser)
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
@ -153,7 +138,7 @@ class ProtocolV1(Protocol):
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
def read(self) -> MessagePayload:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
@ -163,30 +148,18 @@ class ProtocolV1(Protocol):
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
# Strip padding
|
||||
ser = buffer[:datalen]
|
||||
data = BytesIO(ser)
|
||||
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
|
||||
|
||||
# Parse to protobuf
|
||||
msg = protobuf.load_message(data, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"received message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + headerlen :]
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
@ -194,160 +167,3 @@ class ProtocolV1(Protocol):
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
||||
|
||||
|
||||
class ProtocolV2(Protocol):
|
||||
"""Protocol version 2. Currently (11/2018) not used.
|
||||
Intended to mimic U2F/WebAuthN session handling.
|
||||
"""
|
||||
|
||||
VERSION = 2
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.session = None
|
||||
super().__init__(handle)
|
||||
|
||||
def begin_session(self) -> None:
|
||||
# ensure open connection
|
||||
super().begin_session()
|
||||
|
||||
# initiate session
|
||||
chunk = struct.pack(">B", V2_BEGIN_SESSION)
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
|
||||
# get session identifier
|
||||
resp = self.handle.read_chunk()
|
||||
try:
|
||||
headerlen = struct.calcsize(">BL")
|
||||
magic, session = struct.unpack(">BL", resp[:headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != V2_BEGIN_SESSION:
|
||||
raise RuntimeError("Unexpected magic character")
|
||||
self.session = session
|
||||
|
||||
LOG.debug("[session {}] session started".format(self.session))
|
||||
|
||||
def end_session(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
chunk = struct.pack(">BL", V2_END_SESSION, self.session)
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
resp = self.handle.read_chunk()
|
||||
(magic,) = struct.unpack(">B", resp[:1])
|
||||
if magic != V2_END_SESSION:
|
||||
raise RuntimeError("Expected session close")
|
||||
LOG.debug("[session {}] session ended".format(self.session))
|
||||
finally:
|
||||
self.session = None
|
||||
# close connection if appropriate
|
||||
super().end_session()
|
||||
|
||||
def write(self, msg: protobuf.MessageType) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("Missing session for v2 protocol")
|
||||
|
||||
LOG.debug(
|
||||
"[session {}] sending message: {}".format(
|
||||
self.session, msg.__class__.__name__
|
||||
),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
# Serialize whole message
|
||||
data = BytesIO()
|
||||
protobuf.dump_message(data, msg)
|
||||
data = data.getvalue()
|
||||
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
|
||||
data = dataheader + data
|
||||
seq = -1
|
||||
|
||||
# Write it out
|
||||
while data:
|
||||
if seq < 0:
|
||||
repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session)
|
||||
else:
|
||||
repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq)
|
||||
datalen = REPLEN - len(repheader)
|
||||
chunk = repheader + data[:datalen]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
data = data[datalen:]
|
||||
seq += 1
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
if not self.session:
|
||||
raise RuntimeError("Missing session for v2 protocol")
|
||||
|
||||
buffer = bytearray()
|
||||
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, chunk = self.read_first()
|
||||
buffer.extend(chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
next_chunk = self.read_next()
|
||||
buffer.extend(next_chunk)
|
||||
|
||||
# Strip padding
|
||||
buffer = BytesIO(buffer[:datalen])
|
||||
|
||||
# Parse to protobuf
|
||||
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"[session {}] received message: {}".format(
|
||||
self.session, msg.__class__.__name__
|
||||
),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLLL")
|
||||
magic, session, msg_type, datalen = struct.unpack(
|
||||
">BLLL", chunk[:headerlen]
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != V2_FIRST_CHUNK:
|
||||
raise RuntimeError("Unexpected magic character")
|
||||
if session != self.session:
|
||||
raise RuntimeError("Session id mismatch")
|
||||
return msg_type, datalen, chunk[headerlen:]
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLL")
|
||||
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != V2_NEXT_CHUNK:
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
if session != self.session:
|
||||
raise RuntimeError("Session id mismatch")
|
||||
return chunk[headerlen:]
|
||||
|
||||
|
||||
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
|
||||
"""Make a Protocol instance for the given handle.
|
||||
|
||||
Each transport can have a preference for using a particular protocol version.
|
||||
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
|
||||
which forces the library to use V1 anyways.
|
||||
|
||||
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
|
||||
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
|
||||
for it (i.e., USB transports for Trezor T).
|
||||
"""
|
||||
force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1))
|
||||
if want_v2 and not force_v1:
|
||||
return ProtocolV2(handle)
|
||||
else:
|
||||
return ProtocolV1(handle)
|
||||
|
@ -14,15 +14,19 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import Iterable, Optional, cast
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, get_protocol
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
SOCKET_TIMEOUT = 10
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(ProtocolBasedTransport):
|
||||
|
||||
@ -42,8 +46,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.device = (host, port)
|
||||
self.socket = None # type: Optional[socket.socket]
|
||||
|
||||
protocol = get_protocol(self, want_v2=False)
|
||||
super().__init__(protocol=protocol)
|
||||
super().__init__(protocol=ProtocolV1(self))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
@ -126,6 +129,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
LOG.log(DUMP_PACKETS, "sending packet: {}".format(chunk.hex()))
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
@ -136,6 +140,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
break
|
||||
except socket.timeout:
|
||||
continue
|
||||
LOG.log(DUMP_PACKETS, "received packet: {}".format(chunk.hex()))
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return bytearray(chunk)
|
||||
|
@ -20,6 +20,7 @@ import sys
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TREZORS, UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
@ -65,6 +66,7 @@ class WebUsbHandle:
|
||||
assert self.handle is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
|
||||
self.handle.interruptWrite(self.endpoint, chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
@ -76,6 +78,7 @@ class WebUsbHandle:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return chunk
|
||||
@ -136,14 +139,8 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
if self.protocol.VERSION >= 2:
|
||||
# TODO test this
|
||||
# XXX this is broken right now because sessions don't really work
|
||||
# For v2 protocol, use the same WebUSB interface with a different session
|
||||
return WebUsbTransport(self.device, self.handle)
|
||||
else:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
|
||||
|
||||
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
||||
|
@ -59,9 +59,9 @@ def test_cancel_message_via_initialize(client, message):
|
||||
resp = client.call_raw(message)
|
||||
assert isinstance(resp, m.ButtonRequest)
|
||||
|
||||
client.transport.write(m.ButtonAck())
|
||||
client.transport.write(m.Initialize())
|
||||
client._raw_write(m.ButtonAck())
|
||||
client._raw_write(m.Initialize())
|
||||
|
||||
resp = client.transport.read()
|
||||
resp = client._raw_read()
|
||||
|
||||
assert isinstance(resp, m.Features)
|
||||
|
@ -69,10 +69,10 @@ class TestMsgRecoverydeviceT2:
|
||||
|
||||
# Enter mnemonic words
|
||||
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
|
||||
client.transport.write(proto.ButtonAck())
|
||||
client._raw_write(proto.ButtonAck())
|
||||
for word in mnemonic:
|
||||
client.debug.input(word)
|
||||
ret = client.transport.read()
|
||||
ret = client._raw_read()
|
||||
|
||||
# Confirm success
|
||||
assert isinstance(ret, proto.ButtonRequest)
|
||||
@ -125,10 +125,10 @@ class TestMsgRecoverydeviceT2:
|
||||
|
||||
# Enter mnemonic words
|
||||
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
|
||||
client.transport.write(proto.ButtonAck())
|
||||
client._raw_write(proto.ButtonAck())
|
||||
for word in mnemonic:
|
||||
client.debug.input(word)
|
||||
ret = client.transport.read()
|
||||
ret = client._raw_read()
|
||||
|
||||
# Confirm success
|
||||
assert isinstance(ret, proto.ButtonRequest)
|
||||
|
Loading…
Reference in New Issue
Block a user