mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 15:30:55 +00:00
xmr: mlsag memory optimizations, in-place computation
- pub key matrix is not ge25519 as it consumes high amount of memory - in-place computation used to reduce fragmentation overhead
This commit is contained in:
parent
a2b32115b2
commit
90fd0bb67a
@ -107,6 +107,7 @@ def sc_init_into(r, x):
|
|||||||
return tcry.init256_modm(r, x)
|
return tcry.init256_modm(r, x)
|
||||||
|
|
||||||
|
|
||||||
|
sc_copy = tcry.init256_modm
|
||||||
sc_get64 = tcry.get256_modm
|
sc_get64 = tcry.get256_modm
|
||||||
sc_check = tcry.check256_modm
|
sc_check = tcry.check256_modm
|
||||||
check_sc = tcry.check256_modm
|
check_sc = tcry.check256_modm
|
||||||
|
@ -42,6 +42,9 @@ and `sk` is equal to:
|
|||||||
Mostly ported from official Monero client, but also inspired by Mininero.
|
Mostly ported from official Monero client, but also inspired by Mininero.
|
||||||
Author: Dusan Klinec, ph4r05, 2018
|
Author: Dusan Klinec, ph4r05, 2018
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
from apps.monero.xmr import crypto
|
from apps.monero.xmr import crypto
|
||||||
|
|
||||||
|
|
||||||
@ -68,33 +71,45 @@ def generate_mlsag_full(
|
|||||||
for i in range(rows + 1):
|
for i in range(rows + 1):
|
||||||
sk[i] = crypto.sc_0()
|
sk[i] = crypto.sc_0()
|
||||||
|
|
||||||
|
tmp_mi_rows = crypto.new_point(None)
|
||||||
|
tmp_pt = crypto.new_point(None)
|
||||||
|
|
||||||
for i in range(cols):
|
for i in range(cols):
|
||||||
M[i][rows] = crypto.identity()
|
crypto.identity_into(tmp_mi_rows) # M[i][rows]
|
||||||
|
|
||||||
for j in range(rows):
|
for j in range(rows):
|
||||||
M[i][j] = crypto.decodepoint(pubs[i][j].dest)
|
M[i][j] = pubs[i][j].dest
|
||||||
M[i][rows] = crypto.point_add(
|
crypto.point_add_into(
|
||||||
M[i][rows], crypto.decodepoint(pubs[i][j].commitment)
|
tmp_mi_rows,
|
||||||
|
tmp_mi_rows,
|
||||||
|
crypto.decodepoint_into(tmp_pt, pubs[i][j].commitment),
|
||||||
)
|
)
|
||||||
|
pubs[i] = None
|
||||||
|
|
||||||
|
for j in range(len(out_pk_commitments)):
|
||||||
|
crypto.point_sub_into(
|
||||||
|
tmp_mi_rows,
|
||||||
|
tmp_mi_rows,
|
||||||
|
crypto.decodepoint_into(tmp_pt, out_pk_commitments[j]),
|
||||||
|
) # subtract output Ci's in last row
|
||||||
|
|
||||||
|
# Subtract txn fee output in last row
|
||||||
|
crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key)
|
||||||
|
M[i][rows] = crypto.encodepoint(tmp_mi_rows)
|
||||||
|
|
||||||
sk[rows] = crypto.sc_0()
|
sk[rows] = crypto.sc_0()
|
||||||
for j in range(rows):
|
for j in range(rows):
|
||||||
sk[j] = in_sk[j].dest
|
sk[j] = in_sk[j].dest
|
||||||
sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row
|
crypto.sc_add_into(sk[rows], sk[rows], in_sk[j].mask) # add masks in last row
|
||||||
|
|
||||||
for i in range(cols):
|
|
||||||
for j in range(len(out_pk_commitments)):
|
|
||||||
M[i][rows] = crypto.point_sub(
|
|
||||||
M[i][rows], crypto.decodepoint(out_pk_commitments[j])
|
|
||||||
) # subtract output Ci's in last row
|
|
||||||
|
|
||||||
# Subtract txn fee output in last row
|
|
||||||
M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key)
|
|
||||||
|
|
||||||
for j in range(len(out_pk_commitments)):
|
for j in range(len(out_pk_commitments)):
|
||||||
sk[rows] = crypto.sc_sub(
|
crypto.sc_sub_into(
|
||||||
sk[rows], out_sk_mask[j]
|
sk[rows], sk[rows], out_sk_mask[j]
|
||||||
) # subtract output masks in last row
|
) # subtract output masks in last row
|
||||||
|
|
||||||
|
del (pubs, tmp_mi_rows, tmp_pt)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
return generate_mlsag(message, M, sk, kLRki, index, rows)
|
return generate_mlsag(message, M, sk, kLRki, index, rows)
|
||||||
|
|
||||||
|
|
||||||
@ -123,10 +138,19 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
|
|||||||
|
|
||||||
sk[0] = in_sk.dest
|
sk[0] = in_sk.dest
|
||||||
sk[1] = crypto.sc_sub(in_sk.mask, a)
|
sk[1] = crypto.sc_sub(in_sk.mask, a)
|
||||||
|
tmp_pt = crypto.new_point()
|
||||||
|
|
||||||
for i in range(cols):
|
for i in range(cols):
|
||||||
M[i][0] = crypto.decodepoint(pubs[i].dest)
|
crypto.point_sub_into(
|
||||||
M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].commitment), cout)
|
tmp_pt, crypto.decodepoint_into(tmp_pt, pubs[i].commitment), cout
|
||||||
|
)
|
||||||
|
|
||||||
|
M[i][0] = pubs[i].dest
|
||||||
|
M[i][1] = crypto.encodepoint(tmp_pt)
|
||||||
|
pubs[i] = None
|
||||||
|
|
||||||
|
del (pubs)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
return generate_mlsag(message, M, sk, kLRki, index, dsRows)
|
return generate_mlsag(message, M, sk, kLRki, index, dsRows)
|
||||||
|
|
||||||
@ -174,18 +198,19 @@ def generate_first_c_and_key_images(
|
|||||||
:param rows: total number of rows
|
:param rows: total number of rows
|
||||||
:param cols: size of ring
|
:param cols: size of ring
|
||||||
"""
|
"""
|
||||||
Ip = _key_vector(dsRows)
|
|
||||||
rv.II = _key_vector(dsRows)
|
rv.II = _key_vector(dsRows)
|
||||||
alpha = _key_vector(rows)
|
alpha = _key_vector(rows)
|
||||||
rv.ss = _key_matrix(rows, cols)
|
|
||||||
|
|
||||||
tmp_buff = bytearray(32)
|
tmp_buff = bytearray(32)
|
||||||
|
Hi = crypto.new_point()
|
||||||
|
aGi = crypto.new_point()
|
||||||
|
aHPi = crypto.new_point()
|
||||||
hasher = _hasher_message(message)
|
hasher = _hasher_message(message)
|
||||||
|
|
||||||
for i in range(dsRows):
|
for i in range(dsRows):
|
||||||
# this is somewhat extra as compared to the Ring Confidential Tx paper
|
# this is somewhat extra as compared to the Ring Confidential Tx paper
|
||||||
# see footnote in From Zero to Monero section 3.3
|
# see footnote in From Zero to Monero section 3.3
|
||||||
hasher.update(crypto.encodepoint(pk[index][i]))
|
hasher.update(pk[index][i])
|
||||||
if kLRki:
|
if kLRki:
|
||||||
raise NotImplementedError("Multisig not implemented")
|
raise NotImplementedError("Multisig not implemented")
|
||||||
# alpha[i] = kLRki.k
|
# alpha[i] = kLRki.k
|
||||||
@ -194,33 +219,31 @@ def generate_first_c_and_key_images(
|
|||||||
# hash_point(hasher, kLRki.R, tmp_buff)
|
# hash_point(hasher, kLRki.R, tmp_buff)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
Hi = crypto.hash_to_point(crypto.encodepoint(pk[index][i]))
|
crypto.hash_to_point_into(Hi, pk[index][i])
|
||||||
alpha[i] = crypto.random_scalar()
|
alpha[i] = crypto.random_scalar()
|
||||||
# L = alpha_i * G
|
# L = alpha_i * G
|
||||||
aGi = crypto.scalarmult_base(alpha[i])
|
crypto.scalarmult_base_into(aGi, alpha[i])
|
||||||
# Ri = alpha_i * H(P_i)
|
# Ri = alpha_i * H(P_i)
|
||||||
aHPi = crypto.scalarmult(Hi, alpha[i])
|
crypto.scalarmult_into(aHPi, Hi, alpha[i])
|
||||||
# key image
|
# key image
|
||||||
rv.II[i] = crypto.scalarmult(Hi, xx[i])
|
rv.II[i] = crypto.scalarmult(Hi, xx[i])
|
||||||
_hash_point(hasher, aGi, tmp_buff)
|
_hash_point(hasher, aGi, tmp_buff)
|
||||||
_hash_point(hasher, aHPi, tmp_buff)
|
_hash_point(hasher, aHPi, tmp_buff)
|
||||||
|
|
||||||
Ip[i] = rv.II[i]
|
|
||||||
|
|
||||||
for i in range(dsRows, rows):
|
for i in range(dsRows, rows):
|
||||||
alpha[i] = crypto.random_scalar()
|
alpha[i] = crypto.random_scalar()
|
||||||
# L = alpha_i * G
|
# L = alpha_i * G
|
||||||
aGi = crypto.scalarmult_base(alpha[i])
|
crypto.scalarmult_base_into(aGi, alpha[i])
|
||||||
# for some reasons we omit calculating R here, which seems
|
# for some reasons we omit calculating R here, which seems
|
||||||
# contrary to the paper, but it is in the Monero official client
|
# contrary to the paper, but it is in the Monero official client
|
||||||
# see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191
|
# see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191
|
||||||
_hash_point(hasher, pk[index][i], tmp_buff)
|
hasher.update(pk[index][i])
|
||||||
_hash_point(hasher, aGi, tmp_buff)
|
_hash_point(hasher, aGi, tmp_buff)
|
||||||
|
|
||||||
# the first c
|
# the first c
|
||||||
c_old = hasher.digest()
|
c_old = hasher.digest()
|
||||||
c_old = crypto.decodeint(c_old)
|
c_old = crypto.decodeint(c_old)
|
||||||
return c_old, Ip, alpha
|
return c_old, alpha
|
||||||
|
|
||||||
|
|
||||||
def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
|
def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
|
||||||
@ -240,17 +263,22 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
|
|||||||
rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows)
|
rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows)
|
||||||
|
|
||||||
rv = MgSig()
|
rv = MgSig()
|
||||||
c, L, R, Hi = 0, None, None, None
|
rv.cc = crypto.new_scalar()
|
||||||
|
c = crypto.new_scalar()
|
||||||
|
L = crypto.new_point()
|
||||||
|
R = crypto.new_point()
|
||||||
|
Hi = crypto.new_point()
|
||||||
|
|
||||||
# calculates the "first" c, key images and random scalars alpha
|
# calculates the "first" c, key images and random scalars alpha
|
||||||
c_old, Ip, alpha = generate_first_c_and_key_images(
|
c_old, alpha = generate_first_c_and_key_images(
|
||||||
message, rv, pk, xx, kLRki, index, dsRows, rows, cols
|
message, rv, pk, xx, kLRki, index, dsRows, rows, cols
|
||||||
)
|
)
|
||||||
|
|
||||||
i = (index + 1) % cols
|
i = (index + 1) % cols
|
||||||
if i == 0:
|
if i == 0:
|
||||||
rv.cc = c_old
|
crypto.sc_copy(rv.cc, c_old)
|
||||||
|
|
||||||
|
rv.ss = [None] * cols
|
||||||
tmp_buff = bytearray(32)
|
tmp_buff = bytearray(32)
|
||||||
while i != index:
|
while i != index:
|
||||||
rv.ss[i] = _generate_random_vector(rows)
|
rv.ss[i] = _generate_random_vector(rows)
|
||||||
@ -258,27 +286,38 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
|
|||||||
|
|
||||||
for j in range(dsRows):
|
for j in range(dsRows):
|
||||||
# L = rv.ss[i][j] * G + c_old * pk[i][j]
|
# L = rv.ss[i][j] * G + c_old * pk[i][j]
|
||||||
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j])
|
crypto.add_keys2_into(
|
||||||
Hi = crypto.hash_to_point(crypto.encodepoint(pk[i][j]))
|
L, rv.ss[i][j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
|
||||||
|
)
|
||||||
|
crypto.hash_to_point_into(Hi, pk[i][j])
|
||||||
|
|
||||||
# R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j]
|
# R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j]
|
||||||
R = crypto.add_keys3(rv.ss[i][j], Hi, c_old, rv.II[j])
|
crypto.add_keys3_into(R, rv.ss[i][j], Hi, c_old, rv.II[j])
|
||||||
_hash_point(hasher, pk[i][j], tmp_buff)
|
|
||||||
|
hasher.update(pk[i][j])
|
||||||
_hash_point(hasher, L, tmp_buff)
|
_hash_point(hasher, L, tmp_buff)
|
||||||
_hash_point(hasher, R, tmp_buff)
|
_hash_point(hasher, R, tmp_buff)
|
||||||
|
|
||||||
for j in range(dsRows, rows):
|
for j in range(dsRows, rows):
|
||||||
# again, omitting R here as discussed above
|
# again, omitting R here as discussed above
|
||||||
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j])
|
crypto.add_keys2_into(
|
||||||
_hash_point(hasher, pk[i][j], tmp_buff)
|
L, rv.ss[i][j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
|
||||||
|
)
|
||||||
|
hasher.update(pk[i][j])
|
||||||
_hash_point(hasher, L, tmp_buff)
|
_hash_point(hasher, L, tmp_buff)
|
||||||
|
|
||||||
c = crypto.decodeint(hasher.digest())
|
crypto.decodeint_into(c, hasher.digest())
|
||||||
c_old = c
|
crypto.sc_copy(c_old, c)
|
||||||
|
pk[i] = None
|
||||||
i = (i + 1) % cols
|
i = (i + 1) % cols
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
rv.cc = c_old
|
crypto.sc_copy(rv.cc, c_old)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
del rv.II
|
||||||
|
|
||||||
|
rv.ss[index] = [None] * rows
|
||||||
for j in range(rows):
|
for j in range(rows):
|
||||||
rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j])
|
rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user