python: unify protobuf-encoding code paths

Protobuf encoding now happens in TrezorClient, and transports get encoded blobs
to (chunkify and) send.  This is a better design because transports don't need
to know about protobuf.

It also lays groundwork for sending raw bytes feature (#116)

This commit also removes all vestiges of ProtocolV2 which was never used and
will probably need to be redesigned from the ground up anyway. The code is
still ready for protocol flexibility.
pull/919/head
matejcik 4 years ago
parent 22b167a961
commit 9a330f3475

@ -29,6 +29,8 @@ _At the moment, the project does **not** adhere to [Semantic Versioning](https:/
- `get_default_client` respects `TREZOR_PATH` environment variable
- UI callback `get_passphrase` has an additional argument `available_on_device`,
indicating that the connected Trezor is capable of on-device entry
- `Transport.write` and `read` method signatures changed to accept bytes instead of
protobuf messages
### Fixed

@ -21,7 +21,8 @@ import warnings
from mnemonic import Mnemonic
from . import MINIMUM_FIRMWARE_VERSION, exceptions, messages, tools
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages, tools
from .log import DUMP_BYTES
from .messages import Capability
if sys.version_info.major < 3:
@ -114,11 +115,34 @@ class TrezorClient:
def _raw_write(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self.transport.write(msg)
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
msg_type, msg_bytes = mapping.encode(msg)
LOG.log(
DUMP_BYTES,
"encoded as type {} ({} bytes): {}".format(
msg_type, len(msg_bytes), msg_bytes.hex()
),
)
self.transport.write(msg_type, msg_bytes)
def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
return self.transport.read()
msg_type, msg_bytes = self.transport.read()
LOG.log(
DUMP_BYTES,
"received type {} ({} bytes): {}".format(
msg_type, len(msg_bytes), msg_bytes.hex()
),
)
msg = mapping.decode(msg_type, msg_bytes)
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def _callback_pin(self, msg):
try:

@ -19,7 +19,7 @@ from copy import deepcopy
from mnemonic import Mnemonic
from . import messages as proto, protobuf
from . import mapping, messages as proto, protobuf
from .client import TrezorClient
from .tools import expect
@ -45,11 +45,12 @@ class DebugLink:
self.transport.end_session()
def _call(self, msg, nowait=False):
self.transport.write(msg)
msg_type, msg_bytes = mapping.encode(msg)
self.transport.write(msg_type, msg_bytes)
if nowait:
return None
ret = self.transport.read()
return ret
ret_type, ret_bytes = self.transport.read()
return mapping.decode(ret_type, ret_bytes)
def state(self):
return self._call(proto.DebugLinkGetState())

@ -22,8 +22,10 @@ from . import protobuf
OMITTED_MESSAGES = set() # type: Set[Type[protobuf.MessageType]]
DUMP_BYTES = 5
DUMP_PACKETS = 4
logging.addLevelName(DUMP_BYTES, "BYTES")
logging.addLevelName(DUMP_PACKETS, "PACKETS")
class PrettyProtobufFormatter(logging.Formatter):
@ -54,6 +56,8 @@ def enable_debug_output(verbosity: int = 1, handler: Optional[logging.Handler] =
level = logging.DEBUG
if verbosity > 1:
level = DUMP_BYTES
if verbosity > 2:
level = DUMP_PACKETS
logger = logging.getLogger("trezorlib")
logger.setLevel(level)

@ -14,7 +14,10 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages
import io
from typing import Tuple
from . import messages, protobuf
map_type_to_class = {}
map_class_to_type = {}
@ -59,4 +62,17 @@ def get_class(t):
return map_type_to_class[t]
def encode(msg: protobuf.MessageType) -> Tuple[int, bytes]:
message_type = msg.MESSAGE_WIRE_TYPE
buf = io.BytesIO()
protobuf.dump_message(buf, msg)
return message_type, buf.getvalue()
def decode(message_type: int, message_bytes: bytes) -> protobuf.MessageType:
cls = get_class(message_type)
buf = io.BytesIO(message_bytes)
return protobuf.load_message(buf, cls)
build_map()

@ -15,10 +15,9 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
from typing import Iterable, List, Type
from typing import Iterable, List, Tuple, Type
from ..exceptions import TrezorException
from ..protobuf import MessageType
LOG = logging.getLogger(__name__)
@ -35,6 +34,9 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip()
MessagePayload = Tuple[int, bytes]
class TransportException(TrezorException):
pass
@ -71,10 +73,10 @@ class Transport:
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessageType:
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message: MessageType) -> None:
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
@classmethod

@ -16,14 +16,12 @@
import logging
import struct
from io import BytesIO
from typing import Any, Dict, Iterable, Optional
import requests
from .. import mapping, protobuf
from ..log import DUMP_BYTES
from . import Transport, TransportException
from ..log import DUMP_PACKETS
from . import MessagePayload, Transport, TransportException
LOG = logging.getLogger(__name__)
@ -66,10 +64,12 @@ class BridgeHandle:
class BridgeHandleModern(BridgeHandle):
def write_buf(self, buf: bytes) -> None:
LOG.log(DUMP_PACKETS, "sending message: {}".format(buf.hex()))
self.transport._call("post", data=buf.hex())
def read_buf(self) -> bytes:
data = self.transport._call("read")
LOG.log(DUMP_PACKETS, "received message: {}".format(data.text))
return bytes.fromhex(data.text)
@ -87,7 +87,9 @@ class BridgeHandleLegacy(BridgeHandle):
if self.request is None:
raise TransportException("Can't read without write on legacy Bridge")
try:
LOG.log(DUMP_PACKETS, "calling with message: {}".format(self.request))
data = self.transport._call("call", data=self.request)
LOG.log(DUMP_PACKETS, "received response: {}".format(data.text))
return bytes.fromhex(data.text)
finally:
self.request = None
@ -152,29 +154,12 @@ class BridgeTransport(Transport):
self._call("release")
self.session = None
def write(self, msg: protobuf.MessageType) -> None:
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
buffer = BytesIO()
protobuf.dump_message(buffer, msg)
ser = buffer.getvalue()
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
self.handle.write_buf(header + ser)
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data)
def read(self) -> protobuf.MessageType:
def read(self) -> MessagePayload:
data = self.handle.read_buf()
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
ser = data[headerlen : headerlen + datalen]
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
buffer = BytesIO(ser)
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
return msg_type, data[headerlen : headerlen + datalen]

@ -19,6 +19,7 @@ import sys
import time
from typing import Any, Dict, Iterable
from ..log import DUMP_PACKETS
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
@ -82,9 +83,10 @@ class HidHandle:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
if self.hid_version == 2:
self.handle.write(b"\0" + bytearray(chunk))
else:
self.handle.write(chunk)
chunk = b"\x00" + chunk
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
self.handle.write(chunk)
def read_chunk(self) -> bytes:
while True:
@ -93,6 +95,8 @@ class HidHandle:
break
else:
time.sleep(0.001)
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytes(chunk)
@ -119,8 +123,7 @@ class HidTransport(ProtocolBasedTransport):
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
protocol = ProtocolV1(self.handle)
super().__init__(protocol=protocol)
super().__init__(protocol=ProtocolV1(self.handle))
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
@ -142,15 +145,11 @@ class HidTransport(ProtocolBasedTransport):
return devices
def find_debug(self) -> "HidTransport":
if self.protocol.VERSION >= 2:
# use the same device
return self
else:
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool:

@ -15,16 +15,12 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import os
import struct
from io import BytesIO
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from .. import mapping, protobuf
from ..log import DUMP_BYTES
from . import Transport
from . import MessagePayload, Transport
REPLEN = 64
@ -72,7 +68,6 @@ class Protocol:
- open and close physical connections,
- and send and receive binary chunks.
We declare a protocol version (we have implementations of v1 and v2).
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
@ -80,8 +75,6 @@ class Protocol:
its messages.
"""
VERSION = None # type: int
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
@ -97,10 +90,10 @@ class Protocol:
if self.session_counter == 0:
self.handle.close()
def read(self) -> protobuf.MessageType:
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message: protobuf.MessageType) -> None:
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
@ -114,10 +107,10 @@ class ProtocolBasedTransport(Transport):
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message: protobuf.MessageType) -> None:
self.protocol.write(message)
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self) -> protobuf.MessageType:
def read(self) -> MessagePayload:
return self.protocol.read()
def begin_session(self) -> None:
@ -132,19 +125,11 @@ class ProtocolV1(Protocol):
Does not understand sessions.
"""
VERSION = 1
HEADER_LEN = struct.calcsize(">HL")
def write(self, msg: protobuf.MessageType) -> None:
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
LOG.log(DUMP_BYTES, "sending bytes: {}".format(ser.hex()))
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
buffer = bytearray(b"##" + header + ser)
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
@ -153,7 +138,7 @@ class ProtocolV1(Protocol):
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> protobuf.MessageType:
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
@ -163,30 +148,18 @@ class ProtocolV1(Protocol):
while len(buffer) < datalen:
buffer.extend(self.read_next())
# Strip padding
ser = buffer[:datalen]
data = BytesIO(ser)
LOG.log(DUMP_BYTES, "received bytes: {}".format(ser.hex()))
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + headerlen :]
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
@ -194,160 +167,3 @@ class ProtocolV1(Protocol):
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]
class ProtocolV2(Protocol):
"""Protocol version 2. Currently (11/2018) not used.
Intended to mimic U2F/WebAuthN session handling.
"""
VERSION = 2
def __init__(self, handle: Handle) -> None:
self.session = None
super().__init__(handle)
def begin_session(self) -> None:
# ensure open connection
super().begin_session()
# initiate session
chunk = struct.pack(">B", V2_BEGIN_SESSION)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
# get session identifier
resp = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BL")
magic, session = struct.unpack(">BL", resp[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_BEGIN_SESSION:
raise RuntimeError("Unexpected magic character")
self.session = session
LOG.debug("[session {}] session started".format(self.session))
def end_session(self) -> None:
if not self.session:
return
try:
chunk = struct.pack(">BL", V2_END_SESSION, self.session)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
resp = self.handle.read_chunk()
(magic,) = struct.unpack(">B", resp[:1])
if magic != V2_END_SESSION:
raise RuntimeError("Expected session close")
LOG.debug("[session {}] session ended".format(self.session))
finally:
self.session = None
# close connection if appropriate
super().end_session()
def write(self, msg: protobuf.MessageType) -> None:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
LOG.debug(
"[session {}] sending message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
# Serialize whole message
data = BytesIO()
protobuf.dump_message(data, msg)
data = data.getvalue()
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session)
else:
repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self) -> protobuf.MessageType:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, chunk = self.read_first()
buffer.extend(chunk)
# Read the rest of the message
while len(buffer) < datalen:
next_chunk = self.read_next()
buffer.extend(next_chunk)
# Strip padding
buffer = BytesIO(buffer[:datalen])
# Parse to protobuf
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug(
"[session {}] received message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLLL")
magic, session, msg_type, datalen = struct.unpack(
">BLLL", chunk[:headerlen]
)
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_FIRST_CHUNK:
raise RuntimeError("Unexpected magic character")
if session != self.session:
raise RuntimeError("Session id mismatch")
return msg_type, datalen, chunk[headerlen:]
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLL")
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_NEXT_CHUNK:
raise RuntimeError("Unexpected magic characters")
if session != self.session:
raise RuntimeError("Session id mismatch")
return chunk[headerlen:]
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
"""Make a Protocol instance for the given handle.
Each transport can have a preference for using a particular protocol version.
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
which forces the library to use V1 anyways.
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
for it (i.e., USB transports for Trezor T).
"""
force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1))
if want_v2 and not force_v1:
return ProtocolV2(handle)
else:
return ProtocolV1(handle)

