mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-06 06:42:33 +00:00
113 lines
3.5 KiB
Python
113 lines
3.5 KiB
Python
from typing import TYPE_CHECKING
|
|
|
|
from apps.monero.xmr import crypto_helpers
|
|
from apps.monero.xmr.keccak_hasher import KeccakXmrArchive
|
|
|
|
if TYPE_CHECKING:
|
|
from trezor.utils import HashContext
|
|
from .serialize_messages.tx_rsig_bulletproof import Bulletproof, BulletproofPlus
|
|
|
|
|
|
class PreMlsagHasher:
|
|
"""
|
|
Iterative construction of the pre_mlsag_hash
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.state = 0
|
|
self.kc_master: HashContext = crypto_helpers.get_keccak()
|
|
self.rsig_hasher: HashContext = crypto_helpers.get_keccak()
|
|
self.rtcsig_hasher: KeccakXmrArchive = KeccakXmrArchive()
|
|
|
|
def init(self) -> None:
|
|
if self.state != 0:
|
|
raise ValueError("State error")
|
|
|
|
self.state = 1
|
|
|
|
def set_message(self, message: bytes) -> None:
|
|
self.kc_master.update(message)
|
|
|
|
def set_type_fee(self, rv_type: int, fee: int) -> None:
|
|
if self.state != 1:
|
|
raise ValueError("State error")
|
|
self.state = 2
|
|
self.rtcsig_hasher.uint(rv_type, 1) # UInt8
|
|
self.rtcsig_hasher.uvarint(fee) # UVarintType
|
|
|
|
def set_ecdh(self, ecdh: bytes) -> None:
|
|
if self.state not in (2, 3, 4):
|
|
raise ValueError("State error")
|
|
self.state = 4
|
|
self.rtcsig_hasher.buffer(ecdh)
|
|
|
|
def set_out_pk_commitment(self, out_pk_commitment: bytes) -> None:
|
|
if self.state not in (4, 5):
|
|
raise ValueError("State error")
|
|
self.state = 5
|
|
self.rtcsig_hasher.buffer(out_pk_commitment) # ECKey
|
|
|
|
def rctsig_base_done(self) -> None:
|
|
if self.state != 5:
|
|
raise ValueError("State error")
|
|
self.state = 6
|
|
|
|
c_hash = self.rtcsig_hasher.get_digest()
|
|
self.kc_master.update(c_hash)
|
|
self.rtcsig_hasher = None # type: ignore
|
|
|
|
def rsig_val(
|
|
self, p: bytes | list[bytes] | Bulletproof | BulletproofPlus, raw: bool = False
|
|
) -> None:
|
|
if self.state == 8:
|
|
raise ValueError("State error")
|
|
|
|
if raw:
|
|
# Avoiding problem with the memory fragmentation.
|
|
# If the range proof is passed as a list, hash each element
|
|
# as the range proof is split to multiple byte arrays while
|
|
# preserving the byte ordering
|
|
if isinstance(p, list):
|
|
for x in p:
|
|
self.rsig_hasher.update(x)
|
|
else:
|
|
assert isinstance(p, bytes)
|
|
self.rsig_hasher.update(p)
|
|
return
|
|
|
|
from .serialize_messages.tx_rsig_bulletproof import Bulletproof, BulletproofPlus
|
|
|
|
is_plus = isinstance(p, BulletproofPlus)
|
|
assert isinstance(p, Bulletproof) or is_plus
|
|
|
|
# Hash Bulletproof
|
|
fields = (
|
|
(p.A, p.A1, p.B, p.r1, p.s1, p.d1)
|
|
if is_plus
|
|
else (p.A, p.S, p.T1, p.T2, p.taux, p.mu)
|
|
)
|
|
for fld in fields:
|
|
self.rsig_hasher.update(fld)
|
|
|
|
del (fields,)
|
|
for i in range(len(p.L)):
|
|
self.rsig_hasher.update(p.L[i])
|
|
for i in range(len(p.R)):
|
|
self.rsig_hasher.update(p.R[i])
|
|
|
|
if not is_plus:
|
|
self.rsig_hasher.update(p.a)
|
|
self.rsig_hasher.update(p.b)
|
|
self.rsig_hasher.update(p.t)
|
|
|
|
def get_digest(self) -> bytes:
|
|
if self.state != 6:
|
|
raise ValueError("State error")
|
|
self.state = 8
|
|
|
|
c_hash = self.rsig_hasher.digest()
|
|
self.rsig_hasher = None # type: ignore
|
|
|
|
self.kc_master.update(c_hash)
|
|
return self.kc_master.digest()
|