mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-07 01:19:04 +00:00
refactor(core): improve type signature of bech32.convertbits
This commit is contained in:
parent
9fc5bb546b
commit
c3f2db3be5
@ -91,7 +91,6 @@ def address_from_public_key(pubkey: bytes, hrp: str) -> str:
|
|||||||
|
|
||||||
h = sha256_ripemd160(pubkey).digest()
|
h = sha256_ripemd160(pubkey).digest()
|
||||||
|
|
||||||
convertedbits = bech32.convertbits(h, 8, 5, False)
|
assert (len(h) * 8) % 5 == 0 # no padding will be added by convertbits
|
||||||
assert convertedbits is not None
|
convertedbits = bech32.convertbits(h, 8, 5)
|
||||||
|
|
||||||
return bech32.bech32_encode(hrp, convertedbits, bech32.Encoding.BECH32)
|
return bech32.bech32_encode(hrp, convertedbits, bech32.Encoding.BECH32)
|
||||||
|
@ -16,8 +16,6 @@ HRP_SHARED_KEY_HASH = "addr_shared_vkh"
|
|||||||
|
|
||||||
def encode(hrp: str, data: bytes) -> str:
|
def encode(hrp: str, data: bytes) -> str:
|
||||||
converted_bits = bech32.convertbits(data, 8, 5)
|
converted_bits = bech32.convertbits(data, 8, 5)
|
||||||
if converted_bits is None:
|
|
||||||
raise ValueError
|
|
||||||
return bech32.bech32_encode(hrp, converted_bits, bech32.Encoding.BECH32)
|
return bech32.bech32_encode(hrp, converted_bits, bech32.Encoding.BECH32)
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +38,4 @@ def decode(hrp: str, bech: str) -> bytes:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
decoded = bech32.convertbits(data, 5, 8, False)
|
decoded = bech32.convertbits(data, 5, 8, False)
|
||||||
if decoded is None:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
return bytes(decoded)
|
return bytes(decoded)
|
||||||
|
@ -24,7 +24,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Iterable, Union, TypeVar
|
from typing import Sequence, Union, TypeVar
|
||||||
|
|
||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
B = TypeVar("B")
|
B = TypeVar("B")
|
||||||
@ -112,9 +112,22 @@ def bech32_decode(
|
|||||||
|
|
||||||
|
|
||||||
def convertbits(
|
def convertbits(
|
||||||
data: Iterable[int], frombits: int, tobits: int, pad: bool = True
|
data: Sequence[int], frombits: int, tobits: int, arbitrary_input: bool = True
|
||||||
) -> list[int] | None:
|
) -> list[int]:
|
||||||
"""General power-of-2 base conversion."""
|
"""General power-of-2 base conversion.
|
||||||
|
|
||||||
|
The `arbitrary_input` parameter specifies what happens when the total length
|
||||||
|
of input bits is not a multiple of `tobits`.
|
||||||
|
If True (default), the overflowing bits are zero-padded to the right.
|
||||||
|
If False, the input must must be a valid output of `convertbits()` in the opposite
|
||||||
|
direction.
|
||||||
|
Namely:
|
||||||
|
(a) the overflow must only be the zero padding
|
||||||
|
(b) length of the overflow is less than `frombits`, meaning that there is no
|
||||||
|
additional all-zero `frombits`-sized group at the end.
|
||||||
|
If both conditions hold, the all-zero overflow is discarded.
|
||||||
|
Otherwise a ValueError is raised.
|
||||||
|
"""
|
||||||
acc = 0
|
acc = 0
|
||||||
bits = 0
|
bits = 0
|
||||||
ret = []
|
ret = []
|
||||||
@ -122,17 +135,22 @@ def convertbits(
|
|||||||
max_acc = (1 << (frombits + tobits - 1)) - 1
|
max_acc = (1 << (frombits + tobits - 1)) - 1
|
||||||
for value in data:
|
for value in data:
|
||||||
if value < 0 or (value >> frombits):
|
if value < 0 or (value >> frombits):
|
||||||
return None
|
raise ValueError # input value does not match `frombits` size
|
||||||
acc = ((acc << frombits) | value) & max_acc
|
acc = ((acc << frombits) | value) & max_acc
|
||||||
bits += frombits
|
bits += frombits
|
||||||
while bits >= tobits:
|
while bits >= tobits:
|
||||||
bits -= tobits
|
bits -= tobits
|
||||||
ret.append((acc >> bits) & maxv)
|
ret.append((acc >> bits) & maxv)
|
||||||
if pad:
|
|
||||||
|
if arbitrary_input:
|
||||||
if bits:
|
if bits:
|
||||||
|
# append remaining bits, zero-padded from right
|
||||||
ret.append((acc << (tobits - bits)) & maxv)
|
ret.append((acc << (tobits - bits)) & maxv)
|
||||||
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
|
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
|
||||||
return None
|
# (1) either there is a superfluous group at end of input, and/or
|
||||||
|
# (2) the remainder is nonzero
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@ -145,12 +163,15 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]:
|
|||||||
return (None, None)
|
return (None, None)
|
||||||
if hrpgot != hrp:
|
if hrpgot != hrp:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
decoded = convertbits(data[1:], 5, 8, False)
|
try:
|
||||||
if decoded is None or len(decoded) < 2 or len(decoded) > 40:
|
decoded = convertbits(data[1:], 5, 8, False)
|
||||||
|
except ValueError:
|
||||||
|
return (None, None)
|
||||||
|
if not 2 <= len(decoded) <= 40:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
if data[0] > 16:
|
if data[0] > 16:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
|
if data[0] == 0 and len(decoded) not in (20, 32):
|
||||||
return (None, None)
|
return (None, None)
|
||||||
if (
|
if (
|
||||||
data[0] == 0
|
data[0] == 0
|
||||||
@ -162,11 +183,9 @@ def decode(hrp: str, addr: str) -> OptionalTuple2[int, list[int]]:
|
|||||||
return (data[0], decoded)
|
return (data[0], decoded)
|
||||||
|
|
||||||
|
|
||||||
def encode(hrp: str, witver: int, witprog: Iterable[int]) -> str | None:
|
def encode(hrp: str, witver: int, witprog: bytes) -> str | None:
|
||||||
"""Encode a segwit address."""
|
"""Encode a segwit address."""
|
||||||
data = convertbits(witprog, 8, 5)
|
data = convertbits(witprog, 8, 5)
|
||||||
if data is None:
|
|
||||||
return None
|
|
||||||
spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
|
spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
|
||||||
ret = bech32_encode(hrp, [witver] + data, spec)
|
ret = bech32_encode(hrp, [witver] + data, spec)
|
||||||
if decode(hrp, ret) == (None, None):
|
if decode(hrp, ret) == (None, None):
|
||||||
|
@ -20,8 +20,7 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
# THE SOFTWARE.
|
# THE SOFTWARE.
|
||||||
|
|
||||||
if False:
|
from .bech32 import convertbits
|
||||||
from typing import Iterable
|
|
||||||
|
|
||||||
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
|
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
|
||||||
ADDRESS_TYPE_P2KH = 0
|
ADDRESS_TYPE_P2KH = 0
|
||||||
@ -75,30 +74,6 @@ def b32encode(inputs: list[int]) -> str:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def convertbits(
|
|
||||||
data: Iterable[int], frombits: int, tobits: int, pad: bool = True
|
|
||||||
) -> list[int]:
|
|
||||||
acc = 0
|
|
||||||
bits = 0
|
|
||||||
ret = []
|
|
||||||
maxv = (1 << tobits) - 1
|
|
||||||
max_acc = (1 << (frombits + tobits - 1)) - 1
|
|
||||||
for value in data:
|
|
||||||
if value < 0 or (value >> frombits):
|
|
||||||
raise ValueError
|
|
||||||
acc = ((acc << frombits) | value) & max_acc
|
|
||||||
bits += frombits
|
|
||||||
while bits >= tobits:
|
|
||||||
bits -= tobits
|
|
||||||
ret.append((acc >> bits) & maxv)
|
|
||||||
if pad:
|
|
||||||
if bits:
|
|
||||||
ret.append((acc << (tobits - bits)) & maxv)
|
|
||||||
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
|
|
||||||
raise ValueError
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
|
def encode(prefix: str, version: int, payload_bytes: bytes) -> str:
|
||||||
payload_bytes = bytes([version]) + payload_bytes
|
payload_bytes = bytes([version]) + payload_bytes
|
||||||
payload = convertbits(payload_bytes, 8, 5)
|
payload = convertbits(payload_bytes, 8, 5)
|
||||||
|
Loading…
Reference in New Issue
Block a user