1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-30 18:38:27 +00:00

feat(core): Include address_n in address MAC.

This commit is contained in:
Andrew Kozlik 2025-04-15 09:56:19 +02:00 committed by Ioan Bizău
parent f699ecf65e
commit cd1b194be8
7 changed files with 36 additions and 19 deletions

View File

@ -102,7 +102,7 @@ async def get_address(msg: GetAddress, keychain: Keychain, coin: CoinInfo) -> Ad
keychain.is_in_keychain(address_n)
and validate_path_against_script_type(coin, msg)
):
mac = get_address_mac(address, coin.slip44, keychain)
mac = get_address_mac(address, coin.slip44, address_n, keychain)
if msg.show_display:
path = address_n_to_str(address_n)

View File

@ -4,22 +4,25 @@ from trezor import utils
if TYPE_CHECKING:
from apps.common.keychain import Keychain
from apps.common.paths import Bip32Path
_ADDRESS_MAC_KEY_PATH = [b"SLIP-0024", b"Address MAC key"]
def check_address_mac(
address: str, mac: bytes, slip44: int, keychain: Keychain
address: str, mac: bytes, slip44: int, address_n: Bip32Path, keychain: Keychain
) -> None:
from trezor import wire
from trezor.crypto import hashlib
expected_mac = get_address_mac(address, slip44, keychain)
expected_mac = get_address_mac(address, slip44, address_n, keychain)
if len(mac) != hashlib.sha256.digest_size or not utils.consteq(expected_mac, mac):
raise wire.DataError("Invalid address MAC.")
def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes:
def get_address_mac(
address: str, slip44: int, address_n: Bip32Path, keychain: Keychain
) -> bytes:
from trezor.crypto import hmac
from .writers import write_bytes_unchecked, write_compact_size, write_uint32_le
@ -31,6 +34,9 @@ def get_address_mac(address: str, slip44: int, keychain: Keychain) -> bytes:
mac = utils.HashWriter(hmac(hmac.SHA256, node.key()))
address_bytes = address.encode()
write_uint32_le(mac, slip44)
write_compact_size(mac, len(address_n))
for n in address_n:
write_uint32_le(mac, n)
write_compact_size(mac, len(address_bytes))
write_bytes_unchecked(mac, address_bytes)
return mac.get_digest()

View File

@ -59,12 +59,16 @@ class PaymentRequestVerifier:
elif m.refund_memo is not None:
memo = m.refund_memo
# Unlike in a coin purchase memo, the coin type is implied by the payment request.
check_address_mac(memo.address, memo.mac, slip44_id, keychain)
check_address_mac(
memo.address, memo.mac, slip44_id, memo.address_n, keychain
)
writers.write_uint32_le(self.h_pr, _MEMO_TYPE_REFUND)
writers.write_bytes_prefixed(self.h_pr, memo.address.encode())
elif m.coin_purchase_memo is not None:
memo = m.coin_purchase_memo
check_address_mac(memo.address, memo.mac, memo.coin_type, keychain)
check_address_mac(
memo.address, memo.mac, memo.coin_type, memo.address_n, keychain
)
writers.write_uint32_le(self.h_pr, _MEMO_TYPE_COIN_PURCHASE)
writers.write_uint32_le(self.h_pr, memo.coin_type)
writers.write_bytes_prefixed(self.h_pr, memo.amount.encode())

View File

