#!/usr/bin/env python3
import sys
import traceback

import click
import Pyro4
from trezorlib import cosi
from trezorlib.client import get_default_client
from trezorlib.tools import parse_path
from trezorlib._internal.firmware_headers import (
    parse_image,
    VendorHeader,
    BootloaderImage,
    FirmwareImage,
)

from typing import Tuple

Pyro4.config.SERIALIZER = "marshal"

PORT = 5001
indexmap = {
    "bootloader": BootloaderImage,
    "vendorheader": VendorHeader,
    "firmware": FirmwareImage,
}

PATH = "10018h/{}h"

TREZOR = None


def make_commit(fw_or_type, digest, public_keys):
    path = PATH.format(fw_or_type.BIP32_INDEX)
    address_n = parse_path(path)

    # device information - show only first time
    click.echo(
        f"\nUsing device {click.style(TREZOR.features.label, bold=True)} "
        f"at path {TREZOR.transport.get_path()}"
    )

    while True:
        # signing information - repeat every time
        click.echo(f"Commiting to {click.style(fw_or_type.NAME, bold=True)} hash:")
        for partid in range(4):
            digest_part = digest[partid * 8 : (partid + 1) * 8]
            color = "red" if partid % 2 else "cyan"
            digest_str = click.style(digest_part.hex().upper(), fg=color)
            click.echo("\t" + digest_str)
        click.echo(f"Using path: {click.style(path, bold=True)}")

        try:
            commit = cosi.commit(TREZOR, address_n, digest)
            if public_keys is not None and commit.pubkey not in public_keys:
                click.echo(f"\n\nPublic key {commit.pubkey.hex()} is unknown.")
                if click.confirm("Retry with a different passphrase?", default=True):
                    TREZOR.init_device()
                    continue

            return commit.pubkey, commit.commitment
        except Exception as e:
            click.echo(e)
            traceback.print_exc()
            click.echo("Trying again ...\n\n")


@Pyro4.expose
class KeyctlProxy:
    def __init__(
        self, daemon, image_type, digest: bytes, commit: Tuple[bytes, bytes]
    ) -> None:
        self.daemon = daemon
        self.name = image_type.NAME
        self.address_n = parse_path(PATH.format(image_type.BIP32_INDEX))
        self.digest = digest
        self.commit = commit
        self.signature = None
        self.global_params = None

    def _check_name_digest(self, name, digest):
        if name != self.name or digest != self.digest:
            click.echo(f"ERROR! Remote wants to sign {name} with digest {digest.hex()}")
            click.echo(f"Expected: {self.name} with digest {self.digest.hex()}")
            raise ValueError("Unexpected index/digest")

    def get_commit(self, name, digest):
        self._check_name_digest(name, digest)
        click.echo("Sending commitment!")
        return self.commit

    def _make_signature(self, global_R, global_pk):
        while True:
            try:
                click.echo("\n\n\nSigning...")
                signature = cosi.sign(
                    TREZOR, self.address_n, self.digest, global_R, global_pk
                )
                return signature.signature
            except Exception as e:
                click.echo(e)
                traceback.print_exc()
                click.echo("Trying again ...")


    def get_signature(self, name, digest, global_R, global_pk):
        self._check_name_digest(name, digest)
        global_params = global_R, global_pk
        if global_params != self.global_params:
            self.signature = self._make_signature(global_R, global_pk)
            self.global_params = global_params
        click.echo("Sending signature!")
        return self.signature

    @Pyro4.oneway
    def finish(self):
        click.echo("Done! \\(^o^)/")
        self.daemon.shutdown()


@click.command()
@click.option(
    "-l", "--listen", "ipaddr", default="0.0.0.0", help="Bind to particular ip address"
)
@click.option("-t", "--header-type", "fw_or_type", type=click.Choice(indexmap.keys()))
@click.option("-d", "--digest")
@click.argument("fw_file", type=click.File("rb"), required=False)
def cli(ipaddr, fw_file, fw_or_type, digest):
    """Participate in signing of firmware.

    Specify either fw_file to auto-detect type and digest, or use -t and -d to specify
    the type and digest manually.
    """
    global TREZOR

    public_keys = None
    if fw_file:
        if fw_or_type or digest:
            raise click.ClickException("Do not specify fw_file together with -t/-d")

        fw_or_type = parse_image(fw_file.read())
        digest = fw_or_type.digest()
        public_keys = fw_or_type.public_keys

        click.echo(fw_or_type.format())

    if not fw_file and (not fw_or_type or not digest):
        raise click.ClickException("Please specify either fw_file or -t and -h")

    try:
        TREZOR = get_default_client()
        TREZOR.ui.always_prompt = True
    except Exception as e:
        raise click.ClickException("Please connect a Trezor and retry.") from e

    pubkey, R = make_commit(fw_or_type, digest, public_keys)

    daemon = Pyro4.Daemon(host=ipaddr, port=PORT)
    proxy = KeyctlProxy(daemon, fw_or_type, digest, (pubkey, R))
    uri = daemon.register(proxy, "keyctl")
    click.echo(f"keyctl-proxy running at URI: {uri}")
    click.echo("Press Ctrl+C to abort.")
    daemon.requestLoop()


if __name__ == "__main__":
    cli()