#!/usr/bin/env python3

from __future__ import print_function

import sys
import struct
import binascii

import pyblake2

from trezorlib import cosi


def format_sigmask(sigmask):
    bits = [str(b + 1) if sigmask & (1 << b) else "." for b in range(8)]
    return "0x%02x = [%s]" % (sigmask, " ".join(bits))


def format_vtrust(vtrust):
    bits = [str(b) if vtrust & (1 << b) == 0 else "." for b in range(16)]
    # see docs/bootloader.md for vtrust constants
    desc = ""
    wait = (vtrust & 0x000F) ^ 0x000F
    if wait > 0:
        desc = "WAIT_%d" % wait
    if vtrust & 0x0010 == 0:
        desc += " RED"
    if vtrust & 0x0020 == 0:
        desc += " CLICK"
    if vtrust & 0x0040 == 0:
        desc += " STRING"
    return "%d = [%s] = [%s]" % (vtrust, " ".join(bits), desc)


# bootloader/firmware headers specification: https://github.com/trezor/trezor-core/blob/master/docs/bootloader.md

IMAGE_HEADER_SIZE = 1024
IMAGE_SIG_SIZE = 65
IMAGE_CHUNK_SIZE = 128 * 1024
BOOTLOADER_SECTORS_COUNT = 1
FIRMWARE_SECTORS_COUNT = 6 + 7


class BinImage(object):
    def __init__(self, data, magic, max_size):
        header = struct.unpack("<4sIIIBBBBBBBB8s512s415sB64s", data[:IMAGE_HEADER_SIZE])
        self.magic, self.hdrlen, self.expiry, self.codelen, self.vmajor, self.vminor, self.vpatch, self.vbuild, self.fix_vmajor, self.fix_vminor, self.fix_vpatch, self.fix_vbuild, self.reserved1, self.hashes, self.reserved2, self.sigmask, self.sig = (
            header
        )
        assert self.magic == magic
        assert self.hdrlen == IMAGE_HEADER_SIZE
        total_len = self.hdrlen + self.codelen
        assert total_len % 512 == 0
        assert total_len >= 4 * 1024
        assert total_len <= max_size
        assert self.reserved1 == 8 * b"\x00"
        assert self.reserved2 == 415 * b"\x00"
        self.code = data[self.hdrlen :]
        assert len(self.code) == self.codelen

    def print(self):
        if self.magic == b"TRZF":
            print("Trezor Firmware Image")
            total_len = self.vhdrlen + self.hdrlen + self.codelen
        elif self.magic == b"TRZB":
            print("Trezor Bootloader Image")
            total_len = self.hdrlen + self.codelen
        else:
            print("Trezor Unknown Image")
        print("  * magic   :", self.magic.decode())
        print("  * hdrlen  :", self.hdrlen)
        print("  * expiry  :", self.expiry)
        print("  * codelen :", self.codelen)
        print(
            "  * version : %d.%d.%d.%d"
            % (self.vmajor, self.vminor, self.vpatch, self.vbuild)
        )
        print(
            "  * fixver  : %d.%d.%d.%d"
            % (self.fix_vmajor, self.fix_vminor, self.fix_vpatch, self.fix_vbuild)
        )
        print("  * hashes: %s" % ("OK" if self.check_hashes() else "INCORRECT"))
        for i in range(16):
            print(
                "    - %02d : %s"
                % (i, binascii.hexlify(self.hashes[i * 32 : i * 32 + 32]).decode())
            )
        print("  * sigmask :", format_sigmask(self.sigmask))
        print("  * sig     :", binascii.hexlify(self.sig).decode())
        print("  * total   : %d bytes" % total_len)
        print("  * fngprnt :", self.fingerprint())
        print()

    def compute_hashes(self):
        if self.magic == b"TRZF":
            hdrlen = self.vhdrlen + self.hdrlen
        else:
            hdrlen = self.hdrlen
        hashes = b""
        for i in range(16):
            if i == 0:
                d = self.code[: IMAGE_CHUNK_SIZE - hdrlen]
            else:
                s = IMAGE_CHUNK_SIZE - hdrlen + (i - 1) * IMAGE_CHUNK_SIZE
                d = self.code[s : s + IMAGE_CHUNK_SIZE]
            if len(d) > 0:
                h = pyblake2.blake2s(d).digest()
            else:
                h = 32 * b"\x00"
            hashes += h
        return hashes

    def check_hashes(self):
        return self.hashes == self.compute_hashes()

    def update_hashes(self):
        self.hashes = self.compute_hashes()

    def serialize_header(self, sig=True):
        header = struct.pack(
            "<4sIIIBBBBBBBB8s512s415s",
            self.magic,
            self.hdrlen,
            self.expiry,
            self.codelen,
            self.vmajor,
            self.vminor,
            self.vpatch,
            self.vbuild,
            self.fix_vmajor,
            self.fix_vminor,
            self.fix_vpatch,
            self.fix_vbuild,
            self.reserved1,
            self.hashes,
            self.reserved2,
        )
        if sig:
            header += struct.pack("<B64s", self.sigmask, self.sig)
        else:
            header += IMAGE_SIG_SIZE * b"\x00"
        assert len(header) == self.hdrlen
        return header

    def fingerprint(self):
        return pyblake2.blake2s(self.serialize_header(sig=False)).hexdigest()

    def sign(self, sigmask, signature):
        header = self.serialize_header(sig=False)
        data = header + self.code
        assert len(data) == self.hdrlen + self.codelen
        self.sigmask = sigmask
        self.sig = signature

    def write(self, filename):
        with open(filename, "wb") as f:
            f.write(self.serialize_header())
            f.write(self.code)


