1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 23:48:12 +00:00

Support for HID debug_link

This commit is contained in:
slush0 2013-11-15 01:43:05 +01:00
parent 039bcee3f2
commit 12afba8385
2 changed files with 38 additions and 20 deletions

38
cmd.py
View File

@ -14,8 +14,8 @@ def parse_args(commands):
parser = argparse.ArgumentParser(description='Commandline tool for Trezor devices.') parser = argparse.ArgumentParser(description='Commandline tool for Trezor devices.')
parser.add_argument('-t', '--transport', dest='transport', choices=['usb', 'serial', 'pipe', 'socket'], default='usb', help="Transport used for talking with the device") parser.add_argument('-t', '--transport', dest='transport', choices=['usb', 'serial', 'pipe', 'socket'], default='usb', help="Transport used for talking with the device")
parser.add_argument('-p', '--path', dest='path', default='', help="Path used by the transport (usually serial port)") parser.add_argument('-p', '--path', dest='path', default='', help="Path used by the transport (usually serial port)")
parser.add_argument('-dt', '--debuglink-transport', dest='debuglink_transport', choices=['usb', 'serial', 'pipe', 'socket'], default='socket', help="Debuglink transport") parser.add_argument('-dt', '--debuglink-transport', dest='debuglink_transport', choices=['usb', 'serial', 'pipe', 'socket'], default='usb', help="Debuglink transport")
parser.add_argument('-dp', '--debuglink-path', dest='debuglink_path', default='127.0.0.1:2000', help="Path used by the transport (usually serial port)") parser.add_argument('-dp', '--debuglink-path', dest='debuglink_path', default='', help="Path used by the transport (usually serial port)")
parser.add_argument('-j', '--json', dest='json', action='store_true', help="Prints result as json object") parser.add_argument('-j', '--json', dest='json', action='store_true', help="Prints result as json object")
parser.add_argument('-d', '--debug', dest='debug', action='store_true', help='Enable low-level debugging') parser.add_argument('-d', '--debug', dest='debug', action='store_true', help='Enable low-level debugging')
@ -42,7 +42,7 @@ def parse_args(commands):
return parser.parse_args() return parser.parse_args()
def get_transport(transport_string, path): def get_transport(transport_string, path, **kwargs):
if transport_string == 'usb': if transport_string == 'usb':
from trezorlib.transport_hid import HidTransport from trezorlib.transport_hid import HidTransport
@ -52,23 +52,23 @@ def get_transport(transport_string, path):
except IndexError: except IndexError:
raise Exception("No Trezor found on USB") raise Exception("No Trezor found on USB")
return HidTransport(path) return HidTransport(path, **kwargs)
if transport_string == 'serial': if transport_string == 'serial':
from trezorlib.transport_serial import SerialTransport from trezorlib.transport_serial import SerialTransport
return SerialTransport(path) return SerialTransport(path, **kwargs)
if transport_string == 'pipe': if transport_string == 'pipe':
from trezorlib.transport_pipe import PipeTransport from trezorlib.transport_pipe import PipeTransport
return PipeTransport(path, is_device=False) return PipeTransport(path, is_device=False, **kwargs)
if transport_string == 'socket': if transport_string == 'socket':
from trezorlib.transport_socket import SocketTransportClient from trezorlib.transport_socket import SocketTransportClient
return SocketTransportClient(path) return SocketTransportClient(path, **kwargs)
if transport_string == 'fake': if transport_string == 'fake':
from trezorlib.transport_fake import FakeTransport from trezorlib.transport_fake import FakeTransport
return FakeTransport(path) return FakeTransport(path, **kwargs)
raise NotImplemented("Unknown transport") raise NotImplemented("Unknown transport")
@ -96,8 +96,8 @@ class Commands(object):
def ping(self, args): def ping(self, args):
return self.client.ping(args.msg) return self.client.ping(args.msg)
def get_master_public_key(self, args): def get_public_node(self, args):
return self.client.get_master_public_key() return self.client.get_public_node(args.n)
def get_serial_number(self, args): def get_serial_number(self, args):
return binascii.hexlify(self.client.get_serial_number()) return binascii.hexlify(self.client.get_serial_number())
@ -113,6 +113,9 @@ class Commands(object):
return self.client.load_device(seed, args.pin) return self.client.load_device(seed, args.pin)
def sign_message(self, args):
return self.client.sign_message(args.n, args.message)
def firmware_update(self, args): def firmware_update(self, args):
if not args.file: if not args.file:
raise Exception("Must provide firmware filename") raise Exception("Must provide firmware filename")
@ -129,10 +132,11 @@ class Commands(object):
get_entropy.help = 'Get example entropy' get_entropy.help = 'Get example entropy'
get_features.help = 'Retrieve device features and settings' get_features.help = 'Retrieve device features and settings'
get_serial_number.help = 'Get device\'s unique identifier' get_serial_number.help = 'Get device\'s unique identifier'
get_master_public_key.help = 'Get master public key' get_public_node.help = 'Get public node of given path'
set_label.help = 'Set new wallet label' set_label.help = 'Set new wallet label'
set_coin.help = 'Switch device to another crypto currency' set_coin.help = 'Switch device to another crypto currency'
load_device.help = 'Load custom configuration to the device' load_device.help = 'Load custom configuration to the device'
sign_message.help = 'Sign message using address of given path'
firmware_update.help = 'Upload new firmware to device (must be in bootloader mode)' firmware_update.help = 'Upload new firmware to device (must be in bootloader mode)'
get_address.arguments = ( get_address.arguments = (
@ -162,6 +166,15 @@ class Commands(object):
(('-n', '--pin'), {'type': str, 'default': ''}), (('-n', '--pin'), {'type': str, 'default': ''}),
) )
sign_message.arguments = (
(('n',), {'metavar': 'N', 'type': int, 'nargs': '+'}),
(('message',), {'type': str}),
)
get_public_node.arguments = (
(('n',), {'metavar': 'N', 'type': int, 'nargs': '+'}),
)
firmware_update.arguments = ( firmware_update.arguments = (
(('-f', '--file'), {'type': str}), (('-f', '--file'), {'type': str}),
) )
@ -239,6 +252,9 @@ def main():
transport = get_transport(args.transport, args.path) transport = get_transport(args.transport, args.path)
if args.debug: if args.debug:
if args.debuglink_transport == 'usb' and args.debuglink_path == '':
debuglink_transport = get_transport('usb', args.path, debug_link=True)
else:
debuglink_transport = get_transport(args.debuglink_transport, args.debuglink_path) debuglink_transport = get_transport(args.debuglink_transport, args.debuglink_path)
debuglink = DebugLink(debuglink_transport) debuglink = DebugLink(debuglink_transport)
else: else:

