diff --git a/src/tests/test_crypto_base58.py b/src/tests/test_crypto_base58.py index 41bcdbe5d3..d2e3f3a3d5 100644 --- a/src/tests/test_crypto_base58.py +++ b/src/tests/test_crypto_base58.py @@ -5,6 +5,9 @@ import unittest from ubinascii import unhexlify from trezor.crypto import base58 +from trezor.crypto.hashlib import ripemd160 + +digestfunc_graphene = lambda x: ripemd160(x).digest()[:4] class TestCryptoBase58(unittest.TestCase): @@ -62,13 +65,24 @@ class TestCryptoBase58(unittest.TestCase): ('055ece0cadddc415b1980f001785947120acdb36fc', '3ALJH9Y951VCGcVZYAdpA3KchoP9McEj1G'), ] + vectors_graphene = [ + ('02e649f63f8e8121345fd7f47d0d185a3ccaa843115cd2e9392dcd9b82263bc680', '6dumtt9swxCqwdPZBGXh9YmHoEjFFnNfwHaTqRbQTghGAY2gRz'), + ('021c7359cd885c0e319924d97e3980206ad64387aff54908241125b3a88b55ca16', '5725vivYpuFWbeyTifZ5KevnHyqXCi5hwHbNU9cYz1FHbFXCxX'), + ('02f561e0b57a552df3fa1df2d87a906b7a9fc33a83d5d15fa68a644ecb0806b49a', '6kZKHSuxqAwdCYsMvwTcipoTsNE2jmEUNBQufGYywpniBKXWZK'), + ('03e7595c3e6b58f907bee951dc29796f3757307e700ecf3d09307a0cc4a564eba3', '8b82mpnH8YX1E9RHnU2a2YgLTZ8ooevEGP9N15c1yFqhoBvJur'), + ] + def test_decode_check(self): for a, b in self.vectors: self.assertEqual(base58.decode_check(b), unhexlify(a)) + for a, b in self.vectors_graphene: + self.assertEqual(base58.decode_check(b, digestfunc=digestfunc_graphene), unhexlify(a)) def test_encode_check(self): for a, b in self.vectors: self.assertEqual(base58.encode_check(unhexlify(a)), b) + for a, b in self.vectors_graphene: + self.assertEqual(base58.encode_check(unhexlify(a), digestfunc=digestfunc_graphene), b) if __name__ == '__main__': unittest.main() diff --git a/src/trezor/crypto/base58.py b/src/trezor/crypto/base58.py index 9b44325bd3..7e36e9f1cc 100644 --- a/src/trezor/crypto/base58.py +++ b/src/trezor/crypto/base58.py @@ -13,8 +13,6 @@ # This module adds shiny packaging and support for python3. # -from .hashlib import sha256 - # 58 character alphabet used _alphabet = '123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz' @@ -36,7 +34,7 @@ def encode(data: bytes) -> str: acc, mod = divmod(acc, 58) result += _alphabet[mod] - return ''.join([c for c in reversed(result + _alphabet[0] * (origlen - newlen))]) + return ''.join((c for c in reversed(result + _alphabet[0] * (origlen - newlen)))) def decode(string: str) -> bytes: @@ -57,25 +55,31 @@ def decode(string: str) -> bytes: acc, mod = divmod(acc, 256) result.append(mod) - return bytes([b for b in reversed(result +[0] * (origlen - newlen))]) + return bytes((b for b in reversed(result +[0] * (origlen - newlen)))) -def encode_check(data: bytes) -> str: +def encode_check(data: bytes, digestfunc=None) -> str: ''' Convert bytes to base58 encoded string, append checksum. ''' - digest = sha256(sha256(data).digest()).digest() - return encode(data + digest[:4]) + if digestfunc is None: + from .hashlib import sha256 + digestfunc = lambda x: sha256(sha256(x).digest()).digest()[:4] + return encode(data + digestfunc(data)) -def decode_check(string: str) -> bytes: +def decode_check(string: str, digestfunc=None) -> bytes: ''' Convert base58 encoded string to bytes and verify checksum. ''' result = decode(string) - result, check = result[:-4], result[-4:] - digest = sha256(sha256(result).digest()).digest() - if check != digest[:4]: + if digestfunc is None: + from .hashlib import sha256 + digestfunc = lambda x: sha256(sha256(x).digest()).digest()[:4] + digestlen = len(digestfunc(b'')) + result, check = result[:-digestlen], result[-digestlen:] + + if check != digestfunc(result): raise ValueError('Invalid checksum') return result