1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 23:48:12 +00:00

trezorlib: transport/protocol reshuffle

This commit breaks session handling (which matters with Bridge) and
regresses Bridge to an older code state. Both of these issues will be
rectified in subsequent commits.

Explanation of this big API reshuffle follows:

* protocols are moved to trezorlib.transport, and to a single common file.
* there is a cleaner definition of Transport and Protocol API (see below)
* fully valid mypy type hinting
* session handle counters and open handle counters mostly went away. Transports
  and Protocols are meant to be "raw" APIs; TrezorClient will implement
  context-handler-based sessions, session tracking, etc.

I'm calling this a "reshuffle" because it involved very small number of
code changes. Most of it is moving things around where they sit better.

The API changes are as follows.

Transport is now a thing that can:
* open and close sessions
* read and write protobuf messages
* enumerate and find devices

Some transports (all except bridge) are technically bytes-based and need
a separate protocol implementation (because we have two existing protocols,
although only the first one is actually used). Hence a protocol superclass.

Protocol is a thing that *also* can:
* open and close sessions
* read and write protobuf messages
For that, it requires a `handle`.

Handle is a physical layer for a protocol. It can:
* open and close some sort of device connection
  (this is distinct from session! Connection is a channel over which you can
  send data. Session is a logical arrangement on top of that; you can have
  multiple sessions on a single connection.)
* read and write 64-byte chunks of data

With that, we introduce ProtocolBasedTransport, which simply delegates
the appropriate Transport functionality to respective Protocol methods.

hid and webusb transports are ProtocolBasedTransport-s that provide separate
device handles. HidHandle and WebUsbHandle existed before, but the distinction
of functionality between a Transport and its Handle was unclear. Some methods
were moved and now the handles implement the Handle API, while the transports
provide the enumeration parts of the Transport API, as well as glue between
the respective Protocols and Handles.

udp transport is also a ProtocolBasedTransport, but it acts as its own handle.
(That might be changed. For now, I went with the pre-existing structure.)

In addition, session_begin/end is renamed to begin/end_session to keep
consistent verb_noun naming.
This commit is contained in:
matejcik 2018-11-08 15:24:28 +01:00
parent 560a5215c5
commit aac7726824
13 changed files with 619 additions and 558 deletions

View File

@ -5,3 +5,4 @@ click>=7,<8
pyblake2>=0.9.3 pyblake2>=0.9.3
libusb1>=1.6.4 libusb1>=1.6.4
construct>=2.9 construct>=2.9
typing_extensions>=3.6

View File

@ -25,10 +25,10 @@ from .tools import expect
class DebugLink: class DebugLink:
def __init__(self, transport): def __init__(self, transport):
self.transport = transport self.transport = transport
self.transport.session_begin() self.transport.begin_session()
def close(self): def close(self):
self.transport.session_end() self.transport.end_session()
def _call(self, msg, nowait=False): def _call(self, msg, nowait=False):
self.transport.write(msg) self.transport.write(msg)

View File

