diff --git a/embed/extmod/modtrezorio/modtrezorio-sdcard.h b/embed/extmod/modtrezorio/modtrezorio-sdcard.h index 3298615f9..f8006bd36 100644 --- a/embed/extmod/modtrezorio/modtrezorio-sdcard.h +++ b/embed/extmod/modtrezorio/modtrezorio-sdcard.h @@ -21,6 +21,9 @@ STATIC mp_obj_t mod_trezorio_SDCard_make_new(const mp_obj_type_t *type, size_t n mp_arg_check_num(n_args, n_kw, 0, 0, false); mp_obj_SDCard_t *o = m_new_obj(mp_obj_SDCard_t); o->base.type = type; +#if defined TREZOR_UNIX + sdcard_init(); +#endif return MP_OBJ_FROM_PTR(o); } diff --git a/embed/unix/sdcard.c b/embed/unix/sdcard.c index b3fcc0378..1f48a861f 100644 --- a/embed/unix/sdcard.c +++ b/embed/unix/sdcard.c @@ -5,31 +5,95 @@ * see LICENSE file for details */ +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" #include "sdcard.h" +#ifndef SDCARD_FILE +#define SDCARD_FILE "/var/tmp/trezor.sdcard" +#endif + +#define SDCARD_SIZE (32 * 1024 * 1024) + +static uint8_t *sdcard_buffer; +static secbool sdcard_powered; + +static void sdcard_exit(void) +{ + int r = munmap(sdcard_buffer, SDCARD_SIZE); + ensure(sectrue * (r == 0), "munmap failed"); +} + void sdcard_init(void) { + int r; + + // check whether the file exists and it has the correct size + struct stat sb; + r = stat(SDCARD_FILE, &sb); + + // (re)create if non existant or wrong size + if (r != 0 || sb.st_size != SDCARD_SIZE) { + int fd = open(SDCARD_FILE, O_RDWR | O_CREAT | O_TRUNC, (mode_t)0600); + ensure(sectrue * (fd >= 0), "open failed"); + for (int i = 0; i < SDCARD_SIZE / 16; i++) { + ssize_t s = write(fd, "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF", 16); + ensure(sectrue * (s >= 0), "write failed"); + } + r = close(fd); + ensure(sectrue * (r == 0), "close failed"); + } + + // mmap file + int fd = open(SDCARD_FILE, O_RDWR); + ensure(sectrue * (fd >= 0), "open failed"); + + void *map = mmap(0, SDCARD_SIZE, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + ensure(sectrue * (map != MAP_FAILED), "mmap failed"); + + sdcard_buffer = (uint8_t *)map; + + sdcard_powered = secfalse; + + atexit(sdcard_exit); } secbool sdcard_is_present(void) { - return secfalse; + return sectrue; } secbool sdcard_power_on(void) { - return secfalse; + sdcard_powered = sectrue; + return sectrue; } secbool sdcard_power_off(void) { + sdcard_powered = secfalse; return sectrue; } uint64_t sdcard_get_capacity_in_bytes(void) { - return 0; + return sdcard_powered == sectrue ? SDCARD_SIZE : 0; } secbool sdcard_read_blocks(uint32_t *dest, uint32_t block_num, uint32_t num_blocks) { - return secfalse; + if (sectrue != sdcard_powered) { + return secfalse; + } + memcpy(dest, sdcard_buffer + block_num * SDCARD_BLOCK_SIZE, num_blocks * SDCARD_BLOCK_SIZE); + return sectrue; } secbool sdcard_write_blocks(const uint32_t *src, uint32_t block_num, uint32_t num_blocks) { - return secfalse; + if (sectrue != sdcard_powered) { + return secfalse; + } + memcpy(sdcard_buffer + block_num * SDCARD_BLOCK_SIZE, src, num_blocks * SDCARD_BLOCK_SIZE); + return sectrue; } diff --git a/tests/test_trezor.io.py b/tests/test_trezor.io.py index bd0f499ee..c960726f8 100644 --- a/tests/test_trezor.io.py +++ b/tests/test_trezor.io.py @@ -2,11 +2,47 @@ from common import * from trezor import io + class TestIo(unittest.TestCase): - def test_sdcard(self): + def test_sdcard_start(self): sd = io.SDCard() - sd.present() + assert sd.present() == True + + def test_sdcard_power(self): + sd = io.SDCard() + x = bytearray(8 * 512) + assert sd.capacity() == 0 + assert sd.read(0, x) == False + sd.power(True) + assert sd.capacity() > 0 + assert sd.read(0, x) == True + sd.power(False) + assert sd.capacity() == 0 + assert sd.read(0, x) == False + + def test_sdcard_read(self): + sd = io.SDCard() + x = bytearray(8 * 512) + sd.power(True) + assert sd.read(0, x) == True + sd.power(False) + assert sd.read(0, x) == False + + def test_sdcard_read_write(self): + sd = io.SDCard() + r = bytearray(8 * 512) + w0 = bytearray(b'0' * (8 * 512)) + w1 = bytearray(b'1' * (8 * 512)) + sd.power(True) + assert sd.write(0, w0) == True + assert sd.read(0, r) == True + assert r == w0 + assert sd.write(0, w1) == True + assert sd.read(0, r) == True + assert r == w1 + sd.power(False) + if __name__ == '__main__': unittest.main()