diff --git a/core/src/apps/common/cbor.py b/core/src/apps/common/cbor.py index 3df990e0fb..b2bdc13330 100644 --- a/core/src/apps/common/cbor.py +++ b/core/src/apps/common/cbor.py @@ -10,10 +10,16 @@ from trezor import log, utils from . import readers if False: - from typing import Any, Union, Iterator, Tuple + from typing import Any, Generic, Iterator, Tuple, TypeVar, Union + K = TypeVar("K") + V = TypeVar("V") Value = Any CborSequence = Union[list[Value], Tuple[Value, ...]] +else: + # mypy cheat: Generic[K, V] will be `object` which is a valid parent type + Generic = {(0, 0): object} # type: ignore + K = V = 0 # type: ignore _CBOR_TYPE_MASK = const(0xE0) _CBOR_INFO_BITS = const(0x1F) @@ -82,6 +88,11 @@ def _cbor_encode(value: Value) -> Iterator[bytes]: for k, v in sorted_map: yield k yield from _cbor_encode(v) + elif isinstance(value, OrderedMap): + yield _header(_CBOR_MAP, len(value)) + for k, v in value: + yield encode(k) + yield from _cbor_encode(v) elif isinstance(value, Tagged): yield _header(_CBOR_TAG, value.tag) yield from _cbor_encode(value.value) @@ -226,6 +237,26 @@ class IndefiniteLengthArray: return False +class OrderedMap(Generic[K, V]): + """ + Items of an OrderedMap are included in CBOR as they are added without sorting them in any way. We also allow + duplicates since CBOR is also somewhat lenient in not allowing them. It is thus up to the client to make sure no + duplicates are inserted if it's desired. + """ + + def __init__(self) -> None: + self._internal_list: list[tuple[K, V]] = [] + + def __setitem__(self, key: K, value: V) -> None: + self._internal_list.append((key, value)) + + def __iter__(self) -> Iterator: + yield from self._internal_list + + def __len__(self) -> int: + return len(self._internal_list) + + def encode(value: Value) -> bytes: return b"".join(_cbor_encode(value)) diff --git a/core/tests/test_apps.common.cbor.py b/core/tests/test_apps.common.cbor.py index e029cefcbc..723491b24e 100644 --- a/core/tests/test_apps.common.cbor.py +++ b/core/tests/test_apps.common.cbor.py @@ -3,8 +3,9 @@ import math from common import * from apps.common.cbor import ( - Tagged, IndefiniteLengthArray, + OrderedMap, + Tagged, decode, encode, encode_chunked, @@ -59,6 +60,7 @@ class TestCardanoCbor(unittest.TestCase): # maps ({}, 'a0'), ({1: 2, 3: 4}, 'a201020304'), + ({3: 4, 1: 2}, 'a201020304'), # indefinite (IndefiniteLengthArray([]), '9fff'), @@ -94,6 +96,25 @@ class TestCardanoCbor(unittest.TestCase): self.assertEqual(encode(value_tuple), encoded) self.assertEqual(decode(encoded), val) + def test_cbor_ordered_map(self): + """ + OrderedMaps should be encoded as maps without any ordering and decoded back as dicts. + """ + test_vectors = [ + ({}, 'a0'), + ([[1, 2], [3, 4]], 'a201020304'), + ([[3, 4], [1, 2]], 'a203040102'), + ] + + for val, encoded_hex in test_vectors: + ordered_map = OrderedMap() + for key, value in val: + ordered_map[key] = value + + encoded = unhexlify(encoded_hex) + self.assertEqual(encode(ordered_map), encoded) + self.assertEqual(decode(encoded), {k: v for k, v in val}) + def test_encode_streamed(self): large_dict = {i: i for i in range(100)} encoded = encode(large_dict)