storage: style

pull/159/head
Tomas Susanka 5 years ago
parent 99329fb30d
commit 18482a9c37

@ -10,7 +10,9 @@ class Storage:
self.lib = c.cdll.LoadLibrary(fname) self.lib = c.cdll.LoadLibrary(fname)
self.flash_size = c.cast(self.lib.FLASH_SIZE, c.POINTER(c.c_uint32))[0] self.flash_size = c.cast(self.lib.FLASH_SIZE, c.POINTER(c.c_uint32))[0]
self.flash_buffer = c.create_string_buffer(self.flash_size) self.flash_buffer = c.create_string_buffer(self.flash_size)
c.cast(self.lib.FLASH_BUFFER, c.POINTER(c.c_void_p))[0] = c.addressof(self.flash_buffer) c.cast(self.lib.FLASH_BUFFER, c.POINTER(c.c_void_p))[0] = c.addressof(
self.flash_buffer
)
def init(self, salt: bytes) -> None: def init(self, salt: bytes) -> None:
self.lib.storage_init(0, salt, c.c_uint16(len(salt))) self.lib.storage_init(0, salt, c.c_uint16(len(salt)))
@ -31,14 +33,18 @@ class Storage:
return self.lib.storage_get_pin_rem() return self.lib.storage_get_pin_rem()
def change_pin(self, oldpin: int, newpin: int) -> bool: def change_pin(self, oldpin: int, newpin: int) -> bool:
return sectrue == self.lib.storage_change_pin(c.c_uint32(oldpin), c.c_uint32(newpin)) return sectrue == self.lib.storage_change_pin(
c.c_uint32(oldpin), c.c_uint32(newpin)
)
def get(self, key: int) -> bytes: def get(self, key: int) -> bytes:
val_len = c.c_uint16() val_len = c.c_uint16()
if sectrue != self.lib.storage_get(c.c_uint16(key), None, 0, c.byref(val_len)): if sectrue != self.lib.storage_get(c.c_uint16(key), None, 0, c.byref(val_len)):
raise RuntimeError("Failed to find key in storage.") raise RuntimeError("Failed to find key in storage.")
s = c.create_string_buffer(val_len.value) s = c.create_string_buffer(val_len.value)
if sectrue != self.lib.storage_get(c.c_uint16(key), s, val_len, c.byref(val_len)): if sectrue != self.lib.storage_get(
c.c_uint16(key), s, val_len, c.byref(val_len)
):
raise RuntimeError("Failed to get value from storage.") raise RuntimeError("Failed to get value from storage.")
return s.raw return s.raw
@ -47,7 +53,9 @@ class Storage:
raise RuntimeError("Failed to set value in storage.") raise RuntimeError("Failed to set value in storage.")
def set_counter(self, key: int, count: int) -> bool: def set_counter(self, key: int, count: int) -> bool:
return sectrue == self.lib.storage_set_counter(c.c_uint16(key), c.c_uint32(count)) return sectrue == self.lib.storage_set_counter(
c.c_uint16(key), c.c_uint32(count)
)
def next_counter(self, key: int) -> int: def next_counter(self, key: int) -> int:
count = c.c_uint32() count = c.c_uint32()
@ -61,7 +69,10 @@ class Storage:
def _dump(self) -> bytes: def _dump(self) -> bytes:
# return just sectors 4 and 16 of the whole flash # return just sectors 4 and 16 of the whole flash
return [self.flash_buffer[0x010000:0x010000 + 0x10000], self.flash_buffer[0x110000:0x110000 + 0x10000]] return [
self.flash_buffer[0x010000 : 0x010000 + 0x10000],
self.flash_buffer[0x110000 : 0x110000 + 0x10000],
]
def _get_flash_buffer(self) -> bytes: def _get_flash_buffer(self) -> bytes:
return bytes(self.flash_buffer) return bytes(self.flash_buffer)

