trezorlib: transport/protocol reshuffle

This commit breaks session handling (which matters with Bridge) and
regresses Bridge to an older code state. Both of these issues will be
rectified in subsequent commits.

Explanation of this big API reshuffle follows:

* protocols are moved to trezorlib.transport, and to a single common file.
* there is a cleaner definition of Transport and Protocol API (see below)
* fully valid mypy type hinting
* session handle counters and open handle counters mostly went away. Transports
  and Protocols are meant to be "raw" APIs; TrezorClient will implement
  context-handler-based sessions, session tracking, etc.

I'm calling this a "reshuffle" because it involved very small number of
code changes. Most of it is moving things around where they sit better.

The API changes are as follows.

Transport is now a thing that can:
* open and close sessions
* read and write protobuf messages
* enumerate and find devices

Some transports (all except bridge) are technically bytes-based and need
a separate protocol implementation (because we have two existing protocols,
although only the first one is actually used). Hence a protocol superclass.

Protocol is a thing that *also* can:
* open and close sessions
* read and write protobuf messages
For that, it requires a `handle`.

Handle is a physical layer for a protocol. It can:
* open and close some sort of device connection
  (this is distinct from session! Connection is a channel over which you can
  send data. Session is a logical arrangement on top of that; you can have
  multiple sessions on a single connection.)
* read and write 64-byte chunks of data

With that, we introduce ProtocolBasedTransport, which simply delegates
the appropriate Transport functionality to respective Protocol methods.

hid and webusb transports are ProtocolBasedTransport-s that provide separate
device handles. HidHandle and WebUsbHandle existed before, but the distinction
of functionality between a Transport and its Handle was unclear. Some methods
were moved and now the handles implement the Handle API, while the transports
provide the enumeration parts of the Transport API, as well as glue between
the respective Protocols and Handles.

udp transport is also a ProtocolBasedTransport, but it acts as its own handle.
(That might be changed. For now, I went with the pre-existing structure.)

In addition, session_begin/end is renamed to begin/end_session to keep
consistent verb_noun naming.
pull/25/head
matejcik 6 years ago
parent 560a5215c5
commit aac7726824

@ -5,3 +5,4 @@ click>=7,<8
pyblake2>=0.9.3
libusb1>=1.6.4
construct>=2.9
typing_extensions>=3.6

@ -25,10 +25,10 @@ from .tools import expect
class DebugLink:
def __init__(self, transport):
self.transport = transport
self.transport.session_begin()
self.transport.begin_session()
def close(self):
self.transport.session_end()
self.transport.end_session()
def _call(self, msg, nowait=False):
self.transport.write(msg)

@ -1,91 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from io import BytesIO
from typing import Tuple
from . import mapping, protobuf
from .transport import Transport
REPLEN = 64
LOG = logging.getLogger(__name__)
class ProtocolV1:
def session_begin(self, transport: Transport) -> None:
pass
def session_end(self, transport: Transport) -> None:
pass
def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
data = bytearray(b"##" + header + ser)
while data:
# Report ID, data padded to 63 bytes
chunk = b"?" + data[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
data = data[63:]
def read(self, transport: Transport) -> protobuf.MessageType:
# Read header with first part of message data
chunk = transport.read_chunk()
msg_type, datalen, data = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
data.extend(self.parse_next(chunk))
# Strip padding
data = BytesIO(data[:datalen])
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + headerlen :]
return msg_type, datalen, data
def parse_next(self, chunk: bytes) -> bytes:
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

