diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index a9de39af82..6537392986 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -14,21 +14,28 @@ # You should have received a copy of the License along with this library. # If not, see . -from . import messages as proto -from .tools import expect +from . import messages +from .tools import expect, Address + +if False: + from .client import TrezorClient -@expect(proto.Entropy, field="entropy") -def get_entropy(client, size): - return client.call(proto.GetEntropy(size=size)) +@expect(messages.Entropy, field="entropy") +def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy: + return client.call(messages.GetEntropy(size=size)) -@expect(proto.SignedIdentity) +@expect(messages.SignedIdentity) def sign_identity( - client, identity, challenge_hidden, challenge_visual, ecdsa_curve_name=None -): + client: "TrezorClient", + identity: messages.IdentityType, + challenge_hidden: bytes, + challenge_visual: str, + ecdsa_curve_name: str = None, +) -> messages.SignedIdentity: return client.call( - proto.SignIdentity( + messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, challenge_visual=challenge_visual, @@ -37,10 +44,15 @@ def sign_identity( ) -@expect(proto.ECDHSessionKey) -def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=None): +@expect(messages.ECDHSessionKey) +def get_ecdh_session_key( + client: "TrezorClient", + identity: messages.IdentityType, + peer_public_key: bytes, + ecdsa_curve_name: str = None, +) -> messages.ECDHSessionKey: return client.call( - proto.GetECDHSessionKey( + messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name, @@ -48,12 +60,18 @@ def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=Non ) -@expect(proto.CipheredKeyValue, field="value") +@expect(messages.CipheredKeyValue, field="value") def encrypt_keyvalue( - client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b"" -): + client: "TrezorClient", + n: Address, + key: str, + value: bytes, + ask_on_encrypt: bool = True, + ask_on_decrypt: bool = True, + iv: bytes = b"", +) -> messages.CipheredKeyValue: return client.call( - proto.CipherKeyValue( + messages.CipherKeyValue( address_n=n, key=key, value=value, @@ -65,12 +83,18 @@ def encrypt_keyvalue( ) -@expect(proto.CipheredKeyValue, field="value") +@expect(messages.CipheredKeyValue, field="value") def decrypt_keyvalue( - client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b"" -): + client: "TrezorClient", + n: Address, + key: str, + value: bytes, + ask_on_encrypt: bool = True, + ask_on_decrypt: bool = True, + iv: bytes = b"", +) -> messages.CipheredKeyValue: return client.call( - proto.CipherKeyValue( + messages.CipherKeyValue( address_n=n, key=key, value=value, diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index 3d9e8434c8..296a7597f8 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -411,11 +411,9 @@ def dump_message(writer: Writer, msg: MessageType) -> None: writer.write(svalue) elif ftype is UnicodeType: - if not isinstance(svalue, bytes): - svalue = svalue.encode() - - dump_uvarint(writer, len(svalue)) - writer.write(svalue) + svalue_bytes = svalue.encode() + dump_uvarint(writer, len(svalue_bytes)) + writer.write(svalue_bytes) elif issubclass(ftype, MessageType): counter = CountingWriter() diff --git a/tests/device_tests/test_msg_cipherkeyvalue.py b/tests/device_tests/test_msg_cipherkeyvalue.py index e09aa530a4..c42b75bf3e 100644 --- a/tests/device_tests/test_msg_cipherkeyvalue.py +++ b/tests/device_tests/test_msg_cipherkeyvalue.py @@ -27,7 +27,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.encrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", b"testing message!", ask_on_encrypt=True, ask_on_decrypt=True, @@ -37,7 +37,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.encrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", b"testing message!", ask_on_encrypt=True, ask_on_decrypt=False, @@ -47,7 +47,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.encrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", b"testing message!", ask_on_encrypt=False, ask_on_decrypt=True, @@ -57,7 +57,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.encrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", b"testing message!", ask_on_encrypt=False, ask_on_decrypt=False, @@ -68,7 +68,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.encrypt_keyvalue( client, [0, 1, 2], - b"test2", + "test2", b"testing message!", ask_on_encrypt=True, ask_on_decrypt=True, @@ -79,7 +79,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.encrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", b"testing message! it is different", ask_on_encrypt=True, ask_on_decrypt=True, @@ -93,7 +93,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.encrypt_keyvalue( client, [0, 1, 3], - b"test", + "test", b"testing message!", ask_on_encrypt=True, ask_on_decrypt=True, @@ -105,7 +105,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.decrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", bytes.fromhex("676faf8f13272af601776bc31bc14e8f"), ask_on_encrypt=True, ask_on_decrypt=True, @@ -115,7 +115,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.decrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"), ask_on_encrypt=True, ask_on_decrypt=False, @@ -125,7 +125,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.decrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", bytes.fromhex("958d4f63269b61044aaedc900c8d6208"), ask_on_encrypt=False, ask_on_decrypt=True, @@ -135,7 +135,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.decrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"), ask_on_encrypt=False, ask_on_decrypt=False, @@ -146,7 +146,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.decrypt_keyvalue( client, [0, 1, 2], - b"test2", + "test2", bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"), ask_on_encrypt=True, ask_on_decrypt=True, @@ -157,7 +157,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.decrypt_keyvalue( client, [0, 1, 2], - b"test", + "test", bytes.fromhex( "676faf8f13272af601776bc31bc14e8f3ae1c88536bf18f1b44f1e4c2c4a613d" ), @@ -170,7 +170,7 @@ class TestMsgCipherkeyvalue(TrezorTest): res = misc.decrypt_keyvalue( client, [0, 1, 3], - b"test", + "test", bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"), ask_on_encrypt=True, ask_on_decrypt=True, @@ -179,8 +179,8 @@ class TestMsgCipherkeyvalue(TrezorTest): def test_encrypt_badlen(self, client): with pytest.raises(Exception): - misc.encrypt_keyvalue(client, [0, 1, 2], b"test", b"testing") + misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing") def test_decrypt_badlen(self, client): with pytest.raises(Exception): - misc.decrypt_keyvalue(client, [0, 1, 2], b"test", b"testing") + misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing")