1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-16 03:18:09 +00:00

protobuf: support signed ints properly (fixes #189)

This commit is contained in:
matejcik 2018-05-09 13:25:07 +02:00 committed by Pavol Rusnak
parent af7a66697b
commit 88ea30b746
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

View File

@ -40,6 +40,8 @@ async def load_uvarint(reader):
async def dump_uvarint(writer, n):
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER
shifted = True
while shifted:
@ -49,15 +51,45 @@ async def dump_uvarint(writer, n):
n = shifted
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
#
# To achieve this with a twos-complement number:
# 1. shift left by 1, leaving LSbit free
# 2. if the number is negative, do bitwise negation.
# This keeps positive number the same, and converts negative from twos-complement
# to the appropriate value, while setting the sign bit.
#
# The original algorithm makes use of the fact that arithmetic (signed) shift
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
#
# But this is harder in Python because we don't natively know the bit size of the number.
# So we have to branch on whether the number is negative.
def sint_to_uint(sint):
res = sint << 1
if sint < 0:
res = ~res
return res
def uint_to_sint(uint):
sign = uint & 1
res = uint >> 1
if sign:
res = ~res
return res
class UVarintType:
WIRE_TYPE = 0
class Sint32Type:
WIRE_TYPE = 0
class Sint64Type:
class SVarintType:
WIRE_TYPE = 0
@ -149,10 +181,8 @@ async def load_message(reader, msg_type):
if ftype is UVarintType:
fvalue = ivalue
elif ftype is Sint32Type:
fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff)
elif ftype is Sint64Type:
fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff)
elif ftype is SVarintType:
fvalue = uint_to_sint(ivalue)
elif ftype is BoolType:
fvalue = bool(ivalue)
elif ftype is BytesType:
@ -188,10 +218,7 @@ async def dump_message(writer, msg):
fields = mtype.FIELDS
for ftag in fields:
field = fields[ftag]
fname = field[0]
ftype = field[1]
fflags = field[2]
fname, ftype, fflags = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
@ -209,11 +236,8 @@ async def dump_message(writer, msg):
if ftype is UVarintType:
await dump_uvarint(writer, svalue)
elif ftype is Sint32Type:
await dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31))
elif ftype is Sint64Type:
await dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63))
elif ftype is SVarintType:
await dump_uvarint(writer, sint_to_uint(svalue))
elif ftype is BoolType:
await dump_uvarint(writer, int(svalue))