mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-08 14:31:06 +00:00
trezorlib: transport/protocol reshuffle
This commit breaks session handling (which matters with Bridge) and regresses Bridge to an older code state. Both of these issues will be rectified in subsequent commits. Explanation of this big API reshuffle follows: * protocols are moved to trezorlib.transport, and to a single common file. * there is a cleaner definition of Transport and Protocol API (see below) * fully valid mypy type hinting * session handle counters and open handle counters mostly went away. Transports and Protocols are meant to be "raw" APIs; TrezorClient will implement context-handler-based sessions, session tracking, etc. I'm calling this a "reshuffle" because it involved very small number of code changes. Most of it is moving things around where they sit better. The API changes are as follows. Transport is now a thing that can: * open and close sessions * read and write protobuf messages * enumerate and find devices Some transports (all except bridge) are technically bytes-based and need a separate protocol implementation (because we have two existing protocols, although only the first one is actually used). Hence a protocol superclass. Protocol is a thing that *also* can: * open and close sessions * read and write protobuf messages For that, it requires a `handle`. Handle is a physical layer for a protocol. It can: * open and close some sort of device connection (this is distinct from session! Connection is a channel over which you can send data. Session is a logical arrangement on top of that; you can have multiple sessions on a single connection.) * read and write 64-byte chunks of data With that, we introduce ProtocolBasedTransport, which simply delegates the appropriate Transport functionality to respective Protocol methods. hid and webusb transports are ProtocolBasedTransport-s that provide separate device handles. HidHandle and WebUsbHandle existed before, but the distinction of functionality between a Transport and its Handle was unclear. Some methods were moved and now the handles implement the Handle API, while the transports provide the enumeration parts of the Transport API, as well as glue between the respective Protocols and Handles. udp transport is also a ProtocolBasedTransport, but it acts as its own handle. (That might be changed. For now, I went with the pre-existing structure.) In addition, session_begin/end is renamed to begin/end_session to keep consistent verb_noun naming.
This commit is contained in:
parent
560a5215c5
commit
aac7726824
@ -5,3 +5,4 @@ click>=7,<8
|
||||
pyblake2>=0.9.3
|
||||
libusb1>=1.6.4
|
||||
construct>=2.9
|
||||
typing_extensions>=3.6
|
||||
|
@ -25,10 +25,10 @@ from .tools import expect
|
||||
class DebugLink:
|
||||
def __init__(self, transport):
|
||||
self.transport = transport
|
||||
self.transport.session_begin()
|
||||
self.transport.begin_session()
|
||||
|
||||
def close(self):
|
||||
self.transport.session_end()
|
||||
self.transport.end_session()
|
||||
|
||||
def _call(self, msg, nowait=False):
|
||||
self.transport.write(msg)
|
||||
|
@ -1,91 +0,0 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2018 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from . import mapping, protobuf
|
||||
from .transport import Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV1:
|
||||
def session_begin(self, transport: Transport) -> None:
|
||||
pass
|
||||
|
||||
def session_end(self, transport: Transport) -> None:
|
||||
pass
|
||||
|
||||
def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
|
||||
LOG.debug(
|
||||
"sending message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
data = BytesIO()
|
||||
protobuf.dump_message(data, msg)
|
||||
ser = data.getvalue()
|
||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
||||
data = bytearray(b"##" + header + ser)
|
||||
|
||||
while data:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + data[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
data = data[63:]
|
||||
|
||||
def read(self, transport: Transport) -> protobuf.MessageType:
|
||||
# Read header with first part of message data
|
||||
chunk = transport.read_chunk()
|
||||
msg_type, datalen, data = self.parse_first(chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(data) < datalen:
|
||||
chunk = transport.read_chunk()
|
||||
data.extend(self.parse_next(chunk))
|
||||
|
||||
# Strip padding
|
||||
data = BytesIO(data[:datalen])
|
||||
|
||||
# Parse to protobuf
|
||||
msg = protobuf.load_message(data, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"received message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + headerlen :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def parse_next(self, chunk: bytes) -> bytes:
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
@ -1,147 +0,0 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2018 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from . import mapping, protobuf
|
||||
from .transport import Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV2:
|
||||
def __init__(self) -> None:
|
||||
self.session = None
|
||||
|
||||
def session_begin(self, transport: Transport) -> None:
|
||||
chunk = struct.pack(">B", 0x03)
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
resp = transport.read_chunk()
|
||||
self.session = self.parse_session_open(resp)
|
||||
LOG.debug("[session {}] session started".format(self.session))
|
||||
|
||||
def session_end(self, transport: Transport) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
chunk = struct.pack(">BL", 0x04, self.session)
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
resp = transport.read_chunk()
|
||||
(magic,) = struct.unpack(">B", resp[:1])
|
||||
if magic != 0x04:
|
||||
raise RuntimeError("Expected session close")
|
||||
LOG.debug("[session {}] session ended".format(self.session))
|
||||
self.session = None
|
||||
|
||||
def write(self, transport: Transport, msg: protobuf.MessageType) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("Missing session for v2 protocol")
|
||||
|
||||
LOG.debug(
|
||||
"[session {}] sending message: {}".format(
|
||||
self.session, msg.__class__.__name__
|
||||
),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
# Serialize whole message
|
||||
data = BytesIO()
|
||||
protobuf.dump_message(data, msg)
|
||||
data = data.getvalue()
|
||||
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
|
||||
data = dataheader + data
|
||||
seq = -1
|
||||
|
||||
# Write it out
|
||||
while data:
|
||||
if seq < 0:
|
||||
repheader = struct.pack(">BL", 0x01, self.session)
|
||||
else:
|
||||
repheader = struct.pack(">BLL", 0x02, self.session, seq)
|
||||
datalen = REPLEN - len(repheader)
|
||||
chunk = repheader + data[:datalen]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
data = data[datalen:]
|
||||
seq += 1
|
||||
|
||||
def read(self, transport: Transport) -> protobuf.MessageType:
|
||||
if not self.session:
|
||||
raise RuntimeError("Missing session for v2 protocol")
|
||||
|
||||
# Read header with first part of message data
|
||||
chunk = transport.read_chunk()
|
||||
msg_type, datalen, data = self.parse_first(chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(data) < datalen:
|
||||
chunk = transport.read_chunk()
|
||||
next_data = self.parse_next(chunk)
|
||||
data.extend(next_data)
|
||||
|
||||
# Strip padding
|
||||
data = BytesIO(data[:datalen])
|
||||
|
||||
# Parse to protobuf
|
||||
msg = protobuf.load_message(data, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"[session {}] received message: {}".format(
|
||||
self.session, msg.__class__.__name__
|
||||
),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def parse_first(self, chunk: bytes) -> Tuple[int, int, bytes]:
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLLL")
|
||||
magic, session, msg_type, datalen = struct.unpack(
|
||||
">BLLL", chunk[:headerlen]
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != 0x01:
|
||||
raise RuntimeError("Unexpected magic character")
|
||||
if session != self.session:
|
||||
raise RuntimeError("Session id mismatch")
|
||||
return msg_type, datalen, chunk[headerlen:]
|
||||
|
||||
def parse_next(self, chunk: bytes) -> bytes:
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLL")
|
||||
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != 0x02:
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
if session != self.session:
|
||||
raise RuntimeError("Session id mismatch")
|
||||
return chunk[headerlen:]
|
||||
|
||||
def parse_session_open(self, chunk: bytes) -> int:
|
||||
try:
|
||||
headerlen = struct.calcsize(">BL")
|
||||
magic, session = struct.unpack(">BL", chunk[:headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != 0x03:
|
||||
raise RuntimeError("Unexpected magic character")
|
||||
return session
|
@ -41,10 +41,10 @@ class TrezorTest:
|
||||
# self.client.set_buttonwait(3)
|
||||
|
||||
device.wipe(self.client)
|
||||
self.client.transport.session_begin()
|
||||
self.client.transport.begin_session()
|
||||
|
||||
def teardown_method(self, method):
|
||||
self.client.transport.session_end()
|
||||
self.client.transport.end_session()
|
||||
self.client.close()
|
||||
|
||||
def _setup_mnemonic(self, mnemonic=None, pin="", passphrase=False):
|
||||
|
@ -55,11 +55,11 @@ def client():
|
||||
wirelink = get_device()
|
||||
client = TrezorClientDebugLink(wirelink)
|
||||
wipe_device(client)
|
||||
client.transport.session_begin()
|
||||
client.transport.begin_session()
|
||||
|
||||
yield client
|
||||
|
||||
client.transport.session_end()
|
||||
client.transport.end_session()
|
||||
|
||||
# XXX debuglink session must also be closed
|
||||
# client.close accomplishes that for now; going forward, there should
|
||||
|
@ -222,14 +222,13 @@ def session(f):
|
||||
# Decorator wraps a BaseClient method
|
||||
# with session activation / deactivation
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
def wrapped_f(client, *args, **kwargs):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
client = args[0]
|
||||
client.transport.session_begin()
|
||||
client.transport.begin_session()
|
||||
try:
|
||||
return f(*args, **kwargs)
|
||||
return f(client, *args, **kwargs)
|
||||
finally:
|
||||
client.transport.session_end()
|
||||
client.transport.begin_session()
|
||||
|
||||
return wrapped_f
|
||||
|
||||
|
@ -18,6 +18,8 @@ import importlib
|
||||
import logging
|
||||
from typing import Iterable, Type
|
||||
|
||||
from ..protobuf import MessageType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -25,38 +27,49 @@ class TransportException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Transport(object):
|
||||
def __init__(self):
|
||||
self.session_counter = 0
|
||||
class Transport:
|
||||
"""Raw connection to a Trezor device.
|
||||
|
||||
def __str__(self):
|
||||
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
|
||||
or USB-HID connection, or UDP socket of listening emulator(s).
|
||||
It can also enumerate devices available over this communication link, and return
|
||||
them as instances.
|
||||
|
||||
Transport instance is a thing that:
|
||||
- can be identified and requested by a string URI-like path
|
||||
- can open and close sessions, which enclose related operations
|
||||
- can read and write protobuf messages
|
||||
|
||||
You need to implement a new Transport subclass if you invent a new way to connect
|
||||
a Trezor device to a computer.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = None # type: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.get_path()
|
||||
|
||||
def get_path(self):
|
||||
return "{}:{}".format(self.PATH_PREFIX, self.device)
|
||||
|
||||
def session_begin(self):
|
||||
if self.session_counter == 0:
|
||||
self.open()
|
||||
self.session_counter += 1
|
||||
|
||||
def session_end(self):
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.close()
|
||||
|
||||
def open(self):
|
||||
def get_path(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
def begin_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self) -> MessageType:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message: MessageType) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls):
|
||||
def enumerate(cls) -> Iterable["Transport"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path, prefix_search=False):
|
||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "Transport":
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
|
@ -14,9 +14,11 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import binascii
|
||||
import logging
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import requests
|
||||
|
||||
@ -28,6 +30,10 @@ LOG = logging.getLogger(__name__)
|
||||
TREZORD_HOST = "http://127.0.0.1:21325"
|
||||
|
||||
|
||||
def get_error(resp: requests.Response) -> str:
|
||||
return " (error=%d str=%s)" % (resp.status_code, resp.json()["error"])
|
||||
|
||||
|
||||
class BridgeTransport(Transport):
|
||||
"""
|
||||
BridgeTransport implements transport through TREZOR Bridge (aka trezord).
|
||||
@ -36,91 +42,81 @@ class BridgeTransport(Transport):
|
||||
PATH_PREFIX = "bridge"
|
||||
HEADERS = {"Origin": "https://python.trezor.io"}
|
||||
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
|
||||
def __init__(self, device: Dict[str, Any]) -> None:
|
||||
self.device = device
|
||||
self.conn = requests.Session()
|
||||
self.session = None
|
||||
self.request = None
|
||||
self.session = None # type: Optional[str]
|
||||
self.response = None # type: Optional[str]
|
||||
|
||||
def get_path(self):
|
||||
def get_path(self) -> str:
|
||||
return "%s:%s" % (self.PATH_PREFIX, self.device["path"])
|
||||
|
||||
@classmethod
|
||||
def _call(cls, action, data=None, uri_suffix=None, session=None):
|
||||
if uri_suffix is not None:
|
||||
uri_suffix = "/" + uri_suffix
|
||||
elif session is not None:
|
||||
uri_suffix = "/{}".format(session)
|
||||
else:
|
||||
uri_suffix = ""
|
||||
|
||||
url = "{}/{}{}".format(TREZORD_HOST, action, uri_suffix)
|
||||
r = requests.post(url, headers=cls.HEADERS, data=data)
|
||||
|
||||
if r.status_code != 200:
|
||||
raise TransportException(
|
||||
"trezord: '{}' action failed with code {}: {}".format(
|
||||
action, r.status_code, r.json().get("error", "(no error message)")
|
||||
)
|
||||
)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls):
|
||||
def enumerate(cls) -> Iterable["BridgeTransport"]:
|
||||
try:
|
||||
r = cls._call("enumerate")
|
||||
r = requests.post(TREZORD_HOST + "/enumerate", headers=cls.HEADERS)
|
||||
if r.status_code != 200:
|
||||
raise TransportException(
|
||||
"trezord: Could not enumerate devices" + get_error(r)
|
||||
)
|
||||
return [BridgeTransport(dev) for dev in r.json()]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def open(self):
|
||||
r = self._call("acquire", uri_suffix="{}/null".format(self.device["path"]))
|
||||
def begin_session(self) -> None:
|
||||
r = self.conn.post(
|
||||
TREZORD_HOST + "/acquire/%s/null" % self.device["path"],
|
||||
headers=self.HEADERS,
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise TransportException(
|
||||
"trezord: Could not acquire session" + get_error(r)
|
||||
)
|
||||
self.session = r.json()["session"]
|
||||
|
||||
def close(self):
|
||||
def end_session(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
self._call("release", session=self.session)
|
||||
r = self.conn.post(
|
||||
TREZORD_HOST + "/release/%s" % self.session, headers=self.HEADERS
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise TransportException(
|
||||
"trezord: Could not release session" + get_error(r)
|
||||
)
|
||||
self.session = None
|
||||
|
||||
def write(self, msg):
|
||||
if self.request is not None:
|
||||
raise TransportException("trezord can't perform two writes without a read")
|
||||
|
||||
def write(self, msg: protobuf.MessageType) -> None:
|
||||
LOG.debug(
|
||||
"preparing message: {}".format(msg.__class__.__name__),
|
||||
"sending message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
# encode the message
|
||||
data = BytesIO()
|
||||
protobuf.dump_message(data, msg)
|
||||
ser = data.getvalue()
|
||||
buffer = BytesIO()
|
||||
protobuf.dump_message(buffer, msg)
|
||||
ser = buffer.getvalue()
|
||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
||||
# store for later
|
||||
self.request = (header + ser).hex()
|
||||
data = binascii.hexlify(header + ser).decode()
|
||||
r = self.conn.post( # type: ignore # typeshed bug
|
||||
TREZORD_HOST + "/call/%s" % self.session, data=data, headers=self.HEADERS
|
||||
)
|
||||
if r.status_code != 200:
|
||||
raise TransportException("trezord: Could not write message" + get_error(r))
|
||||
self.response = r.text
|
||||
|
||||
def read(self):
|
||||
if self.request is None:
|
||||
raise TransportException("trezord: no request in queue")
|
||||
|
||||
try:
|
||||
LOG.debug("sending prepared message")
|
||||
r = self._call("call", data=self.request, session=self.session)
|
||||
|
||||
data = bytes.fromhex(r.text)
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
data = BytesIO(data[headerlen : headerlen + datalen])
|
||||
msg = protobuf.load_message(data, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"received message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
finally:
|
||||
self.request = None
|
||||
def read(self) -> protobuf.MessageType:
|
||||
if self.response is None:
|
||||
raise TransportException("No response stored")
|
||||
data = binascii.unhexlify(self.response)
|
||||
headerlen = struct.calcsize(">HL")
|
||||
(msg_type, datalen) = struct.unpack(">HL", data[:headerlen])
|
||||
buffer = BytesIO(data[headerlen : headerlen + datalen])
|
||||
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"received message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
self.response = None
|
||||
return msg
|
||||
|
||||
|
||||
TRANSPORT = BridgeTransport
|
||||
|
@ -16,77 +16,99 @@
|
||||
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import hid
|
||||
|
||||
from . import Transport, TransportException
|
||||
from ..protocol_v1 import ProtocolV1
|
||||
from ..protocol_v2 import ProtocolV2
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, get_protocol
|
||||
|
||||
DEV_TREZOR1 = (0x534C, 0x0001)
|
||||
DEV_TREZOR2 = (0x1209, 0x53C1)
|
||||
DEV_TREZOR2_BL = (0x1209, 0x53C0)
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
|
||||
|
||||
class HidHandle:
|
||||
def __init__(self, path):
|
||||
def __init__(self, path: str, probe_hid_version: bool = False) -> None:
|
||||
self.path = path
|
||||
self.count = 0
|
||||
self.handle = None # type: HidDeviceHandle
|
||||
self.hid_version = None if probe_hid_version else 2
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (
|
||||
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
|
||||
)
|
||||
raise e
|
||||
self.handle.set_nonblocking(True)
|
||||
|
||||
if self.hid_version is None:
|
||||
self.hid_version = self.probe_hid_version()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
def open(self):
|
||||
if self.count == 0:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (
|
||||
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
|
||||
)
|
||||
raise e
|
||||
self.handle.set_nonblocking(True)
|
||||
self.count += 1
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
|
||||
def close(self):
|
||||
if self.count == 1:
|
||||
self.handle.close()
|
||||
if self.count > 0:
|
||||
self.count -= 1
|
||||
if self.hid_version == 2:
|
||||
self.handle.write(b"\0" + bytearray(chunk))
|
||||
else:
|
||||
self.handle.write(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
while True:
|
||||
chunk = self.handle.read(64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return chunk
|
||||
|
||||
def probe_hid_version(self) -> int:
|
||||
n = self.handle.write([0, 63] + [0xFF] * 63)
|
||||
if n == 65:
|
||||
return 2
|
||||
n = self.handle.write([63] + [0xFF] * 63)
|
||||
if n == 64:
|
||||
return 1
|
||||
raise TransportException("Unknown HID version")
|
||||
|
||||
|
||||
class HidTransport(Transport):
|
||||
class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
|
||||
def __init__(self, device, protocol=None, hid_handle=None):
|
||||
super(HidTransport, self).__init__()
|
||||
|
||||
def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None:
|
||||
if hid_handle is None:
|
||||
hid_handle = HidHandle(device["path"])
|
||||
|
||||
if protocol is None:
|
||||
# force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
|
||||
force_v1 = True
|
||||
|
||||
if is_trezor2(device) and not int(force_v1):
|
||||
protocol = ProtocolV2()
|
||||
else:
|
||||
protocol = ProtocolV1()
|
||||
|
||||
self.device = device
|
||||
self.protocol = protocol
|
||||
self.hid = hid_handle
|
||||
self.hid_version = None
|
||||
|
||||
def get_path(self):
|
||||
protocol = get_protocol(hid_handle, is_trezor2(device))
|
||||
super().__init__(protocol=protocol)
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "%s:%s" % (self.PATH_PREFIX, self.device["path"].decode())
|
||||
|
||||
@staticmethod
|
||||
def enumerate(debug=False):
|
||||
@classmethod
|
||||
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
|
||||
devices = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)):
|
||||
@ -100,84 +122,35 @@ class HidTransport(Transport):
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self):
|
||||
if isinstance(self.protocol, ProtocolV2):
|
||||
# For v2 protocol, lets use the same HID interface, but with a different session
|
||||
protocol = ProtocolV2()
|
||||
debug = HidTransport(self.device, protocol, self.hid)
|
||||
return debug
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
def find_debug(self) -> "HidTransport":
|
||||
if self.protocol.VERSION >= 2:
|
||||
# use the same device
|
||||
return self
|
||||
else:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
def open(self):
|
||||
self.hid.open()
|
||||
if is_trezor1(self.device):
|
||||
self.hid_version = self.probe_hid_version()
|
||||
else:
|
||||
self.hid_version = 2
|
||||
self.protocol.session_begin(self)
|
||||
|
||||
def close(self):
|
||||
self.protocol.session_end(self)
|
||||
self.hid.close()
|
||||
self.hid_version = None
|
||||
|
||||
def read(self):
|
||||
return self.protocol.read(self)
|
||||
|
||||
def write(self, msg):
|
||||
return self.protocol.write(self, msg)
|
||||
|
||||
def write_chunk(self, chunk):
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
if self.hid_version == 2:
|
||||
self.hid.handle.write(b"\0" + bytearray(chunk))
|
||||
else:
|
||||
self.hid.handle.write(chunk)
|
||||
|
||||
def read_chunk(self):
|
||||
while True:
|
||||
chunk = self.hid.handle.read(64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return bytearray(chunk)
|
||||
|
||||
def probe_hid_version(self):
|
||||
n = self.hid.handle.write([0, 63] + [0xFF] * 63)
|
||||
if n == 65:
|
||||
return 2
|
||||
n = self.hid.handle.write([63] + [0xFF] * 63)
|
||||
if n == 64:
|
||||
return 1
|
||||
raise TransportException("Unknown HID version")
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
|
||||
def is_trezor1(dev):
|
||||
def is_trezor1(dev: HidDevice) -> bool:
|
||||
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR1
|
||||
|
||||
|
||||
def is_trezor2(dev):
|
||||
def is_trezor2(dev: HidDevice) -> bool:
|
||||
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2
|
||||
|
||||
|
||||
def is_trezor2_bl(dev):
|
||||
def is_trezor2_bl(dev: HidDevice) -> bool:
|
||||
return (dev["vendor_id"], dev["product_id"]) == DEV_TREZOR2_BL
|
||||
|
||||
|
||||
def is_wirelink(dev):
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||
|
||||
|
||||
def is_debuglink(dev):
|
||||
def is_debuglink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1
|
||||
|
||||
|
||||
|
348
trezorlib/transport/protocol.py
Normal file
348
trezorlib/transport/protocol.py
Normal file
@ -0,0 +1,348 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2018 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from . import Transport
|
||||
from .. import mapping, protobuf
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
V2_FIRST_CHUNK = 0x01
|
||||
V2_NEXT_CHUNK = 0x02
|
||||
V2_BEGIN_SESSION = 0x03
|
||||
V2_END_SESSION = 0x04
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Handle(StructuralType):
|
||||
"""PEP 544 structural type for Handle functionality.
|
||||
(called a "Protocol" in the proposed PEP, name which is impractical here)
|
||||
|
||||
Handle is a "physical" layer for a protocol.
|
||||
It can open/close a connection and read/write bare data in 64-byte chunks.
|
||||
|
||||
Functionally we gain nothing from making this an (abstract) base class for handle
|
||||
implementations, so this definition is for type hinting purposes only. You can,
|
||||
but don't have to, inherit from it.
|
||||
"""
|
||||
|
||||
def open(self) -> None:
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
...
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
...
|
||||
|
||||
|
||||
class Protocol:
|
||||
"""Wire protocol that can communicate with a Trezor device, given a Handle.
|
||||
|
||||
A Protocol implements the part of the Transport API that relates to communicating
|
||||
logical messages over a physical layer. It is a thing that can:
|
||||
- open and close sessions,
|
||||
- send and receive protobuf messages,
|
||||
given the ability to:
|
||||
- open and close physical connections,
|
||||
- and send and receive binary chunks.
|
||||
|
||||
We declare a protocol version (we have implementations of v1 and v2).
|
||||
For now, the class also handles session counting and opening the underlying Handle.
|
||||
This will probably be removed in the future.
|
||||
|
||||
We will need a new Protocol class if we change the way a Trezor device encapsulates
|
||||
its messages.
|
||||
"""
|
||||
|
||||
VERSION = None # type: int
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.handle = handle
|
||||
self.session_counter = 0
|
||||
|
||||
def begin_session(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.handle.open()
|
||||
self.session_counter += 1
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message: protobuf.MessageType) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolBasedTransport(Transport):
|
||||
"""Transport that implements its communications through a Protocol.
|
||||
|
||||
Intended as a base class for implementations that proxy their communication
|
||||
operations to a Protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: Protocol) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def write(self, message: protobuf.MessageType) -> None:
|
||||
self.protocol.write(message)
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
return self.protocol.read()
|
||||
|
||||
def begin_session(self) -> None:
|
||||
self.protocol.begin_session()
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.protocol.end_session()
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
def write(self, msg: protobuf.MessageType) -> None:
|
||||
LOG.debug(
|
||||
"sending message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
data = BytesIO()
|
||||
protobuf.dump_message(data, msg)
|
||||
ser = data.getvalue()
|
||||
header = struct.pack(">HL", mapping.get_type(msg), len(ser))
|
||||
buffer = bytearray(b"##" + header + ser)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
# Strip padding
|
||||
data = BytesIO(buffer[:datalen])
|
||||
|
||||
# Parse to protobuf
|
||||
msg = protobuf.load_message(data, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"received message: {}".format(msg.__class__.__name__),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + headerlen :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
||||
|
||||
|
||||
class ProtocolV2(Protocol):
|
||||
"""Protocol version 2. Currently (11/2018) not used.
|
||||
Intended to mimic U2F/WebAuthN session handling.
|
||||
"""
|
||||
|
||||
VERSION = 2
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.session = None
|
||||
super().__init__(handle)
|
||||
|
||||
def begin_session(self) -> None:
|
||||
# ensure open connection
|
||||
super().begin_session()
|
||||
|
||||
# initiate session
|
||||
chunk = struct.pack(">B", V2_BEGIN_SESSION)
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
|
||||
# get session identifier
|
||||
resp = self.handle.read_chunk()
|
||||
try:
|
||||
headerlen = struct.calcsize(">BL")
|
||||
magic, session = struct.unpack(">BL", resp[:headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != V2_BEGIN_SESSION:
|
||||
raise RuntimeError("Unexpected magic character")
|
||||
self.session = session
|
||||
|
||||
LOG.debug("[session {}] session started".format(self.session))
|
||||
|
||||
def end_session(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
|
||||
try:
|
||||
chunk = struct.pack(">BL", V2_END_SESSION, self.session)
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
resp = self.handle.read_chunk()
|
||||
(magic,) = struct.unpack(">B", resp[:1])
|
||||
if magic != V2_END_SESSION:
|
||||
raise RuntimeError("Expected session close")
|
||||
LOG.debug("[session {}] session ended".format(self.session))
|
||||
finally:
|
||||
self.session = None
|
||||
# close connection if appropriate
|
||||
super().end_session()
|
||||
|
||||
def write(self, msg: protobuf.MessageType) -> None:
|
||||
if not self.session:
|
||||
raise RuntimeError("Missing session for v2 protocol")
|
||||
|
||||
LOG.debug(
|
||||
"[session {}] sending message: {}".format(
|
||||
self.session, msg.__class__.__name__
|
||||
),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
# Serialize whole message
|
||||
data = BytesIO()
|
||||
protobuf.dump_message(data, msg)
|
||||
data = data.getvalue()
|
||||
dataheader = struct.pack(">LL", mapping.get_type(msg), len(data))
|
||||
data = dataheader + data
|
||||
seq = -1
|
||||
|
||||
# Write it out
|
||||
while data:
|
||||
if seq < 0:
|
||||
repheader = struct.pack(">BL", V2_FIRST_CHUNK, self.session)
|
||||
else:
|
||||
repheader = struct.pack(">BLL", V2_NEXT_CHUNK, self.session, seq)
|
||||
datalen = REPLEN - len(repheader)
|
||||
chunk = repheader + data[:datalen]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
data = data[datalen:]
|
||||
seq += 1
|
||||
|
||||
def read(self) -> protobuf.MessageType:
|
||||
if not self.session:
|
||||
raise RuntimeError("Missing session for v2 protocol")
|
||||
|
||||
buffer = bytearray()
|
||||
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, chunk = self.read_first()
|
||||
buffer.extend(chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
next_chunk = self.read_next()
|
||||
buffer.extend(next_chunk)
|
||||
|
||||
# Strip padding
|
||||
buffer = BytesIO(buffer[:datalen])
|
||||
|
||||
# Parse to protobuf
|
||||
msg = protobuf.load_message(buffer, mapping.get_class(msg_type))
|
||||
LOG.debug(
|
||||
"[session {}] received message: {}".format(
|
||||
self.session, msg.__class__.__name__
|
||||
),
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLLL")
|
||||
magic, session, msg_type, datalen = struct.unpack(
|
||||
">BLLL", chunk[:headerlen]
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != V2_FIRST_CHUNK:
|
||||
raise RuntimeError("Unexpected magic character")
|
||||
if session != self.session:
|
||||
raise RuntimeError("Session id mismatch")
|
||||
return msg_type, datalen, chunk[headerlen:]
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
try:
|
||||
headerlen = struct.calcsize(">BLL")
|
||||
magic, session, sequence = struct.unpack(">BLL", chunk[:headerlen])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
if magic != V2_NEXT_CHUNK:
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
if session != self.session:
|
||||
raise RuntimeError("Session id mismatch")
|
||||
return chunk[headerlen:]
|
||||
|
||||
|
||||
def get_protocol(handle: Handle, want_v2: bool) -> Protocol:
|
||||
"""Make a Protocol instance for the given handle.
|
||||
|
||||
Each transport can have a preference for using a particular protocol version.
|
||||
This preference is overridable through `TREZOR_PROTOCOL_V1` environment variable,
|
||||
which forces the library to use V1 anyways.
|
||||
|
||||
As of 11/2018, no devices support V2, so we enforce V1 here. It is still possible
|
||||
to set `TREZOR_PROTOCOL_V1=0` and thus enable V2 protocol for transports that ask
|
||||
for it (i.e., USB transports for Trezor T).
|
||||
"""
|
||||
force_v1 = int(os.environ.get("TREZOR_PROTOCOL_V1", 1))
|
||||
if want_v2 and not force_v1:
|
||||
return ProtocolV2(handle)
|
||||
else:
|
||||
return ProtocolV1(handle)
|
@ -15,20 +15,19 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import socket
|
||||
from typing import Iterable, cast
|
||||
|
||||
from . import Transport, TransportException
|
||||
from ..protocol_v1 import ProtocolV1
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, get_protocol
|
||||
|
||||
|
||||
class UdpTransport(Transport):
|
||||
class UdpTransport(ProtocolBasedTransport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
|
||||
def __init__(self, device=None, protocol=None):
|
||||
super(UdpTransport, self).__init__()
|
||||
|
||||
def __init__(self, device: str = None) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
@ -36,21 +35,21 @@ class UdpTransport(Transport):
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
if not protocol:
|
||||
protocol = ProtocolV1()
|
||||
self.device = (host, port)
|
||||
self.protocol = protocol
|
||||
self.socket = None
|
||||
self.socket = None # type: Optional[socket.socket]
|
||||
|
||||
def get_path(self):
|
||||
return "%s:%s:%s" % ((self.PATH_PREFIX,) + self.device)
|
||||
protocol = get_protocol(self, want_v2=False)
|
||||
super().__init__(protocol=protocol)
|
||||
|
||||
def find_debug(self):
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport("{}:{}".format(host, port + 1), self.protocol)
|
||||
return UdpTransport("{}:{}".format(host, port + 1))
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path):
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
@ -64,7 +63,7 @@ class UdpTransport(Transport):
|
||||
d.close()
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls):
|
||||
def enumerate(cls) -> Iterable["UdpTransport"]:
|
||||
default_path = "{}:{}".format(cls.DEFAULT_HOST, cls.DEFAULT_PORT)
|
||||
try:
|
||||
return [cls._try_path(default_path)]
|
||||
@ -72,27 +71,29 @@ class UdpTransport(Transport):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls, path, prefix_search=False):
|
||||
def find_by_path(cls, path: str, prefix_search: bool = False) -> "UdpTransport":
|
||||
if prefix_search:
|
||||
return super().find_by_path(path, prefix_search)
|
||||
return cast(UdpTransport, super().find_by_path(path, prefix_search))
|
||||
# This is *technically* type-able: mark `find_by_path` as returning
|
||||
# the same type from which `cls` comes from.
|
||||
# Mypy can't handle that though, so here we are.
|
||||
else:
|
||||
path = path.replace("{}:".format(cls.PATH_PREFIX), "")
|
||||
return cls._try_path(path)
|
||||
|
||||
def open(self):
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.connect(self.device)
|
||||
self.socket.settimeout(10)
|
||||
self.protocol.session_begin(self)
|
||||
|
||||
def close(self):
|
||||
if self.socket:
|
||||
self.protocol.session_end(self)
|
||||
def close(self) -> None:
|
||||
if self.socket is not None:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
self.socket = None
|
||||
|
||||
def _ping(self):
|
||||
def _ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
@ -101,18 +102,14 @@ class UdpTransport(Transport):
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
||||
def read(self):
|
||||
return self.protocol.read(self)
|
||||
|
||||
def write(self, msg):
|
||||
return self.protocol.write(self, msg)
|
||||
|
||||
def write_chunk(self, chunk):
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self):
|
||||
def read_chunk(self) -> bytes:
|
||||
assert self.socket is not None
|
||||
while True:
|
||||
try:
|
||||
chunk = self.socket.recv(64)
|
||||
|
@ -17,12 +17,12 @@
|
||||
import atexit
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
import usb1
|
||||
|
||||
from . import Transport, TransportException
|
||||
from ..protocol_v1 import ProtocolV1
|
||||
from ..protocol_v2 import ProtocolV2
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, get_protocol
|
||||
|
||||
DEV_TREZOR1 = (0x534C, 0x0001)
|
||||
DEV_TREZOR2 = (0x1209, 0x53C1)
|
||||
@ -35,34 +35,52 @@ DEBUG_ENDPOINT = 2
|
||||
|
||||
|
||||
class WebUsbHandle:
|
||||
def __init__(self, device):
|
||||
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
|
||||
self.device = device
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle = None # type: Optional[usb1.USBDeviceHandle]
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
if self.handle is None:
|
||||
if sys.platform.startswith("linux"):
|
||||
args = (
|
||||
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
|
||||
)
|
||||
else:
|
||||
args = ()
|
||||
raise IOError("Cannot open device", *args)
|
||||
self.handle.claimInterface(self.interface)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
self.handle.releaseInterface(self.interface)
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
def open(self, interface):
|
||||
if self.count == 0:
|
||||
self.handle = self.device.open()
|
||||
if self.handle is None:
|
||||
if sys.platform.startswith("linux"):
|
||||
args = (
|
||||
"Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules",
|
||||
)
|
||||
else:
|
||||
args = ()
|
||||
raise IOError("Cannot open device", *args)
|
||||
self.handle.claimInterface(interface)
|
||||
self.count += 1
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
assert self.handle is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
self.handle.interruptWrite(self.endpoint, chunk)
|
||||
|
||||
def close(self, interface):
|
||||
if self.count == 1:
|
||||
self.handle.releaseInterface(interface)
|
||||
self.handle.close()
|
||||
if self.count > 0:
|
||||
self.count -= 1
|
||||
def read_chunk(self) -> bytes:
|
||||
assert self.handle is not None
|
||||
endpoint = 0x80 | self.endpoint
|
||||
while True:
|
||||
chunk = self.handle.interruptRead(endpoint, 64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return chunk
|
||||
|
||||
|
||||
class WebUsbTransport(Transport):
|
||||
class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
@ -70,31 +88,24 @@ class WebUsbTransport(Transport):
|
||||
PATH_PREFIX = "webusb"
|
||||
context = None
|
||||
|
||||
def __init__(self, device, protocol=None, handle=None, debug=False):
|
||||
super(WebUsbTransport, self).__init__()
|
||||
|
||||
def __init__(
|
||||
self, device: str, handle: WebUsbHandle = None, debug: bool = False
|
||||
) -> None:
|
||||
if handle is None:
|
||||
handle = WebUsbHandle(device)
|
||||
|
||||
if protocol is None:
|
||||
# force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0')
|
||||
force_v1 = True
|
||||
|
||||
if is_trezor2(device) and not int(force_v1):
|
||||
protocol = ProtocolV2()
|
||||
else:
|
||||
protocol = ProtocolV1()
|
||||
handle = WebUsbHandle(device, debug)
|
||||
|
||||
self.device = device
|
||||
self.protocol = protocol
|
||||
self.handle = handle
|
||||
self.debug = debug
|
||||
|
||||
def get_path(self):
|
||||
protocol = get_protocol(handle, is_trezor2(device))
|
||||
super().__init__(protocol=protocol)
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "%s:%s" % (self.PATH_PREFIX, dev_to_str(self.device))
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls):
|
||||
def enumerate(cls) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
@ -117,69 +128,30 @@ class WebUsbTransport(Transport):
|
||||
pass
|
||||
return devices
|
||||
|
||||
def find_debug(self):
|
||||
if isinstance(self.protocol, ProtocolV2):
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
if self.protocol.VERSION >= 2:
|
||||
# TODO test this
|
||||
# For v2 protocol, lets use the same WebUSB interface, but with a different session
|
||||
protocol = ProtocolV2()
|
||||
debug = WebUsbTransport(self.device, protocol, self.handle)
|
||||
return debug
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
# XXX this is broken right now because sessions don't really work
|
||||
# For v2 protocol, use the same WebUSB interface with a different session
|
||||
return WebUsbTransport(self.device, self.handle)
|
||||
else:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
protocol = ProtocolV1()
|
||||
debug = WebUsbTransport(self.device, protocol, None, True)
|
||||
return debug
|
||||
raise TransportException("Debug WebUSB device not found")
|
||||
|
||||
def open(self):
|
||||
interface = DEBUG_INTERFACE if self.debug else INTERFACE
|
||||
self.handle.open(interface)
|
||||
self.protocol.session_begin(self)
|
||||
|
||||
def close(self):
|
||||
interface = DEBUG_INTERFACE if self.debug else INTERFACE
|
||||
self.protocol.session_end(self)
|
||||
self.handle.close(interface)
|
||||
|
||||
def read(self):
|
||||
return self.protocol.read(self)
|
||||
|
||||
def write(self, msg):
|
||||
return self.protocol.write(self, msg)
|
||||
|
||||
def write_chunk(self, chunk):
|
||||
endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
self.handle.handle.interruptWrite(endpoint, chunk)
|
||||
|
||||
def read_chunk(self):
|
||||
endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT
|
||||
endpoint = 0x80 | endpoint
|
||||
while True:
|
||||
chunk = self.handle.handle.interruptRead(endpoint, 64)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return bytearray(chunk)
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
|
||||
|
||||
def is_trezor1(dev):
|
||||
def is_trezor1(dev: usb1.USBDevice) -> bool:
|
||||
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR1
|
||||
|
||||
|
||||
def is_trezor2(dev):
|
||||
def is_trezor2(dev: usb1.USBDevice) -> bool:
|
||||
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2
|
||||
|
||||
|
||||
def is_trezor2_bl(dev):
|
||||
def is_trezor2_bl(dev: usb1.USBDevice) -> bool:
|
||||
return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL
|
||||
|
||||
|
||||
def is_vendor_class(dev):
|
||||
def is_vendor_class(dev: usb1.USBDevice) -> bool:
|
||||
configurationId = 0
|
||||
altSettingId = 0
|
||||
return (
|
||||
@ -188,7 +160,7 @@ def is_vendor_class(dev):
|
||||
)
|
||||
|
||||
|
||||
def dev_to_str(dev):
|
||||
def dev_to_str(dev: usb1.USBDevice) -> str:
|
||||
return ":".join(
|
||||
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user