diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index e5c32b4ded..84cd5bf064 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -126,59 +126,48 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" __b58base = len(__b58chars) +def b58encode_int(i: int) -> str: + """Encode an integer using Base58""" + digits = [] + while i: + i, idx = divmod(i, __b58base) + digits.append(__b58chars[idx]) + return "".join(reversed(digits)) + + def b58encode(v: bytes) -> str: """encode v, which is a string of bytes, to base58.""" + origlen = len(v) + v = v.lstrip(b"\0") + newlen = len(v) - long_value = 0 - for c in v: - long_value = long_value * 256 + c + acc = int.from_bytes(v, byteorder="big") # first byte is most significant - result = "" - while long_value >= __b58base: - div, mod = divmod(long_value, __b58base) - result = __b58chars[mod] + result - long_value = div - result = __b58chars[long_value] + result + result = b58encode_int(acc) + return __b58chars[0] * (origlen - newlen) + result - # Bitcoin does a little leading-zero-compression: - # leading 0-bytes in the input become leading-1s - nPad = 0 - for c in v: - if c == 0: - nPad += 1 - else: - break - return (__b58chars[0] * nPad) + result +def b58decode_int(v: str) -> int: + """Decode a Base58 encoded string as an integer""" + decimal = 0 + try: + for char in v: + decimal = decimal * __b58base + __b58chars.index(char) + except KeyError: + raise ValueError(f"Invalid character {char!r}") from None # type: ignore [possibly unbound] + return decimal def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes: """decode v into a string of len bytes.""" - str_v = v.decode() if isinstance(v, bytes) else v + v_str = v if isinstance(v, str) else v.decode() + origlen = len(v_str) + v_str = v_str.lstrip(__b58chars[0]) + newlen = len(v_str) - for c in str_v: - if c not in __b58chars: - raise ValueError("invalid Base58 string") + acc = b58decode_int(v_str) - long_value = 0 - for (i, c) in enumerate(str_v[::-1]): - long_value += __b58chars.find(c) * (__b58base**i) - - result = b"" - while long_value >= 256: - div, mod = divmod(long_value, 256) - result = struct.pack("B", mod) + result - long_value = div - result = struct.pack("B", long_value) + result - - nPad = 0 - for c in str_v: - if c == __b58chars[0]: - nPad += 1 - else: - break - - result = b"\x00" * nPad + result + result = acc.to_bytes(origlen - newlen + (acc.bit_length() + 7) // 8, "big") if length is not None and len(result) != length: raise ValueError("Result length does not match expected_length") diff --git a/python/tests/test_tools.py b/python/tests/test_tools.py index 3a0325ea8d..3bdda1fe97 100644 --- a/python/tests/test_tools.py +++ b/python/tests/test_tools.py @@ -49,3 +49,41 @@ VECTORS = ( # descriptor, checksum @pytest.mark.parametrize("descriptor, checksum", VECTORS) def test_descriptor_checksum(descriptor, checksum): assert tools.descriptor_checksum(descriptor) == checksum + + +BASE58_VECTORS = ( # data_hex, encoding_b58 + ("", ""), + ("61", "2g"), + ("626262", "a3gV"), + ("636363", "aPEr"), + ("73696d706c792061206c6f6e6720737472696e67", "2cFupjhnEsSn59qHXstmK2ffpLv2"), + ( + "00eb15231dfceb60925886b67d065299925915aeb172c06647", + "1NS17iag9jJgTHD1VXjvLCEnZuQ3rJDE9L", + ), + ("516b6fcd0f", "ABnLTmg"), + ("bf4f89001e670274dd", "3SEo3LWLoPntC"), + ("572e4794", "3EFU7m"), + ("ecac89cad93923c02321", "EJDM8drfXA6uyA"), + ("10c8511e", "Rt5zm"), + ("00000000000000000000", "1111111111"), + ("00" * 32, "11111111111111111111111111111111"), + ( + "000111d38e5fc9071ffcd20b4a763cc9ae4f252bb4e48fd66a835e252ada93ff480d6dd43dc62a641155a5", + "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz", + ), + ( + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff", + "1cWB5HCBdLjAuqGGReWE3R3CguuwSjw6RHn39s2yuDRTS5NsBgNiFpWgAnEx6VQi8csexkgYw3mdYrMHr8x9i7aEwP8kZ7vccXWqKDvGv3u1GxFKPuAkn8JCPPGDMf3vMMnbzm6Nh9zh1gcNsMvH3ZNLmP5fSG6DGbbi2tuwMWPthr4boWwCxf7ewSgNQeacyozhKDDQQ1qL5fQFUW52QKUZDZ5fw3KXNQJMcNTcaB723LchjeKun7MuGW5qyCBZYzA1KjofN1gYBV3NqyhQJ3Ns746GNuf9N2pQPmHz4xpnSrrfCvy6TVVz5d4PdrjeshsWQwpZsZGzvbdAdN8MKV5QsBDY", + ), +) + + +@pytest.mark.parametrize("data_hex,encoding_b58", BASE58_VECTORS) +def test_b58encode(data_hex, encoding_b58): + assert tools.b58encode(bytes.fromhex(data_hex)) == encoding_b58 + + +@pytest.mark.parametrize("data_hex,encoding_b58", BASE58_VECTORS) +def test_b58decode(data_hex, encoding_b58): + assert tools.b58decode(encoding_b58).hex() == data_hex