1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 15:30:55 +00:00

xmr: mlsag memory optimizations, in-place computation

- pub key matrix is not ge25519 as it consumes high amount of memory
- in-place computation used to reduce fragmentation overhead
This commit is contained in:
Dusan Klinec 2018-11-01 17:18:40 +01:00
parent a2b32115b2
commit 90fd0bb67a
No known key found for this signature in database
GPG Key ID: 6337E118CCBCE103
2 changed files with 81 additions and 41 deletions

View File

@ -107,6 +107,7 @@ def sc_init_into(r, x):
return tcry.init256_modm(r, x) return tcry.init256_modm(r, x)
sc_copy = tcry.init256_modm
sc_get64 = tcry.get256_modm sc_get64 = tcry.get256_modm
sc_check = tcry.check256_modm sc_check = tcry.check256_modm
check_sc = tcry.check256_modm check_sc = tcry.check256_modm

View File

@ -42,6 +42,9 @@ and `sk` is equal to:
Mostly ported from official Monero client, but also inspired by Mininero. Mostly ported from official Monero client, but also inspired by Mininero.
Author: Dusan Klinec, ph4r05, 2018 Author: Dusan Klinec, ph4r05, 2018
""" """
import gc
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
@ -68,33 +71,45 @@ def generate_mlsag_full(
for i in range(rows + 1): for i in range(rows + 1):
sk[i] = crypto.sc_0() sk[i] = crypto.sc_0()
tmp_mi_rows = crypto.new_point(None)
tmp_pt = crypto.new_point(None)
for i in range(cols): for i in range(cols):
M[i][rows] = crypto.identity() crypto.identity_into(tmp_mi_rows) # M[i][rows]
for j in range(rows): for j in range(rows):
M[i][j] = crypto.decodepoint(pubs[i][j].dest) M[i][j] = pubs[i][j].dest
M[i][rows] = crypto.point_add( crypto.point_add_into(
M[i][rows], crypto.decodepoint(pubs[i][j].commitment) tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, pubs[i][j].commitment),
) )
pubs[i] = None
for j in range(len(out_pk_commitments)):
crypto.point_sub_into(
tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, out_pk_commitments[j]),
) # subtract output Ci's in last row
# Subtract txn fee output in last row
crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key)
M[i][rows] = crypto.encodepoint(tmp_mi_rows)
sk[rows] = crypto.sc_0() sk[rows] = crypto.sc_0()
for j in range(rows): for j in range(rows):
sk[j] = in_sk[j].dest sk[j] = in_sk[j].dest
sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row crypto.sc_add_into(sk[rows], sk[rows], in_sk[j].mask) # add masks in last row
for i in range(cols):
for j in range(len(out_pk_commitments)):
M[i][rows] = crypto.point_sub(
M[i][rows], crypto.decodepoint(out_pk_commitments[j])
) # subtract output Ci's in last row
# Subtract txn fee output in last row
M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key)
for j in range(len(out_pk_commitments)): for j in range(len(out_pk_commitments)):
sk[rows] = crypto.sc_sub( crypto.sc_sub_into(
sk[rows], out_sk_mask[j] sk[rows], sk[rows], out_sk_mask[j]
) # subtract output masks in last row ) # subtract output masks in last row
del (pubs, tmp_mi_rows, tmp_pt)
gc.collect()
return generate_mlsag(message, M, sk, kLRki, index, rows) return generate_mlsag(message, M, sk, kLRki, index, rows)
@ -123,10 +138,19 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
sk[0] = in_sk.dest sk[0] = in_sk.dest
sk[1] = crypto.sc_sub(in_sk.mask, a) sk[1] = crypto.sc_sub(in_sk.mask, a)
tmp_pt = crypto.new_point()
for i in range(cols): for i in range(cols):
M[i][0] = crypto.decodepoint(pubs[i].dest) crypto.point_sub_into(
M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].commitment), cout) tmp_pt, crypto.decodepoint_into(tmp_pt, pubs[i].commitment), cout
)
M[i][0] = pubs[i].dest
M[i][1] = crypto.encodepoint(tmp_pt)
pubs[i] = None
del (pubs)
gc.collect()
return generate_mlsag(message, M, sk, kLRki, index, dsRows) return generate_mlsag(message, M, sk, kLRki, index, dsRows)
@ -174,18 +198,19 @@ def generate_first_c_and_key_images(
:param rows: total number of rows :param rows: total number of rows
:param cols: size of ring :param cols: size of ring
""" """
Ip = _key_vector(dsRows)
rv.II = _key_vector(dsRows) rv.II = _key_vector(dsRows)
alpha = _key_vector(rows) alpha = _key_vector(rows)
rv.ss = _key_matrix(rows, cols)
tmp_buff = bytearray(32) tmp_buff = bytearray(32)
Hi = crypto.new_point()
aGi = crypto.new_point()
aHPi = crypto.new_point()
hasher = _hasher_message(message) hasher = _hasher_message(message)
for i in range(dsRows): for i in range(dsRows):
# this is somewhat extra as compared to the Ring Confidential Tx paper # this is somewhat extra as compared to the Ring Confidential Tx paper
# see footnote in From Zero to Monero section 3.3 # see footnote in From Zero to Monero section 3.3
hasher.update(crypto.encodepoint(pk[index][i])) hasher.update(pk[index][i])
if kLRki: if kLRki:
raise NotImplementedError("Multisig not implemented") raise NotImplementedError("Multisig not implemented")
# alpha[i] = kLRki.k # alpha[i] = kLRki.k
@ -194,33 +219,31 @@ def generate_first_c_and_key_images(
# hash_point(hasher, kLRki.R, tmp_buff) # hash_point(hasher, kLRki.R, tmp_buff)
else: else:
Hi = crypto.hash_to_point(crypto.encodepoint(pk[index][i])) crypto.hash_to_point_into(Hi, pk[index][i])
alpha[i] = crypto.random_scalar() alpha[i] = crypto.random_scalar()
# L = alpha_i * G # L = alpha_i * G
aGi = crypto.scalarmult_base(alpha[i]) crypto.scalarmult_base_into(aGi, alpha[i])
# Ri = alpha_i * H(P_i) # Ri = alpha_i * H(P_i)
aHPi = crypto.scalarmult(Hi, alpha[i]) crypto.scalarmult_into(aHPi, Hi, alpha[i])
# key image # key image
rv.II[i] = crypto.scalarmult(Hi, xx[i]) rv.II[i] = crypto.scalarmult(Hi, xx[i])
_hash_point(hasher, aGi, tmp_buff) _hash_point(hasher, aGi, tmp_buff)
_hash_point(hasher, aHPi, tmp_buff) _hash_point(hasher, aHPi, tmp_buff)
Ip[i] = rv.II[i]
for i in range(dsRows, rows): for i in range(dsRows, rows):
alpha[i] = crypto.random_scalar() alpha[i] = crypto.random_scalar()
# L = alpha_i * G # L = alpha_i * G
aGi = crypto.scalarmult_base(alpha[i]) crypto.scalarmult_base_into(aGi, alpha[i])
# for some reasons we omit calculating R here, which seems # for some reasons we omit calculating R here, which seems
# contrary to the paper, but it is in the Monero official client # contrary to the paper, but it is in the Monero official client
# see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191 # see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191
_hash_point(hasher, pk[index][i], tmp_buff) hasher.update(pk[index][i])
_hash_point(hasher, aGi, tmp_buff) _hash_point(hasher, aGi, tmp_buff)
# the first c # the first c
c_old = hasher.digest() c_old = hasher.digest()
c_old = crypto.decodeint(c_old) c_old = crypto.decodeint(c_old)
return c_old, Ip, alpha return c_old, alpha
def generate_mlsag(message, pk, xx, kLRki, index, dsRows): def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
@ -240,17 +263,22 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows) rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows)
rv = MgSig() rv = MgSig()
c, L, R, Hi = 0, None, None, None rv.cc = crypto.new_scalar()
c = crypto.new_scalar()
L = crypto.new_point()
R = crypto.new_point()
Hi = crypto.new_point()
# calculates the "first" c, key images and random scalars alpha # calculates the "first" c, key images and random scalars alpha
c_old, Ip, alpha = generate_first_c_and_key_images( c_old, alpha = generate_first_c_and_key_images(
message, rv, pk, xx, kLRki, index, dsRows, rows, cols message, rv, pk, xx, kLRki, index, dsRows, rows, cols
) )
i = (index + 1) % cols i = (index + 1) % cols
if i == 0: if i == 0:
rv.cc = c_old crypto.sc_copy(rv.cc, c_old)
rv.ss = [None] * cols
tmp_buff = bytearray(32) tmp_buff = bytearray(32)
while i != index: while i != index:
rv.ss[i] = _generate_random_vector(rows) rv.ss[i] = _generate_random_vector(rows)
@ -258,27 +286,38 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
for j in range(dsRows): for j in range(dsRows):
# L = rv.ss[i][j] * G + c_old * pk[i][j] # L = rv.ss[i][j] * G + c_old * pk[i][j]
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j]) crypto.add_keys2_into(
Hi = crypto.hash_to_point(crypto.encodepoint(pk[i][j])) L, rv.ss[i][j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
)
crypto.hash_to_point_into(Hi, pk[i][j])
# R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j] # R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j]
R = crypto.add_keys3(rv.ss[i][j], Hi, c_old, rv.II[j]) crypto.add_keys3_into(R, rv.ss[i][j], Hi, c_old, rv.II[j])
_hash_point(hasher, pk[i][j], tmp_buff)
hasher.update(pk[i][j])
_hash_point(hasher, L, tmp_buff) _hash_point(hasher, L, tmp_buff)
_hash_point(hasher, R, tmp_buff) _hash_point(hasher, R, tmp_buff)
for j in range(dsRows, rows): for j in range(dsRows, rows):
# again, omitting R here as discussed above # again, omitting R here as discussed above
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j]) crypto.add_keys2_into(
_hash_point(hasher, pk[i][j], tmp_buff) L, rv.ss[i][j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
)
hasher.update(pk[i][j])
_hash_point(hasher, L, tmp_buff) _hash_point(hasher, L, tmp_buff)
c = crypto.decodeint(hasher.digest()) crypto.decodeint_into(c, hasher.digest())
c_old = c crypto.sc_copy(c_old, c)
pk[i] = None
i = (i + 1) % cols i = (i + 1) % cols
if i == 0: if i == 0:
rv.cc = c_old crypto.sc_copy(rv.cc, c_old)
gc.collect()
del rv.II
rv.ss[index] = [None] * rows
for j in range(rows): for j in range(rows):
rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j]) rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j])