@ -14,15 +14,19 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import socket
import time
from typing import Iterable, Optional, cast
from ..log import DUMP_PACKETS
from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol
from .protocol import ProtocolBasedTransport, ProtocolV1
SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport):
@ -42,8 +46,7 @@ class UdpTransport(ProtocolBasedTransport):
self.device = (host, port)
self.socket = None # type: Optional[socket.socket]
protocol = get_protocol(self, want_v2=False)
super().__init__(protocol=protocol)
super().__init__(protocol=ProtocolV1(self))
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
@ -126,6 +129,7 @@ class UdpTransport(ProtocolBasedTransport):
assert self.socket is not None
if len(chunk) != 64:
raise TransportException("Unexpected data length")
LOG.log(DUMP_PACKETS, "sending packet: {}".format(chunk.hex()))
self.socket.sendall(chunk)
def read_chunk(self) -> bytes:
@ -136,6 +140,7 @@ class UdpTransport(ProtocolBasedTransport):
break
except socket.timeout:
continue
LOG.log(DUMP_PACKETS, "received packet: {}".format(chunk.hex()))
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk)

@ -20,6 +20,7 @@ import sys
import time
from typing import Iterable, Optional
from ..log import DUMP_PACKETS
from . import TREZORS, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
@ -65,6 +66,7 @@ class WebUsbHandle:
assert self.handle is not None
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
LOG.log(DUMP_PACKETS, "writing packet: {}".format(chunk.hex()))
self.handle.interruptWrite(self.endpoint, chunk)
def read_chunk(self) -> bytes:
@ -76,6 +78,7 @@ class WebUsbHandle:
break
else:
time.sleep(0.001)
LOG.log(DUMP_PACKETS, "read packet: {}".format(chunk.hex()))
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return chunk
@ -136,14 +139,8 @@ class WebUsbTransport(ProtocolBasedTransport):
return devices
def find_debug(self) -> "WebUsbTransport":
if self.protocol.VERSION >= 2:
# TODO test this
# XXX this is broken right now because sessions don't really work
# For v2 protocol, use the same WebUSB interface with a different session
return WebUsbTransport(self.device, self.handle)
else:
# For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True)
# For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True)
def is_vendor_class(dev: "usb1.USBDevice") -> bool:

