mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
cosi: replace slow djb implementation of ed25519 with an optimized one
from https://github.com/pyca/ed25519 This makes the calculations several orders of magnitude faster, which allows us to run the CoSi test in Travis. It also doesn't stop firmware update for several seconds while we validate the CoSi signatures. It's still essentially the same insecure implementation, fallible to all the same timing attacks, and it shouldn't be used for anything except validating public signatures of public data. But now it also takes about as much time as it should on modern hardware.
This commit is contained in:
parent
3d5fa7a2f6
commit
ba365b5486
@ -1,140 +1,299 @@
|
||||
# orignal version downloaded from https://ed25519.cr.yp.to/python/ed25519.py
|
||||
# modified for Python 3 by Jochen Hoenicke <hoenicke@gmail.com>
|
||||
# ed25519.py - Optimized version of the reference implementation of Ed25519
|
||||
# downloaded from https://github.com/pyca/ed25519
|
||||
#
|
||||
# Written in 2011? by Daniel J. Bernstein <djb@cr.yp.to>
|
||||
# 2013 by Donald Stufft <donald@stufft.io>
|
||||
# 2013 by Alex Gaynor <alex.gaynor@gmail.com>
|
||||
# 2013 by Greg Price <price@mit.edu>
|
||||
#
|
||||
# To the extent possible under law, the author(s) have dedicated all copyright
|
||||
# and related and neighboring rights to this software to the public domain
|
||||
# worldwide. This software is distributed without any warranty.
|
||||
#
|
||||
# You should have received a copy of the CC0 Public Domain Dedication along
|
||||
# with this software. If not, see
|
||||
# <http://creativecommons.org/publicdomain/zero/1.0/>.
|
||||
|
||||
"""
|
||||
NB: This code is not safe for use with secret keys or secret data.
|
||||
The only safe use of this code is for verifying signatures on public messages.
|
||||
|
||||
Functions for computing the public key of a secret key and for signing
|
||||
a message are included, namely publickey_unsafe and signature_unsafe,
|
||||
for testing purposes only.
|
||||
|
||||
The root of the problem is that Python's long-integer arithmetic is
|
||||
not designed for use in cryptography. Specifically, it may take more
|
||||
or less time to execute an operation depending on the values of the
|
||||
inputs, and its memory access patterns may also depend on the inputs.
|
||||
This opens it to timing and cache side-channel attacks which can
|
||||
disclose data to an attacker. We rely on Python's long-integer
|
||||
arithmetic, so we cannot handle secrets without risking their disclosure.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from typing import NewType, Tuple
|
||||
|
||||
Point = NewType("Point", Tuple[int, int])
|
||||
Point = NewType("Point", Tuple[int, int, int, int])
|
||||
|
||||
|
||||
__version__ = "1.0.dev1"
|
||||
|
||||
|
||||
b = 256
|
||||
q = 2 ** 255 - 19
|
||||
l = 2 ** 252 + 27742317777372353535851937790883648493
|
||||
|
||||
COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1))
|
||||
COORD_HIGH_BIT = 1 << b - 2
|
||||
|
||||
|
||||
def H(m: bytes) -> bytes:
|
||||
return hashlib.sha512(m).digest()
|
||||
|
||||
|
||||
def expmod(b: int, e: int, m: int) -> int:
|
||||
if e < 0:
|
||||
raise ValueError("negative exponent")
|
||||
if e == 0:
|
||||
return 1
|
||||
t = expmod(b, e >> 1, m) ** 2 % m
|
||||
if e & 1:
|
||||
t = (t * b) % m
|
||||
return t
|
||||
def pow2(x: int, p: int) -> int:
|
||||
"""== pow(x, 2**p, q)"""
|
||||
while p > 0:
|
||||
x = x * x % q
|
||||
p -= 1
|
||||
return x
|
||||
|
||||
|
||||
def inv(x: int) -> int:
|
||||
return expmod(x, q - 2, q)
|
||||
def inv(z: int) -> int:
|
||||
"""$= z^{-1} mod q$, for z != 0"""
|
||||
# Adapted from curve25519_athlon.c in djb's Curve25519.
|
||||
z2 = z * z % q # 2
|
||||
z9 = pow2(z2, 2) * z % q # 9
|
||||
z11 = z9 * z2 % q # 11
|
||||
z2_5_0 = (z11 * z11) % q * z9 % q # 31 == 2^5 - 2^0
|
||||
z2_10_0 = pow2(z2_5_0, 5) * z2_5_0 % q # 2^10 - 2^0
|
||||
z2_20_0 = pow2(z2_10_0, 10) * z2_10_0 % q # ...
|
||||
z2_40_0 = pow2(z2_20_0, 20) * z2_20_0 % q
|
||||
z2_50_0 = pow2(z2_40_0, 10) * z2_10_0 % q
|
||||
z2_100_0 = pow2(z2_50_0, 50) * z2_50_0 % q
|
||||
z2_200_0 = pow2(z2_100_0, 100) * z2_100_0 % q
|
||||
z2_250_0 = pow2(z2_200_0, 50) * z2_50_0 % q # 2^250 - 2^0
|
||||
return pow2(z2_250_0, 5) * z11 % q # 2^255 - 2^5 + 11 = q - 2
|
||||
|
||||
|
||||
d = -121665 * inv(121666)
|
||||
I = expmod(2, (q - 1) >> 2, q)
|
||||
d = -121665 * inv(121666) % q
|
||||
I = pow(2, (q - 1) // 4, q)
|
||||
|
||||
|
||||
def xrecover(y: int) -> int:
|
||||
xx = (y * y - 1) * inv(d * y * y + 1)
|
||||
x = expmod(xx, (q + 3) >> 3, q)
|
||||
x = pow(xx, (q + 3) // 8, q)
|
||||
|
||||
if (x * x - xx) % q != 0:
|
||||
x = (x * I) % q
|
||||
|
||||
if x % 2 != 0:
|
||||
x = q - x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
By = 4 * inv(5)
|
||||
Bx = xrecover(By)
|
||||
B = Point((Bx % q, By % q))
|
||||
B = Point((Bx % q, By % q, 1, (Bx * By) % q))
|
||||
ident = Point((0, 1, 1, 0))
|
||||
|
||||
|
||||
def edwards(P: Point, Q: Point) -> Point:
|
||||
x1 = P[0]
|
||||
y1 = P[1]
|
||||
x2 = Q[0]
|
||||
y2 = Q[1]
|
||||
x3 = (x1 * y2 + x2 * y1) * inv(1 + d * x1 * x2 * y1 * y2)
|
||||
y3 = (y1 * y2 + x1 * x2) * inv(1 - d * x1 * x2 * y1 * y2)
|
||||
return Point((x3 % q, y3 % q))
|
||||
def edwards_add(P: Point, Q: Point) -> Point:
|
||||
# This is formula sequence 'addition-add-2008-hwcd-3' from
|
||||
# http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
|
||||
(x1, y1, z1, t1) = P
|
||||
(x2, y2, z2, t2) = Q
|
||||
|
||||
a = (y1 - x1) * (y2 - x2) % q
|
||||
b = (y1 + x1) * (y2 + x2) % q
|
||||
c = t1 * 2 * d * t2 % q
|
||||
dd = z1 * 2 * z2 % q
|
||||
e = b - a
|
||||
f = dd - c
|
||||
g = dd + c
|
||||
h = b + a
|
||||
x3 = e * f
|
||||
y3 = g * h
|
||||
t3 = e * h
|
||||
z3 = f * g
|
||||
|
||||
return Point((x3 % q, y3 % q, z3 % q, t3 % q))
|
||||
|
||||
|
||||
def edwards_double(P: Point) -> Point:
|
||||
# This is formula sequence 'dbl-2008-hwcd' from
|
||||
# http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
|
||||
(x1, y1, z1, _) = P
|
||||
|
||||
a = x1 * x1 % q
|
||||
b = y1 * y1 % q
|
||||
c = 2 * z1 * z1 % q
|
||||
# dd = -a
|
||||
e = ((x1 + y1) * (x1 + y1) - a - b) % q
|
||||
g = -a + b # dd + b
|
||||
f = g - c
|
||||
h = -a - b # dd - b
|
||||
x3 = e * f
|
||||
y3 = g * h
|
||||
t3 = e * h
|
||||
z3 = f * g
|
||||
|
||||
return Point((x3 % q, y3 % q, z3 % q, t3 % q))
|
||||
|
||||
|
||||
def scalarmult(P: Point, e: int) -> Point:
|
||||
if e == 0:
|
||||
return Point((0, 1))
|
||||
Q = scalarmult(P, e >> 1)
|
||||
Q = edwards(Q, Q)
|
||||
return ident
|
||||
Q = scalarmult(P, e // 2)
|
||||
Q = edwards_double(Q)
|
||||
if e & 1:
|
||||
Q = edwards(Q, P)
|
||||
Q = edwards_add(Q, P)
|
||||
return Q
|
||||
|
||||
|
||||
# Bpow[i] == scalarmult(B, 2**i)
|
||||
Bpow = [] # type: List[Point]
|
||||
|
||||
|
||||
def make_Bpow() -> None:
|
||||
P = B
|
||||
for _ in range(253):
|
||||
Bpow.append(P)
|
||||
P = edwards_double(P)
|
||||
|
||||
|
||||
make_Bpow()
|
||||
|
||||
|
||||
def scalarmult_B(e: int) -> Point:
|
||||
"""
|
||||
Implements scalarmult(B, e) more efficiently.
|
||||
"""
|
||||
# scalarmult(B, l) is the identity
|
||||
e = e % l
|
||||
P = ident
|
||||
for i in range(253):
|
||||
if e & 1:
|
||||
P = edwards_add(P, Bpow[i])
|
||||
e = e // 2
|
||||
assert e == 0, e
|
||||
return P
|
||||
|
||||
|
||||
def encodeint(y: int) -> bytes:
|
||||
bits = [(y >> i) & 1 for i in range(b)]
|
||||
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)])
|
||||
return y.to_bytes(b // 8, "little")
|
||||
|
||||
|
||||
def encodepoint(P: Point) -> bytes:
|
||||
x = P[0]
|
||||
y = P[1]
|
||||
bits = [(y >> i) & 1 for i in range(b - 1)] + [x & 1]
|
||||
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)])
|
||||
(x, y, z, _) = P
|
||||
zi = inv(z)
|
||||
x = (x * zi) % q
|
||||
y = (y * zi) % q
|
||||
|
||||
|
||||
def bit(h: bytes, i: int) -> int:
|
||||
return (h[i >> 3] >> (i & 7)) & 1
|
||||
|
||||
|
||||
def publickey(sk: bytes) -> bytes:
|
||||
h = H(sk)
|
||||
a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2))
|
||||
A = scalarmult(B, a)
|
||||
return encodepoint(A)
|
||||
|
||||
|
||||
def Hint(m: bytes) -> int:
|
||||
h = H(m)
|
||||
return sum(2 ** i * bit(h, i) for i in range(2 * b))
|
||||
|
||||
|
||||
def signature(m: bytes, sk: bytes, pk: bytes) -> bytes:
|
||||
h = H(sk)
|
||||
a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2))
|
||||
r = Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + m)
|
||||
R = scalarmult(B, r)
|
||||
S = (r + Hint(encodepoint(R) + pk + m) * a) % l
|
||||
return encodepoint(R) + encodeint(S)
|
||||
|
||||
|
||||
def isoncurve(P: Point) -> bool:
|
||||
x = P[0]
|
||||
y = P[1]
|
||||
return (-x * x + y * y - 1 - d * x * x * y * y) % q == 0
|
||||
xbit = (x & 1) << (b - 1)
|
||||
y_result = y & ~xbit # clear x bit
|
||||
y_result |= xbit # set corret x bit value
|
||||
return encodeint(y_result)
|
||||
|
||||
|
||||
def decodeint(s: bytes) -> int:
|
||||
return sum(2 ** i * bit(s, i) for i in range(0, b))
|
||||
return int.from_bytes(s, "little")
|
||||
|
||||
|
||||
def decodepoint(s: bytes) -> Point:
|
||||
y = sum(2 ** i * bit(s, i) for i in range(0, b - 1))
|
||||
y = decodeint(s) & ~(1 << b - 1) # y without the highest bit
|
||||
x = xrecover(y)
|
||||
if x & 1 != bit(s, b - 1):
|
||||
x = q - x
|
||||
P = Point((x, y))
|
||||
P = Point((x, y, 1, (x * y) % q))
|
||||
if not isoncurve(P):
|
||||
raise ValueError("decoding point that is not on curve")
|
||||
return P
|
||||
|
||||
|
||||
def decodecoord(s: bytes) -> int:
|
||||
a = decodeint(s[: b // 8])
|
||||
# clear mask bits
|
||||
a &= COORD_MASK
|
||||
# set high bit
|
||||
a |= COORD_HIGH_BIT
|
||||
return a
|
||||
|
||||
|
||||
def bit(h: bytes, i: int) -> int:
|
||||
return (h[i // 8] >> (i % 8)) & 1
|
||||
|
||||
|
||||
def publickey_unsafe(sk: bytes) -> bytes:
|
||||
"""
|
||||
Not safe to use with secret keys or secret data.
|
||||
|
||||
See module docstring. This function should be used for testing only.
|
||||
"""
|
||||
h = H(sk)
|
||||
a = decodecoord(h)
|
||||
A = scalarmult_B(a)
|
||||
return encodepoint(A)
|
||||
|
||||
|
||||
def Hint(m: bytes) -> int:
|
||||
return decodeint(H(m))
|
||||
|
||||
|
||||
def signature_unsafe(m: bytes, sk: bytes, pk: bytes) -> bytes:
|
||||
"""
|
||||
Not safe to use with secret keys or secret data.
|
||||
|
||||
See module docstring. This function should be used for testing only.
|
||||
"""
|
||||
h = H(sk)
|
||||
a = decodecoord(h)
|
||||
r = Hint(h[b // 8 : b // 4] + m)
|
||||
R = scalarmult_B(r)
|
||||
S = (r + Hint(encodepoint(R) + pk + m) * a) % l
|
||||
return encodepoint(R) + encodeint(S)
|
||||
|
||||
|
||||
def isoncurve(P: Point) -> bool:
|
||||
(x, y, z, t) = P
|
||||
return (
|
||||
z % q != 0
|
||||
and x * y % q == z * t % q
|
||||
and (y * y - x * x - z * z - d * t * t) % q == 0
|
||||
)
|
||||
|
||||
|
||||
class SignatureMismatch(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def checkvalid(s: bytes, m: bytes, pk: bytes) -> None:
|
||||
if len(s) != b >> 2:
|
||||
"""
|
||||
Not safe to use when any argument is secret.
|
||||
|
||||
See module docstring. This function should be used only for
|
||||
verifying public signatures of public messages.
|
||||
"""
|
||||
if len(s) != b // 4:
|
||||
raise ValueError("signature length is wrong")
|
||||
if len(pk) != b >> 3:
|
||||
|
||||
if len(pk) != b // 8:
|
||||
raise ValueError("public-key length is wrong")
|
||||
R = decodepoint(s[0 : b >> 3])
|
||||
|
||||
R = decodepoint(s[: b // 8])
|
||||
A = decodepoint(pk)
|
||||
S = decodeint(s[b >> 3 : b >> 2])
|
||||
S = decodeint(s[b // 8 : b // 4])
|
||||
h = Hint(encodepoint(R) + pk + m)
|
||||
if scalarmult(B, S) != edwards(R, scalarmult(A, h)):
|
||||
raise ValueError("signature does not pass verification")
|
||||
|
||||
(x1, y1, z1, _) = P = scalarmult_B(S)
|
||||
(x2, y2, z2, _) = Q = edwards_add(R, scalarmult(A, h))
|
||||
|
||||
if (
|
||||
not isoncurve(P)
|
||||
or not isoncurve(Q)
|
||||
or (x1 * z2 - x2 * z1) % q != 0
|
||||
or (y1 * z2 - y2 * z1) % q != 0
|
||||
):
|
||||
raise SignatureMismatch("signature does not pass verification")
|
||||
|
@ -15,7 +15,7 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from functools import reduce
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
from . import _ed25519, messages
|
||||
from .tools import expect
|
||||
@ -30,7 +30,7 @@ Ed25519Signature = bytes
|
||||
def combine_keys(pks: Iterable[Ed25519PublicPoint]) -> Ed25519PublicPoint:
|
||||
"""Combine a list of Ed25519 points into a "global" CoSi key."""
|
||||
P = [_ed25519.decodepoint(pk) for pk in pks]
|
||||
combine = reduce(_ed25519.edwards, P)
|
||||
combine = reduce(_ed25519.edwards_add, P)
|
||||
return Ed25519PublicPoint(_ed25519.encodepoint(combine))
|
||||
|
||||
|
||||
@ -75,11 +75,28 @@ def verify(
|
||||
_ed25519.checkvalid(signature, digest, pub_key)
|
||||
|
||||
|
||||
def verify_m_of_n(
|
||||
signature: Ed25519Signature,
|
||||
digest: bytes,
|
||||
m: int,
|
||||
n: int,
|
||||
mask: int,
|
||||
keys: List[Ed25519PublicPoint],
|
||||
) -> None:
|
||||
selected_keys = [keys[i] for i in range(n) if mask & (1 << i)]
|
||||
if len(selected_keys) < m:
|
||||
raise ValueError(
|
||||
"Not enough signers ({} required, {} found)".format(m, len(selected_keys))
|
||||
)
|
||||
global_pk = combine_keys(selected_keys)
|
||||
return verify(signature, digest, global_pk)
|
||||
|
||||
|
||||
def pubkey_from_privkey(privkey: Ed25519PrivateKey) -> Ed25519PublicPoint:
|
||||
"""Interpret 32 bytes of data as an Ed25519 private key.
|
||||
Calculate and return the corresponding public key.
|
||||
"""
|
||||
return Ed25519PublicPoint(_ed25519.publickey(privkey))
|
||||
return Ed25519PublicPoint(_ed25519.publickey_unsafe(privkey))
|
||||
|
||||
|
||||
def sign_with_privkey(
|
||||
@ -92,16 +109,8 @@ def sign_with_privkey(
|
||||
"""Create a CoSi signature of `digest` with the supplied private key.
|
||||
This function needs to know the global public key and global commitment.
|
||||
"""
|
||||
b = _ed25519.b
|
||||
h = _ed25519.H(privkey)
|
||||
# curvepoint preparation:
|
||||
# 1. take lowest b bits of h
|
||||
a = int.from_bytes(h[: b // 8], "little")
|
||||
# 2. clear lowest three and highest bit
|
||||
bitmask = 1 + 2 + 4 + (1 << b - 1)
|
||||
a &= ~bitmask
|
||||
# 3. set next-highest bit
|
||||
a |= 1 << b - 2
|
||||
a = _ed25519.decodecoord(h)
|
||||
|
||||
S = (nonce + _ed25519.Hint(global_commit + global_pubkey + digest) * a) % _ed25519.l
|
||||
return Ed25519Signature(_ed25519.encodeint(S))
|
||||
|
@ -18,22 +18,7 @@ import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import cosi
|
||||
|
||||
# These tests calculate Ed25519 signatures in pure Python.
|
||||
# In addition to being Python, this is also DJB's proof-of-concept, unoptimized code.
|
||||
# As a result, it is actually very noticeably slow. On a gen8 Core i5, this takes around 40 seconds.
|
||||
# To skip the test, run `pytest -m "not slow_cosi"`.
|
||||
|
||||
# Therefore, the tests are skipped by default.
|
||||
# Run `pytest -m slow_cosi` to explicitly enable.
|
||||
|
||||
pytestmark = pytest.mark.slow_cosi
|
||||
if "slow_cosi" not in pytest.config.getoption("-m"):
|
||||
pytestmark = pytest.mark.skip(
|
||||
"Skipping slow CoSi tests. 'pytest -m slow_cosi' to run."
|
||||
)
|
||||
|
||||
from trezorlib import _ed25519, cosi
|
||||
|
||||
RFC8032_VECTORS = (
|
||||
( # test 1
|
||||
@ -123,8 +108,8 @@ def test_single_eddsa_vector(privkey, pubkey, message, signature):
|
||||
except ValueError:
|
||||
pytest.fail("Signature does not verify.")
|
||||
|
||||
fake_signature = b"\xf1" + signature[1:]
|
||||
with pytest.raises(ValueError):
|
||||
fake_signature = signature[:37] + b"\xf0" + signature[38:]
|
||||
with pytest.raises(_ed25519.SignatureMismatch):
|
||||
cosi.verify(fake_signature, message, pubkey)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user