You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
3126 lines
101 KiB
3126 lines
101 KiB
import gc
|
|
import math
|
|
from micropython import const
|
|
from typing import TYPE_CHECKING
|
|
|
|
from trezor import utils
|
|
from trezor.crypto import random
|
|
from trezor.utils import memcpy as tmemcpy
|
|
|
|
from apps.monero.xmr import crypto, crypto_helpers
|
|
from apps.monero.xmr.serialize.int_serialize import dump_uvarint_b_into, uvarint_size
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Iterator, TypeVar, Generic
|
|
|
|
from .serialize_messages.tx_rsig_bulletproof import Bulletproof, BulletproofPlus
|
|
|
|
T = TypeVar("T")
|
|
ScalarDst = TypeVar("ScalarDst", bytearray, crypto.Scalar)
|
|
|
|
else:
|
|
Generic = (object,)
|
|
T = 0 # type: ignore
|
|
|
|
# Constants
|
|
TBYTES = (bytes, bytearray, memoryview)
|
|
_BP_LOG_N = const(6)
|
|
_BP_N = const(64) # 1 << _BP_LOG_N
|
|
_BP_M = const(16) # maximal number of bulletproofs
|
|
|
|
_ZERO = b"\x00" * 32
|
|
_ONE = b"\x01" + b"\x00" * 31
|
|
_TWO = b"\x02" + b"\x00" * 31
|
|
_EIGHT = b"\x08" + b"\x00" * 31
|
|
_INV_EIGHT = crypto_helpers.INV_EIGHT
|
|
_MINUS_ONE = b"\xec\xd3\xf5\x5c\x1a\x63\x12\x58\xd6\x9c\xf7\xa2\xde\xf9\xde\x14\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10"
|
|
|
|
# Monero H point
|
|
_XMR_H = b"\x8b\x65\x59\x70\x15\x37\x99\xaf\x2a\xea\xdc\x9f\xf1\xad\xd0\xea\x6c\x72\x51\xd5\x41\x54\xcf\xa9\x2c\x17\x3a\x0d\xd3\x9c\x1f\x94"
|
|
_XMR_HP = crypto.xmr_H()
|
|
_XMR_G = b"\x58\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66\x66"
|
|
|
|
# ip12 = inner_product(oneN, twoN);
|
|
_BP_IP12 = b"\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
|
|
_INITIAL_TRANSCRIPT = b"\x4a\x67\x7c\x90\xeb\x73\x05\x1e\x79\x0d\xa4\x55\x91\x10\x7f\x6e\xe1\x05\x90\x4d\x91\x87\xc5\xd3\x54\x71\x09\x6c\x44\x5a\x22\x75"
|
|
_TWO_SIXTY_FOUR_MINUS_ONE = b"\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
|
|
|
#
|
|
# Rct keys operations
|
|
# tmp_x are global working registers to minimize memory allocations / heap fragmentation.
|
|
# Caution has to be exercised when using the registers and operations using the registers
|
|
#
|
|
|
|
_tmp_bf_0 = bytearray(32)
|
|
_tmp_bf_1 = bytearray(32)
|
|
_tmp_bf_exp = bytearray(16 + 32 + 4)
|
|
|
|
_tmp_pt_1 = crypto.Point()
|
|
_tmp_pt_2 = crypto.Point()
|
|
_tmp_pt_3 = crypto.Point()
|
|
_tmp_pt_4 = crypto.Point()
|
|
|
|
_tmp_sc_1 = crypto.Scalar()
|
|
_tmp_sc_2 = crypto.Scalar()
|
|
_tmp_sc_3 = crypto.Scalar()
|
|
_tmp_sc_4 = crypto.Scalar()
|
|
_tmp_sc_5 = crypto.Scalar()
|
|
|
|
|
|
def _ensure_dst_key(dst: bytearray | None = None) -> bytearray:
|
|
if dst is None:
|
|
dst = bytearray(32)
|
|
return dst
|
|
|
|
|
|
def memcpy(
|
|
dst: bytearray, dst_off: int, src: bytes, src_off: int, len: int
|
|
) -> bytearray:
|
|
if dst is not None:
|
|
tmemcpy(dst, dst_off, src, src_off, len)
|
|
return dst
|
|
|
|
|
|
def _copy_key(dst: bytearray | None, src: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
for i in range(32):
|
|
dst[i] = src[i]
|
|
return dst
|
|
|
|
|
|
def _init_key(val: bytes, dst: bytearray | None = None) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
return _copy_key(dst, val)
|
|
|
|
|
|
def _load_scalar(dst: crypto.Scalar | None, a: ScalarDst) -> crypto.Scalar:
|
|
return (
|
|
crypto.sc_copy(dst, a)
|
|
if isinstance(a, crypto.Scalar)
|
|
else crypto.decodeint_into_noreduce(dst, a)
|
|
)
|
|
|
|
|
|
def _gc_iter(i: int) -> None:
|
|
if i & 127 == 0:
|
|
gc.collect()
|
|
|
|
|
|
def _invert(dst: bytearray | None, x: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
|
|
crypto.sc_inv_into(_tmp_sc_2, _tmp_sc_1)
|
|
crypto.encodeint_into(dst, _tmp_sc_2)
|
|
return dst
|
|
|
|
|
|
def _scalarmult_key(
|
|
dst: bytearray,
|
|
P,
|
|
s: bytes | None,
|
|
s_raw: int | None = None,
|
|
tmp_pt: crypto.Point = _tmp_pt_1,
|
|
):
|
|
# TODO: two functions based on s/s_raw ?
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodepoint_into(tmp_pt, P)
|
|
if s:
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, s)
|
|
crypto.scalarmult_into(tmp_pt, tmp_pt, _tmp_sc_1)
|
|
else:
|
|
assert s_raw is not None
|
|
crypto.scalarmult_into(tmp_pt, tmp_pt, s_raw)
|
|
crypto.encodepoint_into(dst, tmp_pt)
|
|
return dst
|
|
|
|
|
|
def _scalarmult8(dst: bytearray | None, P, tmp_pt: crypto.Point = _tmp_pt_1):
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodepoint_into(tmp_pt, P)
|
|
crypto.ge25519_mul8(tmp_pt, tmp_pt)
|
|
crypto.encodepoint_into(dst, tmp_pt)
|
|
return dst
|
|
|
|
|
|
def _scalarmultH(dst: bytearray, x: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into(_tmp_sc_1, x)
|
|
crypto.scalarmult_into(_tmp_pt_1, _XMR_HP, _tmp_sc_1)
|
|
crypto.encodepoint_into(dst, _tmp_pt_1)
|
|
return dst
|
|
|
|
|
|
def _scalarmult_base(dst: bytearray, x: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
|
|
crypto.scalarmult_base_into(_tmp_pt_1, _tmp_sc_1)
|
|
crypto.encodepoint_into(dst, _tmp_pt_1)
|
|
return dst
|
|
|
|
|
|
def _sc_gen(dst: bytearray | None = None) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.random_scalar(_tmp_sc_1)
|
|
crypto.encodeint_into(dst, _tmp_sc_1)
|
|
return dst
|
|
|
|
|
|
def _sc_add(dst: bytearray | None, a: bytes, b: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
crypto.sc_add_into(_tmp_sc_3, _tmp_sc_1, _tmp_sc_2)
|
|
crypto.encodeint_into(dst, _tmp_sc_3)
|
|
return dst
|
|
|
|
|
|
def _sc_sub(
|
|
dst: bytearray | None,
|
|
a: bytes | crypto.Scalar,
|
|
b: bytes | crypto.Scalar,
|
|
) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
if not isinstance(a, crypto.Scalar):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
a = _tmp_sc_1
|
|
if not isinstance(b, crypto.Scalar):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
b = _tmp_sc_2
|
|
crypto.sc_sub_into(_tmp_sc_3, a, b)
|
|
crypto.encodeint_into(dst, _tmp_sc_3)
|
|
return dst
|
|
|
|
|
|
def _sc_mul(dst: bytearray | None, a: bytes, b: bytes | crypto.Scalar) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
if not isinstance(b, crypto.Scalar):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
b = _tmp_sc_2
|
|
crypto.sc_mul_into(_tmp_sc_3, _tmp_sc_1, b)
|
|
crypto.encodeint_into(dst, _tmp_sc_3)
|
|
return dst
|
|
|
|
|
|
def _sc_mul8(dst: bytearray | None, a: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, _EIGHT)
|
|
crypto.sc_mul_into(_tmp_sc_3, _tmp_sc_1, _tmp_sc_2)
|
|
crypto.encodeint_into(dst, _tmp_sc_3)
|
|
return dst
|
|
|
|
|
|
def _sc_muladd(
|
|
dst: ScalarDst | None,
|
|
a: bytes | crypto.Scalar,
|
|
b: bytes | crypto.Scalar,
|
|
c: bytes | crypto.Scalar,
|
|
) -> ScalarDst:
|
|
if isinstance(dst, crypto.Scalar):
|
|
dst_sc = dst
|
|
else:
|
|
dst_sc = _tmp_sc_4
|
|
if not isinstance(a, crypto.Scalar):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
a = _tmp_sc_1
|
|
if not isinstance(b, crypto.Scalar):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
b = _tmp_sc_2
|
|
if not isinstance(c, crypto.Scalar):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_3, c)
|
|
c = _tmp_sc_3
|
|
crypto.sc_muladd_into(dst_sc, a, b, c)
|
|
if not isinstance(dst, crypto.Scalar):
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.encodeint_into(dst, dst_sc)
|
|
return dst
|
|
|
|
|
|
def _sc_mulsub(dst: bytearray | None, a: bytes, b: bytes, c: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_3, c)
|
|
crypto.sc_mulsub_into(_tmp_sc_4, _tmp_sc_1, _tmp_sc_2, _tmp_sc_3)
|
|
crypto.encodeint_into(dst, _tmp_sc_4)
|
|
return dst
|
|
|
|
|
|
def _add_keys(dst: bytearray | None, A: bytes, B: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodepoint_into(_tmp_pt_1, A)
|
|
crypto.decodepoint_into(_tmp_pt_2, B)
|
|
crypto.point_add_into(_tmp_pt_3, _tmp_pt_1, _tmp_pt_2)
|
|
crypto.encodepoint_into(dst, _tmp_pt_3)
|
|
return dst
|
|
|
|
|
|
def _sub_keys(dst: bytearray | None, A: bytes, B: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodepoint_into(_tmp_pt_1, A)
|
|
crypto.decodepoint_into(_tmp_pt_2, B)
|
|
crypto.point_sub_into(_tmp_pt_3, _tmp_pt_1, _tmp_pt_2)
|
|
crypto.encodepoint_into(dst, _tmp_pt_3)
|
|
return dst
|
|
|
|
|
|
def _add_keys2(dst: bytearray | None, a: bytes, b: bytes, B: bytes) -> bytearray:
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
crypto.decodepoint_into(_tmp_pt_1, B)
|
|
crypto.add_keys2_into(_tmp_pt_2, _tmp_sc_1, _tmp_sc_2, _tmp_pt_1)
|
|
crypto.encodepoint_into(dst, _tmp_pt_2)
|
|
return dst
|
|
|
|
|
|
def _hash_to_scalar(dst, data):
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.hash_to_scalar_into(_tmp_sc_1, data)
|
|
crypto.encodeint_into(dst, _tmp_sc_1)
|
|
return dst
|
|
|
|
|
|
def _hash_vct_to_scalar(dst, data):
|
|
dst = _ensure_dst_key(dst)
|
|
ctx = crypto_helpers.get_keccak()
|
|
for x in data:
|
|
ctx.update(x)
|
|
hsh = ctx.digest()
|
|
|
|
crypto.decodeint_into(_tmp_sc_1, hsh)
|
|
crypto.encodeint_into(dst, _tmp_sc_1)
|
|
return dst
|
|
|
|
|
|
def _get_exponent(dst, base, idx):
|
|
return _get_exponent_univ(dst, base, idx, b"bulletproof")
|
|
|
|
|
|
def _get_exponent_plus(dst, base, idx):
|
|
return _get_exponent_univ(dst, base, idx, b"bulletproof_plus")
|
|
|
|
|
|
def _get_exponent_univ(dst, base, idx, salt):
|
|
dst = _ensure_dst_key(dst)
|
|
lsalt = len(salt)
|
|
final_size = lsalt + 32 + uvarint_size(idx)
|
|
memcpy(_tmp_bf_exp, 0, base, 0, 32)
|
|
memcpy(_tmp_bf_exp, 32, salt, 0, lsalt)
|
|
dump_uvarint_b_into(idx, _tmp_bf_exp, 32 + lsalt)
|
|
crypto.fast_hash_into(_tmp_bf_1, _tmp_bf_exp, final_size)
|
|
crypto.hash_to_point_into(_tmp_pt_4, _tmp_bf_1)
|
|
crypto.encodepoint_into(dst, _tmp_pt_4)
|
|
return dst
|
|
|
|
|
|
def _sc_square_mult(
|
|
dst: crypto.Scalar | None, x: crypto.Scalar, n: int
|
|
) -> crypto.Scalar:
|
|
if n == 0:
|
|
return crypto.decodeint_into_noreduce(dst, _ONE)
|
|
|
|
lg = int(math.log(n, 2))
|
|
dst = crypto.sc_copy(dst, x)
|
|
for i in range(1, lg + 1):
|
|
crypto.sc_mul_into(dst, dst, dst)
|
|
if n & (1 << (lg - i)) > 0:
|
|
crypto.sc_mul_into(dst, dst, x)
|
|
return dst
|
|
|
|
|
|
def _invert_batch(x):
|
|
scratch = _ensure_dst_keyvect(None, len(x))
|
|
acc = bytearray(_ONE)
|
|
for n in range(len(x)):
|
|
utils.ensure(x[n] != _ZERO, "cannot invert zero")
|
|
scratch[n] = acc
|
|
if n == 0:
|
|
memcpy(acc, 0, x[0], 0, 32) # acc = x[0]
|
|
else:
|
|
_sc_mul(acc, acc, x[n])
|
|
|
|
_invert(acc, acc)
|
|
tmp = _ensure_dst_key(None)
|
|
|
|
for i in range(len(x) - 1, -1, -1):
|
|
_sc_mul(tmp, acc, x[i])
|
|
x[i] = _sc_mul(x[i], acc, scratch[i])
|
|
memcpy(acc, 0, tmp, 0, 32)
|
|
return x
|
|
|
|
|
|
def _sum_of_even_powers(res, x, n):
|
|
"""
|
|
Given a scalar, construct the sum of its powers from 2 to n (where n is a power of 2):
|
|
Output x**2 + x**4 + x**6 + ... + x**n
|
|
"""
|
|
utils.ensure(n & (n - 1) == 0, "n is not 2^x")
|
|
utils.ensure(n != 0, "n == 0")
|
|
|
|
x1 = bytearray(x)
|
|
_sc_mul(x1, x1, x1)
|
|
|
|
res = _ensure_dst_key(res)
|
|
memcpy(res, 0, x1, 0, len(x1))
|
|
while n > 2:
|
|
_sc_muladd(res, x1, res, res)
|
|
_sc_mul(x1, x1, x1)
|
|
n /= 2
|
|
return res
|
|
|
|
|
|
def _sum_of_scalar_powers(res, x, n):
|
|
"""
|
|
Given a scalar, return the sum of its powers from 1 to n
|
|
Output x**1 + x**2 + x**3 + ... + x**n
|
|
"""
|
|
utils.ensure(n != 0, "n == 0")
|
|
res = _ensure_dst_key(res)
|
|
memcpy(res, 0, _ONE, 0, len(_ONE))
|
|
|
|
if n == 1:
|
|
memcpy(res, 0, x, 0, len(x))
|
|
return res
|
|
|
|
n += 1
|
|
x1 = bytearray(x)
|
|
is_power_of_2 = (n & (n - 1)) == 0
|
|
if is_power_of_2:
|
|
_sc_add(res, res, x1)
|
|
while n > 2:
|
|
_sc_mul(x1, x1, x1)
|
|
_sc_muladd(res, x1, res, res)
|
|
n /= 2
|
|
else:
|
|
prev = bytearray(x1)
|
|
for i in range(1, n):
|
|
if i > 1:
|
|
_sc_mul(prev, prev, x1)
|
|
_sc_add(res, res, prev)
|
|
|
|
_sc_sub(res, res, _ONE)
|
|
return res
|
|
|
|
|
|
#
|
|
# Key Vectors
|
|
#
|
|
|
|
|
|
class KeyVBase(Generic[T]):
|
|
"""
|
|
Base KeyVector object
|
|
"""
|
|
|
|
__slots__ = ("current_idx", "size")
|
|
|
|
def __init__(self, elems: int = 64) -> None:
|
|
self.current_idx = 0
|
|
self.size = elems
|
|
|
|
def idxize(self, idx: int) -> int:
|
|
if idx < 0:
|
|
idx = self.size + idx
|
|
if idx >= self.size:
|
|
raise IndexError(f"Index out of bounds {idx} vs {self.size}")
|
|
return idx
|
|
|
|
def __getitem__(self, item: int) -> T:
|
|
raise NotImplementedError
|
|
|
|
def __setitem__(self, key: int, value: T) -> None:
|
|
raise NotImplementedError
|
|
|
|
def __iter__(self) -> Iterator[T]:
|
|
self.current_idx = 0
|
|
return self
|
|
|
|
def __next__(self) -> T:
|
|
if self.current_idx >= self.size:
|
|
raise StopIteration
|
|
else:
|
|
self.current_idx += 1
|
|
return self[self.current_idx - 1]
|
|
|
|
def __len__(self) -> int:
|
|
return self.size
|
|
|
|
def to(self, idx: int, buff: bytearray | None = None, offset: int = 0) -> bytearray:
|
|
buff = _ensure_dst_key(buff)
|
|
return memcpy(buff, offset, self.__getitem__(self.idxize(idx)), 0, 32)
|
|
|
|
def read(self, idx: int, buff: bytes, offset: int = 0) -> bytes:
|
|
raise NotImplementedError
|
|
|
|
def slice(self, res, start: int, stop: int):
|
|
for i in range(start, stop):
|
|
res[i - start] = self[i]
|
|
return res
|
|
|
|
def slice_view(self, start: int, stop: int) -> "KeyVSliced":
|
|
return KeyVSliced(self, start, stop)
|
|
|
|
|
|
_CHBITS = const(5)
|
|
_CHSIZE = const(1 << _CHBITS)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
KeyVBaseType = KeyVBase
|
|
else:
|
|
KeyVBaseType = (KeyVBase,)
|
|
|
|
|
|
class KeyV(KeyVBaseType[T]):
|
|
"""
|
|
KeyVector abstraction
|
|
Constant precomputed buffers = bytes, frozen. Same operation as normal.
|
|
|
|
Non-constant KeyVector is separated to _CHSIZE elements chunks to avoid problems with
|
|
the heap fragmentation. In this it is more probable that the chunks are correctly
|
|
allocated as smaller continuous memory is required. Chunk is assumed to
|
|
have _CHSIZE elements at all times to minimize corner cases handling. BP require either
|
|
multiple of _CHSIZE elements vectors or less than _CHSIZE.
|
|
|
|
Some chunk-dependent cases are not implemented as they are currently not needed in the BP.
|
|
"""
|
|
|
|
__slots__ = ("current_idx", "size", "d", "mv", "const", "cur", "chunked")
|
|
|
|
def __init__(
|
|
self,
|
|
elems: int = 64,
|
|
buffer: bytes | None = None,
|
|
const: bool = False,
|
|
no_init: bool = False,
|
|
) -> None:
|
|
super().__init__(elems)
|
|
self.d: bytes | bytearray | list[bytearray] | None = None
|
|
self.mv: memoryview | None = None
|
|
self.const = const
|
|
self.cur = _ensure_dst_key()
|
|
self.chunked = False
|
|
if no_init:
|
|
pass
|
|
elif buffer:
|
|
self.d = buffer # can be immutable (bytes)
|
|
self.size = len(buffer) // 32
|
|
else:
|
|
self._set_d(elems)
|
|
|
|
if not no_init:
|
|
self._set_mv()
|
|
|
|
def _set_d(self, elems: int) -> None:
|
|
if elems > _CHSIZE and elems % _CHSIZE == 0:
|
|
self.chunked = True
|
|
gc.collect()
|
|
self.d = [bytearray(32 * _CHSIZE) for _ in range(elems // _CHSIZE)]
|
|
|
|
else:
|
|
self.chunked = False
|
|
gc.collect()
|
|
self.d = bytearray(32 * elems)
|
|
|
|
def _set_mv(self) -> None:
|
|
if not self.chunked:
|
|
assert isinstance(self.d, TBYTES)
|
|
self.mv = memoryview(self.d)
|
|
|
|
def __getitem__(self, item):
|
|
"""
|
|
Returns corresponding 32 byte array.
|
|
Creates new memoryview on access.
|
|
"""
|
|
if self.chunked:
|
|
return self.to(item)
|
|
item = self.idxize(item)
|
|
assert self.mv is not None
|
|
return self.mv[item * 32 : (item + 1) * 32]
|
|
|
|
def __setitem__(self, key, value):
|
|
if self.chunked:
|
|
self.read(key, value)
|
|
if self.const:
|
|
raise ValueError("Constant KeyV")
|
|
ck = self[key]
|
|
for i in range(32):
|
|
ck[i] = value[i]
|
|
|
|
def to(self, idx, buff: bytearray | None = None, offset: int = 0):
|
|
idx = self.idxize(idx)
|
|
if self.chunked:
|
|
assert isinstance(self.d, list)
|
|
memcpy(
|
|
buff if buff else self.cur,
|
|
offset,
|
|
self.d[idx >> _CHBITS],
|
|
(idx & (_CHSIZE - 1)) << 5,
|
|
32,
|
|
)
|
|
else:
|
|
assert isinstance(self.d, (bytes, bytearray))
|
|
memcpy(buff if buff else self.cur, offset, self.d, idx << 5, 32)
|
|
return buff if buff else self.cur
|
|
|
|
def read(self, idx: int, buff: bytes, offset: int = 0) -> bytes:
|
|
idx = self.idxize(idx)
|
|
if self.chunked:
|
|
assert isinstance(self.d, list)
|
|
memcpy(self.d[idx >> _CHBITS], (idx & (_CHSIZE - 1)) << 5, buff, offset, 32)
|
|
else:
|
|
assert isinstance(self.d, bytearray)
|
|
memcpy(self.d, idx << 5, buff, offset, 32)
|
|
|
|
def resize(self, nsize, chop: int = False, realloc: int = False):
|
|
if self.size == nsize:
|
|
return
|
|
|
|
if self.chunked and nsize <= _CHSIZE:
|
|
assert isinstance(self.d, list)
|
|
self.chunked = False # de-chunk
|
|
if self.size > nsize and realloc:
|
|
gc.collect()
|
|
self.d = bytearray(self.d[0][: nsize << 5])
|
|
elif self.size > nsize and not chop:
|
|
gc.collect()
|
|
self.d = self.d[0][: nsize << 5]
|
|
else:
|
|
gc.collect()
|
|
self.d = bytearray(nsize << 5)
|
|
|
|
elif self.chunked and self.size < nsize:
|
|
assert isinstance(self.d, list)
|
|
if nsize % _CHSIZE != 0 or realloc or chop:
|
|
raise ValueError("Unsupported") # not needed
|
|
for i in range((nsize - self.size) // _CHSIZE):
|
|
self.d.append(bytearray(32 * _CHSIZE))
|
|
|
|
elif self.chunked:
|
|
assert isinstance(self.d, list)
|
|
if nsize % _CHSIZE != 0:
|
|
raise ValueError("Unsupported") # not needed
|
|
for i in range((self.size - nsize) // _CHSIZE):
|
|
self.d.pop()
|
|
if realloc:
|
|
for i in range(nsize // _CHSIZE):
|
|
self.d[i] = bytearray(self.d[i])
|
|
|
|
else:
|
|
assert isinstance(self.d, (bytes, bytearray))
|
|
if self.size > nsize and realloc:
|
|
gc.collect()
|
|
self.d = bytearray(self.d[: nsize << 5])
|
|
elif self.size > nsize and not chop:
|
|
gc.collect()
|
|
self.d = self.d[: nsize << 5]
|
|
else:
|
|
gc.collect()
|
|
self.d = bytearray(nsize << 5)
|
|
|
|
self.size = nsize
|
|
self._set_mv()
|
|
|
|
def realloc(self, nsize, collect: int = False):
|
|
self.d = None
|
|
self.mv = None
|
|
if collect:
|
|
gc.collect() # gc collect prev. allocation
|
|
|
|
self._set_d(nsize)
|
|
self.size = nsize
|
|
self._set_mv()
|
|
|
|
def realloc_init_from(self, nsize, src, offset: int = 0, collect: int = False):
|
|
if not isinstance(src, KeyV):
|
|
raise ValueError("KeyV supported only")
|
|
self.realloc(nsize, collect)
|
|
|
|
if not self.chunked and not src.chunked:
|
|
assert isinstance(self.d, bytearray)
|
|
assert isinstance(src.d, (bytes, bytearray))
|
|
memcpy(self.d, 0, src.d, offset << 5, nsize << 5)
|
|
|
|
elif self.chunked and not src.chunked or self.chunked and src.chunked:
|
|
for i in range(nsize):
|
|
self.read(i, src.to(i + offset))
|
|
|
|
elif not self.chunked and src.chunked:
|
|
assert isinstance(self.d, bytearray)
|
|
assert isinstance(src.d, list)
|
|
for i in range(nsize >> _CHBITS):
|
|
memcpy(
|
|
self.d,
|
|
i << 11,
|
|
src.d[i + (offset >> _CHBITS)],
|
|
(offset & (_CHSIZE - 1)) << 5 if i == 0 else 0,
|
|
nsize << 5 if i <= nsize >> _CHBITS else (nsize & _CHSIZE) << 5,
|
|
)
|
|
|
|
|
|
class KeyVEval(KeyVBase):
|
|
"""
|
|
KeyVector computed / evaluated on demand
|
|
"""
|
|
|
|
__slots__ = ("current_idx", "size", "fnc", "raw", "scalar", "buff")
|
|
|
|
def __init__(self, elems=64, src=None, raw: int = False, scalar=True):
|
|
super().__init__(elems)
|
|
self.fnc = src
|
|
self.raw = raw
|
|
self.scalar = scalar
|
|
self.buff = (
|
|
_ensure_dst_key()
|
|
if not raw
|
|
else (crypto.Scalar() if scalar else crypto.Point())
|
|
)
|
|
|
|
def __getitem__(self, item):
|
|
return self.fnc(self.idxize(item), self.buff)
|
|
|
|
def to(self, idx, buff: bytearray | None = None, offset: int = 0):
|
|
self.fnc(self.idxize(idx), self.buff)
|
|
if self.raw:
|
|
if offset != 0:
|
|
raise ValueError("Not supported")
|
|
if self.scalar and buff:
|
|
return crypto.sc_copy(buff, self.buff)
|
|
elif self.scalar:
|
|
return self.buff
|
|
else:
|
|
raise ValueError("Not supported")
|
|
else:
|
|
memcpy(buff, offset, self.buff, 0, 32)
|
|
return buff if buff else self.buff
|
|
|
|
|
|
class KeyVSized(KeyVBase):
|
|
"""
|
|
Resized vector, wrapping possibly larger vector
|
|
(e.g., precomputed, but has to have exact size for further computations)
|
|
"""
|
|
|
|
__slots__ = ("current_idx", "size", "wrapped")
|
|
|
|
def __init__(self, wrapped, new_size):
|
|
super().__init__(new_size)
|
|
self.wrapped = wrapped
|
|
|
|
def __getitem__(self, item):
|
|
return self.wrapped[self.idxize(item)]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.wrapped[self.idxize(key)] = value
|
|
|
|
|
|
class KeyVConst(KeyVBase):
|
|
__slots__ = ("current_idx", "size", "elem")
|
|
|
|
def __init__(self, size, elem, copy=True):
|
|
super().__init__(size)
|
|
self.elem = _init_key(elem) if copy else elem
|
|
|
|
def __getitem__(self, item):
|
|
return self.elem
|
|
|
|
def to(self, idx: int, buff: bytearray, offset: int = 0):
|
|
memcpy(buff, offset, self.elem, 0, 32)
|
|
return buff if buff else self.elem
|
|
|
|
|
|
class KeyVPrecomp(KeyVBase):
|
|
"""
|
|
Vector with possibly large size and some precomputed prefix.
|
|
Usable for Gi vector with precomputed usual sizes (i.e., 2 output transactions)
|
|
but possible to compute further
|
|
"""
|
|
|
|
__slots__ = ("current_idx", "size", "precomp_prefix", "aux_comp_fnc", "buff")
|
|
|
|
def __init__(self, size, precomp_prefix, aux_comp_fnc) -> None:
|
|
super().__init__(size)
|
|
self.precomp_prefix = precomp_prefix
|
|
self.aux_comp_fnc = aux_comp_fnc
|
|
self.buff = _ensure_dst_key()
|
|
|
|
def __getitem__(self, item):
|
|
item = self.idxize(item)
|
|
if item < len(self.precomp_prefix):
|
|
return self.precomp_prefix[item]
|
|
return self.aux_comp_fnc(item, self.buff)
|
|
|
|
def to(self, idx: int, buff: bytearray | None = None, offset: int = 0) -> bytearray:
|
|
item = self.idxize(idx)
|
|
if item < len(self.precomp_prefix):
|
|
return self.precomp_prefix.to(item, buff if buff else self.buff, offset)
|
|
self.aux_comp_fnc(item, self.buff)
|
|
memcpy(buff, offset, self.buff, 0, 32)
|
|
return buff if buff else self.buff
|
|
|
|
|
|
class KeyVSliced(KeyVBase):
|
|
"""
|
|
Sliced in-memory vector version, remapping
|
|
"""
|
|
|
|
__slots__ = ("current_idx", "size", "wrapped", "offset")
|
|
|
|
def __init__(self, src, start, stop):
|
|
super().__init__(stop - start)
|
|
self.wrapped = src
|
|
self.offset = start
|
|
|
|
def __getitem__(self, item):
|
|
return self.wrapped[self.offset + self.idxize(item)]
|
|
|
|
def __setitem__(self, key, value) -> None:
|
|
self.wrapped[self.offset + self.idxize(key)] = value
|
|
|
|
def resize(self, nsize: int, chop: bool = False) -> None:
|
|
raise ValueError("Not supported")
|
|
|
|
def to(self, idx, buff: bytearray | None = None, offset: int = 0):
|
|
return self.wrapped.to(self.offset + self.idxize(idx), buff, offset)
|
|
|
|
def read(self, idx, buff, offset: int = 0):
|
|
return self.wrapped.read(self.offset + self.idxize(idx), buff, offset)
|
|
|
|
|
|
class KeyVPowers(KeyVBase):
|
|
"""
|
|
Vector of x^i. Allows only sequential access (no jumping). Resets on [0,1] access.
|
|
"""
|
|
|
|
__slots__ = ("current_idx", "size", "x", "raw", "cur", "last_idx")
|
|
|
|
def __init__(self, size, x, raw: int = False):
|
|
super().__init__(size)
|
|
self.x = x if not raw else crypto.decodeint_into_noreduce(None, x)
|
|
self.raw = raw
|
|
self.cur = bytearray(32) if not raw else crypto.Scalar()
|
|
self.last_idx = 0
|
|
|
|
def __getitem__(self, item):
|
|
prev = self.last_idx
|
|
item = self.idxize(item)
|
|
self.last_idx = item
|
|
|
|
if item == 0:
|
|
return (
|
|
_copy_key(self.cur, _ONE)
|
|
if not self.raw
|
|
else crypto.decodeint_into_noreduce(None, _ONE)
|
|
)
|
|
elif item == 1:
|
|
return (
|
|
_copy_key(self.cur, self.x)
|
|
if not self.raw
|
|
else crypto.sc_copy(self.cur, self.x)
|
|
)
|
|
elif item == prev:
|
|
return self.cur
|
|
elif item == prev + 1:
|
|
return (
|
|
_sc_mul(self.cur, self.cur, self.x)
|
|
if not self.raw
|
|
else crypto.sc_mul_into(self.cur, self.cur, self.x)
|
|
)
|
|
else:
|
|
raise IndexError(f"KeyVPowers: Only linear scan allowed: {prev}, {item}")
|
|
|
|
def set_state(self, idx: int, val):
|
|
self.last_idx = idx
|
|
if self.raw:
|
|
return crypto.sc_copy(self.cur, val)
|
|
else:
|
|
return _copy_key(self.cur, val)
|
|
|
|
|
|
class KeyVPowersBackwards(KeyVBase):
|
|
"""
|
|
Vector of x^i.
|
|
|
|
Used with BP+
|
|
Allows arbitrary jumps as it is used in the folding mechanism. However, sequential access is the fastest.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"current_idx",
|
|
"size",
|
|
"x",
|
|
"x_inv",
|
|
"x_max",
|
|
"cur_sc",
|
|
"tmp_sc",
|
|
"raw",
|
|
"cur",
|
|
"last_idx",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
x: ScalarDst,
|
|
x_inv: ScalarDst | None = None,
|
|
x_max: ScalarDst | None = None,
|
|
raw: int = False,
|
|
):
|
|
super().__init__(size)
|
|
self.raw = raw
|
|
self.cur = bytearray(32) if not raw else crypto.Scalar()
|
|
self.cur_sc = crypto.Scalar()
|
|
self.last_idx = 0
|
|
|
|
self.x = _load_scalar(None, x)
|
|
self.x_inv = crypto.Scalar()
|
|
self.x_max = crypto.Scalar()
|
|
self.tmp_sc = crypto.Scalar() # TODO: use static helper when everything works
|
|
if x_inv:
|
|
_load_scalar(self.x_inv, x_inv)
|
|
else:
|
|
crypto.sc_inv_into(self.x_inv, self.x)
|
|
|
|
if x_max:
|
|
_load_scalar(self.x_max, x_max)
|
|
else:
|
|
_sc_square_mult(self.x_max, self.x, size - 1)
|
|
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.last_idx = self.size - 1
|
|
crypto.sc_copy(self.cur_sc, self.x_max)
|
|
|
|
def move_more(self, item: int, prev: int):
|
|
sdiff = prev - item
|
|
if sdiff < 0:
|
|
raise ValueError("Not supported")
|
|
|
|
_sc_square_mult(self.tmp_sc, self.x_inv, sdiff)
|
|
crypto.sc_mul_into(self.cur_sc, self.cur_sc, self.tmp_sc)
|
|
|
|
def __getitem__(self, item):
|
|
prev = self.last_idx
|
|
item = self.idxize(item)
|
|
self.last_idx = item
|
|
|
|
if item == 0:
|
|
return self.cur_sc if self.raw else _copy_key(self.cur, _ONE)
|
|
elif item == 1:
|
|
crypto.sc_copy(self.cur_sc, self.x)
|
|
elif item == self.size - 1: # reset
|
|
self.reset()
|
|
elif item == prev:
|
|
pass
|
|
elif (
|
|
item == prev - 1
|
|
): # backward step, mult inverse to decrease acc power by one
|
|
crypto.sc_mul_into(self.cur_sc, self.cur_sc, self.x_inv)
|
|
elif item < prev: # jump backward
|
|
self.move_more(item, prev)
|
|
else: # arbitrary jump
|
|
self.reset()
|
|
self.move_more(item, self.last_idx)
|
|
self.last_idx = item
|
|
|
|
return self.cur_sc if self.raw else crypto.encodeint_into(self.cur, self.cur_sc)
|
|
|
|
|
|
class VctD(KeyVBase):
|
|
"""
|
|
Vector of d[j*N+i] = z**(2*(j+1)) * 2**i, i \\in [0,N), j \\in [0,M)
|
|
|
|
Used with BP+.
|
|
Allows arbitrary jumps as it is used in the folding mechanism. However, sequential access is the fastest.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"current_idx",
|
|
"size",
|
|
"N",
|
|
"z_sq",
|
|
"z_last",
|
|
"two",
|
|
"cur_sc",
|
|
"tmp_sc",
|
|
"cur",
|
|
"last_idx",
|
|
"current_n",
|
|
"raw",
|
|
)
|
|
|
|
def __init__(self, N: int, M: int, z_sq: bytearray, raw: bool = False):
|
|
super().__init__(N * M)
|
|
self.N = N
|
|
self.raw = raw
|
|
self.z_sq = crypto.decodeint_into_noreduce(None, z_sq)
|
|
self.z_last = crypto.Scalar()
|
|
self.two = crypto.decodeint_into_noreduce(None, _TWO)
|
|
self.cur_sc = crypto.Scalar()
|
|
self.tmp_sc = crypto.Scalar() # TODO: use static helper when everything works
|
|
self.cur = _ensure_dst_key() if not self.raw else None
|
|
self.last_idx = 0
|
|
self.current_n = 0
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.current_idx = 0
|
|
self.current_n = 0
|
|
crypto.sc_copy(self.z_last, self.z_sq)
|
|
crypto.sc_copy(self.cur_sc, self.z_sq)
|
|
if not self.raw:
|
|
crypto.encodeint_into(self.cur, self.cur_sc) # z**2 + 2**0
|
|
|
|
def move_one(self, item: int):
|
|
"""Fast linear jump step"""
|
|
self.current_n += 1
|
|
if item != 0 and self.current_n >= self.N: # reset 2**i part,
|
|
self.current_n = 0
|
|
crypto.sc_mul_into(self.z_last, self.z_last, self.z_sq)
|
|
crypto.sc_copy(self.cur_sc, self.z_last)
|
|
else:
|
|
crypto.sc_mul_into(self.cur_sc, self.cur_sc, self.two)
|
|
if not self.raw:
|
|
crypto.encodeint_into(self.cur, self.cur_sc)
|
|
|
|
def move_more(self, item: int, prev: int):
|
|
"""More costly but required arbitrary jump forward"""
|
|
sdiff = item - prev
|
|
if sdiff < 0:
|
|
raise ValueError("Not supported")
|
|
|
|
self.current_n = item % self.N # reset for move_one incremental move
|
|
same_2 = sdiff % self.N == 0 # same 2**i component? simpler move
|
|
z_squares_to_mul = (item // self.N) - (prev // self.N)
|
|
|
|
# If z component needs to be updated, compute update and add it
|
|
if z_squares_to_mul > 0:
|
|
_sc_square_mult(self.tmp_sc, self.z_sq, z_squares_to_mul)
|
|
crypto.sc_mul_into(self.z_last, self.z_last, self.tmp_sc)
|
|
if same_2:
|
|
crypto.sc_mul_into(self.cur_sc, self.cur_sc, self.tmp_sc)
|
|
return
|
|
|
|
# Optimal jump is complicated as due to 2**(i%64), power2 component can be lower in the new position
|
|
# Thus reset and rebuild from z_last
|
|
if not same_2:
|
|
crypto.sc_copy(self.cur_sc, self.z_last)
|
|
_sc_square_mult(self.tmp_sc, self.two, item % self.N)
|
|
crypto.sc_mul_into(self.cur_sc, self.cur_sc, self.tmp_sc)
|
|
|
|
def __getitem__(self, item):
|
|
prev = self.last_idx
|
|
item = self.idxize(item)
|
|
self.last_idx = item
|
|
|
|
if item == 0:
|
|
self.reset()
|
|
elif item == prev:
|
|
pass
|
|
elif item == prev + 1:
|
|
self.move_one(item)
|
|
elif item > prev:
|
|
self.move_more(item, prev)
|
|
else:
|
|
self.reset()
|
|
self.move_more(item, 0)
|
|
return self.cur if not self.raw else self.cur_sc
|
|
|
|
|
|
class KeyHadamardFoldedVct(KeyVBase):
|
|
"""
|
|
Hadamard-folded evaluated vector
|
|
"""
|
|
|
|
__slots__ = (
|
|
"current_idx",
|
|
"size",
|
|
"src",
|
|
"a",
|
|
"b",
|
|
"raw",
|
|
"gc_fnc",
|
|
"cur_pt",
|
|
"tmp_pt",
|
|
"cur",
|
|
)
|
|
|
|
def __init__(
|
|
self, src: KeyVBase, a: ScalarDst, b: ScalarDst, raw: bool = False, gc_fnc=None
|
|
):
|
|
super().__init__(len(src) >> 1)
|
|
self.src = src
|
|
self.raw = raw
|
|
self.gc_fnc = gc_fnc
|
|
self.a = _load_scalar(None, a)
|
|
self.b = _load_scalar(None, b)
|
|
self.cur_pt = crypto.Point()
|
|
self.tmp_pt = crypto.Point()
|
|
self.cur = _ensure_dst_key() if not self.raw else None
|
|
|
|
def __getitem__(self, item):
|
|
i = self.idxize(item)
|
|
crypto.decodepoint_into(self.cur_pt, self.src.to(i))
|
|
crypto.decodepoint_into(self.tmp_pt, self.src.to(self.size + i))
|
|
crypto.add_keys3_into(self.cur_pt, self.a, self.cur_pt, self.b, self.tmp_pt)
|
|
if self.gc_fnc:
|
|
self.gc_fnc(i)
|
|
if not self.raw:
|
|
return crypto.encodepoint_into(self.cur, self.cur_pt)
|
|
else:
|
|
return self.cur_pt
|
|
|
|
|
|
class KeyScalarFoldedVct(KeyVBase):
|
|
"""
|
|
Scalar-folded evaluated vector
|
|
"""
|
|
|
|
__slots__ = (
|
|
"current_idx",
|
|
"size",
|
|
"src",
|
|
"a",
|
|
"b",
|
|
"raw",
|
|
"gc_fnc",
|
|
"cur_sc",
|
|
"tmp_sc",
|
|
"cur",
|
|
)
|
|
|
|
def __init__(
|
|
self, src: KeyVBase, a: ScalarDst, b: ScalarDst, raw: bool = False, gc_fnc=None
|
|
):
|
|
super().__init__(len(src) >> 1)
|
|
self.src = src
|
|
self.raw = raw
|
|
self.gc_fnc = gc_fnc
|
|
self.a = _load_scalar(None, a)
|
|
self.b = _load_scalar(None, b)
|
|
self.cur_sc = crypto.Scalar()
|
|
self.tmp_sc = crypto.Scalar()
|
|
self.cur = _ensure_dst_key() if not self.raw else None
|
|
|
|
def __getitem__(self, item):
|
|
i = self.idxize(item)
|
|
|
|
crypto.decodeint_into_noreduce(self.tmp_sc, self.src.to(i))
|
|
crypto.sc_mul_into(self.tmp_sc, self.tmp_sc, self.a)
|
|
crypto.decodeint_into_noreduce(self.cur_sc, self.src.to(self.size + i))
|
|
crypto.sc_muladd_into(self.cur_sc, self.cur_sc, self.b, self.tmp_sc)
|
|
|
|
if self.gc_fnc:
|
|
self.gc_fnc(i)
|
|
if not self.raw:
|
|
return crypto.encodeint_into(self.cur, self.cur_sc)
|
|
else:
|
|
return self.cur_sc
|
|
|
|
|
|
class KeyPow2Vct(KeyVBase):
|
|
"""
|
|
2**i vector, note that Curve25519 has scalar order 2 ** 252 + 27742317777372353535851937790883648493
|
|
"""
|
|
|
|
__slots__ = (
|
|
"size",
|
|
"raw",
|
|
"cur",
|
|
"cur_sc",
|
|
)
|
|
|
|
def __init__(self, size: int, raw: bool = False):
|
|
super().__init__(size)
|
|
self.raw = raw
|
|
self.cur = _ensure_dst_key()
|
|
self.cur_sc = crypto.Scalar()
|
|
|
|
def __getitem__(self, item):
|
|
i = self.idxize(item)
|
|
if i == 0:
|
|
_copy_key(self.cur, _ONE)
|
|
elif i == 1:
|
|
_copy_key(self.cur, _TWO)
|
|
else:
|
|
_copy_key(self.cur, _ZERO)
|
|
self.cur[i >> 3] = 1 << (i & 7)
|
|
|
|
if i < 252 and self.raw:
|
|
return crypto.decodeint_into_noreduce(self.cur_sc, self.cur)
|
|
|
|
if i > 252: # reduction, costly
|
|
crypto.decodeint_into(self.cur_sc, self.cur)
|
|
if not self.raw:
|
|
return crypto.encodeint_into(self.cur, self.cur_sc)
|
|
|
|
return self.cur_sc if self.raw else self.cur
|
|
|
|
|
|
class KeyChallengeCacheVct(KeyVBase):
|
|
"""
|
|
Challenge cache vector for BP+ verification
|
|
More on this in the verification code, near "challenge_cache" definition
|
|
"""
|
|
|
|
__slots__ = (
|
|
"nbits",
|
|
"ch_",
|
|
"chi",
|
|
"precomp",
|
|
"precomp_depth",
|
|
"cur",
|
|
)
|
|
|
|
def __init__(
|
|
self, nbits: int, ch_: KeyVBase, chi: KeyVBase, precomputed: KeyVBase | None
|
|
):
|
|
super().__init__(1 << nbits)
|
|
self.nbits = nbits
|
|
self.ch_ = ch_
|
|
self.chi = chi
|
|
self.precomp = precomputed
|
|
self.precomp_depth = 0
|
|
self.cur = _ensure_dst_key()
|
|
if not precomputed:
|
|
return
|
|
|
|
while (1 << self.precomp_depth) < len(precomputed):
|
|
self.precomp_depth += 1
|
|
|
|
def __getitem__(self, item):
|
|
i = self.idxize(item)
|
|
bits_done = 0
|
|
|
|
if self.precomp:
|
|
_copy_key(self.cur, self.precomp[i >> (self.nbits - self.precomp_depth)])
|
|
bits_done += self.precomp_depth
|
|
else:
|
|
_copy_key(self.cur, _ONE)
|
|
|
|
for j in range(self.nbits - 1, bits_done - 1, -1):
|
|
if i & (1 << (self.nbits - 1 - j)) > 0:
|
|
_sc_mul(self.cur, self.cur, self.ch_[j])
|
|
else:
|
|
_sc_mul(self.cur, self.cur, self.chi[j])
|
|
return self.cur
|
|
|
|
|
|
class KeyR0(KeyVBase):
|
|
"""
|
|
Vector r0. Allows only sequential access (no jumping). Resets on [0,1] access.
|
|
zt_i = z^{2 + \floor{i/N}} 2^{i % N}
|
|
r0_i = ((a_{Ri} + z) y^{i}) + zt_i
|
|
|
|
Could be composed from smaller vectors, but RAW returns are required
|
|
"""
|
|
|
|
__slots__ = (
|
|
"current_idx",
|
|
"size",
|
|
"N",
|
|
"aR",
|
|
"raw",
|
|
"y",
|
|
"yp",
|
|
"z",
|
|
"zt",
|
|
"p2",
|
|
"res",
|
|
"cur",
|
|
"last_idx",
|
|
)
|
|
|
|
def __init__(self, size, N, aR, y, z, raw: int = False, **kwargs) -> None:
|
|
super().__init__(size)
|
|
self.N = N
|
|
self.aR = aR
|
|
self.raw = raw
|
|
self.y = crypto.decodeint_into_noreduce(None, y)
|
|
self.yp = crypto.Scalar() # y^{i}
|
|
self.z = crypto.decodeint_into_noreduce(None, z)
|
|
self.zt = crypto.Scalar() # z^{2 + \floor{i/N}}
|
|
self.p2 = crypto.Scalar() # 2^{i \% N}
|
|
self.res = crypto.Scalar() # tmp_sc_1
|
|
|
|
self.cur = bytearray(32) if not raw else None
|
|
self.last_idx = 0
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
crypto.decodeint_into_noreduce(self.yp, _ONE)
|
|
crypto.decodeint_into_noreduce(self.p2, _ONE)
|
|
crypto.sc_mul_into(self.zt, self.z, self.z)
|
|
|
|
def __getitem__(self, item):
|
|
prev = self.last_idx
|
|
item = self.idxize(item)
|
|
self.last_idx = item
|
|
|
|
# Const init for eval
|
|
if item == 0: # Reset on first item access
|
|
self.reset()
|
|
|
|
elif item == prev + 1:
|
|
crypto.sc_mul_into(self.yp, self.yp, self.y) # ypow
|
|
if item % self.N == 0:
|
|
crypto.sc_mul_into(self.zt, self.zt, self.z) # zt
|
|
crypto.decodeint_into_noreduce(self.p2, _ONE) # p2 reset
|
|
else:
|
|
crypto.decodeint_into_noreduce(self.res, _TWO) # p2
|
|
crypto.sc_mul_into(self.p2, self.p2, self.res) # p2
|
|
|
|
elif item == prev: # No advancing
|
|
pass
|
|
|
|
else:
|
|
raise IndexError("Only linear scan allowed")
|
|
|
|
# Eval r0[i]
|
|
if (
|
|
item == 0 or item != prev
|
|
): # if True not present, fails with cross dot product
|
|
crypto.decodeint_into_noreduce(self.res, self.aR.to(item)) # aR[i]
|
|
crypto.sc_add_into(self.res, self.res, self.z) # aR[i] + z
|
|
crypto.sc_mul_into(self.res, self.res, self.yp) # (aR[i] + z) * y^i
|
|
crypto.sc_muladd_into(
|
|
self.res, self.zt, self.p2, self.res
|
|
) # (aR[i] + z) * y^i + z^{2 + \floor{i/N}} 2^{i \% N}
|
|
|
|
if self.raw:
|
|
return self.res
|
|
|
|
crypto.encodeint_into(self.cur, self.res)
|
|
return self.cur
|
|
|
|
|
|
def _ensure_dst_keyvect(dst=None, size: int | None = None):
|
|
if dst is None:
|
|
dst = KeyV(elems=size)
|
|
return dst
|
|
if size is not None and size != len(dst):
|
|
dst.resize(size)
|
|
return dst
|
|
|
|
|
|
def _const_vector(val, elems=_BP_N, copy: bool = True) -> KeyVConst:
|
|
return KeyVConst(elems, val, copy)
|
|
|
|
|
|
def _vector_exponent_custom(A, B, a, b, dst=None, a_raw=None, b_raw=None):
|
|
"""
|
|
\\sum_{i=0}^{|A|} a_i A_i + b_i B_i
|
|
"""
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.identity_into(_tmp_pt_2)
|
|
|
|
for i in range(len(a or a_raw)):
|
|
if a:
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a.to(i))
|
|
crypto.decodepoint_into(_tmp_pt_3, A.to(i))
|
|
if b:
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b.to(i))
|
|
crypto.decodepoint_into(_tmp_pt_4, B.to(i))
|
|
crypto.add_keys3_into(
|
|
_tmp_pt_1,
|
|
_tmp_sc_1 if a else a_raw.to(i),
|
|
_tmp_pt_3,
|
|
_tmp_sc_2 if b else b_raw.to(i),
|
|
_tmp_pt_4,
|
|
)
|
|
crypto.point_add_into(_tmp_pt_2, _tmp_pt_2, _tmp_pt_1)
|
|
_gc_iter(i)
|
|
crypto.encodepoint_into(dst, _tmp_pt_2)
|
|
return dst
|
|
|
|
|
|
def _vector_powers(x, n, dst=None, dynamic: int = False):
|
|
"""
|
|
r_i = x^i
|
|
"""
|
|
if dynamic:
|
|
return KeyVPowers(n, x)
|
|
dst = _ensure_dst_keyvect(dst, n)
|
|
if n == 0:
|
|
return dst
|
|
dst.read(0, _ONE)
|
|
if n == 1:
|
|
return dst
|
|
dst.read(1, x)
|
|
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, x)
|
|
for i in range(2, n):
|
|
crypto.sc_mul_into(_tmp_sc_1, _tmp_sc_1, _tmp_sc_2)
|
|
crypto.encodeint_into(_tmp_bf_0, _tmp_sc_1)
|
|
dst.read(i, _tmp_bf_0)
|
|
_gc_iter(i)
|
|
return dst
|
|
|
|
|
|
def _vector_power_sum(x, n, dst=None):
|
|
"""
|
|
\\sum_{i=0}^{n-1} x^i
|
|
"""
|
|
dst = _ensure_dst_key(dst)
|
|
if n == 0:
|
|
return _copy_key(dst, _ZERO)
|
|
if n == 1:
|
|
_copy_key(dst, _ONE)
|
|
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_3, _ONE)
|
|
crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_1)
|
|
crypto.sc_copy(_tmp_sc_2, _tmp_sc_1)
|
|
|
|
for i in range(2, n):
|
|
crypto.sc_mul_into(_tmp_sc_2, _tmp_sc_2, _tmp_sc_1)
|
|
crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_2)
|
|
_gc_iter(i)
|
|
|
|
return crypto.encodeint_into(dst, _tmp_sc_3)
|
|
|
|
|
|
def _inner_product(a, b, dst=None):
|
|
"""
|
|
\\sum_{i=0}^{|a|} a_i b_i
|
|
"""
|
|
if len(a) != len(b):
|
|
raise ValueError("Incompatible sizes of a and b")
|
|
dst = _ensure_dst_key(dst)
|
|
crypto.sc_copy(_tmp_sc_1, 0)
|
|
|
|
for i in range(len(a)):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, a.to(i))
|
|
crypto.decodeint_into_noreduce(_tmp_sc_3, b.to(i))
|
|
crypto.sc_muladd_into(_tmp_sc_1, _tmp_sc_2, _tmp_sc_3, _tmp_sc_1)
|
|
_gc_iter(i)
|
|
|
|
crypto.encodeint_into(dst, _tmp_sc_1)
|
|
return dst
|
|
|
|
|
|
def _weighted_inner_product(
|
|
dst: bytearray | None, a: KeyVBase, b: KeyVBase, y: bytearray
|
|
):
|
|
"""
|
|
Output a_0*b_0*y**1 + a_1*b_1*y**2 + ... + a_{n-1}*b_{n-1}*y**n
|
|
"""
|
|
if len(a) != len(b):
|
|
raise ValueError("Incompatible sizes of a and b")
|
|
dst = _ensure_dst_key(dst)
|
|
y_sc = crypto.decodeint_into_noreduce(_tmp_sc_4, y)
|
|
y_pow = crypto.sc_copy(_tmp_sc_5, _tmp_sc_4)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, _ZERO)
|
|
|
|
for i in range(len(a)):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, a.to(i))
|
|
crypto.decodeint_into_noreduce(_tmp_sc_3, b.to(i))
|
|
crypto.sc_mul_into(_tmp_sc_2, _tmp_sc_2, _tmp_sc_3)
|
|
crypto.sc_muladd_into(_tmp_sc_1, _tmp_sc_2, y_pow, _tmp_sc_1)
|
|
crypto.sc_mul_into(y_pow, y_pow, y_sc)
|
|
_gc_iter(i)
|
|
|
|
crypto.encodeint_into(dst, _tmp_sc_1)
|
|
return dst
|
|
|
|
|
|
def _hadamard_fold(v, a, b, into=None, into_offset: int = 0, vR=None, vRoff=0):
|
|
"""
|
|
Folds a curvepoint array using a two way scaled Hadamard product
|
|
|
|
ln = len(v); h = ln // 2
|
|
v_i = a v_i + b v_{h + i}
|
|
"""
|
|
h = len(v) // 2
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
into = into if into else v
|
|
|
|
for i in range(h):
|
|
crypto.decodepoint_into(_tmp_pt_1, v.to(i))
|
|
crypto.decodepoint_into(_tmp_pt_2, v.to(h + i) if not vR else vR.to(i + vRoff))
|
|
crypto.add_keys3_into(_tmp_pt_3, _tmp_sc_1, _tmp_pt_1, _tmp_sc_2, _tmp_pt_2)
|
|
crypto.encodepoint_into(_tmp_bf_0, _tmp_pt_3)
|
|
into.read(i + into_offset, _tmp_bf_0)
|
|
_gc_iter(i)
|
|
|
|
return into
|
|
|
|
|
|
def _scalar_fold(v, a, b, into=None, into_offset: int = 0):
|
|
"""
|
|
ln = len(v); h = ln // 2
|
|
v_i = a v_i + b v_{h + i}
|
|
"""
|
|
h = len(v) // 2
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
|
|
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
|
|
into = into if into else v
|
|
|
|
for i in range(h):
|
|
crypto.decodeint_into_noreduce(_tmp_sc_3, v.to(i))
|
|
crypto.decodeint_into_noreduce(_tmp_sc_4, v.to(h + i))
|
|
crypto.sc_mul_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_1)
|
|
crypto.sc_muladd_into(_tmp_sc_3, _tmp_sc_4, _tmp_sc_2, _tmp_sc_3)
|
|
crypto.encodeint_into(_tmp_bf_0, _tmp_sc_3)
|
|
into.read(i + into_offset, _tmp_bf_0)
|
|
_gc_iter(i)
|
|
|
|
return into
|
|
|
|
|
|
def _cross_inner_product(l0, r0, l1, r1):
|
|
"""
|
|
t1 = l0 . r1 + l1 . r0
|
|
t2 = l1 . r1
|
|
"""
|
|
sc_t1 = crypto.Scalar()
|
|
sc_t2 = crypto.Scalar()
|
|
tl = crypto.Scalar()
|
|
tr = crypto.Scalar()
|
|
|
|
for i in range(len(l0)):
|
|
crypto.decodeint_into_noreduce(tl, l0.to(i))
|
|
crypto.decodeint_into_noreduce(tr, r1.to(i))
|
|
crypto.sc_muladd_into(sc_t1, tl, tr, sc_t1)
|
|
|
|
crypto.decodeint_into_noreduce(tl, l1.to(i))
|
|
crypto.sc_muladd_into(sc_t2, tl, tr, sc_t2)
|
|
|
|
crypto.decodeint_into_noreduce(tr, r0.to(i))
|
|
crypto.sc_muladd_into(sc_t1, tl, tr, sc_t1)
|
|
|
|
_gc_iter(i)
|
|
|
|
return crypto_helpers.encodeint(sc_t1), crypto_helpers.encodeint(sc_t2)
|
|
|
|
|
|
def _vector_gen(dst, size, op):
|
|
dst = _ensure_dst_keyvect(dst, size)
|
|
for i in range(size):
|
|
dst.to(i, _tmp_bf_0)
|
|
op(i, _tmp_bf_0)
|
|
dst.read(i, _tmp_bf_0)
|
|
_gc_iter(i)
|
|
return dst
|
|
|
|
|
|
def _vector_dup(x, n, dst=None):
|
|
dst = _ensure_dst_keyvect(dst, n)
|
|
for i in range(n):
|
|
dst[i] = x
|
|
_gc_iter(i)
|
|
return dst
|
|
|
|
|
|
def _hash_cache_mash(dst, hash_cache, *args):
|
|
dst = _ensure_dst_key(dst)
|
|
ctx = crypto_helpers.get_keccak()
|
|
ctx.update(hash_cache)
|
|
|
|
for x in args:
|
|
if x is None:
|
|
break
|
|
ctx.update(x)
|
|
hsh = ctx.digest()
|
|
|
|
crypto.decodeint_into(_tmp_sc_1, hsh)
|
|
crypto.encodeint_into(hash_cache, _tmp_sc_1)
|
|
_copy_key(dst, hash_cache)
|
|
return dst
|
|
|
|
|
|
def _is_reduced(sc) -> bool:
|
|
return crypto.encodeint_into(_tmp_bf_0, crypto.decodeint_into(_tmp_sc_1, sc)) == sc
|
|
|
|
|
|
class MultiExpSequential:
|
|
"""
|
|
MultiExp object similar to MultiExp array of [(scalar, point), ]
|
|
MultiExp computes simply: res = \\sum_i scalar_i * point_i
|
|
Straus / Pippenger algorithms are implemented in the original Monero C++ code for the speed
|
|
but the memory cost is around 1 MB which is not affordable here in HW devices.
|
|
|
|
Moreover, Monero needs speed for very fast verification for blockchain verification which is not
|
|
priority in this use case.
|
|
|
|
MultiExp holder with sequential evaluation
|
|
"""
|
|
|
|
def __init__(
|
|
self, size: int | None = None, points: list | None = None, point_fnc=None
|
|
) -> None:
|
|
self.current_idx = 0
|
|
self.size = size if size else None
|
|
self.points = points if points else []
|
|
self.point_fnc = point_fnc
|
|
if points and size is None:
|
|
self.size = len(points) if points else 0
|
|
else:
|
|
self.size = 0
|
|
|
|
self.acc = crypto.Point()
|
|
self.tmp = _ensure_dst_key()
|
|
|
|
def get_point(self, idx):
|
|
return (
|
|
self.point_fnc(idx, None) if idx >= len(self.points) else self.points[idx]
|
|
)
|
|
|
|
def add_pair(self, scalar, point) -> None:
|
|
self._acc(scalar, point)
|
|
|
|
def add_scalar(self, scalar) -> None:
|
|
self._acc(scalar, self.get_point(self.current_idx))
|
|
|
|
def add_scalar_idx(self, scalar, idx: int) -> None:
|
|
self._acc(scalar, self.get_point(idx))
|
|
|
|
def _acc(self, scalar, point) -> None:
|
|
crypto.decodeint_into_noreduce(_tmp_sc_1, scalar)
|
|
crypto.decodepoint_into(_tmp_pt_2, point)
|
|
crypto.scalarmult_into(_tmp_pt_3, _tmp_pt_2, _tmp_sc_1)
|
|
crypto.point_add_into(self.acc, self.acc, _tmp_pt_3)
|
|
self.current_idx += 1
|
|
self.size += 1
|
|
|
|
def eval(self, dst):
|
|
dst = _ensure_dst_key(dst)
|
|
return crypto.encodepoint_into(dst, self.acc)
|
|
|
|
|
|
def _multiexp(dst=None, data=None):
|
|
return data.eval(dst)
|
|
|
|
|
|
class BulletProofGenException(Exception):
|
|
pass
|
|
|
|
|
|
class BulletProofBuilder:
|
|
def __init__(self) -> None:
|
|
self.use_det_masks = True
|
|
self.proof_sec = None
|
|
|
|
# BP_GI_PRE = get_exponent(Gi[i], _XMR_H, i * 2 + 1)
|
|
self.Gprec = KeyV(buffer=crypto.BP_GI_PRE, const=True)
|
|
# BP_HI_PRE = get_exponent(Hi[i], _XMR_H, i * 2)
|
|
self.Hprec = KeyV(buffer=crypto.BP_HI_PRE, const=True)
|
|
# BP_TWO_N = vector_powers(_TWO, _BP_N);
|
|
self.twoN = KeyPow2Vct(250)
|
|
self.fnc_det_mask = None
|
|
|
|
# aL, aR amount bitmasks, can be freed once not needed
|
|
self.aL = None
|
|
self.aR = None
|
|
|
|
self.tmp_sc_1 = crypto.Scalar()
|
|
self.tmp_det_buff = bytearray(64 + 1 + 4)
|
|
|
|
self.gc_fnc = gc.collect
|
|
self.gc_trace = None
|
|
|
|
def gc(self, *args) -> None:
|
|
if self.gc_trace:
|
|
self.gc_trace(*args)
|
|
if self.gc_fnc:
|
|
self.gc_fnc()
|
|
|
|
def aX_vcts(self, sv, MN) -> tuple:
|
|
num_inp = len(sv)
|
|
|
|
def e_xL(idx, d=None, is_a=True):
|
|
j, i = idx // _BP_N, idx % _BP_N
|
|
r = None
|
|
if j < num_inp and sv[j][i // 8] & (1 << i % 8):
|
|
r = _ONE if is_a else _ZERO
|
|
else:
|
|
r = _ZERO if is_a else _MINUS_ONE
|
|
if d:
|
|
return _copy_key(d, r)
|
|
return r
|
|
|
|
aL = KeyVEval(MN, lambda i, d: e_xL(i, d, True))
|
|
aR = KeyVEval(MN, lambda i, d: e_xL(i, d, False))
|
|
return aL, aR
|
|
|
|
def _det_mask_init(self) -> None:
|
|
memcpy(self.tmp_det_buff, 0, self.proof_sec, 0, len(self.proof_sec))
|
|
|
|
def _det_mask(self, i, is_sL: bool = True, dst: bytearray | None = None):
|
|
dst = _ensure_dst_key(dst)
|
|
if self.fnc_det_mask:
|
|
return self.fnc_det_mask(i, is_sL, dst)
|
|
self.tmp_det_buff[64] = int(is_sL)
|
|
memcpy(self.tmp_det_buff, 65, _ZERO, 0, 4)
|
|
dump_uvarint_b_into(i, self.tmp_det_buff, 65)
|
|
crypto.hash_to_scalar_into(self.tmp_sc_1, self.tmp_det_buff)
|
|
crypto.encodeint_into(dst, self.tmp_sc_1)
|
|
return dst
|
|
|
|
def _gprec_aux(self, size: int) -> KeyVPrecomp:
|
|
return KeyVPrecomp(
|
|
size, self.Gprec, lambda i, d: _get_exponent(d, _XMR_H, i * 2 + 1)
|
|
)
|
|
|
|
def _hprec_aux(self, size: int) -> KeyVPrecomp:
|
|
return KeyVPrecomp(
|
|
size, self.Hprec, lambda i, d: _get_exponent(d, _XMR_H, i * 2)
|
|
)
|
|
|
|
def _two_aux(self, size: int) -> KeyVPrecomp:
|
|
# Simple recursive exponentiation from precomputed results
|
|
lx = len(self.twoN)
|
|
|
|
def pow_two(i: int, d=None):
|
|
if i < lx:
|
|
return self.twoN[i]
|
|
|
|
d = _ensure_dst_key(d)
|
|
flr = i // 2
|
|
|
|
lw = pow_two(flr)
|
|
rw = pow_two(flr + 1 if flr != i / 2.0 else lw)
|
|
return _sc_mul(d, lw, rw)
|
|
|
|
return KeyVPrecomp(size, self.twoN, pow_two)
|
|
|
|
def sL_vct(self, ln=_BP_N):
|
|
return (
|
|
KeyVEval(ln, lambda i, dst: self._det_mask(i, True, dst))
|
|
if self.use_det_masks
|
|
else self.sX_gen(ln)
|
|
)
|
|
|
|
def sR_vct(self, ln=_BP_N):
|
|
return (
|
|
KeyVEval(ln, lambda i, dst: self._det_mask(i, False, dst))
|
|
if self.use_det_masks
|
|
else self.sX_gen(ln)
|
|
)
|
|
|
|
def sX_gen(self, ln=_BP_N) -> KeyV:
|
|
gc.collect()
|
|
buff = bytearray(ln * 32)
|
|
buff_mv = memoryview(buff)
|
|
sc = crypto.Scalar()
|
|
for i in range(ln):
|
|
crypto.random_scalar(sc)
|
|
crypto.encodeint_into(buff_mv[i * 32 : (i + 1) * 32], sc)
|
|
_gc_iter(i)
|
|
return KeyV(buffer=buff)
|
|
|
|
def vector_exponent(self, a, b, dst=None, a_raw=None, b_raw=None):
|
|
return _vector_exponent_custom(self.Gprec, self.Hprec, a, b, dst, a_raw, b_raw)
|
|
|
|
def prove(self, sv: crypto.Scalar, gamma: crypto.Scalar):
|
|
return self.prove_batch([sv], [gamma])
|
|
|
|
def prove_setup(self, sv: list[crypto.Scalar], gamma: list[crypto.Scalar]) -> tuple:
|
|
utils.ensure(len(sv) == len(gamma), "|sv| != |gamma|")
|
|
utils.ensure(len(sv) > 0, "sv empty")
|
|
|
|
self.proof_sec = random.bytes(64)
|
|
self._det_mask_init()
|
|
gc.collect()
|
|
sv = [crypto_helpers.encodeint(x) for x in sv]
|
|
gamma = [crypto_helpers.encodeint(x) for x in gamma]
|
|
|
|
M, logM = 1, 0
|
|
while M <= _BP_M and M < len(sv):
|
|
logM += 1
|
|
M = 1 << logM
|
|
MN = M * _BP_N
|
|
|
|
V = _ensure_dst_keyvect(None, len(sv))
|
|
for i in range(len(sv)):
|
|
_add_keys2(_tmp_bf_0, gamma[i], sv[i], _XMR_H)
|
|
_scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT)
|
|
V.read(i, _tmp_bf_0)
|
|
|
|
self.prove_setup_aLaR(MN, None, sv)
|
|
return M, logM, V, gamma
|
|
|
|
def prove_setup_aLaR(self, MN, sv, sv_vct=None):
|
|
sv_vct = sv_vct if sv_vct else [crypto_helpers.encodeint(x) for x in sv]
|
|
self.aL, self.aR = self.aX_vcts(sv_vct, MN)
|
|
|
|
def prove_batch(
|
|
self, sv: list[crypto.Scalar], gamma: list[crypto.Scalar]
|
|
) -> Bulletproof:
|
|
M, logM, V, gamma = self.prove_setup(sv, gamma)
|
|
hash_cache = _ensure_dst_key()
|
|
while True:
|
|
self.gc(10)
|
|
try:
|
|
return self._prove_batch_main(V, gamma, hash_cache, logM, M)
|
|
except BulletProofGenException:
|
|
self.prove_setup_aLaR(M * _BP_N, sv)
|
|
continue
|
|
|
|
def _prove_batch_main(self, V, gamma, hash_cache, logM, M) -> Bulletproof:
|
|
N = _BP_N
|
|
logN = _BP_LOG_N
|
|
logMN = logM + logN
|
|
MN = M * N
|
|
_hash_vct_to_scalar(hash_cache, V)
|
|
|
|
# PHASE 1
|
|
A, S, T1, T2, taux, mu, t, l, r, y, x_ip, hash_cache = self._prove_phase1(
|
|
N, M, V, gamma, hash_cache
|
|
)
|
|
|
|
# PHASE 2
|
|
L, R, a, b = self._prove_loop(MN, logMN, l, r, y, x_ip, hash_cache)
|
|
|
|
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof
|
|
|
|
return Bulletproof(
|
|
V=V, A=A, S=S, T1=T1, T2=T2, taux=taux, mu=mu, L=L, R=R, a=a, b=b, t=t
|
|
)
|
|
|
|
def _prove_phase1(self, N, M, V, gamma, hash_cache) -> tuple:
|
|
MN = M * N
|
|
aL = self.aL
|
|
aR = self.aR
|
|
Gprec = self._gprec_aux(MN)
|
|
Hprec = self._hprec_aux(MN)
|
|
|
|
# PAPER LINES 38-39, compute A = 8^{-1} ( \alpha G + \sum_{i=0}^{MN-1} a_{L,i} \Gi_i + a_{R,i} \Hi_i)
|
|
alpha = _sc_gen()
|
|
A = _ensure_dst_key()
|
|
_vector_exponent_custom(Gprec, Hprec, aL, aR, A)
|
|
_add_keys(A, A, _scalarmult_base(_tmp_bf_1, alpha))
|
|
_scalarmult_key(A, A, _INV_EIGHT)
|
|
self.gc(11)
|
|
|
|
# PAPER LINES 40-42, compute S = 8^{-1} ( \rho G + \sum_{i=0}^{MN-1} s_{L,i} \Gi_i + s_{R,i} \Hi_i)
|
|
sL = self.sL_vct(MN)
|
|
sR = self.sR_vct(MN)
|
|
rho = _sc_gen()
|
|
S = _ensure_dst_key()
|
|
_vector_exponent_custom(Gprec, Hprec, sL, sR, S)
|
|
_add_keys(S, S, _scalarmult_base(_tmp_bf_1, rho))
|
|
_scalarmult_key(S, S, _INV_EIGHT)
|
|
self.gc(12)
|
|
|
|
# PAPER LINES 43-45
|
|
y = _ensure_dst_key()
|
|
_hash_cache_mash(y, hash_cache, A, S)
|
|
if y == _ZERO:
|
|
raise BulletProofGenException()
|
|
|
|
z = _ensure_dst_key()
|
|
_hash_to_scalar(hash_cache, y)
|
|
_copy_key(z, hash_cache)
|
|
zc = crypto.decodeint_into_noreduce(None, z)
|
|
if z == _ZERO:
|
|
raise BulletProofGenException()
|
|
|
|
# Polynomial construction by coefficients
|
|
# l0 = aL - z r0 = ((aR + z) . ypow) + zt
|
|
# l1 = sL r1 = sR . ypow
|
|
l0 = KeyVEval(MN, lambda i, d: _sc_sub(d, aL.to(i), zc)) # noqa: F821
|
|
l1 = sL
|
|
self.gc(13)
|
|
|
|
# This computes the ugly sum/concatenation from PAPER LINE 65
|
|
# r0_i = ((a_{Ri} + z) y^{i}) + zt_i
|
|
# r1_i = s_{Ri} y^{i}
|
|
r0 = KeyR0(MN, N, aR, y, z)
|
|
ypow = KeyVPowers(MN, y, raw=True)
|
|
r1 = KeyVEval(MN, lambda i, d: _sc_mul(d, sR.to(i), ypow[i])) # noqa: F821
|
|
del aR
|
|
self.gc(14)
|
|
|
|
# Evaluate per index
|
|
# - $t_1 = l_0 . r_1 + l_1 . r0$
|
|
# - $t_2 = l_1 . r_1$
|
|
# - compute then T1, T2, x
|
|
t1, t2 = _cross_inner_product(l0, r0, l1, r1)
|
|
|
|
# PAPER LINES 47-48, Compute: T1, T2
|
|
# T1 = 8^{-1} (\tau_1G + t_1H )
|
|
# T2 = 8^{-1} (\tau_2G + t_2H )
|
|
tau1, tau2 = _sc_gen(), _sc_gen()
|
|
T1, T2 = _ensure_dst_key(), _ensure_dst_key()
|
|
|
|
_add_keys2(T1, tau1, t1, _XMR_H)
|
|
_scalarmult_key(T1, T1, _INV_EIGHT)
|
|
|
|
_add_keys2(T2, tau2, t2, _XMR_H)
|
|
_scalarmult_key(T2, T2, _INV_EIGHT)
|
|
del (t1, t2)
|
|
self.gc(16)
|
|
|
|
# PAPER LINES 49-51, compute x
|
|
x = _ensure_dst_key()
|
|
_hash_cache_mash(x, hash_cache, z, T1, T2)
|
|
if x == _ZERO:
|
|
raise BulletProofGenException()
|
|
|
|
# Second pass, compute l, r
|
|
# Offloaded version does this incrementally and produces l, r outs in chunks
|
|
# Message offloaded sends blinded vectors with random constants.
|
|
# - $l_i = l_{0,i} + xl_{1,i}
|
|
# - $r_i = r_{0,i} + xr_{1,i}
|
|
# - $t = l . r$
|
|
l = _ensure_dst_keyvect(None, MN)
|
|
r = _ensure_dst_keyvect(None, MN)
|
|
ts = crypto.Scalar()
|
|
for i in range(MN):
|
|
_sc_muladd(_tmp_bf_0, x, l1.to(i), l0.to(i))
|
|
l.read(i, _tmp_bf_0)
|
|
|
|
_sc_muladd(_tmp_bf_1, x, r1.to(i), r0.to(i))
|
|
r.read(i, _tmp_bf_1)
|
|
|
|
_sc_muladd(ts, _tmp_bf_0, _tmp_bf_1, ts)
|
|
|
|
t = crypto_helpers.encodeint(ts)
|
|
self.aL = None
|
|
self.aR = None
|
|
del (l0, l1, sL, sR, r0, r1, ypow, ts, aL)
|
|
self.gc(17)
|
|
|
|
# PAPER LINES 52-53, Compute \tau_x
|
|
taux = _ensure_dst_key()
|
|
_sc_mul(taux, tau1, x)
|
|
_sc_mul(_tmp_bf_0, x, x)
|
|
_sc_muladd(taux, tau2, _tmp_bf_0, taux)
|
|
del (tau1, tau2)
|
|
|
|
zpow = crypto.sc_mul_into(None, zc, zc)
|
|
for j in range(1, len(V) + 1):
|
|
_sc_muladd(taux, zpow, gamma[j - 1], taux)
|
|
crypto.sc_mul_into(zpow, zpow, zc)
|
|
del (zc, zpow)
|
|
|
|
self.gc(18)
|
|
mu = _ensure_dst_key()
|
|
_sc_muladd(mu, x, rho, alpha)
|
|
del (rho, alpha)
|
|
self.gc(19)
|
|
|
|
# PAPER LINES 32-33
|
|
x_ip = _hash_cache_mash(None, hash_cache, x, taux, mu, t)
|
|
if x_ip == _ZERO:
|
|
raise BulletProofGenException()
|
|
|
|
return A, S, T1, T2, taux, mu, t, l, r, y, x_ip, hash_cache
|
|
|
|
def _prove_loop(self, MN, logMN, l, r, y, x_ip, hash_cache) -> tuple:
|
|
nprime = MN
|
|
aprime = l
|
|
bprime = r
|
|
|
|
Hprec = self._hprec_aux(MN)
|
|
|
|
yinvpowL = KeyVPowers(MN, _invert(_tmp_bf_0, y), raw=True)
|
|
yinvpowR = KeyVPowers(MN, _tmp_bf_0, raw=True)
|
|
tmp_pt = crypto.Point()
|
|
|
|
Gprime = self._gprec_aux(MN)
|
|
HprimeL = KeyVEval(
|
|
MN, lambda i, d: _scalarmult_key(d, Hprec.to(i), None, yinvpowL[i])
|
|
)
|
|
HprimeR = KeyVEval(
|
|
MN, lambda i, d: _scalarmult_key(d, Hprec.to(i), None, yinvpowR[i], tmp_pt)
|
|
)
|
|
Hprime = HprimeL
|
|
self.gc(20)
|
|
|
|
L = _ensure_dst_keyvect(None, logMN)
|
|
R = _ensure_dst_keyvect(None, logMN)
|
|
cL = _ensure_dst_key()
|
|
cR = _ensure_dst_key()
|
|
winv = _ensure_dst_key()
|
|
w_round = _ensure_dst_key()
|
|
tmp = _ensure_dst_key()
|
|
_tmp_k_1 = _ensure_dst_key()
|
|
round = 0
|
|
|
|
# PAPER LINE 13
|
|
while nprime > 1:
|
|
# PAPER LINE 15
|
|
npr2 = nprime
|
|
nprime >>= 1
|
|
self.gc(22)
|
|
|
|
# PAPER LINES 16-17
|
|
# cL = \ap_{\left(\inta\right)} \cdot \bp_{\left(\intb\right)}
|
|
# cR = \ap_{\left(\intb\right)} \cdot \bp_{\left(\inta\right)}
|
|
_inner_product(
|
|
aprime.slice_view(0, nprime), bprime.slice_view(nprime, npr2), cL
|
|
)
|
|
|
|
_inner_product(
|
|
aprime.slice_view(nprime, npr2), bprime.slice_view(0, nprime), cR
|
|
)
|
|
self.gc(23)
|
|
|
|
# PAPER LINES 18-19
|
|
# Lc = 8^{-1} \left(\left( \sum_{i=0}^{\np} \ap_{i}\quad\Gp_{i+\np} + \bp_{i+\np}\Hp_{i} \right)
|
|
# + \left(c_L x_{ip}\right)H \right)
|
|
_vector_exponent_custom(
|
|
Gprime.slice_view(nprime, npr2),
|
|
Hprime.slice_view(0, nprime),
|
|
aprime.slice_view(0, nprime),
|
|
bprime.slice_view(nprime, npr2),
|
|
_tmp_bf_0,
|
|
)
|
|
|
|
# In round 0 backup the y^{prime - 1}
|
|
if round == 0:
|
|
yinvpowR.set_state(yinvpowL.last_idx, yinvpowL.cur)
|
|
|
|
_sc_mul(tmp, cL, x_ip)
|
|
_add_keys(_tmp_bf_0, _tmp_bf_0, _scalarmultH(_tmp_k_1, tmp))
|
|
_scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT)
|
|
L.read(round, _tmp_bf_0)
|
|
self.gc(24)
|
|
|
|
# Rc = 8^{-1} \left(\left( \sum_{i=0}^{\np} \ap_{i+\np}\Gp_{i}\quad + \bp_{i}\quad\Hp_{i+\np} \right)
|
|
# + \left(c_R x_{ip}\right)H \right)
|
|
_vector_exponent_custom(
|
|
Gprime.slice_view(0, nprime),
|
|
Hprime.slice_view(nprime, npr2),
|
|
aprime.slice_view(nprime, npr2),
|
|
bprime.slice_view(0, nprime),
|
|
_tmp_bf_0,
|
|
)
|
|
|
|
_sc_mul(tmp, cR, x_ip)
|
|
_add_keys(_tmp_bf_0, _tmp_bf_0, _scalarmultH(_tmp_k_1, tmp))
|
|
_scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT)
|
|
R.read(round, _tmp_bf_0)
|
|
self.gc(25)
|
|
|
|
# PAPER LINES 21-22
|
|
_hash_cache_mash(w_round, hash_cache, L.to(round), R.to(round))
|
|
if w_round == _ZERO:
|
|
raise BulletProofGenException()
|
|
|
|
# PAPER LINES 24-25, fold {G~, H~}
|
|
_invert(winv, w_round)
|
|
self.gc(26)
|
|
|
|
# PAPER LINES 28-29, fold {a, b} vectors
|
|
# aprime's high part is used as a buffer for other operations
|
|
_scalar_fold(aprime, w_round, winv)
|
|
aprime.resize(nprime)
|
|
self.gc(27)
|
|
|
|
_scalar_fold(bprime, winv, w_round)
|
|
bprime.resize(nprime)
|
|
self.gc(28)
|
|
|
|
# First fold produced to a new buffer, smaller one (G~ on-the-fly)
|
|
Gprime_new = KeyV(nprime) if round == 0 else Gprime
|
|
Gprime = _hadamard_fold(Gprime, winv, w_round, Gprime_new, 0)
|
|
Gprime.resize(nprime)
|
|
self.gc(30)
|
|
|
|
# Hadamard fold for H is special - linear scan only.
|
|
# Linear scan is slow, thus we have HprimeR.
|
|
if round == 0:
|
|
Hprime_new = KeyV(nprime)
|
|
Hprime = _hadamard_fold(
|
|
Hprime, w_round, winv, Hprime_new, 0, HprimeR, nprime
|
|
)
|
|
# Hprime = _hadamard_fold_linear(Hprime, w_round, winv, Hprime_new, 0)
|
|
|
|
else:
|
|
_hadamard_fold(Hprime, w_round, winv)
|
|
Hprime.resize(nprime)
|
|
|
|
if round == 0:
|
|
# del (Gprec, Hprec, yinvpowL, HprimeL)
|
|
del (Hprec, yinvpowL, yinvpowR, HprimeL, HprimeR, tmp_pt)
|
|
|
|
self.gc(31)
|
|
round += 1
|
|
|
|
return L, R, aprime.to(0), bprime.to(0)
|
|
|
|
def verify(self, proof) -> bool:
|
|
return self.verify_batch([proof])
|
|
|
|
def verify_batch(self, proofs: list[Bulletproof], single_optim: bool = True):
|
|
"""
|
|
BP batch verification
|
|
:param proofs:
|
|
:param single_optim: single proof memory optimization
|
|
:return:
|
|
"""
|
|
max_length = 0
|
|
for proof in proofs:
|
|
utils.ensure(_is_reduced(proof.taux), "Input scalar not in range")
|
|
utils.ensure(_is_reduced(proof.mu), "Input scalar not in range")
|
|
utils.ensure(_is_reduced(proof.a), "Input scalar not in range")
|
|
utils.ensure(_is_reduced(proof.b), "Input scalar not in range")
|
|
utils.ensure(_is_reduced(proof.t), "Input scalar not in range")
|
|
utils.ensure(len(proof.V) >= 1, "V does not have at least one element")
|
|
utils.ensure(len(proof.L) == len(proof.R), "|L| != |R|")
|
|
utils.ensure(len(proof.L) > 0, "Empty proof")
|
|
max_length = max(max_length, len(proof.L))
|
|
|
|
utils.ensure(max_length < 32, "At least one proof is too large")
|
|
|
|
maxMN = 1 << max_length
|
|
logN = 6
|
|
N = 1 << logN
|
|
tmp = _ensure_dst_key()
|
|
|
|
# setup weighted aggregates
|
|
is_single = len(proofs) == 1 and single_optim # ph4
|
|
z1 = _init_key(_ZERO)
|
|
z3 = _init_key(_ZERO)
|
|
m_z4 = _vector_dup(_ZERO, maxMN) if not is_single else None
|
|
m_z5 = _vector_dup(_ZERO, maxMN) if not is_single else None
|
|
m_y0 = _init_key(_ZERO)
|
|
y1 = _init_key(_ZERO)
|
|
muex_acc = _init_key(_ONE)
|
|
|
|
Gprec = self._gprec_aux(maxMN)
|
|
Hprec = self._hprec_aux(maxMN)
|
|
|
|
for proof in proofs:
|
|
M = 1
|
|
logM = 0
|
|
while M <= _BP_M and M < len(proof.V):
|
|
logM += 1
|
|
M = 1 << logM
|
|
|
|
utils.ensure(len(proof.L) == 6 + logM, "Proof is not the expected size")
|
|
MN = M * N
|
|
weight_y = crypto_helpers.encodeint(crypto.random_scalar())
|
|
weight_z = crypto_helpers.encodeint(crypto.random_scalar())
|
|
|
|
# Reconstruct the challenges
|
|
hash_cache = _hash_vct_to_scalar(None, proof.V)
|
|
y = _hash_cache_mash(None, hash_cache, proof.A, proof.S)
|
|
utils.ensure(y != _ZERO, "y == 0")
|
|
z = _hash_to_scalar(None, y)
|
|
_copy_key(hash_cache, z)
|
|
utils.ensure(z != _ZERO, "z == 0")
|
|
|
|
x = _hash_cache_mash(None, hash_cache, z, proof.T1, proof.T2)
|
|
utils.ensure(x != _ZERO, "x == 0")
|
|
x_ip = _hash_cache_mash(None, hash_cache, x, proof.taux, proof.mu, proof.t)
|
|
utils.ensure(x_ip != _ZERO, "x_ip == 0")
|
|
|
|
# PAPER LINE 61
|
|
_sc_mulsub(m_y0, proof.taux, weight_y, m_y0)
|
|
zpow = _vector_powers(z, M + 3)
|
|
|
|
k = _ensure_dst_key()
|
|
ip1y = _vector_power_sum(y, MN)
|
|
_sc_mulsub(k, zpow.to(2), ip1y, _ZERO)
|
|
for j in range(1, M + 1):
|
|
utils.ensure(j + 2 < len(zpow), "invalid zpow index")
|
|
_sc_mulsub(k, zpow.to(j + 2), _BP_IP12, k)
|
|
|
|
# VERIFY_line_61rl_new
|
|
_sc_muladd(tmp, z, ip1y, k)
|
|
_sc_sub(tmp, proof.t, tmp)
|
|
|
|
_sc_muladd(y1, tmp, weight_y, y1)
|
|
weight_y8 = _init_key(weight_y)
|
|
_sc_mul(weight_y8, weight_y, _EIGHT)
|
|
|
|
muex = MultiExpSequential(points=[pt for pt in proof.V])
|
|
for j in range(len(proof.V)):
|
|
_sc_mul(tmp, zpow.to(j + 2), weight_y8)
|
|
muex.add_scalar(_init_key(tmp))
|
|
|
|
_sc_mul(tmp, x, weight_y8)
|
|
muex.add_pair(_init_key(tmp), proof.T1)
|
|
|
|
xsq = _ensure_dst_key()
|
|
_sc_mul(xsq, x, x)
|
|
|
|
_sc_mul(tmp, xsq, weight_y8)
|
|
muex.add_pair(_init_key(tmp), proof.T2)
|
|
|
|
weight_z8 = _init_key(weight_z)
|
|
_sc_mul(weight_z8, weight_z, _EIGHT)
|
|
|
|
muex.add_pair(weight_z8, proof.A)
|
|
_sc_mul(tmp, x, weight_z8)
|
|
muex.add_pair(_init_key(tmp), proof.S)
|
|
|
|
_multiexp(tmp, muex)
|
|
_add_keys(muex_acc, muex_acc, tmp)
|
|
del muex
|
|
|
|
# Compute the number of rounds for the inner product
|
|
rounds = logM + logN
|
|
utils.ensure(rounds > 0, "Zero rounds")
|
|
|
|
# PAPER LINES 21-22
|
|
# The inner product challenges are computed per round
|
|
w = _ensure_dst_keyvect(None, rounds)
|
|
for i in range(rounds):
|
|
_hash_cache_mash(_tmp_bf_0, hash_cache, proof.L[i], proof.R[i])
|
|
w.read(i, _tmp_bf_0)
|
|
utils.ensure(w.to(i) != _ZERO, "w[i] == 0")
|
|
|
|
# Basically PAPER LINES 24-25
|
|
# Compute the curvepoints from G[i] and H[i]
|
|
yinvpow = _init_key(_ONE)
|
|
ypow = _init_key(_ONE)
|
|
yinv = _invert(None, y)
|
|
self.gc(61)
|
|
|
|
winv = _ensure_dst_keyvect(None, rounds)
|
|
for i in range(rounds):
|
|
_invert(_tmp_bf_0, w.to(i))
|
|
winv.read(i, _tmp_bf_0)
|
|
self.gc(62)
|
|
|
|
g_scalar = _ensure_dst_key()
|
|
h_scalar = _ensure_dst_key()
|
|
twoN = self._two_aux(N)
|
|
for i in range(MN):
|
|
_copy_key(g_scalar, proof.a)
|
|
_sc_mul(h_scalar, proof.b, yinvpow)
|
|
|
|
for j in range(rounds - 1, -1, -1):
|
|
J = len(w) - j - 1
|
|
|
|
if (i & (1 << j)) == 0:
|
|
_sc_mul(g_scalar, g_scalar, winv.to(J))
|
|
_sc_mul(h_scalar, h_scalar, w.to(J))
|
|
else:
|
|
_sc_mul(g_scalar, g_scalar, w.to(J))
|
|
_sc_mul(h_scalar, h_scalar, winv.to(J))
|
|
|
|
# Adjust the scalars using the exponents from PAPER LINE 62
|
|
_sc_add(g_scalar, g_scalar, z)
|
|
utils.ensure(2 + i // N < len(zpow), "invalid zpow index")
|
|
utils.ensure(i % N < len(twoN), "invalid twoN index")
|
|
_sc_mul(tmp, zpow.to(2 + i // N), twoN.to(i % N))
|
|
_sc_muladd(tmp, z, ypow, tmp)
|
|
_sc_mulsub(h_scalar, tmp, yinvpow, h_scalar)
|
|
|
|
if not is_single: # ph4
|
|
assert m_z4 is not None and m_z5 is not None
|
|
m_z4.read(i, _sc_mulsub(_tmp_bf_0, g_scalar, weight_z, m_z4[i]))
|
|
m_z5.read(i, _sc_mulsub(_tmp_bf_0, h_scalar, weight_z, m_z5[i]))
|
|
else:
|
|
_sc_mul(tmp, g_scalar, weight_z)
|
|
_sub_keys(
|
|
muex_acc, muex_acc, _scalarmult_key(tmp, Gprec.to(i), tmp)
|
|
)
|
|
|
|
_sc_mul(tmp, h_scalar, weight_z)
|
|
_sub_keys(
|
|
muex_acc, muex_acc, _scalarmult_key(tmp, Hprec.to(i), tmp)
|
|
)
|
|
|
|
if i != MN - 1:
|
|
_sc_mul(yinvpow, yinvpow, yinv)
|
|
_sc_mul(ypow, ypow, y)
|
|
if i & 15 == 0:
|
|
self.gc(62)
|
|
|
|
del (g_scalar, h_scalar, twoN)
|
|
self.gc(63)
|
|
|
|
_sc_muladd(z1, proof.mu, weight_z, z1)
|
|
muex = MultiExpSequential(
|
|
point_fnc=lambda i, d: proof.L[i // 2]
|
|
if i & 1 == 0
|
|
else proof.R[i // 2]
|
|
)
|
|
for i in range(rounds):
|
|
_sc_mul(tmp, w.to(i), w.to(i))
|
|
_sc_mul(tmp, tmp, weight_z8)
|
|
muex.add_scalar(tmp)
|
|
_sc_mul(tmp, winv.to(i), winv.to(i))
|
|
_sc_mul(tmp, tmp, weight_z8)
|
|
muex.add_scalar(tmp)
|
|
|
|
acc = _multiexp(None, muex)
|
|
_add_keys(muex_acc, muex_acc, acc)
|
|
|
|
_sc_mulsub(tmp, proof.a, proof.b, proof.t)
|
|
_sc_mul(tmp, tmp, x_ip)
|
|
_sc_muladd(z3, tmp, weight_z, z3)
|
|
|
|
_sc_sub(tmp, m_y0, z1)
|
|
z3p = _sc_sub(None, z3, y1)
|
|
|
|
check2 = crypto_helpers.encodepoint(
|
|
crypto.ge25519_double_scalarmult_vartime_into(
|
|
None,
|
|
crypto.xmr_H(),
|
|
crypto_helpers.decodeint(z3p),
|
|
crypto_helpers.decodeint(tmp),
|
|
)
|
|
)
|
|
_add_keys(muex_acc, muex_acc, check2)
|
|
|
|
if not is_single: # ph4
|
|
assert m_z4 is not None and m_z5 is not None
|
|
muex = MultiExpSequential(
|
|
point_fnc=lambda i, d: Gprec.to(i // 2)
|
|
if i & 1 == 0
|
|
else Hprec.to(i // 2)
|
|
)
|
|
for i in range(maxMN):
|
|
muex.add_scalar(m_z4[i])
|
|
muex.add_scalar(m_z5[i])
|
|
_add_keys(muex_acc, muex_acc, _multiexp(None, muex))
|
|
|
|
if muex_acc != _ONE:
|
|
raise ValueError("Verification failure at step 2")
|
|
return True
|
|
|
|
|
|
def _compute_LR(
|
|
size: int,
|
|
y: bytearray,
|
|
G: KeyVBase,
|
|
G0: int,
|
|
H: KeyVBase,
|
|
H0: int,
|
|
a: KeyVBase,
|
|
a0: int,
|
|
b: KeyVBase,
|
|
b0: int,
|
|
c: bytearray,
|
|
d: bytearray,
|
|
tmp: bytearray = _tmp_bf_0,
|
|
) -> bytearray:
|
|
"""
|
|
LR computation for BP+
|
|
returns:
|
|
c * 8^{-1} * H +
|
|
d * 8^{-1} * G +
|
|
\\sum_i a_{a0 + i} * 8^{-1} * y * G_{G0+i} +
|
|
b_{b0 + i} * 8^{-1} * H_{H0+i}
|
|
"""
|
|
muex = MultiExpSequential()
|
|
for i in range(size):
|
|
_sc_mul(tmp, a.to(a0 + i), y)
|
|
_sc_mul(tmp, tmp, _INV_EIGHT)
|
|
muex.add_pair(tmp, G.to(G0 + i))
|
|
|
|
_sc_mul(tmp, b.to(b0 + i), _INV_EIGHT)
|
|
muex.add_pair(tmp, H.to(H0 + i))
|
|
|
|
muex.add_pair(_sc_mul(tmp, c, _INV_EIGHT), _XMR_H)
|
|
muex.add_pair(_sc_mul(tmp, d, _INV_EIGHT), _XMR_G)
|
|
return _multiexp(tmp, muex)
|
|
|
|
|
|
class BulletProofPlusData:
|
|
def __init__(self):
|
|
self.y = None
|
|
self.z = None
|
|
self.e = None
|
|
self.challenges = None
|
|
self.logM = None
|
|
self.inv_offset = None
|
|
|
|
|
|
class BulletProofPlusBuilder:
|
|
"""
|
|
Bulletproof+
|
|
https://eprint.iacr.org/2020/735.pdf
|
|
https://github.com/monero-project/monero/blob/67e5ca9ad6f1c861ad315476a88f9d36c38a0abb/src/ringct/bulletproofs_plus.cc
|
|
"""
|
|
|
|
def __init__(self, save_mem=True) -> None:
|
|
self.save_mem = save_mem
|
|
|
|
# BP_GI_PRE = _get_exponent_plus(Gi[i], _XMR_H, i * 2 + 1)
|
|
self.Gprec = KeyV(buffer=crypto.BP_PLUS_GI_PRE, const=True)
|
|
|
|
# BP_HI_PRE = None #_get_exponent_plus(Hi[i], _XMR_H, i * 2)
|
|
self.Hprec = KeyV(buffer=crypto.BP_PLUS_HI_PRE, const=True)
|
|
|
|
# aL, aR amount bitmasks, can be freed once not needed
|
|
self.aL = None
|
|
self.aR = None
|
|
|
|
self.gc_fnc = gc.collect
|
|
self.gc_trace = None
|
|
|
|
def gc(self, *args) -> None:
|
|
if self.gc_trace:
|
|
self.gc_trace(*args)
|
|
if self.gc_fnc:
|
|
self.gc_fnc()
|
|
|
|
def aX_vcts(self, sv, MN) -> tuple:
|
|
num_inp = len(sv)
|
|
sc_zero = crypto.decodeint_into_noreduce(None, _ZERO)
|
|
sc_one = crypto.decodeint_into_noreduce(None, _ONE)
|
|
sc_mone = crypto.decodeint_into_noreduce(None, _MINUS_ONE)
|
|
|
|
def e_xL(idx, d=None, is_a=True):
|
|
j, i = idx // _BP_N, idx % _BP_N
|
|
r = None
|
|
if j < num_inp and sv[j][i // 8] & (1 << i % 8):
|
|
r = sc_one if is_a else sc_zero
|
|
else:
|
|
r = sc_zero if is_a else sc_mone
|
|
if d:
|
|
return crypto.sc_copy(d, r)
|
|
return r
|
|
|
|
aL = KeyVEval(MN, lambda i, d: e_xL(i, d, True), raw=True)
|
|
aR = KeyVEval(MN, lambda i, d: e_xL(i, d, False), raw=True)
|
|
return aL, aR
|
|
|
|
def _gprec_aux(self, size: int) -> KeyVPrecomp:
|
|
return KeyVPrecomp(
|
|
size, self.Gprec, lambda i, d: _get_exponent_plus(d, _XMR_H, i * 2 + 1)
|
|
)
|
|
|
|
def _hprec_aux(self, size: int) -> KeyVPrecomp:
|
|
return KeyVPrecomp(
|
|
size, self.Hprec, lambda i, d: _get_exponent_plus(d, _XMR_H, i * 2)
|
|
)
|
|
|
|
def vector_exponent(self, a, b, dst=None, a_raw=None, b_raw=None):
|
|
return _vector_exponent_custom(self.Gprec, self.Hprec, a, b, dst, a_raw, b_raw)
|
|
|
|
def prove(
|
|
self, sv: list[crypto.Scalar], gamma: list[crypto.Scalar]
|
|
) -> BulletproofPlus:
|
|
return self.prove_batch([sv], [gamma])
|
|
|
|
def prove_setup(self, sv: list[crypto.Scalar], gamma: list[crypto.Scalar]) -> tuple:
|
|
utils.ensure(len(sv) == len(gamma), "|sv| != |gamma|")
|
|
utils.ensure(len(sv) > 0, "sv empty")
|
|
|
|
gc.collect()
|
|
sv = [crypto_helpers.encodeint(x) for x in sv]
|
|
gamma = [crypto_helpers.encodeint(x) for x in gamma]
|
|
|
|
M, logM = 1, 0
|
|
while M <= _BP_M and M < len(sv):
|
|
logM += 1
|
|
M = 1 << logM
|
|
MN = M * _BP_N
|
|
|
|
V = _ensure_dst_keyvect(None, len(sv))
|
|
for i in range(len(sv)):
|
|
_add_keys2(_tmp_bf_0, gamma[i], sv[i], _XMR_H)
|
|
_scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT)
|
|
V.read(i, _tmp_bf_0)
|
|
|
|
self.prove_setup_aLaR(MN, None, sv)
|
|
return M, logM, V, gamma
|
|
|
|
def prove_setup_aLaR(self, MN, sv, sv_vct=None):
|
|
sv_vct = sv_vct if sv_vct else [crypto_helpers.encodeint(x) for x in sv]
|
|
self.aL, self.aR = self.aX_vcts(sv_vct, MN)
|
|
|
|
def prove_batch(
|
|
self, sv: list[crypto.Scalar], gamma: list[crypto.Scalar]
|
|
) -> BulletproofPlus:
|
|
M, logM, V, gamma = self.prove_setup(sv, gamma)
|
|
hash_cache = _ensure_dst_key()
|
|
while True:
|
|
self.gc(10)
|
|
try:
|
|
return self._prove_batch_main(
|
|
V, gamma, hash_cache, logM, _BP_LOG_N, M, _BP_N
|
|
)
|
|
except BulletProofGenException:
|
|
self.prove_setup_aLaR(M * _BP_N, sv)
|
|
continue
|
|
|
|
def _prove_batch_main(
|
|
self,
|
|
V: KeyVBase,
|
|
gamma: list[crypto.Scalar],
|
|
hash_cache: bytearray,
|
|
logM: int,
|
|
logN: int,
|
|
M: int,
|
|
N: int,
|
|
) -> BulletproofPlus:
|
|
_hash_vct_to_scalar(hash_cache, V)
|
|
|
|
MN = M * N
|
|
logMN = logM + logN
|
|
|
|
tmp = _ensure_dst_key()
|
|
tmp2 = _ensure_dst_key()
|
|
memcpy(hash_cache, 0, _INITIAL_TRANSCRIPT, 0, len(_INITIAL_TRANSCRIPT))
|
|
_hash_cache_mash(hash_cache, hash_cache, _hash_vct_to_scalar(tmp, V))
|
|
|
|
# compute A = 8^{-1} ( \alpha G + \sum_{i=0}^{MN-1} a_{L,i} \Gi_i + a_{R,i} \Hi_i)
|
|
aL = self.aL
|
|
aR = self.aR
|
|
inv_8_sc = crypto.decodeint_into_noreduce(None, _INV_EIGHT)
|
|
aL8 = KeyVEval(
|
|
len(aL),
|
|
lambda i, d: crypto.sc_mul_into(d, aL[i], inv_8_sc), # noqa: F821
|
|
raw=True,
|
|
)
|
|
aR8 = KeyVEval(
|
|
len(aL),
|
|
lambda i, d: crypto.sc_mul_into(d, aR[i], inv_8_sc), # noqa: F821
|
|
raw=True,
|
|
)
|
|
alpha = _sc_gen()
|
|
|
|
A = _ensure_dst_key()
|
|
Gprec = self._gprec_aux(MN) # Extended precomputed GiHi
|
|
Hprec = self._hprec_aux(MN)
|
|
_vector_exponent_custom(
|
|
Gprec, Hprec, a=None, b=None, a_raw=aL8, b_raw=aR8, dst=A
|
|
)
|
|
_sc_mul(tmp, alpha, _INV_EIGHT)
|
|
_add_keys(A, A, _scalarmult_base(_tmp_bf_1, tmp))
|
|
del (aL8, aR8, inv_8_sc)
|
|
self.gc(11)
|
|
|
|
# Challenges
|
|
y = _hash_cache_mash(None, hash_cache, A)
|
|
if y == _ZERO:
|
|
raise BulletProofGenException()
|
|
z = _hash_to_scalar(None, y)
|
|
if z == _ZERO:
|
|
raise BulletProofGenException()
|
|
_copy_key(hash_cache, z)
|
|
self.gc(12)
|
|
|
|
zc = crypto.decodeint_into_noreduce(None, z)
|
|
z_squared = crypto.encodeint_into(None, crypto.sc_mul_into(_tmp_sc_1, zc, zc))
|
|
d_vct = VctD(N, M, z_squared, raw=True)
|
|
del (z,)
|
|
|
|
# aL1 = aL - z
|
|
aL1_sc = crypto.Scalar()
|
|
|
|
def aL1_fnc(i, d):
|
|
return crypto.encodeint_into(d, crypto.sc_sub_into(aL1_sc, aL.to(i), zc))
|
|
|
|
aprime = KeyVEval(MN, aL1_fnc, raw=False) # aL1
|
|
|
|
# aR1[i] = (aR[i] - z) + d[i] * y**(MN-i)
|
|
y_sc = crypto.decodeint_into_noreduce(None, y)
|
|
yinv = crypto.sc_inv_into(None, y_sc)
|
|
_sc_square_mult(_tmp_sc_5, y_sc, MN - 1) # y**(MN-1)
|
|
crypto.sc_mul_into(_tmp_sc_5, _tmp_sc_5, y_sc) # y**MN
|
|
|
|
ypow_back = KeyVPowersBackwards(
|
|
MN + 1, y, x_inv=yinv, x_max=_tmp_sc_5, raw=True
|
|
)
|
|
aR1_sc1 = crypto.Scalar()
|
|
|
|
def aR1_fnc(i, d):
|
|
crypto.sc_add_into(aR1_sc1, aR.to(i), zc)
|
|
crypto.sc_muladd_into(aR1_sc1, d_vct[i], ypow_back[MN - i], aR1_sc1)
|
|
return crypto.encodeint_into(d, aR1_sc1)
|
|
|
|
bprime = KeyVEval(MN, aR1_fnc, raw=False) # aR1
|
|
|
|
self.gc(13)
|
|
_copy_key(tmp, _ONE)
|
|
alpha1 = _copy_key(None, alpha)
|
|
crypto.sc_mul_into(_tmp_sc_4, ypow_back.x_max, y_sc)
|
|
crypto.encodeint_into(_tmp_bf_0, _tmp_sc_4) # compute y**(MN+1)
|
|
for j in range(len(V)):
|
|
_sc_mul(tmp, tmp, z_squared)
|
|
_sc_mul(tmp2, _tmp_bf_0, tmp)
|
|
_sc_muladd(alpha1, tmp2, gamma[j], alpha1)
|
|
|
|
# y, y**-1 powers
|
|
ypow = _sc_square_mult(None, y_sc, MN >> 1)
|
|
yinvpow = _sc_square_mult(None, yinv, MN >> 1)
|
|
del (z_squared, alpha)
|
|
|
|
# Proof loop phase
|
|
challenge = _ensure_dst_key()
|
|
challenge_inv = _ensure_dst_key()
|
|
rnd = 0
|
|
nprime = MN
|
|
Gprime = Gprec
|
|
Hprime = Hprec
|
|
L = _ensure_dst_keyvect(None, logMN)
|
|
R = _ensure_dst_keyvect(None, logMN)
|
|
tmp_sc_1 = crypto.Scalar()
|
|
del (logMN,)
|
|
if not self.save_mem:
|
|
del (Gprec, Hprec)
|
|
|
|
dL = _ensure_dst_key()
|
|
dR = _ensure_dst_key()
|
|
cL = _ensure_dst_key()
|
|
cR = _ensure_dst_key()
|
|
while nprime > 1:
|
|
npr2 = nprime
|
|
nprime >>= 1
|
|
self.gc(22)
|
|
|
|
# Compute cL, cR
|
|
# cL = \\sum_i y^{i+1} * aprime_i * bprime_{i + nprime}
|
|
# cL = \\sum_i y^{i+1} * aprime_{i + nprime} * y^{nprime} * bprime_{i}
|
|
_weighted_inner_product(
|
|
cL, aprime.slice_view(0, nprime), bprime.slice_view(nprime, npr2), y
|
|
)
|
|
|
|
def vec_sc_fnc(i, d):
|
|
crypto.decodeint_into_noreduce(tmp_sc_1, aprime.to(i + nprime))
|
|
crypto.sc_mul_into(tmp_sc_1, tmp_sc_1, ypow)
|
|
crypto.encodeint_into(d, tmp_sc_1)
|
|
|
|
vec_aprime_x_ypownprime = KeyVEval(nprime, vec_sc_fnc)
|
|
_weighted_inner_product(
|
|
cR, vec_aprime_x_ypownprime, bprime.slice_view(0, nprime), y
|
|
)
|
|
del (vec_aprime_x_ypownprime,)
|
|
self.gc(25)
|
|
|
|
_sc_gen(dL)
|
|
_sc_gen(dR)
|
|
|
|
# Compute L[r], R[r]
|
|
# L[r] = cL * 8^{-1} * H + dL * 8^{-1} * G +
|
|
# \\sum_i aprime_{i} * 8^{-1} * y^{-nprime} * Gprime_{nprime + i} +
|
|
# bprime_{nprime + i} * 8^{-1} * Hprime_{i}
|
|
#
|
|
# R[r] = cR * 8^{-1} * H + dR * 8^{-1} * G +
|
|
# \\sum_i aprime_{nprime + i} * 8^{-1} * y^{nprime} * Gprime_{i} +
|
|
# bprime_{i} * 8^{-1} * Hprime_{nprime + i}
|
|
_compute_LR(
|
|
size=nprime,
|
|
y=yinvpow,
|
|
G=Gprime,
|
|
G0=nprime,
|
|
H=Hprime,
|
|
H0=0,
|
|
a=aprime,
|
|
a0=0,
|
|
b=bprime,
|
|
b0=nprime,
|
|
c=cL,
|
|
d=dL,
|
|
tmp=tmp,
|
|
)
|
|
L.read(rnd, tmp)
|
|
|
|
_compute_LR(
|
|
size=nprime,
|
|
y=ypow,
|
|
G=Gprime,
|
|
G0=0,
|
|
H=Hprime,
|
|
H0=nprime,
|
|
a=aprime,
|
|
a0=nprime,
|
|
b=bprime,
|
|
b0=0,
|
|
c=cR,
|
|
d=dR,
|
|
tmp=tmp,
|
|
)
|
|
R.read(rnd, tmp)
|
|
self.gc(26)
|
|
|
|
_hash_cache_mash(challenge, hash_cache, L[rnd], R[rnd])
|
|
if challenge == _ZERO:
|
|
raise BulletProofGenException()
|
|
|
|
_invert(challenge_inv, challenge)
|
|
_sc_mul(tmp, crypto.encodeint_into(_tmp_bf_0, yinvpow), challenge)
|
|
self.gc(27)
|
|
|
|
# Hadamard fold Gprime, Hprime
|
|
# When memory saving is enabled, Gprime and Hprime vectors are folded in-memory for round=1
|
|
# Depth 2 in-memory folding would be also possible if needed: np2 = nprime // 2
|
|
# Gprime_new[i] = c * (a * Gprime[i] + b * Gprime[i+nprime]) +
|
|
# d * (a * Gprime[np2 + i] + b * Gprime[i+nprime + np2])
|
|
Gprime_new = Gprime
|
|
if self.save_mem and rnd == 0:
|
|
Gprime = KeyHadamardFoldedVct(
|
|
Gprime, a=challenge_inv, b=tmp, gc_fnc=_gc_iter
|
|
)
|
|
elif (self.save_mem and rnd == 1) or (not self.save_mem and rnd == 0):
|
|
Gprime_new = KeyV(nprime)
|
|
|
|
if not self.save_mem or rnd != 0:
|
|
Gprime = _hadamard_fold(Gprime, challenge_inv, tmp, into=Gprime_new)
|
|
Gprime.resize(nprime)
|
|
del (Gprime_new,)
|
|
self.gc(30)
|
|
|
|
Hprime_new = Hprime
|
|
if self.save_mem and rnd == 0:
|
|
Hprime = KeyHadamardFoldedVct(
|
|
Hprime, a=challenge, b=challenge_inv, gc_fnc=_gc_iter
|
|
)
|
|
elif (self.save_mem and rnd == 1) or (not self.save_mem and rnd == 0):
|
|
Hprime_new = KeyV(nprime)
|
|
|
|
if not self.save_mem or rnd != 0:
|
|
Hprime = _hadamard_fold(
|
|
Hprime, challenge, challenge_inv, into=Hprime_new
|
|
)
|
|
Hprime.resize(nprime)
|
|
del (Hprime_new,)
|
|
self.gc(30)
|
|
|
|
# Scalar fold aprime, bprime
|
|
# aprime[i] = challenge * aprime[i] + tmp * aprime[i + nprime]
|
|
# bprime[i] = challenge_inv * bprime[i] + challenge * bprime[i + nprime]
|
|
# When memory saving is enabled, aprime vector is folded in-memory for round=1
|
|
_sc_mul(tmp, challenge_inv, ypow)
|
|
|
|
aprime_new = aprime
|
|
if self.save_mem and rnd == 0:
|
|
aprime = KeyScalarFoldedVct(aprime, a=challenge, b=tmp, gc_fnc=_gc_iter)
|
|
elif (self.save_mem and rnd == 1) or (not self.save_mem and rnd == 0):
|
|
aprime_new = KeyV(nprime)
|
|
|
|
if not self.save_mem or rnd != 0:
|
|
for i in range(nprime):
|
|
_sc_mul(tmp2, aprime.to(i), challenge)
|
|
aprime_new.read(
|
|
i, _sc_muladd(_tmp_bf_0, aprime.to(i + nprime), tmp, tmp2)
|
|
)
|
|
|
|
aprime = aprime_new
|
|
aprime.resize(nprime)
|
|
|
|
if (self.save_mem and rnd == 1) or (not self.save_mem and rnd == 0):
|
|
pass
|
|
# self.aL = None
|
|
# del (aL1_fnc, aL1_sc, aL)
|
|
self.gc(31)
|
|
|
|
bprime_new = KeyV(nprime) if rnd == 0 else bprime
|
|
if rnd == 0:
|
|
# Two-phase folding for bprime, so it can be linearly scanned (faster) for r=0 (eval vector)
|
|
for i in range(nprime):
|
|
bprime_new.read(i, _sc_mul(tmp, bprime[i], challenge_inv))
|
|
for i in range(nprime):
|
|
_sc_muladd(tmp, bprime[i + nprime], challenge, bprime_new[i])
|
|
bprime_new.read(i, tmp)
|
|
|
|
self.aR = None
|
|
del (aR1_fnc, aR1_sc1, aR, d_vct, ypow_back)
|
|
self.gc(31)
|
|
|
|
else:
|
|
for i in range(nprime):
|
|
_sc_mul(tmp2, bprime.to(i), challenge_inv)
|
|
bprime_new.read(
|
|
i, _sc_muladd(_tmp_bf_0, bprime.to(i + nprime), challenge, tmp2)
|
|
)
|
|
|
|
bprime = bprime_new
|
|
bprime.resize(nprime)
|
|
self.gc(32)
|
|
|
|
_sc_muladd(alpha1, dL, _sc_mul(tmp, challenge, challenge), alpha1)
|
|
_sc_muladd(alpha1, dR, _sc_mul(tmp, challenge_inv, challenge_inv), alpha1)
|
|
|
|
# end: update ypow, yinvpow; reduce by halves
|
|
nnprime = nprime >> 1
|
|
if nnprime > 0:
|
|
crypto.sc_mul_into(
|
|
ypow, ypow, _sc_square_mult(_tmp_sc_1, yinv, nnprime)
|
|
)
|
|
crypto.sc_mul_into(
|
|
yinvpow, yinvpow, _sc_square_mult(_tmp_sc_1, y_sc, nnprime)
|
|
)
|
|
|
|
self.gc(49)
|
|
rnd += 1
|
|
|
|
# Final round computations
|
|
del (cL, cR, dL, dR)
|
|
self.gc(50)
|
|
|
|
r = _sc_gen()
|
|
s = _sc_gen()
|
|
d_ = _sc_gen()
|
|
eta = _sc_gen()
|
|
|
|
muex = MultiExpSequential()
|
|
muex.add_pair(_sc_mul(tmp, r, _INV_EIGHT), Gprime.to(0))
|
|
muex.add_pair(_sc_mul(tmp, s, _INV_EIGHT), Hprime.to(0))
|
|
muex.add_pair(_sc_mul(tmp, d_, _INV_EIGHT), _XMR_G)
|
|
|
|
_sc_mul(tmp, r, y)
|
|
_sc_mul(tmp, tmp, bprime[0])
|
|
_sc_mul(tmp2, s, y)
|
|
_sc_mul(tmp2, tmp2, aprime[0])
|
|
_sc_add(tmp, tmp, tmp2)
|
|
muex.add_pair(_sc_mul(tmp2, tmp, _INV_EIGHT), _XMR_H)
|
|
A1 = _multiexp(None, muex)
|
|
|
|
_sc_mul(tmp, r, y)
|
|
_sc_mul(tmp, tmp, s)
|
|
_sc_mul(tmp, tmp, _INV_EIGHT)
|
|
_sc_mul(tmp2, eta, _INV_EIGHT)
|
|
B = _add_keys2(None, tmp2, tmp, _XMR_H)
|
|
|
|
e = _hash_cache_mash(None, hash_cache, A1, B)
|
|
if e == _ZERO:
|
|
raise BulletProofGenException()
|
|
|
|
e_squared = _sc_mul(None, e, e)
|
|
r1 = _sc_muladd(None, aprime[0], e, r)
|
|
s1 = _sc_muladd(None, bprime[0], e, s)
|
|
d1 = _sc_muladd(None, d_, e, eta)
|
|
_sc_muladd(d1, alpha1, e_squared, d1)
|
|
|
|
from .serialize_messages.tx_rsig_bulletproof import BulletproofPlus
|
|
|
|
return BulletproofPlus(V=V, A=A, A1=A1, B=B, r1=r1, s1=s1, d1=d1, L=L, R=R)
|
|
|
|
def verify(self, proof: BulletproofPlus) -> bool:
|
|
return self.verify_batch([proof])
|
|
|
|
def verify_batch(self, proofs: list[BulletproofPlus]):
|
|
"""
|
|
BP+ batch verification
|
|
"""
|
|
max_length = 0
|
|
for proof in proofs:
|
|
utils.ensure(_is_reduced(proof.r1), "Input scalar not in range")
|
|
utils.ensure(_is_reduced(proof.s1), "Input scalar not in range")
|
|
utils.ensure(_is_reduced(proof.d1), "Input scalar not in range")
|
|
utils.ensure(len(proof.V) >= 1, "V does not have at least one element")
|
|
utils.ensure(len(proof.L) == len(proof.R), "|L| != |R|")
|
|
utils.ensure(len(proof.L) > 0, "Empty proof")
|
|
max_length = max(max_length, len(proof.L))
|
|
|
|
utils.ensure(max_length < 32, "At least one proof is too large")
|
|
self.gc(1)
|
|
|
|
logN = 6
|
|
N = 1 << logN
|
|
tmp = _ensure_dst_key()
|
|
|
|
max_length = 0 # size of each of the longest proof's inner-product vectors
|
|
nV = 0 # number of output commitments across all proofs
|
|
inv_offset = 0
|
|
max_logm = 0
|
|
|
|
proof_data = []
|
|
to_invert_offset = 0
|
|
to_invert = _ensure_dst_keyvect(None, 11 * len(proofs))
|
|
for proof in proofs:
|
|
max_length = max(max_length, len(proof.L))
|
|
nV += len(proof.V)
|
|
pd = BulletProofPlusData()
|
|
proof_data.append(pd)
|
|
|
|
# Reconstruct the challenges
|
|
transcript = bytearray(_INITIAL_TRANSCRIPT)
|
|
_hash_cache_mash(transcript, transcript, _hash_vct_to_scalar(tmp, proof.V))
|
|
|
|
pd.y = _hash_cache_mash(None, transcript, proof.A)
|
|
utils.ensure(not (pd.y == _ZERO), "y == 0")
|
|
pd.z = _hash_to_scalar(None, pd.y)
|
|
_copy_key(transcript, pd.z)
|
|
|
|
# Determine the number of inner-product rounds based on proof size
|
|
pd.logM = 0
|
|
while True:
|
|
M = 1 << pd.logM
|
|
if M > _BP_M or M >= len(proof.V):
|
|
break
|
|
pd.logM += 1
|
|
|
|
max_logm = max(max_logm, pd.logM)
|
|
rounds = pd.logM + logN
|
|
utils.ensure(rounds > 0, "zero rounds")
|
|
|
|
# The inner-product challenges are computed per round
|
|
pd.challenges = _ensure_dst_keyvect(None, rounds)
|
|
for j in range(rounds):
|
|
pd.challenges[j] = _hash_cache_mash(
|
|
pd.challenges[j], transcript, proof.L[j], proof.R[j]
|
|
)
|
|
utils.ensure(pd.challenges[j] != _ZERO, "challenges[j] == 0")
|
|
|
|
# Final challenge
|
|
pd.e = _hash_cache_mash(None, transcript, proof.A1, proof.B)
|
|
utils.ensure(pd.e != _ZERO, "e == 0")
|
|
|
|
# batch scalar inversions
|
|
pd.inv_offset = inv_offset
|
|
for j in range(rounds): # max rounds is 10 = lg(16*64) = lg(1024)
|
|
to_invert.read(to_invert_offset, pd.challenges[j])
|
|
to_invert_offset += 1
|
|
|
|
to_invert.read(to_invert_offset, pd.y)
|
|
to_invert_offset += 1
|
|
inv_offset += rounds + 1
|
|
self.gc(2)
|
|
|
|
to_invert.resize(inv_offset)
|
|
self.gc(2)
|
|
|
|
utils.ensure(max_length < 32, "At least one proof is too large")
|
|
maxMN = 1 << max_length
|
|
tmp2 = _ensure_dst_key()
|
|
|
|
# multiexp_size = nV + (2 * (max_logm + logN) + 3) * len(proofs) + 2 * maxMN
|
|
Gprec = self._gprec_aux(maxMN) # Extended precomputed GiHi
|
|
Hprec = self._hprec_aux(maxMN)
|
|
muex_expl = MultiExpSequential()
|
|
muex_gh = MultiExpSequential(
|
|
point_fnc=lambda i, d: Gprec[i >> 1] if i & 1 == 0 else Hprec[i >> 1]
|
|
)
|
|
|
|
inverses = _invert_batch(to_invert)
|
|
del (to_invert,)
|
|
self.gc(3)
|
|
|
|
# Weights and aggregates
|
|
#
|
|
# The idea is to take the single multiscalar multiplication used in the verification
|
|
# of each proof in the batch and weight it using a random weighting factor, resulting
|
|
# in just one multiscalar multiplication check to zero for the entire batch.
|
|
# We can further simplify the verifier complexity by including common group elements
|
|
# only once in this single multiscalar multiplication.
|
|
# Common group elements' weighted scalar sums are tracked across proofs for this reason.
|
|
#
|
|
# To build a multiscalar multiplication for each proof, we use the method described in
|
|
# Section 6.1 of the preprint. Note that the result given there does not account for
|
|
# the construction of the inner-product inputs that are produced in the range proof
|
|
# verifier algorithm; we have done so here.
|
|
|
|
G_scalar = bytearray(_ZERO)
|
|
H_scalar = bytearray(_ZERO)
|
|
# Gi_scalars = _vector_dup(_ZERO, maxMN)
|
|
# Hi_scalars = _vector_dup(_ZERO, maxMN)
|
|
|
|
proof_data_index = 0
|
|
for proof in proofs:
|
|
self.gc(4)
|
|
pd = proof_data[proof_data_index] # type: BulletProofPlusData
|
|
proof_data_index += 1
|
|
|
|
utils.ensure(len(proof.L) == 6 + pd.logM, "Proof is not the expected size")
|
|
M = 1 << pd.logM
|
|
MN = M * N
|
|
weight = bytearray(_ZERO)
|
|
while weight == _ZERO:
|
|
_sc_gen(weight)
|
|
|
|
# Rescale previously offset proof elements
|
|
#
|
|
# Compute necessary powers of the y-challenge
|
|
y_MN = bytearray(pd.y)
|
|
y_MN_1 = _ensure_dst_key(None)
|
|
temp_MN = MN
|
|
while temp_MN > 1:
|
|
_sc_mul(y_MN, y_MN, y_MN)
|
|
temp_MN /= 2
|
|
|
|
_sc_mul(y_MN_1, y_MN, pd.y)
|
|
|
|
# V_j: -e**2 * z**(2*j+1) * y**(MN+1) * weight
|
|
e_squared = _ensure_dst_key(None)
|
|
_sc_mul(e_squared, pd.e, pd.e)
|
|
|
|
z_squared = _ensure_dst_key(None)
|
|
_sc_mul(z_squared, pd.z, pd.z)
|
|
|
|
_sc_sub(tmp, _ZERO, e_squared)
|
|
_sc_mul(tmp, tmp, y_MN_1)
|
|
_sc_mul(tmp, tmp, weight)
|
|
|
|
for j in range(len(proof.V)):
|
|
_sc_mul(tmp, tmp, z_squared)
|
|
# This ensures that all such group elements are in the prime-order subgroup.
|
|
muex_expl.add_pair(tmp, _scalarmult8(tmp2, proof.V[j]))
|
|
|
|
# B: -weight
|
|
_sc_mul(tmp, _MINUS_ONE, weight)
|
|
muex_expl.add_pair(tmp, _scalarmult8(tmp2, proof.B))
|
|
|
|
# A1: -weight * e
|
|
_sc_mul(tmp, tmp, pd.e)
|
|
muex_expl.add_pair(tmp, _scalarmult8(tmp2, proof.A1))
|
|
|
|
# A: -weight * e * e
|
|
minus_weight_e_squared = _sc_mul(None, tmp, pd.e)
|
|
muex_expl.add_pair(minus_weight_e_squared, _scalarmult8(tmp2, proof.A))
|
|
|
|
# G: weight * d1
|
|
_sc_muladd(G_scalar, weight, proof.d1, G_scalar)
|
|
self.gc(5)
|
|
|
|
# Windowed vector
|
|
# d[j*N+i] = z **(2*(j+1)) * 2**i
|
|
# d is being read iteratively from [0..MN) only once.
|
|
# Can be computed on the fly: hold last z and 2**i, add together
|
|
d = VctD(N, M, z_squared)
|
|
|
|
# More efficient computation of sum(d)
|
|
sum_d = _ensure_dst_key(None)
|
|
_sc_mul(
|
|
sum_d, _TWO_SIXTY_FOUR_MINUS_ONE, _sum_of_even_powers(None, pd.z, 2 * M)
|
|
)
|
|
|
|
# H: weight*( r1*y*s1 + e**2*( y**(MN+1)*z*sum(d) + (z**2-z)*sum(y) ) )
|
|
sum_y = _sum_of_scalar_powers(None, pd.y, MN)
|
|
_sc_sub(tmp, z_squared, pd.z)
|
|
_sc_mul(tmp, tmp, sum_y)
|
|
|
|
_sc_mul(tmp2, y_MN_1, pd.z)
|
|
_sc_mul(tmp2, tmp2, sum_d)
|
|
_sc_add(tmp, tmp, tmp2)
|
|
_sc_mul(tmp, tmp, e_squared)
|
|
_sc_mul(tmp2, proof.r1, pd.y)
|
|
_sc_mul(tmp2, tmp2, proof.s1)
|
|
_sc_add(tmp, tmp, tmp2)
|
|
_sc_muladd(H_scalar, tmp, weight, H_scalar)
|
|
|
|
# Compute the number of rounds for the inner-product argument
|
|
rounds = pd.logM + logN
|
|
utils.ensure(rounds > 0, "zero rounds")
|
|
|
|
# challenges_inv = inverses[pd.inv_offset]
|
|
yinv = inverses[pd.inv_offset + rounds]
|
|
self.gc(6)
|
|
|
|
# Description of challenges_cache:
|
|
# Let define ch_[i] = pd.challenges[i] and
|
|
# chi[i] = pd.challenges[i]^{-1}
|
|
# Also define b_j[i] = i-th bit of integer j, 0 is MSB
|
|
# encoded in {rounds} bits
|
|
#
|
|
# challenges_cache[i] contains multiplication ch_ or chi depending on the b_i
|
|
# i.e., its binary representation. chi is for 0, ch_ for 1 in the b_i repr.
|
|
#
|
|
# challenges_cache[i] = \\mult_{j \in [0, rounds)} (b_i[j] * ch_[j]) +
|
|
# (1-b_i[j]) * chi[j]
|
|
# Originally, it is constructed iteratively, starting with 1 bit, 2 bits.
|
|
# We cannot afford having it all precomputed, thus we precompute it up to
|
|
# a threshold challenges_cache_depth_lim bits, the rest is evaluated on the fly
|
|
challenges_cache_depth_lim = const(8)
|
|
challenges_cache_depth = min(rounds, challenges_cache_depth_lim)
|
|
challenges_cache = _ensure_dst_keyvect(None, 1 << challenges_cache_depth)
|
|
|
|
challenges_cache[0] = inverses[pd.inv_offset]
|
|
challenges_cache[1] = pd.challenges[0]
|
|
|
|
for j in range(1, challenges_cache_depth):
|
|
slots = 1 << (j + 1)
|
|
for s in range(slots - 1, -1, -2):
|
|
challenges_cache.read(
|
|
s,
|
|
_sc_mul(
|
|
_tmp_bf_0,
|
|
challenges_cache[s // 2],
|
|
pd.challenges[j], # even s
|
|
),
|
|
)
|
|
challenges_cache.read(
|
|
s - 1,
|
|
_sc_mul(
|
|
_tmp_bf_0,
|
|
challenges_cache[s // 2],
|
|
inverses[pd.inv_offset + j], # odd s
|
|
),
|
|
)
|
|
|
|
if rounds > challenges_cache_depth:
|
|
challenges_cache = KeyChallengeCacheVct(
|
|
rounds,
|
|
pd.challenges,
|
|
inverses.slice_view(pd.inv_offset, pd.inv_offset + rounds + 1),
|
|
challenges_cache,
|
|
)
|
|
|
|
# Gi and Hi
|
|
self.gc(7)
|
|
e_r1_w_y = _ensure_dst_key()
|
|
_sc_mul(e_r1_w_y, pd.e, proof.r1)
|
|
_sc_mul(e_r1_w_y, e_r1_w_y, weight)
|
|
e_s1_w = _ensure_dst_key()
|
|
_sc_mul(e_s1_w, pd.e, proof.s1)
|
|
_sc_mul(e_s1_w, e_s1_w, weight)
|
|
e_squared_z_w = _ensure_dst_key()
|
|
_sc_mul(e_squared_z_w, e_squared, pd.z)
|
|
_sc_mul(e_squared_z_w, e_squared_z_w, weight)
|
|
minus_e_squared_z_w = _ensure_dst_key()
|
|
_sc_sub(minus_e_squared_z_w, _ZERO, e_squared_z_w)
|
|
minus_e_squared_w_y = _ensure_dst_key()
|
|
_sc_sub(minus_e_squared_w_y, _ZERO, e_squared)
|
|
_sc_mul(minus_e_squared_w_y, minus_e_squared_w_y, weight)
|
|
_sc_mul(minus_e_squared_w_y, minus_e_squared_w_y, y_MN)
|
|
|
|
g_scalar = _ensure_dst_key()
|
|
h_scalar = _ensure_dst_key()
|
|
for i in range(MN):
|
|
if i % 8 == 0:
|
|
self.gc(8)
|
|
_copy_key(g_scalar, e_r1_w_y)
|
|
|
|
# Use the binary decomposition of the index
|
|
_sc_muladd(g_scalar, g_scalar, challenges_cache[i], e_squared_z_w)
|
|
_sc_muladd(
|
|
h_scalar,
|
|
e_s1_w,
|
|
challenges_cache[(~i) & (MN - 1)],
|
|
minus_e_squared_z_w,
|
|
)
|
|
|
|
# Complete the scalar derivation
|
|
_sc_muladd(h_scalar, minus_e_squared_w_y, d[i], h_scalar)
|
|
# Gi_scalars.read(i, _sc_add(Gi_scalars[i], Gi_scalars[i], g_scalar)) # Gi_scalars[i] accumulates across proofs; (g1+g2)G = g1G + g2G
|
|
# Hi_scalars.read(i, _sc_add(Hi_scalars[i], Hi_scalars[i], h_scalar))
|
|
|
|
muex_gh.add_scalar_idx(g_scalar, 2 * i)
|
|
muex_gh.add_scalar_idx(h_scalar, 2 * i + 1)
|
|
|
|
# Update iterated values
|
|
_sc_mul(e_r1_w_y, e_r1_w_y, yinv)
|
|
_sc_mul(minus_e_squared_w_y, minus_e_squared_w_y, yinv)
|
|
del (challenges_cache, d)
|
|
self.gc(9)
|
|
|
|
# L_j: -weight*e*e*challenges[j]**2
|
|
# R_j: -weight*e*e*challenges[j]**(-2)
|
|
for j in range(rounds):
|
|
_sc_mul(tmp, pd.challenges[j], pd.challenges[j])
|
|
_sc_mul(tmp, tmp, minus_weight_e_squared)
|
|
muex_expl.add_pair(tmp, _scalarmult8(tmp2, proof.L[j]))
|
|
|
|
_sc_mul(tmp, inverses[pd.inv_offset + j], inverses[pd.inv_offset + j])
|
|
_sc_mul(tmp, tmp, minus_weight_e_squared)
|
|
muex_expl.add_pair(tmp, _scalarmult8(tmp2, proof.R[j]))
|
|
proof_data[proof_data_index - 1] = None
|
|
del (pd,)
|
|
del (inverses,)
|
|
self.gc(10)
|
|
|
|
# Verify all proofs in the weighted batch
|
|
muex_expl.add_pair(G_scalar, _XMR_G)
|
|
muex_expl.add_pair(H_scalar, _XMR_H)
|
|
# for i in range(maxMN):
|
|
# muex_gh.add_scalar_idx(Gi_scalars[i], i*2)
|
|
# muex_gh.add_scalar_idx(Hi_scalars[i], i*2 + 1)
|
|
|
|
m1 = _multiexp(tmp, muex_gh)
|
|
m2 = _multiexp(tmp2, muex_expl)
|
|
muex = _add_keys(tmp, m1, m2)
|
|
|
|
if muex != _ONE:
|
|
raise ValueError("Verification error")
|
|
|
|
return True
|