From c3f2db3be5351f997eb8fef32fb1c92bae5b7d2c Mon Sep 17 00:00:00 2001 From: matejcik Date: Tue, 2 Nov 2021 13:54:18 +0100 Subject: [PATCH] refactor(core): improve type signature of bech32.convertbits --- core/src/apps/binance/helpers.py | 5 ++- core/src/apps/cardano/helpers/bech32.py | 5 --- core/src/trezor/crypto/bech32.py | 45 ++++++++++++++++++------- core/src/trezor/crypto/cashaddr.py | 27 +-------------- 4 files changed, 35 insertions(+), 47 deletions(-) diff --git a/core/src/apps/binance/helpers.py b/core/src/apps/binance/helpers.py index d2467a0c2..742e675f0 100644 --- a/core/src/apps/binance/helpers.py +++ b/core/src/apps/binance/helpers.py @@ -91,7 +91,6 @@ def address_from_public_key(pubkey: bytes, hrp: str) -> str: h = sha256_ripemd160(pubkey).digest() - convertedbits = bech32.convertbits(h, 8, 5, False) - assert convertedbits is not None - + assert (len(h) * 8) % 5 == 0 # no padding will be added by convertbits + convertedbits = bech32.convertbits(h, 8, 5) return bech32.bech32_encode(hrp, convertedbits, bech32.Encoding.BECH32) diff --git a/core/src/apps/cardano/helpers/bech32.py b/core/src/apps/cardano/helpers/bech32.py index 7ab266841..5b70ad38f 100644 --- a/core/src/apps/cardano/helpers/bech32.py +++ b/core/src/apps/cardano/helpers/bech32.py @@ -16,8 +16,6 @@ HRP_SHARED_KEY_HASH = "addr_shared_vkh" def encode(hrp: str, data: bytes) -> str: converted_bits = bech32.convertbits(data, 8, 5) - if converted_bits is None: - raise ValueError return bech32.bech32_encode(hrp, converted_bits, bech32.Encoding.BECH32) @@ -40,7 +38,4 @@ def decode(hrp: str, bech: str) -> bytes: raise ValueError decoded = bech32.convertbits(data, 5, 8, False) - if decoded is None: - raise ValueError - return bytes(decoded) diff --git a/core/src/trezor/crypto/bech32.py b/core/src/trezor/crypto/bech32.py index 1d206258a..caaf40ff7 100644 --- a/core/src/trezor/crypto/bech32.py +++ b/core/src/trezor/crypto/bech32.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from enum import IntEnum - from typing import Iterable, Union, TypeVar + from typing import Sequence, Union, TypeVar A = TypeVar("A") B = TypeVar("B") @@ -112,9 +112,22 @@ def bech32_decode( def convertbits( - data: Iterable[int], frombits: int, tobits: int, pad: bool = True -) -> list[int] | None: - """General power-of-2 base conversion.""" + data: Sequence[int], frombits: int, tobits: int, arbitrary_input: bool = True +) -> list[int]: + """General power-of-2 base conversion. + + The `arbitrary_input` parameter specifies what happens when the total length + of input bits is not a multiple of `tobits`. + If True (default), the overflowing bits are zero-padded to the right. + If False, the input must must be a valid output of `convertbits()` in the opposite + direction. + Namely: + (a) the overflow must only be the zero padding + (b) length of the overflow is less than `frombits`, meaning that there is no + additional all-zero `frombits`-sized group at the end. + If both conditions hold, the all-zero overflow is discarded. + Otherwise a ValueError is raised. + """ acc = 0 bits = 0 ret = [] @@ -122,17 +135,22 @@ def convertbits( max_acc = (1 << (frombits + tobits - 1)) - 1 for value in data: if value < 0 or (value >> frombits): - return None + raise ValueError # input value does not match `frombits` size acc = ((acc << frombits) | value) & max_acc bits += frombits while bits >= tobits: bits -= tobits ret.append((acc >> bits) & maxv) - if pad: + + if arbitrary_input: if bits: + # append remaining bits, zero-padded from right ret.append((acc << (tobits - bits)) & maxv) elif bits >= frombits or ((acc << (tobits - bits)) & maxv): - return None + # (1) either there is a superfluous group at end of input, and/or + # (2) the remainder is nonzero + raise ValueError + return ret @@ -145,12 +163,15 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]: return (None, None) if hrpgot != hrp: return (None, None) - decoded = convertbits(data[1:], 5, 8, False) - if decoded is None or len(decoded) < 2 or len(decoded) > 40: + try: + decoded = convertbits(data[1:], 5, 8, False) + except ValueError: + return (None, None) + if not 2 <= len(decoded) <= 40: return (None, None) if data[0] > 16: return (None, None) - if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: + if data[0] == 0 and len(decoded) not in (20, 32): return (None, None) if ( data[0] == 0 @@ -162,11 +183,9 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]: return (data[0], decoded) -def encode(hrp: str, witver: int, witprog: Iterable[int]) -> str | None: +def encode(hrp: str, witver: int, witprog: bytes) -> str | None: """Encode a segwit address.""" data = convertbits(witprog, 8, 5) - if data is None: - return None spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M ret = bech32_encode(hrp, [witver] + data, spec) if decode(hrp, ret) == (None, None): diff --git a/core/src/trezor/crypto/cashaddr.py b/core/src/trezor/crypto/cashaddr.py index 26b21f9b0..9ebc45949 100644 --- a/core/src/trezor/crypto/cashaddr.py +++ b/core/src/trezor/crypto/cashaddr.py @@ -20,8 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -if False: - from typing import Iterable +from .bech32 import convertbits CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" ADDRESS_TYPE_P2KH = 0 @@ -75,30 +74,6 @@ def b32encode(inputs: list[int]) -> str: return out -def convertbits( - data: Iterable[int], frombits: int, tobits: int, pad: bool = True -) -> list[int]: - acc = 0 - bits = 0 - ret = [] - maxv = (1 << tobits) - 1 - max_acc = (1 << (frombits + tobits - 1)) - 1 - for value in data: - if value < 0 or (value >> frombits): - raise ValueError - acc = ((acc << frombits) | value) & max_acc - bits += frombits - while bits >= tobits: - bits -= tobits - ret.append((acc >> bits) & maxv) - if pad: - if bits: - ret.append((acc << (tobits - bits)) & maxv) - elif bits >= frombits or ((acc << (tobits - bits)) & maxv): - raise ValueError - return ret - - def encode(prefix: str, version: int, payload_bytes: bytes) -> str: payload_bytes = bytes([version]) + payload_bytes payload = convertbits(payload_bytes, 8, 5)