diff --git a/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h b/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h index 0574ba3e0..58cfcee9f 100644 --- a/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h +++ b/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h @@ -31,6 +31,7 @@ typedef struct _mp_obj_HDNode_t { mp_obj_base_t base; uint32_t fingerprint; HDNode hdnode; + bool private_key_set; } mp_obj_HDNode_t; STATIC const mp_obj_type_t mod_trezorcrypto_HDNode_type; @@ -110,8 +111,10 @@ STATIC mp_obj_t mod_trezorcrypto_HDNode_make_new(const mp_obj_type_t *type, size } if (32 == private_key.len) { memcpy(o->hdnode.private_key, private_key.buf, 32); + o->private_key_set = true; } else { memzero(o->hdnode.private_key, 32); + o->private_key_set = false; } if (33 == public_key.len) { memcpy(o->hdnode.public_key, public_key.buf, 33); @@ -123,16 +126,27 @@ STATIC mp_obj_t mod_trezorcrypto_HDNode_make_new(const mp_obj_type_t *type, size return MP_OBJ_FROM_PTR(o); } -/// def derive(self, index: int) -> None: +/// def derive(self, index: int, public: bool=False) -> None: /// ''' /// Derive a BIP0032 child node in place. /// ''' -STATIC mp_obj_t mod_trezorcrypto_HDNode_derive(mp_obj_t self, mp_obj_t index) { - mp_obj_HDNode_t *o = MP_OBJ_TO_PTR(self); - uint32_t i = mp_obj_get_int_truncated(index); +STATIC mp_obj_t mod_trezorcrypto_HDNode_derive(size_t n_args, const mp_obj_t *args) { + mp_obj_HDNode_t *o = MP_OBJ_TO_PTR(args[0]); + uint32_t i = mp_obj_get_int_truncated(args[1]); uint32_t fp = hdnode_fingerprint(&o->hdnode); + bool public = n_args > 2 && args[2] == mp_const_true; - if (!hdnode_private_ckd(&o->hdnode, i)) { + int res; + if (public) { + res = hdnode_public_ckd(&o->hdnode, i); + } else { + if (!o->private_key_set) { + memzero(&o->hdnode, sizeof(o->hdnode)); + mp_raise_ValueError("Failed to derive, private key not set"); + } + res = hdnode_private_ckd(&o->hdnode, i); + } + if (!res) { memzero(&o->hdnode, sizeof(o->hdnode)); mp_raise_ValueError("Failed to derive"); } @@ -140,7 +154,7 @@ STATIC mp_obj_t mod_trezorcrypto_HDNode_derive(mp_obj_t self, mp_obj_t index) { return mp_const_none; } -STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_HDNode_derive_obj, mod_trezorcrypto_HDNode_derive); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorcrypto_HDNode_derive_obj, 2, 3, mod_trezorcrypto_HDNode_derive); /// def derive_path(self, path: List[int]) -> None: /// '''