# This file is part of the Trezor project. # # Copyright (C) 2012-2022 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 # as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the License along with this library. # If not, see . import typing as t from copy import copy from dataclasses import asdict from enum import Enum import click import construct as c from construct_classes import Struct from typing_extensions import Protocol, Self, runtime_checkable from .. import cosi, firmware SYM_OK = click.style("\u2714", fg="green") SYM_FAIL = click.style("\u274c", fg="red") class Status(Enum): VALID = click.style("VALID", fg="green", bold=True) INVALID = click.style("INVALID", fg="red", bold=True) MISSING = click.style("MISSING", fg="blue", bold=True) DEVEL = click.style("DEVEL", fg="red", bold=True) def is_ok(self) -> bool: return self is Status.VALID or self is Status.DEVEL VHASH_DEVEL = bytes.fromhex( "c5b4d40cb76911392122c8d1c277937e49c69b2aaf818001ec5c7663fcce258f" ) def _make_dev_keys(*key_bytes: bytes) -> t.Sequence[bytes]: return [k * 32 for k in key_bytes] def all_zero(data: bytes) -> bool: return all(b == 0 for b in data) def _check_signature_any(fw: "SignableImageProto", is_devel: bool = False) -> Status: if not fw.signature_present(): return Status.MISSING try: fw.verify() return Status.VALID if not is_devel else Status.DEVEL except Exception: pass try: fw.verify(public_keys=fw.public_keys(dev_keys=True)) return Status.DEVEL except Exception: return Status.INVALID # ====================== formatting functions ==================== class LiteralStr(str): pass def _format_container( pb: t.Union[c.Container, Struct, dict], indent: int = 0, sep: str = " " * 4, truncate_after: t.Optional[int] = 64, truncate_to: t.Optional[int] = 32, ) -> str: def mostly_printable(bytes: bytes) -> bool: if not bytes: return True printable = sum(1 for byte in bytes if 0x20 <= byte <= 0x7E) return printable / len(bytes) > 0.8 def pformat(value: t.Any, indent: int) -> str: level = sep * indent leadin = sep * (indent + 1) if isinstance(value, LiteralStr): return value if isinstance(value, list): # short list of simple values if not value or isinstance(value[0], (int, bool, Enum)): return repr(value) # long list, one line per entry lines = ["[", level + "]"] lines[1:1] = [leadin + pformat(x, indent + 1) for x in value] return "\n".join(lines) if isinstance(value, Struct): value = asdict(value) if isinstance(value, dict): lines = ["{"] for key, val in value.items(): if key.startswith("_"): continue if val is None or val == []: continue lines.append(leadin + key + ": " + pformat(val, indent + 1)) lines.append(level + "}") return "\n".join(lines) if isinstance(value, (bytes, bytearray)): length = len(value) suffix = "" if truncate_after and length > truncate_after: suffix = "..." value = value[: truncate_to or 0] if mostly_printable(value): output = repr(value) else: output = value.hex() return f"{length} bytes {output}{suffix}" if isinstance(value, Enum): return str(value) return repr(value) return pformat(pb, indent) def _format_version(version: t.Tuple[int, ...]) -> str: return ".".join(str(i) for i in version) def format_header( header: firmware.core.FirmwareHeader, code_hashes: t.Sequence[bytes], digest: bytes, sig_status: Status, ) -> str: header_dict = asdict(header) header_out = header_dict.copy() for key, val in header_out.items(): if "version" in key: header_out[key] = LiteralStr(_format_version(val)) hashes_out = [] for expected, actual in zip(header.hashes, code_hashes): status = SYM_OK if expected == actual else SYM_FAIL hashes_out.append(LiteralStr(f"{status} {expected.hex()}")) if all(all_zero(h) for h in header.hashes): hash_status = Status.MISSING elif header.hashes != code_hashes: hash_status = Status.INVALID else: hash_status = Status.VALID header_out["hashes"] = hashes_out all_ok = SYM_OK if hash_status.is_ok() and sig_status.is_ok() else SYM_FAIL output = [ "Firmware Header " + _format_container(header_out), f"Fingerprint: {click.style(digest.hex(), bold=True)}", f"{all_ok} Signature is {sig_status.value}, hashes are {hash_status.value}", ] return "\n".join(output) # =========================== functionality implementations =============== class SignableImageProto(Protocol): NAME: t.ClassVar[str] @classmethod def parse(cls, data: bytes) -> Self: ... def digest(self) -> bytes: ... def verify(self, public_keys: t.Sequence[bytes] = ...) -> None: ... def build(self) -> bytes: ... def format(self, verbose: bool = False) -> str: ... def signature_present(self) -> bool: ... def public_keys(self, dev_keys: bool = False) -> t.Sequence[bytes]: ... @runtime_checkable class CosiSignedImage(SignableImageProto, Protocol): DEV_KEYS: t.ClassVar[t.Sequence[bytes]] = [] def insert_signature(self, signature: bytes, sigmask: int) -> None: ... @runtime_checkable class LegacySignedImage(SignableImageProto, Protocol): def slots(self) -> t.Iterable[int]: ... def insert_signature(self, slot: int, key_index: int, signature: bytes) -> None: ... class CosiSignatureHeaderProto(Protocol): signature: bytes sigmask: int class CosiSignedMixin: def signature_present(self) -> bool: header = self.get_header() return not all_zero(header.signature) or header.sigmask != 0 def insert_signature(self, signature: bytes, sigmask: int) -> None: self.get_header().signature = signature self.get_header().sigmask = sigmask def get_header(self) -> CosiSignatureHeaderProto: raise NotImplementedError class VendorHeader(firmware.VendorHeader, CosiSignedMixin): NAME = "vendorheader" DEV_KEYS = _make_dev_keys(b"\x44", b"\x45") SUBCON = c.Struct(*firmware.VendorHeader.SUBCON.subcons, c.Terminated) def get_header(self) -> CosiSignatureHeaderProto: return self def _format(self, terse: bool) -> str: if not terse: output = [ "Vendor Header " + _format_container(self), f"Pubkey bundle hash: {self.vhash().hex()}", ] else: output = [ "Vendor Header for {vendor} version {version} ({size} bytes)".format( vendor=click.style(self.text, bold=True), version=_format_version(self.version), size=self.header_len, ), ] if not terse: output.append(f"Fingerprint: {click.style(self.digest().hex(), bold=True)}") sig_status = _check_signature_any(self) sym = SYM_OK if sig_status.is_ok() else SYM_FAIL output.append(f"{sym} Signature is {sig_status.value}") return "\n".join(output) def format(self, verbose: bool = False) -> str: return self._format(terse=False) def public_keys(self, dev_keys: bool = False) -> t.Sequence[bytes]: if not dev_keys: return firmware.V2_BOOTLOADER_KEYS else: return firmware.V2_BOOTLOADER_DEV_KEYS class VendorFirmware(firmware.VendorFirmware, CosiSignedMixin): NAME = "firmware" DEV_KEYS = _make_dev_keys(b"\x47", b"\x48") def get_header(self) -> CosiSignatureHeaderProto: return self.firmware.header def format(self, verbose: bool = False) -> str: vh = copy(self.vendor_header) vh.__class__ = VendorHeader assert isinstance(vh, VendorHeader) is_devel = self.vendor_header.vhash() == VHASH_DEVEL return ( vh._format(terse=not verbose) + "\n" + format_header( self.firmware.header, self.firmware.code_hashes(), self.digest(), _check_signature_any(self, is_devel), ) ) def public_keys(self, dev_keys: bool = False) -> t.Sequence[bytes]: return self.vendor_header.pubkeys class BootloaderImage(firmware.FirmwareImage, CosiSignedMixin): NAME = "bootloader" DEV_KEYS = _make_dev_keys(b"\x41", b"\x42") def get_header(self) -> CosiSignatureHeaderProto: return self.header def format(self, verbose: bool = False) -> str: return format_header( self.header, self.code_hashes(), self.digest(), _check_signature_any(self), ) def verify(self, public_keys: t.Sequence[bytes] = ()) -> None: self.validate_code_hashes() if not public_keys: public_keys = self.public_keys() try: cosi.verify( self.header.signature, self.digest(), firmware.V2_SIGS_REQUIRED, public_keys, self.header.sigmask, ) except Exception: raise firmware.InvalidSignatureError("Invalid bootloader signature") def public_keys(self, dev_keys: bool = False) -> t.Sequence[bytes]: if not dev_keys: return firmware.V2_BOARDLOADER_KEYS else: return firmware.V2_BOARDLOADER_DEV_KEYS class LegacyFirmware(firmware.LegacyFirmware): NAME = "legacy_firmware_v1" BIP32_INDEX = None def signature_present(self) -> bool: return any(i != 0 for i in self.key_indexes) or any( not all_zero(sig) for sig in self.signatures ) def insert_signature(self, slot: int, key_index: int, signature: bytes) -> None: if not 0 <= slot < firmware.V1_SIGNATURE_SLOTS: raise ValueError("Invalid slot number") if not 0 < key_index <= len(firmware.V1_BOOTLOADER_KEYS): raise ValueError("Invalid key index") self.key_indexes[slot] = key_index self.signatures[slot] = signature def format(self, verbose: bool = False) -> str: contents = asdict(self).copy() del contents["embedded_v2"] if self.embedded_v2: em = copy(self.embedded_v2) em.__class__ = LegacyV2Firmware assert isinstance(em, LegacyV2Firmware) embedded_content = "\nEmbedded V2 header: " + em.format(verbose=verbose) else: embedded_content = "" return _format_container(contents) + embedded_content def public_keys(self, dev_keys: bool = False) -> t.Sequence[bytes]: return firmware.V1_BOOTLOADER_KEYS def slots(self) -> t.Iterable[int]: return self.key_indexes class LegacyV2Firmware(firmware.LegacyV2Firmware): NAME = "legacy_firmware_v2" BIP32_INDEX = 5 def signature_present(self) -> bool: return any(i != 0 for i in self.header.v1_key_indexes) or any( not all_zero(sig) for sig in self.header.v1_signatures ) def insert_signature(self, slot: int, key_index: int, signature: bytes) -> None: if not 0 <= slot < firmware.V1_SIGNATURE_SLOTS: raise ValueError("Invalid slot number") if not 0 < key_index <= len(firmware.V1_BOOTLOADER_KEYS): raise ValueError("Invalid key index") if not isinstance(self.header.v1_key_indexes, list): self.header.v1_key_indexes = list(self.header.v1_key_indexes) if not isinstance(self.header.v1_signatures, list): self.header.v1_signatures = list(self.header.v1_signatures) self.header.v1_key_indexes[slot] = key_index self.header.v1_signatures[slot] = signature def format(self, verbose: bool = False) -> str: return format_header( self.header, self.code_hashes(), self.digest(), _check_signature_any(self), ) def public_keys(self, dev_keys: bool = False) -> t.Sequence[bytes]: return firmware.V1_BOOTLOADER_KEYS def slots(self) -> t.Iterable[int]: return self.header.v1_key_indexes def parse_image(image: bytes) -> SignableImageProto: try: return VendorFirmware.parse(image) except c.ConstructError: pass try: return VendorHeader.parse(image) except c.ConstructError: pass try: firmware_img = firmware.core.FirmwareImage.parse(image) if firmware_img.header.magic == firmware.core.HeaderType.BOOTLOADER: return BootloaderImage.parse(image) if firmware_img.header.magic == firmware.core.HeaderType.FIRMWARE: return LegacyV2Firmware.parse(image) raise ValueError("Unrecognized firmware header magic") except c.ConstructError: pass try: return LegacyFirmware.parse(image) except c.ConstructError: pass raise ValueError("Unrecognized firmware type")