1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-27 01:48:17 +00:00

ed25519: type hints

This commit is contained in:
matejcik 2018-05-28 14:17:11 +02:00
parent 2fdb5cd538
commit 0e8fe9e743

View File

@ -2,17 +2,20 @@
# modified for Python 3 by Jochen Hoenicke <hoenicke@gmail.com> # modified for Python 3 by Jochen Hoenicke <hoenicke@gmail.com>
import hashlib import hashlib
from typing import Tuple, NewType
Point = NewType("Point", Tuple[int, int])
b = 256 b = 256
q = 2 ** 255 - 19 q = 2 ** 255 - 19
l = 2 ** 252 + 27742317777372353535851937790883648493 l = 2 ** 252 + 27742317777372353535851937790883648493
def H(m): def H(m: bytes) -> bytes:
return hashlib.sha512(m).digest() return hashlib.sha512(m).digest()
def expmod(b, e, m): def expmod(b: int, e: int, m: int) -> int:
if e < 0: if e < 0:
raise ValueError('negative exponent') raise ValueError('negative exponent')
if e == 0: if e == 0:
@ -23,7 +26,7 @@ def expmod(b, e, m):
return t return t
def inv(x): def inv(x: int) -> int:
return expmod(x, q - 2, q) return expmod(x, q - 2, q)
@ -31,7 +34,7 @@ d = -121665 * inv(121666)
I = expmod(2, (q - 1) >> 2, q) I = expmod(2, (q - 1) >> 2, q)
def xrecover(y): def xrecover(y: int) -> int:
xx = (y * y - 1) * inv(d * y * y + 1) xx = (y * y - 1) * inv(d * y * y + 1)
x = expmod(xx, (q + 3) >> 3, q) x = expmod(xx, (q + 3) >> 3, q)
if (x * x - xx) % q != 0: if (x * x - xx) % q != 0:
@ -43,22 +46,22 @@ def xrecover(y):
By = 4 * inv(5) By = 4 * inv(5)
Bx = xrecover(By) Bx = xrecover(By)
B = [Bx % q, By % q] B = Point((Bx % q, By % q))
def edwards(P, Q): def edwards(P: Point, Q: Point) -> Point:
x1 = P[0] x1 = P[0]
y1 = P[1] y1 = P[1]
x2 = Q[0] x2 = Q[0]
y2 = Q[1] y2 = Q[1]
x3 = (x1 * y2 + x2 * y1) * inv(1 + d * x1 * x2 * y1 * y2) x3 = (x1 * y2 + x2 * y1) * inv(1 + d * x1 * x2 * y1 * y2)
y3 = (y1 * y2 + x1 * x2) * inv(1 - d * x1 * x2 * y1 * y2) y3 = (y1 * y2 + x1 * x2) * inv(1 - d * x1 * x2 * y1 * y2)
return [x3 % q, y3 % q] return Point((x3 % q, y3 % q))
def scalarmult(P, e): def scalarmult(P: Point, e: int) -> Point:
if e == 0: if e == 0:
return [0, 1] return Point((0, 1))
Q = scalarmult(P, e >> 1) Q = scalarmult(P, e >> 1)
Q = edwards(Q, Q) Q = edwards(Q, Q)
if e & 1: if e & 1:
@ -66,35 +69,35 @@ def scalarmult(P, e):
return Q return Q
def encodeint(y): def encodeint(y: int) -> bytes:
bits = [(y >> i) & 1 for i in range(b)] bits = [(y >> i) & 1 for i in range(b)]
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)]) return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)])
def encodepoint(P): def encodepoint(P: Point) -> bytes:
x = P[0] x = P[0]
y = P[1] y = P[1]
bits = [(y >> i) & 1 for i in range(b - 1)] + [x & 1] bits = [(y >> i) & 1 for i in range(b - 1)] + [x & 1]
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)]) return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)])
def bit(h, i): def bit(h: bytes, i: int) -> int:
return (h[i >> 3] >> (i & 7)) & 1 return (h[i >> 3] >> (i & 7)) & 1
def publickey(sk): def publickey(sk: bytes) -> bytes:
h = H(sk) h = H(sk)
a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2)) a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2))
A = scalarmult(B, a) A = scalarmult(B, a)
return encodepoint(A) return encodepoint(A)
def Hint(m): def Hint(m: bytes) -> int:
h = H(m) h = H(m)
return sum(2 ** i * bit(h, i) for i in range(2 * b)) return sum(2 ** i * bit(h, i) for i in range(2 * b))
def signature(m, sk, pk): def signature(m: bytes, sk: bytes, pk: bytes) -> bytes:
h = H(sk) h = H(sk)
a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2)) a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2))
r = Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + m) r = Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + m)
@ -103,28 +106,28 @@ def signature(m, sk, pk):
return encodepoint(R) + encodeint(S) return encodepoint(R) + encodeint(S)
def isoncurve(P): def isoncurve(P: Point) -> bool:
x = P[0] x = P[0]
y = P[1] y = P[1]
return (-x * x + y * y - 1 - d * x * x * y * y) % q == 0 return (-x * x + y * y - 1 - d * x * x * y * y) % q == 0
def decodeint(s): def decodeint(s: bytes) -> int:
return sum(2 ** i * bit(s, i) for i in range(0, b)) return sum(2 ** i * bit(s, i) for i in range(0, b))
def decodepoint(s): def decodepoint(s: bytes) -> Point:
y = sum(2 ** i * bit(s, i) for i in range(0, b - 1)) y = sum(2 ** i * bit(s, i) for i in range(0, b - 1))
x = xrecover(y) x = xrecover(y)
if x & 1 != bit(s, b - 1): if x & 1 != bit(s, b - 1):
x = q - x x = q - x
P = [x, y] P = Point((x, y))
if not isoncurve(P): if not isoncurve(P):
raise ValueError('decoding point that is not on curve') raise ValueError('decoding point that is not on curve')
return P return P
def checkvalid(s, m, pk): def checkvalid(s: bytes, m: bytes, pk: bytes) -> None:
if len(s) != b >> 2: if len(s) != b >> 2:
raise ValueError('signature length is wrong') raise ValueError('signature length is wrong')
if len(pk) != b >> 3: if len(pk) != b >> 3: