1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 03:50:58 +00:00

feat(core): implement benchmark application

This commit is contained in:
Ondřej Vejpustek 2024-08-08 14:22:27 +02:00
parent b436b39091
commit 52d85d1f39
12 changed files with 424 additions and 0 deletions

View File

@ -0,0 +1 @@
Added benchmark application.

View File

@ -22,6 +22,7 @@ FEATURE_FLAGS = {
"RDI": True,
"SECP256K1_ZKP": True, # required for trezor.crypto.curve.bip340 (BIP340/Taproot)
"AES_GCM": False,
"AES_GCM": BENCHMARK,
}
FEATURES_WANTED = ["input", "sbu", "sd_card", "rgb_led", "dma2d", "consumption_mask", "usb" ,"optiga", "haptic"]

View File

View File

@ -0,0 +1,29 @@
import utime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Protocol
from trezor.messages import BenchmarkResult
class Benchmark(Protocol):
def prepare(self) -> None: ...
def run(self) -> None: ...
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult: ...
def run_benchmark(benchmark: Benchmark) -> BenchmarkResult:
minimum_duration_s = 1
minimum_duration_us = minimum_duration_s * 1000000
benchmark.prepare()
start_time_us = utime.ticks_us()
repetitions = 0
while True:
benchmark.run()
repetitions += 1
duration_us = utime.ticks_diff(utime.ticks_us(), start_time_us)
if duration_us > minimum_duration_us:
break
return benchmark.get_result(duration_us, repetitions)

View File

@ -0,0 +1,95 @@
from trezor.crypto import aes, aesgcm, chacha20poly1305
from trezor.crypto.curve import curve25519, ed25519, nist256p1, secp256k1
from trezor.crypto.hashlib import (
blake2b,
blake2s,
blake256,
groestl512,
ripemd160,
sha1,
sha3_256,
sha3_512,
sha256,
sha512,
)
from .cipher_benchmark import DecryptBenchmark, EncryptBenchmark
from .common import random_bytes
from .curve_benchmark import (
MultiplyBenchmark,
PublickeyBenchmark,
SignBenchmark,
VerifyBenchmark,
)
from .hash_benchmark import HashBenchmark
# This is a wrapper above the trezor.crypto.curve.ed25519 module that satisfies SignCurve protocol, the modules uses `message` instead of `digest` in `sign()` and `verify()`
class Ed25519:
def __init__(self):
pass
def generate_secret(self) -> bytes:
return ed25519.generate_secret()
def publickey(self, secret_key: bytes) -> bytes:
return ed25519.publickey(secret_key)
def sign(self, secret_key: bytes, digest: bytes) -> bytes:
# ed25519.sign(secret_key: bytes, message: bytes, hasher: str = "") -> bytes:
return ed25519.sign(secret_key, digest)
def verify(self, public_key: bytes, signature: bytes, digest: bytes) -> bool:
# ed25519.verify(public_key: bytes, signature: bytes, message: bytes) -> bool:
return ed25519.verify(public_key, signature, digest)
benchmarks = {
"crypto/hash/blake2b": HashBenchmark(lambda: blake2b()),
"crypto/hash/blake2s": HashBenchmark(lambda: blake2s()),
"crypto/hash/blake256": HashBenchmark(lambda: blake256()),
"crypto/hash/groestl512": HashBenchmark(lambda: groestl512()),
"crypto/hash/ripemd160": HashBenchmark(lambda: ripemd160()),
"crypto/hash/sha1": HashBenchmark(lambda: sha1()),
"crypto/hash/sha3_256": HashBenchmark(lambda: sha3_256()),
"crypto/hash/sha3_512": HashBenchmark(lambda: sha3_512()),
"crypto/hash/sha256": HashBenchmark(lambda: sha256()),
"crypto/hash/sha512": HashBenchmark(lambda: sha512()),
"crypto/cipher/aes128-ecb/encrypt": EncryptBenchmark(
lambda: aes(aes.ECB, random_bytes(16), random_bytes(16)), 16
),
"crypto/cipher/aes128-ecb/decrypt": DecryptBenchmark(
lambda: aes(aes.ECB, random_bytes(16), random_bytes(16)), 16
),
"crypto/cipher/aesgcm128/encrypt": EncryptBenchmark(
lambda: aesgcm(random_bytes(16), random_bytes(16)), 16
),
"crypto/cipher/aesgcm128/decrypt": DecryptBenchmark(
lambda: aesgcm(random_bytes(16), random_bytes(16)), 16
),
"crypto/cipher/aesgcm256/encrypt": EncryptBenchmark(
lambda: aesgcm(random_bytes(32), random_bytes(16)), 16
),
"crypto/cipher/aesgcm256/decrypt": DecryptBenchmark(
lambda: aesgcm(random_bytes(32), random_bytes(16)), 16
),
"crypto/cipher/chacha20poly1305/encrypt": EncryptBenchmark(
lambda: chacha20poly1305(random_bytes(32), random_bytes(12)), 64
),
"crypto/cipher/chacha20poly1305/decrypt": DecryptBenchmark(
lambda: chacha20poly1305(random_bytes(32), random_bytes(12)), 64
),
"crypto/curve/secp256k1/sign": SignBenchmark(secp256k1),
"crypto/curve/secp256k1/verify": VerifyBenchmark(secp256k1),
"crypto/curve/secp256k1/publickey": PublickeyBenchmark(secp256k1),
"crypto/curve/secp256k1/multiply": MultiplyBenchmark(secp256k1),
"crypto/curve/nist256p1/sign": SignBenchmark(nist256p1),
"crypto/curve/nist256p1/verify": VerifyBenchmark(nist256p1),
"crypto/curve/nist256p1/publickey": PublickeyBenchmark(nist256p1),
"crypto/curve/nist256p1/multiply": MultiplyBenchmark(nist256p1),
"crypto/curve/ed25519/sign": SignBenchmark(Ed25519()),
"crypto/curve/ed25519/verify": VerifyBenchmark(Ed25519()),
"crypto/curve/ed25519/publickey": PublickeyBenchmark(ed25519),
"crypto/curve/curve25519/publickey": PublickeyBenchmark(curve25519),
"crypto/curve/curve25519/multiply": MultiplyBenchmark(curve25519),
}

