diff --git a/src/protobuf.py b/src/protobuf.py index 7096fb7739..a7d03a41a0 100644 --- a/src/protobuf.py +++ b/src/protobuf.py @@ -40,6 +40,8 @@ async def load_uvarint(reader): async def dump_uvarint(writer, n): + if n < 0: + raise ValueError("Cannot dump signed value, convert it to unsigned first.") buffer = _UVARINT_BUFFER shifted = True while shifted: @@ -49,15 +51,45 @@ async def dump_uvarint(writer, n): n = shifted +# protobuf interleaved signed encoding: +# https://developers.google.com/protocol-buffers/docs/encoding#structure +# the idea is to save the sign in LSbit instead of twos-complement. +# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips) +# +# To achieve this with a twos-complement number: +# 1. shift left by 1, leaving LSbit free +# 2. if the number is negative, do bitwise negation. +# This keeps positive number the same, and converts negative from twos-complement +# to the appropriate value, while setting the sign bit. +# +# The original algorithm makes use of the fact that arithmetic (signed) shift +# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits". +# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive +# and XOR 1 (bitwise negation) for negative. Cute and efficient. +# +# But this is harder in Python because we don't natively know the bit size of the number. +# So we have to branch on whether the number is negative. + +def sint_to_uint(sint): + res = sint << 1 + if sint < 0: + res = ~res + return res + + +def uint_to_sint(uint): + sign = uint & 1 + res = uint >> 1 + if sign: + res = ~res + return res + + class UVarintType: WIRE_TYPE = 0 -class Sint32Type: - WIRE_TYPE = 0 - - -class Sint64Type: +class SVarintType: WIRE_TYPE = 0 @@ -149,10 +181,8 @@ async def load_message(reader, msg_type): if ftype is UVarintType: fvalue = ivalue - elif ftype is Sint32Type: - fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff) - elif ftype is Sint64Type: - fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff) + elif ftype is SVarintType: + fvalue = uint_to_sint(ivalue) elif ftype is BoolType: fvalue = bool(ivalue) elif ftype is BytesType: @@ -188,10 +218,7 @@ async def dump_message(writer, msg): fields = mtype.FIELDS for ftag in fields: - field = fields[ftag] - fname = field[0] - ftype = field[1] - fflags = field[2] + fname, ftype, fflags = fields[ftag] fvalue = getattr(msg, fname, None) if fvalue is None: @@ -209,11 +236,8 @@ async def dump_message(writer, msg): if ftype is UVarintType: await dump_uvarint(writer, svalue) - elif ftype is Sint32Type: - await dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31)) - - elif ftype is Sint64Type: - await dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63)) + elif ftype is SVarintType: + await dump_uvarint(writer, sint_to_uint(svalue)) elif ftype is BoolType: await dump_uvarint(writer, int(svalue))