From 6475f98b1ec8d2f1bd9740d516014d9c4dbf5a6b Mon Sep 17 00:00:00 2001 From: Chris Rico Date: Fri, 4 Sep 2015 13:41:54 -0500 Subject: [PATCH] Allow firmware update by version or latest from releases.json --- trezorctl | 65 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/trezorctl b/trezorctl index 55b389efad..dae0c20122 100755 --- a/trezorctl +++ b/trezorctl @@ -20,22 +20,22 @@ def parse_args(commands): # parser.add_argument('-d', '--debug', dest='debug', action='store_true', help='Enable low-level debugging') cmdparser = parser.add_subparsers(title='Available commands') - + for cmd in commands._list_commands(): func = object.__getattribute__(commands, cmd) - + try: arguments = func.arguments except AttributeError: arguments = ((('params',), {'nargs': '*'}),) - + item = cmdparser.add_parser(cmd, help=func.help) for arg in arguments: item.add_argument(*arg[0], **arg[1]) - + item.set_defaults(func=func) item.set_defaults(cmd=cmd) - + return parser.parse_args() def get_transport(transport_string, path, **kwargs): @@ -54,7 +54,7 @@ def get_transport(transport_string, path, **kwargs): return HidTransport(d, **kwargs) raise Exception("Device not found") - + if transport_string == 'serial': from trezorlib.transport_serial import SerialTransport return SerialTransport(path, **kwargs) @@ -62,7 +62,7 @@ def get_transport(transport_string, path, **kwargs): if transport_string == 'pipe': from trezorlib.transport_pipe import PipeTransport return PipeTransport(path, is_device=False, **kwargs) - + if transport_string == 'socket': from trezorlib.transport_socket import SocketTransportClient return SocketTransportClient(path, **kwargs) @@ -70,29 +70,29 @@ def get_transport(transport_string, path, **kwargs): if transport_string == 'bridge': from trezorlib.transport_bridge import BridgeTransport return BridgeTransport(path, **kwargs) - + if transport_string == 'fake': from trezorlib.transport_fake import FakeTransport return FakeTransport(path, **kwargs) - + raise NotImplemented("Unknown transport") class Commands(object): def __init__(self, client): self.client = client - + @classmethod def _list_commands(cls): return [ x for x in dir(cls) if not x.startswith('_') ] - + def list(self, args): # Fake method for advertising 'list' command pass - + def get_address(self, args): address_n = self.client.expand_path(args.n) return self.client.get_address(args.coin, address_n, args.show_display) - + def get_entropy(self, args): return binascii.hexlify(self.client.get_entropy(args.size)) @@ -108,7 +108,7 @@ class Commands(object): def get_public_node(self, args): address_n = self.client.expand_path(args.n) return self.client.get_public_node(address_n) - + def set_label(self, args): return self.client.apply_settings(label=args.label) @@ -203,9 +203,6 @@ class Commands(object): return ret def firmware_update(self, args): - if not args.file and not args.url: - raise Exception("Must provide firmware filename or URL") - if args.file: fp = open(args.file, 'r') elif args.url: @@ -213,18 +210,31 @@ class Commands(object): resp = urllib.urlretrieve(args.url) fp = open(resp[0], 'r') urllib.urlcleanup() # We still keep file pointer open - + else: + resp = urllib.urlopen("https://mytrezor.com/data/firmware/releases.json") + releases = json.load(resp) + version = lambda r: r['version'] + version_string = lambda r: ".".join(map(str, version(r))) + if args.version: + release = next((r for r in releases if version_string(r) == args.version)) + else: + release = max(releases, key=version) + print "No file, url, or version given. Fetching latest version: %s" % version_string(release) + print "Firmware fingerprint: %s" % release['fingerprint'] + args.url = release['url'] + return self.firmware_update(args) + if fp.read(8) == '54525a52': print "Converting firmware to binary" fp.seek(0) fp_old = fp - + fp = tempfile.TemporaryFile() fp.write(binascii.unhexlify(fp_old.read())) fp_old.close() - + fp.seek(0) if fp.read(4) != 'TRZR': raise Exception("Trezor firmware header expected") @@ -262,7 +272,7 @@ class Commands(object): (('-n', '-address'), {'type': str}), (('-d', '--show-display'), {'action': 'store_true', 'default': False}), ) - + get_entropy.arguments = ( (('size',), {'type': int}), ) @@ -277,7 +287,7 @@ class Commands(object): (('-p', '--pin-protection'), {'action': 'store_true', 'default': False}), (('-r', '--passphrase-protection'), {'action': 'store_true', 'default': False}), ) - + set_label.arguments = ( (('-l', '--label',), {'type': str, 'default': ''}), # (('-c', '--clear'), {'action': 'store_true', 'default': False}) @@ -289,7 +299,7 @@ class Commands(object): change_pin.arguments = ( (('-r', '--remove'), {'action': 'store_true', 'default': False}), ) - + wipe_device.arguments = () recovery_device.arguments = ( @@ -358,6 +368,7 @@ class Commands(object): firmware_update.arguments = ( (('-f', '--file'), {'type': str}), (('-u', '--url'), {'type': str}), + (('-n', '--version'), {'type': str}), ) def list_usb(): @@ -418,7 +429,7 @@ def qt_pin_func(input_text, message=None): # let's fallback to default pin_func implementation return pin_func(input_text, message) ''' - + def main(): args = parse_args(Commands) @@ -441,13 +452,13 @@ def main(): client = TrezorClient(transport) cmds = Commands(client) - + res = args.func(cmds, args) - + if args.json: print json.dumps(res, sort_keys=True, indent=4) else: print res - + if __name__ == '__main__': main()