@ -1,91 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from io import BytesIO
from typing import Tuple
from . import mapping, protobuf
from .transport import Transport
REPLEN = 64
LOG = logging.getLogger(__name__)
class ProtocolV1:
def session_begin(self, transport: Transport) -> None:
pass
def session_end(self, transport: Transport) -> None:
pass
def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
data = bytearray(b"##" + header + ser)
while data:
# Report ID, data padded to 63 bytes
chunk = b"?" + data[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
data = data[63:]
def read(self, transport: Transport) -> protobuf.MessageType:
# Read header with first part of message data
chunk = transport.read_chunk()
msg_type, datalen, data = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
data.extend(self.parse_next(chunk))
# Strip padding
data = BytesIO(data[:datalen])
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + headerlen :]
return msg_type, datalen, data
def parse_next(self, chunk: bytes) -> bytes:
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -1,147 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from io import BytesIO
from typing import Tuple
from . import mapping, protobuf
from .transport import Transport
REPLEN = 64
LOG = logging.getLogger(__name__)
class ProtocolV2:
def __init__(self) -> None:
self.session = None
def session_begin(self, transport: Transport) -> None:
chunk = struct.pack(">B", 0x03)
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
resp = transport.read_chunk()
self.session = self.parse_session_open(resp)
LOG.debug("[session {}] session started".format(self.session))
def session_end(self, transport: Transport) -> None:
if not self.session:
return
chunk = struct.pack(">BL", 0x04, self.session)
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
resp = transport.read_chunk()
(magic,) = struct.unpack(">B", resp[:1])
if magic != 0x04:
raise RuntimeError("Expected session close")
LOG.debug("[session {}] session ended".format(self.session))
self.session = None
def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
LOG.debug(
"[session {}] sending message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
# Serialize whole message
data = BytesIO()
protobuf.dump_message(data, msg)
data = data.getvalue()
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack(">BL", 0x01, self.session)
else:
repheader = struct.pack(">BLL", 0x02, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, b"\x00")
transport.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self, transport: Transport) -> protobuf.MessageType:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
# Read header with first part of message data
chunk = transport.read_chunk()
msg_type, datalen, data = self.parse_first(chunk)
# Read the rest of the message
while len(data) < datalen:
chunk = transport.read_chunk()
next_data = self.parse_next(chunk)
data.extend(next_data)
# Strip padding
data = BytesIO(data[:datalen])
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"[session {}] received message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
return msg
def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
try:
headerlen = struct.calcsize(">BLLL")
magic, session, msg_type, datalen = struct.unpack(
">BLLL", chunk[:headerlen]
)
except Exception:
raise RuntimeError("Cannot parse header")
if magic != 0x01:
raise RuntimeError("Unexpected magic character")
if session != self.session:
raise RuntimeError("Session id mismatch")
return msg_type, datalen, chunk[headerlen:]
def parse_next(self, chunk: bytes) -> bytes:
try:
headerlen = struct.calcsize(">BLL")
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != 0x02:
raise RuntimeError("Unexpected magic characters")
if session != self.session:
raise RuntimeError("Session id mismatch")
return chunk[headerlen:]
def parse_session_open(self, chunk: bytes) -> int:
try:
headerlen = struct.calcsize(">BL")
magic, session = struct.unpack(">BL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != 0x03:
raise RuntimeError("Unexpected magic character")
return session

View File

@ -41,10 +41,10 @@ class TrezorTest:
# self.client.set_buttonwait(3) # self.client.set_buttonwait(3)
device.wipe(self.client) device.wipe(self.client)
self.client.transport.session_begin() self.client.transport.begin_session()
def teardown_method(self, method): def teardown_method(self, method):
self.client.transport.session_end() self.client.transport.end_session()
self.client.close() self.client.close()
def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False): def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False):

View File

@ -55,11 +55,11 @@ def client():
wirelink = get_device() wirelink = get_device()
client = TrezorClientDebugLink(wirelink) client = TrezorClientDebugLink(wirelink)
wipe_device(client) wipe_device(client)
client.transport.session_begin() client.transport.begin_session()
yield client yield client
client.transport.session_end() client.transport.end_session()
# XXX debuglink session must also be closed # XXX debuglink session must also be closed
# client.close accomplishes that for now; going forward, there should # client.close accomplishes that for now; going forward, there should

View File

@ -222,14 +222,13 @@ def session(f):
# Decorator wraps a BaseClient method # Decorator wraps a BaseClient method
# with session activation / deactivation # with session activation / deactivation
@functools.wraps(f) @functools.wraps(f)
def wrapped_f(*args, **kwargs): def wrapped_f(client, *args, **kwargs):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
client = args[0] client.transport.begin_session()
client.transport.session_begin()
try: try:
return f(*args, **kwargs) return f(client, *args, **kwargs)
finally: finally:
client.transport.session_end() client.transport.begin_session()
return wrapped_f return wrapped_f

View File

