diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 13a4d6d5cb..6ba8c64dba 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -16,8 +16,10 @@ from __future__ import annotations +import copy import functools import hashlib +import inspect import re import struct import unicodedata @@ -46,6 +48,7 @@ if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec from . import client + from .messages import Success from .protobuf import MessageType MT = TypeVar("MT", bound=MessageType) @@ -310,6 +313,82 @@ def expect( return decorator +def _deprecation_retval_helper(value: Any, stacklevel: int = 0) -> Any: + stack = inspect.stack() + func_name = stack[stacklevel + 1].function + + warning_text = ( + f"The return value {value!r} of function {func_name}() " + "is deprecated and it will be removed in a future version." + ) + + # start with warnings disabled, otherwise we emit a lot of warnings while still + # constructing the deprecation warnings helper + warning_enabled = False + + def deprecation_warning_wrapper(orig_value: Callable[P, R]) -> Callable[P, R]: + def emit(*args: P.args, **kwargs: P.kwargs) -> R: + nonlocal warning_enabled + + if warning_enabled: + warnings.warn(warning_text, DeprecationWarning, stacklevel=2) + # only warn once per use + warning_enabled = False + return orig_value(*args, **kwargs) + + return emit + + # Deprecation wrapper class. + # Defined as empty at start. + class Deprecated(value.__class__): + pass + + # Here we install the deprecation_warning_wrapper for all dunder methods. + # This implicitly includes __getattribute__, which causes all non-dunder attribute + # accesses to also raise the warning. + for key in dir(value.__class__): + if not key.startswith("__"): + # skip non-dunder methods + continue + if key in ("__new__", "__init__", "__class__"): + # skip some problematic items + continue + orig_value = getattr(value.__class__, key) + if not callable(orig_value): + # skip non-functions + continue + # replace the method with a wrapper that emits a warning + setattr(Deprecated, key, deprecation_warning_wrapper(orig_value)) + + from .protobuf import MessageType + + # construct an instance: + if isinstance(value, str): + # for str, invoke the copy constructor + ret = Deprecated(value) + elif isinstance(value, MessageType): + # MessageTypes don't have a copy constructor, so + # 1. we make an explicit copy + value = copy.copy(value) + # 2. we change the class of the copy + value.__class__ = Deprecated + # note: we don't need deep copy because all accesses to inner objects already + # trigger the warning via __getattribute__ + ret = value + else: + # we don't support other types currently + raise NotImplementedError + + # enable warnings + warning_enabled = True + + return ret + + +def _return_success(msg: "Success") -> str | None: + return _deprecation_retval_helper(msg.message, stacklevel=1) + + def session( f: "Callable[Concatenate[TrezorClient, P], R]", ) -> "Callable[Concatenate[TrezorClient, P], R]": diff --git a/python/tests/test_tools.py b/python/tests/test_tools.py index 3bdda1fe97..a89d5a91f7 100644 --- a/python/tests/test_tools.py +++ b/python/tests/test_tools.py @@ -16,7 +16,7 @@ import pytest -from trezorlib import tools +from trezorlib import messages, tools VECTORS = ( # descriptor, checksum ( @@ -87,3 +87,50 @@ def test_b58encode(data_hex, encoding_b58): @pytest.mark.parametrize("data_hex,encoding_b58", BASE58_VECTORS) def test_b58decode(data_hex, encoding_b58): assert tools.b58decode(encoding_b58).hex() == data_hex + + +def test_return_success_deprecation(recwarn): + def mkfoo() -> str: + ret = tools._return_success(messages.Success(message="foo")) + assert ret is not None # too bad we can't hook "is None" check + return ret + + # check that just returning success will not cause a warning + mkfoo() + assert len(recwarn) == 0 + + with pytest.deprecated_call(): + # equality is deprecated + assert mkfoo() == "foo" + with pytest.deprecated_call(): + # comparison is deprecated + assert mkfoo() < "fooa" + with pytest.deprecated_call(): + # truthiness is deprecated + assert mkfoo() + with pytest.deprecated_call(): + # addition is deprecated (and hopefully all other operators) + assert mkfoo() + "a" == "fooa" + with pytest.deprecated_call(): + # indexing is deprecated + assert mkfoo()[0] == "f" + with pytest.deprecated_call(): + # methods are deprecated + assert mkfoo().startswith("f") + + +def test_deprecation_helper(recwarn): + def mkfoo() -> messages.Success: + return tools._deprecation_retval_helper(messages.Success(message="foo")) + + # check that just returning success will not cause a warning + mkfoo() + assert len(recwarn) == 0 + + with pytest.deprecated_call(): + # attributes are deprecated + assert mkfoo().message == "foo" + + with pytest.deprecated_call(): + # equality is deprecated (along with other operators hopefully) + assert mkfoo() != "foo"