1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 14:28:07 +00:00

protobuf: properly implement signed types (fixes #249)

This commit is contained in:
matejcik 2018-04-11 11:15:38 +02:00
parent b156ec9757
commit df8c3da1a2
4 changed files with 136 additions and 17 deletions

View File

@ -70,8 +70,8 @@ def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy):
field.TYPE_UINT64: 'p.UVarintType', field.TYPE_UINT64: 'p.UVarintType',
field.TYPE_UINT32: 'p.UVarintType', field.TYPE_UINT32: 'p.UVarintType',
field.TYPE_ENUM: 'p.UVarintType', field.TYPE_ENUM: 'p.UVarintType',
field.TYPE_SINT32: 'p.Sint32Type', field.TYPE_SINT32: 'p.SVarintType',
field.TYPE_SINT64: 'p.Sint64Type', field.TYPE_SINT64: 'p.SVarintType',
field.TYPE_STRING: 'p.UnicodeType', field.TYPE_STRING: 'p.UnicodeType',
field.TYPE_BOOL: 'p.BoolType', field.TYPE_BOOL: 'p.BoolType',
field.TYPE_BYTES: 'p.BytesType' field.TYPE_BYTES: 'p.BytesType'

View File

@ -6,5 +6,5 @@ from .NEMCosignatoryModification import NEMCosignatoryModification
class NEMAggregateModification(p.MessageType): class NEMAggregateModification(p.MessageType):
FIELDS = { FIELDS = {
1: ('modifications', NEMCosignatoryModification, p.FLAG_REPEATED), 1: ('modifications', NEMCosignatoryModification, p.FLAG_REPEATED),
2: ('relative_change', p.Sint32Type, 0), 2: ('relative_change', p.SVarintType, 0),
} }

View File

@ -59,6 +59,8 @@ def load_uvarint(reader):
def dump_uvarint(writer, n): def dump_uvarint(writer, n):
if n < 0:
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
buffer = _UVARINT_BUFFER buffer = _UVARINT_BUFFER
shifted = True shifted = True
while shifted: while shifted:
@ -68,15 +70,47 @@ def dump_uvarint(writer, n):
n = shifted n = shifted
# protobuf interleaved signed encoding:
# https://developers.google.com/protocol-buffers/docs/encoding#structure
# the idea is to save the sign in LSbit instead of twos-complement.
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
#
# To achieve this with a twos-complement number:
# 1. shift left by 1, leaving LSbit free
# 2. XOR with "all sign bits" - 0s for positive, 1s for negative
# This keeps positive number the same, and converts negative from twos-complement
# to the appropriate value, while setting the sign bit. Cute and efficient.
#
# The original algorithm makes use of the fact that arithmetic (signed) shift
# keeps the sign bits, so for a n-bit number, (x >> n+1) gets us the "all sign bits".
#
# But this is harder in Python because we don't know the bit size of the number.
# We could simply shift by 65, relying on the fact that the biggest type for other
# languages is sint64. Or we could shift by 1000 to be extra sure.
#
# But instead, we'll do it less elegantly, with an if branch:
# if the number is negative, do bitwise negation (which is the same as "xor all ones").
def sint_to_uint(sint):
res = sint << 1
if sint < 0:
res = ~res
return res
def uint_to_sint(uint):
sign = uint & 1
res = uint >> 1
if sign:
res = ~res
return res
class UVarintType: class UVarintType:
WIRE_TYPE = 0 WIRE_TYPE = 0
class Sint32Type: class SVarintType:
WIRE_TYPE = 0
class Sint64Type:
WIRE_TYPE = 0 WIRE_TYPE = 0
@ -237,10 +271,8 @@ def load_message(reader, msg_type):
if ftype is UVarintType: if ftype is UVarintType:
fvalue = ivalue fvalue = ivalue
elif ftype is Sint32Type: elif ftype is SVarintType:
fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff) fvalue = uint_to_sint(ivalue)
elif ftype is Sint64Type:
fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff)
elif ftype is BoolType: elif ftype is BoolType:
fvalue = bool(ivalue) fvalue = bool(ivalue)
elif ftype is BytesType: elif ftype is BytesType:
@ -289,11 +321,8 @@ def dump_message(writer, msg):
if ftype is UVarintType: if ftype is UVarintType:
dump_uvarint(writer, svalue) dump_uvarint(writer, svalue)
elif ftype is Sint32Type: elif ftype is SVarintType:
dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31)) dump_uvarint(writer, sint_to_uint(svalue))
elif ftype is Sint64Type:
dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63))
elif ftype is BoolType: elif ftype is BoolType:
dump_uvarint(writer, int(svalue)) dump_uvarint(writer, int(svalue))

View File

@ -0,0 +1,90 @@
from io import BytesIO
import pytest
from trezorlib import protobuf
class PrimitiveMessage(protobuf.MessageType):
FIELDS = {
0: ("uvarint", protobuf.UVarintType, 0),
1: ("svarint", protobuf.SVarintType, 0),
2: ("bool", protobuf.BoolType, 0),
3: ("bytes", protobuf.BytesType, 0),
4: ("unicode", protobuf.UnicodeType, 0),
}
def load_uvarint(buffer):
reader = BytesIO(buffer)
return protobuf.load_uvarint(reader)
def dump_uvarint(value):
writer = BytesIO()
protobuf.dump_uvarint(writer, value)
return writer.getvalue()
def test_dump_uvarint():
assert dump_uvarint(0) == b'\x00'
assert dump_uvarint(1) == b'\x01'
assert dump_uvarint(0xff) == b'\xff\x01'
assert dump_uvarint(123456) == b'\xc0\xc4\x07'
with pytest.raises(ValueError):
dump_uvarint(-1)
def test_load_uvarint():
assert load_uvarint(b'\x00') == 0
assert load_uvarint(b'\x01') == 1
assert load_uvarint(b'\xff\x01') == 0xff
assert load_uvarint(b'\xc0\xc4\x07') == 123456
def test_sint_uint():
"""
Protobuf interleaved signed encoding
https://developers.google.com/protocol-buffers/docs/encoding#structure
LSbit is sign, rest is shifted absolute value.
Or, by example, you count like so: 0, -1, 1, -2, 2, -3 ...
"""
assert protobuf.sint_to_uint(0) == 0
assert protobuf.uint_to_sint(0) == 0
assert protobuf.sint_to_uint(-1) == 1
assert protobuf.sint_to_uint(1) == 2
assert protobuf.uint_to_sint(1) == -1
assert protobuf.uint_to_sint(2) == 1
# roundtrip:
assert protobuf.uint_to_sint(
protobuf.sint_to_uint(1234567891011)
) == 1234567891011
assert protobuf.uint_to_sint(
protobuf.sint_to_uint(- 2 ** 32)
) == - 2 ** 32
def test_simple_message():
msg = PrimitiveMessage(
uvarint=12345678910,
svarint=-12345678910,
bool=True,
bytes=b'\xDE\xAD\xCA\xFE',
unicode="Příliš žluťoučký kůň úpěl ďábelské ódy 😊",
)
buf = BytesIO()
protobuf.dump_message(buf, msg)
buf.seek(0)
retr = protobuf.load_message(buf, PrimitiveMessage)
assert msg == retr
assert retr.uvarint == 12345678910
assert retr.svarint == -12345678910
assert retr.bool is True
assert retr.bytes == b'\xDE\xAD\xCA\xFE'
assert retr.unicode == "Příliš žluťoučký kůň úpěl ďábelské ódy 😊"