#!/usr/bin/env python3
import sys
import binascii
import ed25519raw
import pyblake2


def hex(by):
    return str(binascii.hexlify(by), 'ascii')


def combine_keys(pks):
    combine = None
    for pk in pks:
        P = ed25519raw.decodepoint(pk)
        if not combine:
            combine = P
        else:
            combine = ed25519raw.edwards(combine, P)
    return ed25519raw.encodepoint(combine)


def combine_sig(R, sigs):
    s = 0
    for si in sigs:
        s += ed25519raw.decodeint(si)
    s = s % ed25519raw.l
    sig = R + ed25519raw.encodeint(s)
    return sig


def binom(n, k):
    b = 1
    for i in range(1, k + 1):
        b = b * (n - k + i) // i
    return b


def compute_mask(combination, m, n):
    result = 0
    signer = 0
    while m > 0:
        m = m - 1
        n = n - 1
        numst = binom(n, m)
        while combination >= numst:
            combination -= numst
            signer = signer + 1
            n = n - 1
            numst = binom(n, m)
        result |= 1 << signer
        signer = signer + 1
    return result


def createPubkey():
    print('Enter randomness: ', end='')
    seckey = ed25519raw.H(input().encode('utf-8'))[0:32]
    pubkey = ed25519raw.publickey(seckey)
    print('Secret Key: %s' % hex(seckey))
    print('Public Key: %s' % hex(pubkey))


def combinePubkeys(m, n):
    if binom(n, m) > 100:
        raise Exception("Too many keys")
    pks = []
    for i in range(0, n):
        print('Enter pubkey %d: ' % i, end='')
        pk = binascii.unhexlify(input())
        # remove 00 prefix if present
        if len(pk) == 33:
            pk = pk[1:]
        pks.append(ed25519raw.decodepoint(pk))
    for i in range(0, binom(n, m)):
        mask = compute_mask(i, m, n)
        sum = None
        for j in range(0, n):
            if mask & (1 << j) != 0:
                if sum is None:
                    sum = pks[j]
                else:
                    sum = ed25519raw.edwards(sum, pks[j])
        pk = ed25519raw.encodepoint(sum)
        print('Key %02x: %s' % (mask, hex(pk)))


def get_nonce(sk, data, ctr):
    h = ed25519raw.H(sk)
    b = ed25519raw.b
    a = 2 ** (b - 2) + sum(2 ** i * ed25519raw.bit(h, i) for i in range(3, b - 2))
    r = ed25519raw.Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + data + binascii.unhexlify('%08x' % ctr))
    R = ed25519raw.scalarmult(ed25519raw.B, r)
    return (r, ed25519raw.encodepoint(R))


def phase1(data):
    digest = pyblake2.blake2s(data).digest()
    print('Digest: %s' % hex(digest))
    print('Enter counter (small integer): ', end='')
    ctr = int(input())
    print('Enter privkey: ', end='')
    seckey = binascii.unhexlify(input())
    (_, R) = get_nonce(seckey, digest, ctr)
    print('Local commit: %s' % hex(R))


def combinePhase1(m):
    commits = []
    for i in range(0, m):
        print('Enter commit %d: ' % i, end='')
        commits.append(binascii.unhexlify(input()))
    print('Global commit: %s' % hex(combine_keys(commits)))


def phase2(data):
    digest = pyblake2.blake2s(data).digest()
    print('Digest: %s' % hex(digest))
    print('Enter combined commitment: ', end='')
    R = binascii.unhexlify(input())
    print('Enter combined public key: ', end='')
    pk = binascii.unhexlify(input())
    print('Enter counter: ', end='')
    ctr = int(input())
    print('Enter privkey: ', end='')
    seckey = binascii.unhexlify(input())

    (r, Ri) = get_nonce(seckey, digest, ctr)
    h = ed25519raw.H(seckey)
    b = ed25519raw.b
    a = 2**(b - 2) + sum(2**i * ed25519raw.bit(h, i) for i in range(3, b - 2))
    S = (r + ed25519raw.Hint(R + pk + digest) * a) % ed25519raw.l
    print('Local commit: %s' % hex(Ri))
    print('Local sig: %s' % hex(ed25519raw.encodeint(S)))


