1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-15 19:18:11 +00:00
trezor-firmware/core/src/apps/common/cbor.py
2021-02-10 23:20:56 +01:00

239 lines
7.0 KiB
Python

"""
Minimalistic CBOR implementation, supports only what we need in cardano.
"""
import ustruct as struct
from micropython import const
from trezor import log, utils
from . import readers
if False:
from typing import Any, Iterable, List, Tuple, Union
Value = Any
CborSequence = Union[List[Value], Tuple[Value, ...]]
_CBOR_TYPE_MASK = const(0xE0)
_CBOR_INFO_BITS = const(0x1F)
_CBOR_UNSIGNED_INT = const(0b000 << 5)
_CBOR_NEGATIVE_INT = const(0b001 << 5)
_CBOR_BYTE_STRING = const(0b010 << 5)
_CBOR_TEXT_STRING = const(0b011 << 5)
_CBOR_ARRAY = const(0b100 << 5)
_CBOR_MAP = const(0b101 << 5)
_CBOR_TAG = const(0b110 << 5)
_CBOR_PRIMITIVE = const(0b111 << 5)
_CBOR_UINT8_FOLLOWS = const(0x18)
_CBOR_UINT16_FOLLOWS = const(0x19)
_CBOR_UINT32_FOLLOWS = const(0x1A)
_CBOR_UINT64_FOLLOWS = const(0x1B)
_CBOR_VAR_FOLLOWS = const(0x1F)
_CBOR_FALSE = const(0x14)
_CBOR_TRUE = const(0x15)
_CBOR_NULL = const(0x16)
_CBOR_BREAK = const(0x1F)
_CBOR_RAW_TAG = const(0x18)
def _header(typ: int, l: int) -> bytes:
if l < 24:
return struct.pack(">B", typ + l)
elif l < 2 ** 8:
return struct.pack(">BB", typ + 24, l)
elif l < 2 ** 16:
return struct.pack(">BH", typ + 25, l)
elif l < 2 ** 32:
return struct.pack(">BI", typ + 26, l)
elif l < 2 ** 64:
return struct.pack(">BQ", typ + 27, l)
else:
raise NotImplementedError("Length %d not suppported" % l)
def _cbor_encode(value: Value) -> Iterable[bytes]:
if isinstance(value, int):
if value >= 0:
yield _header(_CBOR_UNSIGNED_INT, value)
else:
yield _header(_CBOR_NEGATIVE_INT, -1 - value)
elif isinstance(value, bytes):
yield _header(_CBOR_BYTE_STRING, len(value))
yield value
elif isinstance(value, bytearray):
yield _header(_CBOR_BYTE_STRING, len(value))
yield bytes(value)
elif isinstance(value, str):
encoded_value = value.encode()
yield _header(_CBOR_TEXT_STRING, len(encoded_value))
yield encoded_value
elif isinstance(value, list) or isinstance(value, tuple):
# definite-length valued list
yield _header(_CBOR_ARRAY, len(value))
for x in value:
yield from _cbor_encode(x)
elif isinstance(value, dict):
yield _header(_CBOR_MAP, len(value))
sorted_map = sorted((encode(k), v) for k, v in value.items())
for k, v in sorted_map:
yield k
yield from _cbor_encode(v)
elif isinstance(value, Tagged):
yield _header(_CBOR_TAG, value.tag)
yield from _cbor_encode(value.value)
elif isinstance(value, IndefiniteLengthArray):
yield bytes([_CBOR_ARRAY + 31])
for x in value.array:
yield from _cbor_encode(x)
yield bytes([_CBOR_PRIMITIVE + 31])
elif isinstance(value, bool):
if value:
yield bytes([_CBOR_PRIMITIVE + _CBOR_TRUE])
else:
yield bytes([_CBOR_PRIMITIVE + _CBOR_FALSE])
elif isinstance(value, Raw):
yield value.value
elif value is None:
yield bytes([_CBOR_PRIMITIVE + _CBOR_NULL])
else:
if __debug__:
log.debug(__name__, "not implemented (encode): %s", type(value))
raise NotImplementedError
def _read_length(r: utils.BufferReader, aux: int) -> int:
if aux < _CBOR_UINT8_FOLLOWS:
return aux
elif aux == _CBOR_UINT8_FOLLOWS:
return r.get()
elif aux == _CBOR_UINT16_FOLLOWS:
return readers.read_uint16_be(r)
elif aux == _CBOR_UINT32_FOLLOWS:
return readers.read_uint32_be(r)
elif aux == _CBOR_UINT64_FOLLOWS:
return readers.read_uint64_be(r)
else:
raise NotImplementedError("Length %d not suppported" % aux)
def _cbor_decode(r: utils.BufferReader) -> Value:
fb = r.get()
fb_type = fb & _CBOR_TYPE_MASK
fb_aux = fb & _CBOR_INFO_BITS
if fb_type == _CBOR_UNSIGNED_INT:
return _read_length(r, fb_aux)
elif fb_type == _CBOR_NEGATIVE_INT:
val = _read_length(r, fb_aux)
return -1 - val
elif fb_type == _CBOR_BYTE_STRING:
ln = _read_length(r, fb_aux)
return r.read(ln)
elif fb_type == _CBOR_TEXT_STRING:
ln = _read_length(r, fb_aux)
return r.read(ln).decode()
elif fb_type == _CBOR_ARRAY:
if fb_aux == _CBOR_VAR_FOLLOWS:
res: Value = []
while True:
item = _cbor_decode(r)
if item == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
res.append(item)
return res
else:
ln = _read_length(r, fb_aux)
res = []
for _ in range(ln):
item = _cbor_decode(r)
res.append(item)
return res
elif fb_type == _CBOR_MAP:
res = {}
if fb_aux == _CBOR_VAR_FOLLOWS:
while True:
key = _cbor_decode(r)
if key in res:
raise ValueError
if key == _CBOR_PRIMITIVE + _CBOR_BREAK:
break
value = _cbor_decode(r)
res[key] = value
else:
ln = _read_length(r, fb_aux)
for _ in range(ln):
key = _cbor_decode(r)
if key in res:
raise ValueError
value = _cbor_decode(r)
res[key] = value
return res
elif fb_type == _CBOR_TAG:
val = _read_length(r, fb_aux)
item = _cbor_decode(r)
if val == _CBOR_RAW_TAG: # only tag 24 (0x18) is supported
return item
else:
return Tagged(val, item)
elif fb_type == _CBOR_PRIMITIVE:
if fb_aux == _CBOR_FALSE:
return False
elif fb_aux == _CBOR_TRUE:
return True
elif fb_aux == _CBOR_NULL:
return None
elif fb_aux == _CBOR_BREAK:
return fb
else:
raise NotImplementedError
else:
if __debug__:
log.debug(__name__, "not implemented (decode): %s", fb)
raise NotImplementedError
class Tagged:
def __init__(self, tag: int, value: Value) -> None:
self.tag = tag
self.value = value
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Tagged)
and self.tag == other.tag
and self.value == other.value
)
class Raw:
def __init__(self, value: Value):
self.value = value
class IndefiniteLengthArray:
def __init__(self, array: List[Value]) -> None:
self.array = array
def __eq__(self, other: object) -> bool:
if isinstance(other, IndefiniteLengthArray):
return self.array == other.array
elif isinstance(other, list):
return self.array == other
else:
return False
def encode(value: Value) -> bytes:
return b"".join(_cbor_encode(value))
def decode(cbor: bytes) -> Value:
r = utils.BufferReader(cbor)
res = _cbor_decode(r)
if r.remaining_count():
raise ValueError
return res