View File

@ -0,0 +1,69 @@
from typing import TYPE_CHECKING, Callable
from trezor.messages import BenchmarkResult
from .common import format_float, maximum_used_memory_in_bytes, random_bytes
if TYPE_CHECKING:
from typing import Protocol
class CipherCtx(Protocol):
def encrypt(self, data: bytes) -> bytes: ...
def decrypt(self, data: bytes) -> bytes: ...
class EncryptBenchmark:
def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
):
self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size
def prepare(self):
self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size)
def run(self):
for _ in range(self.iterations_count):
self.cipher_ctx.encrypt(self.data)
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
value = (repetitions * self.iterations_count * len(self.data) * 1000 * 1000) / (
duration_us * 1024 * 1024
)
return BenchmarkResult(
value=format_float(value),
unit="MB/s",
)
class DecryptBenchmark:
def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
):
self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size
def prepare(self):
self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size)
def run(self):
for _ in range(self.iterations_count):
self.cipher_ctx.decrypt(self.data)
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
value = (repetitions * self.iterations_count * len(self.data) * 1000 * 1000) / (
duration_us * 1024 * 1024
)
return BenchmarkResult(
value=format_float(value),
unit="MB/s",
)

View File

@ -0,0 +1,43 @@
from trezor.crypto import random
maximum_used_memory_in_bytes = 10 * 1024
# Round a float to 2 significant digits and return it as a string, do not use scientific notation
def format_float(value: float) -> str:
def get_magnitude(value: float) -> int:
if value == 0:
return 0
if value < 0:
value = -value
magnitude = 0
if value < 1:
while value < 1:
value = 10 * value
magnitude -= 1
else:
while value >= 10:
value = value / 10
magnitude += 1
return magnitude
significant_digits = 2
precision_digits = significant_digits - get_magnitude(value) - 1
rounded_value = round(value, precision_digits)
return f"{rounded_value:.{max(0, precision_digits)}f}"
def random_bytes(length: int) -> bytes:
# Fast linear congruential generator from Numerical Recipes
def lcg(seed: int) -> int:
return (1664525 * seed + 1013904223) & 0xFFFFFFFF
array = bytearray(length)
seed = random.uniform(0xFFFFFFFF)
for i in range(length):
seed = lcg(seed)
array[i] = seed & 0xFF
return bytes(array)

View File

