diff --git a/src/apps/monero/signing/step_09_sign_input.py b/src/apps/monero/signing/step_09_sign_input.py index c9ea302e1..73c7566eb 100644 --- a/src/apps/monero/signing/step_09_sign_input.py +++ b/src/apps/monero/signing/step_09_sign_input.py @@ -10,7 +10,7 @@ from .state import State from apps.monero.layout import confirms from apps.monero.signing import RctType -from apps.monero.xmr import crypto, serialize +from apps.monero.xmr import crypto if False: from trezor.messages.MoneroTransactionSourceEntry import ( @@ -40,8 +40,6 @@ async def sign_input( :param spend_enc: one time address spending private key. Encrypted. :return: Generated signature MGs[i] """ - from apps.monero.signing import offloading_keys - await confirms.transaction_step( state.ctx, state.STEP_SIGN, state.current_input_index + 1, state.input_count ) @@ -57,8 +55,11 @@ async def sign_input( raise ValueError("Two and more inputs must imply SimpleRCT") input_position = state.source_permutation[state.current_input_index] + mods = utils.unimport_begin() # Check input's HMAC + from apps.monero.signing import offloading_keys + vini_hmac_comp = await offloading_keys.gen_hmac_vini( state.key_hmac, src_entr, vini_bin, input_position ) @@ -66,7 +67,9 @@ async def sign_input( raise ValueError("HMAC is not correct") gc.collect() - state.mem_trace(1) + state.mem_trace(1, True) + + from apps.monero.xmr.crypto import chacha_poly if state.rct_type == RctType.Simple: # both pseudo_out and its mask were offloaded so we need to @@ -78,10 +81,7 @@ async def sign_input( if not crypto.ct_equals(pseudo_out_hmac_comp, pseudo_out_hmac): raise ValueError("HMAC is not correct") - gc.collect() - state.mem_trace(2) - - from apps.monero.xmr.crypto import chacha_poly + state.mem_trace(2, True) pseudo_out_alpha = crypto.decodeint( chacha_poly.decrypt_pack( @@ -92,9 +92,6 @@ async def sign_input( pseudo_out_c = crypto.decodepoint(pseudo_out) # Spending secret - from apps.monero.xmr.crypto import chacha_poly - from apps.monero.xmr.serialize_messages.ct_keys import CtKey - spend_key = crypto.decodeint( chacha_poly.decrypt_pack( offloading_keys.enc_key_spend(state.key_enc, input_position), @@ -102,8 +99,18 @@ async def sign_input( ) ) - gc.collect() - state.mem_trace(3) + del ( + offloading_keys, + chacha_poly, + pseudo_out, + pseudo_out_hmac, + pseudo_out_alpha_enc, + spend_enc, + ) + utils.unimport_end(mods) + state.mem_trace(3, True) + + from apps.monero.xmr.serialize_messages.ct_keys import CtKey # Basic setup, sanity check index = src_entr.real_output @@ -126,14 +133,16 @@ async def sign_input( "Real source entry's mask does not equal spend key's", ) - gc.collect() - state.mem_trace(4) + state.mem_trace(4, True) from apps.monero.xmr import mlsag + mg_buffer = [] + ring_pubkeys = [x.key for x in src_entr.outputs] + del src_entr + if state.rct_type == RctType.Simple: - ring_pubkeys = [x.key for x in src_entr.outputs] - mg = mlsag.generate_mlsag_simple( + mlsag.generate_mlsag_simple( state.full_message, ring_pubkeys, input_secret_key, @@ -141,52 +150,33 @@ async def sign_input( pseudo_out_c, kLRki, index, + mg_buffer, ) + del (input_secret_key, pseudo_out_alpha, pseudo_out_c) + else: # Full RingCt, only one input txn_fee_key = crypto.scalarmult_h(state.fee) - ring_pubkeys = [[x.key] for x in src_entr.outputs] - mg = mlsag.generate_mlsag_full( + mlsag.generate_mlsag_full( state.full_message, ring_pubkeys, - [input_secret_key], + input_secret_key, state.output_sk_masks, state.output_pk_commitments, kLRki, index, txn_fee_key, + mg_buffer, ) - gc.collect() - state.mem_trace(5) + del (input_secret_key, txn_fee_key) - # Encode - mgs = _recode_msg([mg]) - - gc.collect() - state.mem_trace(6) + del (mlsag, ring_pubkeys) + state.mem_trace(5, True) from trezor.messages.MoneroTransactionSignInputAck import ( MoneroTransactionSignInputAck, ) - return MoneroTransactionSignInputAck( - signature=serialize.dump_msg_gc(mgs[0], preallocate=488) - ) - - -def _recode_msg(mgs): - """ - Recodes MGs signatures from raw forms to bytearrays so it works with serialization - """ - for idx in range(len(mgs)): - mgs[idx].cc = crypto.encodeint(mgs[idx].cc) - if hasattr(mgs[idx], "II") and mgs[idx].II: - for i in range(len(mgs[idx].II)): - mgs[idx].II[i] = crypto.encodepoint(mgs[idx].II[i]) - - for i in range(len(mgs[idx].ss)): - for j in range(len(mgs[idx].ss[i])): - mgs[idx].ss[i][j] = crypto.encodeint(mgs[idx].ss[i][j]) - return mgs + return MoneroTransactionSignInputAck(signature=mg_buffer) diff --git a/src/apps/monero/xmr/crypto/__init__.py b/src/apps/monero/xmr/crypto/__init__.py index 9f746eec9..70fde1c46 100644 --- a/src/apps/monero/xmr/crypto/__init__.py +++ b/src/apps/monero/xmr/crypto/__init__.py @@ -107,6 +107,7 @@ def sc_init_into(r, x): return tcry.init256_modm(r, x) +sc_copy = tcry.init256_modm sc_get64 = tcry.get256_modm sc_check = tcry.check256_modm check_sc = tcry.check256_modm diff --git a/src/apps/monero/xmr/mlsag.py b/src/apps/monero/xmr/mlsag.py index 9aef687cb..06b47fd1b 100644 --- a/src/apps/monero/xmr/mlsag.py +++ b/src/apps/monero/xmr/mlsag.py @@ -42,63 +42,76 @@ and `sk` is equal to: Mostly ported from official Monero client, but also inspired by Mininero. Author: Dusan Klinec, ph4r05, 2018 """ + +import gc + from apps.monero.xmr import crypto +from apps.monero.xmr.serialize import int_serialize def generate_mlsag_full( - message, pubs, in_sk, out_sk_mask, out_pk_commitments, kLRki, index, txn_fee_key + message, + pubs, + in_sk, + out_sk_mask, + out_pk_commitments, + kLRki, + index, + txn_fee_key, + mg_buff, ): cols = len(pubs) if cols == 0: raise ValueError("Empty pubs") - rows = len(pubs[0]) - if rows == 0: - raise ValueError("Empty pub row") - for i in range(cols): - if len(pubs[i]) != rows: - raise ValueError("pub is not rectangular") - - if len(in_sk) != rows: - raise ValueError("Bad inSk size") + rows = 1 # Monero uses only one row if len(out_sk_mask) != len(out_pk_commitments): raise ValueError("Bad outsk/putpk size") sk = _key_vector(rows + 1) M = _key_matrix(rows + 1, cols) - for i in range(rows + 1): - sk[i] = crypto.sc_0() + + tmp_mi_rows = crypto.new_point(None) + tmp_pt = crypto.new_point(None) for i in range(cols): - M[i][rows] = crypto.identity() - for j in range(rows): - M[i][j] = crypto.decodepoint(pubs[i][j].dest) - M[i][rows] = crypto.point_add( - M[i][rows], crypto.decodepoint(pubs[i][j].commitment) - ) + crypto.identity_into(tmp_mi_rows) # M[i][rows] - sk[rows] = crypto.sc_0() - for j in range(rows): - sk[j] = in_sk[j].dest - sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row + # Should iterate over rows, simplified as rows == 1 + M[i][0] = pubs[i].dest + crypto.point_add_into( + tmp_mi_rows, + tmp_mi_rows, + crypto.decodepoint_into(tmp_pt, pubs[i].commitment), + ) + pubs[i] = None - for i in range(cols): for j in range(len(out_pk_commitments)): - M[i][rows] = crypto.point_sub( - M[i][rows], crypto.decodepoint(out_pk_commitments[j]) + crypto.point_sub_into( + tmp_mi_rows, + tmp_mi_rows, + crypto.decodepoint_into(tmp_pt, out_pk_commitments[j]), ) # subtract output Ci's in last row # Subtract txn fee output in last row - M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key) + crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key) + M[i][rows] = crypto.encodepoint(tmp_mi_rows) + + # Simplified as rows == 1 + sk[0] = in_sk.dest + sk[rows] = in_sk.mask # originally: sum of all in_sk[0..rows] in sk[rows] for j in range(len(out_pk_commitments)): - sk[rows] = crypto.sc_sub( - sk[rows], out_sk_mask[j] + crypto.sc_sub_into( + sk[rows], sk[rows], out_sk_mask[j] ) # subtract output masks in last row - return generate_mlsag(message, M, sk, kLRki, index, rows) + del (pubs, tmp_mi_rows, tmp_pt) + gc.collect() + return generate_mlsag(message, M, sk, kLRki, index, rows, mg_buff) -def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index): + +def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index, mg_buff): """ MLSAG for RctType.Simple :param message: the full message to be signed (actually its hash) @@ -108,7 +121,7 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index): :param cout: pseudo output commitment; point, decoded; better name: pseudo_out_c :param kLRki: used only in multisig, currently not implemented :param index: specifies corresponding public key to the `in_sk` in the pubs array - :return: MgSig + :param mg_buff: buffer to store the signature to """ # Monero signs inputs separately, so `rows` always equals 2 (pubkey, commitment) # and `dsRows` is always 1 (denotes where the pubkeys "end") @@ -123,12 +136,21 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index): sk[0] = in_sk.dest sk[1] = crypto.sc_sub(in_sk.mask, a) + tmp_pt = crypto.new_point() for i in range(cols): - M[i][0] = crypto.decodepoint(pubs[i].dest) - M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].commitment), cout) + crypto.point_sub_into( + tmp_pt, crypto.decodepoint_into(tmp_pt, pubs[i].commitment), cout + ) + + M[i][0] = pubs[i].dest + M[i][1] = crypto.encodepoint(tmp_pt) + pubs[i] = None + + del (pubs) + gc.collect() - return generate_mlsag(message, M, sk, kLRki, index, dsRows) + return generate_mlsag(message, M, sk, kLRki, index, dsRows, mg_buff) def gen_mlsag_assert(pk, xx, kLRki, index, dsRows): @@ -159,13 +181,10 @@ def gen_mlsag_assert(pk, xx, kLRki, index, dsRows): return rows, cols -def generate_first_c_and_key_images( - message, rv, pk, xx, kLRki, index, dsRows, rows, cols -): +def generate_first_c_and_key_images(message, pk, xx, kLRki, index, dsRows, rows, cols): """ MLSAG computation - the part with secret keys :param message: the full message to be signed (actually its hash) - :param rv: MgSig :param pk: matrix of public keys and commitments :param xx: input secret array composed of a private key and commitment mask :param kLRki: used only in multisig, currently not implemented @@ -174,18 +193,19 @@ def generate_first_c_and_key_images( :param rows: total number of rows :param cols: size of ring """ - Ip = _key_vector(dsRows) - rv.II = _key_vector(dsRows) + II = _key_vector(dsRows) alpha = _key_vector(rows) - rv.ss = _key_matrix(rows, cols) tmp_buff = bytearray(32) + Hi = crypto.new_point() + aGi = crypto.new_point() + aHPi = crypto.new_point() hasher = _hasher_message(message) for i in range(dsRows): # this is somewhat extra as compared to the Ring Confidential Tx paper # see footnote in From Zero to Monero section 3.3 - hasher.update(crypto.encodepoint(pk[index][i])) + hasher.update(pk[index][i]) if kLRki: raise NotImplementedError("Multisig not implemented") # alpha[i] = kLRki.k @@ -194,36 +214,34 @@ def generate_first_c_and_key_images( # hash_point(hasher, kLRki.R, tmp_buff) else: - Hi = crypto.hash_to_point(crypto.encodepoint(pk[index][i])) + crypto.hash_to_point_into(Hi, pk[index][i]) alpha[i] = crypto.random_scalar() # L = alpha_i * G - aGi = crypto.scalarmult_base(alpha[i]) + crypto.scalarmult_base_into(aGi, alpha[i]) # Ri = alpha_i * H(P_i) - aHPi = crypto.scalarmult(Hi, alpha[i]) + crypto.scalarmult_into(aHPi, Hi, alpha[i]) # key image - rv.II[i] = crypto.scalarmult(Hi, xx[i]) + II[i] = crypto.scalarmult(Hi, xx[i]) _hash_point(hasher, aGi, tmp_buff) _hash_point(hasher, aHPi, tmp_buff) - Ip[i] = rv.II[i] - for i in range(dsRows, rows): alpha[i] = crypto.random_scalar() # L = alpha_i * G - aGi = crypto.scalarmult_base(alpha[i]) + crypto.scalarmult_base_into(aGi, alpha[i]) # for some reasons we omit calculating R here, which seems # contrary to the paper, but it is in the Monero official client # see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191 - _hash_point(hasher, pk[index][i], tmp_buff) + hasher.update(pk[index][i]) _hash_point(hasher, aGi, tmp_buff) # the first c c_old = hasher.digest() c_old = crypto.decodeint(c_old) - return c_old, Ip, alpha + return c_old, II, alpha -def generate_mlsag(message, pk, xx, kLRki, index, dsRows): +def generate_mlsag(message, pk, xx, kLRki, index, dsRows, mg_buff): """ Multilayered Spontaneous Anonymous Group Signatures (MLSAG signatures) @@ -233,56 +251,89 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows): :param kLRki: used only in multisig, currently not implemented :param index: specifies corresponding public key to the `xx`'s private key in the `pk` array :param dsRows: separates pubkeys from commitment - :return MgSig + :param mg_buff: mg signature buffer """ - from apps.monero.xmr.serialize_messages.tx_full import MgSig - rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows) + rows_b_size = int_serialize.uvarint_size(rows) + + # Preallocation of the chunked buffer, len + cols + cc + for _ in range(1 + cols + 1): + mg_buff.append(None) - rv = MgSig() - c, L, R, Hi = 0, None, None, None + mg_buff[0] = int_serialize.dump_uvarint_b(cols) + cc = crypto.new_scalar() # rv.cc + c = crypto.new_scalar() + L = crypto.new_point() + R = crypto.new_point() + Hi = crypto.new_point() # calculates the "first" c, key images and random scalars alpha - c_old, Ip, alpha = generate_first_c_and_key_images( - message, rv, pk, xx, kLRki, index, dsRows, rows, cols + c_old, II, alpha = generate_first_c_and_key_images( + message, pk, xx, kLRki, index, dsRows, rows, cols ) i = (index + 1) % cols if i == 0: - rv.cc = c_old + crypto.sc_copy(cc, c_old) + ss = [crypto.new_scalar() for _ in range(rows)] tmp_buff = bytearray(32) + while i != index: - rv.ss[i] = _generate_random_vector(rows) hasher = _hasher_message(message) + # Serialize size of the row + mg_buff[i + 1] = bytearray(rows_b_size + 32 * rows) + int_serialize.dump_uvarint_b_into(rows, mg_buff[i + 1]) + + for x in ss: + crypto.random_scalar(x) + for j in range(dsRows): # L = rv.ss[i][j] * G + c_old * pk[i][j] - L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j]) - Hi = crypto.hash_to_point(crypto.encodepoint(pk[i][j])) + crypto.add_keys2_into( + L, ss[j], c_old, crypto.decodepoint_into(Hi, pk[i][j]) + ) + crypto.hash_to_point_into(Hi, pk[i][j]) + # R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j] - R = crypto.add_keys3(rv.ss[i][j], Hi, c_old, rv.II[j]) - _hash_point(hasher, pk[i][j], tmp_buff) + crypto.add_keys3_into(R, ss[j], Hi, c_old, II[j]) + + hasher.update(pk[i][j]) _hash_point(hasher, L, tmp_buff) _hash_point(hasher, R, tmp_buff) for j in range(dsRows, rows): # again, omitting R here as discussed above - L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j]) - _hash_point(hasher, pk[i][j], tmp_buff) + crypto.add_keys2_into( + L, ss[j], c_old, crypto.decodepoint_into(Hi, pk[i][j]) + ) + hasher.update(pk[i][j]) _hash_point(hasher, L, tmp_buff) - c = crypto.decodeint(hasher.digest()) - c_old = c + for si in range(rows): + crypto.encodeint_into(mg_buff[i + 1], ss[si], rows_b_size + 32 * si) + + crypto.decodeint_into(c, hasher.digest()) + crypto.sc_copy(c_old, c) + pk[i] = None i = (i + 1) % cols if i == 0: - rv.cc = c_old + crypto.sc_copy(cc, c_old) + gc.collect() + + del II + # Finalizing rv.ss by processing rv.ss[index] + mg_buff[index + 1] = bytearray(rows_b_size + 32 * rows) + int_serialize.dump_uvarint_b_into(rows, mg_buff[index + 1]) for j in range(rows): - rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j]) + crypto.sc_mulsub_into(ss[j], c, xx[j], alpha[j]) + crypto.encodeint_into(mg_buff[index + 1], ss[j], rows_b_size + 32 * j) - return rv + # rv.cc + mg_buff[-1] = crypto.encodeint(cc) def _key_vector(rows): diff --git a/src/apps/monero/xmr/serialize_messages/tx_full.py b/src/apps/monero/xmr/serialize_messages/tx_full.py deleted file mode 100644 index 1c98619b6..000000000 --- a/src/apps/monero/xmr/serialize_messages/tx_full.py +++ /dev/null @@ -1,11 +0,0 @@ -from apps.monero.xmr.serialize.message_types import MessageType -from apps.monero.xmr.serialize_messages.base import ECKey -from apps.monero.xmr.serialize_messages.ct_keys import KeyM - - -class MgSig(MessageType): - __slots__ = ("ss", "cc", "II") - - @classmethod - def f_specs(cls): - return (("ss", KeyM), ("cc", ECKey)) diff --git a/src/protobuf.py b/src/protobuf.py index 654c71e98..414f5c0a9 100644 --- a/src/protobuf.py +++ b/src/protobuf.py @@ -275,6 +275,11 @@ async def dump_message(writer, msg, fields=None): elif ftype is BoolType: await dump_uvarint(writer, int(svalue)) + elif ftype is BytesType and is_chunked(svalue): + await dump_uvarint(writer, len_list_bytes(svalue)) + for sub_svalue in svalue: + await writer.awrite(sub_svalue) + elif ftype is BytesType: await dump_uvarint(writer, len(svalue)) await writer.awrite(svalue) @@ -329,7 +334,9 @@ def count_message(msg, fields=None): elif ftype is BytesType: for svalue in fvalue: - svalue = len(svalue) + svalue = ( + len(svalue) if not is_chunked(svalue) else len_list_bytes(svalue) + ) nbytes += count_uvarint(svalue) nbytes += svalue @@ -351,3 +358,18 @@ def count_message(msg, fields=None): raise TypeError return nbytes + + +def is_chunked(svalue): + return ( + isinstance(svalue, list) + and len(svalue) > 0 + and not isinstance(svalue[0], (int, bool)) + ) + + +def len_list_bytes(svalue): + res = 0 + for x in svalue: + res += len(x) + return res