feat(core): make protobuf.dump_uvarint more reusable

release/21.02
matejcik 4 years ago committed by Tomas Susanka
parent e9c227f623
commit 6acc1cd6ab

@ -4,7 +4,7 @@ bytes, string, embedded message and repeated fields.
""" """
if False: if False:
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, TypeVar, Union
from typing_extensions import Protocol from typing_extensions import Protocol
class Reader(Protocol): class Reader(Protocol):
@ -19,6 +19,8 @@ if False:
Writes all bytes from `buf`, or raises `EOFError`. Writes all bytes from `buf`, or raises `EOFError`.
""" """
WriteMethod = Callable[[bytes], Any]
_UVARINT_BUFFER = bytearray(1) _UVARINT_BUFFER = bytearray(1)
@ -36,7 +38,7 @@ def load_uvarint(reader: Reader) -> int:
return result return result
def dump_uvarint(writer: Writer, n: int) -> None: def dump_uvarint(write: WriteMethod, n: int) -> None:
if n < 0: if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.") raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
@ -44,7 +46,7 @@ def dump_uvarint(writer: Writer, n: int) -> None:
while shifted: while shifted:
shifted = n >> 7 shifted = n >> 7
buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00) buffer[0] = (n & 0x7F) | (0x80 if shifted else 0x00)
writer.write(buffer) write(buffer)
n = shifted n = shifted
@ -318,32 +320,32 @@ def dump_message(
fvalue = repvalue fvalue = repvalue
for svalue in fvalue: for svalue in fvalue:
dump_uvarint(writer, fkey) dump_uvarint(writer.write, fkey)
if ftype is UVarintType: if ftype is UVarintType:
dump_uvarint(writer, svalue) dump_uvarint(writer.write, svalue)
elif ftype is SVarintType: elif ftype is SVarintType:
dump_uvarint(writer, sint_to_uint(svalue)) dump_uvarint(writer.write, sint_to_uint(svalue))
elif ftype is BoolType: elif ftype is BoolType:
dump_uvarint(writer, int(svalue)) dump_uvarint(writer.write, int(svalue))
elif isinstance(ftype, EnumType): elif isinstance(ftype, EnumType):
dump_uvarint(writer, svalue) dump_uvarint(writer.write, svalue)
elif ftype is BytesType: elif ftype is BytesType:
if isinstance(svalue, list): if isinstance(svalue, list):
dump_uvarint(writer, _count_bytes_list(svalue)) dump_uvarint(writer.write, _count_bytes_list(svalue))
for sub_svalue in svalue: for sub_svalue in svalue:
writer.write(sub_svalue) writer.write(sub_svalue)
else: else:
dump_uvarint(writer, len(svalue)) dump_uvarint(writer.write, len(svalue))
writer.write(svalue) writer.write(svalue)
elif ftype is UnicodeType: elif ftype is UnicodeType:
svalue = svalue.encode() svalue = svalue.encode()
dump_uvarint(writer, len(svalue)) dump_uvarint(writer.write, len(svalue))
writer.write(svalue) writer.write(svalue)
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
@ -351,7 +353,7 @@ def dump_message(
if ffields is None: if ffields is None:
ffields = ftype.get_fields() ffields = ftype.get_fields()
field_cache[ftype] = ffields field_cache[ftype] = ffields
dump_uvarint(writer, count_message(svalue, field_cache)) dump_uvarint(writer.write, count_message(svalue, field_cache))
dump_message(writer, svalue, field_cache) dump_message(writer, svalue, field_cache)
else: else:

@ -36,9 +36,9 @@ def load_uvarint(data: bytes) -> int:
def dump_uvarint(value: int) -> bytearray: def dump_uvarint(value: int) -> bytearray:
writer = BufferWriter(bytearray(16)) w = bytearray()
protobuf.dump_uvarint(writer, value) protobuf.dump_uvarint(w.extend, value)
return memoryview(writer.buffer)[: writer.offset] return w
def dump_message(msg: protobuf.MessageType) -> bytearray: def dump_message(msg: protobuf.MessageType) -> bytearray:

Loading…
Cancel
Save