diff --git a/trezorctl b/trezorctl index 32df97c6a..3c80cd42b 100755 --- a/trezorctl +++ b/trezorctl @@ -516,53 +516,53 @@ def backup_device(connect): # -def validate_firmware_v1(fw, expected_fingerprint=None): - click.echo("Trezor One firmware image.") - distinct_sig_slots = set(i for i in fw.key_indexes if i != 0) - if not distinct_sig_slots: - if not click.confirm("No signatures found. Continue?", default=False): - sys.exit(1) - elif len(distinct_sig_slots) < 3: - click.echo("Badly signed image (need 3 distinct signatures), aborting.") - sys.exit(1) - else: - all_valid = True - for i in range(len(fw.key_indexes)): - if not firmware.check_sig_v1(fw, i): - click.echo("INVALID signature in slot {}".format(i)) - all_valid = False - - if all_valid: - click.echo("Signatures are valid.") - else: - click.echo("Invalid signature detected, aborting.") - sys.exit(4) +ALLOWED_FIRMWARE_FORMATS = { + 1: (firmware.FirmwareFormat.TREZOR_ONE, firmware.FirmwareFormat.TREZOR_ONE_V2), + 2: (firmware.FirmwareFormat.TREZOR_T), +} - fingerprint = firmware.digest_v1(fw).hex() - click.echo("Firmware fingerprint: {}".format(fingerprint)) - if expected_fingerprint and fingerprint != expected_fingerprint: - click.echo("Expected fingerprint: {}".format(expected_fingerprint)) - click.echo("Fingerprints do not match, aborting.") - sys.exit(5) +def _print_version(version): + vstr = "Firmware version {major}.{minor}.{patch} build {build}".format(**version) + click.echo(vstr) + + +def validate_firmware(version, fw, expected_fingerprint=None): + if version == firmware.FirmwareFormat.TREZOR_ONE: + if fw.embedded_onev2: + click.echo("Trezor One firmware with embedded v2 image (1.8.0 or later)") + _print_version(fw.embedded_onev2.firmware_header.version) + else: + click.echo("Trezor One firmware image.") + elif version == firmware.FirmwareFormat.TREZOR_ONE_V2: + click.echo("Trezor One v2 firmware (1.8.0 or later)") + _print_version(fw.firmware_header.version) + elif version == firmware.FirmwareFormat.TREZOR_T: + click.echo("Trezor T firmware image.") + vendor = fw.vendor_header.vendor_string + vendor_version = "{major}.{minor}".format(**fw.vendor_header.version) + click.echo("Vendor header from {}, version {}".format(vendor, vendor_version)) + _print_version(fw.firmware_header.version) -def validate_firmware_v2(fw, expected_fingerprint=None, skip_vendor_header=False): - click.echo("Trezor T firmware image.") - vendor = fw.vendor_header.vendor_string - vendor_version = "{major}.{minor}".format(**fw.vendor_header.version) - version = fw.firmware_header.version - click.echo("Vendor header from {}, version {}".format(vendor, vendor_version)) - click.echo( - "Firmware version {major}.{minor}.{patch} build {build}".format(**version) - ) try: - firmware.validate(fw, skip_vendor_header) + firmware.validate(version, fw) click.echo("Signatures are valid.") - except Exception as e: + except firmware.Unsigned: + if not click.confirm("No signatures found. Continue?", default=False): + sys.exit(1) + try: + firmware.validate(version, fw, allow_unsigned=True) + click.echo("Unsigned firmware looking OK.") + except firmware.FirmwareIntegrityError as e: + click.echo(e) + click.echo("Firmware validation failed, aborting.") + sys.exit(4) + except firmware.FirmwareIntegrityError as e: click.echo(e) + click.echo("Firmware validation failed, aborting.") sys.exit(4) - fingerprint = firmware.digest(fw).hex() + fingerprint = firmware.digest(version, fw).hex() click.echo("Firmware fingerprint: {}".format(fingerprint)) if expected_fingerprint and fingerprint != expected_fingerprint: click.echo("Expected fingerprint: {}".format(expected_fingerprint)) @@ -628,15 +628,18 @@ def find_best_firmware_version(bootloader_version, requested_version=None): @cli.command() +# fmt: off @click.option("-f", "--filename") @click.option("-u", "--url") @click.option("-v", "--version") -@click.option("-s", "--skip-check", is_flag=True) +@click.option("-s", "--skip-check", is_flag=True, help="Do not validate firmware integrity") +@click.option("--raw", is_flag=True, help="Push raw data to Trezor") @click.option("--fingerprint", help="Expected firmware fingerprint in hex") @click.option("--skip-vendor-header", help="Skip vendor header validation on Trezor T") +# fmt: on @click.pass_obj def firmware_update( - connect, filename, url, version, skip_check, fingerprint, skip_vendor_header + connect, filename, url, version, skip_check, fingerprint, skip_vendor_header, raw ): """Upload new firmware to device. @@ -663,13 +666,13 @@ def firmware_update( click.echo("Please switch your device to bootloader mode.") sys.exit(1) - firmware_version = client.features.major_version + f = client.features + bootloader_onev2 = f.major_version == 1 and f.minor_version >= 8 if filename: data = open(filename, "rb").read() else: if not url: - f = client.features bootloader_version = [f.major_version, f.minor_version, f.patch_version] version_list = [int(x) for x in version.split(".")] if version else None url, fp = find_best_firmware_version(bootloader_version, version_list) @@ -680,26 +683,41 @@ def firmware_update( r = requests.get(url) data = r.content - if not skip_check: + if not raw and not skip_check: try: version, fw = firmware.parse(data) except Exception as e: click.echo(e) sys.exit(2) - if version == firmware.FirmwareFormat.TREZOR_ONE: - validate_firmware_v1(fw, fingerprint) - elif version == firmware.FirmwareFormat.TREZOR_T: - validate_firmware_v2(fw, fingerprint) - else: - click.echo("Unrecognized firmware version.") + validate_firmware(version, fw, fingerprint) + if ( + bootloader_onev2 + and version == firmware.FirmwareFormat.TREZOR_ONE + and not fw.embedded_onev2 + ): + click.echo("Firmware is too old for your device. Aborting.") + sys.exit(3) + elif not bootloader_onev2 and version == firmware.FirmwareFormat.TREZOR_ONE_V2: + click.echo("You need to upgrade to bootloader 1.8.0 first.") + sys.exit(3) - if firmware_version != version.value: + if f.major_version not in ALLOWED_FIRMWARE_FORMATS: + click.echo("trezorctl doesn't know your device version. Aborting.") + sys.exit(3) + elif version not in ALLOWED_FIRMWARE_FORMATS[f.major_version]: click.echo("Firmware does not match your device, aborting.") sys.exit(3) + if not raw: + # special handling for embedded-OneV2 format: + # for bootloader < 1.8, keep the embedding + # for bootloader 1.8.0 and up, strip the old OneV1 header + if bootloader_onev2 and data[:4] == b"TRZR" and data[256 : 256 + 4] == b"TRZF": + data = data[256:] + try: - if firmware_version == 1: + if f.major_version == 1 and f.firmware_present: # Trezor One does not send ButtonRequest click.echo("Please confirm action on your Trezor device") return firmware.update(client, data) diff --git a/trezorlib/firmware.py b/trezorlib/firmware.py index e80a5af55..a99a6c36e 100644 --- a/trezorlib/firmware.py +++ b/trezorlib/firmware.py @@ -16,7 +16,7 @@ import hashlib from enum import Enum -from typing import NewType, Tuple +from typing import Callable, List, NewType, Tuple import construct as c import ecdsa @@ -41,6 +41,7 @@ V2_BOOTLOADER_KEYS = [ V2_BOOTLOADER_M = 2 V2_BOOTLOADER_N = 3 +ONEV2_CHUNK_SIZE = 1024 * 64 V2_CHUNK_SIZE = 1024 * 128 @@ -57,6 +58,18 @@ def _transform_vendor_trust(data: bytes) -> bytes: return bytes(~b & 0xFF for b in data)[::-1] +class FirmwareIntegrityError(Exception): + pass + + +class InvalidSignatureError(FirmwareIntegrityError): + pass + + +class Unsigned(FirmwareIntegrityError): + pass + + # fmt: off Toif = c.Struct( "magic" / c.Const(b"TOI"), @@ -117,7 +130,7 @@ VersionLong = c.Struct( FirmwareHeader = c.Struct( "_start_offset" / c.Tell, "magic" / c.Const(b"TRZF"), - "_header_len" / c.Padding(4), + "header_len" / c.Int32ul, "expiry" / c.Int32ul, "code_length" / c.Rebuild( c.Int32ul, @@ -130,14 +143,21 @@ FirmwareHeader = c.Struct( "reserved" / c.Padding(8), "hashes" / c.Bytes(32)[16], - "reserved" / c.Padding(415), + "v1_signatures" / c.Bytes(64)[V1_SIGNATURE_SLOTS], + "v1_key_indexes" / c.Int8ul[V1_SIGNATURE_SLOTS], # pylint: disable=E1136 + + "reserved" / c.Padding(220), "sigmask" / c.Byte, "signature" / c.Bytes(64), "_end_offset" / c.Tell, - "header_len" / c.Pointer( - c.this._start_offset + 4, - c.Rebuild(c.Int32ul, c.this._end_offset - c.this._start_offset) + + "_rebuild_header_len" / c.If( + c.this.version.major > 1, + c.Pointer( + c.this._start_offset + 4, + c.Rebuild(c.Int32ul, c.this._end_offset - c.this._start_offset) + ), ), ) @@ -151,7 +171,15 @@ Firmware = c.Struct( ) -FirmwareV1 = c.Struct( +FirmwareOneV2 = c.Struct( + "firmware_header" / FirmwareHeader, + "_code_offset" / c.Tell, + "code" / c.Bytes(c.this.firmware_header.code_length), + c.Terminated, +) + + +FirmwareOne = c.Struct( "magic" / c.Const(b"TRZR"), "code_length" / c.Rebuild(c.Int32ul, c.len_(c.this.code)), "key_indexes" / c.Int8ul[V1_SIGNATURE_SLOTS], # pylint: disable=E1136 @@ -163,6 +191,8 @@ FirmwareV1 = c.Struct( "signatures" / c.Bytes(64)[V1_SIGNATURE_SLOTS], "code" / c.Bytes(c.this.code_length), c.Terminated, + + "embedded_onev2" / c.RestreamData(c.this.code, c.Optional(FirmwareOneV2)), ) # fmt: on @@ -171,6 +201,7 @@ FirmwareV1 = c.Struct( class FirmwareFormat(Enum): TREZOR_ONE = 1 TREZOR_T = 2 + TREZOR_ONE_V2 = 3 FirmwareType = NewType("FirmwareType", c.Container) @@ -180,62 +211,137 @@ ParsedFirmware = Tuple[FirmwareFormat, FirmwareType] def parse(data: bytes) -> ParsedFirmware: if data[:4] == b"TRZR": version = FirmwareFormat.TREZOR_ONE - cls = FirmwareV1 + cls = FirmwareOne elif data[:4] == b"TRZV": version = FirmwareFormat.TREZOR_T cls = Firmware + elif data[:4] == b"TRZF": + version = FirmwareFormat.TREZOR_ONE_V2 + cls = FirmwareOneV2 else: raise ValueError("Unrecognized firmware image type") try: fw = cls.parse(data) except Exception as e: - raise ValueError("Invalid firmware image") from e + raise FirmwareIntegrityError("Invalid firmware image") from e return version, FirmwareType(fw) -def digest_v1(fw: FirmwareType) -> bytes: +def digest_onev1(fw: FirmwareType) -> bytes: return hashlib.sha256(fw.code).digest() -def check_sig_v1(fw: FirmwareType, idx: int) -> bool: - key_idx = fw.key_indexes[idx] - signature = fw.signatures[idx] +def check_sig_v1( + digest: bytes, key_indexes: List[int], signatures: List[bytes] +) -> None: + distinct_key_indexes = set(i for i in key_indexes if i != 0) + if not distinct_key_indexes: + raise Unsigned - if key_idx == 0: - # no signature = invalid signature - return False + if len(distinct_key_indexes) < len(key_indexes): + raise InvalidSignatureError( + "Not enough distinct signatures (found {}, need {})".format( + len(distinct_key_indexes), len(key_indexes) + ) + ) - if key_idx not in V1_BOOTLOADER_KEYS: - # unknown pubkey - return False + for i in range(len(key_indexes)): + key_idx = key_indexes[i] + signature = signatures[i] - pubkey = bytes.fromhex(V1_BOOTLOADER_KEYS[key_idx])[1:] - verify = ecdsa.VerifyingKey.from_string( - pubkey, curve=ecdsa.curves.SECP256k1, hashfunc=hashlib.sha256 - ) - try: - verify.verify(signature, fw.code) - return True - except ecdsa.BadSignatureError: - return False + if key_idx not in V1_BOOTLOADER_KEYS: + # unknown pubkey + raise InvalidSignatureError("Unknown key in slot {}".format(i)) + + pubkey = bytes.fromhex(V1_BOOTLOADER_KEYS[key_idx])[1:] + verify = ecdsa.VerifyingKey.from_string(pubkey, curve=ecdsa.curves.SECP256k1) + try: + verify.verify_digest(signature, digest) + except ecdsa.BadSignatureError as e: + raise InvalidSignatureError("Invalid signature in slot {}".format(i)) from e -def _header_digest(header: c.Container, header_type: c.Construct) -> bytes: +def _header_digest( + header: c.Container, + header_type: c.Construct, + hash_function: Callable = pyblake2.blake2s, +) -> bytes: stripped_header = header.copy() stripped_header.sigmask = 0 stripped_header.signature = b"\0" * 64 + stripped_header.v1_key_indexes = [0, 0, 0] + stripped_header.v1_signatures = [b"\0" * 64] * 3 header_bytes = header_type.build(stripped_header) - return pyblake2.blake2s(header_bytes).digest() + return hash_function(header_bytes).digest() + + +def digest_v2(fw: FirmwareType) -> bytes: + return _header_digest(fw.firmware_header, FirmwareHeader, pyblake2.blake2s) + + +def digest_onev2(fw: FirmwareType) -> bytes: + return _header_digest(fw.firmware_header, FirmwareHeader, hashlib.sha256) -def digest(fw: FirmwareType) -> bytes: - return _header_digest(fw.firmware_header, FirmwareHeader) +def validate_code_hashes( + fw: FirmwareType, + hash_function: Callable = pyblake2.blake2s, + chunk_size: int = V2_CHUNK_SIZE, + padding_byte: bytes = None, +) -> None: + for i, expected_hash in enumerate(fw.firmware_header.hashes): + if i == 0: + # Because first chunk is sent along with headers, there is less code in it. + chunk = fw.code[: chunk_size - fw._code_offset] + else: + # Subsequent chunks are shifted by the "missing header" size. + ptr = i * chunk_size - fw._code_offset + chunk = fw.code[ptr : ptr + chunk_size] + + # padding for last chunk + if padding_byte is not None and i > 1 and chunk and len(chunk) < chunk_size: + chunk += padding_byte[0:1] * (chunk_size - len(chunk)) + + if not chunk and expected_hash == b"\0" * 32: + continue + chunk_hash = hash_function(chunk).digest() + if chunk_hash != expected_hash: + raise FirmwareIntegrityError("Invalid firmware data.") + + +def validate_onev2(fw: FirmwareType, allow_unsigned: bool = False) -> None: + try: + check_sig_v1( + digest_onev2(fw), + fw.firmware_header.v1_key_indexes, + fw.firmware_header.v1_signatures, + ) + except Unsigned: + if not allow_unsigned: + raise + + validate_code_hashes( + fw, + hash_function=hashlib.sha256, + chunk_size=ONEV2_CHUNK_SIZE, + padding_byte=b"\xFF", + ) -def validate(fw: FirmwareType, skip_vendor_header=False) -> bool: +def validate_onev1(fw: FirmwareType, allow_unsigned: bool = False) -> None: + try: + check_sig_v1(digest_onev1(fw), fw.key_indexes, fw.signatures) + except Unsigned: + if not allow_unsigned: + raise + if fw.embedded_onev2: + validate_onev2(fw.embedded_onev2, allow_unsigned) + + +def validate_v2(fw: FirmwareType, skip_vendor_header=False) -> None: vendor_fingerprint = _header_digest(fw.vendor_header, VendorHeader) - fingerprint = digest(fw) + fingerprint = digest_v2(fw) if not skip_vendor_header: try: @@ -250,7 +356,7 @@ def validate(fw: FirmwareType, skip_vendor_header=False) -> bool: V2_BOOTLOADER_KEYS, ) except Exception: - raise ValueError("Invalid vendor header signature.") + raise InvalidSignatureError("Invalid vendor header signature.") # XXX expiry is not used now # now = time.gmtime() @@ -267,28 +373,39 @@ def validate(fw: FirmwareType, skip_vendor_header=False) -> bool: fw.vendor_header.pubkeys, ) except Exception: - raise ValueError("Invalid firmware signature.") + raise InvalidSignatureError("Invalid firmware signature.") # XXX expiry is not used now # if time.gmtime(fw.firmware_header.expiry) < now: # raise ValueError("Firmware header expired.") + validate_code_hashes(fw) - for i, expected_hash in enumerate(fw.firmware_header.hashes): - if i == 0: - # Because first chunk is sent along with headers, there is less code in it. - chunk = fw.code[: V2_CHUNK_SIZE - fw._code_offset] - else: - # Subsequent chunks are shifted by the "missing header" size. - ptr = i * V2_CHUNK_SIZE - fw._code_offset - chunk = fw.code[ptr : ptr + V2_CHUNK_SIZE] - if not chunk and expected_hash == b"\0" * 32: - continue - chunk_hash = pyblake2.blake2s(chunk).digest() - if chunk_hash != expected_hash: - raise ValueError("Invalid firmware data.") - - return True +def digest(version: FirmwareFormat, fw: FirmwareType) -> bytes: + if version == FirmwareFormat.TREZOR_ONE: + if fw.embedded_onev2: + return digest_onev2(fw.embedded_onev2) + else: + return digest_onev1(fw) + elif version == FirmwareFormat.TREZOR_ONE_V2: + return digest_onev2(fw) + elif version == FirmwareFormat.TREZOR_T: + return digest_v2(fw) + else: + raise ValueError("Unrecognized firmware version") + + +def validate( + version: FirmwareFormat, fw: FirmwareType, allow_unsigned: bool = True +) -> None: + if version == FirmwareFormat.TREZOR_ONE: + return validate_onev1(fw, allow_unsigned) + elif version == FirmwareFormat.TREZOR_ONE_V2: + return validate_onev2(fw, allow_unsigned) + elif version == FirmwareFormat.TREZOR_T: + return validate_v2(fw) + else: + raise ValueError("Unrecognized firmware version") # ====== Client functions ====== #