1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-15 01:40:57 +00:00

Merge pull request #394 from ph4r05/xmr-mg

xmr: MLSAG computation optimized
This commit is contained in:
Tomas Susanka 2018-11-02 15:26:48 +01:00 committed by GitHub
commit d919e99255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 183 additions and 130 deletions

View File

@ -10,7 +10,7 @@ from .state import State
from apps.monero.layout import confirms
from apps.monero.signing import RctType
from apps.monero.xmr import crypto, serialize
from apps.monero.xmr import crypto
if False:
from trezor.messages.MoneroTransactionSourceEntry import (
@ -40,8 +40,6 @@ async def sign_input(
:param spend_enc: one time address spending private key. Encrypted.
:return: Generated signature MGs[i]
"""
from apps.monero.signing import offloading_keys
await confirms.transaction_step(
state.ctx, state.STEP_SIGN, state.current_input_index + 1, state.input_count
)
@ -57,8 +55,11 @@ async def sign_input(
raise ValueError("Two and more inputs must imply SimpleRCT")
input_position = state.source_permutation[state.current_input_index]
mods = utils.unimport_begin()
# Check input's HMAC
from apps.monero.signing import offloading_keys
vini_hmac_comp = await offloading_keys.gen_hmac_vini(
state.key_hmac, src_entr, vini_bin, input_position
)
@ -66,7 +67,9 @@ async def sign_input(
raise ValueError("HMAC is not correct")
gc.collect()
state.mem_trace(1)
state.mem_trace(1, True)
from apps.monero.xmr.crypto import chacha_poly
if state.rct_type == RctType.Simple:
# both pseudo_out and its mask were offloaded so we need to
@ -78,10 +81,7 @@ async def sign_input(
if not crypto.ct_equals(pseudo_out_hmac_comp, pseudo_out_hmac):
raise ValueError("HMAC is not correct")
gc.collect()
state.mem_trace(2)
from apps.monero.xmr.crypto import chacha_poly
state.mem_trace(2, True)
pseudo_out_alpha = crypto.decodeint(
chacha_poly.decrypt_pack(
@ -92,9 +92,6 @@ async def sign_input(
pseudo_out_c = crypto.decodepoint(pseudo_out)
# Spending secret
from apps.monero.xmr.crypto import chacha_poly
from apps.monero.xmr.serialize_messages.ct_keys import CtKey
spend_key = crypto.decodeint(
chacha_poly.decrypt_pack(
offloading_keys.enc_key_spend(state.key_enc, input_position),
@ -102,8 +99,18 @@ async def sign_input(
)
)
gc.collect()
state.mem_trace(3)
del (
offloading_keys,
chacha_poly,
pseudo_out,
pseudo_out_hmac,
pseudo_out_alpha_enc,
spend_enc,
)
utils.unimport_end(mods)
state.mem_trace(3, True)
from apps.monero.xmr.serialize_messages.ct_keys import CtKey
# Basic setup, sanity check
index = src_entr.real_output
@ -126,14 +133,16 @@ async def sign_input(
"Real source entry's mask does not equal spend key's",
)
gc.collect()
state.mem_trace(4)
state.mem_trace(4, True)
from apps.monero.xmr import mlsag
mg_buffer = []
ring_pubkeys = [x.key for x in src_entr.outputs]
del src_entr
if state.rct_type == RctType.Simple:
ring_pubkeys = [x.key for x in src_entr.outputs]
mg = mlsag.generate_mlsag_simple(
mlsag.generate_mlsag_simple(
state.full_message,
ring_pubkeys,
input_secret_key,
@ -141,52 +150,33 @@ async def sign_input(
pseudo_out_c,
kLRki,
index,
mg_buffer,
)
del (input_secret_key, pseudo_out_alpha, pseudo_out_c)
else:
# Full RingCt, only one input
txn_fee_key = crypto.scalarmult_h(state.fee)
ring_pubkeys = [[x.key] for x in src_entr.outputs]
mg = mlsag.generate_mlsag_full(
mlsag.generate_mlsag_full(
state.full_message,
ring_pubkeys,
[input_secret_key],
input_secret_key,
state.output_sk_masks,
state.output_pk_commitments,
kLRki,
index,
txn_fee_key,
mg_buffer,
)
gc.collect()
state.mem_trace(5)
del (input_secret_key, txn_fee_key)
# Encode
mgs = _recode_msg([mg])
gc.collect()
state.mem_trace(6)
del (mlsag, ring_pubkeys)
state.mem_trace(5, True)
from trezor.messages.MoneroTransactionSignInputAck import (
MoneroTransactionSignInputAck,
)
return MoneroTransactionSignInputAck(
signature=serialize.dump_msg_gc(mgs[0], preallocate=488)
)
def _recode_msg(mgs):
"""
Recodes MGs signatures from raw forms to bytearrays so it works with serialization
"""
for idx in range(len(mgs)):
mgs[idx].cc = crypto.encodeint(mgs[idx].cc)
if hasattr(mgs[idx], "II") and mgs[idx].II:
for i in range(len(mgs[idx].II)):
mgs[idx].II[i] = crypto.encodepoint(mgs[idx].II[i])
for i in range(len(mgs[idx].ss)):
for j in range(len(mgs[idx].ss[i])):
mgs[idx].ss[i][j] = crypto.encodeint(mgs[idx].ss[i][j])
return mgs
return MoneroTransactionSignInputAck(signature=mg_buffer)

View File

@ -107,6 +107,7 @@ def sc_init_into(r, x):
return tcry.init256_modm(r, x)
sc_copy = tcry.init256_modm
sc_get64 = tcry.get256_modm
sc_check = tcry.check256_modm
check_sc = tcry.check256_modm

View File

@ -42,63 +42,76 @@ and `sk` is equal to:
Mostly ported from official Monero client, but also inspired by Mininero.
Author: Dusan Klinec, ph4r05, 2018
"""
import gc
from apps.monero.xmr import crypto
from apps.monero.xmr.serialize import int_serialize
def generate_mlsag_full(
message, pubs, in_sk, out_sk_mask, out_pk_commitments, kLRki, index, txn_fee_key
message,
pubs,
in_sk,
out_sk_mask,
out_pk_commitments,
kLRki,
index,
txn_fee_key,
mg_buff,
):
cols = len(pubs)
if cols == 0:
raise ValueError("Empty pubs")
rows = len(pubs[0])
if rows == 0:
raise ValueError("Empty pub row")
for i in range(cols):
if len(pubs[i]) != rows:
raise ValueError("pub is not rectangular")
if len(in_sk) != rows:
raise ValueError("Bad inSk size")
rows = 1 # Monero uses only one row
if len(out_sk_mask) != len(out_pk_commitments):
raise ValueError("Bad outsk/putpk size")
sk = _key_vector(rows + 1)
M = _key_matrix(rows + 1, cols)
for i in range(rows + 1):
sk[i] = crypto.sc_0()
tmp_mi_rows = crypto.new_point(None)
tmp_pt = crypto.new_point(None)
for i in range(cols):
M[i][rows] = crypto.identity()
for j in range(rows):
M[i][j] = crypto.decodepoint(pubs[i][j].dest)
M[i][rows] = crypto.point_add(
M[i][rows], crypto.decodepoint(pubs[i][j].commitment)
)
crypto.identity_into(tmp_mi_rows) # M[i][rows]
sk[rows] = crypto.sc_0()
for j in range(rows):
sk[j] = in_sk[j].dest
sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row
# Should iterate over rows, simplified as rows == 1
M[i][0] = pubs[i].dest
crypto.point_add_into(
tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, pubs[i].commitment),
)
pubs[i] = None
for i in range(cols):
for j in range(len(out_pk_commitments)):
M[i][rows] = crypto.point_sub(
M[i][rows], crypto.decodepoint(out_pk_commitments[j])
crypto.point_sub_into(
tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, out_pk_commitments[j]),
) # subtract output Ci's in last row
# Subtract txn fee output in last row
M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key)
crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key)
M[i][rows] = crypto.encodepoint(tmp_mi_rows)
# Simplified as rows == 1
sk[0] = in_sk.dest
sk[rows] = in_sk.mask # originally: sum of all in_sk[0..rows] in sk[rows]
for j in range(len(out_pk_commitments)):
sk[rows] = crypto.sc_sub(
sk[rows], out_sk_mask[j]
crypto.sc_sub_into(
sk[rows], sk[rows], out_sk_mask[j]
) # subtract output masks in last row
return generate_mlsag(message, M, sk, kLRki, index, rows)
del (pubs, tmp_mi_rows, tmp_pt)
gc.collect()
return generate_mlsag(message, M, sk, kLRki, index, rows, mg_buff)
def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index, mg_buff):
"""
MLSAG for RctType.Simple
:param message: the full message to be signed (actually its hash)
@ -108,7 +121,7 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
:param cout: pseudo output commitment; point, decoded; better name: pseudo_out_c
:param kLRki: used only in multisig, currently not implemented
:param index: specifies corresponding public key to the `in_sk` in the pubs array
:return: MgSig
:param mg_buff: buffer to store the signature to
"""
# Monero signs inputs separately, so `rows` always equals 2 (pubkey, commitment)
# and `dsRows` is always 1 (denotes where the pubkeys "end")
@ -123,12 +136,21 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
sk[0] = in_sk.dest
sk[1] = crypto.sc_sub(in_sk.mask, a)
tmp_pt = crypto.new_point()
for i in range(cols):
M[i][0] = crypto.decodepoint(pubs[i].dest)
M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].commitment), cout)
crypto.point_sub_into(
tmp_pt, crypto.decodepoint_into(tmp_pt, pubs[i].commitment), cout
)
return generate_mlsag(message, M, sk, kLRki, index, dsRows)
M[i][0] = pubs[i].dest
M[i][1] = crypto.encodepoint(tmp_pt)
pubs[i] = None
del (pubs)
gc.collect()
return generate_mlsag(message, M, sk, kLRki, index, dsRows, mg_buff)
def gen_mlsag_assert(pk, xx, kLRki, index, dsRows):
@ -159,13 +181,10 @@ def gen_mlsag_assert(pk, xx, kLRki, index, dsRows):
return rows, cols
def generate_first_c_and_key_images(
message, rv, pk, xx, kLRki, index, dsRows, rows, cols
):
def generate_first_c_and_key_images(message, pk, xx, kLRki, index, dsRows, rows, cols):
"""
MLSAG computation - the part with secret keys
:param message: the full message to be signed (actually its hash)
:param rv: MgSig
:param pk: matrix of public keys and commitments
:param xx: input secret array composed of a private key and commitment mask
:param kLRki: used only in multisig, currently not implemented
@ -174,18 +193,19 @@ def generate_first_c_and_key_images(
:param rows: total number of rows
:param cols: size of ring
"""
Ip = _key_vector(dsRows)
rv.II = _key_vector(dsRows)
II = _key_vector(dsRows)
alpha = _key_vector(rows)
rv.ss = _key_matrix(rows, cols)
tmp_buff = bytearray(32)
Hi = crypto.new_point()
aGi = crypto.new_point()
aHPi = crypto.new_point()
hasher = _hasher_message(message)
for i in range(dsRows):
# this is somewhat extra as compared to the Ring Confidential Tx paper
# see footnote in From Zero to Monero section 3.3
hasher.update(crypto.encodepoint(pk[index][i]))
hasher.update(pk[index][i])
if kLRki:
raise NotImplementedError("Multisig not implemented")
# alpha[i] = kLRki.k
@ -194,36 +214,34 @@ def generate_first_c_and_key_images(
# hash_point(hasher, kLRki.R, tmp_buff)
else:
Hi = crypto.hash_to_point(crypto.encodepoint(pk[index][i]))
crypto.hash_to_point_into(Hi, pk[index][i])
alpha[i] = crypto.random_scalar()
# L = alpha_i * G
aGi = crypto.scalarmult_base(alpha[i])
crypto.scalarmult_base_into(aGi, alpha[i])
# Ri = alpha_i * H(P_i)
aHPi = crypto.scalarmult(Hi, alpha[i])
crypto.scalarmult_into(aHPi, Hi, alpha[i])
# key image
rv.II[i] = crypto.scalarmult(Hi, xx[i])
II[i] = crypto.scalarmult(Hi, xx[i])
_hash_point(hasher, aGi, tmp_buff)
_hash_point(hasher, aHPi, tmp_buff)
Ip[i] = rv.II[i]
for i in range(dsRows, rows):
alpha[i] = crypto.random_scalar()
# L = alpha_i * G
aGi = crypto.scalarmult_base(alpha[i])
crypto.scalarmult_base_into(aGi, alpha[i])
# for some reasons we omit calculating R here, which seems
# contrary to the paper, but it is in the Monero official client
# see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191
_hash_point(hasher, pk[index][i], tmp_buff)
hasher.update(pk[index][i])
_hash_point(hasher, aGi, tmp_buff)
# the first c
c_old = hasher.digest()
c_old = crypto.decodeint(c_old)
return c_old, Ip, alpha
return c_old, II, alpha
def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
def generate_mlsag(message, pk, xx, kLRki, index, dsRows, mg_buff):
"""
Multilayered Spontaneous Anonymous Group Signatures (MLSAG signatures)
@ -233,56 +251,89 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
:param kLRki: used only in multisig, currently not implemented
:param index: specifies corresponding public key to the `xx`'s private key in the `pk` array
:param dsRows: separates pubkeys from commitment
:return MgSig
:param mg_buff: mg signature buffer
"""
from apps.monero.xmr.serialize_messages.tx_full import MgSig
rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows)
rows_b_size = int_serialize.uvarint_size(rows)
rv = MgSig()
c, L, R, Hi = 0, None, None, None
# Preallocation of the chunked buffer, len + cols + cc
for _ in range(1 + cols + 1):
mg_buff.append(None)
mg_buff[0] = int_serialize.dump_uvarint_b(cols)
cc = crypto.new_scalar() # rv.cc
c = crypto.new_scalar()
L = crypto.new_point()
R = crypto.new_point()
Hi = crypto.new_point()
# calculates the "first" c, key images and random scalars alpha
c_old, Ip, alpha = generate_first_c_and_key_images(
message, rv, pk, xx, kLRki, index, dsRows, rows, cols
c_old, II, alpha = generate_first_c_and_key_images(
message, pk, xx, kLRki, index, dsRows, rows, cols
)
i = (index + 1) % cols
if i == 0:
rv.cc = c_old
crypto.sc_copy(cc, c_old)
ss = [crypto.new_scalar() for _ in range(rows)]
tmp_buff = bytearray(32)
while i != index:
rv.ss[i] = _generate_random_vector(rows)
hasher = _hasher_message(message)
# Serialize size of the row
mg_buff[i + 1] = bytearray(rows_b_size + 32 * rows)
int_serialize.dump_uvarint_b_into(rows, mg_buff[i + 1])
for x in ss:
crypto.random_scalar(x)
for j in range(dsRows):
# L = rv.ss[i][j] * G + c_old * pk[i][j]
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j])
Hi = crypto.hash_to_point(crypto.encodepoint(pk[i][j]))
crypto.add_keys2_into(
L, ss[j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
)
crypto.hash_to_point_into(Hi, pk[i][j])
# R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j]
R = crypto.add_keys3(rv.ss[i][j], Hi, c_old, rv.II[j])
_hash_point(hasher, pk[i][j], tmp_buff)
crypto.add_keys3_into(R, ss[j], Hi, c_old, II[j])
hasher.update(pk[i][j])
_hash_point(hasher, L, tmp_buff)
_hash_point(hasher, R, tmp_buff)
for j in range(dsRows, rows):
# again, omitting R here as discussed above
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j])
_hash_point(hasher, pk[i][j], tmp_buff)
crypto.add_keys2_into(
L, ss[j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
)
hasher.update(pk[i][j])
_hash_point(hasher, L, tmp_buff)
c = crypto.decodeint(hasher.digest())
c_old = c
for si in range(rows):
crypto.encodeint_into(mg_buff[i + 1], ss[si], rows_b_size + 32 * si)
crypto.decodeint_into(c, hasher.digest())
crypto.sc_copy(c_old, c)
pk[i] = None
i = (i + 1) % cols
if i == 0:
rv.cc = c_old
crypto.sc_copy(cc, c_old)
gc.collect()
del II
# Finalizing rv.ss by processing rv.ss[index]
mg_buff[index + 1] = bytearray(rows_b_size + 32 * rows)
int_serialize.dump_uvarint_b_into(rows, mg_buff[index + 1])
for j in range(rows):
rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j])
crypto.sc_mulsub_into(ss[j], c, xx[j], alpha[j])
crypto.encodeint_into(mg_buff[index + 1], ss[j], rows_b_size + 32 * j)
return rv
# rv.cc
mg_buff[-1] = crypto.encodeint(cc)
def _key_vector(rows):

