mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-21 21:00:58 +00:00
protobuf: properly implement signed types (fixes #249)
This commit is contained in:
parent
b156ec9757
commit
df8c3da1a2
@ -70,8 +70,8 @@ def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy):
|
||||
field.TYPE_UINT64: 'p.UVarintType',
|
||||
field.TYPE_UINT32: 'p.UVarintType',
|
||||
field.TYPE_ENUM: 'p.UVarintType',
|
||||
field.TYPE_SINT32: 'p.Sint32Type',
|
||||
field.TYPE_SINT64: 'p.Sint64Type',
|
||||
field.TYPE_SINT32: 'p.SVarintType',
|
||||
field.TYPE_SINT64: 'p.SVarintType',
|
||||
field.TYPE_STRING: 'p.UnicodeType',
|
||||
field.TYPE_BOOL: 'p.BoolType',
|
||||
field.TYPE_BYTES: 'p.BytesType'
|
||||
|
@ -6,5 +6,5 @@ from .NEMCosignatoryModification import NEMCosignatoryModification
|
||||
class NEMAggregateModification(p.MessageType):
|
||||
FIELDS = {
|
||||
1: ('modifications', NEMCosignatoryModification, p.FLAG_REPEATED),
|
||||
2: ('relative_change', p.Sint32Type, 0),
|
||||
2: ('relative_change', p.SVarintType, 0),
|
||||
}
|
||||
|
@ -59,6 +59,8 @@ def load_uvarint(reader):
|
||||
|
||||
|
||||
def dump_uvarint(writer, n):
|
||||
if n < 0:
|
||||
raise ValueError("Cannot dump signed value, convert it to unsigned first.")
|
||||
buffer = _UVARINT_BUFFER
|
||||
shifted = True
|
||||
while shifted:
|
||||
@ -68,15 +70,47 @@ def dump_uvarint(writer, n):
|
||||
n = shifted
|
||||
|
||||
|
||||
# protobuf interleaved signed encoding:
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||
# the idea is to save the sign in LSbit instead of twos-complement.
|
||||
# so counting up, you go: 0, -1, 1, -2, 2, ... (as the first bit changes, sign flips)
|
||||
#
|
||||
# To achieve this with a twos-complement number:
|
||||
# 1. shift left by 1, leaving LSbit free
|
||||
# 2. XOR with "all sign bits" - 0s for positive, 1s for negative
|
||||
# This keeps positive number the same, and converts negative from twos-complement
|
||||
# to the appropriate value, while setting the sign bit. Cute and efficient.
|
||||
#
|
||||
# The original algorithm makes use of the fact that arithmetic (signed) shift
|
||||
# keeps the sign bits, so for a n-bit number, (x >> n+1) gets us the "all sign bits".
|
||||
#
|
||||
# But this is harder in Python because we don't know the bit size of the number.
|
||||
# We could simply shift by 65, relying on the fact that the biggest type for other
|
||||
# languages is sint64. Or we could shift by 1000 to be extra sure.
|
||||
#
|
||||
# But instead, we'll do it less elegantly, with an if branch:
|
||||
# if the number is negative, do bitwise negation (which is the same as "xor all ones").
|
||||
|
||||
def sint_to_uint(sint):
|
||||
res = sint << 1
|
||||
if sint < 0:
|
||||
res = ~res
|
||||
return res
|
||||
|
||||
|
||||
def uint_to_sint(uint):
|
||||
sign = uint & 1
|
||||
res = uint >> 1
|
||||
if sign:
|
||||
res = ~res
|
||||
return res
|
||||
|
||||
|
||||
class UVarintType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
class Sint32Type:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
class Sint64Type:
|
||||
class SVarintType:
|
||||
WIRE_TYPE = 0
|
||||
|
||||
|
||||
@ -237,10 +271,8 @@ def load_message(reader, msg_type):
|
||||
|
||||
if ftype is UVarintType:
|
||||
fvalue = ivalue
|
||||
elif ftype is Sint32Type:
|
||||
fvalue = (ivalue >> 1) ^ ((ivalue << 31) & 0xffffffff)
|
||||
elif ftype is Sint64Type:
|
||||
fvalue = (ivalue >> 1) ^ ((ivalue << 63) & 0xffffffffffffffff)
|
||||
elif ftype is SVarintType:
|
||||
fvalue = uint_to_sint(ivalue)
|
||||
elif ftype is BoolType:
|
||||
fvalue = bool(ivalue)
|
||||
elif ftype is BytesType:
|
||||
@ -289,11 +321,8 @@ def dump_message(writer, msg):
|
||||
if ftype is UVarintType:
|
||||
dump_uvarint(writer, svalue)
|
||||
|
||||
elif ftype is Sint32Type:
|
||||
dump_uvarint(writer, ((svalue << 1) & 0xffffffff) ^ (svalue >> 31))
|
||||
|
||||
elif ftype is Sint64Type:
|
||||
dump_uvarint(writer, ((svalue << 1) & 0xffffffffffffffff) ^ (svalue >> 63))
|
||||
elif ftype is SVarintType:
|
||||
dump_uvarint(writer, sint_to_uint(svalue))
|
||||
|
||||
elif ftype is BoolType:
|
||||
dump_uvarint(writer, int(svalue))
|
||||
|
90
trezorlib/tests/unit_tests/test_protobuf.py
Normal file
90
trezorlib/tests/unit_tests/test_protobuf.py
Normal file
@ -0,0 +1,90 @@
|
||||
from io import BytesIO
|
||||
import pytest
|
||||
|
||||
from trezorlib import protobuf
|
||||
|
||||
|
||||
class PrimitiveMessage(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
0: ("uvarint", protobuf.UVarintType, 0),
|
||||
1: ("svarint", protobuf.SVarintType, 0),
|
||||
2: ("bool", protobuf.BoolType, 0),
|
||||
3: ("bytes", protobuf.BytesType, 0),
|
||||
4: ("unicode", protobuf.UnicodeType, 0),
|
||||
}
|
||||
|
||||
|
||||
def load_uvarint(buffer):
|
||||
reader = BytesIO(buffer)
|
||||
return protobuf.load_uvarint(reader)
|
||||
|
||||
|
||||
def dump_uvarint(value):
|
||||
writer = BytesIO()
|
||||
protobuf.dump_uvarint(writer, value)
|
||||
return writer.getvalue()
|
||||
|
||||
|
||||
def test_dump_uvarint():
|
||||
assert dump_uvarint(0) == b'\x00'
|
||||
assert dump_uvarint(1) == b'\x01'
|
||||
assert dump_uvarint(0xff) == b'\xff\x01'
|
||||
assert dump_uvarint(123456) == b'\xc0\xc4\x07'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dump_uvarint(-1)
|
||||
|
||||
|
||||
def test_load_uvarint():
|
||||
assert load_uvarint(b'\x00') == 0
|
||||
assert load_uvarint(b'\x01') == 1
|
||||
assert load_uvarint(b'\xff\x01') == 0xff
|
||||
assert load_uvarint(b'\xc0\xc4\x07') == 123456
|
||||
|
||||
|
||||
def test_sint_uint():
|
||||
"""
|
||||
Protobuf interleaved signed encoding
|
||||
https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||
LSbit is sign, rest is shifted absolute value.
|
||||
Or, by example, you count like so: 0, -1, 1, -2, 2, -3 ...
|
||||
"""
|
||||
assert protobuf.sint_to_uint(0) == 0
|
||||
assert protobuf.uint_to_sint(0) == 0
|
||||
|
||||
assert protobuf.sint_to_uint(-1) == 1
|
||||
assert protobuf.sint_to_uint(1) == 2
|
||||
|
||||
assert protobuf.uint_to_sint(1) == -1
|
||||
assert protobuf.uint_to_sint(2) == 1
|
||||
|
||||
# roundtrip:
|
||||
assert protobuf.uint_to_sint(
|
||||
protobuf.sint_to_uint(1234567891011)
|
||||
) == 1234567891011
|
||||
assert protobuf.uint_to_sint(
|
||||
protobuf.sint_to_uint(- 2 ** 32)
|
||||
) == - 2 ** 32
|
||||
|
||||
|
||||
def test_simple_message():
|
||||
msg = PrimitiveMessage(
|
||||
uvarint=12345678910,
|
||||
svarint=-12345678910,
|
||||
bool=True,
|
||||
bytes=b'\xDE\xAD\xCA\xFE',
|
||||
unicode="Příliš žluťoučký kůň úpěl ďábelské ódy 😊",
|
||||
)
|
||||
|
||||
buf = BytesIO()
|
||||
|
||||
protobuf.dump_message(buf, msg)
|
||||
buf.seek(0)
|
||||
retr = protobuf.load_message(buf, PrimitiveMessage)
|
||||
|
||||
assert msg == retr
|
||||
assert retr.uvarint == 12345678910
|
||||
assert retr.svarint == -12345678910
|
||||
assert retr.bool is True
|
||||
assert retr.bytes == b'\xDE\xAD\xCA\xFE'
|
||||
assert retr.unicode == "Příliš žluťoučký kůň úpěl ďábelské ódy 😊"
|
Loading…
Reference in New Issue
Block a user