@ -0,0 +1,108 @@
from typing import TYPE_CHECKING
from trezor.messages import BenchmarkResult
from .common import format_float, random_bytes
if TYPE_CHECKING:
from typing import Protocol
class Curve(Protocol):
def generate_secret(self) -> bytes: ...
def publickey(self, secret_key: bytes) -> bytes: ...
class SignCurve(Curve, Protocol):
def sign(self, secret_key: bytes, digest: bytes) -> bytes: ...
def verify(
self, public_key: bytes, signature: bytes, digest: bytes
) -> bool: ...
def generate_secret(self) -> bytes: ...
def publickey(self, secret_key: bytes) -> bytes: ...
class MultiplyCurve(Curve, Protocol):
def generate_secret(self) -> bytes: ...
def publickey(self, secret_key: bytes) -> bytes: ...
def multiply(self, secret_key: bytes, public_key: bytes) -> bytes: ...
class SignBenchmark:
def __init__(self, curve: SignCurve):
self.curve = curve
def prepare(self):
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
self.digest = random_bytes(32)
def run(self):
for _ in range(self.iterations_count):
self.curve.sign(self.secret_key, self.digest)
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
value = duration_us / (repetitions * self.iterations_count * 1000)
return BenchmarkResult(value=format_float(value), unit="ms")
class VerifyBenchmark:
def __init__(self, curve: SignCurve):
self.curve = curve
def prepare(self):
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.secret_key)
self.digest = random_bytes(32)
self.signature = self.curve.sign(self.secret_key, self.digest)
def run(self):
for _ in range(self.iterations_count):
self.curve.verify(self.public_key, self.signature, self.digest)
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
value = duration_us / (repetitions * self.iterations_count * 1000)
return BenchmarkResult(value=format_float(value), unit="ms")
class MultiplyBenchmark:
def __init__(self, curve: MultiplyCurve):
self.curve = curve
def prepare(self):
self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.curve.generate_secret())
self.iterations_count = 10
def run(self):
for _ in range(self.iterations_count):
self.curve.multiply(self.secret_key, self.public_key)
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
value = duration_us / (repetitions * self.iterations_count * 1000)
return BenchmarkResult(value=format_float(value), unit="ms")
class PublickeyBenchmark:
def __init__(self, curve: Curve):
self.curve = curve
def prepare(self):
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
def run(self):
for _ in range(self.iterations_count):
self.curve.publickey(self.secret_key)
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
value = duration_us / (repetitions * self.iterations_count * 1000)
return BenchmarkResult(value=format_float(value), unit="ms")

View File

@ -0,0 +1,38 @@
from typing import TYPE_CHECKING, Callable
from trezor.messages import BenchmarkResult
from .common import format_float, maximum_used_memory_in_bytes, random_bytes
if TYPE_CHECKING:
from typing import Protocol
class HashCtx(Protocol):
block_size: int
def update(self, __buf: bytes) -> None: ...
class HashBenchmark:
def __init__(self, hash_ctx_constructor: Callable[[], HashCtx]):
self.hash_ctx_constructor = hash_ctx_constructor
def prepare(self):
self.hash_ctx = self.hash_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.hash_ctx.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.hash_ctx.block_size)
def run(self):
for _ in range(self.iterations_count):
self.hash_ctx.update(self.data)
def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
value = (repetitions * self.iterations_count * len(self.data) * 1000 * 1000) / (
duration_us * 1024 * 1024
)
return BenchmarkResult(
value=format_float(value),
unit="MB/s",
)

View File

@ -0,0 +1,15 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from trezor.messages import BenchmarkListNames, BenchmarkNames
from .benchmarks import benchmarks
async def list_names(msg: BenchmarkListNames) -> "BenchmarkNames":
from trezor.messages import BenchmarkNames
names = list(benchmarks.keys())
sorted_names = sorted(names)
return BenchmarkNames(names=sorted_names)

View File

@ -0,0 +1,19 @@
from typing import TYPE_CHECKING
from .benchmark import run_benchmark
from .benchmarks import benchmarks
if TYPE_CHECKING:
from trezor.messages import BenchmarkResult, BenchmarkRun
async def run(msg: BenchmarkRun) -> BenchmarkResult:
benchmark_name = msg.name
if benchmark_name not in benchmarks:
raise ValueError("Benchmark not found")
benchmark = benchmarks[benchmark_name]
result = run_benchmark(benchmark)
return result

View File

@ -206,6 +206,12 @@ def _find_message_handler_module(msg_type: int) -> str:
if msg_type == MessageType.SolanaSignTx:
return "apps.solana.sign_tx"
# benchmark
if msg_type == MessageType.BenchmarkListNames:
return "apps.benchmark.list_names"
if msg_type == MessageType.BenchmarkRun:
return "apps.benchmark.run"
raise ValueError