config: always use bytes() for default value

pull/25/head
Jan Pochyla 8 years ago
parent 13533d9156
commit be7ee61ddd

@ -32,7 +32,7 @@ STATIC mp_obj_t mod_TrezorConfig_Config_make_new(const mp_obj_type_t *type, size
/// def trezor.config.get(app: int, key: int) -> bytes:
/// '''
/// Gets a value of given key for given app (or None if not set).
/// Gets a value of given key for given app (or empty bytes if not set).
/// '''
STATIC mp_obj_t mod_TrezorConfig_Config_get(mp_obj_t self, mp_obj_t app, mp_obj_t key) {
uint8_t a = mp_obj_get_int(app);
@ -41,8 +41,8 @@ STATIC mp_obj_t mod_TrezorConfig_Config_get(mp_obj_t self, mp_obj_t app, mp_obj_
const void *val;
uint32_t len;
bool r = norcow_get(appkey, &val, &len);
if (!r) {
return mp_const_none;
if (!r || len == 0) {
return mp_const_empty_bytes;
}
vstr_t vstr;
vstr_init_len(&vstr, len);

@ -6,12 +6,14 @@ else:
from TrezorConfig import Config
_config = Config()
def get(app, key, default=None):
v = _config.get(app, key)
return v if v else default
def set(app, key, value):
def get(app: int, key: int) -> bytes:
return _config.get(app, key)
def set(app: int, key: int, value: bytes):
return _config.set(app, key, value)
def wipe():
return _config.wipe()

@ -2,7 +2,8 @@
import ustruct
def Config():
class Config:
def __init__(self, filename):
self._data = {}
@ -25,8 +26,8 @@ def Config():
f.write(ustruct.pack('<HH', k, len(v)))
f.write(v)
def get(self, app_id, key, default=None):
return self._data.get((app_id << 8) | key, default)
def get(self, app_id, key):
return self._data.get((app_id << 8) | key, bytes())
def set(self, app_id, key, value):
self._data[(app_id << 8) | key] = value

@ -23,19 +23,18 @@ class TestConfig(unittest.TestCase):
def test_set_get(self):
config.wipe()
for _ in range(64):
appid, key = random.uniform(256), random.uniform(256)
value = random.bytes(128)
config.set(appid, key, value)
value2 = config.get(appid, key)
self.assertEqual(value, value2)
appid, key = random.uniform(256), random.uniform(256)
value = random.bytes(128)
config.set(appid, key, value)
value2 = config.get(appid, key)
self.assertEqual(value, value2)
def test_get_default(self):
config.wipe()
for _ in range(64):
appid, key = random.uniform(256), random.uniform(256)
value = random.bytes(128)
value2 = config.get(appid, key, value)
self.assertEqual(value, value2)
appid, key = random.uniform(256), random.uniform(256)
value = config.get(appid, key)
self.assertEqual(value, bytes())
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save