mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-16 03:18:09 +00:00
protobuf: support signed ints properly (fixes #189)
This commit is contained in:
parent
af7a66697b
commit
88ea30b746
@ -40,6 +40,8 @@ async def load_uvarint(reader):
|
||||
|
||||
|
||||
async def dump_uvarint(writer, n):
|
||||
if n < 0:
|
||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||
buffer = _UVARINT_BUFFER
|
||||
shifted = True
|
||||
while shifted:
|
||||
@ -49,15 +51,45 @@ async def dump_uvarint(writer, n):
|
||||
n = shifted
|
||||
|
||||
|
||||
# protobuf interleaved signed encoding:
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||
# the idea is to save the sign in LSbit instead of twos-complement.
|
||||
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
|
||||
#
|
||||
# To achieve this with a twos-complement number:
|
||||
# 1. shift left by 1, leaving LSbit free
|
||||
# 2. if the number is negative, do bitwise negation.
|
||||
# This keeps positive number the same, and converts negative from twos-complement
|
||||
# to the appropriate value, while setting the sign bit.
|
||||
#
|
||||
# The original algorithm makes use of the fact that arithmetic (signed) shift
|
||||
# keeps the sign bits, so for a n-bit number, (x >> n) gets us "all sign bits".
|
||||
# Then you can take "number XOR all-sign-bits", which is XOR 0 (identity) for positive
|
||||
# and XOR 1 (bitwise negation) for negative. Cute and efficient.
|
||||
#
|
||||
# But this is harder in Python because we don't natively know the bit size of the number.
|
||||
# So we have to branch on whether the number is negative.
|
||||
|
||||
def sint_to_uint(sint):
|
||||
res = sint << 1
|
||||
if sint < 0:
|
||||
res = ~res
|
||||
return res
|
||||
|
||||
|
||||
def uint_to_sint(uint):
|
||||
sign = uint & 1
|
||||
res = uint >> 1
|
||||
if sign:
|
||||
res = ~res
|
||||
return res
|
||||
|
||||
|
||||
class UVarintType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
class Sint32Type:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
class Sint64Type:
|
||||
class SVarintType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
@ -149,10 +181,8 @@ async def load_message(reader, msg_type):
|
||||
|
||||
if ftype is UVarintType:
|
||||
fvalue = ivalue
|
||||
elif ftype is Sint32Type:
|
||||
fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff)
|
||||
elif ftype is Sint64Type:
|
||||
fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff)
|
||||
elif ftype is SVarintType:
|
||||
fvalue = uint_to_sint(ivalue)
|
||||
elif ftype is BoolType:
|
||||
fvalue = bool(ivalue)
|
||||
elif ftype is BytesType:
|
||||
@ -188,10 +218,7 @@ async def dump_message(writer, msg):
|
||||
fields = mtype.FIELDS
|
||||
|
||||
for ftag in fields:
|
||||
field = fields[ftag]
|
||||
fname = field[0]
|
||||
ftype = field[1]
|
||||
fflags = field[2]
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
@ -209,11 +236,8 @@ async def dump_message(writer, msg):
|
||||
if ftype is UVarintType:
|
||||
await dump_uvarint(writer, svalue)
|
||||
|
||||
elif ftype is Sint32Type:
|
||||
await dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31))
|
||||
|
||||
elif ftype is Sint64Type:
|
||||
await dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63))
|
||||
elif ftype is SVarintType:
|
||||
await dump_uvarint(writer, sint_to_uint(svalue))
|
||||
|
||||
elif ftype is BoolType:
|
||||
await dump_uvarint(writer, int(svalue))
|
||||
|
Loading…
Reference in New Issue
Block a user