diff --git a/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h b/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h index 43901bedf..f70f32680 100644 --- a/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h +++ b/embed/extmod/modtrezorcrypto/modtrezorcrypto-bip32.h @@ -8,6 +8,7 @@ #include "py/objstr.h" #include "trezor-crypto/bip32.h" +#include "trezor-crypto/curves.h" /// class HDNode: /// ''' @@ -24,6 +25,91 @@ STATIC const mp_obj_type_t mod_trezorcrypto_HDNode_type; #define XPUB_MAXLEN 128 #define ADDRESS_MAXLEN 36 +/// def __init__(self, +/// depth: int, +/// fingerprint: int, +/// child_num: int, +/// chain_code: bytes, +/// private_key: bytes = None, +/// public_key: bytes = None, +/// curve_name: str = None) -> None: +/// ''' +/// ''' +STATIC mp_obj_t mod_trezorcrypto_HDNode_make_new(const mp_obj_type_t *type, size_t n_args, size_t n_kw, const mp_obj_t *args) { + + STATIC const mp_arg_t allowed_args[] = { + { MP_QSTR_depth, MP_ARG_REQUIRED | MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} }, + { MP_QSTR_fingerprint, MP_ARG_REQUIRED | MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} }, + { MP_QSTR_child_num, MP_ARG_REQUIRED | MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} }, + { MP_QSTR_chain_code, MP_ARG_REQUIRED | MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_empty_bytes} }, + { MP_QSTR_private_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_empty_bytes} }, + { MP_QSTR_public_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_empty_bytes} }, + { MP_QSTR_curve_name, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_empty_bytes} }, + }; + mp_arg_val_t vals[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all_kw_array(n_args, n_kw, args, MP_ARRAY_SIZE(allowed_args), allowed_args, vals); + + mp_buffer_info_t chain_code; + mp_buffer_info_t private_key; + mp_buffer_info_t public_key; + mp_buffer_info_t curve_name; + const mp_int_t depth = vals[0].u_int; + const mp_int_t fingerprint = vals[1].u_int; + const mp_int_t child_num = vals[2].u_int; + mp_get_buffer_raise(vals[3].u_obj, &chain_code, MP_BUFFER_READ); + mp_get_buffer_raise(vals[4].u_obj, &private_key, MP_BUFFER_READ); + mp_get_buffer_raise(vals[5].u_obj, &public_key, MP_BUFFER_READ); + mp_get_buffer_raise(vals[6].u_obj, &curve_name, MP_BUFFER_READ); + + if (NULL == chain_code.buf || 32 != chain_code.len) { + mp_raise_ValueError("chain_code is invalid"); + } + if (NULL == public_key.buf && NULL == private_key.buf) { + mp_raise_ValueError("either public_key or private_key is required"); + } + if (NULL != private_key.buf && 32 != private_key.len) { + mp_raise_ValueError("private_key is invalid"); + } + if (NULL != public_key.buf && 33 != public_key.len) { + mp_raise_ValueError("public_key is invalid"); + } + + const curve_info *curve = NULL; + if (NULL == curve_name.buf) { + curve = get_curve_by_name(SECP256K1_NAME); + } else { + curve = get_curve_by_name(curve_name.buf); + } + if (NULL == curve) { + mp_raise_ValueError("curve_name is invalid"); + } + + mp_obj_HDNode_t *o = m_new_obj(mp_obj_HDNode_t); + o->base.type = type; + + o->fingerprint = (uint32_t)fingerprint; + o->hdnode.depth = (uint32_t)depth; + o->hdnode.child_num = (uint32_t)child_num; + if (NULL != chain_code.buf && 32 == chain_code.len) { + memcpy(o->hdnode.chain_code, chain_code.buf, 32); + } else { + memset(o->hdnode.chain_code, 0, 32); + } + if (NULL != private_key.buf && 32 == private_key.len) { + memcpy(o->hdnode.private_key, private_key.buf, 32); + } else { + memset(o->hdnode.private_key, 0, 32); + } + if (NULL != public_key.buf && 33 == public_key.len) { + memcpy(o->hdnode.public_key, public_key.buf, 33); + } else { + memset(o->hdnode.public_key, 0, 33); + } + o->hdnode.curve = curve; + + return MP_OBJ_FROM_PTR(o); +} + /// def derive(self, index: int) -> None: /// ''' /// Derive a BIP0032 child node in place. @@ -232,6 +318,7 @@ STATIC MP_DEFINE_CONST_DICT(mod_trezorcrypto_HDNode_locals_dict, mod_trezorcrypt STATIC const mp_obj_type_t mod_trezorcrypto_HDNode_type = { { &mp_type_type }, .name = MP_QSTR_HDNode, + .make_new = mod_trezorcrypto_HDNode_make_new, .locals_dict = (void*)&mod_trezorcrypto_HDNode_locals_dict, };