class FirmwareImage(BinImage):
    def __init__(self, data, vhdrlen):
        super(FirmwareImage, self).__init__(
            data[vhdrlen:],
            magic=b"TRZF",
            max_size=FIRMWARE_SECTORS_COUNT * IMAGE_CHUNK_SIZE,
        )
        self.vhdrlen = vhdrlen
        self.vheader = data[:vhdrlen]

    def write(self, filename):
        with open(filename, "wb") as f:
            f.write(self.vheader)
            f.write(self.serialize_header())
            f.write(self.code)


class BootloaderImage(BinImage):
    def __init__(self, data):
        super(BootloaderImage, self).__init__(
            data, magic=b"TRZB", max_size=BOOTLOADER_SECTORS_COUNT * IMAGE_CHUNK_SIZE
        )


class VendorHeader(object):
    def __init__(self, data):
        header = struct.unpack("<4sIIBBBBH", data[:18])
        self.magic, self.hdrlen, self.expiry, self.vmajor, self.vminor, self.vsig_m, self.vsig_n, self.vtrust = (
            header
        )
        assert self.magic == b"TRZV"
        data = data[: self.hdrlen]  # strip remaining data (firmware)
        assert self.vsig_m > 0 and self.vsig_m <= self.vsig_n
        assert self.vsig_n > 0 and self.vsig_n <= 8
        p = 32
        self.vpub = []
        for _ in range(self.vsig_n):
            self.vpub.append(data[p : p + 32])
            p += 32
        self.vstr_len = data[p]
        p += 1
        self.vstr = data[p : p + self.vstr_len]
        p += self.vstr_len
        vstr_pad = -p & 3
        p += vstr_pad
        self.vimg_len = len(data) - IMAGE_SIG_SIZE - p
        self.vimg = data[p : p + self.vimg_len]
        p += self.vimg_len
        self.sigmask = data[p]
        p += 1
        self.sig = data[p : p + 64]
        assert (
            len(data)
            == 4
            + 4
            + 4
            + 1
            + 1
            + 1
            + 1
            + 1
            + 15
            + 32 * len(self.vpub)
            + 1
            + self.vstr_len
            + vstr_pad
            + self.vimg_len
            + IMAGE_SIG_SIZE
        )
        assert len(data) % 512 == 0

    def print(self):
        print("Trezor Vendor Header")
        print("  * magic   :", self.magic.decode())
        print("  * hdrlen  :", self.hdrlen)
        print("  * expiry  :", self.expiry)
        print("  * version : %d.%d" % (self.vmajor, self.vminor))
        print("  * scheme  : %d out of %d" % (self.vsig_m, self.vsig_n))
        print("  * trust   :", format_vtrust(self.vtrust))
        for i in range(self.vsig_n):
            print("  * vpub #%d :" % (i + 1), binascii.hexlify(self.vpub[i]).decode())
        print("  * vstr    :", self.vstr.decode())
        print("  * vhash   :", binascii.hexlify(self.vhash()).decode())
        print("  * vimg    : (%d bytes)" % len(self.vimg))
        print("  * sigmask :", format_sigmask(self.sigmask))
        print("  * sig     :", binascii.hexlify(self.sig).decode())
        print("  * fngprnt :", self.fingerprint())
        print()

    def serialize_header(self, sig=True):
        header = struct.pack(
            "<4sIIBBBBH",
            self.magic,
            self.hdrlen,
            self.expiry,
            self.vmajor,
            self.vminor,
            self.vsig_m,
            self.vsig_n,
            self.vtrust,
        )
        header += 14 * b"\x00"
        for i in range(self.vsig_n):
            header += self.vpub[i]
        header += struct.pack("<B", self.vstr_len) + self.vstr
        header += (-len(header) & 3) * b"\x00"  # vstr_pad
        header += self.vimg
        if sig:
            header += struct.pack("<B64s", self.sigmask, self.sig)
        else:
            header += IMAGE_SIG_SIZE * b"\x00"
        assert len(header) == self.hdrlen
        return header

    def fingerprint(self):
        return pyblake2.blake2s(self.serialize_header(sig=False)).hexdigest()

    def vhash(self):
        h = pyblake2.blake2s()
        h.update(struct.pack("<BB", self.vsig_m, self.vsig_n))
        for i in range(8):
            if i < self.vsig_n:
                h.update(self.vpub[i])
            else:
                h.update(b"\x00" * 32)
        return h.digest()

    def sign(self, sigmask, signature):
        header = self.serialize_header(sig=False)
        assert len(header) == self.hdrlen
        self.sigmask = sigmask
        self.sig = signature

    def write(self, filename):
        with open(filename, "wb") as f:
            f.write(self.serialize_header())


