You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/apps/monero/xmr/bulletproof.py

1735 lines
54 KiB

import gc
from micropython import const
from trezor import utils
from trezor.utils import memcpy as tmemcpy
from apps.monero.xmr import crypto
from apps.monero.xmr.serialize.int_serialize import dump_uvarint_b_into, uvarint_size
# Constants
_BP_LOG_N = const(6)
_BP_N = const(64) # 1 << _BP_LOG_N
_BP_M = const(16) # maximal number of bulletproofs
_ZERO = b"\x00" * 32
_ONE = b"\x01" + b"\x00" * 31
_TWO = b"\x02" + b"\x00" * 31
_EIGHT = b"\x08" + b"\x00" * 31
_INV_EIGHT = crypto.INV_EIGHT
_MINUS_ONE = b"\xec\xd3\xf5\x5c\x1a\x63\x12\x58\xd6\x9c\xf7\xa2\xde\xf9\xde\x14\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10"
# _MINUS_INV_EIGHT = b"\x74\xa4\x19\x7a\xf0\x7d\x0b\xf7\x05\xc2\xda\x25\x2b\x5c\x0b\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a"
# Monero H point
_XMR_H = b"\x8b\x65\x59\x70\x15\x37\x99\xaf\x2a\xea\xdc\x9f\xf1\xad\xd0\xea\x6c\x72\x51\xd5\x41\x54\xcf\xa9\x2c\x17\x3a\x0d\xd3\x9c\x1f\x94"
_XMR_HP = crypto.xmr_H()
# ip12 = inner_product(oneN, twoN);
_BP_IP12 = b"\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
#
# Rct keys operations
# tmp_x are global working registers to minimize memory allocations / heap fragmentation.
# Caution has to be exercised when using the registers and operations using the registers
#
_tmp_bf_0 = bytearray(32)
_tmp_bf_1 = bytearray(32)
_tmp_bf_2 = bytearray(32)
_tmp_bf_exp = bytearray(11 + 32 + 4)
_tmp_pt_1 = crypto.new_point()
_tmp_pt_2 = crypto.new_point()
_tmp_pt_3 = crypto.new_point()
_tmp_pt_4 = crypto.new_point()
_tmp_sc_1 = crypto.new_scalar()
_tmp_sc_2 = crypto.new_scalar()
_tmp_sc_3 = crypto.new_scalar()
_tmp_sc_4 = crypto.new_scalar()
def _ensure_dst_key(dst=None):
if dst is None:
dst = bytearray(32)
return dst
def memcpy(dst, dst_off, src, src_off, len):
if dst is not None:
tmemcpy(dst, dst_off, src, src_off, len)
return dst
def _alloc_scalars(num=1):
return (crypto.new_scalar() for _ in range(num))
def _copy_key(dst, src):
for i in range(32):
dst[i] = src[i]
return dst
def _init_key(val, dst=None):
dst = _ensure_dst_key(dst)
return _copy_key(dst, val)
def _gc_iter(i):
if i & 127 == 0:
gc.collect()
def _invert(dst, x):
dst = _ensure_dst_key(dst)
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
crypto.sc_inv_into(_tmp_sc_2, _tmp_sc_1)
crypto.encodeint_into(dst, _tmp_sc_2)
return dst
def _scalarmult_key(dst, P, s, s_raw=None, tmp_pt=_tmp_pt_1):
dst = _ensure_dst_key(dst)
crypto.decodepoint_into(tmp_pt, P)
if s:
crypto.decodeint_into_noreduce(_tmp_sc_1, s)
crypto.scalarmult_into(tmp_pt, tmp_pt, _tmp_sc_1 if s else s_raw)
crypto.encodepoint_into(dst, tmp_pt)
return dst
def _scalarmultH(dst, x):
dst = _ensure_dst_key(dst)
crypto.decodeint_into(_tmp_sc_1, x)
crypto.scalarmult_into(_tmp_pt_1, _XMR_HP, _tmp_sc_1)
crypto.encodepoint_into(dst, _tmp_pt_1)
return dst
def _scalarmult_base(dst, x):
dst = _ensure_dst_key(dst)
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
crypto.scalarmult_base_into(_tmp_pt_1, _tmp_sc_1)
crypto.encodepoint_into(dst, _tmp_pt_1)
return dst
def _sc_gen(dst=None):
dst = _ensure_dst_key(dst)
crypto.random_scalar(_tmp_sc_1)
crypto.encodeint_into(dst, _tmp_sc_1)
return dst
def _sc_add(dst, a, b):
dst = _ensure_dst_key(dst)
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
crypto.sc_add_into(_tmp_sc_3, _tmp_sc_1, _tmp_sc_2)
crypto.encodeint_into(dst, _tmp_sc_3)
return dst
def _sc_sub(dst, a, b, a_raw=None, b_raw=None):
dst = _ensure_dst_key(dst)
if a:
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
if b:
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
crypto.sc_sub_into(_tmp_sc_3, _tmp_sc_1 if a else a_raw, _tmp_sc_2 if b else b_raw)
crypto.encodeint_into(dst, _tmp_sc_3)
return dst
def _sc_mul(dst, a, b=None, b_raw=None):
dst = _ensure_dst_key(dst)
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
if b:
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
crypto.sc_mul_into(_tmp_sc_3, _tmp_sc_1, _tmp_sc_2 if b else b_raw)
crypto.encodeint_into(dst, _tmp_sc_3)
return dst
def _sc_muladd(dst, a, b, c, a_raw=None, b_raw=None, c_raw=None, raw=False):
dst = _ensure_dst_key(dst) if not raw else (dst if dst else crypto.new_scalar())
if a:
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
if b:
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
if c:
crypto.decodeint_into_noreduce(_tmp_sc_3, c)
crypto.sc_muladd_into(
_tmp_sc_4 if not raw else dst,
_tmp_sc_1 if a else a_raw,
_tmp_sc_2 if b else b_raw,
_tmp_sc_3 if c else c_raw,
)
if not raw:
crypto.encodeint_into(dst, _tmp_sc_4)
return dst
def _sc_mulsub(dst, a, b, c):
dst = _ensure_dst_key(dst)
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
crypto.decodeint_into_noreduce(_tmp_sc_3, c)
crypto.sc_mulsub_into(_tmp_sc_4, _tmp_sc_1, _tmp_sc_2, _tmp_sc_3)
crypto.encodeint_into(dst, _tmp_sc_4)
return dst
def _add_keys(dst, A, B):
dst = _ensure_dst_key(dst)
crypto.decodepoint_into(_tmp_pt_1, A)
crypto.decodepoint_into(_tmp_pt_2, B)
crypto.point_add_into(_tmp_pt_3, _tmp_pt_1, _tmp_pt_2)
crypto.encodepoint_into(dst, _tmp_pt_3)
return dst
def _sub_keys(dst, A, B):
dst = _ensure_dst_key(dst)
crypto.decodepoint_into(_tmp_pt_1, A)
crypto.decodepoint_into(_tmp_pt_2, B)
crypto.point_sub_into(_tmp_pt_3, _tmp_pt_1, _tmp_pt_2)
crypto.encodepoint_into(dst, _tmp_pt_3)
return dst
def _add_keys2(dst, a, b, B):
dst = _ensure_dst_key(dst)
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
crypto.decodepoint_into(_tmp_pt_1, B)
crypto.add_keys2_into(_tmp_pt_2, _tmp_sc_1, _tmp_sc_2, _tmp_pt_1)
crypto.encodepoint_into(dst, _tmp_pt_2)
return dst
def _add_keys3(dst, a, A, b, B):
dst = _ensure_dst_key(dst)
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
crypto.decodepoint_into(_tmp_pt_1, A)
crypto.decodepoint_into(_tmp_pt_2, B)
crypto.add_keys3_into(_tmp_pt_3, _tmp_sc_1, _tmp_pt_1, _tmp_sc_2, _tmp_pt_2)
crypto.encodepoint_into(dst, _tmp_pt_3)
return dst
def _hash_to_scalar(dst, data):
dst = _ensure_dst_key(dst)
crypto.hash_to_scalar_into(_tmp_sc_1, data)
crypto.encodeint_into(dst, _tmp_sc_1)
return dst
def _hash_vct_to_scalar(dst, data):
dst = _ensure_dst_key(dst)
ctx = crypto.get_keccak()
for x in data:
ctx.update(x)
hsh = ctx.digest()
crypto.decodeint_into(_tmp_sc_1, hsh)
crypto.encodeint_into(dst, _tmp_sc_1)
return dst
def _get_exponent(dst, base, idx):
dst = _ensure_dst_key(dst)
salt = b"bulletproof"
lsalt = const(11) # len(salt)
final_size = lsalt + 32 + uvarint_size(idx)
memcpy(_tmp_bf_exp, 0, base, 0, 32)
memcpy(_tmp_bf_exp, 32, salt, 0, lsalt)
dump_uvarint_b_into(idx, _tmp_bf_exp, 32 + lsalt)
crypto.keccak_hash_into(_tmp_bf_1, _tmp_bf_exp, final_size)
crypto.hash_to_point_into(_tmp_pt_4, _tmp_bf_1)
crypto.encodepoint_into(dst, _tmp_pt_4)
return dst
#
# Key Vectors
#
class KeyVBase:
"""
Base KeyVector object
"""
__slots__ = ("current_idx", "size")
def __init__(self, elems=64):
self.current_idx = 0
self.size = elems
def idxize(self, idx):
if idx < 0:
idx = self.size + idx
if idx >= self.size:
raise IndexError("Index out of bounds")
return idx
def __getitem__(self, item):
raise ValueError("Not supported")
def __setitem__(self, key, value):
raise ValueError("Not supported")
def __iter__(self):
self.current_idx = 0
return self
def __next__(self):
if self.current_idx >= self.size:
raise StopIteration
else:
self.current_idx += 1
return self[self.current_idx - 1]
def __len__(self):
return self.size
def to(self, idx, buff=None, offset=0):
buff = _ensure_dst_key(buff)
return memcpy(buff, offset, self.to(self.idxize(idx)), 0, 32)
def read(self, idx, buff, offset=0):
raise ValueError
def slice(self, res, start, stop):
for i in range(start, stop):
res[i - start] = self[i]
return res
def slice_view(self, start, stop):
return KeyVSliced(self, start, stop)
_CHBITS = const(5)
_CHSIZE = const(1 << _CHBITS)
class KeyV(KeyVBase):
"""
KeyVector abstraction
Constant precomputed buffers = bytes, frozen. Same operation as normal.
Non-constant KeyVector is separated to _CHSIZE elements chunks to avoid problems with
the heap fragmentation. In this it is more probable that the chunks are correctly
allocated as smaller continuous memory is required. Chunk is assumed to
have _CHSIZE elements at all times to minimize corner cases handling. BP require either
multiple of _CHSIZE elements vectors or less than _CHSIZE.
Some chunk-dependent cases are not implemented as they are currently not needed in the BP.
"""
__slots__ = ("current_idx", "size", "d", "mv", "const", "cur", "chunked")
def __init__(self, elems=64, buffer=None, const=False, no_init=False):
super().__init__(elems)
self.d = None
self.mv = None
self.const = const
self.cur = _ensure_dst_key()
self.chunked = False
if no_init:
pass
elif buffer:
self.d = buffer # can be immutable (bytes)
self.size = len(buffer) // 32
else:
self._set_d(elems)
if not no_init:
self._set_mv()
def _set_d(self, elems):
if elems > _CHSIZE and elems % _CHSIZE == 0:
self.chunked = True
gc.collect()
self.d = [bytearray(32 * _CHSIZE) for _ in range(elems // _CHSIZE)]
else:
self.chunked = False
gc.collect()
self.d = bytearray(32 * elems)
def _set_mv(self):
if not self.chunked:
self.mv = memoryview(self.d)
def __getitem__(self, item):
"""
Returns corresponding 32 byte array.
Creates new memoryview on access.
"""
if self.chunked:
return self.to(item)
item = self.idxize(item)
return self.mv[item * 32 : (item + 1) * 32]
def __setitem__(self, key, value):
if self.chunked:
raise ValueError("Not supported") # not needed
if self.const:
raise ValueError("Constant KeyV")
ck = self[key]
for i in range(32):
ck[i] = value[i]
def to(self, idx, buff=None, offset=0):
idx = self.idxize(idx)
if self.chunked:
memcpy(
buff if buff else self.cur,
offset,
self.d[idx >> _CHBITS],
(idx & (_CHSIZE - 1)) << 5,
32,
)
else:
memcpy(buff if buff else self.cur, offset, self.d, idx << 5, 32)
return buff if buff else self.cur
def read(self, idx, buff, offset=0):
idx = self.idxize(idx)
if self.chunked:
memcpy(self.d[idx >> _CHBITS], (idx & (_CHSIZE - 1)) << 5, buff, offset, 32)
else:
memcpy(self.d, idx << 5, buff, offset, 32)
def resize(self, nsize, chop=False, realloc=False):
if self.size == nsize:
return
if self.chunked and nsize <= _CHSIZE:
self.chunked = False # de-chunk
if self.size > nsize and realloc:
gc.collect()
self.d = bytearray(self.d[0][: nsize << 5])
elif self.size > nsize and not chop:
gc.collect()
self.d = self.d[0][: nsize << 5]
else:
gc.collect()
self.d = bytearray(nsize << 5)
elif self.chunked and self.size < nsize:
if nsize % _CHSIZE != 0 or realloc or chop:
raise ValueError("Unsupported") # not needed
for i in range((nsize - self.size) // _CHSIZE):
self.d.append(bytearray(32 * _CHSIZE))
elif self.chunked:
if nsize % _CHSIZE != 0:
raise ValueError("Unsupported") # not needed
for i in range((self.size - nsize) // _CHSIZE):
self.d.pop()
if realloc:
for i in range(nsize // _CHSIZE):
self.d[i] = bytearray(self.d[i])
else:
if self.size > nsize and realloc:
gc.collect()
self.d = bytearray(self.d[: nsize << 5])
elif self.size > nsize and not chop:
gc.collect()
self.d = self.d[: nsize << 5]
else:
gc.collect()
self.d = bytearray(nsize << 5)
self.size = nsize
self._set_mv()
def realloc(self, nsize, collect=False):
self.d = None
self.mv = None
if collect:
gc.collect() # gc collect prev. allocation
self._set_d(nsize)
self.size = nsize
self._set_mv()
def realloc_init_from(self, nsize, src, offset=0, collect=False):
if not isinstance(src, KeyV):
raise ValueError("KeyV supported only")
self.realloc(nsize, collect)
if not self.chunked and not src.chunked:
memcpy(self.d, 0, src.d, offset << 5, nsize << 5)
elif self.chunked and not src.chunked or self.chunked and src.chunked:
for i in range(nsize):
self.read(i, src.to(i + offset))
elif not self.chunked and src.chunked:
for i in range(nsize >> _CHBITS):
memcpy(
self.d,
i << 11,
src.d[i + (offset >> _CHBITS)],
(offset & (_CHSIZE - 1)) << 5 if i == 0 else 0,
nsize << 5 if i <= nsize >> _CHBITS else (nsize & _CHSIZE) << 5,
)
class KeyVEval(KeyVBase):
"""
KeyVector computed / evaluated on demand
"""
__slots__ = ("current_idx", "size", "fnc", "raw", "scalar", "buff")
def __init__(self, elems=64, src=None, raw=False, scalar=True):
super().__init__(elems)
self.fnc = src
self.raw = raw
self.scalar = scalar
self.buff = (
_ensure_dst_key()
if not raw
else (crypto.new_scalar() if scalar else crypto.new_point())
)
def __getitem__(self, item):
return self.fnc(self.idxize(item), self.buff)
def to(self, idx, buff=None, offset=0):
self.fnc(self.idxize(idx), self.buff)
if self.raw:
if offset != 0:
raise ValueError("Not supported")
if self.scalar and buff:
return crypto.sc_copy(buff, self.buff)
elif self.scalar:
return self.buff
else:
raise ValueError("Not supported")
else:
memcpy(buff, offset, self.buff, 0, 32)
return buff if buff else self.buff
class KeyVSized(KeyVBase):
"""
Resized vector, wrapping possibly larger vector
(e.g., precomputed, but has to have exact size for further computations)
"""
__slots__ = ("current_idx", "size", "wrapped")
def __init__(self, wrapped, new_size):
super().__init__(new_size)
self.wrapped = wrapped
def __getitem__(self, item):
return self.wrapped[self.idxize(item)]
def __setitem__(self, key, value):
self.wrapped[self.idxize(key)] = value
class KeyVConst(KeyVBase):
__slots__ = ("current_idx", "size", "elem")
def __init__(self, size, elem, copy=True):
super().__init__(size)
self.elem = _init_key(elem) if copy else elem
def __getitem__(self, item):
return self.elem
def to(self, idx, buff=None, offset=0):
memcpy(buff, offset, self.elem, 0, 32)
return buff if buff else self.elem
class KeyVPrecomp(KeyVBase):
"""
Vector with possibly large size and some precomputed prefix.
Usable for Gi vector with precomputed usual sizes (i.e., 2 output transactions)
but possible to compute further
"""
__slots__ = ("current_idx", "size", "precomp_prefix", "aux_comp_fnc", "buff")
def __init__(self, size, precomp_prefix, aux_comp_fnc):
super().__init__(size)
self.precomp_prefix = precomp_prefix
self.aux_comp_fnc = aux_comp_fnc
self.buff = _ensure_dst_key()
def __getitem__(self, item):
item = self.idxize(item)
if item < len(self.precomp_prefix):
return self.precomp_prefix[item]
return self.aux_comp_fnc(item, self.buff)
def to(self, idx, buff=None, offset=0):
item = self.idxize(idx)
if item < len(self.precomp_prefix):
return self.precomp_prefix.to(item, buff if buff else self.buff, offset)
self.aux_comp_fnc(item, self.buff)
memcpy(buff, offset, self.buff, 0, 32)
return buff if buff else self.buff
class KeyVSliced(KeyVBase):
"""
Sliced in-memory vector version, remapping
"""
__slots__ = ("current_idx", "size", "wrapped", "offset")
def __init__(self, src, start, stop):
super().__init__(stop - start)
self.wrapped = src
self.offset = start
def __getitem__(self, item):
return self.wrapped[self.offset + self.idxize(item)]
def __setitem__(self, key, value):
self.wrapped[self.offset + self.idxize(key)] = value
def resize(self, nsize, chop=False):
raise ValueError("Not supported")
def to(self, idx, buff=None, offset=0):
return self.wrapped.to(self.offset + self.idxize(idx), buff, offset)
def read(self, idx, buff, offset=0):
return self.wrapped.read(self.offset + self.idxize(idx), buff, offset)
class KeyVPowers(KeyVBase):
"""
Vector of x^i. Allows only sequential access (no jumping). Resets on [0,1] access.
"""
__slots__ = ("current_idx", "size", "x", "raw", "cur", "last_idx")
def __init__(self, size, x, raw=False, **kwargs):
super().__init__(size)
self.x = x if not raw else crypto.decodeint_into_noreduce(None, x)
self.raw = raw
self.cur = bytearray(32) if not raw else crypto.new_scalar()
self.last_idx = 0
def __getitem__(self, item):
prev = self.last_idx
item = self.idxize(item)
self.last_idx = item
if item == 0:
return (
_copy_key(self.cur, _ONE)
if not self.raw
else crypto.decodeint_into_noreduce(None, _ONE)
)
elif item == 1:
return (
_copy_key(self.cur, self.x)
if not self.raw
else crypto.sc_copy(self.cur, self.x)
)
elif item == prev:
return self.cur
elif item == prev + 1:
return (
_sc_mul(self.cur, self.cur, self.x)
if not self.raw
else crypto.sc_mul_into(self.cur, self.cur, self.x)
)
else:
raise IndexError(f"Only linear scan allowed: {prev}, {item}")
def set_state(self, idx, val):
self.last_idx = idx
if self.raw:
return crypto.sc_copy(self.cur, val)
else:
return _copy_key(self.cur, val)
class KeyR0(KeyVBase):
"""
Vector r0. Allows only sequential access (no jumping). Resets on [0,1] access.
zt_i = z^{2 + \floor{i/N}} 2^{i % N}
r0_i = ((a_{Ri} + z) y^{i}) + zt_i
Could be composed from smaller vectors, but RAW returns are required
"""
__slots__ = (
"current_idx",
"size",
"N",
"aR",
"raw",
"y",
"yp",
"z",
"zt",
"p2",
"res",
"cur",
"last_idx",
)
def __init__(self, size, N, aR, y, z, raw=False, **kwargs):
super().__init__(size)
self.N = N
self.aR = aR
self.raw = raw
self.y = crypto.decodeint_into_noreduce(None, y)
self.yp = crypto.new_scalar() # y^{i}
self.z = crypto.decodeint_into_noreduce(None, z)
self.zt = crypto.new_scalar() # z^{2 + \floor{i/N}}
self.p2 = crypto.new_scalar() # 2^{i \% N}
self.res = crypto.new_scalar() # tmp_sc_1
self.cur = bytearray(32) if not raw else None
self.last_idx = 0
self.reset()
def reset(self):
crypto.decodeint_into_noreduce(self.yp, _ONE)
crypto.decodeint_into_noreduce(self.p2, _ONE)
crypto.sc_mul_into(self.zt, self.z, self.z)
def __getitem__(self, item):
prev = self.last_idx
item = self.idxize(item)
self.last_idx = item
# Const init for eval
if item == 0: # Reset on first item access
self.reset()
elif item == prev + 1:
crypto.sc_mul_into(self.yp, self.yp, self.y) # ypow
if item % self.N == 0:
crypto.sc_mul_into(self.zt, self.zt, self.z) # zt
crypto.decodeint_into_noreduce(self.p2, _ONE) # p2 reset
else:
crypto.decodeint_into_noreduce(self.res, _TWO) # p2
crypto.sc_mul_into(self.p2, self.p2, self.res) # p2
elif item == prev: # No advancing
pass
else:
raise IndexError("Only linear scan allowed")
# Eval r0[i]
if (
item == 0 or item != prev
): # if True not present, fails with cross dot product
crypto.decodeint_into_noreduce(self.res, self.aR.to(item)) # aR[i]
crypto.sc_add_into(self.res, self.res, self.z) # aR[i] + z
crypto.sc_mul_into(self.res, self.res, self.yp) # (aR[i] + z) * y^i
crypto.sc_muladd_into(
self.res, self.zt, self.p2, self.res
) # (aR[i] + z) * y^i + z^{2 + \floor{i/N}} 2^{i \% N}
if self.raw:
return self.res
crypto.encodeint_into(self.cur, self.res)
return self.cur
def to(self, idx, buff=None, offset=0):
r = self[idx]
if buff is None:
return r
return memcpy(buff, offset, r, 0, 32)
def _ensure_dst_keyvect(dst=None, size=None):
if dst is None:
dst = KeyV(elems=size)
return dst
if size is not None and size != len(dst):
dst.resize(size)
return dst
def _const_vector(val, elems=_BP_N, copy=True):
return KeyVConst(elems, val, copy)
def _vector_exponent_custom(A, B, a, b, dst=None, a_raw=None, b_raw=None):
"""
\\sum_{i=0}^{|A|} a_i A_i + b_i B_i
"""
dst = _ensure_dst_key(dst)
crypto.identity_into(_tmp_pt_2)
for i in range(len(a or a_raw)):
if a:
crypto.decodeint_into_noreduce(_tmp_sc_1, a.to(i))
crypto.decodepoint_into(_tmp_pt_3, A.to(i))
if b:
crypto.decodeint_into_noreduce(_tmp_sc_2, b.to(i))
crypto.decodepoint_into(_tmp_pt_4, B.to(i))
crypto.add_keys3_into(
_tmp_pt_1,
_tmp_sc_1 if a else a_raw.to(i),
_tmp_pt_3,
_tmp_sc_2 if b else b_raw.to(i),
_tmp_pt_4,
)
crypto.point_add_into(_tmp_pt_2, _tmp_pt_2, _tmp_pt_1)
_gc_iter(i)
crypto.encodepoint_into(dst, _tmp_pt_2)
return dst
def _vector_powers(x, n, dst=None, dynamic=False, **kwargs):
"""
r_i = x^i
"""
if dynamic:
return KeyVPowers(n, x, **kwargs)
dst = _ensure_dst_keyvect(dst, n)
if n == 0:
return dst
dst.read(0, _ONE)
if n == 1:
return dst
dst.read(1, x)
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
crypto.decodeint_into_noreduce(_tmp_sc_2, x)
for i in range(2, n):
crypto.sc_mul_into(_tmp_sc_1, _tmp_sc_1, _tmp_sc_2)
crypto.encodeint_into(_tmp_bf_0, _tmp_sc_1)
dst.read(i, _tmp_bf_0)
_gc_iter(i)
return dst
def _vector_power_sum(x, n, dst=None):
"""
\\sum_{i=0}^{n-1} x^i
"""
dst = _ensure_dst_key(dst)
if n == 0:
return _copy_key(dst, _ZERO)
if n == 1:
_copy_key(dst, _ONE)
crypto.decodeint_into_noreduce(_tmp_sc_1, x)
crypto.decodeint_into_noreduce(_tmp_sc_3, _ONE)
crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_1)
crypto.sc_copy(_tmp_sc_2, _tmp_sc_1)
for i in range(2, n):
crypto.sc_mul_into(_tmp_sc_2, _tmp_sc_2, _tmp_sc_1)
crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_2)
_gc_iter(i)
return crypto.encodeint_into(dst, _tmp_sc_3)
def _inner_product(a, b, dst=None):
"""
\\sum_{i=0}^{|a|} a_i b_i
"""
if len(a) != len(b):
raise ValueError("Incompatible sizes of a and b")
dst = _ensure_dst_key(dst)
crypto.sc_init_into(_tmp_sc_1, 0)
for i in range(len(a)):
crypto.decodeint_into_noreduce(_tmp_sc_2, a.to(i))
crypto.decodeint_into_noreduce(_tmp_sc_3, b.to(i))
crypto.sc_muladd_into(_tmp_sc_1, _tmp_sc_2, _tmp_sc_3, _tmp_sc_1)
_gc_iter(i)
crypto.encodeint_into(dst, _tmp_sc_1)
return dst
def _hadamard_fold(v, a, b, into=None, into_offset=0, vR=None, vRoff=0):
"""
Folds a curvepoint array using a two way scaled Hadamard product
ln = len(v); h = ln // 2
v_i = a v_i + b v_{h + i}
"""
h = len(v) // 2
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
into = into if into else v
for i in range(h):
crypto.decodepoint_into(_tmp_pt_1, v.to(i))
crypto.decodepoint_into(_tmp_pt_2, v.to(h + i) if not vR else vR.to(i + vRoff))
crypto.add_keys3_into(_tmp_pt_3, _tmp_sc_1, _tmp_pt_1, _tmp_sc_2, _tmp_pt_2)
crypto.encodepoint_into(_tmp_bf_0, _tmp_pt_3)
into.read(i + into_offset, _tmp_bf_0)
_gc_iter(i)
return into
def _hadamard_fold_linear(v, a, b, into=None, into_offset=0):
"""
Folds a curvepoint array using a two way scaled Hadamard product.
Iterates v linearly to support linear-scan evaluated vectors (on the fly)
ln = len(v); h = ln // 2
v_i = a v_i + b v_{h + i}
"""
h = len(v) // 2
into = into if into else v
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
for i in range(h):
crypto.decodepoint_into(_tmp_pt_1, v.to(i))
crypto.scalarmult_into(_tmp_pt_1, _tmp_pt_1, _tmp_sc_1)
crypto.encodepoint_into(_tmp_bf_0, _tmp_pt_1)
into.read(i + into_offset, _tmp_bf_0)
_gc_iter(i)
crypto.decodeint_into_noreduce(_tmp_sc_1, b)
for i in range(h):
crypto.decodepoint_into(_tmp_pt_1, v.to(i + h))
crypto.scalarmult_into(_tmp_pt_1, _tmp_pt_1, _tmp_sc_1)
crypto.decodepoint_into(_tmp_pt_2, into.to(i + into_offset))
crypto.point_add_into(_tmp_pt_1, _tmp_pt_1, _tmp_pt_2)
crypto.encodepoint_into(_tmp_bf_0, _tmp_pt_1)
into.read(i + into_offset, _tmp_bf_0)
_gc_iter(i)
return into
def _scalar_fold(v, a, b, into=None, into_offset=0):
"""
ln = len(v); h = ln // 2
v_i = a v_i + b v_{h + i}
"""
h = len(v) // 2
crypto.decodeint_into_noreduce(_tmp_sc_1, a)
crypto.decodeint_into_noreduce(_tmp_sc_2, b)
into = into if into else v
for i in range(h):
crypto.decodeint_into_noreduce(_tmp_sc_3, v.to(i))
crypto.decodeint_into_noreduce(_tmp_sc_4, v.to(h + i))
crypto.sc_mul_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_1)
crypto.sc_mul_into(_tmp_sc_4, _tmp_sc_4, _tmp_sc_2)
crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_4)
crypto.encodeint_into(_tmp_bf_0, _tmp_sc_3)
into.read(i + into_offset, _tmp_bf_0)
_gc_iter(i)
return into
def _cross_inner_product(l0, r0, l1, r1):
"""
t1 = l0 . r1 + l1 . r0
t2 = l1 . r1
"""
sc_t1 = crypto.new_scalar()
sc_t2 = crypto.new_scalar()
tl = crypto.new_scalar()
tr = crypto.new_scalar()
for i in range(len(l0)):
crypto.decodeint_into_noreduce(tl, l0.to(i))
crypto.decodeint_into_noreduce(tr, r1.to(i))
crypto.sc_muladd_into(sc_t1, tl, tr, sc_t1)
crypto.decodeint_into_noreduce(tl, l1.to(i))
crypto.sc_muladd_into(sc_t2, tl, tr, sc_t2)
crypto.decodeint_into_noreduce(tr, r0.to(i))
crypto.sc_muladd_into(sc_t1, tl, tr, sc_t1)
_gc_iter(i)
return crypto.encodeint(sc_t1), crypto.encodeint(sc_t2)
def _vector_gen(dst, size, op):
dst = _ensure_dst_keyvect(dst, size)
for i in range(size):
dst.to(i, _tmp_bf_0)
op(i, _tmp_bf_0)
dst.read(i, _tmp_bf_0)
_gc_iter(i)
return dst
def _vector_dup(x, n, dst=None):
dst = _ensure_dst_keyvect(dst, n)
for i in range(n):
dst[i] = x
_gc_iter(i)
return dst
def _hash_cache_mash(dst, hash_cache, *args):
dst = _ensure_dst_key(dst)
ctx = crypto.get_keccak()
ctx.update(hash_cache)
for x in args:
if x is None:
break
ctx.update(x)
hsh = ctx.digest()
crypto.decodeint_into(_tmp_sc_1, hsh)
crypto.encodeint_into(hash_cache, _tmp_sc_1)
_copy_key(dst, hash_cache)
return dst
def _is_reduced(sc):
return crypto.encodeint_into(_tmp_bf_0, crypto.decodeint_into(_tmp_sc_1, sc)) == sc
class MultiExpSequential:
"""
MultiExp object similar to MultiExp array of [(scalar, point), ]
MultiExp computes simply: res = \\sum_i scalar_i * point_i
Straus / Pippenger algorithms are implemented in the original Monero C++ code for the speed
but the memory cost is around 1 MB which is not affordable here in HW devices.
Moreover, Monero needs speed for very fast verification for blockchain verification which is not
priority in this use case.
MultiExp holder with sequential evaluation
"""
def __init__(self, size=None, points=None, point_fnc=None):
self.current_idx = 0
self.size = size if size else None
self.points = points if points else []
self.point_fnc = point_fnc
if points and size is None:
self.size = len(points) if points else 0
else:
self.size = 0
self.acc = crypto.identity()
self.tmp = _ensure_dst_key()
def get_point(self, idx):
return (
self.point_fnc(idx, None) if idx >= len(self.points) else self.points[idx]
)
def add_pair(self, scalar, point):
self._acc(scalar, point)
def add_scalar(self, scalar):
self._acc(scalar, self.get_point(self.current_idx))
def _acc(self, scalar, point):
crypto.decodeint_into_noreduce(_tmp_sc_1, scalar)
crypto.decodepoint_into(_tmp_pt_2, point)
crypto.scalarmult_into(_tmp_pt_3, _tmp_pt_2, _tmp_sc_1)
crypto.point_add_into(self.acc, self.acc, _tmp_pt_3)
self.current_idx += 1
self.size += 1
def eval(self, dst, GiHi=False):
dst = _ensure_dst_key(dst)
return crypto.encodepoint_into(dst, self.acc)
def _multiexp(dst=None, data=None, GiHi=False):
return data.eval(dst, GiHi)
class BulletProofBuilder:
def __init__(self):
self.use_det_masks = True
self.proof_sec = None
# BP_GI_PRE = get_exponent(Gi[i], _XMR_H, i * 2 + 1)
self.Gprec = KeyV(buffer=crypto.tcry.BP_GI_PRE, const=True)
# BP_HI_PRE = get_exponent(Hi[i], _XMR_H, i * 2)
self.Hprec = KeyV(buffer=crypto.tcry.BP_HI_PRE, const=True)
# BP_TWO_N = vector_powers(_TWO, _BP_N);
self.twoN = KeyV(buffer=crypto.tcry.BP_TWO_N, const=True)
self.fnc_det_mask = None
self.tmp_sc_1 = crypto.new_scalar()
self.tmp_det_buff = bytearray(64 + 1 + 4)
self.gc_fnc = gc.collect
self.gc_trace = None
def gc(self, *args):
if self.gc_trace:
self.gc_trace(*args)
if self.gc_fnc:
self.gc_fnc()
def aX_vcts(self, sv, MN):
num_inp = len(sv)
def e_xL(idx, d=None, is_a=True):
j, i = idx // _BP_N, idx % _BP_N
r = None
if j >= num_inp:
r = _ZERO if is_a else _MINUS_ONE
elif sv[j][i // 8] & (1 << i % 8):
r = _ONE if is_a else _ZERO
else:
r = _ZERO if is_a else _MINUS_ONE
if d:
return memcpy(d, 0, r, 0, 32)
return r
aL = KeyVEval(MN, lambda i, d: e_xL(i, d, True))
aR = KeyVEval(MN, lambda i, d: e_xL(i, d, False))
return aL, aR
def _det_mask_init(self):
memcpy(self.tmp_det_buff, 0, self.proof_sec, 0, len(self.proof_sec))
def _det_mask(self, i, is_sL=True, dst=None):
dst = _ensure_dst_key(dst)
if self.fnc_det_mask:
return self.fnc_det_mask(i, is_sL, dst)
self.tmp_det_buff[64] = int(is_sL)
memcpy(self.tmp_det_buff, 65, _ZERO, 0, 4)
dump_uvarint_b_into(i, self.tmp_det_buff, 65)
crypto.hash_to_scalar_into(self.tmp_sc_1, self.tmp_det_buff)
crypto.encodeint_into(dst, self.tmp_sc_1)
return dst
def _gprec_aux(self, size):
return KeyVPrecomp(
size, self.Gprec, lambda i, d: _get_exponent(d, _XMR_H, i * 2 + 1)
)
def _hprec_aux(self, size):
return KeyVPrecomp(
size, self.Hprec, lambda i, d: _get_exponent(d, _XMR_H, i * 2)
)
def _two_aux(self, size):
# Simple recursive exponentiation from precomputed results
lx = len(self.twoN)
def pow_two(i, d=None):
if i < lx:
return self.twoN[i]
d = _ensure_dst_key(d)
flr = i // 2
lw = pow_two(flr)
rw = pow_two(flr + 1 if flr != i / 2.0 else lw)
return _sc_mul(d, lw, rw)
return KeyVPrecomp(size, self.twoN, pow_two)
def sL_vct(self, ln=_BP_N):
return (
KeyVEval(ln, lambda i, dst: self._det_mask(i, True, dst))
if self.use_det_masks
else self.sX_gen(ln)
)
def sR_vct(self, ln=_BP_N):
return (
KeyVEval(ln, lambda i, dst: self._det_mask(i, False, dst))
if self.use_det_masks
else self.sX_gen(ln)
)
def sX_gen(self, ln=_BP_N):
gc.collect()
buff = bytearray(ln * 32)
buff_mv = memoryview(buff)
sc = crypto.new_scalar()
for i in range(ln):
crypto.random_scalar(sc)
crypto.encodeint_into(buff_mv[i * 32 : (i + 1) * 32], sc)
_gc_iter(i)
return KeyV(buffer=buff)
def vector_exponent(self, a, b, dst=None, a_raw=None, b_raw=None):
return _vector_exponent_custom(self.Gprec, self.Hprec, a, b, dst, a_raw, b_raw)
def prove(self, sv, gamma):
return self.prove_batch([sv], [gamma])
def prove_setup(self, sv, gamma):
utils.ensure(len(sv) == len(gamma), "|sv| != |gamma|")
utils.ensure(len(sv) > 0, "sv empty")
self.proof_sec = crypto.random_bytes(64)
self._det_mask_init()
gc.collect()
sv = [crypto.encodeint(x) for x in sv]
gamma = [crypto.encodeint(x) for x in gamma]
M, logM = 1, 0
while M <= _BP_M and M < len(sv):
logM += 1
M = 1 << logM
MN = M * _BP_N
V = _ensure_dst_keyvect(None, len(sv))
for i in range(len(sv)):
_add_keys2(_tmp_bf_0, gamma[i], sv[i], _XMR_H)
_scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT)
V.read(i, _tmp_bf_0)
aL, aR = self.aX_vcts(sv, MN)
return M, logM, aL, aR, V, gamma
def prove_batch(self, sv, gamma):
M, logM, aL, aR, V, gamma = self.prove_setup(sv, gamma)
hash_cache = _ensure_dst_key()
while True:
self.gc(10)
r = self._prove_batch_main(
V, gamma, aL, aR, hash_cache, logM, _BP_LOG_N, M, _BP_N
)
if r[0]:
break
return r[1]
def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N):
logMN = logM + logN
MN = M * N
_hash_vct_to_scalar(hash_cache, V)
# Extended precomputed GiHi
Gprec = self._gprec_aux(MN)
Hprec = self._hprec_aux(MN)
# PHASE 1
A, S, T1, T2, taux, mu, t, l, r, y, x_ip, hash_cache = self._prove_phase1(
N, M, logMN, V, gamma, aL, aR, hash_cache, Gprec, Hprec
)
# PHASE 2
L, R, a, b = self._prove_loop(
MN, logMN, l, r, y, x_ip, hash_cache, Gprec, Hprec
)
from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof
return (
1,
Bulletproof(
V=V, A=A, S=S, T1=T1, T2=T2, taux=taux, mu=mu, L=L, R=R, a=a, b=b, t=t
),
)
def _prove_phase1(self, N, M, logMN, V, gamma, aL, aR, hash_cache, Gprec, Hprec):
MN = M * N
# PAPER LINES 38-39, compute A = 8^{-1} ( \alpha G + \sum_{i=0}^{MN-1} a_{L,i} \Gi_i + a_{R,i} \Hi_i)
alpha = _sc_gen()
A = _ensure_dst_key()
_vector_exponent_custom(Gprec, Hprec, aL, aR, A)
_add_keys(A, A, _scalarmult_base(_tmp_bf_1, alpha))
_scalarmult_key(A, A, _INV_EIGHT)
self.gc(11)
# PAPER LINES 40-42, compute S = 8^{-1} ( \rho G + \sum_{i=0}^{MN-1} s_{L,i} \Gi_i + s_{R,i} \Hi_i)
sL = self.sL_vct(MN)
sR = self.sR_vct(MN)
rho = _sc_gen()
S = _ensure_dst_key()
_vector_exponent_custom(Gprec, Hprec, sL, sR, S)
_add_keys(S, S, _scalarmult_base(_tmp_bf_1, rho))
_scalarmult_key(S, S, _INV_EIGHT)
self.gc(12)
# PAPER LINES 43-45
y = _ensure_dst_key()
_hash_cache_mash(y, hash_cache, A, S)
if y == _ZERO:
return (0,)
z = _ensure_dst_key()
_hash_to_scalar(hash_cache, y)
_copy_key(z, hash_cache)
zc = crypto.decodeint_into_noreduce(None, z)
if z == _ZERO:
return (0,)
# Polynomial construction by coefficients
# l0 = aL - z r0 = ((aR + z) . ypow) + zt
# l1 = sL r1 = sR . ypow
l0 = KeyVEval(
MN, lambda i, d: _sc_sub(d, aL.to(i), None, None, zc) # noqa: F821
)
l1 = sL
self.gc(13)
# This computes the ugly sum/concatenation from PAPER LINE 65
# r0_i = ((a_{Ri} + z) y^{i}) + zt_i
# r1_i = s_{Ri} y^{i}
r0 = KeyR0(MN, N, aR, y, z)
ypow = KeyVPowers(MN, y, raw=True)
r1 = KeyVEval(
MN, lambda i, d: _sc_mul(d, sR.to(i), None, ypow[i]) # noqa: F821
)
del aR
self.gc(14)
# Evaluate per index
# - $t_1 = l_0 . r_1 + l_1 . r0$
# - $t_2 = l_1 . r_1$
# - compute then T1, T2, x
t1, t2 = _cross_inner_product(l0, r0, l1, r1)
# PAPER LINES 47-48, Compute: T1, T2
# T1 = 8^{-1} (\tau_1G + t_1H )
# T2 = 8^{-1} (\tau_2G + t_2H )
tau1, tau2 = _sc_gen(), _sc_gen()
T1, T2 = _ensure_dst_key(), _ensure_dst_key()
_add_keys2(T1, tau1, t1, _XMR_H)
_scalarmult_key(T1, T1, _INV_EIGHT)
_add_keys2(T2, tau2, t2, _XMR_H)
_scalarmult_key(T2, T2, _INV_EIGHT)
del (t1, t2)
self.gc(16)
# PAPER LINES 49-51, compute x
x = _ensure_dst_key()
_hash_cache_mash(x, hash_cache, z, T1, T2)
if x == _ZERO:
return (0,)
# Second pass, compute l, r
# Offloaded version does this incrementally and produces l, r outs in chunks
# Message offloaded sends blinded vectors with random constants.
# - $l_i = l_{0,i} + xl_{1,i}
# - $r_i = r_{0,i} + xr_{1,i}
# - $t = l . r$
l = _ensure_dst_keyvect(None, MN)
r = _ensure_dst_keyvect(None, MN)
ts = crypto.new_scalar()
for i in range(MN):
_sc_muladd(_tmp_bf_0, x, l1.to(i), l0.to(i))
l.read(i, _tmp_bf_0)
_sc_muladd(_tmp_bf_1, x, r1.to(i), r0.to(i))
r.read(i, _tmp_bf_1)
_sc_muladd(ts, _tmp_bf_0, _tmp_bf_1, None, c_raw=ts, raw=True)
t = crypto.encodeint(ts)
del (l0, l1, sL, sR, r0, r1, ypow, ts)
self.gc(17)
# PAPER LINES 52-53, Compute \tau_x
taux = _ensure_dst_key()
_sc_mul(taux, tau1, x)
_sc_mul(_tmp_bf_0, x, x)
_sc_muladd(taux, tau2, _tmp_bf_0, taux)
del (tau1, tau2)
zpow = crypto.sc_mul_into(None, zc, zc)
for j in range(1, len(V) + 1):
_sc_muladd(taux, None, gamma[j - 1], taux, a_raw=zpow)
crypto.sc_mul_into(zpow, zpow, zc)
del (zc, zpow)
self.gc(18)
mu = _ensure_dst_key()
_sc_muladd(mu, x, rho, alpha)
del (rho, alpha)
self.gc(19)
# PAPER LINES 32-33
x_ip = _hash_cache_mash(None, hash_cache, x, taux, mu, t)
if x_ip == _ZERO:
return 0, None
return A, S, T1, T2, taux, mu, t, l, r, y, x_ip, hash_cache
def _prove_loop(self, MN, logMN, l, r, y, x_ip, hash_cache, Gprec, Hprec):
nprime = MN
aprime = l
bprime = r
yinvpowL = KeyVPowers(MN, _invert(_tmp_bf_0, y), raw=True)
yinvpowR = KeyVPowers(MN, _tmp_bf_0, raw=True)
tmp_pt = crypto.new_point()
Gprime = Gprec
HprimeL = KeyVEval(
MN, lambda i, d: _scalarmult_key(d, Hprec.to(i), None, yinvpowL[i])
)
HprimeR = KeyVEval(
MN, lambda i, d: _scalarmult_key(d, Hprec.to(i), None, yinvpowR[i], tmp_pt)
)
Hprime = HprimeL
self.gc(20)
L = _ensure_dst_keyvect(None, logMN)
R = _ensure_dst_keyvect(None, logMN)
cL = _ensure_dst_key()
cR = _ensure_dst_key()
winv = _ensure_dst_key()
w_round = _ensure_dst_key()
tmp = _ensure_dst_key()
_tmp_k_1 = _ensure_dst_key()
round = 0
# PAPER LINE 13
while nprime > 1:
# PAPER LINE 15
npr2 = nprime
nprime >>= 1
self.gc(22)
# PAPER LINES 16-17
# cL = \ap_{\left(\inta\right)} \cdot \bp_{\left(\intb\right)}
# cR = \ap_{\left(\intb\right)} \cdot \bp_{\left(\inta\right)}
_inner_product(
aprime.slice_view(0, nprime), bprime.slice_view(nprime, npr2), cL
)
_inner_product(
aprime.slice_view(nprime, npr2), bprime.slice_view(0, nprime), cR
)
self.gc(23)
# PAPER LINES 18-19
# Lc = 8^{-1} \left(\left( \sum_{i=0}^{\np} \ap_{i}\quad\Gp_{i+\np} + \bp_{i+\np}\Hp_{i} \right)
# + \left(c_L x_{ip}\right)H \right)
_vector_exponent_custom(
Gprime.slice_view(nprime, npr2),
Hprime.slice_view(0, nprime),
aprime.slice_view(0, nprime),
bprime.slice_view(nprime, npr2),
_tmp_bf_0,
)
# In round 0 backup the y^{prime - 1}
if round == 0:
yinvpowR.set_state(yinvpowL.last_idx, yinvpowL.cur)
_sc_mul(tmp, cL, x_ip)
_add_keys(_tmp_bf_0, _tmp_bf_0, _scalarmultH(_tmp_k_1, tmp))
_scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT)
L.read(round, _tmp_bf_0)
self.gc(24)
# Rc = 8^{-1} \left(\left( \sum_{i=0}^{\np} \ap_{i+\np}\Gp_{i}\quad + \bp_{i}\quad\Hp_{i+\np} \right)
# + \left(c_R x_{ip}\right)H \right)
_vector_exponent_custom(
Gprime.slice_view(0, nprime),
Hprime.slice_view(nprime, npr2),
aprime.slice_view(nprime, npr2),
bprime.slice_view(0, nprime),
_tmp_bf_0,
)
_sc_mul(tmp, cR, x_ip)
_add_keys(_tmp_bf_0, _tmp_bf_0, _scalarmultH(_tmp_k_1, tmp))
_scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT)
R.read(round, _tmp_bf_0)
self.gc(25)
# PAPER LINES 21-22
_hash_cache_mash(w_round, hash_cache, L.to(round), R.to(round))
if w_round == _ZERO:
return (0,)
# PAPER LINES 24-25, fold {G~, H~}
_invert(winv, w_round)
self.gc(26)
# PAPER LINES 28-29, fold {a, b} vectors
# aprime's high part is used as a buffer for other operations
_scalar_fold(aprime, w_round, winv)
aprime.resize(nprime)
self.gc(27)
_scalar_fold(bprime, winv, w_round)
bprime.resize(nprime)
self.gc(28)
# First fold produced to a new buffer, smaller one (G~ on-the-fly)
Gprime_new = KeyV(nprime) if round == 0 else Gprime
Gprime = _hadamard_fold(Gprime, winv, w_round, Gprime_new, 0)
Gprime.resize(nprime)
self.gc(30)
# Hadamard fold for H is special - linear scan only.
# Linear scan is slow, thus we have HprimeR.
if round == 0:
Hprime_new = KeyV(nprime)
Hprime = _hadamard_fold(
Hprime, w_round, winv, Hprime_new, 0, HprimeR, nprime
)
# Hprime = _hadamard_fold_linear(Hprime, w_round, winv, Hprime_new, 0)
else:
_hadamard_fold(Hprime, w_round, winv)
Hprime.resize(nprime)
if round == 0:
# del (Gprec, Hprec, yinvpowL, HprimeL)
del (Gprec, Hprec, yinvpowL, yinvpowR, HprimeL, HprimeR, tmp_pt)
self.gc(31)
round += 1
return L, R, aprime.to(0), bprime.to(0)
def verify(self, proof):
return self.verify_batch([proof])
def verify_batch(self, proofs, single_optim=True):
"""
BP batch verification
:param proofs:
:param single_optim: single proof memory optimization
:return:
"""
max_length = 0
for proof in proofs:
utils.ensure(_is_reduced(proof.taux), "Input scalar not in range")
utils.ensure(_is_reduced(proof.mu), "Input scalar not in range")
utils.ensure(_is_reduced(proof.a), "Input scalar not in range")
utils.ensure(_is_reduced(proof.b), "Input scalar not in range")
utils.ensure(_is_reduced(proof.t), "Input scalar not in range")
utils.ensure(len(proof.V) >= 1, "V does not have at least one element")
utils.ensure(len(proof.L) == len(proof.R), "|L| != |R|")
utils.ensure(len(proof.L) > 0, "Empty proof")
max_length = max(max_length, len(proof.L))
utils.ensure(max_length < 32, "At least one proof is too large")
maxMN = 1 << max_length
logN = 6
N = 1 << logN
tmp = _ensure_dst_key()
# setup weighted aggregates
is_single = len(proofs) == 1 and single_optim # ph4
z1 = _init_key(_ZERO)
z3 = _init_key(_ZERO)
m_z4 = _vector_dup(_ZERO, maxMN) if not is_single else None
m_z5 = _vector_dup(_ZERO, maxMN) if not is_single else None
m_y0 = _init_key(_ZERO)
y1 = _init_key(_ZERO)
muex_acc = _init_key(_ONE)
Gprec = self._gprec_aux(maxMN)
Hprec = self._hprec_aux(maxMN)
for proof in proofs:
M = 1
logM = 0
while M <= _BP_M and M < len(proof.V):
logM += 1
M = 1 << logM
utils.ensure(len(proof.L) == 6 + logM, "Proof is not the expected size")
MN = M * N
weight_y = crypto.encodeint(crypto.random_scalar())
weight_z = crypto.encodeint(crypto.random_scalar())
# Reconstruct the challenges
hash_cache = _hash_vct_to_scalar(None, proof.V)
y = _hash_cache_mash(None, hash_cache, proof.A, proof.S)
utils.ensure(y != _ZERO, "y == 0")
z = _hash_to_scalar(None, y)
_copy_key(hash_cache, z)
utils.ensure(z != _ZERO, "z == 0")
x = _hash_cache_mash(None, hash_cache, z, proof.T1, proof.T2)
utils.ensure(x != _ZERO, "x == 0")
x_ip = _hash_cache_mash(None, hash_cache, x, proof.taux, proof.mu, proof.t)
utils.ensure(x_ip != _ZERO, "x_ip == 0")
# PAPER LINE 61
_sc_mulsub(m_y0, proof.taux, weight_y, m_y0)
zpow = _vector_powers(z, M + 3)
k = _ensure_dst_key()
ip1y = _vector_power_sum(y, MN)
_sc_mulsub(k, zpow.to(2), ip1y, _ZERO)
for j in range(1, M + 1):
utils.ensure(j + 2 < len(zpow), "invalid zpow index")
_sc_mulsub(k, zpow.to(j + 2), _BP_IP12, k)
# VERIFY_line_61rl_new
_sc_muladd(tmp, z, ip1y, k)
_sc_sub(tmp, proof.t, tmp)
_sc_muladd(y1, tmp, weight_y, y1)
weight_y8 = _init_key(weight_y)
weight_y8 = _sc_mul(None, weight_y, _EIGHT)
muex = MultiExpSequential(points=[pt for pt in proof.V])
for j in range(len(proof.V)):
_sc_mul(tmp, zpow.to(j + 2), weight_y8)
muex.add_scalar(_init_key(tmp))
_sc_mul(tmp, x, weight_y8)
muex.add_pair(_init_key(tmp), proof.T1)
xsq = _ensure_dst_key()
_sc_mul(xsq, x, x)
_sc_mul(tmp, xsq, weight_y8)
muex.add_pair(_init_key(tmp), proof.T2)
weight_z8 = _init_key(weight_z)
weight_z8 = _sc_mul(None, weight_z, _EIGHT)
muex.add_pair(weight_z8, proof.A)
_sc_mul(tmp, x, weight_z8)
muex.add_pair(_init_key(tmp), proof.S)
_multiexp(tmp, muex, False)
_add_keys(muex_acc, muex_acc, tmp)
del muex
# Compute the number of rounds for the inner product
rounds = logM + logN
utils.ensure(rounds > 0, "Zero rounds")
# PAPER LINES 21-22
# The inner product challenges are computed per round
w = _ensure_dst_keyvect(None, rounds)
for i in range(rounds):
_hash_cache_mash(_tmp_bf_0, hash_cache, proof.L[i], proof.R[i])
w.read(i, _tmp_bf_0)
utils.ensure(w.to(i) != _ZERO, "w[i] == 0")
# Basically PAPER LINES 24-25
# Compute the curvepoints from G[i] and H[i]
yinvpow = _init_key(_ONE)
ypow = _init_key(_ONE)
yinv = _invert(None, y)
self.gc(61)
winv = _ensure_dst_keyvect(None, rounds)
for i in range(rounds):
_invert(_tmp_bf_0, w.to(i))
winv.read(i, _tmp_bf_0)
self.gc(62)
g_scalar = _ensure_dst_key()
h_scalar = _ensure_dst_key()
twoN = self._two_aux(N)
for i in range(MN):
_copy_key(g_scalar, proof.a)
_sc_mul(h_scalar, proof.b, yinvpow)
for j in range(rounds - 1, -1, -1):
J = len(w) - j - 1
if (i & (1 << j)) == 0:
_sc_mul(g_scalar, g_scalar, winv.to(J))
_sc_mul(h_scalar, h_scalar, w.to(J))
else:
_sc_mul(g_scalar, g_scalar, w.to(J))
_sc_mul(h_scalar, h_scalar, winv.to(J))
# Adjust the scalars using the exponents from PAPER LINE 62
_sc_add(g_scalar, g_scalar, z)
utils.ensure(2 + i // N < len(zpow), "invalid zpow index")
utils.ensure(i % N < len(twoN), "invalid twoN index")
_sc_mul(tmp, zpow.to(2 + i // N), twoN.to(i % N))
_sc_muladd(tmp, z, ypow, tmp)
_sc_mulsub(h_scalar, tmp, yinvpow, h_scalar)
if not is_single: # ph4
m_z4.read(i, _sc_mulsub(_tmp_bf_0, g_scalar, weight_z, m_z4[i]))
m_z5.read(i, _sc_mulsub(_tmp_bf_0, h_scalar, weight_z, m_z5[i]))
else:
_sc_mul(tmp, g_scalar, weight_z)
_sub_keys(
muex_acc, muex_acc, _scalarmult_key(tmp, Gprec.to(i), tmp)
)
_sc_mul(tmp, h_scalar, weight_z)
_sub_keys(
muex_acc, muex_acc, _scalarmult_key(tmp, Hprec.to(i), tmp)
)
if i != MN - 1:
_sc_mul(yinvpow, yinvpow, yinv)
_sc_mul(ypow, ypow, y)
if i & 15 == 0:
self.gc(62)
del (g_scalar, h_scalar, twoN)
self.gc(63)
_sc_muladd(z1, proof.mu, weight_z, z1)
muex = MultiExpSequential(
point_fnc=lambda i, d: proof.L[i // 2]
if i & 1 == 0
else proof.R[i // 2]
)
for i in range(rounds):
_sc_mul(tmp, w.to(i), w.to(i))
_sc_mul(tmp, tmp, weight_z8)
muex.add_scalar(tmp)
_sc_mul(tmp, winv.to(i), winv.to(i))
_sc_mul(tmp, tmp, weight_z8)
muex.add_scalar(tmp)
acc = _multiexp(None, muex, False)
_add_keys(muex_acc, muex_acc, acc)
_sc_mulsub(tmp, proof.a, proof.b, proof.t)
_sc_mul(tmp, tmp, x_ip)
_sc_muladd(z3, tmp, weight_z, z3)
_sc_sub(tmp, m_y0, z1)
z3p = _sc_sub(None, z3, y1)
check2 = crypto.encodepoint(
crypto.ge25519_double_scalarmult_base_vartime(
crypto.decodeint(z3p), crypto.xmr_H(), crypto.decodeint(tmp)
)
)
_add_keys(muex_acc, muex_acc, check2)
if not is_single: # ph4
muex = MultiExpSequential(
point_fnc=lambda i, d: Gprec.to(i // 2)
if i & 1 == 0
else Hprec.to(i // 2)
)
for i in range(maxMN):
muex.add_scalar(m_z4[i])
muex.add_scalar(m_z5[i])
_add_keys(muex_acc, muex_acc, _multiexp(None, muex, True))
if muex_acc != _ONE:
raise ValueError("Verification failure at step 2")
return True