1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-30 03:18:20 +00:00

chore(trezorlib): adjust benchmark feature for THP

[no changelog]
This commit is contained in:
M1nd3r 2024-10-14 18:37:29 +02:00
parent b4689a1191
commit 571898bc0f
3 changed files with 17 additions and 19 deletions

View File

@ -20,17 +20,17 @@ from . import messages
from .tools import expect from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType from .protobuf import MessageType
from .transport.session import Session
@expect(messages.BenchmarkNames) @expect(messages.BenchmarkNames)
def list_names( def list_names(
client: "TrezorClient", session: "Session",
) -> "MessageType": ) -> "MessageType":
return client.call(messages.BenchmarkListNames()) return session.call(messages.BenchmarkListNames())
@expect(messages.BenchmarkResult) @expect(messages.BenchmarkResult)
def run(client: "TrezorClient", name: str) -> "MessageType": def run(session: "Session", name: str) -> "MessageType":
return client.call(messages.BenchmarkRun(name=name)) return session.call(messages.BenchmarkRun(name=name))

View File

@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional
import click import click
from .. import benchmark from .. import benchmark
from . import with_client from . import with_session
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..transport.session import Session
def list_names_patern( def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]:
client: "TrezorClient", pattern: Optional[str] = None names = list(benchmark.list_names(session).names)
) -> List[str]:
names = list(benchmark.list_names(client).names)
if pattern is None: if pattern is None:
return names return names
return [name for name in names if fnmatch(name, pattern)] return [name for name in names if fnmatch(name, pattern)]
@ -43,10 +41,10 @@ def cli() -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_client @with_session
def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: def list_names(session: "Session", pattern: Optional[str] = None) -> None:
"""List names of all supported benchmarks""" """List names of all supported benchmarks"""
names = list_names_patern(client, pattern) names = list_names_patern(session, pattern)
if len(names) == 0: if len(names) == 0:
click.echo("No benchmark satisfies the pattern.") click.echo("No benchmark satisfies the pattern.")
else: else:
@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None:
@cli.command() @cli.command()
@click.argument("pattern", required=False) @click.argument("pattern", required=False)
@with_client @with_session
def run(client: "TrezorClient", pattern: Optional[str]) -> None: def run(session: "Session", pattern: Optional[str]) -> None:
"""Run benchmark""" """Run benchmark"""
names = list_names_patern(client, pattern) names = list_names_patern(session, pattern)
if len(names) == 0: if len(names) == 0:
click.echo("No benchmark satisfies the pattern.") click.echo("No benchmark satisfies the pattern.")
else: else:
for name in names: for name in names:
result = benchmark.run(client, name) result = benchmark.run(session, name)
click.echo(f"{name}: {result.value} {result.unit}") click.echo(f"{name}: {result.value} {result.unit}")

View File

@ -32,8 +32,8 @@ from ..transport.session import Session
from ..transport.udp import UdpTransport from ..transport.udp import UdpTransport
from . import ( from . import (
AliasedGroup, AliasedGroup,
benchmark,
NewTrezorConnection, NewTrezorConnection,
benchmark,
binance, binance,
btc, btc,
cardano, cardano,