1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 14:58:09 +00:00

core/monero: use const where possible

This commit is contained in:
Pavol Rusnak 2019-09-30 12:27:23 +00:00
parent ed0336c0a9
commit 04466402ce
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
2 changed files with 78 additions and 86 deletions

View File

@ -1,4 +1,5 @@
import gc
from micropython import const
from trezor import utils
from trezor.utils import memcpy as _memcpy
@ -8,33 +9,24 @@ from apps.monero.xmr.serialize.int_serialize import dump_uvarint_b_into, uvarint
# Constants
BP_LOG_N = 6
BP_N = 64 # 1 << BP_LOG_N
BP_M = 16 # maximal number of bulletproofs
_BP_LOG_N = const(6)
_BP_N = const(64) # 1 << _BP_LOG_N
_BP_M = const(16) # maximal number of bulletproofs
ZERO = b"\x00" * 32
ONE = b"\x01" + b"\x00" * 31
TWO = b"\x02" + b"\x00" * 31
EIGHT = b"\x08" + b"\x00" * 31
INV_EIGHT = crypto.INV_EIGHT
MINUS_ONE = b"\xec\xd3\xf5\x5c\x1a\x63\x12\x58\xd6\x9c\xf7\xa2\xde\xf9\xde\x14\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10"
MINUS_INV_EIGHT = b"\x74\xa4\x19\x7a\xf0\x7d\x0b\xf7\x05\xc2\xda\x25\x2b\x5c\x0b\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a"
_ZERO = b"\x00" * 32
_ONE = b"\x01" + b"\x00" * 31
# _TWO = b"\x02" + b"\x00" * 31
_EIGHT = b"\x08" + b"\x00" * 31
_INV_EIGHT = crypto.INV_EIGHT
_MINUS_ONE = b"\xec\xd3\xf5\x5c\x1a\x63\x12\x58\xd6\x9c\xf7\xa2\xde\xf9\xde\x14\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10"
# _MINUS_INV_EIGHT = b"\x74\xa4\x19\x7a\xf0\x7d\x0b\xf7\x05\xc2\xda\x25\x2b\x5c\x0b\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a"
# Monero H point
XMR_H = b"\x8b\x65\x59\x70\x15\x37\x99\xaf\x2a\xea\xdc\x9f\xf1\xad\xd0\xea\x6c\x72\x51\xd5\x41\x54\xcf\xa9\x2c\x17\x3a\x0d\xd3\x9c\x1f\x94"
XMR_HP = crypto.xmr_H()
# get_exponent(Gi[i], XMR_H, i * 2 + 1)
BP_GI_PRE = crypto.tcry.BP_GI_PRE
# get_exponent(Hi[i], XMR_H, i * 2)
BP_HI_PRE = crypto.tcry.BP_HI_PRE
# twoN = vector_powers(TWO, BP_N);
BP_TWO_N = crypto.tcry.BP_TWO_N
_XMR_H = b"\x8b\x65\x59\x70\x15\x37\x99\xaf\x2a\xea\xdc\x9f\xf1\xad\xd0\xea\x6c\x72\x51\xd5\x41\x54\xcf\xa9\x2c\x17\x3a\x0d\xd3\x9c\x1f\x94"
_XMR_HP = crypto.xmr_H()
# ip12 = inner_product(oneN, twoN);
BP_IP12 = b"\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
_BP_IP12 = b"\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
#
@ -112,7 +104,7 @@ def scalarmult_key(dst, P, s):
def scalarmultH(dst, x):
dst = _ensure_dst_key(dst)
crypto.decodeint_into(tmp_sc_1, x)
crypto.scalarmult_into(tmp_pt_1, XMR_HP, tmp_sc_1)
crypto.scalarmult_into(tmp_pt_1, _XMR_HP, tmp_sc_1)
crypto.encodepoint_into(dst, tmp_pt_1)
return dst
@ -585,7 +577,7 @@ class KeyVPowers(KeyVBase):
self.last_idx = item
if item == 0:
return copy_key(self.cur, ONE)
return copy_key(self.cur, _ONE)
elif item == 1:
return copy_key(self.cur, self.x)
elif item == prev + 1:
@ -628,7 +620,7 @@ def _ensure_dst_keyvect(dst=None, size=None):
return dst
def const_vector(val, elems=BP_N, copy=True):
def const_vector(val, elems=_BP_N, copy=True):
return KeyVConst(elems, val, copy)
@ -654,7 +646,7 @@ def vector_powers(x, n, dst=None, dynamic=False, **kwargs):
dst = _ensure_dst_keyvect(dst, n)
if n == 0:
return dst
dst.read(0, ONE)
dst.read(0, _ONE)
if n == 1:
return dst
dst.read(1, x)
@ -672,9 +664,9 @@ def vector_powers(x, n, dst=None, dynamic=False, **kwargs):
def vector_power_sum(x, n, dst=None):
dst = _ensure_dst_key(dst)
if n == 0:
return copy_key(dst, ZERO)
return copy_key(dst, _ZERO)
copy_key(dst, ONE)
copy_key(dst, _ONE)
if n == 1:
return dst
@ -920,11 +912,14 @@ class BulletProofBuilder:
self.use_det_masks = True
self.proof_sec = None
self.Gprec = KeyV(buffer=BP_GI_PRE, const=True)
self.Hprec = KeyV(buffer=BP_HI_PRE, const=True)
self.oneN = const_vector(ONE, 64)
self.twoN = KeyV(buffer=BP_TWO_N, const=True)
self.ip12 = BP_IP12
# BP_GI_PRE = get_exponent(Gi[i], _XMR_H, i * 2 + 1)
self.Gprec = KeyV(buffer=crypto.tcry.BP_GI_PRE, const=True)
# BP_HI_PRE = get_exponent(Hi[i], _XMR_H, i * 2)
self.Hprec = KeyV(buffer=crypto.tcry.BP_HI_PRE, const=True)
self.oneN = const_vector(_ONE, 64)
# BP_TWO_N = vector_powers(_TWO, _BP_N);
self.twoN = KeyV(buffer=crypto.tcry.BP_TWO_N, const=True)
self.ip12 = _BP_IP12
self.fnc_det_mask = None
self.tmp_sc_1 = crypto.new_scalar()
@ -943,14 +938,14 @@ class BulletProofBuilder:
num_inp = len(sv)
def e_xL(idx, d=None, is_a=True):
j, i = idx // BP_N, idx % BP_N
j, i = idx // _BP_N, idx % _BP_N
r = None
if j >= num_inp:
r = ZERO if is_a else MINUS_ONE
r = _ZERO if is_a else _MINUS_ONE
elif sv[j][i // 8] & (1 << i % 8):
r = ONE if is_a else ZERO
r = _ONE if is_a else _ZERO
else:
r = ZERO if is_a else MINUS_ONE
r = _ZERO if is_a else _MINUS_ONE
if d:
memcpy(d, 0, r, 0, 32)
return r
@ -967,7 +962,7 @@ class BulletProofBuilder:
if self.fnc_det_mask:
return self.fnc_det_mask(i, is_sL, dst)
self.tmp_det_buff[64] = int(is_sL)
memcpy(self.tmp_det_buff, 65, ZERO, 0, 4)
memcpy(self.tmp_det_buff, 65, _ZERO, 0, 4)
dump_uvarint_b_into(i, self.tmp_det_buff, 65)
crypto.hash_to_scalar_into(self.tmp_sc_1, self.tmp_det_buff)
crypto.encodeint_into(dst, self.tmp_sc_1)
@ -975,11 +970,13 @@ class BulletProofBuilder:
def _gprec_aux(self, size):
return KeyVPrecomp(
size, self.Gprec, lambda i, d: get_exponent(d, XMR_H, i * 2 + 1)
size, self.Gprec, lambda i, d: get_exponent(d, _XMR_H, i * 2 + 1)
)
def _hprec_aux(self, size):
return KeyVPrecomp(size, self.Hprec, lambda i, d: get_exponent(d, XMR_H, i * 2))
return KeyVPrecomp(
size, self.Hprec, lambda i, d: get_exponent(d, _XMR_H, i * 2)
)
def _two_aux(self, size):
# Simple recursive exponentiation from precomputed results
@ -998,21 +995,21 @@ class BulletProofBuilder:
return KeyVPrecomp(size, self.twoN, pow_two)
def sL_vct(self, ln=BP_N):
def sL_vct(self, ln=_BP_N):
return (
KeyVEval(ln, lambda i, dst: self._det_mask(i, True, dst))
if self.use_det_masks
else self.sX_gen(ln)
)
def sR_vct(self, ln=BP_N):
def sR_vct(self, ln=_BP_N):
return (
KeyVEval(ln, lambda i, dst: self._det_mask(i, False, dst))
if self.use_det_masks
else self.sX_gen(ln)
)
def sX_gen(self, ln=BP_N):
def sX_gen(self, ln=_BP_N):
gc.collect()
buff = bytearray(ln * 32)
buff_mv = memoryview(buff)
@ -1043,16 +1040,16 @@ class BulletProofBuilder:
gamma = [crypto.encodeint(x) for x in gamma]
M, logM = 1, 0
while M <= BP_M and M < len(sv):
while M <= _BP_M and M < len(sv):
logM += 1
M = 1 << logM
MN = M * BP_N
MN = M * _BP_N
V = _ensure_dst_keyvect(None, len(sv))
for i in range(len(sv)):
add_keys2(tmp_bf_0, gamma[i], sv[i], XMR_H)
add_keys2(tmp_bf_0, gamma[i], sv[i], _XMR_H)
if not proof_v8:
scalarmult_key(tmp_bf_0, tmp_bf_0, INV_EIGHT)
scalarmult_key(tmp_bf_0, tmp_bf_0, _INV_EIGHT)
V.read(i, tmp_bf_0)
aL, aR = self.aX_vcts(sv, MN)
@ -1064,7 +1061,7 @@ class BulletProofBuilder:
while True:
self.gc(10)
r = self._prove_batch_main(
V, gamma, aL, aR, hash_cache, logM, BP_LOG_N, M, BP_N, proof_v8
V, gamma, aL, aR, hash_cache, logM, _BP_LOG_N, M, _BP_N, proof_v8
)
if r[0]:
break
@ -1088,7 +1085,7 @@ class BulletProofBuilder:
vector_exponent_custom(Gprec, Hprec, aL, aR, ve)
add_keys(A, ve, scalarmult_base(tmp_bf_1, alpha))
if not proof_v8:
scalarmult_key(A, A, INV_EIGHT)
scalarmult_key(A, A, _INV_EIGHT)
self.gc(11)
# PAPER LINES 40-42
@ -1099,20 +1096,20 @@ class BulletProofBuilder:
S = _ensure_dst_key()
add_keys(S, ve, scalarmult_base(tmp_bf_1, rho))
if not proof_v8:
scalarmult_key(S, S, INV_EIGHT)
scalarmult_key(S, S, _INV_EIGHT)
del ve
self.gc(12)
# PAPER LINES 43-45
y = _ensure_dst_key()
hash_cache_mash(y, hash_cache, A, S)
if y == ZERO:
if y == _ZERO:
return (0,)
z = _ensure_dst_key()
hash_to_scalar(hash_cache, y)
copy_key(z, hash_cache)
if z == ZERO:
if z == _ZERO:
return (0,)
# Polynomial construction by coefficients
@ -1177,23 +1174,23 @@ class BulletProofBuilder:
add_keys(T1, scalarmultH(tmp_bf_1, t1), scalarmult_base(tmp_bf_2, tau1))
if not proof_v8:
scalarmult_key(T1, T1, INV_EIGHT)
scalarmult_key(T1, T1, _INV_EIGHT)
add_keys(T2, scalarmultH(tmp_bf_1, t2), scalarmult_base(tmp_bf_2, tau2))
if not proof_v8:
scalarmult_key(T2, T2, INV_EIGHT)
scalarmult_key(T2, T2, _INV_EIGHT)
del (t1, t2)
self.gc(17)
# PAPER LINES 49-51
x = _ensure_dst_key()
hash_cache_mash(x, hash_cache, z, T1, T2)
if x == ZERO:
if x == _ZERO:
return (0,)
# PAPER LINES 52-53
taux = _ensure_dst_key()
copy_key(taux, ZERO)
copy_key(taux, _ZERO)
sc_mul(taux, tau1, x)
xsq = _ensure_dst_key()
sc_mul(xsq, x, x)
@ -1230,7 +1227,7 @@ class BulletProofBuilder:
# PAPER LINES 32-33
x_ip = hash_cache_mash(None, hash_cache, x, taux, mu, t)
if x_ip == ZERO:
if x_ip == _ZERO:
return 0, None
# PHASE 2
@ -1241,7 +1238,7 @@ class BulletProofBuilder:
aprime = l
bprime = r
yinv = invert(None, y)
yinvpow = init_key(ONE)
yinvpow = init_key(_ONE)
self.gc(20)
for i in range(0, MN):
@ -1292,7 +1289,7 @@ class BulletProofBuilder:
sc_mul(tmp, cL, x_ip)
add_keys(tmp_bf_0, tmp_bf_0, scalarmultH(_tmp_k_1, tmp))
if not proof_v8:
scalarmult_key(tmp_bf_0, tmp_bf_0, INV_EIGHT)
scalarmult_key(tmp_bf_0, tmp_bf_0, _INV_EIGHT)
L.read(round, tmp_bf_0)
self.gc(24)
@ -1307,13 +1304,13 @@ class BulletProofBuilder:
sc_mul(tmp, cR, x_ip)
add_keys(tmp_bf_0, tmp_bf_0, scalarmultH(_tmp_k_1, tmp))
if not proof_v8:
scalarmult_key(tmp_bf_0, tmp_bf_0, INV_EIGHT)
scalarmult_key(tmp_bf_0, tmp_bf_0, _INV_EIGHT)
R.read(round, tmp_bf_0)
self.gc(25)
# PAPER LINES 21-22
hash_cache_mash(w_round, hash_cache, L.to(round), R.to(round))
if w_round == ZERO:
if w_round == _ZERO:
return (0,)
# PAPER LINES 24-25
@ -1395,13 +1392,13 @@ class BulletProofBuilder:
# setup weighted aggregates
is_single = len(proofs) == 1 and single_optim # ph4
z1 = init_key(ZERO)
z3 = init_key(ZERO)
m_z4 = vector_dup(ZERO, maxMN) if not is_single else None
m_z5 = vector_dup(ZERO, maxMN) if not is_single else None
m_y0 = init_key(ZERO)
y1 = init_key(ZERO)
muex_acc = init_key(ONE)
z1 = init_key(_ZERO)
z3 = init_key(_ZERO)
m_z4 = vector_dup(_ZERO, maxMN) if not is_single else None
m_z5 = vector_dup(_ZERO, maxMN) if not is_single else None
m_y0 = init_key(_ZERO)
y1 = init_key(_ZERO)
muex_acc = init_key(_ONE)
Gprec = self._gprec_aux(maxMN)
Hprec = self._hprec_aux(maxMN)
@ -1409,7 +1406,7 @@ class BulletProofBuilder:
for proof in proofs:
M = 1
logM = 0
while M <= BP_M and M < len(proof.V):
while M <= _BP_M and M < len(proof.V):
logM += 1
M = 1 << logM
@ -1421,15 +1418,15 @@ class BulletProofBuilder:
# Reconstruct the challenges
hash_cache = hash_vct_to_scalar(None, proof.V)
y = hash_cache_mash(None, hash_cache, proof.A, proof.S)
utils.ensure(y != ZERO, "y == 0")
utils.ensure(y != _ZERO, "y == 0")
z = hash_to_scalar(None, y)
copy_key(hash_cache, z)
utils.ensure(z != ZERO, "z == 0")
utils.ensure(z != _ZERO, "z == 0")
x = hash_cache_mash(None, hash_cache, z, proof.T1, proof.T2)
utils.ensure(x != ZERO, "x == 0")
utils.ensure(x != _ZERO, "x == 0")
x_ip = hash_cache_mash(None, hash_cache, x, proof.taux, proof.mu, proof.t)
utils.ensure(x_ip != ZERO, "x_ip == 0")
utils.ensure(x_ip != _ZERO, "x_ip == 0")
# PAPER LINE 61
sc_mulsub(m_y0, proof.taux, weight_y, m_y0)
@ -1437,10 +1434,10 @@ class BulletProofBuilder:
k = _ensure_dst_key()
ip1y = vector_power_sum(y, MN)
sc_mulsub(k, zpow[2], ip1y, ZERO)
sc_mulsub(k, zpow[2], ip1y, _ZERO)
for j in range(1, M + 1):
utils.ensure(j + 2 < len(zpow), "invalid zpow index")
sc_mulsub(k, zpow.to(j + 2), BP_IP12, k)
sc_mulsub(k, zpow.to(j + 2), _BP_IP12, k)
# VERIFY_line_61rl_new
sc_muladd(tmp, z, ip1y, k)
@ -1449,7 +1446,7 @@ class BulletProofBuilder:
sc_muladd(y1, tmp, weight_y, y1)
weight_y8 = init_key(weight_y)
if not proof_v8:
weight_y8 = sc_mul(None, weight_y, EIGHT)
weight_y8 = sc_mul(None, weight_y, _EIGHT)
muex = MultiExpSequential(points=[pt for pt in proof.V])
for j in range(len(proof.V)):
@ -1467,7 +1464,7 @@ class BulletProofBuilder:
weight_z8 = init_key(weight_z)
if not proof_v8:
weight_z8 = sc_mul(None, weight_z, EIGHT)
weight_z8 = sc_mul(None, weight_z, _EIGHT)
muex.add_pair(weight_z8, proof.A)
sc_mul(tmp, x, weight_z8)
@ -1487,12 +1484,12 @@ class BulletProofBuilder:
for i in range(rounds):
hash_cache_mash(tmp_bf_0, hash_cache, proof.L[i], proof.R[i])
w.read(i, tmp_bf_0)
utils.ensure(w[i] != ZERO, "w[i] == 0")
utils.ensure(w[i] != _ZERO, "w[i] == 0")
# Basically PAPER LINES 24-25
# Compute the curvepoints from G[i] and H[i]
yinvpow = init_key(ONE)
ypow = init_key(ONE)
yinvpow = init_key(_ONE)
ypow = init_key(_ONE)
yinv = invert(None, y)
self.gc(61)
@ -1588,6 +1585,6 @@ class BulletProofBuilder:
muex.add_scalar(m_z5[i])
add_keys(muex_acc, muex_acc, multiexp(None, muex, True))
if muex_acc != ONE:
if muex_acc != _ONE:
raise ValueError("Verification failure at step 2")
return True

View File

@ -1,5 +1,3 @@
from micropython import const
from apps.monero.xmr import crypto
if False:
@ -7,9 +5,6 @@ if False:
from apps.monero.xmr.types import Ge25519, Sc25519
DISPLAY_DECIMAL_POINT = const(12)
class XmrException(Exception):
pass