1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

Adapted python unit test to new API

This commit is contained in:
Jochen Hoenicke 2016-04-25 17:37:43 +02:00
parent 269b779ead
commit 490fbed289
2 changed files with 7 additions and 5 deletions

2
c.pxd
View File

@ -5,7 +5,7 @@ cdef extern from "bip32.h":
ctypedef struct HDNode: ctypedef struct HDNode:
uint8_t public_key[33] uint8_t public_key[33]
int hdnode_from_seed(const uint8_t *seed, int seed_len, HDNode *out) int hdnode_from_seed(const uint8_t *seed, int seed_len, const char *curve, HDNode *out)
int hdnode_private_ckd(HDNode *inout, uint32_t i) int hdnode_private_ckd(HDNode *inout, uint32_t i)

View File

@ -41,11 +41,13 @@ random_iters = int(os.environ.get('ITERS', 1))
lib = c.cdll.LoadLibrary('./libtrezor-crypto.so') lib = c.cdll.LoadLibrary('./libtrezor-crypto.so')
lib.get_curve_by_name.restype = c.c_void_p class curve_info(c.Structure):
_fields_ = [("bip32_name", c.c_char_p),
("params", c.c_void_p)]
lib.get_curve_by_name.restype = c.POINTER(curve_info)
BIGNUM = c.c_uint32 * 9 BIGNUM = c.c_uint32 * 9
class Random(random.Random): class Random(random.Random):
def randbytes(self, n): def randbytes(self, n):
buf = (c.c_uint8 * n)() buf = (c.c_uint8 * n)()
@ -83,7 +85,7 @@ def r(request):
@pytest.fixture(params=list(sorted(curves))) @pytest.fixture(params=list(sorted(curves)))
def curve(request): def curve(request):
name = request.param name = request.param
curve_ptr = lib.get_curve_by_name(name) curve_ptr = lib.get_curve_by_name(name).contents.params
assert curve_ptr, 'curve {} not found'.format(name) assert curve_ptr, 'curve {} not found'.format(name)
curve_obj = curves[name] curve_obj = curves[name]
curve_obj.ptr = c.c_void_p(curve_ptr) curve_obj.ptr = c.c_void_p(curve_ptr)
@ -93,7 +95,7 @@ def curve(request):
@pytest.fixture(params=points) @pytest.fixture(params=points)
def point(request): def point(request):
name = request.param.curve name = request.param.curve
curve_ptr = lib.get_curve_by_name(name) curve_ptr = lib.get_curve_by_name(name).contents.params
assert curve_ptr, 'curve {} not found'.format(name) assert curve_ptr, 'curve {} not found'.format(name)
curve_obj = curves[name] curve_obj = curves[name]
curve_obj.ptr = c.c_void_p(curve_ptr) curve_obj.ptr = c.c_void_p(curve_ptr)