core/cbor: move CBOR from cardano to common

- add support for text strings and boolean values
- add support negative integers and decoding maps
- fix decoding of short strings
- encode maps canonically
- add unit tests for decoding
- sort maps lexicographically by encoded key
pull/213/head
Andrew Kozlik 5 years ago committed by Pavol Rusnak
parent 751715dc15
commit 31506d81e9
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

@ -1,8 +1,7 @@
from trezor import log
from trezor.crypto import base58, crc, hashlib
from apps.cardano import cbor
from apps.common import HARDENED
from apps.common import HARDENED, cbor
from apps.common.seed import remove_ed25519_prefix

@ -7,13 +7,14 @@ from trezor.messages.CardanoSignedTx import CardanoSignedTx
from trezor.messages.CardanoTxRequest import CardanoTxRequest
from trezor.messages.MessageType import CardanoTxAck
from apps.cardano import CURVE, cbor, seed
from apps.cardano import CURVE, seed
from apps.cardano.address import (
derive_address_and_node,
is_safe_output_address,
validate_full_path,
)
from apps.cardano.layout import confirm_sending, confirm_transaction, progress
from apps.common import cbor
from apps.common.paths import validate_path
from apps.common.seed import remove_ed25519_prefix
from apps.homescreen.homescreen import display_homescreen

@ -12,7 +12,9 @@ _CBOR_TYPE_MASK = const(0xE0)
_CBOR_INFO_BITS = const(0x1F)
_CBOR_UNSIGNED_INT = const(0b000 << 5)
_CBOR_NEGATIVE_INT = const(0b001 << 5)
_CBOR_BYTE_STRING = const(0b010 << 5)
_CBOR_TEXT_STRING = const(0b011 << 5)
_CBOR_ARRAY = const(0b100 << 5)
_CBOR_MAP = const(0b101 << 5)
_CBOR_TAG = const(0b110 << 5)
@ -24,6 +26,8 @@ _CBOR_UINT32_FOLLOWS = const(0x1A)
_CBOR_UINT64_FOLLOWS = const(0x1B)
_CBOR_VAR_FOLLOWS = const(0x1F)
_CBOR_FALSE = const(0x14)
_CBOR_TRUE = const(0x15)
_CBOR_BREAK = const(0x1F)
_CBOR_RAW_TAG = const(0x18)
@ -45,13 +49,20 @@ def _header(typ, l: int):
def _cbor_encode(value):
if isinstance(value, int):
yield _header(_CBOR_UNSIGNED_INT, value)
if value >= 0:
yield _header(_CBOR_UNSIGNED_INT, value)
else:
yield _header(_CBOR_NEGATIVE_INT, -1 - value)
elif isinstance(value, bytes):
yield _header(_CBOR_BYTE_STRING, len(value))
yield value
elif isinstance(value, bytearray):
yield _header(_CBOR_BYTE_STRING, len(value))
yield bytes(value)
elif isinstance(value, str):
encoded_value = value.encode()
yield _header(_CBOR_TEXT_STRING, len(encoded_value))
yield encoded_value
elif isinstance(value, list):
# definite-length valued list
yield _header(_CBOR_ARRAY, len(value))
@ -59,8 +70,9 @@ def _cbor_encode(value):
yield from _cbor_encode(x)
elif isinstance(value, dict):
yield _header(_CBOR_MAP, len(value))
for k, v in value.items():
yield from _cbor_encode(k)
sorted_map = sorted((encode(k), v) for k, v in value.items())
for k, v in sorted_map:
yield k
yield from _cbor_encode(v)
elif isinstance(value, Tagged):
yield _header(_CBOR_TAG, value.tag)
@ -70,6 +82,11 @@ def _cbor_encode(value):
for x in value.array:
yield from _cbor_encode(x)
yield bytes([_CBOR_PRIMITIVE + 31])
elif isinstance(value, bool):
if value:
yield bytes([_CBOR_PRIMITIVE + _CBOR_TRUE])
else:
yield bytes([_CBOR_PRIMITIVE + _CBOR_FALSE])
elif isinstance(value, Raw):
yield value.value
else:
@ -79,7 +96,9 @@ def _cbor_encode(value):
def _read_length(cbor, aux):
if aux == _CBOR_UINT8_FOLLOWS:
if aux < _CBOR_UINT8_FOLLOWS:
return (aux, cbor)
elif aux == _CBOR_UINT8_FOLLOWS:
return (cbor[0], cbor[1:])
elif aux == _CBOR_UINT16_FOLLOWS:
res = cbor[1]
@ -111,14 +130,16 @@ def _cbor_decode(cbor):
fb_type = fb & _CBOR_TYPE_MASK
fb_aux = fb & _CBOR_INFO_BITS
if fb_type == _CBOR_UNSIGNED_INT:
if fb_aux < 0x18:
return (fb_aux, cbor[1:])
else:
val, data = _read_length(cbor[1:], fb_aux)
return (int(val), data)
return _read_length(cbor[1:], fb_aux)
elif fb_type == _CBOR_NEGATIVE_INT:
val, data = _read_length(cbor[1:], fb_aux)
return (-1 - val, data)
elif fb_type == _CBOR_BYTE_STRING:
ln, data = _read_length(cbor[1:], fb_aux)
return (data[0:ln], data[ln:])
elif fb_type == _CBOR_TEXT_STRING:
ln, data = _read_length(cbor[1:], fb_aux)
return (data[0:ln].decode(), data[ln:])
elif fb_type == _CBOR_ARRAY:
if fb_aux == _CBOR_VAR_FOLLOWS:
res = []
@ -130,25 +151,49 @@ def _cbor_decode(cbor):
res.append(item)
return (res, data)
else:
if fb_aux < _CBOR_UINT8_FOLLOWS:
ln = fb_aux
data = cbor[1:]
else:
ln, data = _read_length(cbor[1:], fb_aux)
ln, data = _read_length(cbor[1:], fb_aux)
res = []
for i in range(ln):
item, data = _cbor_decode(data)
res.append(item)
return (res, data)
elif fb_type == _CBOR_MAP:
return ({}, cbor[1:])
res = {}
if fb_aux == _CBOR_VAR_FOLLOWS:
data = cbor[1:]
while True:
key, data = _cbor_decode(data)
if key in res:
raise ValueError
if key == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
value, data = _cbor_decode(data)
res[key] = value
else:
ln, data = _read_length(cbor[1:], fb_aux)
for i in range(ln):
key, data = _cbor_decode(data)
if key in res:
raise ValueError
value, data = _cbor_decode(data)
res[key] = value
return res, data
elif fb_type == _CBOR_TAG:
if cbor[1] == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported
return _cbor_decode(cbor[2:])
val, data = _read_length(cbor[1:], fb_aux)
item, data = _cbor_decode(data)
if val == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported
return item, data
else:
return Tagged(val, item), data
elif fb_type == _CBOR_PRIMITIVE:
if fb_aux == _CBOR_FALSE:
return (False, cbor[1:])
elif fb_aux == _CBOR_TRUE:
return (True, cbor[1:])
elif fb_aux == _CBOR_BREAK:
return (cbor[0], cbor[1:])
else:
raise NotImplementedError()
elif fb_type == _CBOR_PRIMITIVE: # only break code is supported
return (cbor[0], cbor[1:])
else:
if __debug__:
log.debug(__name__, "not implemented (decode): %s", cbor[0])
@ -160,6 +205,9 @@ class Tagged:
self.tag = tag
self.value = value
def __eq__(self, other):
return self.tag == other.tag and self.value == other.value
class Raw:
def __init__(self, value):
@ -171,6 +219,12 @@ class IndefiniteLengthArray:
ensure(isinstance(array, list))
self.array = array
def __eq__(self, other):
if isinstance(other, IndefiniteLengthArray):
return self.array == other.array
else:
return self.array == other
def encode(value):
return b"".join(_cbor_encode(value))