def binopen(filename):
    data = open(filename, "rb").read()
    magic = data[:4]
    if magic == b"TRZB":
        return BootloaderImage(data)
    if magic == b"TRZV":
        vheader = VendorHeader(data)
        if len(data) == vheader.hdrlen:
            return vheader
        vheader.print()
        subdata = data[vheader.hdrlen :]
        if subdata[:4] == b"TRZF":
            firmware = FirmwareImage(data, vheader.hdrlen)
            # check signatures against signing keys in the vendor header
            if firmware.sigmask > 0:
                pk = [vheader.vpub[i] for i in range(8) if firmware.sigmask & (1 << i)]
                global_pk = cosi.combine_keys(pk)
                hdr = (
                    subdata[: IMAGE_HEADER_SIZE - IMAGE_SIG_SIZE]
                    + IMAGE_SIG_SIZE * b"\x00"
                )
                digest = pyblake2.blake2s(hdr).digest()
                try:
                    cosi.verify(firmware.sig, digest, global_pk)
                    print("Firmware signature OK")
                except ValueError:
                    print("Firmware signature INCORRECT")
            else:
                print("No firmware signature")
            return firmware
    if magic == b"TRZF":
        return FirmwareImage(data, 0)
    raise Exception("Unknown file format")


def main():
    if len(sys.argv) < 2:
        print("Usage: binctl file.bin [-s sigmask signature] [-h]")
        return 1
    fn = sys.argv[1]
    sign = len(sys.argv) > 2 and sys.argv[2] == "-s"
    rehash = len(sys.argv) == 3 and sys.argv[2] == "-h"
    b = binopen(fn)
    if sign:
        sigmask = 0
        if ":" in sys.argv[3]:
            for idx in sys.argv[3].split(":"):
                sigmask |= 1 << (int(idx) - 1)
        else:
            sigmask = 1 << (int(sys.argv[3]) - 1)
        signature = binascii.unhexlify(sys.argv[4])
        b.sign(sigmask, signature)
        b.write(fn)
    if rehash:
        b.update_hashes()
        b.write(fn)
    b.print()


if __name__ == "__main__":
    main()