1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-08 13:42:41 +00:00

feat(python): allow trezorctl firmware commands to work with unknown models

[no changelog]
This commit is contained in:
matejcik 2025-01-20 11:34:12 +01:00
parent 9862bc7123
commit 138a6410fc
4 changed files with 34 additions and 5 deletions

View File

@ -92,7 +92,13 @@ def _print_firmware_model(hw_model: Union[bytes, fw_models.Model]) -> None:
click.echo(f"{model_name} firmware image.") click.echo(f"{model_name} firmware image.")
return return
except ValueError: except ValueError:
pass assert isinstance(hw_model, bytes)
if hw_model.isascii():
model_name = hw_model.decode("ascii")
click.echo(f"Unrecognized hardware model: {model_name}")
return
else:
click.echo(f"Invalid model field: {hw_model.hex()}")
assert isinstance(hw_model, bytes) assert isinstance(hw_model, bytes)
if all(0x20 <= b < 0x80 for b in hw_model): # isascii if all(0x20 <= b < 0x80 for b in hw_model): # isascii
@ -404,7 +410,7 @@ def validate_firmware(
fingerprint: Optional[str] = None, fingerprint: Optional[str] = None,
model: Optional[TrezorModel] = None, model: Optional[TrezorModel] = None,
bootloader_onev2: Optional[bool] = None, bootloader_onev2: Optional[bool] = None,
prompt_unsigned: bool = True, verify_only: bool = False,
) -> None: ) -> None:
"""Validate the firmware through multiple tests. """Validate the firmware through multiple tests.
@ -419,8 +425,14 @@ def validate_firmware(
sys.exit(2) sys.exit(2)
print_firmware_version(fw) print_firmware_version(fw)
if not fw.model():
click.echo("Cannot validate firmware for unrecognized model.")
if not verify_only:
click.echo("(Hint: use --skip-check to skip validation.)")
sys.exit(3)
validate_fingerprint(fw, fingerprint) validate_fingerprint(fw, fingerprint)
validate_signatures(fw, prompt_unsigned=prompt_unsigned) validate_signatures(fw, prompt_unsigned=not verify_only)
if model is not None and bootloader_onev2 is not None: if model is not None and bootloader_onev2 is not None:
check_device_match(fw, model, bootloader_onev2) check_device_match(fw, model, bootloader_onev2)
@ -548,7 +560,7 @@ def verify(
fingerprint=fingerprint, fingerprint=fingerprint,
bootloader_onev2=bootloader_onev2, bootloader_onev2=bootloader_onev2,
model=model, model=model,
prompt_unsigned=False, verify_only=True,
) )

View File

@ -21,6 +21,7 @@ from typing_extensions import Protocol, TypeGuard
from .. import messages from .. import messages
from ..tools import session from ..tools import session
from .models import Model
from .core import VendorFirmware from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware from .legacy import LegacyFirmware, LegacyV2Firmware
@ -50,6 +51,8 @@ if t.TYPE_CHECKING:
def digest(self) -> bytes: ... def digest(self) -> bytes: ...
def model(self) -> Model | None: ...
def parse(data: bytes) -> "FirmwareType": def parse(data: bytes) -> "FirmwareType":
try: try:

View File

@ -126,7 +126,10 @@ class FirmwareImage(Struct):
@staticmethod @staticmethod
def calc_padding(hw_model: bytes, len: int) -> int: def calc_padding(hw_model: bytes, len: int) -> int:
try:
alignment = Model.from_hw_model(hw_model).code_alignment() alignment = Model.from_hw_model(hw_model).code_alignment()
except ValueError:
alignment = Model.T3W1.code_alignment()
return ((len + alignment - 1) & ~(alignment - 1)) - len return ((len + alignment - 1) & ~(alignment - 1)) - len
def get_hash_params(self) -> "util.FirmwareHashParameters": def get_hash_params(self) -> "util.FirmwareHashParameters":
@ -176,6 +179,11 @@ class FirmwareImage(Struct):
header.v1_signatures = [b"\x00" * 64] * consts.V1_SIGNATURE_SLOTS header.v1_signatures = [b"\x00" * 64] * consts.V1_SIGNATURE_SLOTS
return hash_params.hash_function(header.build()).digest() return hash_params.hash_function(header.build()).digest()
def model(self) -> Model | None:
if isinstance(self.header.hw_model, Model):
return self.header.hw_model
return None
class VendorFirmware(Struct): class VendorFirmware(Struct):
"""Firmware image prefixed by a vendor header. """Firmware image prefixed by a vendor header.
@ -214,3 +222,6 @@ class VendorFirmware(Struct):
# now = time.gmtime() # now = time.gmtime()
# if time.gmtime(fw.vendor_header.expiry) < now: # if time.gmtime(fw.vendor_header.expiry) < now:
# raise ValueError("Vendor header expired.") # raise ValueError("Vendor header expired.")
def model(self) -> Model | None:
return self.firmware.model()

View File

@ -202,3 +202,6 @@ class LegacyFirmware(Struct):
if self.embedded_v2: if self.embedded_v2:
self.embedded_v2.verify() self.embedded_v2.verify()
def model(self) -> Model | None:
return Model.T1B1