#!/usr/bin/env python3

from __future__ import print_function

import sys
import struct
import binascii
import pyblake2

from trezorlib import ed25519raw, ed25519cosi


if sys.version_info.major < 3:
    input = raw_input


def get_trezor(index):
    from trezorlib.client import TrezorClient
    from trezorlib.transport_hid import HidTransport
    devices = HidTransport.enumerate()
    if len(devices) > index:
        return TrezorClient(devices[index])
    else:
        raise Exception('TREZOR with such index not found')


def sign_data(seckeys, data):
    digest = pyblake2.blake2s(data).digest()
    if len(seckeys) == 1:
        sk = seckeys[0]
        pk = ed25519raw.publickey(sk)
        return ed25519raw.signature(digest, sk, pk)
    else:
        ctr = 0
        pubkeys = []
        nonces = []
        commits = []
        for i, sk in enumerate(seckeys):
            if sk == 'trezor':
                t = get_trezor(i)
                # FIXME: path below should change according to what is being signed
                commit = t.cosi_commit(t.expand_path("10018'/0'"), digest)
                pk = commit.pubkey
                r = None
                R = commit.commitment
            else:
                pk = ed25519raw.publickey(sk)
                r, R = ed25519cosi.get_nonce(sk, digest, ctr)
            pubkeys.append(pk)
            nonces.append(r)
            commits.append(R)
        global_pk = ed25519cosi.combine_keys(pubkeys)
        global_R = ed25519cosi.combine_keys(commits)
        sigs = []
        for i, sk in enumerate(seckeys):
            if sk == 'trezor':
                t = get_trezor(i)
                # FIXME: path below should change according to what is being signed
                signature = t.cosi_sign(t.expand_path("10018'/0'"), digest, global_R, global_pk)
                sig = signature.signature
            else:
                r = nonces[i]
                R = commits[i]
                h = ed25519raw.H(sk)
                b = ed25519raw.b
                a = 2 ** (b - 2) + sum(2 ** i * ed25519raw.bit(h, i) for i in range(3, b - 2))
                S = (r + ed25519raw.Hint(global_R + global_pk + digest) * a) % ed25519raw.l
                sig = ed25519raw.encodeint(S)
            sigs.append(sig)
        sig = ed25519cosi.combine_sig(global_R, sigs)
        ed25519raw.checkvalid(sig, digest, global_pk)
        return sig


def format_sigmask(sigmask):
    bits = [str(b + 1) if sigmask & (1 << b) else '.' for b in range(8)]
    return '0x%02x = [%s]' % (sigmask, ' '.join(bits))


# bootloader/firmware headers specification: https://github.com/trezor/trezor-core/blob/master/docs/bootloader.md


class BinImage(object):

    def __init__(self, data, magic, max_size):
        header = struct.unpack('<4sIIIBBBB427sB64s', data[:512])
        self.magic, \
        self.hdrlen, \
        self.expiry, \
        self.codelen, \
        self.vmajor, \
        self.vminor, \
        self.vpatch, \
        self.vbuild, \
        self.reserved, \
        self.sigmask, \
        self.sig = header
        assert self.magic == magic
        assert self.hdrlen == 512
        total_len = self.hdrlen + self.codelen
        assert total_len % 512 == 0
        assert total_len >= 4 * 1024
        assert total_len <= max_size
        assert self.reserved == 427 * b'\x00'
        self.code = data[self.hdrlen:]
        assert len(self.code) == self.codelen

    def print(self):
        if self.magic == b'TRZF':
            print('TREZOR Firmware Image')
            total_len = self.vhdrlen + self.hdrlen + self.codelen
        elif self.magic == b'TRZB':
            print('TREZOR Bootloader Image')
            total_len = self.hdrlen + self.codelen
        else:
            print('TREZOR Unknown Image')
        print('  * magic   :', self.magic.decode('ascii'))
        print('  * hdrlen  :', self.hdrlen)
        print('  * expiry  :', self.expiry)
        print('  * codelen :', self.codelen)
        print('  * version : %d.%d.%d.%d' % (self.vmajor, self.vminor, self.vpatch, self.vbuild))
        print('  * sigmask :', format_sigmask(self.sigmask))
        print('  * sig     :', binascii.hexlify(self.sig).decode('ascii'))
        print('  * total   : %d bytes' % total_len)
        print()

    def serialize_header(self, sig=True):
        header = struct.pack('<4sIIIBBBB427s',
                                self.magic, self.hdrlen, self.expiry, self.codelen,
                                self.vmajor, self.vminor, self.vpatch, self.vbuild,
                                self.reserved)
        if sig:
            header += struct.pack('<B64s', self.sigmask, self.sig)
        else:
            header += 65 * b'\x00'
        assert len(header) == self.hdrlen
        return header

    def sign(self, sigmask, seckeys):
        header = self.serialize_header(sig=False)
        data = header + self.code
        assert len(data) == self.hdrlen + self.codelen
        self.sigmask = sigmask
        self.sig = sign_data(seckeys, data)

    def write(self, filename):
        with open(filename, 'wb') as f:
            f.write(self.serialize_header())
            f.write(self.code)


class FirmwareImage(BinImage):

    def __init__(self, data, vhdrlen):
        super(FirmwareImage, self).__init__(data[vhdrlen:], magic=b'TRZF', max_size=7 * 128 * 1024)
        self.vhdrlen = vhdrlen
        self.vheader = data[:vhdrlen]

    def write(self, filename):
        with open(filename, 'wb') as f:
            f.write(self.vheader)
            f.write(self.serialize_header())
            f.write(self.code)


