1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-26 23:32:03 +00:00

python: do not accept bytes for str protobuf fields

fixes #283

also adds typing information to misc.py
This commit is contained in:
matejcik 2019-08-29 13:56:09 +02:00
parent ab74f55a95
commit 5b8f542436
3 changed files with 63 additions and 41 deletions

View File

@ -14,21 +14,28 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages as proto from . import messages
from .tools import expect from .tools import expect, Address
if False:
from .client import TrezorClient
@expect(proto.Entropy, field="entropy") @expect(messages.Entropy, field="entropy")
def get_entropy(client, size): def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy:
return client.call(proto.GetEntropy(size=size)) return client.call(messages.GetEntropy(size=size))
@expect(proto.SignedIdentity) @expect(messages.SignedIdentity)
def sign_identity( def sign_identity(
client, identity, challenge_hidden, challenge_visual, ecdsa_curve_name=None client: "TrezorClient",
): identity: messages.IdentityType,
challenge_hidden: bytes,
challenge_visual: str,
ecdsa_curve_name: str = None,
) -> messages.SignedIdentity:
return client.call( return client.call(
proto.SignIdentity( messages.SignIdentity(
identity=identity, identity=identity,
challenge_hidden=challenge_hidden, challenge_hidden=challenge_hidden,
challenge_visual=challenge_visual, challenge_visual=challenge_visual,
@ -37,10 +44,15 @@ def sign_identity(
) )
@expect(proto.ECDHSessionKey) @expect(messages.ECDHSessionKey)
def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=None): def get_ecdh_session_key(
client: "TrezorClient",
identity: messages.IdentityType,
peer_public_key: bytes,
ecdsa_curve_name: str = None,
) -> messages.ECDHSessionKey:
return client.call( return client.call(
proto.GetECDHSessionKey( messages.GetECDHSessionKey(
identity=identity, identity=identity,
peer_public_key=peer_public_key, peer_public_key=peer_public_key,
ecdsa_curve_name=ecdsa_curve_name, ecdsa_curve_name=ecdsa_curve_name,
@ -48,12 +60,18 @@ def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=Non
) )
@expect(proto.CipheredKeyValue, field="value") @expect(messages.CipheredKeyValue, field="value")
def encrypt_keyvalue( def encrypt_keyvalue(
client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b"" client: "TrezorClient",
): n: Address,
key: str,
value: bytes,
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> messages.CipheredKeyValue:
return client.call( return client.call(
proto.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
value=value, value=value,
@ -65,12 +83,18 @@ def encrypt_keyvalue(
) )
@expect(proto.CipheredKeyValue, field="value") @expect(messages.CipheredKeyValue, field="value")
def decrypt_keyvalue( def decrypt_keyvalue(
client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b"" client: "TrezorClient",
): n: Address,
key: str,
value: bytes,
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> messages.CipheredKeyValue:
return client.call( return client.call(
proto.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
key=key, key=key,
value=value, value=value,

View File

@ -411,11 +411,9 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
writer.write(svalue) writer.write(svalue)
elif ftype is UnicodeType: elif ftype is UnicodeType:
if not isinstance(svalue, bytes): svalue_bytes = svalue.encode()
svalue = svalue.encode() dump_uvarint(writer, len(svalue_bytes))
writer.write(svalue_bytes)
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif issubclass(ftype, MessageType): elif issubclass(ftype, MessageType):
counter = CountingWriter() counter = CountingWriter()

View File

@ -27,7 +27,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue( res = misc.encrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
b"testing message!", b"testing message!",
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -37,7 +37,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue( res = misc.encrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
b"testing message!", b"testing message!",
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=False, ask_on_decrypt=False,
@ -47,7 +47,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue( res = misc.encrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
b"testing message!", b"testing message!",
ask_on_encrypt=False, ask_on_encrypt=False,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -57,7 +57,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue( res = misc.encrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
b"testing message!", b"testing message!",
ask_on_encrypt=False, ask_on_encrypt=False,
ask_on_decrypt=False, ask_on_decrypt=False,
@ -68,7 +68,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue( res = misc.encrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test2", "test2",
b"testing message!", b"testing message!",
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -79,7 +79,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue( res = misc.encrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
b"testing message! it is different", b"testing message! it is different",
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -93,7 +93,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue( res = misc.encrypt_keyvalue(
client, client,
[0, 1, 3], [0, 1, 3],
b"test", "test",
b"testing message!", b"testing message!",
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -105,7 +105,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue( res = misc.decrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
bytes.fromhex("676faf8f13272af601776bc31bc14e8f"), bytes.fromhex("676faf8f13272af601776bc31bc14e8f"),
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -115,7 +115,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue( res = misc.decrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"), bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"),
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=False, ask_on_decrypt=False,
@ -125,7 +125,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue( res = misc.decrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
bytes.fromhex("958d4f63269b61044aaedc900c8d6208"), bytes.fromhex("958d4f63269b61044aaedc900c8d6208"),
ask_on_encrypt=False, ask_on_encrypt=False,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -135,7 +135,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue( res = misc.decrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"), bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"),
ask_on_encrypt=False, ask_on_encrypt=False,
ask_on_decrypt=False, ask_on_decrypt=False,
@ -146,7 +146,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue( res = misc.decrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test2", "test2",
bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"), bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"),
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -157,7 +157,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue( res = misc.decrypt_keyvalue(
client, client,
[0, 1, 2], [0, 1, 2],
b"test", "test",
bytes.fromhex( bytes.fromhex(
"676faf8f13272af601776bc31bc14e8f3ae1c88536bf18f1b44f1e4c2c4a613d" "676faf8f13272af601776bc31bc14e8f3ae1c88536bf18f1b44f1e4c2c4a613d"
), ),
@ -170,7 +170,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue( res = misc.decrypt_keyvalue(
client, client,
[0, 1, 3], [0, 1, 3],
b"test", "test",
bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"), bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"),
ask_on_encrypt=True, ask_on_encrypt=True,
ask_on_decrypt=True, ask_on_decrypt=True,
@ -179,8 +179,8 @@ class TestMsgCipherkeyvalue(TrezorTest):
def test_encrypt_badlen(self, client): def test_encrypt_badlen(self, client):
with pytest.raises(Exception): with pytest.raises(Exception):
misc.encrypt_keyvalue(client, [0, 1, 2], b"test", b"testing") misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing")
def test_decrypt_badlen(self, client): def test_decrypt_badlen(self, client):
with pytest.raises(Exception): with pytest.raises(Exception):
misc.decrypt_keyvalue(client, [0, 1, 2], b"test", b"testing") misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing")