diff --git a/requirements.txt b/requirements.txt index 80ea723fb..9af85f089 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ hidapi>=0.7.99.post20 requests>=2.4.0 click>=6.2 pyblake2>=0.9.3 +libusb1>=1.6.4 diff --git a/trezorctl b/trezorctl index 527677a82..01693d2c2 100755 --- a/trezorctl +++ b/trezorctl @@ -68,6 +68,10 @@ def get_transport_class_by_name(name): from trezorlib.transport_hid import HidTransport return HidTransport + if name == 'webusb': + from trezorlib.transport_webusb import WebUsbTransport + return WebUsbTransport + if name == 'udp': from trezorlib.transport_udp import UdpTransport return UdpTransport @@ -90,7 +94,7 @@ def get_transport(transport_name, path): @click.group() -@click.option('-t', '--transport', type=click.Choice(['usb', 'udp', 'pipe', 'bridge']), default='usb', help='Select transport used for communication.') +@click.option('-t', '--transport', type=click.Choice(['usb', 'webusb', 'udp', 'pipe', 'bridge']), default='usb', help='Select transport used for communication.') @click.option('-p', '--path', help='Select device by transport-specific path.') @click.option('-v', '--verbose', is_flag=True, help='Show communication messages.') @click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object') diff --git a/trezorlib/tests/device_tests/common.py b/trezorlib/tests/device_tests/common.py index ac835dd72..c709e6822 100644 --- a/trezorlib/tests/device_tests/common.py +++ b/trezorlib/tests/device_tests/common.py @@ -36,6 +36,13 @@ except ImportError as e: print('HID transport disabled:', e) HID_ENABLED = False +try: + from trezorlib.transport_webusb import WebUsbTransport + WEBUSB_ENABLED = True +except ImportError as e: + print('WebUsb transport disabled:', e) + WEBUSB_ENABLED = False + try: from trezorlib.transport_pipe import PipeTransport PIPE_ENABLED = True @@ -66,6 +73,11 @@ def get_transport(): wirelink = devices[0] debuglink = devices[0].find_debug() + elif WEBUSB_ENABLED and WebUsbTransport.enumerate(): + devices = WebUsbTransport.enumerate() + wirelink = devices[0] + debuglink = devices[0].find_debug() + elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'): wirelink = PipeTransport('/tmp/pipe.trezor', False) debuglink = PipeTransport('/tmp/pipe.trezor_debug', False) @@ -79,6 +91,8 @@ def get_transport(): if HID_ENABLED and HidTransport.enumerate(): print('Using TREZOR') +elif WEBUSB_ENABLED and WebUsbTransport.enumerate(): + print('Using TREZOR via WebUSB') elif PIPE_ENABLED and pipe_exists('/tmp/pipe.trezor.to'): print('Using Emulator (v1=pipe)') elif UDP_ENABLED: diff --git a/trezorlib/transport_webusb.py b/trezorlib/transport_webusb.py new file mode 100644 index 000000000..f2a6de834 --- /dev/null +++ b/trezorlib/transport_webusb.py @@ -0,0 +1,177 @@ +# This file is part of the TREZOR project. +# +# Copyright (C) 2012-2016 Marek Palatinus +# Copyright (C) 2012-2016 Pavol Rusnak +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library. If not, see . + +from __future__ import absolute_import + +import time +import usb1 +import os + +from .protocol_v1 import ProtocolV1 +from .protocol_v2 import ProtocolV2 +from .transport import Transport, TransportException + +DEV_TREZOR1 = (0x534c, 0x0001) +DEV_TREZOR2 = (0x1209, 0x53c1) +DEV_TREZOR2_BL = (0x1209, 0x53c0) + +INTERFACE = 0 +ENDPOINT = 1 +DEBUG_INTERFACE = 1 +DEBUG_ENDPOINT = 2 + +context = usb1.USBContext() +context.open() + +import atexit + +def exit_handler(): + context.close() + +atexit.register(exit_handler) + +class WebUsbHandle(object): + + def __init__(self, device): + self.device = device + self.count = 0 + self.handle = None + + def open(self, interface): + if self.count == 0: + self.handle = self.device.open() + if self.handle is None: + raise Exception('Cannot open device') + self.handle.claimInterface(interface) + self.count += 1 + + def close(self, interface): + if self.count == 1: + self.handle.releaseInterface(interface) + self.handle.close() + if self.count > 0: + self.count -= 1 + + +class WebUsbTransport(Transport): + ''' + HidTransport implements transport over USB HID interface. + ''' + + def __init__(self, device, protocol=None, handle=None, debug=False): + super(WebUsbTransport, self).__init__() + + if handle is None: + handle = WebUsbHandle(device) + + if protocol is None: + force_v1 = os.environ.get('TREZOR_TRANSPORT_V1', '0') + + if is_trezor2(device) and not int(force_v1): + protocol = ProtocolV2() + else: + protocol = ProtocolV1() + + self.device = device + self.protocol = protocol + self.handle = handle + self.debug = debug + + def __str__(self): + return dev_to_str(self.device) + + @staticmethod + def enumerate(): + devices = [] + for dev in context.getDeviceIterator(skip_on_error=True): + if not (is_trezor1(dev) or is_trezor2(dev) or is_trezor2_bl(dev)): + continue + devices.append(WebUsbTransport(dev)) + return devices + + @staticmethod + def find_by_path(path=None): + for transport in WebUsbTransport.enumerate(): + if path is None or dev_to_str(transport.device) == path: + return transport + raise TransportException('HID device not found') + + def find_debug(self): + if isinstance(self.protocol, ProtocolV2): + # TODO test this + # For v2 protocol, lets use the same HID interface, but with a different session + protocol = ProtocolV2() + debug = WebUsbTransport(self.device, protocol, self.handle) + return debug + if isinstance(self.protocol, ProtocolV1): + # For v1 protocol, find debug USB interface for the same serial number + protocol = ProtocolV1() + debug = WebUsbTransport(self.device, protocol, None, True) + return debug + raise TransportException('Debug HID device not found') + + def open(self): + interface = DEBUG_INTERFACE if self.debug else INTERFACE + self.handle.open(interface) + self.protocol.session_begin(self) + + def close(self): + interface = DEBUG_INTERFACE if self.debug else INTERFACE + self.protocol.session_end(self) + self.handle.close(interface) + + def read(self): + return self.protocol.read(self) + + def write(self, msg): + return self.protocol.write(self, msg) + + def write_chunk(self, chunk): + endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT + if len(chunk) != 64: + raise TransportException('Unexpected chunk size: %d' % len(chunk)) + self.handle.handle.interruptWrite(endpoint, chunk) + + def read_chunk(self): + endpoint = DEBUG_ENDPOINT if self.debug else ENDPOINT + endpoint = 0x80 | endpoint + while True: + chunk = self.handle.handle.interruptRead(endpoint, 64) + if chunk: + break + else: + time.sleep(0.001) + if len(chunk) != 64: + raise TransportException('Unexpected chunk size: %d' % len(chunk)) + return bytearray(chunk) + + +def is_trezor1(dev): + return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR1 + + +def is_trezor2(dev): + return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2 + + +def is_trezor2_bl(dev): + return (dev.getVendorID(), dev.getProductID()) == DEV_TREZOR2_BL + + +def dev_to_str(dev): + return ':'.join(str(x) for x in ['%03i' % (dev.getBusNumber(), )] + dev.getPortNumberList())