@ -59,9 +59,9 @@ def test_cancel_message_via_initialize(client, message):
resp = client.call_raw(message)
assert isinstance(resp, m.ButtonRequest)
client.transport.write(m.ButtonAck())
client.transport.write(m.Initialize())
client._raw_write(m.ButtonAck())
client._raw_write(m.Initialize())
resp = client.transport.read()
resp = client._raw_read()
assert isinstance(resp, m.Features)

@ -69,10 +69,10 @@ class TestMsgRecoverydeviceT2:
# Enter mnemonic words
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
client.transport.write(proto.ButtonAck())
client._raw_write(proto.ButtonAck())
for word in mnemonic:
client.debug.input(word)
ret = client.transport.read()
ret = client._raw_read()
# Confirm success
assert isinstance(ret, proto.ButtonRequest)
@ -125,10 +125,10 @@ class TestMsgRecoverydeviceT2:
# Enter mnemonic words
assert ret == proto.ButtonRequest(code=proto.ButtonRequestType.MnemonicInput)
client.transport.write(proto.ButtonAck())
client._raw_write(proto.ButtonAck())
for word in mnemonic:
client.debug.input(word)
ret = client.transport.read()
ret = client._raw_read()
# Confirm success
assert isinstance(ret, proto.ButtonRequest)

Loading…
Cancel
Save