1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-16 03:18:09 +00:00

Merge pull request #394 from ph4r05/xmr-mg

xmr: MLSAG computation optimized
This commit is contained in:
Tomas Susanka 2018-11-02 15:26:48 +01:00 committed by GitHub
commit d919e99255
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 183 additions and 130 deletions

View File

@ -10,7 +10,7 @@ from .state import State
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.signing import RctType from apps.monero.signing import RctType
from apps.monero.xmr import crypto, serialize from apps.monero.xmr import crypto
if False: if False:
from trezor.messages.MoneroTransactionSourceEntry import ( from trezor.messages.MoneroTransactionSourceEntry import (
@ -40,8 +40,6 @@ async def sign_input(
:param spend_enc: one time address spending private key. Encrypted. :param spend_enc: one time address spending private key. Encrypted.
:return: Generated signature MGs[i] :return: Generated signature MGs[i]
""" """
from apps.monero.signing import offloading_keys
await confirms.transaction_step( await confirms.transaction_step(
state.ctx, state.STEP_SIGN, state.current_input_index + 1, state.input_count state.ctx, state.STEP_SIGN, state.current_input_index + 1, state.input_count
) )
@ -57,8 +55,11 @@ async def sign_input(
raise ValueError("Two and more inputs must imply SimpleRCT") raise ValueError("Two and more inputs must imply SimpleRCT")
input_position = state.source_permutation[state.current_input_index] input_position = state.source_permutation[state.current_input_index]
mods = utils.unimport_begin()
# Check input's HMAC # Check input's HMAC
from apps.monero.signing import offloading_keys
vini_hmac_comp = await offloading_keys.gen_hmac_vini( vini_hmac_comp = await offloading_keys.gen_hmac_vini(
state.key_hmac, src_entr, vini_bin, input_position state.key_hmac, src_entr, vini_bin, input_position
) )
@ -66,7 +67,9 @@ async def sign_input(
raise ValueError("HMAC is not correct") raise ValueError("HMAC is not correct")
gc.collect() gc.collect()
state.mem_trace(1) state.mem_trace(1, True)
from apps.monero.xmr.crypto import chacha_poly
if state.rct_type == RctType.Simple: if state.rct_type == RctType.Simple:
# both pseudo_out and its mask were offloaded so we need to # both pseudo_out and its mask were offloaded so we need to
@ -78,10 +81,7 @@ async def sign_input(
if not crypto.ct_equals(pseudo_out_hmac_comp, pseudo_out_hmac): if not crypto.ct_equals(pseudo_out_hmac_comp, pseudo_out_hmac):
raise ValueError("HMAC is not correct") raise ValueError("HMAC is not correct")
gc.collect() state.mem_trace(2, True)
state.mem_trace(2)
from apps.monero.xmr.crypto import chacha_poly
pseudo_out_alpha = crypto.decodeint( pseudo_out_alpha = crypto.decodeint(
chacha_poly.decrypt_pack( chacha_poly.decrypt_pack(
@ -92,9 +92,6 @@ async def sign_input(
pseudo_out_c = crypto.decodepoint(pseudo_out) pseudo_out_c = crypto.decodepoint(pseudo_out)
# Spending secret # Spending secret
from apps.monero.xmr.crypto import chacha_poly
from apps.monero.xmr.serialize_messages.ct_keys import CtKey
spend_key = crypto.decodeint( spend_key = crypto.decodeint(
chacha_poly.decrypt_pack( chacha_poly.decrypt_pack(
offloading_keys.enc_key_spend(state.key_enc, input_position), offloading_keys.enc_key_spend(state.key_enc, input_position),
@ -102,8 +99,18 @@ async def sign_input(
) )
) )
gc.collect() del (
state.mem_trace(3) offloading_keys,
chacha_poly,
pseudo_out,
pseudo_out_hmac,
pseudo_out_alpha_enc,
spend_enc,
)
utils.unimport_end(mods)
state.mem_trace(3, True)
from apps.monero.xmr.serialize_messages.ct_keys import CtKey
# Basic setup, sanity check # Basic setup, sanity check
index = src_entr.real_output index = src_entr.real_output
@ -126,14 +133,16 @@ async def sign_input(
"Real source entry's mask does not equal spend key's", "Real source entry's mask does not equal spend key's",
) )
gc.collect() state.mem_trace(4, True)
state.mem_trace(4)
from apps.monero.xmr import mlsag from apps.monero.xmr import mlsag
if state.rct_type == RctType.Simple: mg_buffer = []
ring_pubkeys = [x.key for x in src_entr.outputs] ring_pubkeys = [x.key for x in src_entr.outputs]
mg = mlsag.generate_mlsag_simple( del src_entr
if state.rct_type == RctType.Simple:
mlsag.generate_mlsag_simple(
state.full_message, state.full_message,
ring_pubkeys, ring_pubkeys,
input_secret_key, input_secret_key,
@ -141,52 +150,33 @@ async def sign_input(
pseudo_out_c, pseudo_out_c,
kLRki, kLRki,
index, index,
mg_buffer,
) )
del (input_secret_key, pseudo_out_alpha, pseudo_out_c)
else: else:
# Full RingCt, only one input # Full RingCt, only one input
txn_fee_key = crypto.scalarmult_h(state.fee) txn_fee_key = crypto.scalarmult_h(state.fee)
ring_pubkeys = [[x.key] for x in src_entr.outputs] mlsag.generate_mlsag_full(
mg = mlsag.generate_mlsag_full(
state.full_message, state.full_message,
ring_pubkeys, ring_pubkeys,
[input_secret_key], input_secret_key,
state.output_sk_masks, state.output_sk_masks,
state.output_pk_commitments, state.output_pk_commitments,
kLRki, kLRki,
index, index,
txn_fee_key, txn_fee_key,
mg_buffer,
) )
gc.collect() del (input_secret_key, txn_fee_key)
state.mem_trace(5)
# Encode del (mlsag, ring_pubkeys)
mgs = _recode_msg([mg]) state.mem_trace(5, True)
gc.collect()
state.mem_trace(6)
from trezor.messages.MoneroTransactionSignInputAck import ( from trezor.messages.MoneroTransactionSignInputAck import (
MoneroTransactionSignInputAck, MoneroTransactionSignInputAck,
) )
return MoneroTransactionSignInputAck( return MoneroTransactionSignInputAck(signature=mg_buffer)
signature=serialize.dump_msg_gc(mgs[0], preallocate=488)
)
def _recode_msg(mgs):
"""
Recodes MGs signatures from raw forms to bytearrays so it works with serialization
"""
for idx in range(len(mgs)):
mgs[idx].cc = crypto.encodeint(mgs[idx].cc)
if hasattr(mgs[idx], "II") and mgs[idx].II:
for i in range(len(mgs[idx].II)):
mgs[idx].II[i] = crypto.encodepoint(mgs[idx].II[i])
for i in range(len(mgs[idx].ss)):
for j in range(len(mgs[idx].ss[i])):
mgs[idx].ss[i][j] = crypto.encodeint(mgs[idx].ss[i][j])
return mgs

View File

@ -107,6 +107,7 @@ def sc_init_into(r, x):
return tcry.init256_modm(r, x) return tcry.init256_modm(r, x)
sc_copy = tcry.init256_modm
sc_get64 = tcry.get256_modm sc_get64 = tcry.get256_modm
sc_check = tcry.check256_modm sc_check = tcry.check256_modm
check_sc = tcry.check256_modm check_sc = tcry.check256_modm

View File

@ -42,63 +42,76 @@ and `sk` is equal to:
Mostly ported from official Monero client, but also inspired by Mininero. Mostly ported from official Monero client, but also inspired by Mininero.
Author: Dusan Klinec, ph4r05, 2018 Author: Dusan Klinec, ph4r05, 2018
""" """
import gc
from apps.monero.xmr import crypto from apps.monero.xmr import crypto
from apps.monero.xmr.serialize import int_serialize
def generate_mlsag_full( def generate_mlsag_full(
message, pubs, in_sk, out_sk_mask, out_pk_commitments, kLRki, index, txn_fee_key message,
pubs,
in_sk,
out_sk_mask,
out_pk_commitments,
kLRki,
index,
txn_fee_key,
mg_buff,
): ):
cols = len(pubs) cols = len(pubs)
if cols == 0: if cols == 0:
raise ValueError("Empty pubs") raise ValueError("Empty pubs")
rows = len(pubs[0]) rows = 1 # Monero uses only one row
if rows == 0:
raise ValueError("Empty pub row")
for i in range(cols):
if len(pubs[i]) != rows:
raise ValueError("pub is not rectangular")
if len(in_sk) != rows:
raise ValueError("Bad inSk size")
if len(out_sk_mask) != len(out_pk_commitments): if len(out_sk_mask) != len(out_pk_commitments):
raise ValueError("Bad outsk/putpk size") raise ValueError("Bad outsk/putpk size")
sk = _key_vector(rows + 1) sk = _key_vector(rows + 1)
M = _key_matrix(rows + 1, cols) M = _key_matrix(rows + 1, cols)
for i in range(rows + 1):
sk[i] = crypto.sc_0() tmp_mi_rows = crypto.new_point(None)
tmp_pt = crypto.new_point(None)
for i in range(cols): for i in range(cols):
M[i][rows] = crypto.identity() crypto.identity_into(tmp_mi_rows) # M[i][rows]
for j in range(rows):
M[i][j] = crypto.decodepoint(pubs[i][j].dest) # Should iterate over rows, simplified as rows == 1
M[i][rows] = crypto.point_add( M[i][0] = pubs[i].dest
M[i][rows], crypto.decodepoint(pubs[i][j].commitment) crypto.point_add_into(
tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, pubs[i].commitment),
) )
pubs[i] = None
sk[rows] = crypto.sc_0()
for j in range(rows):
sk[j] = in_sk[j].dest
sk[rows] = crypto.sc_add(sk[rows], in_sk[j].mask) # add masks in last row
for i in range(cols):
for j in range(len(out_pk_commitments)): for j in range(len(out_pk_commitments)):
M[i][rows] = crypto.point_sub( crypto.point_sub_into(
M[i][rows], crypto.decodepoint(out_pk_commitments[j]) tmp_mi_rows,
tmp_mi_rows,
crypto.decodepoint_into(tmp_pt, out_pk_commitments[j]),
) # subtract output Ci's in last row ) # subtract output Ci's in last row
# Subtract txn fee output in last row # Subtract txn fee output in last row
M[i][rows] = crypto.point_sub(M[i][rows], txn_fee_key) crypto.point_sub_into(tmp_mi_rows, tmp_mi_rows, txn_fee_key)
M[i][rows] = crypto.encodepoint(tmp_mi_rows)
# Simplified as rows == 1
sk[0] = in_sk.dest
sk[rows] = in_sk.mask # originally: sum of all in_sk[0..rows] in sk[rows]
for j in range(len(out_pk_commitments)): for j in range(len(out_pk_commitments)):
sk[rows] = crypto.sc_sub( crypto.sc_sub_into(
sk[rows], out_sk_mask[j] sk[rows], sk[rows], out_sk_mask[j]
) # subtract output masks in last row ) # subtract output masks in last row
return generate_mlsag(message, M, sk, kLRki, index, rows) del (pubs, tmp_mi_rows, tmp_pt)
gc.collect()
return generate_mlsag(message, M, sk, kLRki, index, rows, mg_buff)
def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index): def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index, mg_buff):
""" """
MLSAG for RctType.Simple MLSAG for RctType.Simple
:param message: the full message to be signed (actually its hash) :param message: the full message to be signed (actually its hash)
@ -108,7 +121,7 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
:param cout: pseudo output commitment; point, decoded; better name: pseudo_out_c :param cout: pseudo output commitment; point, decoded; better name: pseudo_out_c
:param kLRki: used only in multisig, currently not implemented :param kLRki: used only in multisig, currently not implemented
:param index: specifies corresponding public key to the `in_sk` in the pubs array :param index: specifies corresponding public key to the `in_sk` in the pubs array
:return: MgSig :param mg_buff: buffer to store the signature to
""" """
# Monero signs inputs separately, so `rows` always equals 2 (pubkey, commitment) # Monero signs inputs separately, so `rows` always equals 2 (pubkey, commitment)
# and `dsRows` is always 1 (denotes where the pubkeys "end") # and `dsRows` is always 1 (denotes where the pubkeys "end")
@ -123,12 +136,21 @@ def generate_mlsag_simple(message, pubs, in_sk, a, cout, kLRki, index):
sk[0] = in_sk.dest sk[0] = in_sk.dest
sk[1] = crypto.sc_sub(in_sk.mask, a) sk[1] = crypto.sc_sub(in_sk.mask, a)
tmp_pt = crypto.new_point()
for i in range(cols): for i in range(cols):
M[i][0] = crypto.decodepoint(pubs[i].dest) crypto.point_sub_into(
M[i][1] = crypto.point_sub(crypto.decodepoint(pubs[i].commitment), cout) tmp_pt, crypto.decodepoint_into(tmp_pt, pubs[i].commitment), cout
)
return generate_mlsag(message, M, sk, kLRki, index, dsRows) M[i][0] = pubs[i].dest
M[i][1] = crypto.encodepoint(tmp_pt)
pubs[i] = None
del (pubs)
gc.collect()
return generate_mlsag(message, M, sk, kLRki, index, dsRows, mg_buff)
def gen_mlsag_assert(pk, xx, kLRki, index, dsRows): def gen_mlsag_assert(pk, xx, kLRki, index, dsRows):
@ -159,13 +181,10 @@ def gen_mlsag_assert(pk, xx, kLRki, index, dsRows):
return rows, cols return rows, cols
def generate_first_c_and_key_images( def generate_first_c_and_key_images(message, pk, xx, kLRki, index, dsRows, rows, cols):
message, rv, pk, xx, kLRki, index, dsRows, rows, cols
):
""" """
MLSAG computation - the part with secret keys MLSAG computation - the part with secret keys
:param message: the full message to be signed (actually its hash) :param message: the full message to be signed (actually its hash)
:param rv: MgSig
:param pk: matrix of public keys and commitments :param pk: matrix of public keys and commitments
:param xx: input secret array composed of a private key and commitment mask :param xx: input secret array composed of a private key and commitment mask
:param kLRki: used only in multisig, currently not implemented :param kLRki: used only in multisig, currently not implemented
@ -174,18 +193,19 @@ def generate_first_c_and_key_images(
:param rows: total number of rows :param rows: total number of rows
:param cols: size of ring :param cols: size of ring
""" """
Ip = _key_vector(dsRows) II = _key_vector(dsRows)
rv.II = _key_vector(dsRows)
alpha = _key_vector(rows) alpha = _key_vector(rows)
rv.ss = _key_matrix(rows, cols)
tmp_buff = bytearray(32) tmp_buff = bytearray(32)
Hi = crypto.new_point()
aGi = crypto.new_point()
aHPi = crypto.new_point()
hasher = _hasher_message(message) hasher = _hasher_message(message)
for i in range(dsRows): for i in range(dsRows):
# this is somewhat extra as compared to the Ring Confidential Tx paper # this is somewhat extra as compared to the Ring Confidential Tx paper
# see footnote in From Zero to Monero section 3.3 # see footnote in From Zero to Monero section 3.3
hasher.update(crypto.encodepoint(pk[index][i])) hasher.update(pk[index][i])
if kLRki: if kLRki:
raise NotImplementedError("Multisig not implemented") raise NotImplementedError("Multisig not implemented")
# alpha[i] = kLRki.k # alpha[i] = kLRki.k
@ -194,36 +214,34 @@ def generate_first_c_and_key_images(
# hash_point(hasher, kLRki.R, tmp_buff) # hash_point(hasher, kLRki.R, tmp_buff)
else: else:
Hi = crypto.hash_to_point(crypto.encodepoint(pk[index][i])) crypto.hash_to_point_into(Hi, pk[index][i])
alpha[i] = crypto.random_scalar() alpha[i] = crypto.random_scalar()
# L = alpha_i * G # L = alpha_i * G
aGi = crypto.scalarmult_base(alpha[i]) crypto.scalarmult_base_into(aGi, alpha[i])
# Ri = alpha_i * H(P_i) # Ri = alpha_i * H(P_i)
aHPi = crypto.scalarmult(Hi, alpha[i]) crypto.scalarmult_into(aHPi, Hi, alpha[i])
# key image # key image
rv.II[i] = crypto.scalarmult(Hi, xx[i]) II[i] = crypto.scalarmult(Hi, xx[i])
_hash_point(hasher, aGi, tmp_buff) _hash_point(hasher, aGi, tmp_buff)
_hash_point(hasher, aHPi, tmp_buff) _hash_point(hasher, aHPi, tmp_buff)
Ip[i] = rv.II[i]
for i in range(dsRows, rows): for i in range(dsRows, rows):
alpha[i] = crypto.random_scalar() alpha[i] = crypto.random_scalar()
# L = alpha_i * G # L = alpha_i * G
aGi = crypto.scalarmult_base(alpha[i]) crypto.scalarmult_base_into(aGi, alpha[i])
# for some reasons we omit calculating R here, which seems # for some reasons we omit calculating R here, which seems
# contrary to the paper, but it is in the Monero official client # contrary to the paper, but it is in the Monero official client
# see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191 # see https://github.com/monero-project/monero/blob/636153b2050aa0642ba86842c69ac55a5d81618d/src/ringct/rctSigs.cpp#L191
_hash_point(hasher, pk[index][i], tmp_buff) hasher.update(pk[index][i])
_hash_point(hasher, aGi, tmp_buff) _hash_point(hasher, aGi, tmp_buff)
# the first c # the first c
c_old = hasher.digest() c_old = hasher.digest()
c_old = crypto.decodeint(c_old) c_old = crypto.decodeint(c_old)
return c_old, Ip, alpha return c_old, II, alpha
def generate_mlsag(message, pk, xx, kLRki, index, dsRows): def generate_mlsag(message, pk, xx, kLRki, index, dsRows, mg_buff):
""" """
Multilayered Spontaneous Anonymous Group Signatures (MLSAG signatures) Multilayered Spontaneous Anonymous Group Signatures (MLSAG signatures)
@ -233,56 +251,89 @@ def generate_mlsag(message, pk, xx, kLRki, index, dsRows):
:param kLRki: used only in multisig, currently not implemented :param kLRki: used only in multisig, currently not implemented
:param index: specifies corresponding public key to the `xx`'s private key in the `pk` array :param index: specifies corresponding public key to the `xx`'s private key in the `pk` array
:param dsRows: separates pubkeys from commitment :param dsRows: separates pubkeys from commitment
:return MgSig :param mg_buff: mg signature buffer
""" """
from apps.monero.xmr.serialize_messages.tx_full import MgSig
rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows) rows, cols = gen_mlsag_assert(pk, xx, kLRki, index, dsRows)
rows_b_size = int_serialize.uvarint_size(rows)
rv = MgSig() # Preallocation of the chunked buffer, len + cols + cc
c, L, R, Hi = 0, None, None, None for _ in range(1 + cols + 1):
mg_buff.append(None)
mg_buff[0] = int_serialize.dump_uvarint_b(cols)
cc = crypto.new_scalar() # rv.cc
c = crypto.new_scalar()
L = crypto.new_point()
R = crypto.new_point()
Hi = crypto.new_point()
# calculates the "first" c, key images and random scalars alpha # calculates the "first" c, key images and random scalars alpha
c_old, Ip, alpha = generate_first_c_and_key_images( c_old, II, alpha = generate_first_c_and_key_images(
message, rv, pk, xx, kLRki, index, dsRows, rows, cols message, pk, xx, kLRki, index, dsRows, rows, cols
) )
i = (index + 1) % cols i = (index + 1) % cols
if i == 0: if i == 0:
rv.cc = c_old crypto.sc_copy(cc, c_old)
ss = [crypto.new_scalar() for _ in range(rows)]
tmp_buff = bytearray(32) tmp_buff = bytearray(32)
while i != index: while i != index:
rv.ss[i] = _generate_random_vector(rows)
hasher = _hasher_message(message) hasher = _hasher_message(message)
# Serialize size of the row
mg_buff[i + 1] = bytearray(rows_b_size + 32 * rows)
int_serialize.dump_uvarint_b_into(rows, mg_buff[i + 1])
for x in ss:
crypto.random_scalar(x)
for j in range(dsRows): for j in range(dsRows):
# L = rv.ss[i][j] * G + c_old * pk[i][j] # L = rv.ss[i][j] * G + c_old * pk[i][j]
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j]) crypto.add_keys2_into(
Hi = crypto.hash_to_point(crypto.encodepoint(pk[i][j])) L, ss[j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
)
crypto.hash_to_point_into(Hi, pk[i][j])
# R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j] # R = rv.ss[i][j] * H(pk[i][j]) + c_old * Ip[j]
R = crypto.add_keys3(rv.ss[i][j], Hi, c_old, rv.II[j]) crypto.add_keys3_into(R, ss[j], Hi, c_old, II[j])
_hash_point(hasher, pk[i][j], tmp_buff)
hasher.update(pk[i][j])
_hash_point(hasher, L, tmp_buff) _hash_point(hasher, L, tmp_buff)
_hash_point(hasher, R, tmp_buff) _hash_point(hasher, R, tmp_buff)
for j in range(dsRows, rows): for j in range(dsRows, rows):
# again, omitting R here as discussed above # again, omitting R here as discussed above
L = crypto.add_keys2(rv.ss[i][j], c_old, pk[i][j]) crypto.add_keys2_into(
_hash_point(hasher, pk[i][j], tmp_buff) L, ss[j], c_old, crypto.decodepoint_into(Hi, pk[i][j])
)
hasher.update(pk[i][j])
_hash_point(hasher, L, tmp_buff) _hash_point(hasher, L, tmp_buff)
c = crypto.decodeint(hasher.digest()) for si in range(rows):
c_old = c crypto.encodeint_into(mg_buff[i + 1], ss[si], rows_b_size + 32 * si)
crypto.decodeint_into(c, hasher.digest())
crypto.sc_copy(c_old, c)
pk[i] = None
i = (i + 1) % cols i = (i + 1) % cols
if i == 0: if i == 0:
rv.cc = c_old crypto.sc_copy(cc, c_old)
gc.collect()
del II
# Finalizing rv.ss by processing rv.ss[index]
mg_buff[index + 1] = bytearray(rows_b_size + 32 * rows)
int_serialize.dump_uvarint_b_into(rows, mg_buff[index + 1])
for j in range(rows): for j in range(rows):
rv.ss[index][j] = crypto.sc_mulsub(c, xx[j], alpha[j]) crypto.sc_mulsub_into(ss[j], c, xx[j], alpha[j])
crypto.encodeint_into(mg_buff[index + 1], ss[j], rows_b_size + 32 * j)
return rv # rv.cc
mg_buff[-1] = crypto.encodeint(cc)
def _key_vector(rows): def _key_vector(rows):

View File

@ -1,11 +0,0 @@
from apps.monero.xmr.serialize.message_types import MessageType
from apps.monero.xmr.serialize_messages.base import ECKey
from apps.monero.xmr.serialize_messages.ct_keys import KeyM
class MgSig(MessageType):
__slots__ = ("ss", "cc", "II")
@classmethod
def f_specs(cls):
return (("ss", KeyM), ("cc", ECKey))

View File

@ -275,6 +275,11 @@ async def dump_message(writer, msg, fields=None):
elif ftype is BoolType: elif ftype is BoolType:
await dump_uvarint(writer, int(svalue)) await dump_uvarint(writer, int(svalue))
elif ftype is BytesType and is_chunked(svalue):
await dump_uvarint(writer, len_list_bytes(svalue))
for sub_svalue in svalue:
await writer.awrite(sub_svalue)
elif ftype is BytesType: elif ftype is BytesType:
await dump_uvarint(writer, len(svalue)) await dump_uvarint(writer, len(svalue))
await writer.awrite(svalue) await writer.awrite(svalue)
@ -329,7 +334,9 @@ def count_message(msg, fields=None):
elif ftype is BytesType: elif ftype is BytesType:
for svalue in fvalue: for svalue in fvalue:
svalue = len(svalue) svalue = (
len(svalue) if not is_chunked(svalue) else len_list_bytes(svalue)
)
nbytes += count_uvarint(svalue) nbytes += count_uvarint(svalue)
nbytes += svalue nbytes += svalue
@ -351,3 +358,18 @@ def count_message(msg, fields=None):
raise TypeError raise TypeError
return nbytes return nbytes
def is_chunked(svalue):
return (
isinstance(svalue, list)
and len(svalue) > 0
and not isinstance(svalue[0], (int, bool))
)
def len_list_bytes(svalue):
res = 0
for x in svalue:
res += len(x)
return res