1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-23 07:58:09 +00:00

cosi: replace slow djb implementation of ed25519 with an optimized one

from https://github.com/pyca/ed25519

This makes the calculations several orders of magnitude faster, which
allows us to run the CoSi test in Travis. It also doesn't stop firmware
update for several seconds while we validate the CoSi signatures.

It's still essentially the same insecure implementation, fallible to all
the same timing attacks, and it shouldn't be used for anything except
validating public signatures of public data. But now it also takes about
as much time as it should on modern hardware.
This commit is contained in:
matejcik 2018-10-12 12:20:41 +02:00
parent 3d5fa7a2f6
commit ba365b5486
3 changed files with 258 additions and 105 deletions

View File

@ -1,140 +1,299 @@
# orignal version downloaded from https://ed25519.cr.yp.to/python/ed25519.py # ed25519.py - Optimized version of the reference implementation of Ed25519
# modified for Python 3 by Jochen Hoenicke <hoenicke@gmail.com> # downloaded from https://github.com/pyca/ed25519
#
# Written in 2011? by Daniel J. Bernstein <djb@cr.yp.to>
# 2013 by Donald Stufft <donald@stufft.io>
# 2013 by Alex Gaynor <alex.gaynor@gmail.com>
# 2013 by Greg Price <price@mit.edu>
#
# To the extent possible under law, the author(s) have dedicated all copyright
# and related and neighboring rights to this software to the public domain
# worldwide. This software is distributed without any warranty.
#
# You should have received a copy of the CC0 Public Domain Dedication along
# with this software. If not, see
# <http://creativecommons.org/publicdomain/zero/1.0/>.
"""
NB: This code is not safe for use with secret keys or secret data.
The only safe use of this code is for verifying signatures on public messages.
Functions for computing the public key of a secret key and for signing
a message are included, namely publickey_unsafe and signature_unsafe,
for testing purposes only.
The root of the problem is that Python's long-integer arithmetic is
not designed for use in cryptography. Specifically, it may take more
or less time to execute an operation depending on the values of the
inputs, and its memory access patterns may also depend on the inputs.
This opens it to timing and cache side-channel attacks which can
disclose data to an attacker. We rely on Python's long-integer
arithmetic, so we cannot handle secrets without risking their disclosure.
"""
import hashlib import hashlib
from typing import NewType, Tuple from typing import NewType, Tuple
Point = NewType("Point", Tuple[int, int]) Point = NewType("Point", Tuple[int, int, int, int])
__version__ = "1.0.dev1"
b = 256 b = 256
q = 2 ** 255 - 19 q = 2 ** 255 - 19
l = 2 ** 252 + 27742317777372353535851937790883648493 l = 2 ** 252 + 27742317777372353535851937790883648493
COORD_MASK = ~(1 + 2 + 4 + (1 << b - 1))
COORD_HIGH_BIT = 1 << b - 2
def H(m: bytes) -> bytes: def H(m: bytes) -> bytes:
return hashlib.sha512(m).digest() return hashlib.sha512(m).digest()
def expmod(b: int, e: int, m: int) -> int: def pow2(x: int, p: int) -> int:
if e < 0: """== pow(x, 2**p, q)"""
raise ValueError("negative exponent") while p > 0:
if e == 0: x = x * x % q
return 1 p -= 1
t = expmod(b, e >> 1, m) ** 2 % m return x
if e & 1:
t = (t * b) % m
return t
def inv(x: int) -> int: def inv(z: int) -> int:
return expmod(x, q - 2, q) """$= z^{-1} mod q$, for z != 0"""
# Adapted from curve25519_athlon.c in djb's Curve25519.
z2 = z * z % q # 2
z9 = pow2(z2, 2) * z % q # 9
z11 = z9 * z2 % q # 11
z2_5_0 = (z11 * z11) % q * z9 % q # 31 == 2^5 - 2^0
z2_10_0 = pow2(z2_5_0, 5) * z2_5_0 % q # 2^10 - 2^0
z2_20_0 = pow2(z2_10_0, 10) * z2_10_0 % q # ...
z2_40_0 = pow2(z2_20_0, 20) * z2_20_0 % q
z2_50_0 = pow2(z2_40_0, 10) * z2_10_0 % q
z2_100_0 = pow2(z2_50_0, 50) * z2_50_0 % q
z2_200_0 = pow2(z2_100_0, 100) * z2_100_0 % q
z2_250_0 = pow2(z2_200_0, 50) * z2_50_0 % q # 2^250 - 2^0
return pow2(z2_250_0, 5) * z11 % q # 2^255 - 2^5 + 11 = q - 2
d = -121665 * inv(121666) d = -121665 * inv(121666) % q
I = expmod(2, (q - 1) >> 2, q) I = pow(2, (q - 1) // 4, q)
def xrecover(y: int) -> int: def xrecover(y: int) -> int:
xx = (y * y - 1) * inv(d * y * y + 1) xx = (y * y - 1) * inv(d * y * y + 1)
x = expmod(xx, (q + 3) >> 3, q) x = pow(xx, (q + 3) // 8, q)
if (x * x - xx) % q != 0: if (x * x - xx) % q != 0:
x = (x * I) % q x = (x * I) % q
if x % 2 != 0: if x % 2 != 0:
x = q - x x = q - x
return x return x
By = 4 * inv(5) By = 4 * inv(5)
Bx = xrecover(By) Bx = xrecover(By)
B = Point((Bx % q, By % q)) B = Point((Bx % q, By % q, 1, (Bx * By) % q))
ident = Point((0, 1, 1, 0))
def edwards(P: Point, Q: Point) -> Point: def edwards_add(P: Point, Q: Point) -> Point:
x1 = P[0] # This is formula sequence 'addition-add-2008-hwcd-3' from
y1 = P[1] # http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
x2 = Q[0] (x1, y1, z1, t1) = P
y2 = Q[1] (x2, y2, z2, t2) = Q
x3 = (x1 * y2 + x2 * y1) * inv(1 + d * x1 * x2 * y1 * y2)
y3 = (y1 * y2 + x1 * x2) * inv(1 - d * x1 * x2 * y1 * y2) a = (y1 - x1) * (y2 - x2) % q
return Point((x3 % q, y3 % q)) b = (y1 + x1) * (y2 + x2) % q
c = t1 * 2 * d * t2 % q
dd = z1 * 2 * z2 % q
e = b - a
f = dd - c
g = dd + c
h = b + a
x3 = e * f
y3 = g * h
t3 = e * h
z3 = f * g
return Point((x3 % q, y3 % q, z3 % q, t3 % q))
def edwards_double(P: Point) -> Point:
# This is formula sequence 'dbl-2008-hwcd' from
# http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
(x1, y1, z1, _) = P
a = x1 * x1 % q
b = y1 * y1 % q
c = 2 * z1 * z1 % q
# dd = -a
e = ((x1 + y1) * (x1 + y1) - a - b) % q
g = -a + b # dd + b
f = g - c
h = -a - b # dd - b
x3 = e * f
y3 = g * h
t3 = e * h
z3 = f * g
return Point((x3 % q, y3 % q, z3 % q, t3 % q))
def scalarmult(P: Point, e: int) -> Point: def scalarmult(P: Point, e: int) -> Point:
if e == 0: if e == 0:
return Point((0, 1)) return ident
Q = scalarmult(P, e >> 1) Q = scalarmult(P, e // 2)
Q = edwards(Q, Q) Q = edwards_double(Q)
if e & 1: if e & 1:
Q = edwards(Q, P) Q = edwards_add(Q, P)
return Q return Q
# Bpow[i] == scalarmult(B, 2**i)
Bpow = [] # type: List[Point]
def make_Bpow() -> None:
P = B
for _ in range(253):
Bpow.append(P)
P = edwards_double(P)
make_Bpow()
def scalarmult_B(e: int) -> Point:
"""
Implements scalarmult(B, e) more efficiently.
"""
# scalarmult(B, l) is the identity
e = e % l
P = ident
for i in range(253):
if e & 1:
P = edwards_add(P, Bpow[i])
e = e // 2
assert e == 0, e
return P
def encodeint(y: int) -> bytes: def encodeint(y: int) -> bytes:
bits = [(y >> i) & 1 for i in range(b)] return y.to_bytes(b // 8, "little")
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)])
def encodepoint(P: Point) -> bytes: def encodepoint(P: Point) -> bytes:
x = P[0] (x, y, z, _) = P
y = P[1] zi = inv(z)
bits = [(y >> i) & 1 for i in range(b - 1)] + [x & 1] x = (x * zi) % q
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b >> 3)]) y = (y * zi) % q
xbit = (x & 1) << (b - 1)
def bit(h: bytes, i: int) -> int: y_result = y & ~xbit # clear x bit
return (h[i >> 3] >> (i & 7)) & 1 y_result |= xbit # set corret x bit value
return encodeint(y_result)
def publickey(sk: bytes) -> bytes:
h = H(sk)
a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2))
A = scalarmult(B, a)
return encodepoint(A)
def Hint(m: bytes) -> int:
h = H(m)
return sum(2 ** i * bit(h, i) for i in range(2 * b))
def signature(m: bytes, sk: bytes, pk: bytes) -> bytes:
h = H(sk)
a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2))
r = Hint(bytes([h[i] for i in range(b >> 3, b >> 2)]) + m)
R = scalarmult(B, r)
S = (r + Hint(encodepoint(R) + pk + m) * a) % l
return encodepoint(R) + encodeint(S)
def isoncurve(P: Point) -> bool:
x = P[0]
y = P[1]
return (-x * x + y * y - 1 - d * x * x * y * y) % q == 0
def decodeint(s: bytes) -> int: def decodeint(s: bytes) -> int:
return sum(2 ** i * bit(s, i) for i in range(0, b)) return int.from_bytes(s, "little")
def decodepoint(s: bytes) -> Point: def decodepoint(s: bytes) -> Point:
y = sum(2 ** i * bit(s, i) for i in range(0, b - 1)) y = decodeint(s) & ~(1 << b - 1) # y without the highest bit
x = xrecover(y) x = xrecover(y)
if x & 1 != bit(s, b - 1): if x & 1 != bit(s, b - 1):
x = q - x x = q - x
P = Point((x, y)) P = Point((x, y, 1, (x * y) % q))
if not isoncurve(P): if not isoncurve(P):
raise ValueError("decoding point that is not on curve") raise ValueError("decoding point that is not on curve")
return P return P
def decodecoord(s: bytes) -> int:
a = decodeint(s[: b // 8])
# clear mask bits
a &= COORD_MASK
# set high bit
a |= COORD_HIGH_BIT
return a
def bit(h: bytes, i: int) -> int:
return (h[i // 8] >> (i % 8)) & 1
def publickey_unsafe(sk: bytes) -> bytes:
"""
Not safe to use with secret keys or secret data.
See module docstring. This function should be used for testing only.
"""
h = H(sk)
a = decodecoord(h)
A = scalarmult_B(a)
return encodepoint(A)
def Hint(m: bytes) -> int:
return decodeint(H(m))
def signature_unsafe(m: bytes, sk: bytes, pk: bytes) -> bytes:
"""
Not safe to use with secret keys or secret data.
See module docstring. This function should be used for testing only.
"""
h = H(sk)
a = decodecoord(h)
r = Hint(h[b // 8 : b // 4] + m)
R = scalarmult_B(r)
S = (r + Hint(encodepoint(R) + pk + m) * a) % l
return encodepoint(R) + encodeint(S)
def isoncurve(P: Point) -> bool:
(x, y, z, t) = P
return (
z % q != 0
and x * y % q == z * t % q
and (y * y - x * x - z * z - d * t * t) % q == 0
)
class SignatureMismatch(Exception):
pass
def checkvalid(s: bytes, m: bytes, pk: bytes) -> None: def checkvalid(s: bytes, m: bytes, pk: bytes) -> None:
if len(s) != b >> 2: """
Not safe to use when any argument is secret.
See module docstring. This function should be used only for
verifying public signatures of public messages.
"""
if len(s) != b // 4:
raise ValueError("signature length is wrong") raise ValueError("signature length is wrong")
if len(pk) != b >> 3:
if len(pk) != b // 8:
raise ValueError("public-key length is wrong") raise ValueError("public-key length is wrong")
R = decodepoint(s[0 : b >> 3])
R = decodepoint(s[: b // 8])
A = decodepoint(pk) A = decodepoint(pk)
S = decodeint(s[b >> 3 : b >> 2]) S = decodeint(s[b // 8 : b // 4])
h = Hint(encodepoint(R) + pk + m) h = Hint(encodepoint(R) + pk + m)
if scalarmult(B, S) != edwards(R, scalarmult(A, h)):
raise ValueError("signature does not pass verification") (x1, y1, z1, _) = P = scalarmult_B(S)
(x2, y2, z2, _) = Q = edwards_add(R, scalarmult(A, h))
if (
not isoncurve(P)
or not isoncurve(Q)
or (x1 * z2 - x2 * z1) % q != 0
or (y1 * z2 - y2 * z1) % q != 0
):
raise SignatureMismatch("signature does not pass verification")

View File

@ -15,7 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from functools import reduce from functools import reduce
from typing import Iterable, Tuple from typing import Iterable, List, Tuple
from . import _ed25519, messages from . import _ed25519, messages
from .tools import expect from .tools import expect
@ -30,7 +30,7 @@ Ed25519Signature = bytes
def combine_keys(pks: Iterable[Ed25519PublicPoint]) -> Ed25519PublicPoint: def combine_keys(pks: Iterable[Ed25519PublicPoint]) -> Ed25519PublicPoint:
"""Combine a list of Ed25519 points into a "global" CoSi key.""" """Combine a list of Ed25519 points into a "global" CoSi key."""
P = [_ed25519.decodepoint(pk) for pk in pks] P = [_ed25519.decodepoint(pk) for pk in pks]
combine = reduce(_ed25519.edwards, P) combine = reduce(_ed25519.edwards_add, P)
return Ed25519PublicPoint(_ed25519.encodepoint(combine)) return Ed25519PublicPoint(_ed25519.encodepoint(combine))
@ -75,11 +75,28 @@ def verify(
_ed25519.checkvalid(signature, digest, pub_key) _ed25519.checkvalid(signature, digest, pub_key)
def verify_m_of_n(
signature: Ed25519Signature,
digest: bytes,
m: int,
n: int,
mask: int,
keys: List[Ed25519PublicPoint],
) -> None:
selected_keys = [keys[i] for i in range(n) if mask & (1 << i)]
if len(selected_keys) < m:
raise ValueError(
"Not enough signers ({} required, {} found)".format(m, len(selected_keys))
)
global_pk = combine_keys(selected_keys)
return verify(signature, digest, global_pk)
def pubkey_from_privkey(privkey: Ed25519PrivateKey) -> Ed25519PublicPoint: def pubkey_from_privkey(privkey: Ed25519PrivateKey) -> Ed25519PublicPoint:
"""Interpret 32 bytes of data as an Ed25519 private key. """Interpret 32 bytes of data as an Ed25519 private key.
Calculate and return the corresponding public key. Calculate and return the corresponding public key.
""" """
return Ed25519PublicPoint(_ed25519.publickey(privkey)) return Ed25519PublicPoint(_ed25519.publickey_unsafe(privkey))
def sign_with_privkey( def sign_with_privkey(
@ -92,16 +109,8 @@ def sign_with_privkey(
"""Create a CoSi signature of `digest` with the supplied private key. """Create a CoSi signature of `digest` with the supplied private key.
This function needs to know the global public key and global commitment. This function needs to know the global public key and global commitment.
""" """
b = _ed25519.b
h = _ed25519.H(privkey) h = _ed25519.H(privkey)
# curvepoint preparation: a = _ed25519.decodecoord(h)
# 1. take lowest b bits of h
a = int.from_bytes(h[: b // 8], "little")
# 2. clear lowest three and highest bit
bitmask = 1 + 2 + 4 + (1 << b - 1)
a &= ~bitmask
# 3. set next-highest bit
a |= 1 << b - 2
S = (nonce + _ed25519.Hint(global_commit + global_pubkey + digest) * a) % _ed25519.l S = (nonce + _ed25519.Hint(global_commit + global_pubkey + digest) * a) % _ed25519.l
return Ed25519Signature(_ed25519.encodeint(S)) return Ed25519Signature(_ed25519.encodeint(S))

View File

@ -18,22 +18,7 @@ import hashlib
import pytest import pytest
from trezorlib import cosi from trezorlib import _ed25519, cosi
# These tests calculate Ed25519 signatures in pure Python.
# In addition to being Python, this is also DJB's proof-of-concept, unoptimized code.
# As a result, it is actually very noticeably slow. On a gen8 Core i5, this takes around 40 seconds.
# To skip the test, run `pytest -m "not slow_cosi"`.
# Therefore, the tests are skipped by default.
# Run `pytest -m slow_cosi` to explicitly enable.
pytestmark = pytest.mark.slow_cosi
if "slow_cosi" not in pytest.config.getoption("-m"):
pytestmark = pytest.mark.skip(
"Skipping slow CoSi tests. 'pytest -m slow_cosi' to run."
)
RFC8032_VECTORS = ( RFC8032_VECTORS = (
( # test 1 ( # test 1
@ -123,8 +108,8 @@ def test_single_eddsa_vector(privkey, pubkey, message, signature):
except ValueError: except ValueError:
pytest.fail("Signature does not verify.") pytest.fail("Signature does not verify.")
fake_signature = b"\xf1" + signature[1:] fake_signature = signature[:37] + b"\xf0" + signature[38:]
with pytest.raises(ValueError): with pytest.raises(_ed25519.SignatureMismatch):
cosi.verify(fake_signature, message, pubkey) cosi.verify(fake_signature, message, pubkey)