@ -1,147 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from io import BytesIO
from typing import Tuple
from . import mapping, protobuf
from .transport import Transport
REPLEN = 64
LOG = logging.getLogger(__name__)
class ProtocolV2:
def __init__(self) -> None:
self.session = None
def session_begin(self, transport: Transport) -> None:
chunk = struct.pack(">B", 0x03)
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
resp = transport.read_chunk()
self.session = self.parse_session_open(resp)
LOG.debug("[session {}] session started".format(self.session))
def session_end(self, transport: Transport) -> None:
if not self.session:
return
chunk = struct.pack(">BL", 0x04, self.session)
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
resp = transport.read_chunk()
(magic,) = struct.unpack(">B", resp[:1])
if magic != 0x04:
raise RuntimeError("Expected session close")
LOG.debug("[session {}] session ended".format(self.session))
self.session = None
def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
LOG.debug(
"[session {}] sending message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
# Serialize whole message
data = BytesIO()
protobuf.dump_message(data, msg)
data = data.getvalue()
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack(">BL", 0x01, self.session)
else:
repheader = struct.pack(">BLL", 0x02, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self, transport: Transport) -> protobuf.MessageType:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
# Read header with first part of message data
chunk = transport.read_chunk()
msg_type, datalen, data = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
next_data = self.parse_next(chunk)
data.extend(next_data)
# Strip padding
data = BytesIO(data[:datalen])
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"[session {}] received message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
return msg
def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
try:
headerlen = struct.calcsize(">BLLL")
magic, session, msg_type, datalen = struct.unpack(
">BLLL", chunk[:headerlen]
)
except Exception:
raise RuntimeError("Cannot parse header")
if magic != 0x01:
raise RuntimeError("Unexpected magic character")
if session != self.session:
raise RuntimeError("Session id mismatch")
return msg_type, datalen, chunk[headerlen:]
def parse_next(self, chunk: bytes) -> bytes:
try:
headerlen = struct.calcsize(">BLL")
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != 0x02:
raise RuntimeError("Unexpected magic characters")
if session != self.session:
raise RuntimeError("Session id mismatch")
return chunk[headerlen:]
def parse_session_open(self, chunk: bytes) -> int:
try:
headerlen = struct.calcsize(">BL")
magic, session = struct.unpack(">BL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != 0x03:
raise RuntimeError("Unexpected magic character")
return session

@ -41,10 +41,10 @@ class TrezorTest:
# self.client.set_buttonwait(3)
device.wipe(self.client)
self.client.transport.session_begin()
self.client.transport.begin_session()
def teardown_method(self, method):
self.client.transport.session_end()
self.client.transport.end_session()
self.client.close()
def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False):

@ -55,11 +55,11 @@ def client():
wirelink = get_device()
client = TrezorClientDebugLink(wirelink)
wipe_device(client)
client.transport.session_begin()
client.transport.begin_session()
yield client
client.transport.session_end()
client.transport.end_session()
# XXX debuglink session must also be closed
# client.close accomplishes that for now; going forward, there should

@ -222,14 +222,13 @@ def session(f):
# Decorator wraps a BaseClient method
# with session activation / deactivation
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
def wrapped_f(client, *args, **kwargs):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
client = args[0]
client.transport.session_begin()
client.transport.begin_session()
try:
return f(*args, **kwargs)
return f(client, *args, **kwargs)
finally:
client.transport.session_end()
client.transport.begin_session()
return wrapped_f

@ -18,6 +18,8 @@ import importlib
import logging
from typing import Iterable, Type
from ..protobuf import MessageType
LOG = logging.getLogger(__name__)
@ -25,38 +27,49 @@ class TransportException(Exception):
pass
class Transport(object):
def __init__(self):
self.session_counter = 0
class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
def __str__(self):
PATH_PREFIX = None # type: str
def __str__(self) -> str:
return self.get_path()
def get_path(self):
return "{}:{}".format(self.PATH_PREFIX, self.device)
def get_path(self) -> str:
raise NotImplementedError
def session_begin(self):
if self.session_counter == 0:
self.open()
self.session_counter += 1
def begin_session(self) -> None:
raise NotImplementedError
def session_end(self):
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.close()
def end_session(self) -> None:
raise NotImplementedError
def open(self):
def read(self) -> MessageType:
raise NotImplementedError
def close(self):
def write(self, message: MessageType) -> None:
raise NotImplementedError
@classmethod
def enumerate(cls):
def enumerate(cls) -> Iterable["Transport"]:
raise NotImplementedError
@classmethod
def find_by_path(cls, path, prefix_search=False):
def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport":
for device in cls.enumerate():
if (
path is None

@ -14,9 +14,11 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import binascii
import logging
import struct
from io import BytesIO
from typing import Any, Dict, Iterable
import requests
@ -28,6 +30,10 @@ LOG = logging.getLogger(__name__)
TREZORD_HOST = "http://127.0.0.1:21325"
def get_error(resp: requests.Response) -> str:
return " (error=%d str=%s)" % (resp.status_code, resp.json()["error"])
class BridgeTransport(Transport):
"""
BridgeTransport implements transport through TREZOR Bridge (aka trezord).
@ -36,91 +42,81 @@ class BridgeTransport(Transport):
PATH_PREFIX = "bridge"
HEADERS = {"Origin": "https://python.trezor.io"}
def __init__(self, device):
super().__init__()
def __init__(self, device: Dict[str, Any]) -> None:
self.device = device
self.conn = requests.Session()
self.session = None
self.request = None
self.session = None # type: Optional[str]
self.response = None # type: Optional[str]
def get_path(self):
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, self.device["path"])
@classmethod
def _call(cls, action, data=None, uri_suffix=None, session=None):
if uri_suffix is not None:
uri_suffix = "/" + uri_suffix
elif session is not None:
uri_suffix = "/{}".format(session)
else:
uri_suffix = ""
url = "{}/{}{}".format(TREZORD_HOST, action, uri_suffix)
r = requests.post(url, headers=cls.HEADERS, data=data)
if r.status_code != 200:
raise TransportException(
"trezord: '{}' action failed with code {}: {}".format(
action, r.status_code, r.json().get("error", "(no error message)")
)
)
return r
@classmethod
def enumerate(cls):
def enumerate(cls) -> Iterable["BridgeTransport"]:
try:
r = cls._call("enumerate")
r = requests.post(TREZORD_HOST + "/enumerate", headers=cls.HEADERS)
if r.status_code != 200:
raise TransportException(
"trezord: Could not enumerate devices" + get_error(r)
)
return [BridgeTransport(dev) for dev in r.json()]
except Exception:
return []
def open(self):
r = self._call("acquire", uri_suffix="{}/null".format(self.device["path"]))
def begin_session(self) -> None:
r = self.conn.post(
TREZORD_HOST + "/acquire/%s/null" % self.device["path"],
headers=self.HEADERS,
)
if r.status_code != 200:
raise TransportException(
"trezord: Could not acquire session" + get_error(r)
)
self.session = r.json()["session"]
def close(self):
def end_session(self) -> None:
if not self.session:
return
self._call("release", session=self.session)
r = self.conn.post(
TREZORD_HOST + "/release/%s" % self.session, headers=self.HEADERS
)
if r.status_code != 200:
raise TransportException(
"trezord: Could not release session" + get_error(r)
)
self.session = None
def write(self, msg):
if self.request is not None:
raise TransportException("trezord can't perform two writes without a read")
def write(self, msg: protobuf.MessageType) -> None:
LOG.debug(
"preparing message: {}".format(msg.__class__.__name__),
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
# encode the message
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
buffer = BytesIO()
protobuf.dump_message(buffer, msg)
ser = buffer.getvalue()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
# store for later
self.request = (header + ser).hex()
def read(self):
if self.request is None:
raise TransportException("trezord: no request in queue")
try:
LOG.debug("sending prepared message")
r = self._call("call", data=self.request, session=self.session)
data = bytes.fromhex(r.text)
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
data = BytesIO(data[headerlen : headerlen + datalen])
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
finally:
self.request = None
data = binascii.hexlify(header + ser).decode()
r = self.conn.post( # type: ignore # typeshed bug
TREZORD_HOST + "/call/%s" % self.session, data=data, headers=self.HEADERS
)
if r.status_code != 200:
raise TransportException("trezord: Could not write message" + get_error(r))
self.response = r.text
def read(self) -> protobuf.MessageType:
if self.response is None:
raise TransportException("No response stored")
data = binascii.unhexlify(self.response)
headerlen = struct.calcsize(">HL")
(msg_type, datalen) = struct.unpack(">HL", data[:headerlen])
buffer = BytesIO(data[headerlen : headerlen + datalen])
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
self.response = None
return msg
TRANSPORT = BridgeTransport

@ -16,77 +16,99 @@
import sys
import time
from typing import Any, Dict, Iterable
import hid
from . import Transport, TransportException
from ..protocol_v1 import ProtocolV1
from ..protocol_v2 import ProtocolV2
from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol
DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1)
DEV_TREZOR2_BL = (0x1209, 0x53C0)
HidDevice = Dict[str, Any]
HidDeviceHandle = Any
class HidHandle:
def __init__(self, path):
def __init__(self, path: str, probe_hid_version: bool = False) -> None:
self.path = path
self.count = 0
self.handle = None
self.handle = None # type: HidDeviceHandle
self.hid_version = None if probe_hid_version else 2
def open(self) -> None:
self.handle = hid.device()
try:
self.handle.open_path(self.path)
except (IOError, OSError) as e:
if sys.platform.startswith("linux"):
e.args = e.args + (
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
)
raise e
self.handle.set_nonblocking(True)
if self.hid_version is None:
self.hid_version = self.probe_hid_version()
def open(self):
if self.count == 0:
self.handle = hid.device()
try:
self.handle.open_path(self.path)
except (IOError, OSError) as e:
if sys.platform.startswith("linux"):
e.args = e.args + (
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
)
raise e
self.handle.set_nonblocking(True)
self.count += 1
def close(self):
if self.count == 1:
def close(self) -> None:
if self.handle is not None:
self.handle.close()
if self.count > 0:
self.count -= 1
self.handle = None
def write_chunk(self, chunk: bytes) -> None:
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
if self.hid_version == 2:
self.handle.write(b"\0" + bytearray(chunk))
else:
self.handle.write(chunk)
def read_chunk(self) -> bytes:
while True:
chunk = self.handle.read(64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return chunk
def probe_hid_version(self) -> int:
n = self.handle.write([0, 63] + [0xFF] * 63)
if n == 65:
return 2
n = self.handle.write([63] + [0xFF] * 63)
if n == 64:
return 1
raise TransportException("Unknown HID version")
class HidTransport(Transport):
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
def __init__(self, device, protocol=None, hid_handle=None):
super(HidTransport, self).__init__()
def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None:
if hid_handle is None:
hid_handle = HidHandle(device["path"])
if protocol is None:
# force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
force_v1 = True
if is_trezor2(device) and not int(force_v1):
protocol = ProtocolV2()
else:
protocol = ProtocolV1()
self.device = device
self.protocol = protocol
self.hid = hid_handle
self.hid_version = None
def get_path(self):
protocol = get_protocol(hid_handle, is_trezor2(device))
super().__init__(protocol=protocol)
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
@staticmethod
def enumerate(debug=False):
@classmethod
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
devices = []
for dev in hid.enumerate(0, 0):
if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)):
@ -100,84 +122,35 @@ class HidTransport(Transport):
devices.append(HidTransport(dev))
return devices
def find_debug(self):
if isinstance(self.protocol, ProtocolV2):
# For v2 protocol, lets use the same HID interface, but with a different session
protocol = ProtocolV2()
debug = HidTransport(self.device, protocol, self.hid)
return debug
if isinstance(self.protocol, ProtocolV1):
def find_debug(self) -> "HidTransport":
if self.protocol.VERSION >= 2:
# use the same device
return self
else:
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self):
self.hid.open()
if is_trezor1(self.device):
self.hid_version = self.probe_hid_version()
else:
self.hid_version = 2
self.protocol.session_begin(self)
def close(self):
self.protocol.session_end(self)
self.hid.close()
self.hid_version = None
def read(self):
return self.protocol.read(self)
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
if self.hid_version == 2:
self.hid.handle.write(b"\0" + bytearray(chunk))
else:
self.hid.handle.write(chunk)
def read_chunk(self):
while True:
chunk = self.hid.handle.read(64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk)
def probe_hid_version(self):
n = self.hid.handle.write([0, 63] + [0xFF] * 63)
if n == 65:
return 2
n = self.hid.handle.write([63] + [0xFF] * 63)
if n == 64:
return 1
raise TransportException("Unknown HID version")
raise TransportException("Debug HID device not found")
def is_trezor1(dev):
def is_trezor1(dev: HidDevice) -> bool:
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR1
def is_trezor2(dev):
def is_trezor2(dev: HidDevice) -> bool:
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2
def is_trezor2_bl(dev):
def is_trezor2_bl(dev: HidDevice) -> bool:
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2_BL
def is_wirelink(dev):
def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
def is_debuglink(dev):
def is_debuglink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1

@ -0,0 +1,348 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import os
import struct
from io import BytesIO
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import Transport
from .. import mapping, protobuf
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None:
...
def close(self) -> None:
...
def read_chunk(self) -> bytes:
...
def write_chunk(self, chunk: bytes) -> None:
...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
We declare a protocol version (we have implementations of v1 and v2).
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
VERSION = None # type: int
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self) -> protobuf.MessageType:
raise NotImplementedError
def write(self, message: protobuf.MessageType) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message: protobuf.MessageType) -> None:
self.protocol.write(message)
def read(self) -> protobuf.MessageType:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
VERSION = 1
def write(self, msg: protobuf.MessageType) -> None:
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
buffer = bytearray(b"##" + header + ser)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> protobuf.MessageType:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
# Strip padding
data = BytesIO(buffer[:datalen])
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + headerlen :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]
class ProtocolV2(Protocol):
"""Protocol version 2. Currently (11/2018) not used.
Intended to mimic U2F/WebAuthN session handling.
"""
VERSION = 2
def __init__(self, handle: Handle) -> None:
self.session = None
super().__init__(handle)
def begin_session(self) -> None:
# ensure open connection
super().begin_session()
# initiate session
chunk = struct.pack(">B", V2_BEGIN_SESSION)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
# get session identifier
resp = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BL")
magic, session = struct.unpack(">BL", resp[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_BEGIN_SESSION:
raise RuntimeError("Unexpected magic character")
self.session = session
LOG.debug("[session {}] session started".format(self.session))
def end_session(self) -> None:
if not self.session:
return
try:
chunk = struct.pack(">BL", V2_END_SESSION, self.session)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
resp = self.handle.read_chunk()
(magic,) = struct.unpack(">B", resp[:1])
if magic != V2_END_SESSION:
raise RuntimeError("Expected session close")
LOG.debug("[session {}] session ended".format(self.session))
finally:
self.session = None
# close connection if appropriate
super().end_session()
def write(self, msg: protobuf.MessageType) -> None:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
LOG.debug(
"[session {}] sending message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
# Serialize whole message
data = BytesIO()
protobuf.dump_message(data, msg)
data = data.getvalue()
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session)
else:
repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self) -> protobuf.MessageType:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, chunk = self.read_first()
buffer.extend(chunk)
# Read the rest of the message
while len(buffer) < datalen:
next_chunk = self.read_next()
buffer.extend(next_chunk)
# Strip padding
buffer = BytesIO(buffer[:datalen])
# Parse to protobuf
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug(
"[session {}] received message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLLL")
magic, session, msg_type, datalen = struct.unpack(
">BLLL", chunk[:headerlen]
)
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_FIRST_CHUNK:
raise RuntimeError("Unexpected magic character")
if session != self.session:
raise RuntimeError("Session id mismatch")
return msg_type, datalen, chunk[headerlen:]
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLL")
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_NEXT_CHUNK:
raise RuntimeError("Unexpected magic characters")
if session != self.session:
raise RuntimeError("Session id mismatch")
return chunk[headerlen:]
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
"""Make a Protocol instance for the given handle.
Each transport can have a preference for using a particular protocol version.
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
which forces the library to use V1 anyways.
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
for it (i.e., USB transports for Trezor T).
"""
force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1))
if want_v2 and not force_v1:
return ProtocolV2(handle)
else:
return ProtocolV1(handle)

@ -15,20 +15,19 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import socket
from typing import Iterable, cast
from . import Transport, TransportException
from ..protocol_v1 import ProtocolV1
from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol
class UdpTransport(Transport):
class UdpTransport(ProtocolBasedTransport):
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324
PATH_PREFIX = "udp"
def __init__(self, device=None, protocol=None):
super(UdpTransport, self).__init__()
def __init__(self, device: str = None) -> None:
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
@ -36,21 +35,21 @@ class UdpTransport(Transport):
devparts = device.split(":")
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
if not protocol:
protocol = ProtocolV1()
self.device = (host, port)
self.protocol = protocol
self.socket = None
self.socket = None # type: Optional[socket.socket]
protocol = get_protocol(self, want_v2=False)
super().__init__(protocol=protocol)
def get_path(self):
return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device)
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self):
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport("{}:{}".format(host, port + 1), self.protocol)
return UdpTransport("{}:{}".format(host, port + 1))
@classmethod
def _try_path(cls, path):
def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path)
try:
d.open()
@ -64,7 +63,7 @@ class UdpTransport(Transport):
d.close()
@classmethod
def enumerate(cls):
def enumerate(cls) -> Iterable["UdpTransport"]:
default_path = "{}:{}".format(cls.DEFAULT_HOST, cls.DEFAULT_PORT)
try:
return [cls._try_path(default_path)]
@ -72,27 +71,29 @@ class UdpTransport(Transport):
return []
@classmethod
def find_by_path(cls, path, prefix_search=False):
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
if prefix_search:
return super().find_by_path(path, prefix_search)
return cast(UdpTransport, super().find_by_path(path, prefix_search))
# This is *technically* type-able: mark `find_by_path` as returning
# the same type from which `cls` comes from.
# Mypy can't handle that though, so here we are.
else:
path = path.replace("{}:".format(cls.PATH_PREFIX), "")
return cls._try_path(path)
def open(self):
def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device)
self.socket.settimeout(10)
self.protocol.session_begin(self)
def close(self):
if self.socket:
self.protocol.session_end(self)
def close(self) -> None:
if self.socket is not None:
self.socket.close()
self.socket = None
self.socket = None
def _ping(self):
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
@ -101,18 +102,14 @@ class UdpTransport(Transport):
pass
return resp == b"PONGPONG"
def read(self):
return self.protocol.read(self)
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
def write_chunk(self, chunk: bytes) -> None:
assert self.socket is not None
if len(chunk) != 64:
raise TransportException("Unexpected data length")
self.socket.sendall(chunk)
def read_chunk(self):
def read_chunk(self) -> bytes:
assert self.socket is not None
while True:
try:
chunk = self.socket.recv(64)

@ -17,12 +17,12 @@
import atexit
import sys
import time
from typing import Iterable
import usb1
from . import Transport, TransportException
from ..protocol_v1 import ProtocolV1
from ..protocol_v2 import ProtocolV2
from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol
DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1)
@ -35,34 +35,52 @@ DEBUG_ENDPOINT = 2
class WebUsbHandle:
def __init__(self, device):
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
self.device = device
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0
self.handle = None
self.handle = None # type: Optional[usb1.USBDeviceHandle]
def open(self) -> None:
self.handle = self.device.open()
if self.handle is None:
if sys.platform.startswith("linux"):
args = (
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
)
else:
args = ()
raise IOError("Cannot open device", *args)
self.handle.claimInterface(self.interface)
def open(self, interface):
if self.count == 0:
self.handle = self.device.open()
if self.handle is None:
if sys.platform.startswith("linux"):
args = (
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
)
else:
args = ()
raise IOError("Cannot open device", *args)
self.handle.claimInterface(interface)
self.count += 1
def close(self, interface):
if self.count == 1:
self.handle.releaseInterface(interface)
def close(self) -> None:
if self.handle is not None:
self.handle.releaseInterface(self.interface)
self.handle.close()
if self.count > 0:
self.count -= 1
self.handle = None
def write_chunk(self, chunk: bytes) -> None:
assert self.handle is not None
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
self.handle.interruptWrite(self.endpoint, chunk)
class WebUsbTransport(Transport):
def read_chunk(self) -> bytes:
assert self.handle is not None
endpoint = 0x80 | self.endpoint
while True:
chunk = self.handle.interruptRead(endpoint, 64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return chunk
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
@ -70,31 +88,24 @@ class WebUsbTransport(Transport):
PATH_PREFIX = "webusb"
context = None
def __init__(self, device, protocol=None, handle=None, debug=False):
super(WebUsbTransport, self).__init__()
def __init__(
self, device: str, handle: WebUsbHandle = None, debug: bool = False
) -> None:
if handle is None:
handle = WebUsbHandle(device)
if protocol is None:
# force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
force_v1 = True
if is_trezor2(device) and not int(force_v1):
protocol = ProtocolV2()
else:
protocol = ProtocolV1()
handle = WebUsbHandle(device, debug)
self.device = device
self.protocol = protocol
self.handle = handle
self.debug = debug
def get_path(self):
protocol = get_protocol(handle, is_trezor2(device))
super().__init__(protocol=protocol)
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@classmethod
def enumerate(cls):
def enumerate(cls) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
@ -117,69 +128,30 @@ class WebUsbTransport(Transport):
pass
return devices
def find_debug(self):
if isinstance(self.protocol, ProtocolV2):
def find_debug(self) -> "WebUsbTransport":
if self.protocol.VERSION >= 2:
# TODO test this
# For v2 protocol, lets use the same WebUSB interface, but with a different session
protocol = ProtocolV2()
debug = WebUsbTransport(self.device, protocol, self.handle)
return debug
if isinstance(self.protocol, ProtocolV1):
# XXX this is broken right now because sessions don't really work
# For v2 protocol, use the same WebUSB interface with a different session
return WebUsbTransport(self.device, self.handle)
else:
# For v1 protocol, find debug USB interface for the same serial number
protocol = ProtocolV1()
debug = WebUsbTransport(self.device, protocol, None, True)
return debug
raise TransportException("Debug WebUSB device not found")
def open(self):
interface = DEBUG_INTERFACE if self.debug else INTERFACE
self.handle.open(interface)
self.protocol.session_begin(self)
def close(self):
interface = DEBUG_INTERFACE if self.debug else INTERFACE
self.protocol.session_end(self)
self.handle.close(interface)
def read(self):
return self.protocol.read(self)
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
self.handle.handle.interruptWrite(endpoint, chunk)
def read_chunk(self):
endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT
endpoint = 0x80 | endpoint
while True:
chunk = self.handle.handle.interruptRead(endpoint, 64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk)
return WebUsbTransport(self.device, debug=True)
def is_trezor1(dev):
def is_trezor1(dev: usb1.USBDevice) -> bool:
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR1
def is_trezor2(dev):
def is_trezor2(dev: usb1.USBDevice) -> bool:
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2
def is_trezor2_bl(dev):
def is_trezor2_bl(dev: usb1.USBDevice) -> bool:
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL
def is_vendor_class(dev):
def is_vendor_class(dev: usb1.USBDevice) -> bool:
configurationId = 0
altSettingId = 0
return (
@ -188,7 +160,7 @@ def is_vendor_class(dev):
)
def dev_to_str(dev):
def dev_to_str(dev: usb1.USBDevice) -> str:
return ":".join(
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
)

Loading…
Cancel
Save