parent
abbe5535ad
commit
e2d600389b
@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import typing as t
|
||||
from hashlib import sha256
|
||||
|
||||
from trezorlib import cosi, definitions, messages, protobuf
|
||||
|
||||
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
|
||||
|
||||
|
||||
def sign_with_privkeys(digest: bytes, privkeys: t.Sequence[bytes]) -> bytes:
|
||||
"""Locally produce a CoSi signature."""
|
||||
pubkeys = [cosi.pubkey_from_privkey(sk) for sk in privkeys]
|
||||
nonces = [cosi.get_nonce(sk, digest, i) for i, sk in enumerate(privkeys)]
|
||||
|
||||
global_pk = cosi.combine_keys(pubkeys)
|
||||
global_R = cosi.combine_keys(R for _, R in nonces)
|
||||
|
||||
sigs = [
|
||||
cosi.sign_with_privkey(digest, sk, global_pk, r, global_R)
|
||||
for sk, (r, _) in zip(privkeys, nonces)
|
||||
]
|
||||
|
||||
return cosi.combine_sig(global_R, sigs)
|
||||
|
||||
|
||||
def make_network(
|
||||
chain_id: int = 0,
|
||||
slip44: int = 0,
|
||||
symbol: str = "FAKE",
|
||||
name: str = "Fake network",
|
||||
) -> messages.EthereumNetworkInfo:
|
||||
return messages.EthereumNetworkInfo(
|
||||
chain_id=chain_id,
|
||||
slip44=slip44,
|
||||
symbol=symbol,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
def make_token(
|
||||
symbol: str = "FAKE",
|
||||
decimals: int = 18,
|
||||
address: bytes = b"",
|
||||
chain_id: int = 0,
|
||||
name: str = "Fake token",
|
||||
) -> messages.EthereumTokenInfo:
|
||||
return messages.EthereumTokenInfo(
|
||||
symbol=symbol,
|
||||
decimals=decimals,
|
||||
address=address,
|
||||
chain_id=chain_id,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
def make_payload(
|
||||
data_type: messages.EthereumDefinitionType = messages.EthereumDefinitionType.NETWORK,
|
||||
timestamp: int = 0xFFFF_FFFF,
|
||||
message: messages.EthereumNetworkInfo
|
||||
| messages.EthereumTokenInfo
|
||||
| bytes = make_network(),
|
||||
) -> bytes:
|
||||
if isinstance(message, bytes):
|
||||
message_bytes = message
|
||||
else:
|
||||
writer = io.BytesIO()
|
||||
protobuf.dump_message(writer, message)
|
||||
message_bytes = writer.getvalue()
|
||||
|
||||
payload = definitions.DefinitionPayload(
|
||||
magic=b"trzd1",
|
||||
data_type=data_type,
|
||||
timestamp=timestamp,
|
||||
data=message_bytes,
|
||||
)
|
||||
return payload.build()
|
||||
|
||||
|
||||
def sign_payload(
|
||||
payload: bytes,
|
||||
merkle_neighbors: list[bytes],
|
||||
threshold: int = 3,
|
||||
) -> tuple[bytes, bytes]:
|
||||
digest = sha256(b"\x00" + payload).digest()
|
||||
merkle_proof = []
|
||||
for item in merkle_neighbors:
|
||||
left, right = min(digest, item), max(digest, item)
|
||||
digest = sha256(b"\x01" + left + right).digest()
|
||||
merkle_proof.append(digest)
|
||||
|
||||
merkle_proof = len(merkle_proof).to_bytes(1, "little") + b"".join(merkle_proof)
|
||||
signature = sign_with_privkeys(digest, PRIVATE_KEYS_DEV[:threshold])
|
||||
sigmask = 0
|
||||
for i in range(threshold):
|
||||
sigmask |= 1 << i
|
||||
sigmask_byte = sigmask.to_bytes(1, "little")
|
||||
return merkle_proof, sigmask_byte + signature
|
||||
|
||||
|
||||
def encode_network(
|
||||
network: messages.EthereumNetworkInfo | None = None,
|
||||
chain_id: int = 0,
|
||||
slip44: int = 0,
|
||||
symbol: str = "FAKE",
|
||||
name: str = "Fake network",
|
||||
) -> bytes:
|
||||
if network is None:
|
||||
network = make_network(chain_id, slip44, symbol, name)
|
||||
payload = make_payload(
|
||||
data_type=messages.EthereumDefinitionType.NETWORK, message=network
|
||||
)
|
||||
proof, signature = sign_payload(payload, [])
|
||||
return payload + proof + signature
|
||||
|
||||
|
||||
def encode_token(
|
||||
token: messages.EthereumTokenInfo | None = None,
|
||||
symbol: str = "FakeTok",
|
||||
decimals: int = 18,
|
||||
address: t.AnyStr = b"",
|
||||
chain_id: int = 0,
|
||||
name: str = "Fake token",
|
||||
) -> bytes:
|
||||
if token is None:
|
||||
if isinstance(address, str):
|
||||
if address.startswith("0x"):
|
||||
address = address[2:]
|
||||
address = bytes.fromhex(address) # type: ignore (typechecker is lying)
|
||||
token = make_token(symbol, decimals, address, chain_id, name) # type: ignore (typechecker is lying)
|
||||
payload = make_payload(
|
||||
data_type=messages.EthereumDefinitionType.TOKEN, message=token
|
||||
)
|
||||
proof, signature = sign_payload(payload, [])
|
||||
return payload + proof + signature
|
||||
|
||||
|
||||
def make_defs(
|
||||
network: bytes | None, token: bytes | None
|
||||
) -> messages.EthereumDefinitions:
|
||||
return messages.EthereumDefinitions(
|
||||
encoded_network=network,
|
||||
encoded_token=token,
|
||||
)
|
@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import ethereum
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from . import common
|
||||
from .test_sign_typed_data import DATA as TYPED_DATA
|
||||
|
||||
pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum]
|
||||
|
||||
ERC20_OPERATION = "a9059cbb000000000000000000000000574bbb36871ba6b78e27f4b4dcfb76ea0091880b0000000000000000000000000000000000000000000000000000000000000123"
|
||||
ERC20_BUILTIN_TOKEN = "0xdac17f958d2ee523a2206206994597c13d831ec7" # USDT
|
||||
ERC20_FAKE_ADDRESS = "0xdddddddddddddddddddddddddddddddddddddddd"
|
||||
|
||||
DEFAULT_TX_PARAMS = {
|
||||
"nonce": 0x0,
|
||||
"gas_price": 0x4A817C800,
|
||||
"gas_limit": 0x5208,
|
||||
"value": 0x2540BE400,
|
||||
"to": "0x1d1c328764a41bda0492b66baa30c4a339ff85ef",
|
||||
"chain_id": 1,
|
||||
"n": parse_path("m/44h/60h/0h/0/0"),
|
||||
}
|
||||
|
||||
DEFAULT_ERC20_PARAMS = {
|
||||
"nonce": 0x0,
|
||||
"gas_price": 0x4A817C800,
|
||||
"gas_limit": 0x5208,
|
||||
"value": 0x0,
|
||||
"chain_id": 1,
|
||||
"n": parse_path("m/44h/60h/0h/0/0"),
|
||||
"data": bytes.fromhex(ERC20_OPERATION),
|
||||
}
|
||||
|
||||
|
||||
def test_builtin(client: Client) -> None:
|
||||
# Ethereum (SLIP-44 60, chain_id 1) will sign without any definitions provided
|
||||
ethereum.sign_tx(client, **DEFAULT_TX_PARAMS)
|
||||
|
||||
|
||||
def test_chain_id_allowed(client: Client) -> None:
|
||||
# Any chain id is allowed as long as the SLIP44 stays the same
|
||||
params = DEFAULT_TX_PARAMS.copy()
|
||||
params.update(chain_id=222222)
|
||||
ethereum.sign_tx(client, **params)
|
||||
|
||||
|
||||
def test_slip44_disallowed(client: Client) -> None:
|
||||
# SLIP44 is not allowed without a valid network definition
|
||||
params = DEFAULT_TX_PARAMS.copy()
|
||||
params.update(n=parse_path("m/44h/66666h/0h/0/0"))
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
ethereum.sign_tx(client, **params)
|
||||
|
||||
|
||||
def test_slip44_external(client: Client) -> None:
|
||||
# to use a non-default SLIP44, a valid network definition must be provided
|
||||
network = common.encode_network(chain_id=66666, slip44=66666)
|
||||
params = DEFAULT_TX_PARAMS.copy()
|
||||
params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=66666)
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None))
|
||||
|
||||
|
||||
def test_slip44_external_disallowed(client: Client) -> None:
|
||||
# network definition does not allow a different SLIP44
|
||||
network = common.encode_network(chain_id=66666, slip44=66666)
|
||||
params = DEFAULT_TX_PARAMS.copy()
|
||||
params.update(n=parse_path("m/44h/55555h/0h/0/0"), chain_id=66666)
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None))
|
||||
|
||||
|
||||
def test_chain_id_mismatch(client: Client) -> None:
|
||||
# network definition for a different chain id will be rejected
|
||||
network = common.encode_network(chain_id=66666, slip44=60)
|
||||
params = DEFAULT_TX_PARAMS.copy()
|
||||
params.update(chain_id=55555)
|
||||
with pytest.raises(TrezorFailure, match="Network definition mismatch"):
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None))
|
||||
|
||||
|
||||
def test_definition_does_not_override_builtin(client: Client) -> None:
|
||||
# The builtin definition for Ethereum (SLIP44 60, chain_id 1) will be used
|
||||
# even if a valid definition with a different SLIP44 is provided
|
||||
network = common.encode_network(chain_id=1, slip44=66666)
|
||||
params = DEFAULT_TX_PARAMS.copy()
|
||||
params.update(n=parse_path("m/44h/66666h/0h/0/0"), chain_id=1)
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None))
|
||||
|
||||
# TODO: test that the builtin definition will not show different symbol
|
||||
|
||||
|
||||
# TODO: figure out how to test acceptance of a token definition
|
||||
# all tokens are currently accepted, we would need to check the screenshots
|
||||
|
||||
|
||||
def test_builtin_token(client: Client) -> None:
|
||||
# The builtin definition for USDT (ERC20) will be used even if not provided
|
||||
params = DEFAULT_ERC20_PARAMS.copy()
|
||||
params.update(to=ERC20_BUILTIN_TOKEN)
|
||||
ethereum.sign_tx(client, **params)
|
||||
# TODO check that USDT symbol is shown
|
||||
|
||||
|
||||
# TODO: test_builtin_token_not_overriden (builtin definition is used even if a custom one is provided)
|
||||
|
||||
|
||||
def test_external_token(client: Client) -> None:
|
||||
# A valid token definition must be provided to use a non-builtin token
|
||||
token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=1, decimals=8)
|
||||
params = DEFAULT_ERC20_PARAMS.copy()
|
||||
params.update(to=ERC20_FAKE_ADDRESS)
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(None, token))
|
||||
# TODO check that FakeTok symbol is shown
|
||||
|
||||
|
||||
def test_external_chain_without_token(client: Client) -> None:
|
||||
# when using an external chains, unknown tokens are allowed
|
||||
network = common.encode_network(chain_id=66666, slip44=60)
|
||||
params = DEFAULT_ERC20_PARAMS.copy()
|
||||
params.update(to=ERC20_BUILTIN_TOKEN, chain_id=66666)
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(network, None))
|
||||
# TODO check that UNKN token is used, FAKE network
|
||||
|
||||
|
||||
def test_external_chain_token_ok(client: Client) -> None:
|
||||
# when providing an external chain and matching token, everything works
|
||||
network = common.encode_network(chain_id=66666, slip44=60)
|
||||
token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=66666, decimals=8)
|
||||
params = DEFAULT_ERC20_PARAMS.copy()
|
||||
params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666)
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token))
|
||||
# TODO check that FakeTok is used, FAKE network
|
||||
|
||||
|
||||
def test_external_chain_token_mismatch(client: Client) -> None:
|
||||
# when providing external defs, we explicitly allow, but not use, tokens
|
||||
# from other chains
|
||||
network = common.encode_network(chain_id=66666, slip44=60)
|
||||
token = common.encode_token(address=ERC20_FAKE_ADDRESS, chain_id=55555, decimals=8)
|
||||
params = DEFAULT_ERC20_PARAMS.copy()
|
||||
params.update(to=ERC20_FAKE_ADDRESS, chain_id=66666)
|
||||
ethereum.sign_tx(client, **params, definitions=common.make_defs(network, token))
|
||||
# TODO check that UNKN is used for token, FAKE for network
|
||||
|
||||
|
||||
def _call_getaddress(client: Client, slip44: int, network: bytes | None) -> None:
|
||||
ethereum.get_address(
|
||||
client,
|
||||
parse_path(f"m/44h/{slip44}h/0h"),
|
||||
show_display=False,
|
||||
encoded_network=network,
|
||||
)
|
||||
|
||||
|
||||
def _call_signmessage(client: Client, slip44: int, network: bytes | None) -> None:
|
||||
ethereum.sign_message(
|
||||
client,
|
||||
parse_path(f"m/44h/{slip44}h/0h"),
|
||||
b"hello",
|
||||
encoded_network=network,
|
||||
)
|
||||
|
||||
|
||||
def _call_sign_typed_data(client: Client, slip44: int, network: bytes | None) -> None:
|
||||
ethereum.sign_typed_data(
|
||||
client,
|
||||
parse_path(f"m/44h/{slip44}h/0h/0/0"),
|
||||
TYPED_DATA,
|
||||
metamask_v4_compat=True,
|
||||
definitions=common.make_defs(network, None),
|
||||
)
|
||||
|
||||
|
||||
def _call_sign_typed_data_hash(
|
||||
client: Client, slip44: int, network: bytes | None
|
||||
) -> None:
|
||||
ethereum.sign_typed_data_hash(
|
||||
client,
|
||||
parse_path(f"m/44h/{slip44}h/0h/0/0"),
|
||||
b"\x00" * 32,
|
||||
b"\xff" * 32,
|
||||
encoded_network=network,
|
||||
)
|
||||
|
||||
|
||||
MethodType = Callable[[Client, int, "bytes | None"], None]
|
||||
|
||||
|
||||
METHODS = (
|
||||
_call_getaddress,
|
||||
_call_signmessage,
|
||||
pytest.param(_call_sign_typed_data, marks=pytest.mark.skip_t1),
|
||||
pytest.param(_call_sign_typed_data_hash, marks=pytest.mark.skip_t2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", METHODS)
|
||||
def test_method_builtin(client: Client, method: MethodType) -> None:
|
||||
# calling a method with a builtin slip44 will work
|
||||
method(client, 60, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", METHODS)
|
||||
def test_method_def_missing(client: Client, method: MethodType) -> None:
|
||||
# calling a method with a slip44 that has no definition will fail
|
||||
with pytest.raises(TrezorFailure, match="Forbidden key path"):
|
||||
method(client, 66666, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", METHODS)
|
||||
def test_method_external(client: Client, method: MethodType) -> None:
|
||||
# calling a method with a slip44 that has an external definition will work
|
||||
network = common.encode_network(slip44=66666)
|
||||
method(client, 66666, network)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", METHODS)
|
||||
def test_method_external_mismatch(client: Client, method: MethodType) -> None:
|
||||
# calling a method with a slip44 that has an external definition that does not match
|
||||
# the slip44 will fail
|
||||
network = common.encode_network(slip44=77777)
|
||||
with pytest.raises(TrezorFailure, match="Network definition mismatch"):
|
||||
method(client, 66666, network)
|
@ -0,0 +1,131 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from hashlib import sha256
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import ethereum
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
from trezorlib.messages import EthereumDefinitionType
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from .common import make_defs, make_network, make_payload, make_token, sign_payload
|
||||
from .test_definitions import DEFAULT_ERC20_PARAMS, ERC20_FAKE_ADDRESS
|
||||
|
||||
pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum]
|
||||
|
||||
|
||||
def fails(client: Client, network: bytes, match: str) -> None:
|
||||
with pytest.raises(TrezorFailure, match=match):
|
||||
ethereum.get_address(
|
||||
client,
|
||||
parse_path("m/44h/666666h/0h"),
|
||||
show_display=False,
|
||||
encoded_network=network,
|
||||
)
|
||||
|
||||
|
||||
def test_short_message(client: Client) -> None:
|
||||
fails(client, b"\x00", "Invalid Ethereum definition")
|
||||
|
||||
|
||||
def test_mangled_signature(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
proof, signature = sign_payload(payload, [])
|
||||
bad_signature = signature[:-1] + b"\xff"
|
||||
fails(client, payload + proof + bad_signature, "Invalid definition signature")
|
||||
|
||||
|
||||
def test_not_enough_signatures(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
proof, signature = sign_payload(payload, [], threshold=1)
|
||||
fails(client, payload + proof + signature, "Invalid definition signature")
|
||||
|
||||
|
||||
def test_missing_signature(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
proof, _ = sign_payload(payload, [])
|
||||
fails(client, payload + proof, "Invalid Ethereum definition")
|
||||
|
||||
|
||||
def test_mangled_payload(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
proof, signature = sign_payload(payload, [])
|
||||
bad_payload = payload[:-1] + b"\xff"
|
||||
fails(client, bad_payload + proof + signature, "Invalid definition signature")
|
||||
|
||||
|
||||
def test_proof_length_mismatch(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
_, signature = sign_payload(payload, [])
|
||||
bad_proof = b"\x01"
|
||||
fails(client, payload + bad_proof + signature, "Invalid Ethereum definition")
|
||||
|
||||
|
||||
def test_bad_proof(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
proof, signature = sign_payload(payload, [sha256(b"x").digest()])
|
||||
bad_proof = proof[:-1] + b"\xff"
|
||||
fails(client, payload + bad_proof + signature, "Invalid definition signature")
|
||||
|
||||
|
||||
def test_trimmed_proof(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
proof, signature = sign_payload(payload, [])
|
||||
bad_proof = proof[:-1]
|
||||
fails(client, payload + bad_proof + signature, "Invalid Ethereum definition")
|
||||
|
||||
|
||||
def test_bad_prefix(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
payload = b"trzd2" + payload[5:]
|
||||
proof, signature = sign_payload(payload, [])
|
||||
fails(client, payload + proof + signature, "Invalid Ethereum definition")
|
||||
|
||||
|
||||
def test_bad_type(client: Client) -> None:
|
||||
# assuming we expect a network definition
|
||||
payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=make_token())
|
||||
proof, signature = sign_payload(payload, [])
|
||||
fails(client, payload + proof + signature, "Definition type mismatch")
|
||||
|
||||
|
||||
def test_outdated(client: Client) -> None:
|
||||
payload = make_payload(timestamp=0)
|
||||
proof, signature = sign_payload(payload, [])
|
||||
fails(client, payload + proof + signature, "Definition is outdated")
|
||||
|
||||
|
||||
def test_malformed_protobuf(client: Client) -> None:
|
||||
payload = make_payload(message=b"\x00")
|
||||
proof, signature = sign_payload(payload, [])
|
||||
fails(client, payload + proof + signature, "Invalid Ethereum definition")
|
||||
|
||||
|
||||
def test_protobuf_mismatch(client: Client) -> None:
|
||||
payload = make_payload(
|
||||
data_type=EthereumDefinitionType.NETWORK, message=make_token()
|
||||
)
|
||||
proof, signature = sign_payload(payload, [])
|
||||
fails(client, payload + proof + signature, "Invalid Ethereum definition")
|
||||
|
||||
payload = make_payload(
|
||||
data_type=EthereumDefinitionType.TOKEN, message=make_network()
|
||||
)
|
||||
proof, signature = sign_payload(payload, [])
|
||||
# have to do this manually to invoke a method that eats token definitions
|
||||
with pytest.raises(TrezorFailure, match="Invalid Ethereum definition"):
|
||||
params = DEFAULT_ERC20_PARAMS.copy()
|
||||
params.update(to=ERC20_FAKE_ADDRESS)
|
||||
ethereum.sign_tx(
|
||||
client,
|
||||
**params,
|
||||
definitions=make_defs(None, payload + proof + signature),
|
||||
)
|
||||
|
||||
|
||||
def test_trailing_garbage(client: Client) -> None:
|
||||
payload = make_payload()
|
||||
proof, signature = sign_payload(payload, [])
|
||||
fails(client, payload + proof + signature + b"\x00", "Invalid Ethereum definition")
|
Loading…
Reference in new issue