mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 14:28:07 +00:00
protobuf: properly implement signed types (fixes #249)
This commit is contained in:
parent
b156ec9757
commit
df8c3da1a2
@ -70,8 +70,8 @@ def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy):
|
|||||||
field.TYPE_UINT64: 'p.UVarintType',
|
field.TYPE_UINT64: 'p.UVarintType',
|
||||||
field.TYPE_UINT32: 'p.UVarintType',
|
field.TYPE_UINT32: 'p.UVarintType',
|
||||||
field.TYPE_ENUM: 'p.UVarintType',
|
field.TYPE_ENUM: 'p.UVarintType',
|
||||||
field.TYPE_SINT32: 'p.Sint32Type',
|
field.TYPE_SINT32: 'p.SVarintType',
|
||||||
field.TYPE_SINT64: 'p.Sint64Type',
|
field.TYPE_SINT64: 'p.SVarintType',
|
||||||
field.TYPE_STRING: 'p.UnicodeType',
|
field.TYPE_STRING: 'p.UnicodeType',
|
||||||
field.TYPE_BOOL: 'p.BoolType',
|
field.TYPE_BOOL: 'p.BoolType',
|
||||||
field.TYPE_BYTES: 'p.BytesType'
|
field.TYPE_BYTES: 'p.BytesType'
|
||||||
|
@ -6,5 +6,5 @@ from .NEMCosignatoryModification import NEMCosignatoryModification
|
|||||||
class NEMAggregateModification(p.MessageType):
|
class NEMAggregateModification(p.MessageType):
|
||||||
FIELDS = {
|
FIELDS = {
|
||||||
1: ('modifications', NEMCosignatoryModification, p.FLAG_REPEATED),
|
1: ('modifications', NEMCosignatoryModification, p.FLAG_REPEATED),
|
||||||
2: ('relative_change', p.Sint32Type, 0),
|
2: ('relative_change', p.SVarintType, 0),
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,8 @@ def load_uvarint(reader):
|
|||||||
|
|
||||||
|
|
||||||
def dump_uvarint(writer, n):
|
def dump_uvarint(writer, n):
|
||||||
|
if n < 0:
|
||||||
|
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||||
buffer = _UVARINT_BUFFER
|
buffer = _UVARINT_BUFFER
|
||||||
shifted = True
|
shifted = True
|
||||||
while shifted:
|
while shifted:
|
||||||
@ -68,15 +70,47 @@ def dump_uvarint(writer, n):
|
|||||||
n = shifted
|
n = shifted
|
||||||
|
|
||||||
|
|
||||||
|
# protobuf interleaved signed encoding:
|
||||||
|
# https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||||
|
# the idea is to save the sign in LSbit instead of twos-complement.
|
||||||
|
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
|
||||||
|
#
|
||||||
|
# To achieve this with a twos-complement number:
|
||||||
|
# 1. shift left by 1, leaving LSbit free
|
||||||
|
# 2. XOR with "all sign bits" - 0s for positive, 1s for negative
|
||||||
|
# This keeps positive number the same, and converts negative from twos-complement
|
||||||
|
# to the appropriate value, while setting the sign bit. Cute and efficient.
|
||||||
|
#
|
||||||
|
# The original algorithm makes use of the fact that arithmetic (signed) shift
|
||||||
|
# keeps the sign bits, so for a n-bit number, (x >> n+1) gets us the "all sign bits".
|
||||||
|
#
|
||||||
|
# But this is harder in Python because we don't know the bit size of the number.
|
||||||
|
# We could simply shift by 65, relying on the fact that the biggest type for other
|
||||||
|
# languages is sint64. Or we could shift by 1000 to be extra sure.
|
||||||
|
#
|
||||||
|
# But instead, we'll do it less elegantly, with an if branch:
|
||||||
|
# if the number is negative, do bitwise negation (which is the same as "xor all ones").
|
||||||
|
|
||||||
|
def sint_to_uint(sint):
|
||||||
|
res = sint << 1
|
||||||
|
if sint < 0:
|
||||||
|
res = ~res
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def uint_to_sint(uint):
|
||||||
|
sign = uint & 1
|
||||||
|
res = uint >> 1
|
||||||
|
if sign:
|
||||||
|
res = ~res
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class UVarintType:
|
class UVarintType:
|
||||||
WIRE_TYPE = 0
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
|
||||||
class Sint32Type:
|
class SVarintType:
|
||||||
WIRE_TYPE = 0
|
|
||||||
|
|
||||||
|
|
||||||
class Sint64Type:
|
|
||||||
WIRE_TYPE = 0
|
WIRE_TYPE = 0
|
||||||
|
|
||||||
|
|
||||||
@ -237,10 +271,8 @@ def load_message(reader, msg_type):
|
|||||||
|
|
||||||
if ftype is UVarintType:
|
if ftype is UVarintType:
|
||||||
fvalue = ivalue
|
fvalue = ivalue
|
||||||
elif ftype is Sint32Type:
|
elif ftype is SVarintType:
|
||||||
fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff)
|
fvalue = uint_to_sint(ivalue)
|
||||||
elif ftype is Sint64Type:
|
|
||||||
fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff)
|
|
||||||
elif ftype is BoolType:
|
elif ftype is BoolType:
|
||||||
fvalue = bool(ivalue)
|
fvalue = bool(ivalue)
|
||||||
elif ftype is BytesType:
|
elif ftype is BytesType:
|
||||||
@ -289,11 +321,8 @@ def dump_message(writer, msg):
|
|||||||
if ftype is UVarintType:
|
if ftype is UVarintType:
|
||||||
dump_uvarint(writer, svalue)
|
dump_uvarint(writer, svalue)
|
||||||
|
|
||||||
elif ftype is Sint32Type:
|
elif ftype is SVarintType:
|
||||||
dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31))
|
dump_uvarint(writer, sint_to_uint(svalue))
|
||||||
|
|
||||||
elif ftype is Sint64Type:
|
|
||||||
dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63))
|
|
||||||
|
|
||||||
elif ftype is BoolType:
|
elif ftype is BoolType:
|
||||||
dump_uvarint(writer, int(svalue))
|
dump_uvarint(writer, int(svalue))
|
||||||
|
90
trezorlib/tests/unit_tests/test_protobuf.py
Normal file
90
trezorlib/tests/unit_tests/test_protobuf.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from trezorlib import protobuf
|
||||||
|
|
||||||
|
|
||||||
|
class PrimitiveMessage(protobuf.MessageType):
|
||||||
|
FIELDS = {
|
||||||
|
0: ("uvarint", protobuf.UVarintType, 0),
|
||||||
|
1: ("svarint", protobuf.SVarintType, 0),
|
||||||
|
2: ("bool", protobuf.BoolType, 0),
|
||||||
|
3: ("bytes", protobuf.BytesType, 0),
|
||||||
|
4: ("unicode", protobuf.UnicodeType, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_uvarint(buffer):
|
||||||
|
reader = BytesIO(buffer)
|
||||||
|
return protobuf.load_uvarint(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_uvarint(value):
|
||||||
|
writer = BytesIO()
|
||||||
|
protobuf.dump_uvarint(writer, value)
|
||||||
|
return writer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_uvarint():
|
||||||
|
assert dump_uvarint(0) == b'\x00'
|
||||||
|
assert dump_uvarint(1) == b'\x01'
|
||||||
|
assert dump_uvarint(0xff) == b'\xff\x01'
|
||||||
|
assert dump_uvarint(123456) == b'\xc0\xc4\x07'
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
dump_uvarint(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_uvarint():
|
||||||
|
assert load_uvarint(b'\x00') == 0
|
||||||
|
assert load_uvarint(b'\x01') == 1
|
||||||
|
assert load_uvarint(b'\xff\x01') == 0xff
|
||||||
|
assert load_uvarint(b'\xc0\xc4\x07') == 123456
|
||||||
|
|
||||||
|
|
||||||
|
def test_sint_uint():
|
||||||
|
"""
|
||||||
|
Protobuf interleaved signed encoding
|
||||||
|
https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||||
|
LSbit is sign, rest is shifted absolute value.
|
||||||
|
Or, by example, you count like so: 0, -1, 1, -2, 2, -3 ...
|
||||||
|
"""
|
||||||
|
assert protobuf.sint_to_uint(0) == 0
|
||||||
|
assert protobuf.uint_to_sint(0) == 0
|
||||||
|
|
||||||
|
assert protobuf.sint_to_uint(-1) == 1
|
||||||
|
assert protobuf.sint_to_uint(1) == 2
|
||||||
|
|
||||||
|
assert protobuf.uint_to_sint(1) == -1
|
||||||
|
assert protobuf.uint_to_sint(2) == 1
|
||||||
|
|
||||||
|
# roundtrip:
|
||||||
|
assert protobuf.uint_to_sint(
|
||||||
|
protobuf.sint_to_uint(1234567891011)
|
||||||
|
) == 1234567891011
|
||||||
|
assert protobuf.uint_to_sint(
|
||||||
|
protobuf.sint_to_uint(- 2 ** 32)
|
||||||
|
) == - 2 ** 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_message():
|
||||||
|
msg = PrimitiveMessage(
|
||||||
|
uvarint=12345678910,
|
||||||
|
svarint=-12345678910,
|
||||||
|
bool=True,
|
||||||
|
bytes=b'\xDE\xAD\xCA\xFE',
|
||||||
|
unicode="Příliš žluťoučký kůň úpěl ďábelské ódy 😊",
|
||||||
|
)
|
||||||
|
|
||||||
|
buf = BytesIO()
|
||||||
|
|
||||||
|
protobuf.dump_message(buf, msg)
|
||||||
|
buf.seek(0)
|
||||||
|
retr = protobuf.load_message(buf, PrimitiveMessage)
|
||||||
|
|
||||||
|
assert msg == retr
|
||||||
|
assert retr.uvarint == 12345678910
|
||||||
|
assert retr.svarint == -12345678910
|
||||||
|
assert retr.bool is True
|
||||||
|
assert retr.bytes == b'\xDE\xAD\xCA\xFE'
|
||||||
|
assert retr.unicode == "Příliš žluťoučký kůň úpěl ďábelské ódy 😊"
|
Loading…
Reference in New Issue
Block a user