#!/usr/bin/env python3

import binascii
import click
import pyblake2

from trezorlib import ed25519raw, ed25519cosi


indexmap = {
    'bootloader': 0,
    'vendorheader': 1,
    'firmware': 2,
}


def get_trezor():
    from trezorlib.client import TrezorClient
    from trezorlib.transport_hid import HidTransport
    devices = HidTransport.enumerate()
    if len(devices) > 0:
        return TrezorClient(devices[0])
    else:
        raise Exception('No TREZOR found')


@click.group()
def cli():
    pass


@cli.command(help='')
@click.argument('index', type=click.Choice(indexmap.keys()))
@click.argument('filename')
@click.argument('seckey', required=False)
def commit(index, filename, seckey):
    index = indexmap[index]
    data = open(filename, 'rb').read()
    digest = pyblake2.blake2s(data).digest()
    ctr = 0
    if seckey:
        sk = binascii.unhexlify(seckey)
        pk = ed25519raw.publickey(sk)
        _, R = ed25519cosi.get_nonce(sk, digest, ctr)
    else:
        t = get_trezor()
        path = "10018'/%d'" % index
        print('commiting to hash %s with path %s' % (binascii.hexlify(digest).decode(), path))
        commit = t.cosi_commit(t.expand_path(path), digest)
        pk = commit.pubkey
        R = commit.commitment
    print('%s+%s' % (binascii.hexlify(pk).decode(), binascii.hexlify(R).decode()))


@cli.command(help='')
@click.argument('commits', nargs=-1)
def global_commit(commits):
    if len(commits) < 1:
        raise Exception('Need to provide at least one commit')
    pk, R = [], []
    for c in commits:
        a, b = c.split('+')
        pk.append(binascii.unhexlify(a))
        R.append(binascii.unhexlify(b))
    global_pk = ed25519cosi.combine_keys(pk)
    global_R = ed25519cosi.combine_keys(R)
    print('%s+%s' % (binascii.hexlify(global_pk).decode(), binascii.hexlify(global_R).decode()))


@cli.command(help='')
@click.argument('index', type=click.Choice(indexmap.keys()))
@click.argument('filename')
@click.argument('global_commit')
@click.argument('seckey', required=False)
def sign(index, filename, global_commit, seckey):
    index = indexmap[index]
    data = open(filename, 'rb').read()
    digest = pyblake2.blake2s(data).digest()
    global_pk, global_R = [binascii.unhexlify(x) for x in global_commit.split('+')]
    ctr = 0
    if seckey:
        sk = binascii.unhexlify(seckey)
        r, _ = ed25519cosi.get_nonce(sk, digest, ctr)
        h = ed25519raw.H(sk)
        b = ed25519raw.b
        a = 2 ** (b - 2) + sum(2 ** i * ed25519raw.bit(h, i) for i in range(3, b - 2))
        S = (r + ed25519raw.Hint(global_R + global_pk + digest) * a) % ed25519raw.l
        sig = ed25519raw.encodeint(S)
    else:
        t = get_trezor()
        path = "10018'/%d'" % index
        print('signing hash %s with path %s' % (binascii.hexlify(digest).decode(), path))
        signature = t.cosi_sign(t.expand_path(path), digest, global_R, global_pk)
        sig = signature.signature
    print(binascii.hexlify(sig).decode())


@cli.command(help='')
@click.argument('filename')
@click.argument('global_commit')
@click.argument('signatures', nargs=-1)
def global_sign(filename, global_commit, signatures):
    data = open(filename, 'rb').read()
    digest = pyblake2.blake2s(data).digest()
    global_pk, global_R = [binascii.unhexlify(x) for x in global_commit.split('+')]
    signatures = [binascii.unhexlify(x) for x in signatures]
    sig = ed25519cosi.combine_sig(global_R, signatures)
    ed25519raw.checkvalid(sig, digest, global_pk)
    print(binascii.hexlify(sig).decode())


if __name__ == '__main__':
    cli()