1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 16:00:57 +00:00

python: do not accept bytes for str protobuf fields

fixes #283

also adds typing information to misc.py
This commit is contained in:
matejcik 2019-08-29 13:56:09 +02:00
parent ab74f55a95
commit 5b8f542436
3 changed files with 63 additions and 41 deletions

View File

@ -14,21 +14,28 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from . import messages as proto
from .tools import expect
from . import messages
from .tools import expect, Address
if False:
from .client import TrezorClient
@expect(proto.Entropy, field="entropy")
def get_entropy(client, size):
return client.call(proto.GetEntropy(size=size))
@expect(messages.Entropy, field="entropy")
def get_entropy(client: "TrezorClient", size: int) -> messages.Entropy:
return client.call(messages.GetEntropy(size=size))
@expect(proto.SignedIdentity)
@expect(messages.SignedIdentity)
def sign_identity(
client, identity, challenge_hidden, challenge_visual, ecdsa_curve_name=None
):
client: "TrezorClient",
identity: messages.IdentityType,
challenge_hidden: bytes,
challenge_visual: str,
ecdsa_curve_name: str = None,
) -> messages.SignedIdentity:
return client.call(
proto.SignIdentity(
messages.SignIdentity(
identity=identity,
challenge_hidden=challenge_hidden,
challenge_visual=challenge_visual,
@ -37,10 +44,15 @@ def sign_identity(
)
@expect(proto.ECDHSessionKey)
def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=None):
@expect(messages.ECDHSessionKey)
def get_ecdh_session_key(
client: "TrezorClient",
identity: messages.IdentityType,
peer_public_key: bytes,
ecdsa_curve_name: str = None,
) -> messages.ECDHSessionKey:
return client.call(
proto.GetECDHSessionKey(
messages.GetECDHSessionKey(
identity=identity,
peer_public_key=peer_public_key,
ecdsa_curve_name=ecdsa_curve_name,
@ -48,12 +60,18 @@ def get_ecdh_session_key(client, identity, peer_public_key, ecdsa_curve_name=Non
)
@expect(proto.CipheredKeyValue, field="value")
@expect(messages.CipheredKeyValue, field="value")
def encrypt_keyvalue(
client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b""
):
client: "TrezorClient",
n: Address,
key: str,
value: bytes,
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> messages.CipheredKeyValue:
return client.call(
proto.CipherKeyValue(
messages.CipherKeyValue(
address_n=n,
key=key,
value=value,
@ -65,12 +83,18 @@ def encrypt_keyvalue(
)
@expect(proto.CipheredKeyValue, field="value")
@expect(messages.CipheredKeyValue, field="value")
def decrypt_keyvalue(
client, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b""
):
client: "TrezorClient",
n: Address,
key: str,
value: bytes,
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> messages.CipheredKeyValue:
return client.call(
proto.CipherKeyValue(
messages.CipherKeyValue(
address_n=n,
key=key,
value=value,

View File

@ -411,11 +411,9 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
writer.write(svalue)
elif ftype is UnicodeType:
if not isinstance(svalue, bytes):
svalue = svalue.encode()
dump_uvarint(writer, len(svalue))
writer.write(svalue)
svalue_bytes = svalue.encode()
dump_uvarint(writer, len(svalue_bytes))
writer.write(svalue_bytes)
elif issubclass(ftype, MessageType):
counter = CountingWriter()

View File

@ -27,7 +27,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
b"testing message!",
ask_on_encrypt=True,
ask_on_decrypt=True,
@ -37,7 +37,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
b"testing message!",
ask_on_encrypt=True,
ask_on_decrypt=False,
@ -47,7 +47,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
b"testing message!",
ask_on_encrypt=False,
ask_on_decrypt=True,
@ -57,7 +57,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
b"testing message!",
ask_on_encrypt=False,
ask_on_decrypt=False,
@ -68,7 +68,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue(
client,
[0, 1, 2],
b"test2",
"test2",
b"testing message!",
ask_on_encrypt=True,
ask_on_decrypt=True,
@ -79,7 +79,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
b"testing message! it is different",
ask_on_encrypt=True,
ask_on_decrypt=True,
@ -93,7 +93,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.encrypt_keyvalue(
client,
[0, 1, 3],
b"test",
"test",
b"testing message!",
ask_on_encrypt=True,
ask_on_decrypt=True,
@ -105,7 +105,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
bytes.fromhex("676faf8f13272af601776bc31bc14e8f"),
ask_on_encrypt=True,
ask_on_decrypt=True,
@ -115,7 +115,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
bytes.fromhex("5aa0fbcb9d7fa669880745479d80c622"),
ask_on_encrypt=True,
ask_on_decrypt=False,
@ -125,7 +125,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
bytes.fromhex("958d4f63269b61044aaedc900c8d6208"),
ask_on_encrypt=False,
ask_on_decrypt=True,
@ -135,7 +135,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
bytes.fromhex("e0cf0eb0425947000eb546cc3994bc6c"),
ask_on_encrypt=False,
ask_on_decrypt=False,
@ -146,7 +146,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue(
client,
[0, 1, 2],
b"test2",
"test2",
bytes.fromhex("de247a6aa6be77a134bb3f3f925f13af"),
ask_on_encrypt=True,
ask_on_decrypt=True,
@ -157,7 +157,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue(
client,
[0, 1, 2],
b"test",
"test",
bytes.fromhex(
"676faf8f13272af601776bc31bc14e8f3ae1c88536bf18f1b44f1e4c2c4a613d"
),
@ -170,7 +170,7 @@ class TestMsgCipherkeyvalue(TrezorTest):
res = misc.decrypt_keyvalue(
client,
[0, 1, 3],
b"test",
"test",
bytes.fromhex("b4811a9d492f5355a5186ddbfccaae7b"),
ask_on_encrypt=True,
ask_on_decrypt=True,
@ -179,8 +179,8 @@ class TestMsgCipherkeyvalue(TrezorTest):
def test_encrypt_badlen(self, client):
with pytest.raises(Exception):
misc.encrypt_keyvalue(client, [0, 1, 2], b"test", b"testing")
misc.encrypt_keyvalue(client, [0, 1, 2], "test", b"testing")
def test_decrypt_badlen(self, client):
with pytest.raises(Exception):
misc.decrypt_keyvalue(client, [0, 1, 2], b"test", b"testing")
misc.decrypt_keyvalue(client, [0, 1, 2], "test", b"testing")