diff --git a/src/apps/monero/signing/step_09_sign_input.py b/src/apps/monero/signing/step_09_sign_input.py index 8d4249545b..73c7566eb4 100644 --- a/src/apps/monero/signing/step_09_sign_input.py +++ b/src/apps/monero/signing/step_09_sign_input.py @@ -134,14 +134,14 @@ async def sign_input( ) state.mem_trace(4, True) - mg_buffer = [] from apps.monero.xmr import mlsag - if state.rct_type == RctType.Simple: - ring_pubkeys = [x.key for x in src_entr.outputs] - src_entr = None + mg_buffer = [] + ring_pubkeys = [x.key for x in src_entr.outputs] + del src_entr + if state.rct_type == RctType.Simple: mlsag.generate_mlsag_simple( state.full_message, ring_pubkeys, @@ -153,18 +153,15 @@ async def sign_input( mg_buffer, ) - del (ring_pubkeys, input_secret_key, pseudo_out_alpha, pseudo_out_c) + del (input_secret_key, pseudo_out_alpha, pseudo_out_c) else: # Full RingCt, only one input txn_fee_key = crypto.scalarmult_h(state.fee) - ring_pubkeys = [[x.key] for x in src_entr.outputs] - src_entr = None - mlsag.generate_mlsag_full( state.full_message, ring_pubkeys, - [input_secret_key], + input_secret_key, state.output_sk_masks, state.output_pk_commitments, kLRki, @@ -173,9 +170,9 @@ async def sign_input( mg_buffer, ) - del (ring_pubkeys, input_secret_key, txn_fee_key) + del (input_secret_key, txn_fee_key) - del (mlsag, src_entr) + del (mlsag, ring_pubkeys) state.mem_trace(5, True) from trezor.messages.MoneroTransactionSignInputAck import ( diff --git a/src/apps/monero/xmr/mlsag.py b/src/apps/monero/xmr/mlsag.py index 944c1b8afe..06b47fd1b9 100644 --- a/src/apps/monero/xmr/mlsag.py +++ b/src/apps/monero/xmr/mlsag.py @@ -63,22 +63,12 @@ def generate_mlsag_full( cols = len(pubs) if cols == 0: raise ValueError("Empty pubs") - rows = len(pubs[0]) - if rows == 0: - raise ValueError("Empty pub row") - for i in range(cols): - if len(pubs[i]) != rows: - raise ValueError("pub is not rectangular") - - if len(in_sk) != rows: - raise ValueError("Bad inSk size") + rows = 1 # Monero uses only one row if len(out_sk_mask) != len(out_pk_commitments): raise ValueError("Bad outsk/putpk size") sk = _key_vector(rows + 1) M = _key_matrix(rows + 1, cols) - for i in range(rows + 1): - sk[i] = crypto.sc_0() tmp_mi_rows = crypto.new_point(None) tmp_pt = crypto.new_point(None) @@ -86,13 +76,13 @@ def generate_mlsag_full( for i in range(cols): crypto.identity_into(tmp_mi_rows) # M[i][rows] - for j in range(rows): - M[i][j] = pubs[i][j].dest - crypto.point_add_into( - tmp_mi_rows, - tmp_mi_rows, - crypto.decodepoint_into(tmp_pt, pubs[i][j].commitment), - ) + # Should iterate over rows, simplified as rows == 1 + M[i][0] = pubs[i].dest + crypto.point_add_into( + tmp_mi_rows, + tmp_mi_rows, + crypto.decodepoint_into(tmp_pt, pubs[i].commitment), + ) pubs[i] = None for j in range(len(out_pk_commitments)): @@ -106,10 +96,9 @@ def generate_mlsag_full( crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key) M[i][rows] = crypto.encodepoint(tmp_mi_rows) - sk[rows] = crypto.sc_0() - for j in range(rows): - sk[j] = in_sk[j].dest - crypto.sc_add_into(sk[rows], sk[rows], in_sk[j].mask) # add masks in last row + # Simplified as rows == 1 + sk[0] = in_sk.dest + sk[rows] = in_sk.mask # originally: sum of all in_sk[0..rows] in sk[rows] for j in range(len(out_pk_commitments)): crypto.sc_sub_into(