diff --git a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-monero.h b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-monero.h index 37b8c16f7d..c48b07e876 100644 --- a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-monero.h +++ b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-monero.h @@ -972,13 +972,15 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN( mod_trezorcrypto_monero_xmr_random_scalar_obj, 0, 1, mod_trezorcrypto_monero_xmr_random_scalar); -/// def xmr_fast_hash(r: Optional[bytes], buff: bytes) -> bytes: +// clang-format off +/// def xmr_fast_hash(r: Optional[bytes], buff: bytes, length: int, offset: int) -> bytes: +// clang-format on /// """ /// XMR fast hash /// """ STATIC mp_obj_t mod_trezorcrypto_monero_xmr_fast_hash(size_t n_args, const mp_obj_t *args) { - const int off = n_args == 2 ? 0 : -1; + const int off = n_args >= 2 ? 0 : -1; uint8_t buff[32]; uint8_t *buff_use = buff; if (n_args > 1) { @@ -992,47 +994,76 @@ STATIC mp_obj_t mod_trezorcrypto_monero_xmr_fast_hash(size_t n_args, mp_buffer_info_t data; mp_get_buffer_raise(args[1 + off], &data, MP_BUFFER_READ); - xmr_fast_hash(buff_use, data.buf, data.len); - return n_args == 2 ? args[0] : mp_obj_new_bytes(buff, 32); + mp_int_t length = n_args >= 3 ? mp_obj_get_int(args[2]) : data.len; + mp_int_t offset = n_args >= 4 ? mp_obj_get_int(args[3]) : 0; + if (length < 0) length += data.len; + if (offset < 0) offset += data.len; + if (length < 0 || offset < 0 || offset + length > data.len) { + mp_raise_ValueError("Illegal offset/length"); + } + xmr_fast_hash(buff_use, (const char *)data.buf + offset, length); + return n_args >= 2 ? args[0] : mp_obj_new_bytes(buff, 32); } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN( - mod_trezorcrypto_monero_xmr_fast_hash_obj, 1, 2, + mod_trezorcrypto_monero_xmr_fast_hash_obj, 1, 4, mod_trezorcrypto_monero_xmr_fast_hash); -/// def xmr_hash_to_ec(r: Optional[Ge25519], buff: bytes) -> Ge25519: +// clang-format off +/// def xmr_hash_to_ec(r: Optional[Ge25519], buff: bytes, length: int, offset: +/// int) -> Ge25519: +// clang-format on /// """ /// XMR hashing to EC point /// """ STATIC mp_obj_t mod_trezorcrypto_monero_xmr_hash_to_ec(size_t n_args, const mp_obj_t *args) { - const bool res_arg = n_args == 2; + const bool res_arg = n_args >= 2; const int off = res_arg ? 0 : -1; mp_obj_t res = mp_obj_new_ge25519_r(res_arg ? args[0] : mp_const_none); mp_buffer_info_t data; mp_get_buffer_raise(args[1 + off], &data, MP_BUFFER_READ); - xmr_hash_to_ec(&MP_OBJ_GE25519(res), data.buf, data.len); + mp_int_t length = n_args >= 3 ? mp_obj_get_int(args[2]) : data.len; + mp_int_t offset = n_args >= 4 ? mp_obj_get_int(args[3]) : 0; + if (length < 0) length += data.len; + if (offset < 0) offset += data.len; + if (length < 0 || offset < 0 || offset + length > data.len) { + mp_raise_ValueError("Illegal offset/length"); + } + + xmr_hash_to_ec(&MP_OBJ_GE25519(res), (const char *)data.buf + offset, length); return res; } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN( - mod_trezorcrypto_monero_xmr_hash_to_ec_obj, 1, 2, + mod_trezorcrypto_monero_xmr_hash_to_ec_obj, 1, 4, mod_trezorcrypto_monero_xmr_hash_to_ec); -/// def xmr_hash_to_scalar(r: Optional[Sc25519], buff: bytes) -> Sc25519: +// clang-format off +/// def xmr_hash_to_scalar(r: Optional[Sc25519], buff: bytes, length: int, +/// offset: int) -> Sc25519: +// clang-format on /// """ /// XMR hashing to EC scalar /// """ STATIC mp_obj_t mod_trezorcrypto_monero_xmr_hash_to_scalar( size_t n_args, const mp_obj_t *args) { - const bool res_arg = n_args == 2; + const bool res_arg = n_args >= 2; const int off = res_arg ? 0 : -1; mp_obj_t res = mp_obj_new_scalar_r(res_arg ? args[0] : mp_const_none); mp_buffer_info_t data; mp_get_buffer_raise(args[1 + off], &data, MP_BUFFER_READ); - xmr_hash_to_scalar(MP_OBJ_SCALAR(res), data.buf, data.len); + mp_int_t length = n_args >= 3 ? mp_obj_get_int(args[2]) : data.len; + mp_int_t offset = n_args >= 4 ? mp_obj_get_int(args[3]) : 0; + if (length < 0) length += data.len; + if (offset < 0) offset += data.len; + if (length < 0 || offset < 0 || offset + length > data.len) { + mp_raise_ValueError("Illegal offset/length"); + } + xmr_hash_to_scalar(MP_OBJ_SCALAR(res), (const char *)data.buf + offset, + length); return res; } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN( - mod_trezorcrypto_monero_xmr_hash_to_scalar_obj, 1, 2, + mod_trezorcrypto_monero_xmr_hash_to_scalar_obj, 1, 4, mod_trezorcrypto_monero_xmr_hash_to_scalar); /// def xmr_derivation_to_scalar( diff --git a/core/mocks/generated/trezorcrypto/monero.pyi b/core/mocks/generated/trezorcrypto/monero.pyi index 9884eb02ff..a1bc84a390 100644 --- a/core/mocks/generated/trezorcrypto/monero.pyi +++ b/core/mocks/generated/trezorcrypto/monero.pyi @@ -292,21 +292,23 @@ def xmr_random_scalar(r: Optional[Sc25519] = None) -> Sc25519: # extmod/modtrezorcrypto/modtrezorcrypto-monero.h -def xmr_fast_hash(r: Optional[bytes], buff: bytes) -> bytes: +def xmr_fast_hash(r: Optional[bytes], buff: bytes, length: int, offset: int) -> bytes: """ XMR fast hash """ # extmod/modtrezorcrypto/modtrezorcrypto-monero.h -def xmr_hash_to_ec(r: Optional[Ge25519], buff: bytes) -> Ge25519: +def xmr_hash_to_ec(r: Optional[Ge25519], buff: bytes, length: int, offset: +int) -> Ge25519: """ XMR hashing to EC point """ # extmod/modtrezorcrypto/modtrezorcrypto-monero.h -def xmr_hash_to_scalar(r: Optional[Sc25519], buff: bytes) -> Sc25519: +def xmr_hash_to_scalar(r: Optional[Sc25519], buff: bytes, length: int, +offset: int) -> Sc25519: """ XMR hashing to EC scalar """ diff --git a/core/src/apps/monero/xmr/bulletproof.py b/core/src/apps/monero/xmr/bulletproof.py index 1b6d991ade..943da6c487 100644 --- a/core/src/apps/monero/xmr/bulletproof.py +++ b/core/src/apps/monero/xmr/bulletproof.py @@ -2,7 +2,7 @@ import gc from micropython import const from trezor import utils -from trezor.utils import memcpy as _memcpy +from trezor.utils import memcpy as tmemcpy from apps.monero.xmr import crypto from apps.monero.xmr.serialize.int_serialize import dump_uvarint_b_into, uvarint_size @@ -15,7 +15,7 @@ _BP_M = const(16) # maximal number of bulletproofs _ZERO = b"\x00" * 32 _ONE = b"\x01" + b"\x00" * 31 -# _TWO = b"\x02" + b"\x00" * 31 +_TWO = b"\x02" + b"\x00" * 31 _EIGHT = b"\x08" + b"\x00" * 31 _INV_EIGHT = crypto.INV_EIGHT _MINUS_ONE = b"\xec\xd3\xf5\x5c\x1a\x63\x12\x58\xd6\x9c\xf7\xa2\xde\xf9\xde\x14\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10" @@ -35,21 +35,20 @@ _BP_IP12 = b"\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x0 # Caution has to be exercised when using the registers and operations using the registers # -tmp_bf_0 = bytearray(32) -tmp_bf_1 = bytearray(32) -tmp_bf_2 = bytearray(32) -tmp_bf_exp = bytearray(11 + 32 + 4) -tmp_bf_exp_mv = memoryview(tmp_bf_exp) +_tmp_bf_0 = bytearray(32) +_tmp_bf_1 = bytearray(32) +_tmp_bf_2 = bytearray(32) +_tmp_bf_exp = bytearray(11 + 32 + 4) -tmp_pt_1 = crypto.new_point() -tmp_pt_2 = crypto.new_point() -tmp_pt_3 = crypto.new_point() -tmp_pt_4 = crypto.new_point() +_tmp_pt_1 = crypto.new_point() +_tmp_pt_2 = crypto.new_point() +_tmp_pt_3 = crypto.new_point() +_tmp_pt_4 = crypto.new_point() -tmp_sc_1 = crypto.new_scalar() -tmp_sc_2 = crypto.new_scalar() -tmp_sc_3 = crypto.new_scalar() -tmp_sc_4 = crypto.new_scalar() +_tmp_sc_1 = crypto.new_scalar() +_tmp_sc_2 = crypto.new_scalar() +_tmp_sc_3 = crypto.new_scalar() +_tmp_sc_4 = crypto.new_scalar() def _ensure_dst_key(dst=None): @@ -60,188 +59,199 @@ def _ensure_dst_key(dst=None): def memcpy(dst, dst_off, src, src_off, len): if dst is not None: - _memcpy(dst, dst_off, src, src_off, len) + tmemcpy(dst, dst_off, src, src_off, len) return dst -def alloc_scalars(num=1): +def _alloc_scalars(num=1): return (crypto.new_scalar() for _ in range(num)) -def copy_key(dst, src): +def _copy_key(dst, src): for i in range(32): dst[i] = src[i] return dst -def init_key(val, dst=None): +def _init_key(val, dst=None): dst = _ensure_dst_key(dst) - return copy_key(dst, val) + return _copy_key(dst, val) -def gc_iter(i): +def _gc_iter(i): if i & 127 == 0: gc.collect() -def invert(dst, x): +def _invert(dst, x): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, x) - crypto.sc_inv_into(tmp_sc_2, tmp_sc_1) - crypto.encodeint_into(dst, tmp_sc_2) + crypto.decodeint_into_noreduce(_tmp_sc_1, x) + crypto.sc_inv_into(_tmp_sc_2, _tmp_sc_1) + crypto.encodeint_into(dst, _tmp_sc_2) return dst -def scalarmult_key(dst, P, s): +def _scalarmult_key(dst, P, s, s_raw=None, tmp_pt=_tmp_pt_1): dst = _ensure_dst_key(dst) - crypto.decodepoint_into(tmp_pt_1, P) - crypto.decodeint_into_noreduce(tmp_sc_1, s) - crypto.scalarmult_into(tmp_pt_2, tmp_pt_1, tmp_sc_1) - crypto.encodepoint_into(dst, tmp_pt_2) + crypto.decodepoint_into(tmp_pt, P) + if s: + crypto.decodeint_into_noreduce(_tmp_sc_1, s) + crypto.scalarmult_into(tmp_pt, tmp_pt, _tmp_sc_1 if s else s_raw) + crypto.encodepoint_into(dst, tmp_pt) return dst -def scalarmultH(dst, x): +def _scalarmultH(dst, x): dst = _ensure_dst_key(dst) - crypto.decodeint_into(tmp_sc_1, x) - crypto.scalarmult_into(tmp_pt_1, _XMR_HP, tmp_sc_1) - crypto.encodepoint_into(dst, tmp_pt_1) + crypto.decodeint_into(_tmp_sc_1, x) + crypto.scalarmult_into(_tmp_pt_1, _XMR_HP, _tmp_sc_1) + crypto.encodepoint_into(dst, _tmp_pt_1) return dst -def scalarmult_base(dst, x): +def _scalarmult_base(dst, x): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, x) - crypto.scalarmult_base_into(tmp_pt_1, tmp_sc_1) - crypto.encodepoint_into(dst, tmp_pt_1) + crypto.decodeint_into_noreduce(_tmp_sc_1, x) + crypto.scalarmult_base_into(_tmp_pt_1, _tmp_sc_1) + crypto.encodepoint_into(dst, _tmp_pt_1) return dst -def sc_gen(dst=None): +def _sc_gen(dst=None): dst = _ensure_dst_key(dst) - crypto.random_scalar(tmp_sc_1) - crypto.encodeint_into(dst, tmp_sc_1) + crypto.random_scalar(_tmp_sc_1) + crypto.encodeint_into(dst, _tmp_sc_1) return dst -def sc_add(dst, a, b): +def _sc_add(dst, a, b): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) - crypto.sc_add_into(tmp_sc_3, tmp_sc_1, tmp_sc_2) - crypto.encodeint_into(dst, tmp_sc_3) + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + crypto.decodeint_into_noreduce(_tmp_sc_2, b) + crypto.sc_add_into(_tmp_sc_3, _tmp_sc_1, _tmp_sc_2) + crypto.encodeint_into(dst, _tmp_sc_3) return dst -def sc_sub(dst, a, b): +def _sc_sub(dst, a, b, a_raw=None, b_raw=None): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) - crypto.sc_sub_into(tmp_sc_3, tmp_sc_1, tmp_sc_2) - crypto.encodeint_into(dst, tmp_sc_3) + if a: + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + if b: + crypto.decodeint_into_noreduce(_tmp_sc_2, b) + crypto.sc_sub_into(_tmp_sc_3, _tmp_sc_1 if a else a_raw, _tmp_sc_2 if b else b_raw) + crypto.encodeint_into(dst, _tmp_sc_3) return dst -def sc_mul(dst, a, b): +def _sc_mul(dst, a, b=None, b_raw=None): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) - crypto.sc_mul_into(tmp_sc_3, tmp_sc_1, tmp_sc_2) - crypto.encodeint_into(dst, tmp_sc_3) + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + if b: + crypto.decodeint_into_noreduce(_tmp_sc_2, b) + crypto.sc_mul_into(_tmp_sc_3, _tmp_sc_1, _tmp_sc_2 if b else b_raw) + crypto.encodeint_into(dst, _tmp_sc_3) return dst -def sc_muladd(dst, a, b, c): - dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) - crypto.decodeint_into_noreduce(tmp_sc_3, c) - crypto.sc_muladd_into(tmp_sc_4, tmp_sc_1, tmp_sc_2, tmp_sc_3) - crypto.encodeint_into(dst, tmp_sc_4) +def _sc_muladd(dst, a, b, c, a_raw=None, b_raw=None, c_raw=None, raw=False): + dst = _ensure_dst_key(dst) if not raw else (dst if dst else crypto.new_scalar()) + if a: + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + if b: + crypto.decodeint_into_noreduce(_tmp_sc_2, b) + if c: + crypto.decodeint_into_noreduce(_tmp_sc_3, c) + crypto.sc_muladd_into( + _tmp_sc_4 if not raw else dst, + _tmp_sc_1 if a else a_raw, + _tmp_sc_2 if b else b_raw, + _tmp_sc_3 if c else c_raw, + ) + if not raw: + crypto.encodeint_into(dst, _tmp_sc_4) return dst -def sc_mulsub(dst, a, b, c): +def _sc_mulsub(dst, a, b, c): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) - crypto.decodeint_into_noreduce(tmp_sc_3, c) - crypto.sc_mulsub_into(tmp_sc_4, tmp_sc_1, tmp_sc_2, tmp_sc_3) - crypto.encodeint_into(dst, tmp_sc_4) + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + crypto.decodeint_into_noreduce(_tmp_sc_2, b) + crypto.decodeint_into_noreduce(_tmp_sc_3, c) + crypto.sc_mulsub_into(_tmp_sc_4, _tmp_sc_1, _tmp_sc_2, _tmp_sc_3) + crypto.encodeint_into(dst, _tmp_sc_4) return dst -def add_keys(dst, A, B): +def _add_keys(dst, A, B): dst = _ensure_dst_key(dst) - crypto.decodepoint_into(tmp_pt_1, A) - crypto.decodepoint_into(tmp_pt_2, B) - crypto.point_add_into(tmp_pt_3, tmp_pt_1, tmp_pt_2) - crypto.encodepoint_into(dst, tmp_pt_3) + crypto.decodepoint_into(_tmp_pt_1, A) + crypto.decodepoint_into(_tmp_pt_2, B) + crypto.point_add_into(_tmp_pt_3, _tmp_pt_1, _tmp_pt_2) + crypto.encodepoint_into(dst, _tmp_pt_3) return dst -def sub_keys(dst, A, B): +def _sub_keys(dst, A, B): dst = _ensure_dst_key(dst) - crypto.decodepoint_into(tmp_pt_1, A) - crypto.decodepoint_into(tmp_pt_2, B) - crypto.point_sub_into(tmp_pt_3, tmp_pt_1, tmp_pt_2) - crypto.encodepoint_into(dst, tmp_pt_3) + crypto.decodepoint_into(_tmp_pt_1, A) + crypto.decodepoint_into(_tmp_pt_2, B) + crypto.point_sub_into(_tmp_pt_3, _tmp_pt_1, _tmp_pt_2) + crypto.encodepoint_into(dst, _tmp_pt_3) return dst -def add_keys2(dst, a, b, B): +def _add_keys2(dst, a, b, B): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) - crypto.decodepoint_into(tmp_pt_1, B) - crypto.add_keys2_into(tmp_pt_2, tmp_sc_1, tmp_sc_2, tmp_pt_1) - crypto.encodepoint_into(dst, tmp_pt_2) + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + crypto.decodeint_into_noreduce(_tmp_sc_2, b) + crypto.decodepoint_into(_tmp_pt_1, B) + crypto.add_keys2_into(_tmp_pt_2, _tmp_sc_1, _tmp_sc_2, _tmp_pt_1) + crypto.encodepoint_into(dst, _tmp_pt_2) return dst -def add_keys3(dst, a, A, b, B): +def _add_keys3(dst, a, A, b, B): dst = _ensure_dst_key(dst) - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) - crypto.decodepoint_into(tmp_pt_1, A) - crypto.decodepoint_into(tmp_pt_2, B) - crypto.add_keys3_into(tmp_pt_3, tmp_sc_1, tmp_pt_1, tmp_sc_2, tmp_pt_2) - crypto.encodepoint_into(dst, tmp_pt_3) + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + crypto.decodeint_into_noreduce(_tmp_sc_2, b) + crypto.decodepoint_into(_tmp_pt_1, A) + crypto.decodepoint_into(_tmp_pt_2, B) + crypto.add_keys3_into(_tmp_pt_3, _tmp_sc_1, _tmp_pt_1, _tmp_sc_2, _tmp_pt_2) + crypto.encodepoint_into(dst, _tmp_pt_3) return dst -def hash_to_scalar(dst, data): +def _hash_to_scalar(dst, data): dst = _ensure_dst_key(dst) - crypto.hash_to_scalar_into(tmp_sc_1, data) - crypto.encodeint_into(dst, tmp_sc_1) + crypto.hash_to_scalar_into(_tmp_sc_1, data) + crypto.encodeint_into(dst, _tmp_sc_1) return dst -def hash_vct_to_scalar(dst, data): +def _hash_vct_to_scalar(dst, data): dst = _ensure_dst_key(dst) ctx = crypto.get_keccak() for x in data: ctx.update(x) hsh = ctx.digest() - crypto.decodeint_into(tmp_sc_1, hsh) - crypto.encodeint_into(tmp_bf_1, tmp_sc_1) - copy_key(dst, tmp_bf_1) + crypto.decodeint_into(_tmp_sc_1, hsh) + crypto.encodeint_into(dst, _tmp_sc_1) return dst -def get_exponent(dst, base, idx): +def _get_exponent(dst, base, idx): dst = _ensure_dst_key(dst) salt = b"bulletproof" - idx_size = uvarint_size(idx) - final_size = len(salt) + 32 + idx_size - buff = tmp_bf_exp_mv - memcpy(buff, 0, base, 0, 32) - memcpy(buff, 32, salt, 0, len(salt)) - dump_uvarint_b_into(idx, buff, 32 + len(salt)) - crypto.keccak_hash_into(tmp_bf_1, buff[:final_size]) - crypto.hash_to_point_into(tmp_pt_1, tmp_bf_1) - crypto.encodepoint_into(dst, tmp_pt_1) + lsalt = const(11) # len(salt) + final_size = lsalt + 32 + uvarint_size(idx) + memcpy(_tmp_bf_exp, 0, base, 0, 32) + memcpy(_tmp_bf_exp, 32, salt, 0, lsalt) + dump_uvarint_b_into(idx, _tmp_bf_exp, 32 + lsalt) + crypto.keccak_hash_into(_tmp_bf_1, _tmp_bf_exp, final_size) + crypto.hash_to_point_into(_tmp_pt_4, _tmp_bf_1) + crypto.encodepoint_into(dst, _tmp_pt_4) return dst @@ -255,6 +265,8 @@ class KeyVBase: Base KeyVector object """ + __slots__ = ("current_idx", "size") + def __init__(self, elems=64): self.current_idx = 0 self.size = elems @@ -286,8 +298,9 @@ class KeyVBase: def __len__(self): return self.size - def to(self, idx, buff, offset=0): - return memcpy(buff, offset, self[self.idxize(idx)], 0, 32) + def to(self, idx, buff=None, offset=0): + buff = _ensure_dst_key(buff) + return memcpy(buff, offset, self.to(self.idxize(idx)), 0, 32) def read(self, idx, buff, offset=0): raise ValueError @@ -301,20 +314,26 @@ class KeyVBase: return KeyVSliced(self, start, stop) +_CHBITS = const(5) +_CHSIZE = const(1 << _CHBITS) + + class KeyV(KeyVBase): """ KeyVector abstraction Constant precomputed buffers = bytes, frozen. Same operation as normal. - Non-constant KeyVector is separated to 64 elements chunks to avoid problems with - the heap fragmentation. In this way the chunks are more probable to be correctly - allocated as smaller chunk of continuous memory is required. Chunk is assumed to - have 64 elements at all times to minimize corner cases handling. BP require either - multiple of 64 elements vectors or less than 64. + Non-constant KeyVector is separated to _CHSIZE elements chunks to avoid problems with + the heap fragmentation. In this it is more probable that the chunks are correctly + allocated as smaller continuous memory is required. Chunk is assumed to + have _CHSIZE elements at all times to minimize corner cases handling. BP require either + multiple of _CHSIZE elements vectors or less than _CHSIZE. Some chunk-dependent cases are not implemented as they are currently not needed in the BP. """ + __slots__ = ("current_idx", "size", "d", "mv", "const", "cur", "chunked") + def __init__(self, elems=64, buffer=None, const=False, no_init=False): super().__init__(elems) self.d = None @@ -334,10 +353,10 @@ class KeyV(KeyVBase): self._set_mv() def _set_d(self, elems): - if elems > 64 and elems % 64 == 0: + if elems > _CHSIZE and elems % _CHSIZE == 0: self.chunked = True gc.collect() - self.d = [bytearray(32 * 64) for _ in range(elems // 64)] + self.d = [bytearray(32 * _CHSIZE) for _ in range(elems // _CHSIZE)] else: self.chunked = False @@ -354,7 +373,7 @@ class KeyV(KeyVBase): Creates new memoryview on access. """ if self.chunked: - raise ValueError("Not supported") # not needed + return self.to(item) item = self.idxize(item) return self.mv[item * 32 : (item + 1) * 32] @@ -373,8 +392,8 @@ class KeyV(KeyVBase): memcpy( buff if buff else self.cur, offset, - self.d[idx >> 6], - (idx & 63) << 5, + self.d[idx >> _CHBITS], + (idx & (_CHSIZE - 1)) << 5, 32, ) else: @@ -384,7 +403,7 @@ class KeyV(KeyVBase): def read(self, idx, buff, offset=0): idx = self.idxize(idx) if self.chunked: - memcpy(self.d[idx >> 6], (idx & 63) << 5, buff, offset, 32) + memcpy(self.d[idx >> _CHBITS], (idx & (_CHSIZE - 1)) << 5, buff, offset, 32) else: memcpy(self.d, idx << 5, buff, offset, 32) @@ -392,7 +411,7 @@ class KeyV(KeyVBase): if self.size == nsize: return self - if self.chunked and nsize <= 64: + if self.chunked and nsize <= _CHSIZE: self.chunked = False # de-chunk if self.size > nsize and realloc: gc.collect() @@ -404,8 +423,20 @@ class KeyV(KeyVBase): gc.collect() self.d = bytearray(nsize << 5) + elif self.chunked and self.size < nsize: + if nsize % _CHSIZE != 0 or realloc or chop: + raise ValueError("Unsupported") # not needed + for i in range((nsize - self.size) // _CHSIZE): + self.d.append(bytearray(32 * _CHSIZE)) + elif self.chunked: - raise ValueError("Unsupported") # not needed + if nsize % _CHSIZE != 0: + raise ValueError("Unsupported") # not needed + for i in range((self.size - nsize) // _CHSIZE): + self.d.pop() + if realloc: + for i in range(nsize // _CHSIZE): + self.d[i] = bytearray(self.d[i]) else: if self.size > nsize and realloc: @@ -439,20 +470,18 @@ class KeyV(KeyVBase): if not self.chunked and not src.chunked: memcpy(self.d, 0, src.d, offset << 5, nsize << 5) - elif self.chunked and not src.chunked: - raise ValueError("Unsupported") # not needed - - elif self.chunked and src.chunked: - raise ValueError("Unsupported") # not needed + elif self.chunked and not src.chunked or self.chunked and src.chunked: + for i in range(nsize): + self.read(i, src.to(i + offset)) elif not self.chunked and src.chunked: - for i in range(nsize >> 6): + for i in range(nsize >> _CHBITS): memcpy( self.d, i << 11, - src.d[i + (offset >> 6)], - (offset & 63) << 5 if i == 0 else 0, - nsize << 5 if i <= nsize >> 6 else (nsize & 64) << 5, + src.d[i + (offset >> _CHBITS)], + (offset & (_CHSIZE - 1)) << 5 if i == 0 else 0, + nsize << 5 if i <= nsize >> _CHBITS else (nsize & _CHSIZE) << 5, ) @@ -461,18 +490,35 @@ class KeyVEval(KeyVBase): KeyVector computed / evaluated on demand """ - def __init__(self, elems=64, src=None): + __slots__ = ("current_idx", "size", "fnc", "raw", "scalar", "buff") + + def __init__(self, elems=64, src=None, raw=False, scalar=True): super().__init__(elems) self.fnc = src - self.buff = _ensure_dst_key() - self.mv = memoryview(self.buff) + self.raw = raw + self.scalar = scalar + self.buff = ( + _ensure_dst_key() + if not raw + else (crypto.new_scalar() if scalar else crypto.new_point()) + ) def __getitem__(self, item): return self.fnc(self.idxize(item), self.buff) def to(self, idx, buff=None, offset=0): self.fnc(self.idxize(idx), self.buff) - memcpy(buff, offset, self.buff, 0, 32) + if self.raw: + if offset != 0: + raise ValueError("Not supported") + if self.scalar and buff: + return crypto.sc_copy(buff, self.buff) + elif self.scalar: + return self.buff + else: + raise ValueError("Not supported") + else: + memcpy(buff, offset, self.buff, 0, 32) return buff if buff else self.buff @@ -482,6 +528,8 @@ class KeyVSized(KeyVBase): (e.g., precomputed, but has to have exact size for further computations) """ + __slots__ = ("current_idx", "size", "wrapped") + def __init__(self, wrapped, new_size): super().__init__(new_size) self.wrapped = wrapped @@ -494,9 +542,11 @@ class KeyVSized(KeyVBase): class KeyVConst(KeyVBase): + __slots__ = ("current_idx", "size", "elem") + def __init__(self, size, elem, copy=True): super().__init__(size) - self.elem = init_key(elem) if copy else elem + self.elem = _init_key(elem) if copy else elem def __getitem__(self, item): return self.elem @@ -513,6 +563,8 @@ class KeyVPrecomp(KeyVBase): but possible to compute further """ + __slots__ = ("current_idx", "size", "precomp_prefix", "aux_comp_fnc", "buff") + def __init__(self, size, precomp_prefix, aux_comp_fnc): super().__init__(size) self.precomp_prefix = precomp_prefix @@ -539,6 +591,8 @@ class KeyVSliced(KeyVBase): Sliced in-memory vector version, remapping """ + __slots__ = ("current_idx", "size", "wrapped", "offset") + def __init__(self, src, start, stop): super().__init__(stop - start) self.wrapped = src @@ -565,10 +619,13 @@ class KeyVPowers(KeyVBase): Vector of x^i. Allows only sequential access (no jumping). Resets on [0,1] access. """ - def __init__(self, size, x, **kwargs): + __slots__ = ("current_idx", "size", "x", "raw", "cur", "last_idx") + + def __init__(self, size, x, raw=False, **kwargs): super().__init__(size) - self.x = x - self.cur = bytearray(32) + self.x = x if not raw else crypto.decodeint_into_noreduce(None, x) + self.raw = raw + self.cur = bytearray(32) if not raw else crypto.new_scalar() self.last_idx = 0 def __getitem__(self, item): @@ -577,39 +634,130 @@ class KeyVPowers(KeyVBase): self.last_idx = item if item == 0: - return copy_key(self.cur, _ONE) + return ( + _copy_key(self.cur, _ONE) + if not self.raw + else crypto.decodeint_into_noreduce(None, _ONE) + ) elif item == 1: - return copy_key(self.cur, self.x) + return ( + _copy_key(self.cur, self.x) + if not self.raw + else crypto.sc_copy(self.cur, self.x) + ) + elif item == prev: + return self.cur elif item == prev + 1: - return sc_mul(self.cur, self.cur, self.x) + return ( + _sc_mul(self.cur, self.cur, self.x) + if not self.raw + else crypto.sc_mul_into(self.cur, self.cur, self.x) + ) + else: + raise IndexError("Only linear scan allowed: %s, %s" % (prev, item)) + + def set_state(self, idx, val): + self.item = idx + self.last_idx = idx + if self.raw: + return crypto.sc_copy(self.cur, val) + else: + return _copy_key(self.cur, val) + + +class KeyR0(KeyVBase): + """ + Vector r0. Allows only sequential access (no jumping). Resets on [0,1] access. + zt_i = z^{2 + \floor{i/N}} 2^{i % N} + r0_i = ((a_{Ri} + z) y^{i}) + zt_i + + Could be composed from smaller vectors, but RAW returns are required + """ + + __slots__ = ( + "current_idx", + "size", + "N", + "aR", + "raw", + "y", + "yp", + "z", + "zt", + "p2", + "res", + "cur", + "last_idx", + ) + + def __init__(self, size, N, aR, y, z, raw=False, **kwargs): + super().__init__(size) + self.N = N + self.aR = aR + self.raw = raw + self.y = crypto.decodeint_into_noreduce(None, y) + self.yp = crypto.new_scalar() # y^{i} + self.z = crypto.decodeint_into_noreduce(None, z) + self.zt = crypto.new_scalar() # z^{2 + \floor{i/N}} + self.p2 = crypto.new_scalar() # 2^{i \% N} + self.res = crypto.new_scalar() # tmp_sc_1 + + self.cur = bytearray(32) if not raw else None + self.last_idx = 0 + self.reset() + + def reset(self): + crypto.decodeint_into_noreduce(self.yp, _ONE) + crypto.decodeint_into_noreduce(self.p2, _ONE) + crypto.sc_mul_into(self.zt, self.z, self.z) + + def __getitem__(self, item): + prev = self.last_idx + item = self.idxize(item) + self.last_idx = item + + # Const init for eval + if item == 0: # Reset on first item access + self.reset() + + elif item == prev + 1: + crypto.sc_mul_into(self.yp, self.yp, self.y) # ypow + if item % self.N == 0: + crypto.sc_mul_into(self.zt, self.zt, self.z) # zt + crypto.decodeint_into_noreduce(self.p2, _ONE) # p2 reset + else: + crypto.decodeint_into_noreduce(self.res, _TWO) # p2 + crypto.sc_mul_into(self.p2, self.p2, self.res) # p2 + + elif item == prev: # No advancing + pass + else: raise IndexError("Only linear scan allowed") + # Eval r0[i] + if ( + item == 0 or item != prev + ): # if True not present, fails with cross dot product + crypto.decodeint_into_noreduce(self.res, self.aR.to(item)) # aR[i] + crypto.sc_add_into(self.res, self.res, self.z) # aR[i] + z + crypto.sc_mul_into(self.res, self.res, self.yp) # (aR[i] + z) * y^i + crypto.sc_muladd_into( + self.res, self.zt, self.p2, self.res + ) # (aR[i] + z) * y^i + z^{2 + \floor{i/N}} 2^{i \% N} -class KeyVZtwo(KeyVBase): - """ - Ztwo vector - see vector_z_two_i - """ - - def __init__(self, N, logN, M, zpow, twoN, raw=False): - super().__init__(N * M) - self.N = N - self.logN = logN - self.M = M - self.zpow = zpow - self.twoN = twoN - self.raw = raw - self.sc = crypto.new_scalar() - self.cur = bytearray(32) if not raw else None - - def __getitem__(self, item): - vector_z_two_i(self.logN, self.zpow, self.twoN, self.idxize(item), self.sc) if self.raw: - return self.sc + return self.res - crypto.encodeint_into(self.cur, self.sc) + crypto.encodeint_into(self.cur, self.res) return self.cur + def to(self, idx, buff=None, offset=0): + r = self[idx] + if buff is None: + return r + return memcpy(buff, offset, r, 0, 32) + def _ensure_dst_keyvect(dst=None, size=None): if dst is None: @@ -620,27 +768,41 @@ def _ensure_dst_keyvect(dst=None, size=None): return dst -def const_vector(val, elems=_BP_N, copy=True): +def _const_vector(val, elems=_BP_N, copy=True): return KeyVConst(elems, val, copy) -def vector_exponent_custom(A, B, a, b, dst=None): +def _vector_exponent_custom(A, B, a, b, dst=None, a_raw=None, b_raw=None): + """ + \\sum_{i=0}^{|A|} a_i A_i + b_i B_i + """ dst = _ensure_dst_key(dst) - crypto.identity_into(tmp_pt_2) + crypto.identity_into(_tmp_pt_2) - for i in range(len(a)): - crypto.decodeint_into_noreduce(tmp_sc_1, a.to(i)) - crypto.decodepoint_into(tmp_pt_3, A.to(i)) - crypto.decodeint_into_noreduce(tmp_sc_2, b.to(i)) - crypto.decodepoint_into(tmp_pt_4, B.to(i)) - crypto.add_keys3_into(tmp_pt_1, tmp_sc_1, tmp_pt_3, tmp_sc_2, tmp_pt_4) - crypto.point_add_into(tmp_pt_2, tmp_pt_2, tmp_pt_1) - gc_iter(i) - crypto.encodepoint_into(dst, tmp_pt_2) + for i in range(len(a or a_raw)): + if a: + crypto.decodeint_into_noreduce(_tmp_sc_1, a.to(i)) + crypto.decodepoint_into(_tmp_pt_3, A.to(i)) + if b: + crypto.decodeint_into_noreduce(_tmp_sc_2, b.to(i)) + crypto.decodepoint_into(_tmp_pt_4, B.to(i)) + crypto.add_keys3_into( + _tmp_pt_1, + _tmp_sc_1 if a else a_raw.to(i), + _tmp_pt_3, + _tmp_sc_2 if b else b_raw.to(i), + _tmp_pt_4, + ) + crypto.point_add_into(_tmp_pt_2, _tmp_pt_2, _tmp_pt_1) + _gc_iter(i) + crypto.encodepoint_into(dst, _tmp_pt_2) return dst -def vector_powers(x, n, dst=None, dynamic=False, **kwargs): +def _vector_powers(x, n, dst=None, dynamic=False, **kwargs): + """ + r_i = x^i + """ if dynamic: return KeyVPowers(n, x, **kwargs) dst = _ensure_dst_keyvect(dst, n) @@ -651,186 +813,181 @@ def vector_powers(x, n, dst=None, dynamic=False, **kwargs): return dst dst.read(1, x) - crypto.decodeint_into_noreduce(tmp_sc_1, x) - crypto.decodeint_into_noreduce(tmp_sc_2, x) + crypto.decodeint_into_noreduce(_tmp_sc_1, x) + crypto.decodeint_into_noreduce(_tmp_sc_2, x) for i in range(2, n): - crypto.sc_mul_into(tmp_sc_1, tmp_sc_1, tmp_sc_2) - crypto.encodeint_into(tmp_bf_0, tmp_sc_1) - dst.read(i, tmp_bf_0) - gc_iter(i) + crypto.sc_mul_into(_tmp_sc_1, _tmp_sc_1, _tmp_sc_2) + crypto.encodeint_into(_tmp_bf_0, _tmp_sc_1) + dst.read(i, _tmp_bf_0) + _gc_iter(i) return dst -def vector_power_sum(x, n, dst=None): +def _vector_power_sum(x, n, dst=None): + """ + \\sum_{i=0}^{n-1} x^i + """ dst = _ensure_dst_key(dst) if n == 0: - return copy_key(dst, _ZERO) - - copy_key(dst, _ONE) + return _copy_key(dst, _ZERO) if n == 1: - return dst + _copy_key(dst, _ONE) - prev = init_key(x) - for i in range(1, n): - if i > 1: - sc_mul(prev, prev, x) - sc_add(dst, dst, prev) - gc_iter(i) - return dst + crypto.decodeint_into_noreduce(_tmp_sc_1, x) + crypto.decodeint_into_noreduce(_tmp_sc_3, _ONE) + crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_1) + crypto.sc_copy(_tmp_sc_2, _tmp_sc_1) + + for i in range(2, n): + crypto.sc_mul_into(_tmp_sc_2, _tmp_sc_2, _tmp_sc_1) + crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_2) + _gc_iter(i) + + return crypto.encodeint_into(dst, _tmp_sc_3) -def inner_product(a, b, dst=None): +def _inner_product(a, b, dst=None): + """ + \\sum_{i=0}^{|a|} a_i b_i + """ if len(a) != len(b): raise ValueError("Incompatible sizes of a and b") dst = _ensure_dst_key(dst) - crypto.sc_init_into(tmp_sc_1, 0) + crypto.sc_init_into(_tmp_sc_1, 0) for i in range(len(a)): - crypto.decodeint_into_noreduce(tmp_sc_2, a.to(i)) - crypto.decodeint_into_noreduce(tmp_sc_3, b.to(i)) - crypto.sc_muladd_into(tmp_sc_1, tmp_sc_2, tmp_sc_3, tmp_sc_1) - gc_iter(i) + crypto.decodeint_into_noreduce(_tmp_sc_2, a.to(i)) + crypto.decodeint_into_noreduce(_tmp_sc_3, b.to(i)) + crypto.sc_muladd_into(_tmp_sc_1, _tmp_sc_2, _tmp_sc_3, _tmp_sc_1) + _gc_iter(i) - crypto.encodeint_into(dst, tmp_sc_1) + crypto.encodeint_into(dst, _tmp_sc_1) return dst -def hadamard(a, b, dst=None): - dst = _ensure_dst_keyvect(dst, len(a)) - for i in range(len(a)): - sc_mul(tmp_bf_1, a.to(i), b.to(i)) - dst.read(i, tmp_bf_1) - gc_iter(i) - return dst - - -def hadamard_fold(v, a, b, into=None, into_offset=0): +def _hadamard_fold(v, a, b, into=None, into_offset=0, vR=None, vRoff=0): """ Folds a curvepoint array using a two way scaled Hadamard product ln = len(v); h = ln // 2 - v[i] = a * v[i] + b * v[h + i] + v_i = a v_i + b v_{h + i} """ h = len(v) // 2 - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + crypto.decodeint_into_noreduce(_tmp_sc_2, b) into = into if into else v for i in range(h): - crypto.decodepoint_into(tmp_pt_1, v.to(i)) - crypto.decodepoint_into(tmp_pt_2, v.to(h + i)) - crypto.add_keys3_into(tmp_pt_3, tmp_sc_1, tmp_pt_1, tmp_sc_2, tmp_pt_2) - crypto.encodepoint_into(tmp_bf_0, tmp_pt_3) - into.read(i + into_offset, tmp_bf_0) - gc_iter(i) + crypto.decodepoint_into(_tmp_pt_1, v.to(i)) + crypto.decodepoint_into(_tmp_pt_2, v.to(h + i) if not vR else vR.to(i + vRoff)) + crypto.add_keys3_into(_tmp_pt_3, _tmp_sc_1, _tmp_pt_1, _tmp_sc_2, _tmp_pt_2) + crypto.encodepoint_into(_tmp_bf_0, _tmp_pt_3) + into.read(i + into_offset, _tmp_bf_0) + _gc_iter(i) return into -def scalar_fold(v, a, b, into=None, into_offset=0): +def _hadamard_fold_linear(v, a, b, into=None, into_offset=0): + """ + Folds a curvepoint array using a two way scaled Hadamard product. + Iterates v linearly to support linear-scan evaluated vectors (on the fly) + + ln = len(v); h = ln // 2 + v_i = a v_i + b v_{h + i} + """ + h = len(v) // 2 + into = into if into else v + + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + for i in range(h): + crypto.decodepoint_into(_tmp_pt_1, v.to(i)) + crypto.scalarmult_into(_tmp_pt_1, _tmp_pt_1, _tmp_sc_1) + crypto.encodepoint_into(_tmp_bf_0, _tmp_pt_1) + into.read(i + into_offset, _tmp_bf_0) + _gc_iter(i) + + crypto.decodeint_into_noreduce(_tmp_sc_1, b) + for i in range(h): + crypto.decodepoint_into(_tmp_pt_1, v.to(i + h)) + crypto.scalarmult_into(_tmp_pt_1, _tmp_pt_1, _tmp_sc_1) + crypto.decodepoint_into(_tmp_pt_2, into.to(i + into_offset)) + crypto.point_add_into(_tmp_pt_1, _tmp_pt_1, _tmp_pt_2) + crypto.encodepoint_into(_tmp_bf_0, _tmp_pt_1) + into.read(i + into_offset, _tmp_bf_0) + + _gc_iter(i) + return into + + +def _scalar_fold(v, a, b, into=None, into_offset=0): """ ln = len(v); h = ln // 2 - v[i] = v[i] * a + v[h+i] * b) + v_i = a v_i + b v_{h + i} """ h = len(v) // 2 - crypto.decodeint_into_noreduce(tmp_sc_1, a) - crypto.decodeint_into_noreduce(tmp_sc_2, b) + crypto.decodeint_into_noreduce(_tmp_sc_1, a) + crypto.decodeint_into_noreduce(_tmp_sc_2, b) into = into if into else v for i in range(h): - crypto.decodeint_into_noreduce(tmp_sc_3, v.to(i)) - crypto.decodeint_into_noreduce(tmp_sc_4, v.to(h + i)) - crypto.sc_mul_into(tmp_sc_3, tmp_sc_3, tmp_sc_1) - crypto.sc_mul_into(tmp_sc_4, tmp_sc_4, tmp_sc_2) - crypto.sc_add_into(tmp_sc_3, tmp_sc_3, tmp_sc_4) - crypto.encodeint_into(tmp_bf_0, tmp_sc_3) - into.read(i + into_offset, tmp_bf_0) - gc_iter(i) + crypto.decodeint_into_noreduce(_tmp_sc_3, v.to(i)) + crypto.decodeint_into_noreduce(_tmp_sc_4, v.to(h + i)) + crypto.sc_mul_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_1) + crypto.sc_mul_into(_tmp_sc_4, _tmp_sc_4, _tmp_sc_2) + crypto.sc_add_into(_tmp_sc_3, _tmp_sc_3, _tmp_sc_4) + crypto.encodeint_into(_tmp_bf_0, _tmp_sc_3) + into.read(i + into_offset, _tmp_bf_0) + _gc_iter(i) return into -def cross_inner_product(l0, r0, l1, r1): +def _cross_inner_product(l0, r0, l1, r1): """ - t1_1 = l0 . r1, t1_2 = l1 . r0 - t1 = t1_1 + t1_2, t2 = l1 . r1 + t1 = l0 . r1 + l1 . r0 + t2 = l1 . r1 """ - sc_t1_1, sc_t1_2, sc_t2 = alloc_scalars(3) - cl0, cr0, cl1, cr1 = alloc_scalars(4) + sc_t1 = crypto.new_scalar() + sc_t2 = crypto.new_scalar() + tl = crypto.new_scalar() + tr = crypto.new_scalar() for i in range(len(l0)): - crypto.decodeint_into_noreduce(cl0, l0.to(i)) - crypto.decodeint_into_noreduce(cr0, r0.to(i)) - crypto.decodeint_into_noreduce(cl1, l1.to(i)) - crypto.decodeint_into_noreduce(cr1, r1.to(i)) + crypto.decodeint_into_noreduce(tl, l0.to(i)) + crypto.decodeint_into_noreduce(tr, r1.to(i)) + crypto.sc_muladd_into(sc_t1, tl, tr, sc_t1) - crypto.sc_muladd_into(sc_t1_1, cl0, cr1, sc_t1_1) - crypto.sc_muladd_into(sc_t1_2, cl1, cr0, sc_t1_2) - crypto.sc_muladd_into(sc_t2, cl1, cr1, sc_t2) - gc_iter(i) + crypto.decodeint_into_noreduce(tl, l1.to(i)) + crypto.sc_muladd_into(sc_t2, tl, tr, sc_t2) - crypto.sc_add_into(sc_t1_1, sc_t1_1, sc_t1_2) - return crypto.encodeint(sc_t1_1), crypto.encodeint(sc_t2) + crypto.decodeint_into_noreduce(tr, r0.to(i)) + crypto.sc_muladd_into(sc_t1, tl, tr, sc_t1) + + _gc_iter(i) + + return crypto.encodeint(sc_t1), crypto.encodeint(sc_t2) -def vector_gen(dst, size, op): +def _vector_gen(dst, size, op): dst = _ensure_dst_keyvect(dst, size) for i in range(size): - dst.to(i, tmp_bf_0) - op(i, tmp_bf_0) - dst.read(i, tmp_bf_0) - gc_iter(i) + dst.to(i, _tmp_bf_0) + op(i, _tmp_bf_0) + dst.read(i, _tmp_bf_0) + _gc_iter(i) return dst -def vector_add(a, b, dst=None): - dst = _ensure_dst_keyvect(dst, len(a)) - for i in range(len(a)): - sc_add(tmp_bf_1, a.to(i), b.to(i)) - dst.read(i, tmp_bf_1) - gc_iter(i) - return dst - - -def vector_subtract(a, b, dst=None): - dst = _ensure_dst_keyvect(dst, len(a)) - for i in range(len(a)): - sc_sub(tmp_bf_1, a.to(i), b.to(i)) - dst.read(i, tmp_bf_1) - gc_iter(i) - return dst - - -def vector_dup(x, n, dst=None): +def _vector_dup(x, n, dst=None): dst = _ensure_dst_keyvect(dst, n) for i in range(n): dst[i] = x - gc_iter(i) + _gc_iter(i) return dst -def vector_z_two_i(logN, zpow, twoN, i, dst_sc=None): - """ - 0...N|N+1...2N|2N+1...3N|.... - zt[i] = z^b 2^c, where - b = 2 + blockNumber. BlockNumber is idx of N block - c = i % N = i - N * blockNumber - """ - j = i >> logN - crypto.decodeint_into_noreduce(tmp_sc_1, zpow.to(j + 2)) - crypto.decodeint_into_noreduce(tmp_sc_2, twoN.to(i & ((1 << logN) - 1))) - crypto.sc_mul_into(dst_sc, tmp_sc_1, tmp_sc_2) - return dst_sc - - -def vector_z_two(N, logN, M, zpow, twoN, zero_twos=None, dynamic=False, **kwargs): - if dynamic: - return KeyVZtwo(N, logN, M, zpow, twoN, **kwargs) - else: - raise NotImplementedError - - -def hash_cache_mash(dst, hash_cache, *args): +def _hash_cache_mash(dst, hash_cache, *args): dst = _ensure_dst_key(dst) ctx = crypto.get_keccak() ctx.update(hash_cache) @@ -841,16 +998,14 @@ def hash_cache_mash(dst, hash_cache, *args): ctx.update(x) hsh = ctx.digest() - crypto.decodeint_into(tmp_sc_1, hsh) - crypto.encodeint_into(tmp_bf_1, tmp_sc_1) - - copy_key(dst, tmp_bf_1) - copy_key(hash_cache, tmp_bf_1) + crypto.decodeint_into(_tmp_sc_1, hsh) + crypto.encodeint_into(hash_cache, _tmp_sc_1) + _copy_key(dst, hash_cache) return dst -def is_reduced(sc): - return crypto.encodeint(crypto.decodeint(sc)) == sc +def _is_reduced(sc): + return crypto.encodeint_into(_tmp_bf_0, crypto.decodeint_into(_tmp_sc_1, sc)) == sc class MultiExpSequential: @@ -891,10 +1046,10 @@ class MultiExpSequential: self._acc(scalar, self.get_point(self.current_idx)) def _acc(self, scalar, point): - crypto.decodeint_into_noreduce(tmp_sc_1, scalar) - crypto.decodepoint_into(tmp_pt_2, point) - crypto.scalarmult_into(tmp_pt_3, tmp_pt_2, tmp_sc_1) - crypto.point_add_into(self.acc, self.acc, tmp_pt_3) + crypto.decodeint_into_noreduce(_tmp_sc_1, scalar) + crypto.decodepoint_into(_tmp_pt_2, point) + crypto.scalarmult_into(_tmp_pt_3, _tmp_pt_2, _tmp_sc_1) + crypto.point_add_into(self.acc, self.acc, _tmp_pt_3) self.current_idx += 1 self.size += 1 @@ -903,7 +1058,7 @@ class MultiExpSequential: return crypto.encodepoint_into(dst, self.acc) -def multiexp(dst=None, data=None, GiHi=False): +def _multiexp(dst=None, data=None, GiHi=False): return data.eval(dst, GiHi) @@ -916,10 +1071,8 @@ class BulletProofBuilder: self.Gprec = KeyV(buffer=crypto.tcry.BP_GI_PRE, const=True) # BP_HI_PRE = get_exponent(Hi[i], _XMR_H, i * 2) self.Hprec = KeyV(buffer=crypto.tcry.BP_HI_PRE, const=True) - self.oneN = const_vector(_ONE, 64) # BP_TWO_N = vector_powers(_TWO, _BP_N); self.twoN = KeyV(buffer=crypto.tcry.BP_TWO_N, const=True) - self.ip12 = _BP_IP12 self.fnc_det_mask = None self.tmp_sc_1 = crypto.new_scalar() @@ -947,7 +1100,7 @@ class BulletProofBuilder: else: r = _ZERO if is_a else _MINUS_ONE if d: - memcpy(d, 0, r, 0, 32) + return memcpy(d, 0, r, 0, 32) return r aL = KeyVEval(MN, lambda i, d: e_xL(i, d, True)) @@ -970,12 +1123,12 @@ class BulletProofBuilder: def _gprec_aux(self, size): return KeyVPrecomp( - size, self.Gprec, lambda i, d: get_exponent(d, _XMR_H, i * 2 + 1) + size, self.Gprec, lambda i, d: _get_exponent(d, _XMR_H, i * 2 + 1) ) def _hprec_aux(self, size): return KeyVPrecomp( - size, self.Hprec, lambda i, d: get_exponent(d, _XMR_H, i * 2) + size, self.Hprec, lambda i, d: _get_exponent(d, _XMR_H, i * 2) ) def _two_aux(self, size): @@ -991,7 +1144,7 @@ class BulletProofBuilder: lw = pow_two(flr) rw = pow_two(flr + 1 if flr != i / 2.0 else lw) - return sc_mul(d, lw, rw) + return _sc_mul(d, lw, rw) return KeyVPrecomp(size, self.twoN, pow_two) @@ -1017,19 +1170,16 @@ class BulletProofBuilder: for i in range(ln): crypto.random_scalar(sc) crypto.encodeint_into(buff_mv[i * 32 : (i + 1) * 32], sc) - gc_iter(i) + _gc_iter(i) return KeyV(buffer=buff) - def vector_exponent(self, a, b, dst=None): - return vector_exponent_custom(self.Gprec, self.Hprec, a, b, dst) + def vector_exponent(self, a, b, dst=None, a_raw=None, b_raw=None): + return _vector_exponent_custom(self.Gprec, self.Hprec, a, b, dst, a_raw, b_raw) - def prove_testnet(self, sv, gamma): - return self.prove(sv, gamma, proof_v8=True) + def prove(self, sv, gamma): + return self.prove_batch([sv], [gamma]) - def prove(self, sv, gamma, proof_v8=False): - return self.prove_batch([sv], [gamma], proof_v8=proof_v8) - - def prove_setup(self, sv, gamma, proof_v8=False): + def prove_setup(self, sv, gamma): utils.ensure(len(sv) == len(gamma), "|sv| != |gamma|") utils.ensure(len(sv) > 0, "sv empty") @@ -1047,207 +1197,199 @@ class BulletProofBuilder: V = _ensure_dst_keyvect(None, len(sv)) for i in range(len(sv)): - add_keys2(tmp_bf_0, gamma[i], sv[i], _XMR_H) - if not proof_v8: - scalarmult_key(tmp_bf_0, tmp_bf_0, _INV_EIGHT) - V.read(i, tmp_bf_0) + _add_keys2(_tmp_bf_0, gamma[i], sv[i], _XMR_H) + _scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT) + V.read(i, _tmp_bf_0) aL, aR = self.aX_vcts(sv, MN) return M, logM, aL, aR, V, gamma - def prove_batch(self, sv, gamma, proof_v8=False): - M, logM, aL, aR, V, gamma = self.prove_setup(sv, gamma, proof_v8) + def prove_batch(self, sv, gamma): + M, logM, aL, aR, V, gamma = self.prove_setup(sv, gamma) hash_cache = _ensure_dst_key() while True: self.gc(10) r = self._prove_batch_main( - V, gamma, aL, aR, hash_cache, logM, _BP_LOG_N, M, _BP_N, proof_v8 + V, gamma, aL, aR, hash_cache, logM, _BP_LOG_N, M, _BP_N ) if r[0]: break return r[1] - def _prove_batch_main( - self, V, gamma, aL, aR, hash_cache, logM, logN, M, N, proof_v8=False - ): + def _prove_batch_main(self, V, gamma, aL, aR, hash_cache, logM, logN, M, N): logMN = logM + logN MN = M * N - hash_vct_to_scalar(hash_cache, V) + _hash_vct_to_scalar(hash_cache, V) # Extended precomputed GiHi Gprec = self._gprec_aux(MN) Hprec = self._hprec_aux(MN) - # PAPER LINES 38-39 - alpha = sc_gen() - ve = _ensure_dst_key() + # PHASE 1 + A, S, T1, T2, taux, mu, t, l, r, y, x_ip, hash_cache = self._prove_phase1( + N, M, logMN, V, gamma, aL, aR, hash_cache, Gprec, Hprec + ) + + # PHASE 2 + L, R, a, b = self._prove_loop( + MN, logMN, l, r, y, x_ip, hash_cache, Gprec, Hprec + ) + + from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof + + return ( + 1, + Bulletproof( + V=V, A=A, S=S, T1=T1, T2=T2, taux=taux, mu=mu, L=L, R=R, a=a, b=b, t=t + ), + ) + + def _prove_phase1(self, N, M, logMN, V, gamma, aL, aR, hash_cache, Gprec, Hprec): + MN = M * N + + # PAPER LINES 38-39, compute A = 8^{-1} ( \alpha G + \sum_{i=0}^{MN-1} a_{L,i} \Gi_i + a_{R,i} \Hi_i) + alpha = _sc_gen() A = _ensure_dst_key() - vector_exponent_custom(Gprec, Hprec, aL, aR, ve) - add_keys(A, ve, scalarmult_base(tmp_bf_1, alpha)) - if not proof_v8: - scalarmult_key(A, A, _INV_EIGHT) + _vector_exponent_custom(Gprec, Hprec, aL, aR, A) + _add_keys(A, A, _scalarmult_base(_tmp_bf_1, alpha)) + _scalarmult_key(A, A, _INV_EIGHT) self.gc(11) - # PAPER LINES 40-42 + # PAPER LINES 40-42, compute S = 8^{-1} ( \rho G + \sum_{i=0}^{MN-1} s_{L,i} \Gi_i + s_{R,i} \Hi_i) sL = self.sL_vct(MN) sR = self.sR_vct(MN) - rho = sc_gen() - vector_exponent_custom(Gprec, Hprec, sL, sR, ve) + rho = _sc_gen() S = _ensure_dst_key() - add_keys(S, ve, scalarmult_base(tmp_bf_1, rho)) - if not proof_v8: - scalarmult_key(S, S, _INV_EIGHT) - del ve + _vector_exponent_custom(Gprec, Hprec, sL, sR, S) + _add_keys(S, S, _scalarmult_base(_tmp_bf_1, rho)) + _scalarmult_key(S, S, _INV_EIGHT) self.gc(12) # PAPER LINES 43-45 y = _ensure_dst_key() - hash_cache_mash(y, hash_cache, A, S) + _hash_cache_mash(y, hash_cache, A, S) if y == _ZERO: return (0,) z = _ensure_dst_key() - hash_to_scalar(hash_cache, y) - copy_key(z, hash_cache) + _hash_to_scalar(hash_cache, y) + _copy_key(z, hash_cache) + zc = crypto.decodeint_into_noreduce(None, z) if z == _ZERO: return (0,) # Polynomial construction by coefficients - zMN = const_vector(z, MN) - l0 = _ensure_dst_keyvect(None, MN) - vector_subtract(aL, zMN, l0) + # l0 = aL - z r0 = ((aR + z) . ypow) + zt + # l1 = sL r1 = sR . ypow + l0 = KeyVEval( + MN, lambda i, d: _sc_sub(d, aL.to(i), None, None, zc) # noqa: F821 + ) l1 = sL self.gc(13) # This computes the ugly sum/concatenation from PAPER LINE 65 - # r0 = aR + z - r0 = vector_add(aR, zMN) - del zMN + # r0_i = ((a_{Ri} + z) y^{i}) + zt_i + # r1_i = s_{Ri} y^{i} + r0 = KeyR0(MN, N, aR, y, z) + ypow = KeyVPowers(MN, y, raw=True) + r1 = KeyVEval( + MN, lambda i, d: _sc_mul(d, sR.to(i), None, ypow[i]) # noqa: F821 + ) + del aR self.gc(14) - # r0 = r0 \odot yMN => r0[i] = r0[i] * y^i - # r1 = sR \odot yMN => r1[i] = sR[i] * y^i - yMN = vector_powers(y, MN, dynamic=False) - hadamard(r0, yMN, dst=r0) - self.gc(15) + # Evaluate per index + # - $t_1 = l_0 . r_1 + l_1 . r0$ + # - $t_2 = l_1 . r_1$ + # - compute then T1, T2, x + t1, t2 = _cross_inner_product(l0, r0, l1, r1) - # r0 = r0 + zero_twos - zpow = vector_powers(z, M + 2) - twoN = self._two_aux(MN) - zero_twos = vector_z_two(N, logN, M, zpow, twoN, dynamic=True, raw=True) - vector_gen( - r0, - len(r0), - lambda i, d: crypto.encodeint_into( - d, - crypto.sc_add_into( - tmp_sc_1, - zero_twos[i], # noqa: F821 - crypto.decodeint_into_noreduce(tmp_sc_2, r0.to(i)), # noqa: F821 - ), - ), - ) - - del (zero_twos, twoN) - self.gc(15) - - # Polynomial construction before PAPER LINE 46 - # r1 = KeyVEval(MN, lambda i, d: sc_mul(d, yMN[i], sR[i])) - # r1 optimization possible, but has clashing sc registers. - # Moreover, max memory complexity is 4MN as below (while loop). - r1 = hadamard(yMN, sR, yMN) # re-use yMN vector for r1 - del (yMN, sR) - self.gc(16) - - # Inner products - # l0 = aL - z r0 = ((aR + z) \cdot ypow) + zt - # l1 = sL r1 = sR \cdot ypow - # t1_1 = l0 . r1, t1_2 = l1 . r0 - # t1 = t1_1 + t1_2, t2 = l1 . r1 - # l = l0 \odot x*l1 r = r0 \odot x*r1 - t1, t2 = cross_inner_product(l0, r0, l1, r1) - self.gc(17) - - # PAPER LINES 47-48 - tau1, tau2 = sc_gen(), sc_gen() + # PAPER LINES 47-48, Compute: T1, T2 + # T1 = 8^{-1} (\tau_1G + t_1H ) + # T2 = 8^{-1} (\tau_2G + t_2H ) + tau1, tau2 = _sc_gen(), _sc_gen() T1, T2 = _ensure_dst_key(), _ensure_dst_key() - add_keys(T1, scalarmultH(tmp_bf_1, t1), scalarmult_base(tmp_bf_2, tau1)) - if not proof_v8: - scalarmult_key(T1, T1, _INV_EIGHT) + _add_keys2(T1, tau1, t1, _XMR_H) + _scalarmult_key(T1, T1, _INV_EIGHT) - add_keys(T2, scalarmultH(tmp_bf_1, t2), scalarmult_base(tmp_bf_2, tau2)) - if not proof_v8: - scalarmult_key(T2, T2, _INV_EIGHT) + _add_keys2(T2, tau2, t2, _XMR_H) + _scalarmult_key(T2, T2, _INV_EIGHT) del (t1, t2) - self.gc(17) + self.gc(16) - # PAPER LINES 49-51 + # PAPER LINES 49-51, compute x x = _ensure_dst_key() - hash_cache_mash(x, hash_cache, z, T1, T2) + _hash_cache_mash(x, hash_cache, z, T1, T2) if x == _ZERO: return (0,) - # PAPER LINES 52-53 + # Second pass, compute l, r + # Offloaded version does this incrementally and produces l, r outs in chunks + # Message offloaded sends blinded vectors with random constants. + # - $l_i = l_{0,i} + xl_{1,i} + # - $r_i = r_{0,i} + xr_{1,i} + # - $t = l . r$ + l = _ensure_dst_keyvect(None, MN) + r = _ensure_dst_keyvect(None, MN) + ts = crypto.new_scalar() + for i in range(MN): + _sc_muladd(_tmp_bf_0, x, l1.to(i), l0.to(i)) + l.read(i, _tmp_bf_0) + + _sc_muladd(_tmp_bf_1, x, r1.to(i), r0.to(i)) + r.read(i, _tmp_bf_1) + + _sc_muladd(ts, _tmp_bf_0, _tmp_bf_1, None, c_raw=ts, raw=True) + + t = crypto.encodeint(ts) + del (l0, l1, sL, sR, r0, r1, ypow, ts) + self.gc(17) + + # PAPER LINES 52-53, Compute \tau_x taux = _ensure_dst_key() - copy_key(taux, _ZERO) - sc_mul(taux, tau1, x) - xsq = _ensure_dst_key() - sc_mul(xsq, x, x) - sc_muladd(taux, tau2, xsq, taux) - del (xsq, tau1, tau2) + _sc_mul(taux, tau1, x) + _sc_mul(_tmp_bf_0, x, x) + _sc_muladd(taux, tau2, _tmp_bf_0, taux) + del (tau1, tau2) + + zpow = crypto.sc_mul_into(None, zc, zc) for j in range(1, len(V) + 1): - sc_muladd(taux, zpow.to(j + 1), gamma[j - 1], taux) - del zpow + _sc_muladd(taux, None, gamma[j - 1], taux, a_raw=zpow) + crypto.sc_mul_into(zpow, zpow, zc) + del (zc, zpow) self.gc(18) mu = _ensure_dst_key() - sc_muladd(mu, x, rho, alpha) + _sc_muladd(mu, x, rho, alpha) del (rho, alpha) - - # PAPER LINES 54-57 - # l = l0 \odot x*l1, has to evaluated as it becomes aprime in the loop - l = vector_gen( - l0, - len(l0), - lambda i, d: sc_add(d, d, sc_mul(tmp_bf_1, l1.to(i), x)), # noqa: F821 - ) - del (l0, l1, sL) - self.gc(19) - - # r = r0 \odot x*r1, has to evaluated as it becomes bprime in the loop - r = vector_gen( - r0, - len(r0), - lambda i, d: sc_add(d, d, sc_mul(tmp_bf_1, r1.to(i), x)), # noqa: F821 - ) - t = inner_product(l, r) - del (r1, r0) self.gc(19) # PAPER LINES 32-33 - x_ip = hash_cache_mash(None, hash_cache, x, taux, mu, t) + x_ip = _hash_cache_mash(None, hash_cache, x, taux, mu, t) if x_ip == _ZERO: return 0, None - # PHASE 2 - # These are used in the inner product rounds + return A, S, T1, T2, taux, mu, t, l, r, y, x_ip, hash_cache + + def _prove_loop(self, MN, logMN, l, r, y, x_ip, hash_cache, Gprec, Hprec): nprime = MN - Gprime = _ensure_dst_keyvect(None, MN) - Hprime = _ensure_dst_keyvect(None, MN) aprime = l bprime = r - yinv = invert(None, y) - yinvpow = init_key(_ONE) - self.gc(20) - for i in range(0, MN): - Gprime.read(i, Gprec.to(i)) - scalarmult_key(tmp_bf_0, Hprec.to(i), yinvpow) - Hprime.read(i, tmp_bf_0) - sc_mul(yinvpow, yinvpow, yinv) - gc_iter(i) - self.gc(21) + yinvpowL = KeyVPowers(MN, _invert(_tmp_bf_0, y), raw=True) + yinvpowR = KeyVPowers(MN, _tmp_bf_0, raw=True) + tmp_pt = crypto.new_point() + + Gprime = Gprec + HprimeL = KeyVEval( + MN, lambda i, d: _scalarmult_key(d, Hprec.to(i), None, yinvpowL[i]) + ) + HprimeR = KeyVEval( + MN, lambda i, d: _scalarmult_key(d, Hprec.to(i), None, yinvpowR[i], tmp_pt) + ) + Hprime = HprimeL + self.gc(20) L = _ensure_dst_keyvect(None, logMN) R = _ensure_dst_keyvect(None, logMN) @@ -1256,9 +1398,8 @@ class BulletProofBuilder: winv = _ensure_dst_key() w_round = _ensure_dst_key() tmp = _ensure_dst_key() - - round = 0 _tmp_k_1 = _ensure_dst_key() + round = 0 # PAPER LINE 13 while nprime > 1: @@ -1268,116 +1409,118 @@ class BulletProofBuilder: self.gc(22) # PAPER LINES 16-17 - inner_product( + # cL = \ap_{\left(\inta\right)} \cdot \bp_{\left(\intb\right)} + # cR = \ap_{\left(\intb\right)} \cdot \bp_{\left(\inta\right)} + _inner_product( aprime.slice_view(0, nprime), bprime.slice_view(nprime, npr2), cL ) - inner_product( + _inner_product( aprime.slice_view(nprime, npr2), bprime.slice_view(0, nprime), cR ) self.gc(23) # PAPER LINES 18-19 - vector_exponent_custom( + # Lc = 8^{-1} \left(\left( \sum_{i=0}^{\np} \ap_{i}\quad\Gp_{i+\np} + \bp_{i+\np}\Hp_{i} \right) + # + \left(c_L x_{ip}\right)H \right) + _vector_exponent_custom( Gprime.slice_view(nprime, npr2), Hprime.slice_view(0, nprime), aprime.slice_view(0, nprime), bprime.slice_view(nprime, npr2), - tmp_bf_0, + _tmp_bf_0, ) - sc_mul(tmp, cL, x_ip) - add_keys(tmp_bf_0, tmp_bf_0, scalarmultH(_tmp_k_1, tmp)) - if not proof_v8: - scalarmult_key(tmp_bf_0, tmp_bf_0, _INV_EIGHT) - L.read(round, tmp_bf_0) + # In round 0 backup the y^{prime - 1} + if round == 0: + yinvpowR.set_state(yinvpowL.last_idx, yinvpowL.cur) + + _sc_mul(tmp, cL, x_ip) + _add_keys(_tmp_bf_0, _tmp_bf_0, _scalarmultH(_tmp_k_1, tmp)) + _scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT) + L.read(round, _tmp_bf_0) self.gc(24) - vector_exponent_custom( + # Rc = 8^{-1} \left(\left( \sum_{i=0}^{\np} \ap_{i+\np}\Gp_{i}\quad + \bp_{i}\quad\Hp_{i+\np} \right) + # + \left(c_R x_{ip}\right)H \right) + _vector_exponent_custom( Gprime.slice_view(0, nprime), Hprime.slice_view(nprime, npr2), aprime.slice_view(nprime, npr2), bprime.slice_view(0, nprime), - tmp_bf_0, + _tmp_bf_0, ) - sc_mul(tmp, cR, x_ip) - add_keys(tmp_bf_0, tmp_bf_0, scalarmultH(_tmp_k_1, tmp)) - if not proof_v8: - scalarmult_key(tmp_bf_0, tmp_bf_0, _INV_EIGHT) - R.read(round, tmp_bf_0) + _sc_mul(tmp, cR, x_ip) + _add_keys(_tmp_bf_0, _tmp_bf_0, _scalarmultH(_tmp_k_1, tmp)) + _scalarmult_key(_tmp_bf_0, _tmp_bf_0, _INV_EIGHT) + R.read(round, _tmp_bf_0) self.gc(25) # PAPER LINES 21-22 - hash_cache_mash(w_round, hash_cache, L.to(round), R.to(round)) + _hash_cache_mash(w_round, hash_cache, L.to(round), R.to(round)) if w_round == _ZERO: return (0,) - # PAPER LINES 24-25 - invert(winv, w_round) + # PAPER LINES 24-25, fold {G~, H~} + _invert(winv, w_round) self.gc(26) - hadamard_fold(Gprime, winv, w_round) + # PAPER LINES 28-29, fold {a, b} vectors + # aprime's high part is used as a buffer for other operations + _scalar_fold(aprime, w_round, winv) + aprime.resize(nprime) self.gc(27) - hadamard_fold(Hprime, w_round, winv, Gprime, nprime) - Hprime.realloc_init_from(nprime, Gprime, nprime, round < 2) + _scalar_fold(bprime, winv, w_round) + bprime.resize(nprime) self.gc(28) - # PAPER LINES 28-29 - scalar_fold(aprime, w_round, winv, Gprime, nprime) - aprime.realloc_init_from(nprime, Gprime, nprime, round < 2) - self.gc(29) - - scalar_fold(bprime, winv, w_round, Gprime, nprime) - bprime.realloc_init_from(nprime, Gprime, nprime, round < 2) + # First fold produced to a new buffer, smaller one (G~ on-the-fly) + Gprime_new = KeyV(nprime) if round == 0 else Gprime + Gprime = _hadamard_fold(Gprime, winv, w_round, Gprime_new, 0) + Gprime.resize(nprime) self.gc(30) - # Finally resize Gprime which was buffer for all ops - Gprime.resize(nprime, realloc=True) + # Hadamard fold for H is special - linear scan only. + # Linear scan is slow, thus we have HprimeR. + if round == 0: + Hprime_new = KeyV(nprime) + Hprime = _hadamard_fold( + Hprime, w_round, winv, Hprime_new, 0, HprimeR, nprime + ) + # Hprime = _hadamard_fold_linear(Hprime, w_round, winv, Hprime_new, 0) + + else: + _hadamard_fold(Hprime, w_round, winv) + Hprime.resize(nprime) + + if round == 0: + # del (Gprec, Hprec, yinvpowL, HprimeL) + del (Gprec, Hprec, yinvpowL, yinvpowR, HprimeL, HprimeR, tmp_pt) + + self.gc(31) round += 1 - from apps.monero.xmr.serialize_messages.tx_rsig_bulletproof import Bulletproof + return L, R, aprime.to(0), bprime.to(0) - return ( - 1, - Bulletproof( - V=V, - A=A, - S=S, - T1=T1, - T2=T2, - taux=taux, - mu=mu, - L=L, - R=R, - a=aprime.to(0), - b=bprime.to(0), - t=t, - ), - ) + def verify(self, proof): + return self.verify_batch([proof]) - def verify_testnet(self, proof): - return self.verify(proof, proof_v8=True) - - def verify(self, proof, proof_v8=False): - return self.verify_batch([proof], proof_v8=proof_v8) - - def verify_batch(self, proofs, single_optim=True, proof_v8=False): + def verify_batch(self, proofs, single_optim=True): """ BP batch verification :param proofs: :param single_optim: single proof memory optimization - :param proof_v8: previous testnet version :return: """ max_length = 0 for proof in proofs: - utils.ensure(is_reduced(proof.taux), "Input scalar not in range") - utils.ensure(is_reduced(proof.mu), "Input scalar not in range") - utils.ensure(is_reduced(proof.a), "Input scalar not in range") - utils.ensure(is_reduced(proof.b), "Input scalar not in range") - utils.ensure(is_reduced(proof.t), "Input scalar not in range") + utils.ensure(_is_reduced(proof.taux), "Input scalar not in range") + utils.ensure(_is_reduced(proof.mu), "Input scalar not in range") + utils.ensure(_is_reduced(proof.a), "Input scalar not in range") + utils.ensure(_is_reduced(proof.b), "Input scalar not in range") + utils.ensure(_is_reduced(proof.t), "Input scalar not in range") utils.ensure(len(proof.V) >= 1, "V does not have at least one element") utils.ensure(len(proof.L) == len(proof.R), "|L| != |R|") utils.ensure(len(proof.L) > 0, "Empty proof") @@ -1392,13 +1535,13 @@ class BulletProofBuilder: # setup weighted aggregates is_single = len(proofs) == 1 and single_optim # ph4 - z1 = init_key(_ZERO) - z3 = init_key(_ZERO) - m_z4 = vector_dup(_ZERO, maxMN) if not is_single else None - m_z5 = vector_dup(_ZERO, maxMN) if not is_single else None - m_y0 = init_key(_ZERO) - y1 = init_key(_ZERO) - muex_acc = init_key(_ONE) + z1 = _init_key(_ZERO) + z3 = _init_key(_ZERO) + m_z4 = _vector_dup(_ZERO, maxMN) if not is_single else None + m_z5 = _vector_dup(_ZERO, maxMN) if not is_single else None + m_y0 = _init_key(_ZERO) + y1 = _init_key(_ZERO) + muex_acc = _init_key(_ONE) Gprec = self._gprec_aux(maxMN) Hprec = self._hprec_aux(maxMN) @@ -1416,62 +1559,60 @@ class BulletProofBuilder: weight_z = crypto.encodeint(crypto.random_scalar()) # Reconstruct the challenges - hash_cache = hash_vct_to_scalar(None, proof.V) - y = hash_cache_mash(None, hash_cache, proof.A, proof.S) + hash_cache = _hash_vct_to_scalar(None, proof.V) + y = _hash_cache_mash(None, hash_cache, proof.A, proof.S) utils.ensure(y != _ZERO, "y == 0") - z = hash_to_scalar(None, y) - copy_key(hash_cache, z) + z = _hash_to_scalar(None, y) + _copy_key(hash_cache, z) utils.ensure(z != _ZERO, "z == 0") - x = hash_cache_mash(None, hash_cache, z, proof.T1, proof.T2) + x = _hash_cache_mash(None, hash_cache, z, proof.T1, proof.T2) utils.ensure(x != _ZERO, "x == 0") - x_ip = hash_cache_mash(None, hash_cache, x, proof.taux, proof.mu, proof.t) + x_ip = _hash_cache_mash(None, hash_cache, x, proof.taux, proof.mu, proof.t) utils.ensure(x_ip != _ZERO, "x_ip == 0") # PAPER LINE 61 - sc_mulsub(m_y0, proof.taux, weight_y, m_y0) - zpow = vector_powers(z, M + 3) + _sc_mulsub(m_y0, proof.taux, weight_y, m_y0) + zpow = _vector_powers(z, M + 3) k = _ensure_dst_key() - ip1y = vector_power_sum(y, MN) - sc_mulsub(k, zpow[2], ip1y, _ZERO) + ip1y = _vector_power_sum(y, MN) + _sc_mulsub(k, zpow.to(2), ip1y, _ZERO) for j in range(1, M + 1): utils.ensure(j + 2 < len(zpow), "invalid zpow index") - sc_mulsub(k, zpow.to(j + 2), _BP_IP12, k) + _sc_mulsub(k, zpow.to(j + 2), _BP_IP12, k) # VERIFY_line_61rl_new - sc_muladd(tmp, z, ip1y, k) - sc_sub(tmp, proof.t, tmp) + _sc_muladd(tmp, z, ip1y, k) + _sc_sub(tmp, proof.t, tmp) - sc_muladd(y1, tmp, weight_y, y1) - weight_y8 = init_key(weight_y) - if not proof_v8: - weight_y8 = sc_mul(None, weight_y, _EIGHT) + _sc_muladd(y1, tmp, weight_y, y1) + weight_y8 = _init_key(weight_y) + weight_y8 = _sc_mul(None, weight_y, _EIGHT) muex = MultiExpSequential(points=[pt for pt in proof.V]) for j in range(len(proof.V)): - sc_mul(tmp, zpow[j + 2], weight_y8) - muex.add_scalar(init_key(tmp)) + _sc_mul(tmp, zpow.to(j + 2), weight_y8) + muex.add_scalar(_init_key(tmp)) - sc_mul(tmp, x, weight_y8) - muex.add_pair(init_key(tmp), proof.T1) + _sc_mul(tmp, x, weight_y8) + muex.add_pair(_init_key(tmp), proof.T1) xsq = _ensure_dst_key() - sc_mul(xsq, x, x) + _sc_mul(xsq, x, x) - sc_mul(tmp, xsq, weight_y8) - muex.add_pair(init_key(tmp), proof.T2) + _sc_mul(tmp, xsq, weight_y8) + muex.add_pair(_init_key(tmp), proof.T2) - weight_z8 = init_key(weight_z) - if not proof_v8: - weight_z8 = sc_mul(None, weight_z, _EIGHT) + weight_z8 = _init_key(weight_z) + weight_z8 = _sc_mul(None, weight_z, _EIGHT) muex.add_pair(weight_z8, proof.A) - sc_mul(tmp, x, weight_z8) - muex.add_pair(init_key(tmp), proof.S) + _sc_mul(tmp, x, weight_z8) + muex.add_pair(_init_key(tmp), proof.S) - multiexp(tmp, muex, False) - add_keys(muex_acc, muex_acc, tmp) + _multiexp(tmp, muex, False) + _add_keys(muex_acc, muex_acc, tmp) del muex # Compute the number of rounds for the inner product @@ -1482,97 +1623,101 @@ class BulletProofBuilder: # The inner product challenges are computed per round w = _ensure_dst_keyvect(None, rounds) for i in range(rounds): - hash_cache_mash(tmp_bf_0, hash_cache, proof.L[i], proof.R[i]) - w.read(i, tmp_bf_0) - utils.ensure(w[i] != _ZERO, "w[i] == 0") + _hash_cache_mash(_tmp_bf_0, hash_cache, proof.L[i], proof.R[i]) + w.read(i, _tmp_bf_0) + utils.ensure(w.to(i) != _ZERO, "w[i] == 0") # Basically PAPER LINES 24-25 # Compute the curvepoints from G[i] and H[i] - yinvpow = init_key(_ONE) - ypow = init_key(_ONE) - yinv = invert(None, y) + yinvpow = _init_key(_ONE) + ypow = _init_key(_ONE) + yinv = _invert(None, y) self.gc(61) winv = _ensure_dst_keyvect(None, rounds) for i in range(rounds): - invert(tmp_bf_0, w.to(i)) - winv.read(i, tmp_bf_0) + _invert(_tmp_bf_0, w.to(i)) + winv.read(i, _tmp_bf_0) self.gc(62) g_scalar = _ensure_dst_key() h_scalar = _ensure_dst_key() twoN = self._two_aux(N) for i in range(MN): - copy_key(g_scalar, proof.a) - sc_mul(h_scalar, proof.b, yinvpow) + _copy_key(g_scalar, proof.a) + _sc_mul(h_scalar, proof.b, yinvpow) for j in range(rounds - 1, -1, -1): J = len(w) - j - 1 if (i & (1 << j)) == 0: - sc_mul(g_scalar, g_scalar, winv.to(J)) - sc_mul(h_scalar, h_scalar, w.to(J)) + _sc_mul(g_scalar, g_scalar, winv.to(J)) + _sc_mul(h_scalar, h_scalar, w.to(J)) else: - sc_mul(g_scalar, g_scalar, w.to(J)) - sc_mul(h_scalar, h_scalar, winv.to(J)) + _sc_mul(g_scalar, g_scalar, w.to(J)) + _sc_mul(h_scalar, h_scalar, winv.to(J)) # Adjust the scalars using the exponents from PAPER LINE 62 - sc_add(g_scalar, g_scalar, z) + _sc_add(g_scalar, g_scalar, z) utils.ensure(2 + i // N < len(zpow), "invalid zpow index") utils.ensure(i % N < len(twoN), "invalid twoN index") - sc_mul(tmp, zpow.to(2 + i // N), twoN.to(i % N)) - sc_muladd(tmp, z, ypow, tmp) - sc_mulsub(h_scalar, tmp, yinvpow, h_scalar) + _sc_mul(tmp, zpow.to(2 + i // N), twoN.to(i % N)) + _sc_muladd(tmp, z, ypow, tmp) + _sc_mulsub(h_scalar, tmp, yinvpow, h_scalar) if not is_single: # ph4 - sc_mulsub(m_z4[i], g_scalar, weight_z, m_z4[i]) - sc_mulsub(m_z5[i], h_scalar, weight_z, m_z5[i]) + _sc_mulsub(m_z4[i], g_scalar, weight_z, m_z4[i]) + _sc_mulsub(m_z5[i], h_scalar, weight_z, m_z5[i]) else: - sc_mul(tmp, g_scalar, weight_z) - sub_keys(muex_acc, muex_acc, scalarmult_key(tmp, Gprec.to(i), tmp)) + _sc_mul(tmp, g_scalar, weight_z) + _sub_keys( + muex_acc, muex_acc, _scalarmult_key(tmp, Gprec.to(i), tmp) + ) - sc_mul(tmp, h_scalar, weight_z) - sub_keys(muex_acc, muex_acc, scalarmult_key(tmp, Hprec.to(i), tmp)) + _sc_mul(tmp, h_scalar, weight_z) + _sub_keys( + muex_acc, muex_acc, _scalarmult_key(tmp, Hprec.to(i), tmp) + ) if i != MN - 1: - sc_mul(yinvpow, yinvpow, yinv) - sc_mul(ypow, ypow, y) + _sc_mul(yinvpow, yinvpow, yinv) + _sc_mul(ypow, ypow, y) if i & 15 == 0: self.gc(62) del (g_scalar, h_scalar, twoN) self.gc(63) - sc_muladd(z1, proof.mu, weight_z, z1) + _sc_muladd(z1, proof.mu, weight_z, z1) muex = MultiExpSequential( point_fnc=lambda i, d: proof.L[i // 2] if i & 1 == 0 else proof.R[i // 2] ) for i in range(rounds): - sc_mul(tmp, w[i], w[i]) - sc_mul(tmp, tmp, weight_z8) + _sc_mul(tmp, w.to(i), w.to(i)) + _sc_mul(tmp, tmp, weight_z8) muex.add_scalar(tmp) - sc_mul(tmp, winv[i], winv[i]) - sc_mul(tmp, tmp, weight_z8) + _sc_mul(tmp, winv.to(i), winv.to(i)) + _sc_mul(tmp, tmp, weight_z8) muex.add_scalar(tmp) - acc = multiexp(None, muex, False) - add_keys(muex_acc, muex_acc, acc) + acc = _multiexp(None, muex, False) + _add_keys(muex_acc, muex_acc, acc) - sc_mulsub(tmp, proof.a, proof.b, proof.t) - sc_mul(tmp, tmp, x_ip) - sc_muladd(z3, tmp, weight_z, z3) + _sc_mulsub(tmp, proof.a, proof.b, proof.t) + _sc_mul(tmp, tmp, x_ip) + _sc_muladd(z3, tmp, weight_z, z3) - sc_sub(tmp, m_y0, z1) - z3p = sc_sub(None, z3, y1) + _sc_sub(tmp, m_y0, z1) + z3p = _sc_sub(None, z3, y1) check2 = crypto.encodepoint( crypto.ge25519_double_scalarmult_base_vartime( crypto.decodeint(z3p), crypto.xmr_H(), crypto.decodeint(tmp) ) ) - add_keys(muex_acc, muex_acc, check2) + _add_keys(muex_acc, muex_acc, check2) if not is_single: # ph4 muex = MultiExpSequential( @@ -1583,7 +1728,7 @@ class BulletProofBuilder: for i in range(maxMN): muex.add_scalar(m_z4[i]) muex.add_scalar(m_z5[i]) - add_keys(muex_acc, muex_acc, multiexp(None, muex, True)) + _add_keys(muex_acc, muex_acc, _multiexp(None, muex, True)) if muex_acc != _ONE: raise ValueError("Verification failure at step 2") diff --git a/core/tests/test_apps.monero.bulletproof.py b/core/tests/test_apps.monero.bulletproof.py index 467da8940b..8dcec9b00b 100644 --- a/core/tests/test_apps.monero.bulletproof.py +++ b/core/tests/test_apps.monero.bulletproof.py @@ -285,102 +285,27 @@ class TestMoneroBulletproof(unittest.TestCase): bpi.use_det_masks = False self.mask_consistency_check(bpi) - def test_verify_testnet(self): - bpi = bp.BulletProofBuilder() - - # fmt: off - bp_proof = Bulletproof( - V=[bytes( - [0x67, 0x54, 0xbf, 0x40, 0xcb, 0x45, 0x63, 0x0d, 0x4b, 0xea, 0x08, 0x9e, 0xd7, 0x86, 0xec, 0x3c, 0xe5, - 0xbd, 0x4e, 0xed, 0x8f, 0xf3, 0x25, 0x76, 0xae, 0xca, 0xb8, 0x9e, 0xf2, 0x5e, 0x41, 0x16])], - A=bytes( - [0x96, 0x10, 0x17, 0x66, 0x87, 0x7e, 0xef, 0x97, 0xb3, 0x82, 0xfb, 0x8e, 0x0c, 0x2a, 0x93, 0x68, 0x9e, - 0x05, 0x22, 0x07, 0xe3, 0x30, 0x94, 0x20, 0x58, 0x6f, 0x5d, 0x01, 0x6d, 0x4e, 0xd5, 0x88]), - S=bytes( - [0x50, 0x51, 0x38, 0x32, 0x96, 0x20, 0x7c, 0xc9, 0x60, 0x4d, 0xac, 0x7c, 0x7c, 0x21, 0xf9, 0xad, 0x1c, - 0xc2, 0x2d, 0xee, 0x88, 0x7b, 0xa2, 0xe2, 0x61, 0x81, 0x46, 0xf5, 0x99, 0xc3, 0x12, 0x57]), - T1=bytes( - [0x1a, 0x7d, 0x06, 0x51, 0x41, 0xe6, 0x12, 0xbe, 0xad, 0xd7, 0x68, 0x60, 0x85, 0xfc, 0xc4, 0x86, 0x0b, - 0x39, 0x4b, 0x06, 0xf7, 0xca, 0xb3, 0x29, 0xdf, 0x1d, 0xbf, 0x96, 0x5f, 0xbe, 0x8c, 0x87]), - T2=bytes( - [0x57, 0xae, 0x91, 0x04, 0xfa, 0xac, 0xf3, 0x73, 0x75, 0xf2, 0x83, 0xd6, 0x9a, 0xcb, 0xef, 0xe4, 0xfc, - 0xe5, 0x37, 0x55, 0x52, 0x09, 0xb5, 0x60, 0x6d, 0xab, 0x46, 0x85, 0x01, 0x23, 0x9e, 0x47]), - taux=bytes( - [0x44, 0x7a, 0x87, 0xd9, 0x5f, 0x1b, 0x17, 0xed, 0x53, 0x7f, 0xc1, 0x4f, 0x91, 0x9b, 0xca, 0x68, 0xce, - 0x20, 0x43, 0xc0, 0x88, 0xf1, 0xdf, 0x12, 0x7b, 0xd7, 0x7f, 0xe0, 0x27, 0xef, 0xef, 0x0d]), - mu=bytes( - [0x32, 0xf9, 0xe4, 0xe1, 0xc2, 0xd8, 0xe4, 0xb0, 0x0d, 0x49, 0xd1, 0x02, 0xbc, 0xcc, 0xf7, 0xa2, 0x5a, - 0xc7, 0x28, 0xf3, 0x05, 0xb5, 0x64, 0x2e, 0xde, 0xcf, 0x01, 0x61, 0xb8, 0x62, 0xfb, 0x0d]), - L=[ - bytes([0xde, 0x71, 0xca, 0x09, 0xf9, 0xd9, 0x1f, 0xa2, 0xae, 0xdf, 0x39, 0x49, 0x04, 0xaa, 0x6b, 0x58, - 0x67, 0x9d, 0x61, 0xa6, 0xfa, 0xec, 0x81, 0xf6, 0x4c, 0x15, 0x09, 0x9d, 0x10, 0x21, 0xff, 0x39]), - bytes([0x90, 0x47, 0xbf, 0xf0, 0x1f, 0x72, 0x47, 0x4e, 0xd5, 0x58, 0xfb, 0xc1, 0x16, 0x43, 0xb7, 0xd8, - 0xb1, 0x00, 0xa4, 0xa3, 0x19, 0x9b, 0xda, 0x5b, 0x27, 0xd3, 0x6c, 0x5a, 0x87, 0xf8, 0xf0, 0x28]), - bytes([0x03, 0x45, 0xef, 0x57, 0x19, 0x8b, 0xc7, 0x38, 0xb7, 0xcb, 0x9c, 0xe7, 0xe8, 0x23, 0x27, 0xbb, - 0xd3, 0x54, 0xcb, 0x38, 0x3c, 0x24, 0x8a, 0x60, 0x11, 0x20, 0x92, 0x99, 0xec, 0x35, 0x71, 0x9f]), - bytes([0x7a, 0xb6, 0x36, 0x42, 0x36, 0x83, 0xf3, 0xa6, 0xc1, 0x24, 0xc5, 0x63, 0xb0, 0x4c, 0x8b, 0xef, - 0x7c, 0x77, 0x25, 0x83, 0xa8, 0xbb, 0x8b, 0x57, 0x75, 0x1c, 0xb6, 0xd7, 0xca, 0xc9, 0x0d, 0x78]), - bytes([0x9d, 0x79, 0x66, 0x21, 0x64, 0x72, 0x97, 0x08, 0xa0, 0x5a, 0x94, 0x5a, 0x94, 0x7b, 0x11, 0xeb, - 0x4e, 0xe9, 0x43, 0x2f, 0x08, 0xa2, 0x57, 0xa5, 0xd5, 0x99, 0xb0, 0xa7, 0xde, 0x78, 0x80, 0xb7]), - bytes([0x9f, 0x88, 0x5c, 0xa5, 0xeb, 0x08, 0xef, 0x1a, 0xcf, 0xbb, 0x1d, 0x04, 0xc5, 0x47, 0x24, 0x37, - 0x49, 0xe4, 0x4e, 0x9c, 0x5d, 0x56, 0xd0, 0x97, 0xfd, 0x8a, 0xe3, 0x23, 0x1d, 0xab, 0x16, 0x03]), - ], - R=[ - bytes([0xae, 0x89, 0xeb, 0xa8, 0x5b, 0xd5, 0x65, 0xd6, 0x9f, 0x2a, 0xfd, 0x04, 0x66, 0xad, 0xb1, 0xf3, - 0x5e, 0xf6, 0x60, 0xa7, 0x26, 0x94, 0x3b, 0x72, 0x5a, 0x5c, 0x80, 0xfa, 0x0f, 0x75, 0x48, 0x27]), - bytes([0xc9, 0x1a, 0x61, 0x70, 0x6d, 0xea, 0xea, 0xb2, 0x42, 0xff, 0x27, 0x3b, 0x8e, 0x94, 0x07, 0x75, - 0x40, 0x7d, 0x33, 0xde, 0xfc, 0xbd, 0x53, 0xa0, 0x2a, 0xf9, 0x0c, 0x36, 0xb0, 0xdd, 0xbe, 0x8d]), - bytes([0xb7, 0x39, 0x7a, 0x0e, 0xa1, 0x42, 0x0f, 0x94, 0x62, 0x24, 0xcf, 0x54, 0x75, 0xe3, 0x0b, 0x0f, - 0xfb, 0xcb, 0x67, 0x7b, 0xbc, 0x98, 0x36, 0x01, 0x9f, 0x73, 0xa0, 0x70, 0xa1, 0x7e, 0xf0, 0xcf]), - bytes([0x40, 0x06, 0xd4, 0xfa, 0x22, 0x7c, 0x82, 0xbf, 0xe8, 0xe0, 0x35, 0x13, 0x28, 0xa2, 0xb9, 0x51, - 0xa3, 0x37, 0x34, 0xc0, 0xa6, 0x43, 0xd6, 0xb7, 0x7a, 0x40, 0xae, 0xf9, 0x36, 0x0e, 0xe3, 0xcc]), - bytes([0x88, 0x38, 0x64, 0xe9, 0x63, 0xe3, 0x33, 0xd9, 0xf6, 0xca, 0x47, 0xc4, 0xc7, 0x36, 0x70, 0x01, - 0xd2, 0xe4, 0x8c, 0x9f, 0x25, 0xc2, 0xce, 0xcf, 0x81, 0x89, 0x4f, 0x24, 0xcb, 0xb8, 0x40, 0x73]), - bytes([0xdc, 0x35, 0x65, 0xed, 0x6b, 0xb0, 0xa7, 0x1a, 0x1b, 0xf3, 0xd6, 0xfb, 0x47, 0x00, 0x48, 0x00, - 0x20, 0x6d, 0xd4, 0xeb, 0xff, 0xb9, 0xdc, 0x43, 0x30, 0x8a, 0x90, 0xfe, 0x43, 0x74, 0x75, 0x68]), - ], - a=bytes( - [0xb4, 0x8e, 0xc2, 0x31, 0xce, 0x05, 0x9a, 0x7a, 0xbc, 0x82, 0x8c, 0x30, 0xb3, 0xe3, 0x80, 0x86, 0x05, - 0xb8, 0x4c, 0x93, 0x9a, 0x8e, 0xce, 0x39, 0x0f, 0xb6, 0xee, 0x28, 0xf6, 0x7e, 0xd5, 0x07]), - b=bytes( - [0x47, 0x10, 0x62, 0xc2, 0xad, 0xc7, 0xe2, 0xc9, 0x14, 0x6f, 0xf4, 0xd1, 0xfe, 0x52, 0xa9, 0x1a, 0xe4, - 0xb6, 0xd0, 0x25, 0x4b, 0x19, 0x80, 0x7c, 0xcd, 0x62, 0x62, 0x1d, 0x97, 0x20, 0x71, 0x0b]), - t=bytes( - [0x47, 0x06, 0xea, 0x76, 0x8f, 0xdb, 0xa3, 0x15, 0xe0, 0x2c, 0x6b, 0x25, 0xa1, 0xf7, 0x3c, 0xc8, 0x1d, - 0x97, 0xa6, 0x52, 0x48, 0x75, 0x37, 0xf9, 0x1e, 0x14, 0xac, 0xb1, 0x2a, 0x34, 0xc6, 0x06]) - ) - # fmt: on - - self.assertTrue(bpi.verify_testnet(bp_proof)) - def test_verify(self): bpi = bp.BulletProofBuilder() self.assertTrue(bpi.verify(self.bproof_1())) self.assertTrue(bpi.verify(self.bproof_2())) self.assertTrue(bpi.verify(self.bproof_4())) - def test_prove_testnet(self): + def test_prove(self): bpi = bp.BulletProofBuilder() val = crypto.sc_init(123) mask = crypto.sc_init(432) - bp_res = bpi.prove_testnet(val, mask) - bpi.verify_testnet(bp_res) + bp_res = bpi.prove(val, mask) + bpi.verify(bp_res) - try: - bp_res.S[0] += 1 - bpi.verify(bp_res) - self.fail("Verification should have failed") - except: - pass - - def test_prove_testnet_2(self): + def test_prove_2(self): bpi = bp.BulletProofBuilder() val = crypto.sc_init((1 << 30) - 1 + 16) mask = crypto.random_scalar() - bp_res = bpi.prove_testnet(val, mask) - bpi.verify_testnet(bp_res) + bp_res = bpi.prove(val, mask) + bpi.verify(bp_res) def test_verify_batch_1(self): bpi = bp.BulletProofBuilder() @@ -403,15 +328,6 @@ class TestMoneroBulletproof(unittest.TestCase): bp_res = bpi.prove(val, mask) bpi.verify(bp_res) - def test_prove_testnet_random_masks(self): - bpi = bp.BulletProofBuilder() - bpi.use_det_masks = False # trully randomly generated mask vectors - val = crypto.sc_init((1 << 30) - 1 + 16) - mask = crypto.random_scalar() - - bp_res = bpi.prove_testnet(val, mask) - bpi.verify_testnet(bp_res) - def ctest_multiexp(self): scalars = [0, 1, 2, 3, 4, 99] point_base = [0, 2, 4, 7, 12, 18] @@ -438,6 +354,13 @@ class TestMoneroBulletproof(unittest.TestCase): proof = bpi.prove_batch(sv, gamma) bpi.verify_batch([proof]) + def test_prove_batch16(self): + bpi = bp.BulletProofBuilder() + sv = [crypto.sc_init(137*i) for i in range(16)] + gamma = [crypto.sc_init(991*i) for i in range(16)] + proof = bpi.prove_batch(sv, gamma) + bpi.verify_batch([proof]) + if __name__ == "__main__": unittest.main()