@ -1,16 +1,17 @@
from common import *
from apps.cardano.cbor import (
from apps.common.cbor import (
Tagged,
IndefiniteLengthArray,
encode
decode,
encode,
)
from ubinascii import unhexlify
class TestCardanoCbor(unittest.TestCase):
def test_cbor_encoding(self):
test_vectors = [
# integers
# unsigned integers
(0, '00'),
(1, '01'),
(10, '0a'),
@ -22,10 +23,26 @@ class TestCardanoCbor(unittest.TestCase):
(1000000, '1a000f4240'),
(1000000000000, '1b000000e8d4a51000'),
# negative integers
(-1, '20'),
(-10, '29'),
(-24, '37'),
(-25, '3818'),
(-26, '3819'),
(-100, '3863'),
(-1000, '3903E7'),
(-1000000, '3A000F423F'),
(-1000000000000, '3B000000E8D4A50FFF'),
# binary strings
(b'', '40'),
(unhexlify('01020304'), '4401020304'),
# text strings
('', '60'),
('Fun', '6346756e'),
(u'P\u0159\xed\u0161ern\u011b \u017elu\u0165ou\u010dk\xfd k\u016f\u0148 \xfap\u011bl \u010f\xe1belsk\xe9 \xf3dy z\xe1ke\u0159n\xfd u\u010de\u0148 b\u011b\u017e\xed pod\xe9l z\xf3ny \xfal\u016f', '786550c599c3adc5a165726ec49b20c5be6c75c5a56f75c48d6bc3bd206bc5afc58820c3ba70c49b6c20c48fc3a162656c736bc3a920c3b36479207ac3a16b65c5996ec3bd2075c48d65c5882062c49bc5bec3ad20706f64c3a96c207ac3b36e7920c3ba6cc5af'),
# tags
(Tagged(1, 1363896240), 'c11a514b67b0'),
(Tagged(23, unhexlify('01020304')), 'd74401020304'),
@ -38,19 +55,21 @@ class TestCardanoCbor(unittest.TestCase):
# maps
({}, 'a0'),
# Note: normal python dict doesn't have a fixed item ordering
({1: 2, 3: 4}, 'a203040102'),
({1: 2, 3: 4}, 'a201020304'),
# indefinite
(IndefiniteLengthArray([]), '9fff'),
(IndefiniteLengthArray([1, [2, 3], [4, 5]]), '9f01820203820405ff'),
(IndefiniteLengthArray([1, [2, 3], IndefiniteLengthArray([4, 5])]),
'9f018202039f0405ffff'),
# boolean
(True, 'f5'),
(False, 'f4'),
]
for val, expected in test_vectors:
encoded = encode(val)
self.assertEqual(unhexlify(expected), encoded)
for val, encoded in test_vectors:
self.assertEqual(unhexlify(encoded), encode(val))
self.assertEqual(val, decode(unhexlify(encoded)))
if __name__ == '__main__':
unittest.main()
Loading…
Cancel
Save