From 5884d1c03fd11c90464289267a35c8e41931fed5 Mon Sep 17 00:00:00 2001 From: Pavol Rusnak Date: Tue, 3 Oct 2017 23:48:12 +0200 Subject: [PATCH] tools: update ed25519{cosi,raw}.py to work in both py2 and py3 --- tools/ed25519cosi.py | 10 +++++++--- tools/ed25519raw.py | 32 +++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/tools/ed25519cosi.py b/tools/ed25519cosi.py index 4636faed3f..849819c56f 100644 --- a/tools/ed25519cosi.py +++ b/tools/ed25519cosi.py @@ -1,5 +1,7 @@ +import sys from functools import reduce import binascii + import ed25519raw @@ -19,7 +21,10 @@ def combine_sig(R, sigs): def get_nonce(sk, data, ctr): h = ed25519raw.H(sk) b = ed25519raw.b - r = ed25519raw.Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + data + binascii.unhexlify('%08x' % ctr)) + if sys.version_info.major < 3: + r = ed25519raw.Hint(''.join([h[i] for i in range(b >> 3, b >> 2)]) + data + binascii.unhexlify('%08x' % ctr)) + else: + r = ed25519raw.Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + data + binascii.unhexlify('%08x' % ctr)) R = ed25519raw.scalarmult(ed25519raw.B, r) return r, ed25519raw.encodepoint(R) @@ -41,7 +46,7 @@ def self_test(digest): sigs = [] for i in range(0, N): print('----- Key %d ------' % (i + 1)) - seckey = bytes([0x41 + i]) * 32 + seckey = (chr(0x41 + i) * 32).encode() pubkey = ed25519raw.publickey(seckey) print('Secret Key: %s' % to_hex(seckey)) print('Public Key: %s' % to_hex(pubkey)) @@ -81,7 +86,6 @@ def self_test(digest): if __name__ == '__main__': - import sys if len(sys.argv) > 1: self_test(digest=sys.argv[1]) else: diff --git a/tools/ed25519raw.py b/tools/ed25519raw.py index 24c7d68711..1e75b09269 100644 --- a/tools/ed25519raw.py +++ b/tools/ed25519raw.py @@ -1,11 +1,12 @@ # orignal version downloaded from https://ed25519.cr.yp.to/python/ed25519.py # modified for Python 3 by Jochen Hoenicke +import sys import hashlib b = 256 -q = 2**255 - 19 -l = 2**252 + 27742317777372353535851937790883648493 +q = 2 ** 255 - 19 +l = 2 ** 252 + 27742317777372353535851937790883648493 def H(m): @@ -17,7 +18,7 @@ def expmod(b, e, m): raise Exception('negative exponent') if e == 0: return 1 - t = expmod(b, e >> 1, m)**2 % m + t = expmod(b, e >> 1, m) ** 2 % m if e & 1: t = (t * b) % m return t @@ -68,35 +69,44 @@ def scalarmult(P, e): def encodeint(y): bits = [(y >> i) & 1 for i in range(b)] - return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)]) + if sys.version_info.major < 3: + return ''.join([chr(sum([bits[i * 8 + j] << j for j in range(8)])) for i in range(b >> 3)]) + else: + return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)]) def encodepoint(P): x = P[0] y = P[1] bits = [(y >> i) & 1 for i in range(b - 1)] + [x & 1] - return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)]) + if sys.version_info.major < 3: + return ''.join([chr(sum([bits[i * 8 + j] << j for j in range(8)])) for i in range(b >> 3)]) + else: + return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)]) def bit(h, i): - return (h[i >> 3] >> (i & 7)) & 1 + if sys.version_info.major < 3: + return (ord(h[i >> 3]) >> (i & 7)) & 1 + else: + return (h[i >> 3] >> (i & 7)) & 1 def publickey(sk): h = H(sk) - a = 2**(b - 2) + sum(2**i * bit(h, i) for i in range(3, b - 2)) + a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2)) A = scalarmult(B, a) return encodepoint(A) def Hint(m): h = H(m) - return sum(2**i * bit(h, i) for i in range(2 * b)) + return sum(2 ** i * bit(h, i) for i in range(2 * b)) def signature(m, sk, pk): h = H(sk) - a = 2**(b - 2) + sum(2**i * bit(h, i) for i in range(3, b - 2)) + a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2)) r = Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + m) R = scalarmult(B, r) S = (r + Hint(encodepoint(R) + pk + m) * a) % l @@ -110,11 +120,11 @@ def isoncurve(P): def decodeint(s): - return sum(2**i * bit(s, i) for i in range(0, b)) + return sum(2 ** i * bit(s, i) for i in range(0, b)) def decodepoint(s): - y = sum(2**i * bit(s, i) for i in range(0, b - 1)) + y = sum(2 ** i * bit(s, i) for i in range(0, b - 1)) x = xrecover(y) if x & 1 != bit(s, b - 1): x = q - x