def combinePhase2(m):
    sigs = []
    print('Enter global commit: ', end='')
    R = binascii.unhexlify(input())
    for i in range(0, m):
        print('Enter sig %d: ' % i, end='')
        sigs.append(binascii.unhexlify(input()))
    sig = combine_sig(R, sigs)
    print('Combined sig: %s' % hex(sig))


def checkSignature(data):
    digest = pyblake2.blake2s(data).digest()
    print('Digest: %s' % hex(digest))
    print('Enter Public Key: ', end='')
    pubkey = binascii.unhexlify(input())
    print('Enter Sig: ', end='')
    sig = binascii.unhexlify(input())
    ed25519raw.checkvalid(sig, digest, pubkey)
    print('Valid Signature!')


def usage():
    print('Usage: keyctl phase options')
    print('Phases:')
    print('  keyctl create_pub: create single public keys')
    print('  keyctl combine_pub m n: create combined public keys')
    print('  keyctl ph1 file.bin: compute partial commitment')
    print('  keyctl combine_ph1 m: combine commitments')
    print('  keyctl ph2 file.bin: compute partial signature')
    print('  keyctl combine_ph2 m: combine signatures')
    print('  keyctl check_sig file.bin: check signature')


def main():
    if len(sys.argv) < 2:
        usage()
        return 1
    func = sys.argv[1]
    if func == 'create_pub':
        createPubkey()
    elif func == 'combine_pub':
        m = int(sys.argv[2])
        n = int(sys.argv[3])
        combinePubkeys(m, n)
    elif func == 'ph1':
        filename = sys.argv[2]
        data = open(filename, 'rb').read()
        phase1(data)
    elif func == 'combine_ph1':
        m = int(sys.argv[2])
        combinePhase1(m)
    elif func == 'ph2':
        filename = sys.argv[2]
        data = open(filename, 'rb').read()
        phase2(data)
    elif func == 'combine_ph2':
        m = int(sys.argv[2])
        combinePhase2(m)
    elif func == 'check_sig':
        filename = sys.argv[2]
        data = open(filename, 'rb').read()
        checkSignature(data)
    else:
        usage()


def test():
    data = sys.argv[1].encode('utf-8')
    N = 5
    keyset = [1, 3, 4]

    digest = pyblake2.blake2s(data).digest()
    print('Digest: %s' % hex(digest))
    sks = []
    pks = []
    nonces = []
    commits = []
    sigs = []
    for i in range(0, N):
        print('----- Key %d ------' % (i + 1))
        seckey = ed25519raw.H(("key%d" % (i + 1)).encode('utf-8'))[0:32]
        pubkey = ed25519raw.publickey(seckey)
        print('Secret Key: %s' % hex(seckey))
        print('Public Key: %s' % hex(pubkey))
        sks.append(seckey)
        pks.append(pubkey)
        ctr = 0
        (r, R) = get_nonce(seckey, digest, ctr)
        print('Local nonce:  %s' % hex(ed25519raw.encodeint(r)))
        print('Local commit: %s' % hex(R))
        nonces.append(r)
        commits.append(R)

    globalPk = combine_keys([pks[i] for i in keyset])
    globalR = combine_keys([commits[i] for i in keyset])
    print('-----------------')
    print('Global pubkey: %s' % hex(globalPk))
    print('Global commit: %s' % hex(globalR))
    print('-----------------')

    for i in range(0, 5):
        seckey = sks[i]
        pubkey = pks[i]
        r = nonces[i]
        R = commits[i]
        h = ed25519raw.H(seckey)
        b = ed25519raw.b
        a = 2**(b - 2) + sum(2**i * ed25519raw.bit(h, i)
                             for i in range(3, b - 2))
        S = (r + ed25519raw.Hint(globalR + globalPk + digest) * a) % ed25519raw.l
        print('Local sig %d: %s' % (i + 1, hex(ed25519raw.encodeint(S))))
        commits.append(R)
        sigs.append(ed25519raw.encodeint(S))

    print('-----------------')
    sig = combine_sig(globalR, [sigs[i] for i in [1, 3, 4]])
    print('Global sig: %s' % hex(sig))
    ed25519raw.checkvalid(sig, digest, globalPk)
    print('Valid Signature!')


if __name__ == '__main__':
    test()