View File

@ -5,7 +5,7 @@ import time
from transport import Transport, NotImplementedException from transport import Transport, NotImplementedException
DEVICE_IDS = [ DEVICE_IDS = [
(0x10c4, 0xea80), # Trezor Pi (0x10c4, 0xea80), # Shield
(0x534c, 0x0001), # Trezor (0x534c, 0x0001), # Trezor
] ]
@ -21,6 +21,8 @@ class HidTransport(Transport):
def __init__(self, device, *args, **kwargs): def __init__(self, device, *args, **kwargs):
self.hid = None self.hid = None
self.buffer = '' self.buffer = ''
if bool(kwargs.get('debug_link')):
device = device[:-2] + '01'
super(HidTransport, self).__init__(device, *args, **kwargs) super(HidTransport, self).__init__(device, *args, **kwargs)
@classmethod @classmethod
@ -29,18 +31,18 @@ class HidTransport(Transport):
for d in hid.enumerate(0, 0): for d in hid.enumerate(0, 0):
vendor_id = d.get('vendor_id') vendor_id = d.get('vendor_id')
product_id = d.get('product_id') product_id = d.get('product_id')
serial_number = d.get('serial_number') path = d.get('path')
if (vendor_id, product_id) in DEVICE_IDS: if (vendor_id, product_id) in DEVICE_IDS and path.endswith(':00'):
devices.append("0x%04x:0x%04x:%s" % (vendor_id, product_id, serial_number)) devices.append(path)
return devices return devices
def _open(self): def _open(self):
self.buffer = '' self.buffer = ''
path = self.device.split(':') print self.device
self.hid = hid.device() self.hid = hid.device()
self.hid.open(int(path[0], 16), int(path[1], 16)) self.hid.open_path(self.device)
self.hid.set_nonblocking(True) self.hid.set_nonblocking(True)
self.hid.send_feature_report([0x41, 0x01]) # enable UART self.hid.send_feature_report([0x41, 0x01]) # enable UART
self.hid.send_feature_report([0x43, 0x03]) # purge TX/RX FIFOs self.hid.send_feature_report([0x43, 0x03]) # purge TX/RX FIFOs