@ -18,6 +18,8 @@ import importlib
import logging import logging
from typing import Iterable, Type from typing import Iterable, Type
from ..protobuf import MessageType
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -25,38 +27,49 @@ class TransportException(Exception):
pass pass
class Transport(object): class Transport:
def __init__(self): """Raw connection to a Trezor device.
self.session_counter = 0
def __str__(self): Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX = None # type: str
def __str__(self) -> str:
return self.get_path() return self.get_path()
def get_path(self): def get_path(self) -> str:
return "{}:{}".format(self.PATH_PREFIX, self.device)
def session_begin(self):
if self.session_counter == 0:
self.open()
self.session_counter += 1
def session_end(self):
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.close()
def open(self):
raise NotImplementedError raise NotImplementedError
def close(self): def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessageType:
raise NotImplementedError
def write(self, message: MessageType) -> None:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def enumerate(cls): def enumerate(cls) -> Iterable["Transport"]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def find_by_path(cls, path, prefix_search=False): def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport":
for device in cls.enumerate(): for device in cls.enumerate():
if ( if (
path is None path is None

View File

@ -14,9 +14,11 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import binascii
import logging import logging
import struct import struct
from io import BytesIO from io import BytesIO
from typing import Any, Dict, Iterable
import requests import requests
@ -28,6 +30,10 @@ LOG = logging.getLogger(__name__)
TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_HOST = "http://127.0.0.1:21325"
def get_error(resp: requests.Response) -> str:
return " (error=%d str=%s)" % (resp.status_code, resp.json()["error"])
class BridgeTransport(Transport): class BridgeTransport(Transport):
""" """
BridgeTransport implements transport through TREZOR Bridge (aka trezord). BridgeTransport implements transport through TREZOR Bridge (aka trezord).
@ -36,91 +42,81 @@ class BridgeTransport(Transport):
PATH_PREFIX = "bridge" PATH_PREFIX = "bridge"
HEADERS = {"Origin": "https://python.trezor.io"} HEADERS = {"Origin": "https://python.trezor.io"}
def __init__(self, device): def __init__(self, device: Dict[str, Any]) -> None:
super().__init__()
self.device = device self.device = device
self.conn = requests.Session() self.conn = requests.Session()
self.session = None self.session = None # type: Optional[str]
self.request = None self.response = None # type: Optional[str]
def get_path(self): def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, self.device["path"]) return "%s:%s" % (self.PATH_PREFIX, self.device["path"])
@classmethod @classmethod
def _call(cls, action, data=None, uri_suffix=None, session=None): def enumerate(cls) -> Iterable["BridgeTransport"]:
if uri_suffix is not None: try:
uri_suffix = "/" + uri_suffix r = requests.post(TREZORD_HOST + "/enumerate", headers=cls.HEADERS)
elif session is not None:
uri_suffix = "/{}".format(session)
else:
uri_suffix = ""
url = "{}/{}{}".format(TREZORD_HOST, action, uri_suffix)
r = requests.post(url, headers=cls.HEADERS, data=data)
if r.status_code != 200: if r.status_code != 200:
raise TransportException( raise TransportException(
"trezord: '{}' action failed with code {}: {}".format( "trezord: Could not enumerate devices" + get_error(r)
action, r.status_code, r.json().get("error", "(no error message)")
) )
)
return r
@classmethod
def enumerate(cls):
try:
r = cls._call("enumerate")
return [BridgeTransport(dev) for dev in r.json()] return [BridgeTransport(dev) for dev in r.json()]
except Exception: except Exception:
return [] return []
def open(self): def begin_session(self) -> None:
r = self._call("acquire", uri_suffix="{}/null".format(self.device["path"])) r = self.conn.post(
TREZORD_HOST + "/acquire/%s/null" % self.device["path"],
headers=self.HEADERS,
)
if r.status_code != 200:
raise TransportException(
"trezord: Could not acquire session" + get_error(r)
)
self.session = r.json()["session"] self.session = r.json()["session"]
def close(self): def end_session(self) -> None:
if not self.session: if not self.session:
return return
self._call("release", session=self.session) r = self.conn.post(
TREZORD_HOST + "/release/%s" % self.session, headers=self.HEADERS
)
if r.status_code != 200:
raise TransportException(
"trezord: Could not release session" + get_error(r)
)
self.session = None self.session = None
def write(self, msg): def write(self, msg: protobuf.MessageType) -> None:
if self.request is not None:
raise TransportException("trezord can't perform two writes without a read")
LOG.debug( LOG.debug(
"preparing message: {}".format(msg.__class__.__name__), "sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg}, extra={"protobuf": msg},
) )
# encode the message buffer = BytesIO()
data = BytesIO() protobuf.dump_message(buffer, msg)
protobuf.dump_message(data, msg) ser = buffer.getvalue()
ser = data.getvalue()
header = struct.pack(">HL", mapping.get_type(msg), len(ser)) header = struct.pack(">HL", mapping.get_type(msg), len(ser))
# store for later data = binascii.hexlify(header + ser).decode()
self.request = (header + ser).hex() r = self.conn.post( # type: ignore # typeshed bug
TREZORD_HOST + "/call/%s" % self.session, data=data, headers=self.HEADERS
)
if r.status_code != 200:
raise TransportException("trezord: Could not write message" + get_error(r))
self.response = r.text
def read(self): def read(self) -> protobuf.MessageType:
if self.request is None: if self.response is None:
raise TransportException("trezord: no request in queue") raise TransportException("No response stored")
data = binascii.unhexlify(self.response)
try:
LOG.debug("sending prepared message")
r = self._call("call", data=self.request, session=self.session)
data = bytes.fromhex(r.text)
headerlen = struct.calcsize(">HL") headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen]) (msg_type, datalen) = struct.unpack(">HL", data[:headerlen])
data = BytesIO(data[headerlen : headerlen + datalen]) buffer = BytesIO(data[headerlen : headerlen + datalen])
msg = protobuf.load_message(data, mapping.get_class(msg_type)) msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug( LOG.debug(
"received message: {}".format(msg.__class__.__name__), "received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg}, extra={"protobuf": msg},
) )
self.response = None
return msg return msg
finally:
self.request = None
TRANSPORT = BridgeTransport TRANSPORT = BridgeTransport

View File

@ -16,26 +16,28 @@
import sys import sys
import time import time
from typing import Any, Dict, Iterable
import hid import hid
from . import Transport, TransportException from . import TransportException
from ..protocol_v1 import ProtocolV1 from .protocol import ProtocolBasedTransport, get_protocol
from ..protocol_v2 import ProtocolV2
DEV_TREZOR1 = (0x534C, 0x0001) DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1) DEV_TREZOR2 = (0x1209, 0x53C1)
DEV_TREZOR2_BL = (0x1209, 0x53C0) DEV_TREZOR2_BL = (0x1209, 0x53C0)
HidDevice = Dict[str, Any]
HidDeviceHandle = Any
class HidHandle: class HidHandle:
def __init__(self, path): def __init__(self, path: str, probe_hid_version: bool = False) -> None:
self.path = path self.path = path
self.count = 0 self.handle = None # type: HidDeviceHandle
self.handle = None self.hid_version = None if probe_hid_version else 2
def open(self): def open(self) -> None:
if self.count == 0:
self.handle = hid.device() self.handle = hid.device()
try: try:
self.handle.open_path(self.path) self.handle.open_path(self.path)
@ -46,47 +48,67 @@ class HidHandle:
) )
raise e raise e
self.handle.set_nonblocking(True) self.handle.set_nonblocking(True)
self.count += 1
def close(self): if self.hid_version is None:
if self.count == 1: self.hid_version = self.probe_hid_version()
def close(self) -> None:
if self.handle is not None:
self.handle.close() self.handle.close()
if self.count > 0: self.handle = None
self.count -= 1
def write_chunk(self, chunk: bytes) -> None:
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
if self.hid_version == 2:
self.handle.write(b"\0" + bytearray(chunk))
else:
self.handle.write(chunk)
def read_chunk(self) -> bytes:
while True:
chunk = self.handle.read(64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return chunk
def probe_hid_version(self) -> int:
n = self.handle.write([0, 63] + [0xFF] * 63)
if n == 65:
return 2
n = self.handle.write([63] + [0xFF] * 63)
if n == 64:
return 1
raise TransportException("Unknown HID version")
class HidTransport(Transport): class HidTransport(ProtocolBasedTransport):
""" """
HidTransport implements transport over USB HID interface. HidTransport implements transport over USB HID interface.
""" """
PATH_PREFIX = "hid" PATH_PREFIX = "hid"
def __init__(self, device, protocol=None, hid_handle=None): def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None:
super(HidTransport, self).__init__()
if hid_handle is None: if hid_handle is None:
hid_handle = HidHandle(device["path"]) hid_handle = HidHandle(device["path"])
if protocol is None:
# force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
force_v1 = True
if is_trezor2(device) and not int(force_v1):
protocol = ProtocolV2()
else:
protocol = ProtocolV1()
self.device = device self.device = device
self.protocol = protocol
self.hid = hid_handle self.hid = hid_handle
self.hid_version = None
def get_path(self): protocol = get_protocol(hid_handle, is_trezor2(device))
super().__init__(protocol=protocol)
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode()) return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
@staticmethod @classmethod
def enumerate(debug=False): def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
devices = [] devices = []
for dev in hid.enumerate(0, 0): for dev in hid.enumerate(0, 0):
if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)): if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)):
@ -100,84 +122,35 @@ class HidTransport(Transport):
devices.append(HidTransport(dev)) devices.append(HidTransport(dev))
return devices return devices
def find_debug(self): def find_debug(self) -> "HidTransport":
if isinstance(self.protocol, ProtocolV2): if self.protocol.VERSION >= 2:
# For v2 protocol, lets use the same HID interface, but with a different session # use the same device
protocol = ProtocolV2() return self
debug = HidTransport(self.device, protocol, self.hid) else:
return debug
if isinstance(self.protocol, ProtocolV1):
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True): for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]: if debug.device["serial_number"] == self.device["serial_number"]:
return debug return debug
raise TransportException("Debug HID device not found") raise TransportException("Debug HID device not found")
def open(self):
self.hid.open()
if is_trezor1(self.device):
self.hid_version = self.probe_hid_version()
else:
self.hid_version = 2
self.protocol.session_begin(self)
def close(self): def is_trezor1(dev: HidDevice) -> bool:
self.protocol.session_end(self)
self.hid.close()
self.hid_version = None
def read(self):
return self.protocol.read(self)
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
if self.hid_version == 2:
self.hid.handle.write(b"\0" + bytearray(chunk))
else:
self.hid.handle.write(chunk)
def read_chunk(self):
while True:
chunk = self.hid.handle.read(64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk)
def probe_hid_version(self):
n = self.hid.handle.write([0, 63] + [0xFF] * 63)
if n == 65:
return 2
n = self.hid.handle.write([63] + [0xFF] * 63)
if n == 64:
return 1
raise TransportException("Unknown HID version")
def is_trezor1(dev):
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR1 return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR1
def is_trezor2(dev): def is_trezor2(dev: HidDevice) -> bool:
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2 return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2
def is_trezor2_bl(dev): def is_trezor2_bl(dev: HidDevice) -> bool:
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2_BL return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2_BL
def is_wirelink(dev): def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
def is_debuglink(dev): def is_debuglink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1 return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1

View File

@ -0,0 +1,348 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import os
import struct
from io import BytesIO
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import Transport
from .. import mapping, protobuf
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None:
...
def close(self) -> None:
...
def read_chunk(self) -> bytes:
...
def write_chunk(self, chunk: bytes) -> None:
...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
We declare a protocol version (we have implementations of v1 and v2).
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
VERSION = None # type: int
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self) -> protobuf.MessageType:
raise NotImplementedError
def write(self, message: protobuf.MessageType) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message: protobuf.MessageType) -> None:
self.protocol.write(message)
def read(self) -> protobuf.MessageType:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
VERSION = 1
def write(self, msg: protobuf.MessageType) -> None:
LOG.debug(
"sending message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
data = BytesIO()
protobuf.dump_message(data, msg)
ser = data.getvalue()
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
buffer = bytearray(b"##" + header + ser)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> protobuf.MessageType:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
# Strip padding
data = BytesIO(buffer[:datalen])
# Parse to protobuf
msg = protobuf.load_message(data, mapping.get_class(msg_type))
LOG.debug(
"received message: {}".format(msg.__class__.__name__),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + headerlen :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]
class ProtocolV2(Protocol):
"""Protocol version 2. Currently (11/2018) not used.
Intended to mimic U2F/WebAuthN session handling.
"""
VERSION = 2
def __init__(self, handle: Handle) -> None:
self.session = None
super().__init__(handle)
def begin_session(self) -> None:
# ensure open connection
super().begin_session()
# initiate session
chunk = struct.pack(">B", V2_BEGIN_SESSION)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
# get session identifier
resp = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BL")
magic, session = struct.unpack(">BL", resp[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_BEGIN_SESSION:
raise RuntimeError("Unexpected magic character")
self.session = session
LOG.debug("[session {}] session started".format(self.session))
def end_session(self) -> None:
if not self.session:
return
try:
chunk = struct.pack(">BL", V2_END_SESSION, self.session)
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
resp = self.handle.read_chunk()
(magic,) = struct.unpack(">B", resp[:1])
if magic != V2_END_SESSION:
raise RuntimeError("Expected session close")
LOG.debug("[session {}] session ended".format(self.session))
finally:
self.session = None
# close connection if appropriate
super().end_session()
def write(self, msg: protobuf.MessageType) -> None:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
LOG.debug(
"[session {}] sending message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
# Serialize whole message
data = BytesIO()
protobuf.dump_message(data, msg)
data = data.getvalue()
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
data = dataheader + data
seq = -1
# Write it out
while data:
if seq < 0:
repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session)
else:
repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq)
datalen = REPLEN - len(repheader)
chunk = repheader + data[:datalen]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
data = data[datalen:]
seq += 1
def read(self) -> protobuf.MessageType:
if not self.session:
raise RuntimeError("Missing session for v2 protocol")
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, chunk = self.read_first()
buffer.extend(chunk)
# Read the rest of the message
while len(buffer) < datalen:
next_chunk = self.read_next()
buffer.extend(next_chunk)
# Strip padding
buffer = BytesIO(buffer[:datalen])
# Parse to protobuf
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
LOG.debug(
"[session {}] received message: {}".format(
self.session, msg.__class__.__name__
),
extra={"protobuf": msg},
)
return msg
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLLL")
magic, session, msg_type, datalen = struct.unpack(
">BLLL", chunk[:headerlen]
)
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_FIRST_CHUNK:
raise RuntimeError("Unexpected magic character")
if session != self.session:
raise RuntimeError("Session id mismatch")
return msg_type, datalen, chunk[headerlen:]
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
try:
headerlen = struct.calcsize(">BLL")
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
except Exception:
raise RuntimeError("Cannot parse header")
if magic != V2_NEXT_CHUNK:
raise RuntimeError("Unexpected magic characters")
if session != self.session:
raise RuntimeError("Session id mismatch")
return chunk[headerlen:]
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
"""Make a Protocol instance for the given handle.
Each transport can have a preference for using a particular protocol version.
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
which forces the library to use V1 anyways.
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
for it (i.e., USB transports for Trezor T).
"""
force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1))
if want_v2 and not force_v1:
return ProtocolV2(handle)
else:
return ProtocolV1(handle)

View File

@ -15,20 +15,19 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import socket import socket
from typing import Iterable, cast
from . import Transport, TransportException from . import TransportException
from ..protocol_v1 import ProtocolV1 from .protocol import ProtocolBasedTransport, get_protocol
class UdpTransport(Transport): class UdpTransport(ProtocolBasedTransport):
DEFAULT_HOST = "127.0.0.1" DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324 DEFAULT_PORT = 21324
PATH_PREFIX = "udp" PATH_PREFIX = "udp"
def __init__(self, device=None, protocol=None): def __init__(self, device: str = None) -> None:
super(UdpTransport, self).__init__()
if not device: if not device:
host = UdpTransport.DEFAULT_HOST host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT port = UdpTransport.DEFAULT_PORT
@ -36,21 +35,21 @@ class UdpTransport(Transport):
devparts = device.split(":") devparts = device.split(":")
host = devparts[0] host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
if not protocol:
protocol = ProtocolV1()
self.device = (host, port) self.device = (host, port)
self.protocol = protocol self.socket = None # type: Optional[socket.socket]
self.socket = None
def get_path(self): protocol = get_protocol(self, want_v2=False)
return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device) super().__init__(protocol=protocol)
def find_debug(self): def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device host, port = self.device
return UdpTransport("{}:{}".format(host, port + 1), self.protocol) return UdpTransport("{}:{}".format(host, port + 1))
@classmethod @classmethod
def _try_path(cls, path): def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path) d = cls(path)
try: try:
d.open() d.open()
@ -64,7 +63,7 @@ class UdpTransport(Transport):
d.close() d.close()
@classmethod @classmethod
def enumerate(cls): def enumerate(cls) -> Iterable["UdpTransport"]:
default_path = "{}:{}".format(cls.DEFAULT_HOST, cls.DEFAULT_PORT) default_path = "{}:{}".format(cls.DEFAULT_HOST, cls.DEFAULT_PORT)
try: try:
return [cls._try_path(default_path)] return [cls._try_path(default_path)]
@ -72,27 +71,29 @@ class UdpTransport(Transport):
return [] return []
@classmethod @classmethod
def find_by_path(cls, path, prefix_search=False): def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
if prefix_search: if prefix_search:
return super().find_by_path(path, prefix_search) return cast(UdpTransport, super().find_by_path(path, prefix_search))
# This is *technically* type-able: mark `find_by_path` as returning
# the same type from which `cls` comes from.
# Mypy can't handle that though, so here we are.
else: else:
path = path.replace("{}:".format(cls.PATH_PREFIX), "") path = path.replace("{}:".format(cls.PATH_PREFIX), "")
return cls._try_path(path) return cls._try_path(path)
def open(self): def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device) self.socket.connect(self.device)
self.socket.settimeout(10) self.socket.settimeout(10)
self.protocol.session_begin(self)
def close(self): def close(self) -> None:
if self.socket: if self.socket is not None:
self.protocol.session_end(self)
self.socket.close() self.socket.close()
self.socket = None self.socket = None
def _ping(self): def _ping(self) -> bool:
"""Test if the device is listening.""" """Test if the device is listening."""
assert self.socket is not None
resp = None resp = None
try: try:
self.socket.sendall(b"PINGPING") self.socket.sendall(b"PINGPING")
@ -101,18 +102,14 @@ class UdpTransport(Transport):
pass pass
return resp == b"PONGPONG" return resp == b"PONGPONG"
def read(self): def write_chunk(self, chunk: bytes) -> None:
return self.protocol.read(self) assert self.socket is not None
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected data length") raise TransportException("Unexpected data length")
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self): def read_chunk(self) -> bytes:
assert self.socket is not None
while True: while True:
try: try:
chunk = self.socket.recv(64) chunk = self.socket.recv(64)

View File

@ -17,12 +17,12 @@
import atexit import atexit
import sys import sys
import time import time
from typing import Iterable
import usb1 import usb1
from . import Transport, TransportException from . import TransportException
from ..protocol_v1 import ProtocolV1 from .protocol import ProtocolBasedTransport, get_protocol
from ..protocol_v2 import ProtocolV2
DEV_TREZOR1 = (0x534C, 0x0001) DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1) DEV_TREZOR2 = (0x1209, 0x53C1)
@ -35,13 +35,14 @@ DEBUG_ENDPOINT = 2
class WebUsbHandle: class WebUsbHandle:
def __init__(self, device): def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
self.device = device self.device = device
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0 self.count = 0
self.handle = None self.handle = None # type: Optional[usb1.USBDeviceHandle]
def open(self, interface): def open(self) -> None:
if self.count == 0:
self.handle = self.device.open() self.handle = self.device.open()
if self.handle is None: if self.handle is None:
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
@ -51,18 +52,35 @@ class WebUsbHandle:
else: else:
args = () args = ()
raise IOError("Cannot open device", *args) raise IOError("Cannot open device", *args)
self.handle.claimInterface(interface) self.handle.claimInterface(self.interface)
self.count += 1
def close(self, interface): def close(self) -> None:
if self.count == 1: if self.handle is not None:
self.handle.releaseInterface(interface) self.handle.releaseInterface(self.interface)
self.handle.close() self.handle.close()
if self.count > 0: self.handle = None
self.count -= 1
def write_chunk(self, chunk: bytes) -> None:
assert self.handle is not None
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
self.handle.interruptWrite(self.endpoint, chunk)
def read_chunk(self) -> bytes:
assert self.handle is not None
endpoint = 0x80 | self.endpoint
while True:
chunk = self.handle.interruptRead(endpoint, 64)
if chunk:
break
else:
time.sleep(0.001)
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return chunk
class WebUsbTransport(Transport): class WebUsbTransport(ProtocolBasedTransport):
""" """
WebUsbTransport implements transport over WebUSB interface. WebUsbTransport implements transport over WebUSB interface.
""" """
@ -70,31 +88,24 @@ class WebUsbTransport(Transport):
PATH_PREFIX = "webusb" PATH_PREFIX = "webusb"
context = None context = None
def __init__(self, device, protocol=None, handle=None, debug=False): def __init__(
super(WebUsbTransport, self).__init__() self, device: str, handle: WebUsbHandle = None, debug: bool = False
) -> None:
if handle is None: if handle is None:
handle = WebUsbHandle(device) handle = WebUsbHandle(device, debug)
if protocol is None:
# force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
force_v1 = True
if is_trezor2(device) and not int(force_v1):
protocol = ProtocolV2()
else:
protocol = ProtocolV1()
self.device = device self.device = device
self.protocol = protocol
self.handle = handle self.handle = handle
self.debug = debug self.debug = debug
def get_path(self): protocol = get_protocol(handle, is_trezor2(device))
super().__init__(protocol=protocol)
def get_path(self) -> str:
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device)) return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
@classmethod @classmethod
def enumerate(cls): def enumerate(cls) -> Iterable["WebUsbTransport"]:
if cls.context is None: if cls.context is None:
cls.context = usb1.USBContext() cls.context = usb1.USBContext()
cls.context.open() cls.context.open()
@ -117,69 +128,30 @@ class WebUsbTransport(Transport):
pass pass
return devices return devices
def find_debug(self): def find_debug(self) -> "WebUsbTransport":
if isinstance(self.protocol, ProtocolV2): if self.protocol.VERSION >= 2:
# TODO test this # TODO test this
# For v2 protocol, lets use the same WebUSB interface, but with a different session # XXX this is broken right now because sessions don't really work
protocol = ProtocolV2() # For v2 protocol, use the same WebUSB interface with a different session
debug = WebUsbTransport(self.device, protocol, self.handle) return WebUsbTransport(self.device, self.handle)
return debug
if isinstance(self.protocol, ProtocolV1):
# For v1 protocol, find debug USB interface for the same serial number
protocol = ProtocolV1()
debug = WebUsbTransport(self.device, protocol, None, True)
return debug
raise TransportException("Debug WebUSB device not found")
def open(self):
interface = DEBUG_INTERFACE if self.debug else INTERFACE
self.handle.open(interface)
self.protocol.session_begin(self)
def close(self):
interface = DEBUG_INTERFACE if self.debug else INTERFACE
self.protocol.session_end(self)
self.handle.close(interface)
def read(self):
return self.protocol.read(self)
def write(self, msg):
return self.protocol.write(self, msg)
def write_chunk(self, chunk):
endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT
if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk))
self.handle.handle.interruptWrite(endpoint, chunk)
def read_chunk(self):
endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT
endpoint = 0x80 | endpoint
while True:
chunk = self.handle.handle.interruptRead(endpoint, 64)
if chunk:
break
else: else:
time.sleep(0.001) # For v1 protocol, find debug USB interface for the same serial number
if len(chunk) != 64: return WebUsbTransport(self.device, debug=True)
raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk)
def is_trezor1(dev): def is_trezor1(dev: usb1.USBDevice) -> bool:
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR1 return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR1
def is_trezor2(dev): def is_trezor2(dev: usb1.USBDevice) -> bool:
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2 return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2
def is_trezor2_bl(dev): def is_trezor2_bl(dev: usb1.USBDevice) -> bool:
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL
def is_vendor_class(dev): def is_vendor_class(dev: usb1.USBDevice) -> bool:
configurationId = 0 configurationId = 0
altSettingId = 0 altSettingId = 0
return ( return (
@ -188,7 +160,7 @@ def is_vendor_class(dev):
) )
def dev_to_str(dev): def dev_to_str(dev: usb1.USBDevice) -> str:
return ":".join( return ":".join(
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList() str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
) )