mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
xmr: mlsag memory optimizations, in-place computation
- pub key matrix is not ge25519 as it consumes high amount of memory - in-place computation used to reduce fragmentation overhead
This commit is contained in:
parent
a2b32115b2
commit
90fd0bb67a
@ -107,6 +107,7 @@ def sc_init_into(r, x):
|
||||
return tcry.init256_modm(r, x)
|
||||
|
||||
|
||||
sc_copy = tcry.init256_modm
|
||||
sc_get64 = tcry.get256_modm
|
||||
sc_check = tcry.check256_modm
|
||||
check_sc = tcry.check256_modm
|
||||
|
@ -42,6 +42,9 @@ and `sk` is equal to:
|
||||
Mostly ported from official Monero client, but also inspired by Mininero.
|
||||
Author: Dusan Klinec, ph4r05, 2018
|
||||
"""
|
||||
|
||||
import gc
|
||||
|
||||
from apps.monero.xmr import crypto
|
||||
|
||||
|
||||
@ -68,33 +71,45 @@ def generate_mlsag_full(
|
||||
for i in range(rows + 1):
|
||||
sk[i] = crypto.sc_0()
|
||||
|
||||
tmp_mi_rows = crypto.new_point(None)
|
||||
tmp_pt = crypto.new_point(None)
|
||||
|
||||
for i in range(cols):
|
||||
M[i][rows] = crypto.identity()
|
||||
crypto.identity_into(tmp_mi_rows) # M[i][rows]
|
||||
|
||||
for j in range(rows):
|
||||
M[i][j] = crypto.decodepoint(pubs[i][j].dest)
|
||||
M[i][rows] = crypto.point_add(
|
||||
M[i][rows], crypto.decodepoint(pubs[i][j].commitment)
|
||||
M[i][j] = pubs[i][j].dest
|
||||
crypto.point_add_into(
|
||||
tmp_mi_rows,
|
||||
tmp_mi_rows,
|
||||
crypto.decodepoint_into(tmp_pt, pubs[i][j].commitment),
|
||||
)
|
||||
pubs[i] = None
|
||||
|
||||
for j in range(len(out_pk_commitments)):
|
||||
crypto.point_sub_into(
|
||||
tmp_mi_rows,
|
||||
tmp_mi_rows,
|
||||
crypto.decodepoint_into(tmp_pt, out_pk_commitments[j]),
|
||||
) # subtract output Ci's in last row
|
||||
|
||||
# Subtract txn fee output in last row
|
||||
crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key)
|
||||
M[i][rows] = crypto.encodepoint(tmp_mi_rows)
|
||||
|
||||
sk[rows] = crypto.sc_0()
|
||||
for j in range(rows):
|
||||
sk[j] = in_sk[j].dest
|
||||
sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row
|
||||
|
||||
for i in range(cols):
|
||||
for j in range(len(out_pk_commitments)):
|
||||
M[i][rows] = crypto.point_sub(
|
||||
M[i][rows], crypto.decodepoint(out_pk_commitments[j])
|
||||
) # subtract output Ci's in last row
|
||||
|
||||
# Subtract txn fee output in last row
|
||||
M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key)
|
||||
crypto.sc_add_into(sk[rows], sk[rows], in_sk[j].mask) # add masks in last row
|
||||
|
||||
for j in range(len(out_pk_commitments)):
|
||||
sk[rows] = crypto.sc_sub(
|
||||
sk[rows], out_sk_mask[j]
|
||||
crypto.sc_sub_into(
|
||||
sk[rows], sk[rows], out_sk_mask[j]
|
||||
) # subtract output masks in last row
|
||||
|
||||
del (pubs, tmp_mi_rows, tmp_pt)
|
||||
gc.collect()
|
||||
|
||||
return generate_mlsag(message, M, sk, kLRki, index, rows)
|
||||
|
||||
|
||||
@ -123,10 +138,19 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
|
||||
|
||||
sk[0] = in_sk.dest
|
||||
sk[1] = crypto.sc_sub(in_sk.mask, a)
|
||||
tmp_pt = crypto.new_point()
|
||||
|
||||
for i in range(cols):
|
||||
M[i][0] = crypto.decodepoint(pubs[i].dest)
|
||||
M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].commitment), cout)
|
||||
crypto.point_sub_into(
|
||||
tmp_pt, crypto.decodepoint_into(tmp_pt, pubs[i].commitment), cout
|
||||
)
|
||||
|
||||
M[i][0] = pubs[i].dest
|
||||
M[i][1] = crypto.encodepoint(tmp_pt)
|
||||
pubs[i] = None
|
||||
|
||||
del (pubs)
|
||||
gc.collect()
|
||||
|
||||
return generate_mlsag(message, M, sk, kLRki, index, dsRows)
|
||||
|
||||
@ -174,18 +198,19 @@ def generate_first_c_and_key_images(
|
||||
:param rows: total number of rows
|
||||
:param cols: size of ring
|
||||
"""
|
||||
Ip = _key_vector(dsRows)
|
||||
rv.II = _key_vector(dsRows)
|
||||
alpha = _key_vector(rows)
|
||||
rv.ss = _key_matrix(rows, cols)
|
||||
|
||||
tmp_buff = bytearray(32)
|
||||
Hi = crypto.new_point()
|
||||
aGi = crypto.new_point()
|
||||
aHPi = crypto.new_point()
|
||||
hasher = _hasher_message(message)
|
||||
|
||||
for i in range(dsRows):
|
||||
# this is somewhat extra as compared to the Ring Confidential Tx paper
|
||||
# see footnote in From Zero to Monero section 3.3
|
||||
hasher.update(crypto.encodepoint(pk[index][i]))
|
||||
hasher.update(pk[index][i])
|
||||
if kLRki:
|
||||
raise NotImplementedError("Multisig not implemented")
|
||||
# alpha[i] = kLRki.k
|
||||
@ -194,33 +219,31 @@ def generate_first_c_and_key_images(
|
||||
# hash_point(hasher, kLRki.R, tmp_buff)
|
||||
|
||||
else:
|
||||
Hi = crypto.hash_to_point(crypto.encodepoint(pk[index][i]))
|
||||
crypto.hash_to_point_into(Hi, pk[index][i])
|
||||
alpha[i] = crypto.random_scalar()
|
||||
# L = alpha_i * G
|
||||
aGi = crypto.scalarmult_base(alpha[i])
|
||||
crypto.scalarmult_base_into(aGi, alpha[i])
|
||||
# Ri = alpha_i * H(P_i)
|
||||
aHPi = crypto.scalarmult(Hi, alpha[i])
|
||||
crypto.scalarmult_into(aHPi, Hi, alpha[i])
|
||||
# key image
|
||||
rv.II[i] = crypto.scalarmult(Hi, xx[i])
|
||||
_hash_point(hasher, aGi, tmp_buff)
|
||||
_hash_point(hasher, aHPi, tmp_buff)
|
||||
|
||||
Ip[i] = rv.II[i]
|
||||
|
||||
for i in range(dsRows, rows):
|
||||
alpha[i] = crypto.random_scalar()
|
||||
# L = alpha_i * G
|
||||
aGi = crypto.scalarmult_base(alpha[i])
|
||||
crypto.scalarmult_base_into(aGi, alpha[i])
|
||||
# for some reasons we omit calculating R here, which seems
|
||||
# contrary to the paper, but it is in the Monero official client
|
||||
# see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191
|
||||
_hash_point(hasher, pk[index][i], tmp_buff)
|
||||
hasher.update(pk[index][i])
|
||||
_hash_point(hasher, aGi, tmp_buff)
|
||||
|
||||
# the first c
|
||||
c_old = hasher.digest()
|
||||
c_old = crypto.decodeint(c_old)
|
||||
return c_old, Ip, alpha
|
||||
return c_old, alpha
|
||||
|
||||
|
||||
def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
|
||||
@ -240,17 +263,22 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
|
||||
rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows)
|
||||
|
||||
rv = MgSig()
|
||||
c, L, R, Hi = 0, None, None, None
|
||||
rv.cc = crypto.new_scalar()
|
||||
c = crypto.new_scalar()
|
||||
L = crypto.new_point()
|
||||
R = crypto.new_point()
|
||||
Hi = crypto.new_point()
|
||||
|
||||
# calculates the "first" c, key images and random scalars alpha
|
||||
c_old, Ip, alpha = generate_first_c_and_key_images(
|
||||
c_old, alpha = generate_first_c_and_key_images(
|
||||
message, rv, pk, xx, kLRki, index, dsRows, rows, cols
|
||||
)
|
||||
|
||||
i = (index + 1) % cols
|
||||
if i == 0:
|
||||
rv.cc = c_old
|
||||
crypto.sc_copy(rv.cc, c_old)
|
||||
|
||||
rv.ss = [None] * cols
|
||||
tmp_buff = bytearray(32)
|
||||
while i != index:
|
||||
rv.ss[i] = _generate_random_vector(rows)
|
||||
@ -258,27 +286,38 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
|
||||
|
||||
for j in range(dsRows):
|
||||
# L = rv.ss[i][j] * G + c_old * pk[i][j]
|
||||
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j])
|
||||
Hi = crypto.hash_to_point(crypto.encodepoint(pk[i][j]))
|
||||
crypto.add_keys2_into(
|
||||
L, rv.ss[i][j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
|
||||
)
|
||||
crypto.hash_to_point_into(Hi, pk[i][j])
|
||||
|
||||
# R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j]
|
||||
R = crypto.add_keys3(rv.ss[i][j], Hi, c_old, rv.II[j])
|
||||
_hash_point(hasher, pk[i][j], tmp_buff)
|
||||
crypto.add_keys3_into(R, rv.ss[i][j], Hi, c_old, rv.II[j])
|
||||
|
||||
hasher.update(pk[i][j])
|
||||
_hash_point(hasher, L, tmp_buff)
|
||||
_hash_point(hasher, R, tmp_buff)
|
||||
|
||||
for j in range(dsRows, rows):
|
||||
# again, omitting R here as discussed above
|
||||
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j])
|
||||
_hash_point(hasher, pk[i][j], tmp_buff)
|
||||
crypto.add_keys2_into(
|
||||
L, rv.ss[i][j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
|
||||
)
|
||||
hasher.update(pk[i][j])
|
||||
_hash_point(hasher, L, tmp_buff)
|
||||
|
||||
c = crypto.decodeint(hasher.digest())
|
||||
c_old = c
|
||||
crypto.decodeint_into(c, hasher.digest())
|
||||
crypto.sc_copy(c_old, c)
|
||||
pk[i] = None
|
||||
i = (i + 1) % cols
|
||||
|
||||
if i == 0:
|
||||
rv.cc = c_old
|
||||
crypto.sc_copy(rv.cc, c_old)
|
||||
gc.collect()
|
||||
|
||||
del rv.II
|
||||
|
||||
rv.ss[index] = [None] * rows
|
||||
for j in range(rows):
|
||||
rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user