diff --git a/micropython/boardloader/main.c b/micropython/boardloader/main.c index 32d23d474..85bd5ae4b 100644 --- a/micropython/boardloader/main.c +++ b/micropython/boardloader/main.c @@ -51,27 +51,18 @@ bool check_sdcard(void) } } +static void progress_callback(void) { + display_printf("."); +} + bool copy_sdcard(void) { display_printf("erasing flash "); // erase flash (except boardloader) - HAL_FLASH_Unlock(); - FLASH_EraseInitTypeDef EraseInitStruct; - __HAL_FLASH_CLEAR_FLAG(FLASH_FLAG_EOP | FLASH_FLAG_OPERR | FLASH_FLAG_WRPERR | - FLASH_FLAG_PGAERR | FLASH_FLAG_PGPERR | FLASH_FLAG_PGSERR); - EraseInitStruct.TypeErase = FLASH_TYPEERASE_SECTORS; - EraseInitStruct.VoltageRange = FLASH_VOLTAGE_RANGE_3; - EraseInitStruct.NbSectors = 1; - uint32_t SectorError = 0; - for (int i = 2; i < 12; i++) { - EraseInitStruct.Sector = i; - if (HAL_FLASHEx_Erase(&EraseInitStruct, &SectorError) != HAL_OK) { - HAL_FLASH_Lock(); - display_printf(" failed\n"); - return false; - } - display_printf("."); + if (0 != flash_erase_sectors(FLASH_SECTOR_BOARDLOADER_END + 1, FLASH_SECTOR_FIRMWARE_END, progress_callback)) { + display_printf(" failed\n"); + return false; } display_printf(" done\n"); @@ -87,10 +78,10 @@ bool copy_sdcard(void) if (!image_parse_header((const uint8_t *)buf, IMAGE_MAGIC, IMAGE_MAXSIZE, &hdr)) { display_printf("invalid header\n"); sdcard_power_off(); - HAL_FLASH_Lock(); return false; } + HAL_FLASH_Unlock(); for (int i = 0; i < (HEADER_SIZE + hdr.codelen) / SDCARD_BLOCK_SIZE; i++) { sdcard_read_blocks((uint8_t *)buf, i, 1); for (int j = 0; j < SDCARD_BLOCK_SIZE / sizeof(uint32_t); j++) { diff --git a/micropython/bootloader/messages.c b/micropython/bootloader/messages.c index 9c8f9bf3a..071d527ee 100644 --- a/micropython/bootloader/messages.c +++ b/micropython/bootloader/messages.c @@ -1,9 +1,13 @@ +#include STM32_HAL_H + #include #include #include #include "messages.pb.h" +#include "common.h" +#include "flash.h" #include "usb.h" #include "version.h" @@ -229,6 +233,9 @@ void process_msg_FirmwareErase(uint8_t iface_num, uint32_t msg_size, uint8_t *bu firmware_size = msg_recv.has_length ? msg_recv.length : 0; if (firmware_size > 0 && firmware_size % 4 == 0) { + // erase flash + flash_erase_sectors(FLASH_SECTOR_FIRMWARE_START, FLASH_SECTOR_FIRMWARE_END, NULL); + // request new firmware chunk_requested = (firmware_size > FIRMWARE_CHUNK_SIZE) ? FIRMWARE_CHUNK_SIZE : firmware_size; MSG_SEND_INIT(FirmwareRequest); MSG_SEND_ASSIGN_VALUE(offset, 0); @@ -247,21 +254,33 @@ static uint32_t chunk_size = 0; static bool _read_payload(pb_istream_t *stream, const pb_field_t *field, void **arg) { #define BUFSIZE 1024 - pb_byte_t buf[BUFSIZE]; + uint32_t buf[BUFSIZE / sizeof(uint32_t)]; + uint32_t chunk_written = 0; chunk_size = stream->bytes_left; while (stream->bytes_left) { - if (!pb_read(stream, buf, (stream->bytes_left > BUFSIZE) ? BUFSIZE : stream->bytes_left)) { + memset(buf, 0xFF, sizeof(buf)); + // read data + if (!pb_read(stream, (pb_byte_t *)buf, (stream->bytes_left > BUFSIZE) ? BUFSIZE : stream->bytes_left)) { return false; } + // write data + for (int i = 0; i < BUFSIZE / sizeof(uint32_t); i++) { + if (HAL_FLASH_Program(FLASH_TYPEPROGRAM_WORD, FIRMWARE_START + firmware_flashed + chunk_written + i * sizeof(uint32_t), buf[i]) != HAL_OK) { + return false; + } + } + chunk_written += BUFSIZE; } return true; } void process_msg_FirmwareUpload(uint8_t iface_num, uint32_t msg_size, uint8_t *buf) { + HAL_FLASH_Unlock(); MSG_RECV_INIT(FirmwareUpload); MSG_RECV_CALLBACK(payload, _read_payload); MSG_RECV(FirmwareUpload); + HAL_FLASH_Lock(); if (chunk_size != chunk_requested) { MSG_SEND_INIT(Failure); diff --git a/micropython/trezorhal/flash.c b/micropython/trezorhal/flash.c index 0ce0a2aeb..234d59681 100644 --- a/micropython/trezorhal/flash.c +++ b/micropython/trezorhal/flash.c @@ -38,3 +38,27 @@ void flash_set_option_bytes(void) HAL_FLASHEx_OBProgram(&opts); } } + +int flash_erase_sectors(int start, int end, void (*progress)(void)) +{ + HAL_FLASH_Unlock(); + FLASH_EraseInitTypeDef EraseInitStruct; + __HAL_FLASH_CLEAR_FLAG(FLASH_FLAG_EOP | FLASH_FLAG_OPERR | FLASH_FLAG_WRPERR | + FLASH_FLAG_PGAERR | FLASH_FLAG_PGPERR | FLASH_FLAG_PGSERR); + EraseInitStruct.TypeErase = FLASH_TYPEERASE_SECTORS; + EraseInitStruct.VoltageRange = FLASH_VOLTAGE_RANGE_3; + EraseInitStruct.NbSectors = 1; + uint32_t SectorError = 0; + for (int i = start; i <= end; i++) { + EraseInitStruct.Sector = i; + if (HAL_FLASHEx_Erase(&EraseInitStruct, &SectorError) != HAL_OK) { + HAL_FLASH_Lock(); + return 0; + } + if (progress) { + progress(); + } + } + HAL_FLASH_Lock(); + return 1; +} diff --git a/micropython/trezorhal/flash.h b/micropython/trezorhal/flash.h index 015436541..670b7f11e 100644 --- a/micropython/trezorhal/flash.h +++ b/micropython/trezorhal/flash.h @@ -5,4 +5,18 @@ int flash_init(void); void flash_set_option_bytes(void); +#define FLASH_SECTOR_BOARDLOADER_START 0 +#define FLASH_SECTOR_BOARDLOADER_END 1 + +#define FLASH_SECTOR_STORAGE_START 2 +#define FLASH_SECTOR_STORAGE_END 3 + +#define FLASH_SECTOR_BOOTLOADER_START 4 +#define FLASH_SECTOR_BOOTLOADER_END 4 + +#define FLASH_SECTOR_FIRMWARE_START 5 +#define FLASH_SECTOR_FIRMWARE_END 11 + +int flash_erase_sectors(int start, int end, void (*progress)(void)); + #endif