View File

@ -1,11 +0,0 @@
from apps.monero.xmr.serialize.message_types import MessageType
from apps.monero.xmr.serialize_messages.base import ECKey
from apps.monero.xmr.serialize_messages.ct_keys import KeyM
class MgSig(MessageType):
__slots__ = ("ss", "cc", "II")
@classmethod
def f_specs(cls):
return (("ss", KeyM), ("cc", ECKey))

View File

@ -275,6 +275,11 @@ async def dump_message(writer, msg, fields=None):
elif ftype is BoolType:
await dump_uvarint(writer, int(svalue))
elif ftype is BytesType and is_chunked(svalue):
await dump_uvarint(writer, len_list_bytes(svalue))
for sub_svalue in svalue:
await writer.awrite(sub_svalue)
elif ftype is BytesType:
await dump_uvarint(writer, len(svalue))
await writer.awrite(svalue)
@ -329,7 +334,9 @@ def count_message(msg, fields=None):
elif ftype is BytesType:
for svalue in fvalue:
svalue = len(svalue)
svalue = (
len(svalue) if not is_chunked(svalue) else len_list_bytes(svalue)
)
nbytes += count_uvarint(svalue)
nbytes += svalue
@ -351,3 +358,18 @@ def count_message(msg, fields=None):
raise TypeError
return nbytes
def is_chunked(svalue):
return (
isinstance(svalue, list)
and len(svalue) > 0
and not isinstance(svalue[0], (int, bool))
)
def len_list_bytes(svalue):
res = 0
for x in svalue:
res += len(x)
return res