@ -34,7 +34,7 @@ async def get_address(
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
slip44_id = address_n[1] # it depends on the network (ETH vs ETC...)
mac = get_address_mac(address, paths.unharden(slip44_id), keychain)
mac = get_address_mac(address, paths.unharden(slip44_id), address_n, keychain)
if msg.show_display:
coin = "ETH"

View File

@ -328,18 +328,20 @@ class TestAddress(unittest.TestCase):
VECTORS = (
(
"Bitcoin",
[H_(44), H_(0), H_(0), 1, 0],
"1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE",
"9cf7c230041d6ed95b8273bd32e023d3f227ec8c44257f6463c743a4b4add028",
"158dd8df21894cc1cb01a33736a50884ecd6d5c2bcc2ffd2398f4d147d19c191",
),
(
"Testnet",
[H_(44), H_(0), H_(0), 1, 0],
"mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ",
"4375089e50423505dc3480e6e85b0ba37a52bd1e009db5d260b6329f22c950d9",
"0b1048fcf82a0a08cffc87a8db2e2512e0d1379eb8d15c9adae8672ba2e00be0",
),
)
seed = bip39.seed(" ".join(["all"] * 12), "")
for coin_name, address, mac in VECTORS:
for coin_name, address_n, address, mac in VECTORS:
coin = coins.by_name(coin_name)
mac = unhexlify(mac)
keychain = Keychain(
@ -348,11 +350,13 @@ class TestAddress(unittest.TestCase):
[AlwaysMatchingSchema],
slip21_namespaces=[[b"SLIP-0024"]],
)
self.assertEqual(get_address_mac(address, coin.slip44, keychain), mac)
check_address_mac(address, mac, coin.slip44, keychain)
self.assertEqual(
get_address_mac(address, coin.slip44, address_n, keychain), mac
)
check_address_mac(address, mac, coin.slip44, address_n, keychain)
with self.assertRaises(wire.DataError):
mac = bytes([mac[0] ^ 1]) + mac[1:]
check_address_mac(address, mac, coin.slip44, keychain)
check_address_mac(address, mac, coin.slip44, address_n, keychain)
if __name__ == "__main__":

View File

@ -146,7 +146,7 @@ def test_address_mac(client: Client):
assert resp.address == "1DyHzbQUoQEsLxJn6M7fMD8Xdt1XvNiwNE"
assert (
resp.mac.hex()
== "9cf7c230041d6ed95b8273bd32e023d3f227ec8c44257f6463c743a4b4add028"
== "158dd8df21894cc1cb01a33736a50884ecd6d5c2bcc2ffd2398f4d147d19c191"
)
resp = btc.get_authenticated_address(
@ -155,7 +155,7 @@ def test_address_mac(client: Client):
assert resp.address == "mm6kLYbGEL1tGe4ZA8xacfgRPdW1NLjCbZ"
assert (
resp.mac.hex()
== "4375089e50423505dc3480e6e85b0ba37a52bd1e009db5d260b6329f22c950d9"
== "bc555ef1d74814b26a7a2c8039c87414cdd0730027aee8e038ad236ce57875c1"
)
# Script type mismatch.
@ -174,7 +174,7 @@ def test_altcoin_address_mac(client: Client):
assert resp.address == "LWj6ApswZxay4cJEJES2sGe7fLMLRvvv8h"
assert (
resp.mac.hex()
== "eaf47182d7ae17d2046ec2e204bc5b67477db20a5eaea3cec5393c25664bc4d2"
== "13ae756e50735626639dcaf037e65d62b152593e9f088f8bbf452a3c148b5ea6"
)
resp = btc.get_authenticated_address(
@ -183,7 +183,7 @@ def test_altcoin_address_mac(client: Client):
assert resp.address == "bitcoincash:qzc5q87w069lzg7g3gzx0c8dz83mn7l02scej5aluw"
assert (
resp.mac.hex()
== "46d8e369b499a9dc62eb9e4472f4a12640ae0fb7a63c1a4dde6752123b2b7274"
== "d5f54aea2200f50d1a0419d2faae62a82198f27df63ae4eac56d0e625de174f5"
)
resp = btc.get_authenticated_address(
@ -192,7 +192,7 @@ def test_altcoin_address_mac(client: Client):
assert resp.address == "Fmhtxeh7YdCBkyQF7AQG4QnY8y3rJg89di"
assert (
resp.mac.hex()
== "08d67c5f1ee20fd03f3e5aa26f798574716c122238ac280e33a6f3787d531552"
== "c988b6f968a55b5be918d97f0f672acfd3bdbc91c148bc1a04a0ae4e8753464c"
)

View File

@ -71,7 +71,9 @@ def make_payment_request(
hash_bytes_prefixed(h_pr, memo.text.encode())
elif isinstance(memo, RefundMemo):
msg_memo = messages.RefundMemo(
address=memo.address_resp.address, mac=memo.address_resp.mac
address=memo.address_resp.address,
address_n=memo.address_n,
mac=memo.address_resp.mac,
)
msg_memos.append(messages.PaymentRequestMemo(refund_memo=msg_memo))
memo_type = 2
@ -82,6 +84,7 @@ def make_payment_request(
coin_type=memo.slip44,
amount=memo.amount,
address=memo.address_resp.address,
address_n=memo.address_n,
mac=memo.address_resp.mac,
)
msg_memos.append(messages.PaymentRequestMemo(coin_purchase_memo=msg_memo))