1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 14:58:09 +00:00

refactor(core): improve type signature of bech32.convertbits

This commit is contained in:
matejcik 2021-11-02 13:54:18 +01:00 committed by matejcik
parent 9fc5bb546b
commit c3f2db3be5
4 changed files with 35 additions and 47 deletions

View File

@ -91,7 +91,6 @@ def address_from_public_key(pubkey: bytes, hrp: str) -> str:
h = sha256_ripemd160(pubkey).digest()
convertedbits = bech32.convertbits(h, 8, 5, False)
assert convertedbits is not None
assert (len(h) * 8) % 5 == 0 # no padding will be added by convertbits
convertedbits = bech32.convertbits(h, 8, 5)
return bech32.bech32_encode(hrp, convertedbits, bech32.Encoding.BECH32)

View File

@ -16,8 +16,6 @@ HRP_SHARED_KEY_HASH = "addr_shared_vkh"
def encode(hrp: str, data: bytes) -> str:
converted_bits = bech32.convertbits(data, 8, 5)
if converted_bits is None:
raise ValueError
return bech32.bech32_encode(hrp, converted_bits, bech32.Encoding.BECH32)
@ -40,7 +38,4 @@ def decode(hrp: str, bech: str) -> bytes:
raise ValueError
decoded = bech32.convertbits(data, 5, 8, False)
if decoded is None:
raise ValueError
return bytes(decoded)

View File

@ -24,7 +24,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from enum import IntEnum
from typing import Iterable, Union, TypeVar
from typing import Sequence, Union, TypeVar
A = TypeVar("A")
B = TypeVar("B")
@ -112,9 +112,22 @@ def bech32_decode(
def convertbits(
data: Iterable[int], frombits: int, tobits: int, pad: bool = True
) -> list[int] | None:
"""General power-of-2 base conversion."""
data: Sequence[int], frombits: int, tobits: int, arbitrary_input: bool = True
) -> list[int]:
"""General power-of-2 base conversion.
The `arbitrary_input` parameter specifies what happens when the total length
of input bits is not a multiple of `tobits`.
If True (default), the overflowing bits are zero-padded to the right.
If False, the input must must be a valid output of `convertbits()` in the opposite
direction.
Namely:
(a) the overflow must only be the zero padding
(b) length of the overflow is less than `frombits`, meaning that there is no
additional all-zero `frombits`-sized group at the end.
If both conditions hold, the all-zero overflow is discarded.
Otherwise a ValueError is raised.
"""
acc = 0
bits = 0
ret = []
@ -122,17 +135,22 @@ def convertbits(
max_acc = (1 << (frombits + tobits - 1)) - 1
for value in data:
if value < 0 or (value >> frombits):
return None
raise ValueError # input value does not match `frombits` size
acc = ((acc << frombits) | value) & max_acc
bits += frombits
while bits >= tobits:
bits -= tobits
ret.append((acc >> bits) & maxv)
if pad:
if arbitrary_input:
if bits:
# append remaining bits, zero-padded from right
ret.append((acc << (tobits - bits)) & maxv)
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
return None
# (1) either there is a superfluous group at end of input, and/or
# (2) the remainder is nonzero
raise ValueError
return ret
@ -145,12 +163,15 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]:
return (None, None)
if hrpgot != hrp:
return (None, None)
decoded = convertbits(data[1:], 5, 8, False)
if decoded is None or len(decoded) < 2 or len(decoded) > 40:
try:
decoded = convertbits(data[1:], 5, 8, False)
except ValueError:
return (None, None)
if not 2 <= len(decoded) <= 40:
return (None, None)
if data[0] > 16:
return (None, None)
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
if data[0] == 0 and len(decoded) not in (20, 32):
return (None, None)
if (
data[0] == 0
@ -162,11 +183,9 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]:
return (data[0], decoded)
def encode(hrp: str, witver: int, witprog: Iterable[int]) -> str | None:
def encode(hrp: str, witver: int, witprog: bytes) -> str | None:
"""Encode a segwit address."""
data = convertbits(witprog, 8, 5)
if data is None:
return None
spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
ret = bech32_encode(hrp, [witver] + data, spec)
if decode(hrp, ret) == (None, None):

View File

@ -20,8 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
if False:
from typing import Iterable
from .bech32 import convertbits
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
ADDRESS_TYPE_P2KH = 0
@ -75,30 +74,6 @@ def b32encode(inputs: list[int]) -> str:
return out
def convertbits(
data: Iterable[int], frombits: int, tobits: int, pad: bool = True
) -> list[int]:
acc = 0
bits = 0
ret = []
maxv = (1 << tobits) - 1
max_acc = (1 << (frombits + tobits - 1)) - 1
for value in data:
if value < 0 or (value >> frombits):
raise ValueError
acc = ((acc << frombits) | value) & max_acc
bits += frombits
while bits >= tobits:
bits -= tobits
ret.append((acc >> bits) & maxv)
if pad:
if bits:
ret.append((acc << (tobits - bits)) & maxv)
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
raise ValueError
return ret
def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
payload_bytes = bytes([version]) + payload_bytes
payload = convertbits(payload_bytes, 8, 5)