@ -4,13 +4,15 @@ import os
sectrue = -1431655766 # 0xAAAAAAAAA sectrue = -1431655766 # 0xAAAAAAAAA
fname = os.path.join(os.path.dirname(__file__), "libtrezor-storage0.so") fname = os.path.join(os.path.dirname(__file__), "libtrezor-storage0.so")
class Storage:
class Storage:
def __init__(self) -> None: def __init__(self) -> None:
self.lib = c.cdll.LoadLibrary(fname) self.lib = c.cdll.LoadLibrary(fname)
self.flash_size = c.cast(self.lib.FLASH_SIZE, c.POINTER(c.c_uint32))[0] self.flash_size = c.cast(self.lib.FLASH_SIZE, c.POINTER(c.c_uint32))[0]
self.flash_buffer = c.create_string_buffer(self.flash_size) self.flash_buffer = c.create_string_buffer(self.flash_size)
c.cast(self.lib.FLASH_BUFFER, c.POINTER(c.c_void_p))[0] = c.addressof(self.flash_buffer) c.cast(self.lib.FLASH_BUFFER, c.POINTER(c.c_void_p))[0] = c.addressof(
self.flash_buffer
)
def init(self) -> None: def init(self) -> None:
self.lib.storage_init(0) self.lib.storage_init(0)
@ -28,12 +30,16 @@ class Storage:
return sectrue == self.lib.storage_has_pin() return sectrue == self.lib.storage_has_pin()
def change_pin(self, oldpin: int, newpin: int) -> bool: def change_pin(self, oldpin: int, newpin: int) -> bool:
return sectrue == self.lib.storage_change_pin(c.c_uint32(oldpin), c.c_uint32(newpin)) return sectrue == self.lib.storage_change_pin(
c.c_uint32(oldpin), c.c_uint32(newpin)
)
def get(self, key: int) -> bytes: def get(self, key: int) -> bytes:
val_ptr = c.c_void_p() val_ptr = c.c_void_p()
val_len = c.c_uint16() val_len = c.c_uint16()
if sectrue != self.lib.storage_get(c.c_uint16(key), c.byref(val_ptr), c.byref(val_len)): if sectrue != self.lib.storage_get(
c.c_uint16(key), c.byref(val_ptr), c.byref(val_len)
):
raise RuntimeError("Failed to find key in storage.") raise RuntimeError("Failed to find key in storage.")
return c.string_at(val_ptr, size=val_len.value) return c.string_at(val_ptr, size=val_len.value)
@ -43,7 +49,10 @@ class Storage:
def _dump(self) -> bytes: def _dump(self) -> bytes:
# return just sectors 4 and 16 of the whole flash # return just sectors 4 and 16 of the whole flash
return [self.flash_buffer[0x010000:0x010000 + 0x10000], self.flash_buffer[0x110000:0x110000 + 0x10000]] return [
self.flash_buffer[0x010000 : 0x010000 + 0x10000],
self.flash_buffer[0x110000 : 0x110000 + 0x10000],
]
def _get_flash_buffer(self) -> bytes: def _get_flash_buffer(self) -> bytes:
return bytes(self.flash_buffer) return bytes(self.flash_buffer)

@ -1,7 +1,7 @@
import pytest import pytest
from . import common
from ..src import consts, norcow from ..src import consts, norcow
from . import common
def test_norcow_set(): def test_norcow_set():

@ -1,18 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from hashlib import sha256
from c.storage import Storage as StorageC from c.storage import Storage as StorageC
from c0.storage import Storage as StorageC0
from python.src.storage import Storage as StoragePy from python.src.storage import Storage as StoragePy
from hashlib import sha256
def hash(data): def hash(data):
return sha256(data).hexdigest()[:16] return sha256(data).hexdigest()[:16]
# Strings for testing ChaCha20 encryption. # Strings for testing ChaCha20 encryption.
test_strings = [b"Short string.", b"", b"Although ChaCha20 is a stream cipher, it operates on blocks of 64 bytes. This string is over 152 bytes in length so that we test multi-block encryption.", b"This string is exactly 64 bytes long, that is exactly one block."] test_strings = [
b"Short string.",
b"",
b"Although ChaCha20 is a stream cipher, it operates on blocks of 64 bytes. This string is over 152 bytes in length so that we test multi-block encryption.",
b"This string is exactly 64 bytes long, that is exactly one block.",
]
# Unique device ID for testing. # Unique device ID for testing.
uid = b"\x67\xce\x6a\xe8\xf7\x9b\x73\x96\x83\x88\x21\x5e" uid = b"\x67\xce\x6a\xe8\xf7\x9b\x73\x96\x83\x88\x21\x5e"
@ -24,12 +28,12 @@ a = []
for s in [sc, sp]: for s in [sc, sp]:
print(s.__class__) print(s.__class__)
s.init(uid) s.init(uid)
assert s.unlock(3) == False assert s.unlock(3) is False
assert s.unlock(1) == True assert s.unlock(1) is True
s.set(0xbeef, b"hello") s.set(0xBEEF, b"hello")
s.set(0x03fe, b"world!") s.set(0x03FE, b"world!")
s.set(0xbeef, b"satoshi") s.set(0xBEEF, b"satoshi")
s.set(0xbeef, b"Satoshi") s.set(0xBEEF, b"Satoshi")
for value in test_strings: for value in test_strings:
s.set(0x0301, value) s.set(0x0301, value)
assert s.get(0x0301) == value assert s.get(0x0301) == value

@ -2,3 +2,4 @@
^\./core/src/ ^\./core/src/
^\./crypto/ ^\./crypto/
^\./legacy/ ^\./legacy/
^\./storage/

Loading…
Cancel
Save