1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 07:28:10 +00:00

xmr: hf10 upgrades

- Deterministic output commitment masks, based on amount_key
- Bulletproof v2 serialization, EcdhInfo serialized as 8 B amount, XOR encrypted by a specific key derived from amount key
- Signing - pseudo_out recomputation on inputs, sign step
- Dummy encrypted payment ID (if applicable) for better transaction uniformity
This commit is contained in:
Dusan Klinec 2019-02-15 03:08:54 +01:00
parent 72631db462
commit bba8bf38eb
No known key found for this signature in database
GPG Key ID: 6337E118CCBCE103
21 changed files with 534 additions and 527 deletions

View File

@ -26,7 +26,7 @@ Pillow = ">=5.2.0"
Mako = ">=1.0.7" Mako = ">=1.0.7"
# monero # monero
monero_agent = {version = ">=1.7.1", extras = ["tcry", "dev"]} monero_agent = {version = ">=2.0.1", extras = ["tcry", "dev"]}
py_trezor_crypto_ph4 = {version = ">=0.1.1"} py_trezor_crypto_ph4 = {version = ">=0.1.1"}
[dev-packages] [dev-packages]

11
Pipfile.lock generated
View File

@ -290,17 +290,17 @@
"tcry" "tcry"
], ],
"hashes": [ "hashes": [
"sha256:229855aeffc2457c3cd20b30a41d8c31fc2898d9deb3667cfdb85ce5318aa218", "sha256:0e8ac7a9ff9512b9781deacf6a5ae8c53f49d0b8e43684278e4e57f328493ebb",
"sha256:898324657bf87c9f002dab5d2137565abda95950f52cbdbc6ccc659dd9d9910e" "sha256:e6e99d44a2d76cdc1addb9421d5039546d4ea0002691ca5186124aa4561f82f5"
], ],
"index": "pypi", "index": "pypi",
"version": "==1.7.6" "version": "==2.0.1"
}, },
"monero-serialize": { "monero-serialize": {
"hashes": [ "hashes": [
"sha256:81ae31a25901cf81969b48a14f9267775b9dbe4856f322807732bf03836773c7" "sha256:0bf42972a2b13b47c2b2bd42352006caf14117387c1fab09ab2ac0817c0dede5"
], ],
"version": "==2.1.0" "version": "==3.0.0"
}, },
"more-itertools": { "more-itertools": {
"hashes": [ "hashes": [
@ -433,6 +433,7 @@
"sha256:c523afd73949ac083bb7c5ef67416f0a971edc1921f4046a8cd40f87e43133ff" "sha256:c523afd73949ac083bb7c5ef67416f0a971edc1921f4046a8cd40f87e43133ff"
], ],
"index": "pypi", "index": "pypi",
"markers": "extra == 'tcry'",
"version": "==0.1.1" "version": "==0.1.1"
}, },
"pyblake2": { "pyblake2": {

View File

@ -64,7 +64,11 @@ async def require_confirm_transaction(ctx, tsx_data, network_type):
cur_payment = None cur_payment = None
await _require_confirm_output(ctx, dst, network_type, cur_payment) await _require_confirm_output(ctx, dst, network_type, cur_payment)
if has_payment and not has_integrated: if (
has_payment
and not has_integrated
and tsx_data.payment_id != b"\x00\x00\x00\x00\x00\x00\x00\x00"
):
await _require_confirm_payment_id(ctx, tsx_data.payment_id) await _require_confirm_payment_id(ctx, tsx_data.payment_id)
await _require_confirm_fee(ctx, tsx_data.fee) await _require_confirm_fee(ctx, tsx_data.fee)

View File

@ -67,12 +67,7 @@ async def sign_tx_dispatch(state, msg, keychain):
return ( return (
await step_04_input_vini.input_vini( await step_04_input_vini.input_vini(
state, state, msg.src_entr, msg.vini, msg.vini_hmac
msg.src_entr,
msg.vini,
msg.vini_hmac,
msg.pseudo_out,
msg.pseudo_out_hmac,
), ),
( (
MessageType.MoneroTransactionInputViniRequest, MessageType.MoneroTransactionInputViniRequest,
@ -91,11 +86,14 @@ async def sign_tx_dispatch(state, msg, keychain):
elif msg.MESSAGE_WIRE_TYPE == MessageType.MoneroTransactionSetOutputRequest: elif msg.MESSAGE_WIRE_TYPE == MessageType.MoneroTransactionSetOutputRequest:
from apps.monero.signing import step_06_set_output from apps.monero.signing import step_06_set_output
is_offloaded_bp = bool(msg.is_offloaded_bp)
dst, dst_hmac, rsig_data = msg.dst_entr, msg.dst_entr_hmac, msg.rsig_data dst, dst_hmac, rsig_data = msg.dst_entr, msg.dst_entr_hmac, msg.rsig_data
del msg del msg
return ( return (
await step_06_set_output.set_output(state, dst, dst_hmac, rsig_data), await step_06_set_output.set_output(
state, dst, dst_hmac, rsig_data, is_offloaded_bp
),
( (
MessageType.MoneroTransactionSetOutputRequest, MessageType.MoneroTransactionSetOutputRequest,
MessageType.MoneroTransactionAllOutSetRequest, MessageType.MoneroTransactionAllOutSetRequest,

View File

@ -43,18 +43,14 @@ class RsigType:
Bulletproof = 1 Bulletproof = 1
def get_monero_rct_type(rct_type, rsig_type): def get_monero_rct_type(bp_version=1):
""" """
This converts our internal representation of RctType and RsigType Returns transaction RctType according to the BP version.
into what is used in Monero: Only HP9+ is supported, thus Full and Simple variants are removed.
- Null = 0
- Full = 1
- Simple = 2
- Simple/Full with bulletproof = 3
""" """
if rsig_type == RsigType.Bulletproof: if bp_version == 1:
return 3 # Bulletproofs return 3 # TxRctType.Bulletproof
if rct_type == RctType.Simple: elif bp_version == 2:
return 2 # Simple return 4 # TxRctType.Bulletproof2
else: else:
return 1 # Full raise ValueError("Unsupported BP version")

View File

@ -2,14 +2,28 @@ from trezor import utils
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
if False:
from apps.monero.xmr.types import *
def _build_key(secret, discriminator=None, index: int = None) -> bytes:
BUILD_KEY_BUFFER = bytearray(32 + 12 + 4) # key + disc + index
def _build_key(
secret, discriminator=None, index: int = None, out: bytes = None
) -> bytes:
""" """
Creates an unique-purpose key Creates an unique-purpose key
""" """
key_buff = bytearray(32 + 12 + 4) # key + disc + index key_buff = BUILD_KEY_BUFFER # bytearray(32 + 12 + 4) # key + disc + index
utils.ensure(len(secret) == 32, "Invalid key length")
utils.ensure(len(discriminator) <= 12, "Disc too long")
offset = 32 offset = 32
utils.memcpy(key_buff, 0, secret, 0, len(secret)) utils.memcpy(key_buff, 0, secret, 0, 32)
for i in range(32, len(key_buff)):
key_buff[i] = 0
if discriminator is not None: if discriminator is not None:
utils.memcpy(key_buff, offset, discriminator, 0, len(discriminator)) utils.memcpy(key_buff, offset, discriminator, 0, len(discriminator))
@ -24,7 +38,7 @@ def _build_key(secret, discriminator=None, index: int = None) -> bytes:
offset += 1 offset += 1
index = shifted index = shifted
return crypto.keccak_2hash(key_buff) return crypto.keccak_2hash(key_buff, out)
def hmac_key_txin(key_hmac, idx: int) -> bytes: def hmac_key_txin(key_hmac, idx: int) -> bytes:
@ -83,6 +97,13 @@ def enc_key_cout(key_enc, idx: int = None) -> bytes:
return _build_key(key_enc, b"cout", idx) return _build_key(key_enc, b"cout", idx)
def det_comm_masks(key_enc, idx: int) -> Sc25519:
"""
Deterministic output commitment masks
"""
return crypto.decodeint(_build_key(key_enc, b"out-mask", idx))
async def gen_hmac_vini(key, src_entr, vini_bin, idx: int) -> bytes: async def gen_hmac_vini(key, src_entr, vini_bin, idx: int) -> bytes:
""" """
Computes hmac (TxSourceEntry[i] || tx.vin[i]) Computes hmac (TxSourceEntry[i] || tx.vin[i])

View File

@ -60,12 +60,11 @@ class State:
""" """
self.need_additional_txkeys = False self.need_additional_txkeys = False
# Ring Confidential Transaction type # Connected client version
# allowed values: RctType.{Full, Simple} self.client_version = 0
self.rct_type = None
# Range Signature type (also called range proof) # Bulletproof version. Pre for <=HF9 is 1, for >HP10 is 2
# allowed values: RsigType.{Borromean, Bulletproof} self.bp_version = 1
self.rsig_type = None
self.input_count = 0 self.input_count = 0
self.output_count = 0 self.output_count = 0
@ -82,18 +81,22 @@ class State:
# currently processed input/output index # currently processed input/output index
self.current_input_index = -1 self.current_input_index = -1
self.current_output_index = -1 self.current_output_index = -1
self.is_processing_offloaded = False
# for pseudo_out recomputation from new mask
self.input_last_amount = 0
self.summary_inputs_money = 0 self.summary_inputs_money = 0
self.summary_outs_money = 0 self.summary_outs_money = 0
# output commitments # output commitments
self.output_pk_commitments = [] self.output_pk_commitments = []
# masks used in the output commitment
self.output_sk_masks = []
self.output_amounts = [] self.output_amounts = []
# output *range proof* masks # output *range proof* masks. HP10+ makes them deterministic.
self.output_masks = [] self.output_masks = []
# last output mask for client_version=0
self.output_last_mask = None
# the range proofs are calculated in batches, this denotes the grouping # the range proofs are calculated in batches, this denotes the grouping
self.rsig_grouping = [] self.rsig_grouping = []
@ -146,3 +149,9 @@ class State:
def change_address(self): def change_address(self):
return self.output_change.addr if self.output_change else None return self.output_change.addr if self.output_change else None
def is_bulletproof_v2(self):
return self.bp_version >= 2
def is_det_mask(self):
return self.bp_version >= 2 or self.client_version > 0

View File

@ -6,7 +6,6 @@ import gc
from apps.monero import misc, signing from apps.monero import misc, signing
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing import RctType, RsigType
from apps.monero.signing.state import State from apps.monero.signing.state import State
from apps.monero.xmr import crypto, monero from apps.monero.xmr import crypto, monero
@ -28,6 +27,7 @@ async def init_transaction(
await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n) await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n)
state.creds = misc.get_creds(keychain, address_n, network_type) state.creds = misc.get_creds(keychain, address_n, network_type)
state.client_version = tsx_data.client_version or 0
state.fee = state.fee if state.fee > 0 else 0 state.fee = state.fee if state.fee > 0 else 0
state.tx_priv = crypto.random_scalar() state.tx_priv = crypto.random_scalar()
state.tx_pub = crypto.scalarmult_base(state.tx_priv) state.tx_pub = crypto.scalarmult_base(state.tx_priv)
@ -74,9 +74,9 @@ async def init_transaction(
state.mem_trace(10, True) state.mem_trace(10, True)
# Final message hasher # Final message hasher
state.full_message_hasher.init(state.rct_type == RctType.Simple) state.full_message_hasher.init()
state.full_message_hasher.set_type_fee( state.full_message_hasher.set_type_fee(
signing.get_monero_rct_type(state.rct_type, state.rsig_type), state.fee signing.get_monero_rct_type(state.bp_version), state.fee
) )
# Sub address precomputation # Sub address precomputation
@ -167,31 +167,32 @@ def _get_primary_change_address(state: State):
def _check_rsig_data(state: State, rsig_data: MoneroTransactionRsigData): def _check_rsig_data(state: State, rsig_data: MoneroTransactionRsigData):
""" """
There are two types of monero ring confidential transactions: There are two types of monero ring confidential transactions:
1. RCTTypeFull = 1 (used if num_inputs == 1) 1. RCTTypeFull = 1 (used if num_inputs == 1 && Borromean)
2. RCTTypeSimple = 2 (for num_inputs > 1) 2. RCTTypeSimple = 2 (for num_inputs > 1 || !Borromean)
and four types of range proofs (set in `rsig_data.rsig_type`): and four types of range proofs (set in `rsig_data.rsig_type`):
1. RangeProofBorromean = 0 1. RangeProofBorromean = 0
2. RangeProofBulletproof = 1 2. RangeProofBulletproof = 1
3. RangeProofMultiOutputBulletproof = 2 3. RangeProofMultiOutputBulletproof = 2
4. RangeProofPaddedBulletproof = 3 4. RangeProofPaddedBulletproof = 3
The current code supports only HF9, HF10 thus TX type is always simple
and RCT algorithm is always Bulletproof.
""" """
state.rsig_grouping = rsig_data.grouping state.rsig_grouping = rsig_data.grouping
if rsig_data.rsig_type == 0: if rsig_data.rsig_type == 0:
state.rsig_type = RsigType.Borromean raise ValueError("Borromean range sig not supported")
elif rsig_data.rsig_type in (1, 2, 3): elif rsig_data.rsig_type in (1, 2, 3):
state.rsig_type = RsigType.Bulletproof state.bp_version = rsig_data.bp_version or 1
if state.bp_version not in (1, 2):
raise ValueError("Unknown BP version")
else: else:
raise ValueError("Unknown rsig type") raise ValueError("Unknown rsig type")
# unintuitively RctType.Simple is used for more inputs if state.output_count > 2:
if state.input_count > 1 or state.rsig_type == RsigType.Bulletproof:
state.rct_type = RctType.Simple
else:
state.rct_type = RctType.Full
if state.rsig_type == RsigType.Bulletproof and state.output_count > 2:
state.rsig_offload = True state.rsig_offload = True
_check_grouping(state) _check_grouping(state)
@ -293,18 +294,24 @@ def _process_payment_id(state: State, tsx_data: MoneroTransactionData):
therefore the TX_EXTRA_NONCE_ENCRYPTED_PAYMENT_ID = 0x01 tag is used. therefore the TX_EXTRA_NONCE_ENCRYPTED_PAYMENT_ID = 0x01 tag is used.
If it is not encrypted, we use TX_EXTRA_NONCE_PAYMENT_ID = 0x00. If it is not encrypted, we use TX_EXTRA_NONCE_PAYMENT_ID = 0x00.
Since Monero release 0.13 all 2 output payments have encrypted payment ID
to make BC more uniform.
See: See:
- https://github.com/monero-project/monero/blob/ff7dc087ae5f7de162131cea9dbcf8eac7c126a1/src/cryptonote_basic/tx_extra.h - https://github.com/monero-project/monero/blob/ff7dc087ae5f7de162131cea9dbcf8eac7c126a1/src/cryptonote_basic/tx_extra.h
""" """
# encrypted payment id / dummy payment ID
view_key_pub_enc = None
if not tsx_data.payment_id or len(tsx_data.payment_id) == 8:
view_key_pub_enc = _get_key_for_payment_id_encryption(
tsx_data, state.change_address(), state.client_version > 0
)
if not tsx_data.payment_id: if not tsx_data.payment_id:
return return
# encrypted payment id elif len(tsx_data.payment_id) == 8:
if len(tsx_data.payment_id) == 8:
view_key_pub_enc = _get_key_for_payment_id_encryption(
tsx_data.outputs, state.change_address()
)
view_key_pub = crypto.decodepoint(view_key_pub_enc) view_key_pub = crypto.decodepoint(view_key_pub_enc)
payment_id_encr = _encrypt_payment_id( payment_id_encr = _encrypt_payment_id(
tsx_data.payment_id, view_key_pub, state.tx_priv tsx_data.payment_id, view_key_pub, state.tx_priv
@ -334,10 +341,15 @@ def _process_payment_id(state: State, tsx_data: MoneroTransactionData):
state.extra_nonce = extra_buff state.extra_nonce = extra_buff
def _get_key_for_payment_id_encryption(destinations: list, change_addr=None): def _get_key_for_payment_id_encryption(
tsx_data: MoneroTransactionData,
change_addr=None,
add_dummy_payment_id: bool = False,
):
""" """
Returns destination address public view key to be used for Returns destination address public view key to be used for
payment id encryption. payment id encryption. If no encrypted payment ID is chosen,
dummy payment ID is set for better transaction uniformity if possible.
""" """
from apps.monero.xmr.addresses import addr_eq from apps.monero.xmr.addresses import addr_eq
from trezor.messages.MoneroAccountPublicAddress import MoneroAccountPublicAddress from trezor.messages.MoneroAccountPublicAddress import MoneroAccountPublicAddress
@ -346,20 +358,24 @@ def _get_key_for_payment_id_encryption(destinations: list, change_addr=None):
spend_public_key=crypto.NULL_KEY_ENC, view_public_key=crypto.NULL_KEY_ENC spend_public_key=crypto.NULL_KEY_ENC, view_public_key=crypto.NULL_KEY_ENC
) )
count = 0 count = 0
for dest in destinations: for dest in tsx_data.outputs:
if dest.amount == 0: if dest.amount == 0:
continue continue
if change_addr and addr_eq(dest.addr, change_addr): if change_addr and addr_eq(dest.addr, change_addr):
continue continue
if addr_eq(dest.addr, addr): if addr_eq(dest.addr, addr):
continue continue
if count > 0: if count > 0 and tsx_data.payment_id:
raise ValueError( raise ValueError(
"Destinations have to have exactly one output to support encrypted payment ids" "Destinations have to have exactly one output to support encrypted payment ids"
) )
addr = dest.addr addr = dest.addr
count += 1 count += 1
# Insert dummy payment id for transaction uniformity
if not tsx_data.payment_id and count <= 1 and add_dummy_payment_id:
tsx_data.payment_id = bytearray(8)
if count == 0 and change_addr: if count == 0 and change_addr:
return change_addr.view_public_key return change_addr.view_public_key
@ -380,6 +396,4 @@ def _encrypt_payment_id(payment_id, public_key, secret_key):
derivation[32] = 0x8D # ENCRYPTED_PAYMENT_ID_TAIL derivation[32] = 0x8D # ENCRYPTED_PAYMENT_ID_TAIL
hash = crypto.cn_fast_hash(derivation) hash = crypto.cn_fast_hash(derivation)
pm_copy = bytearray(payment_id) pm_copy = bytearray(payment_id)
for i in range(8): return crypto.xor8(pm_copy, hash)
pm_copy[i] ^= hash[i]
return pm_copy

View File

@ -14,7 +14,6 @@ key derived for exactly this purpose.
from .state import State from .state import State
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing import RctType
from apps.monero.xmr import crypto, monero, serialize from apps.monero.xmr import crypto, monero, serialize
if False: if False:
@ -95,27 +94,19 @@ async def set_input(state: State, src_entr: MoneroTransactionSourceEntry):
state.mem_trace(3, True) state.mem_trace(3, True)
# PseudoOuts commitment, alphas stored to state # PseudoOuts commitment, alphas stored to state
pseudo_out = None alpha, pseudo_out = _gen_commitment(state, src_entr.amount)
pseudo_out_hmac = None pseudo_out = crypto.encodepoint(pseudo_out)
alpha_enc = None
if state.rct_type == RctType.Simple: # In full version the alpha is encrypted and passed back for storage
alpha, pseudo_out = _gen_commitment(state, src_entr.amount) pseudo_out_hmac = crypto.compute_hmac(
pseudo_out = crypto.encodepoint(pseudo_out) offloading_keys.hmac_key_txin_comm(state.key_hmac, state.current_input_index),
pseudo_out,
)
# In full version the alpha is encrypted and passed back for storage alpha_enc = chacha_poly.encrypt_pack(
pseudo_out_hmac = crypto.compute_hmac( offloading_keys.enc_key_txin_alpha(state.key_enc, state.current_input_index),
offloading_keys.hmac_key_txin_comm( crypto.encodeint(alpha),
state.key_hmac, state.current_input_index )
),
pseudo_out,
)
alpha_enc = chacha_poly.encrypt_pack(
offloading_keys.enc_key_txin_alpha(
state.key_enc, state.current_input_index
),
crypto.encodeint(alpha),
)
spend_enc = chacha_poly.encrypt_pack( spend_enc = chacha_poly.encrypt_pack(
offloading_keys.enc_key_spend(state.key_enc, state.current_input_index), offloading_keys.enc_key_spend(state.key_enc, state.current_input_index),
@ -128,6 +119,7 @@ async def set_input(state: State, src_entr: MoneroTransactionSourceEntry):
the precomputed subaddresses so we clear them to save memory. the precomputed subaddresses so we clear them to save memory.
""" """
state.subaddresses = None state.subaddresses = None
state.input_last_amount = src_entr.amount
return MoneroTransactionSetInputAck( return MoneroTransactionSetInputAck(
vini=vini_bin, vini=vini_bin,

View File

@ -7,7 +7,7 @@ Also hashes `pseudo_out` to the final_message.
from .state import State from .state import State
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing import RctType, RsigType, offloading_keys from apps.monero.signing import offloading_keys
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
if False: if False:
@ -21,9 +21,18 @@ async def input_vini(
src_entr: MoneroTransactionSourceEntry, src_entr: MoneroTransactionSourceEntry,
vini_bin: bytes, vini_bin: bytes,
vini_hmac: bytes, vini_hmac: bytes,
pseudo_out: bytes,
pseudo_out_hmac: bytes,
): ):
"""
This step serves for an incremental hashing of tx.vin[i] to the tx_prefix_hasher
after the sorting on tx.vin[i].ki.
Originally, this step also incrementaly hashed pseudo_output[i] to the full_message_hasher for
RctSimple transactions with Borromean proofs (HF8).
In later hard-forks, the pseudo_outputs were moved to the rctsig.prunable
which is not hashed to the final signature, thus pseudo_output hashing has been removed
(as we support only HF9 and HF10 now).
"""
from trezor.messages.MoneroTransactionInputViniAck import ( from trezor.messages.MoneroTransactionInputViniAck import (
MoneroTransactionInputViniAck, MoneroTransactionInputViniAck,
) )
@ -50,24 +59,4 @@ async def input_vini(
Incremental hasing of tx.vin[i] Incremental hasing of tx.vin[i]
""" """
state.tx_prefix_hasher.buffer(vini_bin) state.tx_prefix_hasher.buffer(vini_bin)
# in monero version >= 8 pseudo outs were moved to a different place
# bulletproofs imply version >= 8
if state.rct_type == RctType.Simple and state.rsig_type != RsigType.Bulletproof:
_hash_vini_pseudo_out(state, pseudo_out, pseudo_out_hmac)
return MoneroTransactionInputViniAck() return MoneroTransactionInputViniAck()
def _hash_vini_pseudo_out(state: State, pseudo_out: bytes, pseudo_out_hmac: bytes):
"""
Incremental hasing of pseudo output. Only applicable for simple rct.
"""
idx = state.source_permutation[state.current_input_index]
pseudo_out_hmac_comp = crypto.compute_hmac(
offloading_keys.hmac_key_txin_comm(state.key_hmac, idx), pseudo_out
)
if not crypto.ct_equals(pseudo_out_hmac, pseudo_out_hmac_comp):
raise ValueError("HMAC invalid for pseudo outs")
state.full_message_hasher.set_pseudo_out(pseudo_out)

View File

@ -3,12 +3,9 @@ All inputs set. Defining range signature parameters.
If in the applicable offloading mode, generate commitment masks. If in the applicable offloading mode, generate commitment masks.
""" """
from trezor import utils
from .state import State from .state import State
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing import RctType
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
@ -20,39 +17,53 @@ async def all_inputs_set(state: State):
from trezor.messages.MoneroTransactionAllInputsSetAck import ( from trezor.messages.MoneroTransactionAllInputsSetAck import (
MoneroTransactionAllInputsSetAck, MoneroTransactionAllInputsSetAck,
) )
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
# Generate random commitment masks to be used in range proofs. # Generate random commitment masks to be used in range proofs.
# If SimpleRCT is used the sum of the masks must match the input masks sum. # If SimpleRCT is used the sum of the masks must match the input masks sum.
state.sumout = crypto.sc_init(0) state.sumout = crypto.sc_init(0)
for i in range(state.output_count): rsig_data = None
cur_mask = crypto.new_scalar() # new mask for each output
is_last = i + 1 == state.output_count
if is_last and state.rct_type == RctType.Simple:
# in SimpleRCT the last mask needs to be calculated as an offset of the sum
crypto.sc_sub_into(cur_mask, state.sumpouts_alphas, state.sumout)
else:
crypto.random_scalar(cur_mask)
crypto.sc_add_into(state.sumout, state.sumout, cur_mask) # Client 0, HF9. Non-deterministic masks
state.output_masks.append(cur_mask) if not state.is_det_mask():
rsig_data = await _compute_masks(state)
if state.rct_type == RctType.Simple: resp = MoneroTransactionAllInputsSetAck(rsig_data=rsig_data)
utils.ensure( return resp
crypto.sc_eq(state.sumout, state.sumpouts_alphas), "Invalid masks sum"
) # sum check
state.sumout = crypto.sc_init(0) async def _compute_masks(state: State):
"""
Output masks computed in advance. Used with client_version=0 && HF9.
After HF10 (included) masks are deterministic, computed from the amount_key.
After all client update to v1 this code will be removed.
In order to preserve client_version=0 compatibility the masks have to be adjusted.
"""
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
from apps.monero.signing import offloading_keys
rsig_data = MoneroTransactionRsigData() rsig_data = MoneroTransactionRsigData()
resp = MoneroTransactionAllInputsSetAck(rsig_data=rsig_data)
# If range proofs are being offloaded, we send the masks to the host, which uses them # If range proofs are being offloaded, we send the masks to the host, which uses them
# to create the range proof. If not, we do not send any and we use them in the following step. # to create the range proof. If not, we do not send any and we use them in the following step.
if state.rsig_offload: if state.rsig_offload:
tmp_buff = bytearray(32) rsig_data.mask = []
rsig_data.mask = bytearray(32 * state.output_count)
for i in range(state.output_count):
crypto.encodeint_into(tmp_buff, state.output_masks[i])
utils.memcpy(rsig_data.mask, 32 * i, tmp_buff, 0, 32)
return resp # Deterministic masks, the last one is computed to balance the sums
for i in range(state.output_count):
if i + 1 == state.output_count:
cur_mask = crypto.sc_sub(state.sumpouts_alphas, state.sumout)
state.output_last_mask = cur_mask
else:
cur_mask = offloading_keys.det_comm_masks(state.key_enc, i)
crypto.sc_add_into(state.sumout, state.sumout, cur_mask)
if state.rsig_offload:
rsig_data.mask.append(crypto.encodeint(cur_mask))
if not crypto.sc_eq(state.sumpouts_alphas, state.sumout):
raise ValueError("Sum eq error")
state.sumout = crypto.sc_init(0)
return rsig_data

View File

@ -10,53 +10,60 @@ from .state import State
from apps.monero import signing from apps.monero import signing
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing import RsigType, offloading_keys from apps.monero.signing import offloading_keys
from apps.monero.xmr import crypto, serialize from apps.monero.xmr import crypto, serialize
async def set_output(state: State, dst_entr, dst_entr_hmac, rsig_data): async def set_output(
state: State, dst_entr, dst_entr_hmac, rsig_data, is_offloaded_bp=False
):
state.mem_trace(0, True) state.mem_trace(0, True)
mods = utils.unimport_begin() mods = utils.unimport_begin()
await confirms.transaction_step( # Progress update only for master message (skip for offloaded BP msg)
state.ctx, state.STEP_OUT, state.current_output_index + 1, state.output_count if not is_offloaded_bp:
) await confirms.transaction_step(
state.mem_trace(1) state.ctx,
state.STEP_OUT,
state.current_output_index + 1,
state.output_count,
)
state.current_output_index += 1 state.mem_trace(1, True)
dst_entr = await _validate(state, dst_entr, dst_entr_hmac, is_offloaded_bp)
state.mem_trace(2, True) state.mem_trace(2, True)
await _validate(state, dst_entr, dst_entr_hmac)
# First output - we include the size of the container into the tx prefix hasher if not state.is_processing_offloaded:
if state.current_output_index == 0: # First output - we include the size of the container into the tx prefix hasher
state.tx_prefix_hasher.uvarint(state.output_count) if state.current_output_index == 0:
state.mem_trace(4, True) state.tx_prefix_hasher.uvarint(state.output_count)
state.mem_trace(4, True)
state.output_amounts.append(dst_entr.amount)
state.summary_outs_money += dst_entr.amount
state.output_amounts.append(dst_entr.amount)
state.summary_outs_money += dst_entr.amount
utils.unimport_end(mods) utils.unimport_end(mods)
state.mem_trace(5, True) state.mem_trace(5, True)
# Range proof first, memory intensive # Compute tx keys and masks if applicable
rsig, mask = _range_proof(state, dst_entr.amount, rsig_data) tx_out_key, amount_key = _compute_tx_keys(state, dst_entr)
utils.unimport_end(mods) utils.unimport_end(mods)
state.mem_trace(6, True) state.mem_trace(6, True)
# additional tx key if applicable # Range proof first, memory intensive (fragmentation)
additional_txkey_priv = _set_out_additional_keys(state, dst_entr) rsig_data_new, mask = _range_proof(state, rsig_data)
# derivation = a*R or r*A or s*C utils.unimport_end(mods)
derivation = _set_out_derivation(state, dst_entr, additional_txkey_priv)
# amount key = H_s(derivation || i)
amount_key = crypto.derivation_to_scalar(derivation, state.current_output_index)
# one-time destination address P = H_s(derivation || i)*G + B
tx_out_key = crypto.derive_public_key(
derivation,
state.current_output_index,
crypto.decodepoint(dst_entr.addr.spend_public_key),
)
del (derivation, additional_txkey_priv)
state.mem_trace(7, True) state.mem_trace(7, True)
# If det masks & offloading, return as we are handling offloaded BP.
if state.is_processing_offloaded:
from trezor.messages.MoneroTransactionSetOutputAck import (
MoneroTransactionSetOutputAck,
)
return MoneroTransactionSetOutputAck()
# Tx header prefix hashing, hmac dst_entr # Tx header prefix hashing, hmac dst_entr
tx_out_bin, hmac_vouti = await _set_out_tx_out(state, dst_entr, tx_out_key) tx_out_bin, hmac_vouti = await _set_out_tx_out(state, dst_entr, tx_out_key)
state.mem_trace(11, True) state.mem_trace(11, True)
@ -93,29 +100,102 @@ async def set_output(state: State, dst_entr, dst_entr_hmac, rsig_data):
return MoneroTransactionSetOutputAck( return MoneroTransactionSetOutputAck(
tx_out=tx_out_bin, tx_out=tx_out_bin,
vouti_hmac=hmac_vouti, vouti_hmac=hmac_vouti,
rsig_data=_return_rsig_data(rsig), rsig_data=rsig_data_new,
out_pk=out_pk_bin, out_pk=out_pk_bin,
ecdh_info=ecdh_info_bin, ecdh_info=ecdh_info_bin,
) )
async def _validate(state: State, dst_entr, dst_entr_hmac): async def _validate(state: State, dst_entr, dst_entr_hmac, is_offloaded_bp):
if state.current_input_index + 1 != state.input_count: # If offloading flag then it has to be det_masks and offloading enabled.
raise ValueError("Invalid number of inputs") # Using IF as it is easier to read.
if state.current_output_index >= state.output_count: if is_offloaded_bp and (not state.rsig_offload or not state.is_det_mask()):
raise ValueError("Invalid output index") raise ValueError("Extraneous offloaded msg")
if dst_entr.amount < 0:
raise ValueError("Destination with wrong amount: %s" % dst_entr.amount)
# HMAC check of the destination # State change according to the det-mask BP offloading.
dst_entr_hmac_computed = await offloading_keys.gen_hmac_tsxdest( if state.is_det_mask() and state.rsig_offload:
state.key_hmac, dst_entr, state.current_output_index bidx = _get_rsig_batch(state, state.current_output_index)
last_in_batch = _is_last_in_batch(state, state.current_output_index, bidx)
utils.ensure(
not last_in_batch or state.is_processing_offloaded != is_offloaded_bp,
"Offloaded BP out of order",
)
state.is_processing_offloaded = is_offloaded_bp
if not state.is_processing_offloaded:
state.current_output_index += 1
utils.ensure(
not dst_entr or dst_entr.amount >= 0, "Destination with negative amount"
) )
if not crypto.ct_equals(dst_entr_hmac, dst_entr_hmac_computed): utils.ensure(
raise ValueError("HMAC invalid") state.current_input_index + 1 == state.input_count, "Invalid number of inputs"
del (dst_entr_hmac, dst_entr_hmac_computed) )
utils.ensure(
state.current_output_index < state.output_count, "Invalid output index"
)
utils.ensure(
state.is_det_mask() or not state.is_processing_offloaded,
"Offloaded extra msg while not using det masks",
)
if not state.is_processing_offloaded:
# HMAC check of the destination
dst_entr_hmac_computed = await offloading_keys.gen_hmac_tsxdest(
state.key_hmac, dst_entr, state.current_output_index
)
utils.ensure(
crypto.ct_equals(dst_entr_hmac, dst_entr_hmac_computed), "HMAC failed"
)
del (dst_entr_hmac_computed)
else:
dst_entr = None
del (dst_entr_hmac)
state.mem_trace(3, True) state.mem_trace(3, True)
return dst_entr
def _compute_tx_keys(state: State, dst_entr):
"""Computes tx_out_key, amount_key"""
if state.is_processing_offloaded:
return None, None # no need to recompute
# additional tx key if applicable
additional_txkey_priv = _set_out_additional_keys(state, dst_entr)
# derivation = a*R or r*A or s*C
derivation = _set_out_derivation(state, dst_entr, additional_txkey_priv)
# amount key = H_s(derivation || i)
amount_key = crypto.derivation_to_scalar(derivation, state.current_output_index)
# one-time destination address P = H_s(derivation || i)*G + B
tx_out_key = crypto.derive_public_key(
derivation,
state.current_output_index,
crypto.decodepoint(dst_entr.addr.spend_public_key),
)
del (derivation, additional_txkey_priv)
# Computes the newest mask if applicable
if state.is_det_mask():
from apps.monero.xmr import monero
mask = monero.commitment_mask(crypto.encodeint(amount_key))
elif state.current_output_index + 1 < state.output_count:
mask = offloading_keys.det_comm_masks(state.key_enc, state.current_output_index)
else:
mask = state.output_last_mask
state.output_last_mask = None
state.output_masks.append(mask)
return tx_out_key, amount_key
async def _set_out_tx_out(state: State, dst_entr, tx_out_key): async def _set_out_tx_out(state: State, dst_entr, tx_out_key):
""" """
@ -139,109 +219,128 @@ async def _set_out_tx_out(state: State, dst_entr, tx_out_key):
return tx_out_bin, hmac_vouti return tx_out_bin, hmac_vouti
def _range_proof(state, amount, rsig_data): def _range_proof(state, rsig_data):
""" """
Computes rangeproof Computes rangeproof and handles range proof offloading logic.
In order to optimize incremental transaction build, the mask computation is changed compared
to the official Monero code. In the official code, the input pedersen commitments are computed
after range proof in such a way summed masks for commitments (alpha) and rangeproofs (ai) are equal.
In order to save roundtrips we compute commitments randomly and then for the last rangeproof
a[63] = (\\sum_{i=0}^{num_inp}alpha_i - \\sum_{i=0}^{num_outs-1} amasks_i) - \\sum_{i=0}^{62}a_i
Since HF10 the commitments are deterministic.
The range proof is incrementally hashed to the final_message. The range proof is incrementally hashed to the final_message.
""" """
from apps.monero.xmr import range_signatures
mask = state.output_masks[state.current_output_index]
provided_rsig = None provided_rsig = None
if rsig_data and rsig_data.rsig and len(rsig_data.rsig) > 0: if rsig_data and rsig_data.rsig and len(rsig_data.rsig) > 0:
provided_rsig = rsig_data.rsig provided_rsig = rsig_data.rsig
if not state.rsig_offload and provided_rsig: if not state.rsig_offload and provided_rsig:
raise signing.Error("Provided unexpected rsig") raise signing.Error("Provided unexpected rsig")
# Batching # Batching & validation
bidx = _get_rsig_batch(state, state.current_output_index) bidx = _get_rsig_batch(state, state.current_output_index)
batch_size = state.rsig_grouping[bidx]
last_in_batch = _is_last_in_batch(state, state.current_output_index, bidx) last_in_batch = _is_last_in_batch(state, state.current_output_index, bidx)
if state.rsig_offload and provided_rsig and not last_in_batch: if state.rsig_offload and provided_rsig and not last_in_batch:
raise signing.Error("Provided rsig too early") raise signing.Error("Provided rsig too early")
if state.rsig_offload and last_in_batch and not provided_rsig:
if (
state.rsig_offload
and last_in_batch
and not provided_rsig
and (not state.is_det_mask() or state.is_processing_offloaded)
):
raise signing.Error("Rsig expected, not provided") raise signing.Error("Rsig expected, not provided")
# Batch not finished, skip range sig generation now # Batch not finished, skip range sig generation now
mask = state.output_masks[-1] if not state.is_processing_offloaded else None
offload_mask = mask and state.is_det_mask() and state.rsig_offload
# If not last, do not proceed to the BP processing.
if not last_in_batch: if not last_in_batch:
return None, mask rsig_data_new = (
_return_rsig_data(mask=crypto.encodeint(mask)) if offload_mask else None
)
return rsig_data_new, mask
# Rangeproof # Rangeproof
# Pedersen commitment on the value, mask from the commitment, range signature. # Pedersen commitment on the value, mask from the commitment, range signature.
C, rsig = None, None rsig = None
state.mem_trace("pre-rproof" if __debug__ else None, collect=True) state.mem_trace("pre-rproof" if __debug__ else None, collect=True)
if state.rsig_type == RsigType.Bulletproof and not state.rsig_offload: if not state.rsig_offload:
"""Bulletproof calculation in trezor""" """Bulletproof calculation in Trezor"""
rsig = range_signatures.prove_range_bp_batch( rsig = _rsig_bp(state)
state.output_amounts, state.output_masks
)
state.mem_trace("post-bp" if __debug__ else None, collect=True)
# Incremental BP hashing elif state.is_det_mask() and not state.is_processing_offloaded:
# BP is hashed with raw=False as hash does not contain L, R """Bulletproof offloaded to the host, deterministic masks. Nothing here, waiting for offloaded BP."""
# array sizes compared to the serialized bulletproof format pass
# thus direct serialization cannot be used.
state.full_message_hasher.rsig_val(rsig, True, raw=False)
state.mem_trace("post-bp-hash" if __debug__ else None, collect=True)
rsig = _dump_rsig_bp(rsig) elif state.is_det_mask() and state.is_processing_offloaded:
state.mem_trace( """Bulletproof offloaded to the host, check BP, hash it."""
"post-bp-ser, size: %s" % len(rsig) if __debug__ else None, collect=True _rsig_process_bp(state, rsig_data)
)
elif state.rsig_type == RsigType.Borromean and not state.rsig_offload:
"""Borromean calculation in trezor"""
C, mask, rsig = range_signatures.prove_range_borromean(amount, mask)
del range_signatures
# Incremental hashing
state.full_message_hasher.rsig_val(rsig, False, raw=True)
_check_out_commitment(state, amount, mask, C)
elif state.rsig_type == RsigType.Bulletproof and state.rsig_offload:
"""Bulletproof calculated on host, verify in trezor"""
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof
# TODO this should be tested
# last_in_batch = True (see above) so this is fine
masks = state.output_masks[
1 + state.current_output_index - batch_size : 1 + state.current_output_index
]
bp_obj = serialize.parse_msg(rsig_data.rsig, Bulletproof)
rsig_data.rsig = None
# BP is hashed with raw=False as hash does not contain L, R
# array sizes compared to the serialized bulletproof format
# thus direct serialization cannot be used.
state.full_message_hasher.rsig_val(bp_obj, True, raw=False)
res = range_signatures.verify_bp(bp_obj, state.output_amounts, masks)
utils.ensure(res, "BP verification fail")
state.mem_trace("BP verified" if __debug__ else None, collect=True)
del (bp_obj, range_signatures)
elif state.rsig_type == RsigType.Borromean and state.rsig_offload:
"""Borromean offloading not supported"""
raise signing.Error(
"Unsupported rsig state (Borromean offloaded is not supported)"
)
else: else:
raise signing.Error("Unexpected rsig state") """Bulletproof calculated on host, verify in Trezor"""
_rsig_process_bp(state, rsig_data)
state.mem_trace("rproof" if __debug__ else None, collect=True) state.mem_trace("rproof" if __debug__ else None, collect=True)
if state.current_output_index + 1 == state.output_count:
# Construct new rsig data to send back to the host.
rsig_data_new = _return_rsig_data(
rsig, crypto.encodeint(mask) if offload_mask else None
)
if state.current_output_index + 1 == state.output_count and (
not state.rsig_offload or state.is_processing_offloaded
):
# output masks and amounts are not needed anymore # output masks and amounts are not needed anymore
state.output_amounts = [] state.output_amounts = None
state.output_masks = [] state.output_masks = None
return rsig, mask
return rsig_data_new, mask
def _rsig_bp(state: State):
"""Bulletproof calculation in trezor"""
from apps.monero.xmr import range_signatures
rsig = range_signatures.prove_range_bp_batch(
state.output_amounts, state.output_masks
)
state.mem_trace("post-bp" if __debug__ else None, collect=True)
# Incremental BP hashing
# BP is hashed with raw=False as hash does not contain L, R
# array sizes compared to the serialized bulletproof format
# thus direct serialization cannot be used.
state.full_message_hasher.rsig_val(rsig, True, raw=False)
state.mem_trace("post-bp-hash" if __debug__ else None, collect=True)
rsig = _dump_rsig_bp(rsig)
state.mem_trace(
"post-bp-ser, size: %s" % len(rsig) if __debug__ else None, collect=True
)
# state cleanup
state.output_masks = []
state.output_amounts = []
return rsig
def _rsig_process_bp(state: State, rsig_data):
from apps.monero.xmr import range_signatures
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof
bp_obj = serialize.parse_msg(rsig_data.rsig, Bulletproof)
rsig_data.rsig = None
# BP is hashed with raw=False as hash does not contain L, R
# array sizes compared to the serialized bulletproof format
# thus direct serialization cannot be used.
state.full_message_hasher.rsig_val(bp_obj, True, raw=False)
res = range_signatures.verify_bp(bp_obj, state.output_amounts, state.output_masks)
utils.ensure(res, "BP verification fail")
state.mem_trace("BP verified" if __debug__ else None, collect=True)
del (bp_obj, range_signatures)
# State cleanup after verification is finished
state.output_amounts = []
state.output_masks = []
def _dump_rsig_bp(rsig): def _dump_rsig_bp(rsig):
@ -286,15 +385,21 @@ def _dump_rsig_bp(rsig):
return buff return buff
def _return_rsig_data(rsig): def _return_rsig_data(rsig=None, mask=None):
if rsig is None: if rsig is None and mask is None:
return None return None
from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData from trezor.messages.MoneroTransactionRsigData import MoneroTransactionRsigData
if isinstance(rsig, list): rsig_data = MoneroTransactionRsigData()
return MoneroTransactionRsigData(rsig_parts=rsig)
else: if mask:
return MoneroTransactionRsigData(rsig=rsig) rsig_data.mask = mask
if rsig:
rsig_data.rsig = rsig
return rsig_data
def _get_ecdh_info_and_out_pk(state: State, tx_out_key, amount, mask, amount_key): def _get_ecdh_info_and_out_pk(state: State, tx_out_key, amount, mask, amount_key):
@ -305,23 +410,49 @@ def _get_ecdh_info_and_out_pk(state: State, tx_out_key, amount, mask, amount_key
""" """
out_pk_dest = crypto.encodepoint(tx_out_key) out_pk_dest = crypto.encodepoint(tx_out_key)
out_pk_commitment = crypto.encodepoint(crypto.gen_commitment(mask, amount)) out_pk_commitment = crypto.encodepoint(crypto.gen_commitment(mask, amount))
crypto.sc_add_into(state.sumout, state.sumout, mask)
state.sumout = crypto.sc_add(state.sumout, mask)
state.output_sk_masks.append(mask)
# masking of mask and amount # masking of mask and amount
ecdh_info = _ecdh_encode(mask, amount, crypto.encodeint(amount_key)) ecdh_info = _ecdh_encode(
mask, amount, crypto.encodeint(amount_key), state.is_bulletproof_v2()
)
# Manual ECDH info serialization # Manual ECDH info serialization
ecdh_info_bin = bytearray(64) ecdh_info_bin = _serialize_ecdh(ecdh_info, state.is_bulletproof_v2())
utils.memcpy(ecdh_info_bin, 0, ecdh_info.mask, 0, 32)
utils.memcpy(ecdh_info_bin, 32, ecdh_info.amount, 0, 32)
gc.collect() gc.collect()
return out_pk_dest, out_pk_commitment, ecdh_info_bin return out_pk_dest, out_pk_commitment, ecdh_info_bin
def _ecdh_encode(mask, amount, amount_key): def _serialize_ecdh(ecdh_info, v2=False):
"""
Serializes ECDH according to the current format defined by the hard fork version
or the signature format respectively.
"""
if v2:
# In HF10 the amount is serialized to 8B and mask is deterministic
ecdh_info_bin = bytearray(8)
ecdh_info_bin[:] = ecdh_info.amount[0:8]
return ecdh_info_bin
else:
ecdh_info_bin = bytearray(64)
utils.memcpy(ecdh_info_bin, 0, ecdh_info.mask, 0, 32)
utils.memcpy(ecdh_info_bin, 32, ecdh_info.amount, 0, 32)
return ecdh_info_bin
def _ecdh_hash(shared_sec):
"""
Generates ECDH hash for amount masking for Bulletproof2
"""
data = bytearray(38)
data[0:6] = b"amount"
data[6:] = shared_sec
return crypto.cn_fast_hash(data)
def _ecdh_encode(mask, amount, amount_key, v2=False):
""" """
Output recipients need be able to reconstruct the amount commitments. Output recipients need be able to reconstruct the amount commitments.
This means the blinding factor `mask` and `amount` must be communicated This means the blinding factor `mask` and `amount` must be communicated
@ -336,23 +467,27 @@ def _ecdh_encode(mask, amount, amount_key):
from apps.monero.xmr.serialize_messages.tx_ecdh import EcdhTuple from apps.monero.xmr.serialize_messages.tx_ecdh import EcdhTuple
ecdh_info = EcdhTuple(mask=mask, amount=crypto.sc_init(amount)) ecdh_info = EcdhTuple(mask=mask, amount=crypto.sc_init(amount))
amount_key_hash_single = crypto.hash_to_scalar(amount_key)
amount_key_hash_double = crypto.hash_to_scalar(
crypto.encodeint(amount_key_hash_single)
)
ecdh_info.mask = crypto.sc_add(ecdh_info.mask, amount_key_hash_single) if v2:
ecdh_info.amount = crypto.sc_add(ecdh_info.amount, amount_key_hash_double) amnt = ecdh_info.amount
return _recode_ecdh(ecdh_info) ecdh_info.mask = crypto.NULL_KEY_ENC
ecdh_info.amount = bytearray(32)
crypto.encodeint_into(ecdh_info.amount, amnt)
crypto.xor8(ecdh_info.amount, _ecdh_hash(amount_key))
return ecdh_info
else:
amount_key_hash_single = crypto.hash_to_scalar(amount_key)
amount_key_hash_double = crypto.hash_to_scalar(
crypto.encodeint(amount_key_hash_single)
)
def _recode_ecdh(ecdh_info): # Not modifying passed mask, is reused in BP.
""" ecdh_info.mask = crypto.sc_add(ecdh_info.mask, amount_key_hash_single)
In-place ecdh_info tuple recoding crypto.sc_add_into(ecdh_info.amount, ecdh_info.amount, amount_key_hash_double)
""" ecdh_info.mask = crypto.encodeint(ecdh_info.mask)
ecdh_info.mask = crypto.encodeint(ecdh_info.mask) ecdh_info.amount = crypto.encodeint(ecdh_info.amount)
ecdh_info.amount = crypto.encodeint(ecdh_info.amount) return ecdh_info
return ecdh_info
def _set_out_additional_keys(state: State, dst_entr): def _set_out_additional_keys(state: State, dst_entr):
@ -367,8 +502,9 @@ def _set_out_additional_keys(state: State, dst_entr):
if dst_entr.is_subaddress: if dst_entr.is_subaddress:
# R=r*D # R=r*D
additional_txkey = crypto.scalarmult( additional_txkey = crypto.decodepoint(dst_entr.addr.spend_public_key)
crypto.decodepoint(dst_entr.addr.spend_public_key), additional_txkey_priv crypto.scalarmult_into(
additional_txkey, additional_txkey, additional_txkey_priv
) )
else: else:
# R=r*G # R=r*G
@ -410,16 +546,6 @@ def _set_out_derivation(state: State, dst_entr, additional_txkey_priv):
return derivation return derivation
def _check_out_commitment(state: State, amount, mask, C):
utils.ensure(
crypto.point_eq(
C,
crypto.point_add(crypto.scalarmult_base(mask), crypto.scalarmult_h(amount)),
),
"OutC fail",
)
def _is_last_in_batch(state: State, idx, bidx): def _is_last_in_batch(state: State, idx, bidx):
""" """
Returns true if the current output is last in the rsig batch Returns true if the current output is last in the rsig batch

View File

@ -22,6 +22,7 @@ async def all_outputs_set(state: State):
state.mem_trace(1) state.mem_trace(1)
_validate(state) _validate(state)
state.is_processing_offloaded = False
state.mem_trace(2) state.mem_trace(2)
_set_tx_extra(state) _set_tx_extra(state)
@ -51,7 +52,7 @@ async def all_outputs_set(state: State):
rv_pb = MoneroRingCtSig( rv_pb = MoneroRingCtSig(
txn_fee=state.fee, txn_fee=state.fee,
message=state.tx_prefix_hash, message=state.tx_prefix_hash,
rv_type=get_monero_rct_type(state.rct_type, state.rsig_type), rv_type=get_monero_rct_type(state.bp_version),
) )
_out_pk(state) _out_pk(state)
@ -71,15 +72,9 @@ async def all_outputs_set(state: State):
def _validate(state: State): def _validate(state: State):
from apps.monero.signing import RctType
if state.current_output_index + 1 != state.output_count: if state.current_output_index + 1 != state.output_count:
raise ValueError("Invalid out num") raise ValueError("Invalid out num")
# Test if \sum Alpha == \sum A
if state.rct_type == RctType.Simple:
utils.ensure(crypto.sc_eq(state.sumout, state.sumpouts_alphas))
# Fee test # Fee test
if state.fee != (state.summary_inputs_money - state.summary_outs_money): if state.fee != (state.summary_inputs_money - state.summary_outs_money):
raise ValueError( raise ValueError(

View File

@ -9,7 +9,6 @@ from trezor import utils
from .state import State from .state import State
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing import RctType
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
if False: if False:
@ -29,6 +28,16 @@ async def sign_input(
spend_enc: bytes, spend_enc: bytes,
): ):
""" """
Signing UTXO.
Mask Balancing.
Sum of input masks has to be equal to the sum of output masks.
As the output masks has been made deterministic in HF10 the mask sum equality is corrected
in this step. The last input mask (and thus pseudo_out) is recomputed so the sums equal.
If deterministic masks cannot be used (client_version=0), the balancing is done in step 5
on output masks as pseudo outputs have to remain same.
:param state: transaction state :param state: transaction state
:param src_entr: Source entry :param src_entr: Source entry
:param vini_bin: tx.vin[i] for the transaction. Contains key image, offsets, amount (usually zero) :param vini_bin: tx.vin[i] for the transaction. Contains key image, offsets, amount (usually zero)
@ -47,12 +56,10 @@ async def sign_input(
state.current_input_index += 1 state.current_input_index += 1
if state.current_input_index >= state.input_count: if state.current_input_index >= state.input_count:
raise ValueError("Invalid inputs count") raise ValueError("Invalid inputs count")
if state.rct_type == RctType.Simple and pseudo_out is None: if pseudo_out is None:
raise ValueError("SimpleRCT requires pseudo_out but none provided") raise ValueError("SimpleRCT requires pseudo_out but none provided")
if state.rct_type == RctType.Simple and pseudo_out_alpha_enc is None: if pseudo_out_alpha_enc is None:
raise ValueError("SimpleRCT requires pseudo_out's mask but none provided") raise ValueError("SimpleRCT requires pseudo_out's mask but none provided")
if state.current_input_index >= 1 and not state.rct_type == RctType.Simple:
raise ValueError("Two and more inputs must imply SimpleRCT")
input_position = state.source_permutation[state.current_input_index] input_position = state.source_permutation[state.current_input_index]
mods = utils.unimport_begin() mods = utils.unimport_begin()
@ -71,7 +78,27 @@ async def sign_input(
from apps.monero.xmr.crypto import chacha_poly from apps.monero.xmr.crypto import chacha_poly
if state.rct_type == RctType.Simple: pseudo_out_alpha = crypto.decodeint(
chacha_poly.decrypt_pack(
offloading_keys.enc_key_txin_alpha(state.key_enc, input_position),
bytes(pseudo_out_alpha_enc),
)
)
# Last pseud_out is recomputed so mask sums hold
if state.is_det_mask() and input_position + 1 == state.input_count:
# Recompute the lash alpha so the sum holds
state.mem_trace("Correcting alpha")
alpha_diff = crypto.sc_sub(state.sumout, state.sumpouts_alphas)
crypto.sc_add_into(pseudo_out_alpha, pseudo_out_alpha, alpha_diff)
pseudo_out_c = crypto.gen_commitment(pseudo_out_alpha, state.input_last_amount)
else:
if input_position + 1 == state.input_count:
utils.ensure(
crypto.sc_eq(state.sumpouts_alphas, state.sumout), "Sum eq error"
)
# both pseudo_out and its mask were offloaded so we need to # both pseudo_out and its mask were offloaded so we need to
# validate pseudo_out's HMAC and decrypt the alpha # validate pseudo_out's HMAC and decrypt the alpha
pseudo_out_hmac_comp = crypto.compute_hmac( pseudo_out_hmac_comp = crypto.compute_hmac(
@ -81,16 +108,10 @@ async def sign_input(
if not crypto.ct_equals(pseudo_out_hmac_comp, pseudo_out_hmac): if not crypto.ct_equals(pseudo_out_hmac_comp, pseudo_out_hmac):
raise ValueError("HMAC is not correct") raise ValueError("HMAC is not correct")
state.mem_trace(2, True)
pseudo_out_alpha = crypto.decodeint(
chacha_poly.decrypt_pack(
offloading_keys.enc_key_txin_alpha(state.key_enc, input_position),
bytes(pseudo_out_alpha_enc),
)
)
pseudo_out_c = crypto.decodepoint(pseudo_out) pseudo_out_c = crypto.decodepoint(pseudo_out)
state.mem_trace(2, True)
# Spending secret # Spending secret
spend_key = crypto.decodeint( spend_key = crypto.decodeint(
chacha_poly.decrypt_pack( chacha_poly.decrypt_pack(
@ -141,42 +162,24 @@ async def sign_input(
ring_pubkeys = [x.key for x in src_entr.outputs] ring_pubkeys = [x.key for x in src_entr.outputs]
del src_entr del src_entr
if state.rct_type == RctType.Simple: mlsag.generate_mlsag_simple(
mlsag.generate_mlsag_simple( state.full_message,
state.full_message, ring_pubkeys,
ring_pubkeys, input_secret_key,
input_secret_key, pseudo_out_alpha,
pseudo_out_alpha, pseudo_out_c,
pseudo_out_c, kLRki,
kLRki, index,
index, mg_buffer,
mg_buffer, )
)
del (input_secret_key, pseudo_out_alpha, pseudo_out_c) del (input_secret_key, pseudo_out_alpha, mlsag, ring_pubkeys)
else:
# Full RingCt, only one input
txn_fee_key = crypto.scalarmult_h(state.fee)
mlsag.generate_mlsag_full(
state.full_message,
ring_pubkeys,
input_secret_key,
state.output_sk_masks,
state.output_pk_commitments,
kLRki,
index,
txn_fee_key,
mg_buffer,
)
del (input_secret_key, txn_fee_key)
del (mlsag, ring_pubkeys)
state.mem_trace(5, True) state.mem_trace(5, True)
from trezor.messages.MoneroTransactionSignInputAck import ( from trezor.messages.MoneroTransactionSignInputAck import (
MoneroTransactionSignInputAck, MoneroTransactionSignInputAck,
) )
return MoneroTransactionSignInputAck(signature=mg_buffer) return MoneroTransactionSignInputAck(
signature=mg_buffer, pseudo_out=crypto.encodepoint(pseudo_out_c)
)

View File

@ -25,8 +25,11 @@ keccak_hash = tcry.xmr_fast_hash
keccak_hash_into = tcry.xmr_fast_hash keccak_hash_into = tcry.xmr_fast_hash
def keccak_2hash(inp): def keccak_2hash(inp, buff=None):
return keccak_hash(keccak_hash(inp)) buff = buff if buff else bytearray(32)
keccak_hash_into(buff, inp)
keccak_hash_into(buff, buff)
return buff
def compute_hmac(key, msg=None): def compute_hmac(key, msg=None):
@ -168,7 +171,6 @@ https://www.imperialviolet.org/2013/12/25/elligator.html
http://elligator.cr.yp.to/ http://elligator.cr.yp.to/
http://elligator.cr.yp.to/elligator-20130828.pdf http://elligator.cr.yp.to/elligator-20130828.pdf
""" """
ge_frombytes_vartime_check = tcry.ge25519_check
# #
# Monero specific # Monero specific
@ -226,7 +228,7 @@ def generate_key_derivation(pub, sec):
Key derivation: 8*(key2*key1) Key derivation: 8*(key2*key1)
""" """
sc_check(sec) # checks that the secret key is uniform enough... sc_check(sec) # checks that the secret key is uniform enough...
ge_frombytes_vartime_check(pub) check_ed25519point(pub)
return tcry.xmr_generate_key_derivation(pub, sec) return tcry.xmr_generate_key_derivation(pub, sec)
@ -242,9 +244,7 @@ def derive_public_key(derivation, output_index, B):
""" """
H_s(derivation || varint(output_index))G + B H_s(derivation || varint(output_index))G + B
""" """
ge_frombytes_vartime_check(B) # check some conditions on the point
check_ed25519point(B) check_ed25519point(B)
return tcry.xmr_derive_public_key(derivation, output_index, B) return tcry.xmr_derive_public_key(derivation, output_index, B)
@ -298,3 +298,9 @@ def check_signature(data, c, r, pub):
tmp_c = hash_to_scalar(buff) tmp_c = hash_to_scalar(buff)
res = sc_sub(tmp_c, c) res = sc_sub(tmp_c, c)
return not sc_isnonzero(res) return not sc_isnonzero(res)
def xor8(buff, key):
for i in range(8):
buff[i] ^= key[i]
return buff

View File

@ -61,7 +61,7 @@ def generate_ring_signature(prefix_hash, image, pubs, sec, sec_idx, test=False):
if not crypto.point_eq(k_i, image): if not crypto.point_eq(k_i, image):
raise ValueError("Key image invalid") raise ValueError("Key image invalid")
for k in pubs: for k in pubs:
crypto.ge_frombytes_vartime_check(k) crypto.check_ed25519point(k)
buff_off = len(prefix_hash) buff_off = len(prefix_hash)
buff = bytearray(buff_off + 2 * 32 * len(pubs)) buff = bytearray(buff_off + 2 * 32 * len(pubs))

View File

@ -49,68 +49,6 @@ from apps.monero.xmr import crypto
from apps.monero.xmr.serialize import int_serialize from apps.monero.xmr.serialize import int_serialize
def generate_mlsag_full(
message,
pubs,
in_sk,
out_sk_mask,
out_pk_commitments,
kLRki,
index,
txn_fee_key,
mg_buff,
):
cols = len(pubs)
if cols == 0:
raise ValueError("Empty pubs")
rows = 1 # Monero uses only one row
if len(out_sk_mask) != len(out_pk_commitments):
raise ValueError("Bad outsk/putpk size")
sk = _key_vector(rows + 1)
M = _key_matrix(rows + 1, cols)
tmp_mi_rows = crypto.new_point(None)
tmp_pt = crypto.new_point(None)
for i in range(cols):
crypto.identity_into(tmp_mi_rows) # M[i][rows]
# Should iterate over rows, simplified as rows == 1
M[i][0] = pubs[i].dest
crypto.point_add_into(
tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, pubs[i].commitment),
)
pubs[i] = None
for j in range(len(out_pk_commitments)):
crypto.point_sub_into(
tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, out_pk_commitments[j]),
) # subtract output Ci's in last row
# Subtract txn fee output in last row
crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key)
M[i][rows] = crypto.encodepoint(tmp_mi_rows)
# Simplified as rows == 1
sk[0] = in_sk.dest
sk[rows] = in_sk.mask # originally: sum of all in_sk[0..rows] in sk[rows]
for j in range(len(out_pk_commitments)):
crypto.sc_sub_into(
sk[rows], sk[rows], out_sk_mask[j]
) # subtract output masks in last row
del (pubs, tmp_mi_rows, tmp_pt)
gc.collect()
return generate_mlsag(message, M, sk, kLRki, index, rows, mg_buff)
def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index, mg_buff): def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index, mg_buff):
""" """
MLSAG for RctType.Simple MLSAG for RctType.Simple

View File

@ -8,18 +8,16 @@ class PreMlsagHasher:
""" """
def __init__(self): def __init__(self):
self.is_simple = None
self.state = 0 self.state = 0
self.kc_master = crypto.get_keccak() self.kc_master = crypto.get_keccak()
self.rsig_hasher = crypto.get_keccak() self.rsig_hasher = crypto.get_keccak()
self.rtcsig_hasher = KeccakXmrArchive() self.rtcsig_hasher = KeccakXmrArchive()
def init(self, is_simple): def init(self):
if self.state != 0: if self.state != 0:
raise ValueError("State error") raise ValueError("State error")
self.state = 1 self.state = 1
self.is_simple = is_simple
def set_message(self, message): def set_message(self, message):
self.kc_master.update(message) self.kc_master.update(message)
@ -31,14 +29,6 @@ class PreMlsagHasher:
self.rtcsig_hasher.uint(rv_type, 1) # UInt8 self.rtcsig_hasher.uint(rv_type, 1) # UInt8
self.rtcsig_hasher.uvarint(fee) # UVarintType self.rtcsig_hasher.uvarint(fee) # UVarintType
def set_pseudo_out(self, out):
if self.state != 2 and self.state != 3:
raise ValueError("State error")
self.state = 3
# Manual serialization of the ECKey
self.rtcsig_hasher.buffer(out)
def set_ecdh(self, ecdh): def set_ecdh(self, ecdh):
if self.state != 2 and self.state != 3 and self.state != 4: if self.state != 2 and self.state != 3 and self.state != 4:
raise ValueError("State error") raise ValueError("State error")

View File

@ -259,3 +259,16 @@ def generate_sub_address_keys(view_sec, spend_pub, major, minor):
D = crypto.point_add(spend_pub, M) D = crypto.point_add(spend_pub, M)
C = crypto.scalarmult(D, view_sec) C = crypto.scalarmult(D, view_sec)
return D, C return D, C
def commitment_mask(key, buff=None):
"""
Generates deterministic commitment mask for Bulletproof2
"""
data = bytearray(15 + 32)
data[0:15] = b"commitment_mask"
data[15:] = key
if buff:
return crypto.hash_to_scalar_into(buff, data)
else:
return crypto.hash_to_scalar(data)

View File

@ -1,7 +1,7 @@
""" """
Computes range signature Computes range signature
Can compute Borromean range proof or Bulletproof. Can compute Bulletproof. Borromean support was discontinued.
Also can verify Bulletproof, in case the computation was offloaded. Also can verify Bulletproof, in case the computation was offloaded.
Mostly ported from official Monero client, but also inspired by Mininero. Mostly ported from official Monero client, but also inspired by Mininero.
@ -40,105 +40,3 @@ def verify_bp(bp_proof, amounts, masks):
res = bpi.verify(bp_proof) res = bpi.verify(bp_proof)
gc.collect() gc.collect()
return res return res
def prove_range_borromean(amount, last_mask):
"""Calculates Borromean range proof"""
# The large chunks allocated first to avoid potential memory fragmentation issues.
ai = bytearray(32 * 64)
alphai = bytearray(32 * 64)
Cis = bytearray(32 * 64)
s0s = bytearray(32 * 64)
s1s = bytearray(32 * 64)
buff = bytearray(32)
ee_bin = bytearray(32)
a = crypto.sc_init(0)
si = crypto.sc_init(0)
c = crypto.sc_init(0)
ee = crypto.sc_init(0)
tmp_ai = crypto.sc_init(0)
tmp_alpha = crypto.sc_init(0)
C_acc = crypto.identity()
C_h = crypto.xmr_H()
C_tmp = crypto.identity()
L = crypto.identity()
kck = crypto.get_keccak()
for ii in range(64):
crypto.random_scalar(tmp_ai)
if last_mask is not None and ii == 63:
crypto.sc_sub_into(tmp_ai, last_mask, a)
crypto.sc_add_into(a, a, tmp_ai)
crypto.random_scalar(tmp_alpha)
crypto.scalarmult_base_into(L, tmp_alpha)
crypto.scalarmult_base_into(C_tmp, tmp_ai)
# if 0: C_tmp += Zero (nothing is added)
# if 1: C_tmp += 2^i*H
# 2^i*H is already stored in C_h
if (amount >> ii) & 1 == 1:
crypto.point_add_into(C_tmp, C_tmp, C_h)
crypto.point_add_into(C_acc, C_acc, C_tmp)
# Set Ci[ii] to sigs
crypto.encodepoint_into(Cis, C_tmp, ii << 5)
crypto.encodeint_into(ai, tmp_ai, ii << 5)
crypto.encodeint_into(alphai, tmp_alpha, ii << 5)
if ((amount >> ii) & 1) == 0:
crypto.random_scalar(si)
crypto.encodepoint_into(buff, L)
crypto.hash_to_scalar_into(c, buff)
crypto.point_sub_into(C_tmp, C_tmp, C_h)
crypto.add_keys2_into(L, si, c, C_tmp)
crypto.encodeint_into(s1s, si, ii << 5)
crypto.encodepoint_into(buff, L)
kck.update(buff)
crypto.point_double_into(C_h, C_h)
# Compute ee
tmp_ee = kck.digest()
crypto.decodeint_into(ee, tmp_ee)
del (tmp_ee, kck)
C_h = crypto.xmr_H()
gc.collect()
# Second pass, s0, s1
for ii in range(64):
crypto.decodeint_into(tmp_alpha, alphai, ii << 5)
crypto.decodeint_into(tmp_ai, ai, ii << 5)
if ((amount >> ii) & 1) == 0:
crypto.sc_mulsub_into(si, tmp_ai, ee, tmp_alpha)
crypto.encodeint_into(s0s, si, ii << 5)
else:
crypto.random_scalar(si)
crypto.encodeint_into(s0s, si, ii << 5)
crypto.decodepoint_into(C_tmp, Cis, ii << 5)
crypto.add_keys2_into(L, si, ee, C_tmp)
crypto.encodepoint_into(buff, L)
crypto.hash_to_scalar_into(c, buff)
crypto.sc_mulsub_into(si, tmp_ai, c, tmp_alpha)
crypto.encodeint_into(s1s, si, ii << 5)
crypto.point_double_into(C_h, C_h)
crypto.encodeint_into(ee_bin, ee)
del (ai, alphai, buff, tmp_ai, tmp_alpha, si, c, ee, C_tmp, C_h, L)
gc.collect()
return C_acc, a, [s0s, s1s, ee_bin, Cis]

View File

@ -18,6 +18,9 @@ export EC_BACKEND_FORCE=1
export EC_BACKEND=1 export EC_BACKEND=1
export TREZOR_TEST_GET_TX=1 export TREZOR_TEST_GET_TX=1
export TREZOR_TEST_LIVE_REFRESH=1 export TREZOR_TEST_LIVE_REFRESH=1
export TREZOR_TEST_SIGN_CL0_HF9=1
export TREZOR_TEST_SIGN_CL1_HF9=1
export TREZOR_TEST_SIGN_CL1_HF10=1
python3 -m unittest trezor_monero_test.test_trezor python3 -m unittest trezor_monero_test.test_trezor
error=$? error=$?
kill $upy_pid kill $upy_pid