mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 14:58:09 +00:00
refactor(core): improve type signature of bech32.convertbits
This commit is contained in:
parent
9fc5bb546b
commit
c3f2db3be5
@ -91,7 +91,6 @@ def address_from_public_key(pubkey: bytes, hrp: str) -> str:
|
||||
|
||||
h = sha256_ripemd160(pubkey).digest()
|
||||
|
||||
convertedbits = bech32.convertbits(h, 8, 5, False)
|
||||
assert convertedbits is not None
|
||||
|
||||
assert (len(h) * 8) % 5 == 0 # no padding will be added by convertbits
|
||||
convertedbits = bech32.convertbits(h, 8, 5)
|
||||
return bech32.bech32_encode(hrp, convertedbits, bech32.Encoding.BECH32)
|
||||
|
@ -16,8 +16,6 @@ HRP_SHARED_KEY_HASH = "addr_shared_vkh"
|
||||
|
||||
def encode(hrp: str, data: bytes) -> str:
|
||||
converted_bits = bech32.convertbits(data, 8, 5)
|
||||
if converted_bits is None:
|
||||
raise ValueError
|
||||
return bech32.bech32_encode(hrp, converted_bits, bech32.Encoding.BECH32)
|
||||
|
||||
|
||||
@ -40,7 +38,4 @@ def decode(hrp: str, bech: str) -> bytes:
|
||||
raise ValueError
|
||||
|
||||
decoded = bech32.convertbits(data, 5, 8, False)
|
||||
if decoded is None:
|
||||
raise ValueError
|
||||
|
||||
return bytes(decoded)
|
||||
|
@ -24,7 +24,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
from typing import Iterable, Union, TypeVar
|
||||
from typing import Sequence, Union, TypeVar
|
||||
|
||||
A = TypeVar("A")
|
||||
B = TypeVar("B")
|
||||
@ -112,9 +112,22 @@ def bech32_decode(
|
||||
|
||||
|
||||
def convertbits(
|
||||
data: Iterable[int], frombits: int, tobits: int, pad: bool = True
|
||||
) -> list[int] | None:
|
||||
"""General power-of-2 base conversion."""
|
||||
data: Sequence[int], frombits: int, tobits: int, arbitrary_input: bool = True
|
||||
) -> list[int]:
|
||||
"""General power-of-2 base conversion.
|
||||
|
||||
The `arbitrary_input` parameter specifies what happens when the total length
|
||||
of input bits is not a multiple of `tobits`.
|
||||
If True (default), the overflowing bits are zero-padded to the right.
|
||||
If False, the input must must be a valid output of `convertbits()` in the opposite
|
||||
direction.
|
||||
Namely:
|
||||
(a) the overflow must only be the zero padding
|
||||
(b) length of the overflow is less than `frombits`, meaning that there is no
|
||||
additional all-zero `frombits`-sized group at the end.
|
||||
If both conditions hold, the all-zero overflow is discarded.
|
||||
Otherwise a ValueError is raised.
|
||||
"""
|
||||
acc = 0
|
||||
bits = 0
|
||||
ret = []
|
||||
@ -122,17 +135,22 @@ def convertbits(
|
||||
max_acc = (1 << (frombits + tobits - 1)) - 1
|
||||
for value in data:
|
||||
if value < 0 or (value >> frombits):
|
||||
return None
|
||||
raise ValueError # input value does not match `frombits` size
|
||||
acc = ((acc << frombits) | value) & max_acc
|
||||
bits += frombits
|
||||
while bits >= tobits:
|
||||
bits -= tobits
|
||||
ret.append((acc >> bits) & maxv)
|
||||
if pad:
|
||||
|
||||
if arbitrary_input:
|
||||
if bits:
|
||||
# append remaining bits, zero-padded from right
|
||||
ret.append((acc << (tobits - bits)) & maxv)
|
||||
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
|
||||
return None
|
||||
# (1) either there is a superfluous group at end of input, and/or
|
||||
# (2) the remainder is nonzero
|
||||
raise ValueError
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@ -145,12 +163,15 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]:
|
||||
return (None, None)
|
||||
if hrpgot != hrp:
|
||||
return (None, None)
|
||||
decoded = convertbits(data[1:], 5, 8, False)
|
||||
if decoded is None or len(decoded) < 2 or len(decoded) > 40:
|
||||
try:
|
||||
decoded = convertbits(data[1:], 5, 8, False)
|
||||
except ValueError:
|
||||
return (None, None)
|
||||
if not 2 <= len(decoded) <= 40:
|
||||
return (None, None)
|
||||
if data[0] > 16:
|
||||
return (None, None)
|
||||
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
|
||||
if data[0] == 0 and len(decoded) not in (20, 32):
|
||||
return (None, None)
|
||||
if (
|
||||
data[0] == 0
|
||||
@ -162,11 +183,9 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]:
|
||||
return (data[0], decoded)
|
||||
|
||||
|
||||
def encode(hrp: str, witver: int, witprog: Iterable[int]) -> str | None:
|
||||
def encode(hrp: str, witver: int, witprog: bytes) -> str | None:
|
||||
"""Encode a segwit address."""
|
||||
data = convertbits(witprog, 8, 5)
|
||||
if data is None:
|
||||
return None
|
||||
spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
|
||||
ret = bech32_encode(hrp, [witver] + data, spec)
|
||||
if decode(hrp, ret) == (None, None):
|
||||
|
@ -20,8 +20,7 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
|
||||
if False:
|
||||
from typing import Iterable
|
||||
from .bech32 import convertbits
|
||||
|
||||
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
|
||||
ADDRESS_TYPE_P2KH = 0
|
||||
@ -75,30 +74,6 @@ def b32encode(inputs: list[int]) -> str:
|
||||
return out
|
||||
|
||||
|
||||
def convertbits(
|
||||
data: Iterable[int], frombits: int, tobits: int, pad: bool = True
|
||||
) -> list[int]:
|
||||
acc = 0
|
||||
bits = 0
|
||||
ret = []
|
||||
maxv = (1 << tobits) - 1
|
||||
max_acc = (1 << (frombits + tobits - 1)) - 1
|
||||
for value in data:
|
||||
if value < 0 or (value >> frombits):
|
||||
raise ValueError
|
||||
acc = ((acc << frombits) | value) & max_acc
|
||||
bits += frombits
|
||||
while bits >= tobits:
|
||||
bits -= tobits
|
||||
ret.append((acc >> bits) & maxv)
|
||||
if pad:
|
||||
if bits:
|
||||
ret.append((acc << (tobits - bits)) & maxv)
|
||||
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
|
||||
raise ValueError
|
||||
return ret
|
||||
|
||||
|
||||
def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
|
||||
payload_bytes = bytes([version]) + payload_bytes
|
||||
payload = convertbits(payload_bytes, 8, 5)
|
||||
|
Loading…
Reference in New Issue
Block a user