class BootloaderImage(BinImage):

    def __init__(self, data):
        super(BootloaderImage, self).__init__(data, magic=b'TRZB', max_size=64 * 1024 + 7 * 128 * 1024)


class VendorHeader(object):

    def __init__(self, data):
        header = struct.unpack('<4sIIBBBBB', data[:17])
        self.magic, \
        self.hdrlen, \
        self.expiry, \
        self.vmajor, \
        self.vminor, \
        self.vsig_m, \
        self.vsig_n, \
        self.vtrust = header
        assert self.magic == b'TRZV'
        assert self.vsig_m > 0 and self.vsig_m <= self.vsig_n
        assert self.vsig_n > 0 and self.vsig_n <= 8
        p = 32
        self.vpub = []
        for _ in range(self.vsig_n):
            self.vpub.append(data[p:p + 32])
            p += 32
        self.vstr_len = data[p]
        p += 1
        self.vstr = data[p:p + self.vstr_len]
        p += self.vstr_len
        vstr_pad = -p & 3
        p += vstr_pad
        self.vimg_len = len(data) - 65 - p
        self.vimg = data[p:p + self.vimg_len]
        p += self.vimg_len
        self.sigmask = data[p]
        p += 1
        self.sig = data[p:p + 64]
        assert len(data) == 4 + 4 + 4 + 1 + 1 + 1 + 1 + 1 + 15 + \
                            32 * len(self.vpub) + \
                            1 + self.vstr_len + vstr_pad + \
                            self.vimg_len + \
                            1 + 64

    def print(self):
        print('TREZOR Vendor Header')
        print('  * magic   :', self.magic.decode('ascii'))
        print('  * hdrlen  :', self.hdrlen)
        print('  * expiry  :', self.expiry)
        print('  * version : %d.%d' % (self.vmajor, self.vminor))
        print('  * scheme  : %d out of %d' % (self.vsig_m, self.vsig_n))
        print('  * trust   :', self.vtrust)
        for i in range(self.vsig_n):
            print('  * vpub #%d :' % (i + 1), binascii.hexlify(self.vpub[i]).decode('ascii'))
        print('  * vstr    :', self.vstr.decode('ascii'))
        print('  * vimg    : (%d bytes)' % len(self.vimg))
        print('  * sigmask :', format_sigmask(self.sigmask))
        print('  * sig     :', binascii.hexlify(self.sig).decode('ascii'))

    def serialize_header(self, sig=True):
        header = struct.pack('<4sIIBBBBB',
                               self.magic, self.hdrlen, self.expiry,
                               self.vmajor, self.vminor,
                               self.vsig_m, self.vsig_n, self.vtrust)
        header += 15 * b'\x00'
        for i in range(self.vsig_n):
            header += self.vpub[i]
        header += struct.pack('<B', self.vstr_len) + self.vstr
        header += (-len(header) & 3) * b'\x00'  # vstr_pad
        header += self.vimg
        if sig:
            header += struct.pack('<B64s', self.sigmask, self.sig)
        else:
            header += 65 * b'\x00'
        assert len(header) == self.hdrlen
        return header

    def sign(self, sigmask, seckeys):
        # check whether provided arguments match vsig_m/vsig_n
        if len(seckeys) != self.vsig_m:
            raise Exception('invalid number of signatures (vsig_m expected %d, got %d)' % (self.vsig_m, len(seckeys)))
        if sigmask >= (1 << self.vsig_n):
            raise Exception('signature index is higher than vsig_n (%d)' % self.vsig_n)
        if bin(sigmask).count('1') != self.vsig_m:
            raise Exception('invalid number of indexes (vsig_m expected %d, got %d)' % (self.vsig_m, bin(sigmask).count('1')))
        # sign
        header = self.serialize_header(sig=False)
        self.sigmask = sigmask
        self.sig = sign_data(seckeys, header)

    def write(self, filename):
        with open(filename, 'wb') as f:
            f.write(self.serialize_header())


def binopen(filename):
    data = open(filename, 'rb').read()
    data = bytearray(data)  # python2/3 compatibility
    magic = data[:4]
    if magic == b'TRZB':
        return BootloaderImage(data)
    if magic == b'TRZV':
        vheader = VendorHeader(data)
        if len(data) == vheader.hdrlen:
            return vheader
        subdata = data[vheader.hdrlen:]
        if subdata[:4] == b'TRZF':
            return FirmwareImage(data, vheader.hdrlen)
    if magic == b'TRZF':
        return FirmwareImage(data, 0)
    raise Exception('Unknown file format')


def main():
    if len(sys.argv) < 2:
        print('Usage: binctl file.bin [-s index seckey]')
        return 1
    fn = sys.argv[1]
    sign = len(sys.argv) > 2 and sys.argv[2] == '-s'
    b = binopen(fn)
    if sign:
        sigmask = 0
        if ':' in sys.argv[3]:
            for idx in sys.argv[3].split(':'):
                sigmask |= 1 << (int(idx) - 1)
        else:
            sigmask = 1 << (int(sys.argv[3]) - 1)
        if ':' in sys.argv[4]:
            seckeys = sys.argv[4].split(':')
            for i in range(len(seckeys)):
                if seckeys[i] != 'trezor':
                    seckeys[i] = binascii.unhexlify(seckeys[i])
        else:
            seckeys = [binascii.unhexlify(sys.argv[4])]
        b.sign(sigmask, seckeys)
        print()
        b.write(fn)
    b.print()


if __name__ == '__main__':
    main()