import sys import binascii import ed25519raw import pyblake2 def hex(by): return binascii.hexlify(by).decode() def combine_keys(pks): combine = None for pk in pks: P = ed25519raw.decodepoint(pk) if not combine: combine = P else: combine = ed25519raw.edwards(combine, P) return ed25519raw.encodepoint(combine) def combine_sig(R, sigs): s = 0 for si in sigs: s += ed25519raw.decodeint(si) s = s % ed25519raw.l sig = R + ed25519raw.encodeint(s) return sig def binom(n, k): b = 1 for i in range(1, k + 1): b = b * (n - k + i) // i return b def compute_mask(combination, m, n): result = 0 signer = 0 while m > 0: m = m - 1 n = n - 1 numst = binom(n, m) while combination >= numst: combination -= numst signer = signer + 1 n = n - 1 numst = binom(n, m) result |= 1 << signer signer = signer + 1 return result def createPubkey(): print('Enter randomness: ', end='') seckey = ed25519raw.H(input().encode())[0:32] pubkey = ed25519raw.publickey(seckey) print('Secret Key: %s' % hex(seckey)) print('Public Key: %s' % hex(pubkey)) def combinePubkeys(m, n): if binom(n, m) > 100: raise Exception('Too many keys') pks = [] for i in range(0, n): print('Enter pubkey %d: ' % i, end='') pk = binascii.unhexlify(input()) # remove 00 prefix if present if len(pk) == 33: pk = pk[1:] pks.append(ed25519raw.decodepoint(pk)) for i in range(0, binom(n, m)): mask = compute_mask(i, m, n) sum = None for j in range(0, n): if mask & (1 << j) != 0: if sum is None: sum = pks[j] else: sum = ed25519raw.edwards(sum, pks[j]) pk = ed25519raw.encodepoint(sum) print('Key %02x: %s' % (mask, hex(pk))) def get_nonce(sk, data, ctr): h = ed25519raw.H(sk) b = ed25519raw.b a = 2 ** (b - 2) + sum(2 ** i * ed25519raw.bit(h, i) for i in range(3, b - 2)) r = ed25519raw.Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + data + binascii.unhexlify('%08x' % ctr)) R = ed25519raw.scalarmult(ed25519raw.B, r) return r, ed25519raw.encodepoint(R) def phase1(data): digest = pyblake2.blake2s(data).digest() print('Digest: %s' % hex(digest)) print('Enter counter (small integer): ', end='') ctr = int(input()) print('Enter privkey: ', end='') seckey = binascii.unhexlify(input()) _, R = get_nonce(seckey, digest, ctr) print('Local commit: %s' % hex(R)) def combinePhase1(m): commits = [] for i in range(0, m): print('Enter commit %d: ' % i, end='') commits.append(binascii.unhexlify(input())) print('Global commit: %s' % hex(combine_keys(commits))) def phase2(data): digest = pyblake2.blake2s(data).digest() print('Digest: %s' % hex(digest)) print('Enter combined commitment: ', end='') R = binascii.unhexlify(input()) print('Enter combined public key: ', end='') pk = binascii.unhexlify(input()) print('Enter counter: ', end='') ctr = int(input()) print('Enter privkey: ', end='') seckey = binascii.unhexlify(input()) (r, Ri) = get_nonce(seckey, digest, ctr) h = ed25519raw.H(seckey) b = ed25519raw.b a = 2**(b - 2) + sum(2**i * ed25519raw.bit(h, i) for i in range(3, b - 2)) S = (r + ed25519raw.Hint(R + pk + digest) * a) % ed25519raw.l print('Local commit: %s' % hex(Ri)) print('Local sig: %s' % hex(ed25519raw.encodeint(S))) def combinePhase2(m): sigs = [] print('Enter global commit: ', end='') R = binascii.unhexlify(input()) for i in range(0, m): print('Enter sig %d: ' % i, end='') sigs.append(binascii.unhexlify(input())) sig = combine_sig(R, sigs) print('Combined sig: %s' % hex(sig)) def checkSignature(data): digest = pyblake2.blake2s(data).digest() print('Digest: %s' % hex(digest)) print('Enter Public Key: ', end='') pubkey = binascii.unhexlify(input()) print('Enter Sig: ', end='') sig = binascii.unhexlify(input()) ed25519raw.checkvalid(sig, digest, pubkey) print('Valid Signature!') def usage(): print('Usage: keyctl phase options') print('Phases:') print(' keyctl create_pub: create single public keys') print(' keyctl combine_pub m n: create combined public keys') print(' keyctl ph1 file.bin: compute partial commitment') print(' keyctl combine_ph1 m: combine commitments') print(' keyctl ph2 file.bin: compute partial signature') print(' keyctl combine_ph2 m: combine signatures') print(' keyctl check_sig file.bin: check signature') def main(): if len(sys.argv) < 2: usage() return 1 func = sys.argv[1] if func == 'create_pub': createPubkey() elif func == 'combine_pub': m = int(sys.argv[2]) n = int(sys.argv[3]) combinePubkeys(m, n) elif func == 'ph1': filename = sys.argv[2] data = open(filename, 'rb').read() phase1(data) elif func == 'combine_ph1': m = int(sys.argv[2]) combinePhase1(m) elif func == 'ph2': filename = sys.argv[2] data = open(filename, 'rb').read() phase2(data) elif func == 'combine_ph2': m = int(sys.argv[2]) combinePhase2(m) elif func == 'check_sig': filename = sys.argv[2] data = open(filename, 'rb').read() checkSignature(data) else: usage() def test(data): N = 3 keyset = [0, 2] digest = pyblake2.blake2s(data).digest() print('Digest: %s' % hex(digest)) sks = [] pks = [] nonces = [] commits = [] sigs = [] for i in range(0, N): print('----- Key %d ------' % (i + 1)) seckey = bytes([0x41 + i]) * 32 pubkey = ed25519raw.publickey(seckey) print('Secret Key: %s' % hex(seckey)) print('Public Key: %s' % hex(pubkey)) sks.append(seckey) pks.append(pubkey) ctr = 0 r, R = get_nonce(seckey, digest, ctr) print('Local nonce: %s' % hex(ed25519raw.encodeint(r))) print('Local commit: %s' % hex(R)) nonces.append(r) commits.append(R) global_pk = combine_keys([pks[i] for i in keyset]) global_R = combine_keys([commits[i] for i in keyset]) print('-----------------') print('Global pubkey: %s' % hex(global_pk)) print('Global commit: %s' % hex(global_R)) print('-----------------') for i in range(0, N): seckey = sks[i] pubkey = pks[i] r = nonces[i] R = commits[i] h = ed25519raw.H(seckey) b = ed25519raw.b a = 2**(b - 2) + sum(2**i * ed25519raw.bit(h, i) for i in range(3, b - 2)) S = (r + ed25519raw.Hint(global_R + global_pk + digest) * a) % ed25519raw.l print('Local sig %d: %s' % (i + 1, hex(ed25519raw.encodeint(S)))) sigs.append(ed25519raw.encodeint(S)) print('-----------------') sig = combine_sig(global_R, [sigs[i] for i in keyset]) print('Global sig: %s' % hex(sig)) ed25519raw.checkvalid(sig, digest, global_pk) print('Valid Signature!') if __name__ == '__main__': if len(sys.argv) > 1: test(data=sys.argv[1].encode()) else: test(data=b'test')