1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-17 10:51:00 +00:00

fix(python): replace base58 implementation with a more correct one

based on https://github.com/keis/base58/blob/master/base58/__init__.py
This commit is contained in:
gabrielkerekes 2023-12-01 16:17:16 +01:00 committed by matejcik
parent 76490e6e5f
commit 0dff9390db
2 changed files with 68 additions and 41 deletions

View File

@ -126,59 +126,48 @@ __b58chars = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
__b58base = len(__b58chars)
def b58encode_int(i: int) -> str:
"""Encode an integer using Base58"""
digits = []
while i:
i, idx = divmod(i, __b58base)
digits.append(__b58chars[idx])
return "".join(reversed(digits))
def b58encode(v: bytes) -> str:
"""encode v, which is a string of bytes, to base58."""
origlen = len(v)
v = v.lstrip(b"\0")
newlen = len(v)
long_value = 0
for c in v:
long_value = long_value * 256 + c
acc = int.from_bytes(v, byteorder="big") # first byte is most significant
result = ""
while long_value >= __b58base:
div, mod = divmod(long_value, __b58base)
result = __b58chars[mod] + result
long_value = div
result = __b58chars[long_value] + result
result = b58encode_int(acc)
return __b58chars[0] * (origlen - newlen) + result
# Bitcoin does a little leading-zero-compression:
# leading 0-bytes in the input become leading-1s
nPad = 0
for c in v:
if c == 0:
nPad += 1
else:
break
return (__b58chars[0] * nPad) + result
def b58decode_int(v: str) -> int:
"""Decode a Base58 encoded string as an integer"""
decimal = 0
try:
for char in v:
decimal = decimal * __b58base + __b58chars.index(char)
except KeyError:
raise ValueError(f"Invalid character {char!r}") from None # type: ignore [possibly unbound]
return decimal
def b58decode(v: AnyStr, length: Optional[int] = None) -> bytes:
"""decode v into a string of len bytes."""
str_v = v.decode() if isinstance(v, bytes) else v
v_str = v if isinstance(v, str) else v.decode()
origlen = len(v_str)
v_str = v_str.lstrip(__b58chars[0])
newlen = len(v_str)
for c in str_v:
if c not in __b58chars:
raise ValueError("invalid Base58 string")
acc = b58decode_int(v_str)
long_value = 0
for (i, c) in enumerate(str_v[::-1]):
long_value += __b58chars.find(c) * (__b58base**i)
result = b""
while long_value >= 256:
div, mod = divmod(long_value, 256)
result = struct.pack("B", mod) + result
long_value = div
result = struct.pack("B", long_value) + result
nPad = 0
for c in str_v:
if c == __b58chars[0]:
nPad += 1
else:
break
result = b"\x00" * nPad + result
result = acc.to_bytes(origlen - newlen + (acc.bit_length() + 7) // 8, "big")
if length is not None and len(result) != length:
raise ValueError("Result length does not match expected_length")

View File

@ -49,3 +49,41 @@ VECTORS = ( # descriptor, checksum
@pytest.mark.parametrize("descriptor, checksum", VECTORS)
def test_descriptor_checksum(descriptor, checksum):
assert tools.descriptor_checksum(descriptor) == checksum
BASE58_VECTORS = ( # data_hex, encoding_b58
("", ""),
("61", "2g"),
("626262", "a3gV"),
("636363", "aPEr"),
("73696d706c792061206c6f6e6720737472696e67", "2cFupjhnEsSn59qHXstmK2ffpLv2"),
(
"00eb15231dfceb60925886b67d065299925915aeb172c06647",
"1NS17iag9jJgTHD1VXjvLCEnZuQ3rJDE9L",
),
("516b6fcd0f", "ABnLTmg"),
("bf4f89001e670274dd", "3SEo3LWLoPntC"),
("572e4794", "3EFU7m"),
("ecac89cad93923c02321", "EJDM8drfXA6uyA"),
("10c8511e", "Rt5zm"),
("00000000000000000000", "1111111111"),
("00" * 32, "11111111111111111111111111111111"),
(
"000111d38e5fc9071ffcd20b4a763cc9ae4f252bb4e48fd66a835e252ada93ff480d6dd43dc62a641155a5",
"123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz",
),
(
"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff",
"1cWB5HCBdLjAuqGGReWE3R3CguuwSjw6RHn39s2yuDRTS5NsBgNiFpWgAnEx6VQi8csexkgYw3mdYrMHr8x9i7aEwP8kZ7vccXWqKDvGv3u1GxFKPuAkn8JCPPGDMf3vMMnbzm6Nh9zh1gcNsMvH3ZNLmP5fSG6DGbbi2tuwMWPthr4boWwCxf7ewSgNQeacyozhKDDQQ1qL5fQFUW52QKUZDZ5fw3KXNQJMcNTcaB723LchjeKun7MuGW5qyCBZYzA1KjofN1gYBV3NqyhQJ3Ns746GNuf9N2pQPmHz4xpnSrrfCvy6TVVz5d4PdrjeshsWQwpZsZGzvbdAdN8MKV5QsBDY",
),
)
@pytest.mark.parametrize("data_hex,encoding_b58", BASE58_VECTORS)
def test_b58encode(data_hex, encoding_b58):
assert tools.b58encode(bytes.fromhex(data_hex)) == encoding_b58
@pytest.mark.parametrize("data_hex,encoding_b58", BASE58_VECTORS)
def test_b58decode(data_hex, encoding_b58):
assert tools.b58decode(encoding_b58).hex() == data_hex