from common import *

if not utils.BITCOIN_ONLY:
    from apps.monero.xmr import crypto, crypto_helpers, clsag
    from apps.monero.xmr.serialize_messages.tx_ct_key import CtKey
    from trezor.crypto import monero as tcry
    from trezor.crypto import random
    import ubinascii

    point_mul8_into = tcry.ge25519_mul8


class TmpKey:
    def __init__(self, d, c):
        self.dest = d
        self.commitment = c


@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestMoneroClsag(unittest.TestCase):
    def verify_clsag(self, msg, ss, sc1, sI, sD, pubs, C_offset):
        n = len(pubs)
        c = crypto.Scalar()
        D_8 = crypto.Point()
        tmp_bf = bytearray(32)
        C_offset_bf = crypto_helpers.encodepoint(C_offset)

        crypto.sc_copy(c, sc1)
        point_mul8_into(D_8, sD)

        hsh_P = crypto_helpers.get_keccak()  # domain, I, D, P, C, C_offset
        hsh_C = crypto_helpers.get_keccak()  # domain, I, D, P, C, C_offset
        hsh_P.update(clsag._HASH_KEY_CLSAG_AGG_0)
        hsh_C.update(clsag._HASH_KEY_CLSAG_AGG_1)

        def hsh_PC(x):
            hsh_P.update(x)
            hsh_C.update(x)

        for x in pubs:
            hsh_PC(x.dest)

        for x in pubs:
            hsh_PC(x.commitment)

        hsh_PC(crypto.encodepoint_into(tmp_bf, sI))
        hsh_PC(crypto.encodepoint_into(tmp_bf, sD))
        hsh_PC(C_offset_bf)
        mu_P = crypto_helpers.decodeint(hsh_P.digest())
        mu_C = crypto_helpers.decodeint(hsh_C.digest())

        c_to_hash = crypto_helpers.get_keccak()  # domain, P, C, C_offset, message, L, R
        c_to_hash.update(clsag._HASH_KEY_CLSAG_ROUND)
        for i in range(len(pubs)):
            c_to_hash.update(pubs[i].dest)
        for i in range(len(pubs)):
            c_to_hash.update(pubs[i].commitment)
        c_to_hash.update(C_offset_bf)
        c_to_hash.update(msg)

        c_p = crypto.Scalar()
        c_c = crypto.Scalar()
        L = crypto.Point()
        R = crypto.Point()
        tmp_pt = crypto.Point()
        i = 0
        while i < n:
            crypto.sc_mul_into(c_p, mu_P, c)
            crypto.sc_mul_into(c_c, mu_C, c)

            C_P = crypto.point_sub_into(
                None, crypto.decodepoint_into(tmp_pt, pubs[i].commitment), C_offset
            )
            crypto.add_keys2_into(
                L, ss[i], c_p, crypto.decodepoint_into(tmp_pt, pubs[i].dest)
            )
            crypto.point_add_into(L, L, crypto.scalarmult_into(tmp_pt, C_P, c_c))

            HP = crypto.hash_to_point_into(None, pubs[i].dest)
            crypto.add_keys3_into(R, ss[i], HP, c_p, sI)
            crypto.point_add_into(R, R, crypto.scalarmult_into(tmp_pt, D_8, c_c))

            chasher = c_to_hash.copy()
            chasher.update(crypto.encodepoint_into(tmp_bf, L))
            chasher.update(crypto.encodepoint_into(tmp_bf, R))
            crypto.decodeint_into(c, chasher.digest())
            i += 1
        res = crypto.sc_sub_into(None, c, sc1)
        if not crypto.sc_eq(res, crypto.Scalar(0)):
            raise ValueError("Signature error")

    def gen_clsag_test(self, ring_size=11, index=None):
        res = self.gen_clsag_sig(ring_size=11, index=index)
        msg, scalars, sc1, sI, sD, ring2, Cp = res
        self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)

    def gen_clsag_sig(self, ring_size=11, index=None):
        msg = random.bytes(32)
        amnt = crypto.Scalar(random.uniform(0xFFFFFF) + 12)
        priv = crypto.random_scalar()
        msk = crypto.random_scalar()
        alpha = crypto.random_scalar()
        P = crypto.scalarmult_base_into(None, priv)
        C = crypto.add_keys2_into(None, msk, amnt, crypto.xmr_H())
        Cp = crypto.add_keys2_into(None, alpha, amnt, crypto.xmr_H())

        ring = []
        for i in range(ring_size - 1):
            tk = TmpKey(
                crypto_helpers.encodepoint(
                    crypto.scalarmult_base_into(None, crypto.random_scalar())
                ),
                crypto_helpers.encodepoint(
                    crypto.scalarmult_base_into(None, crypto.random_scalar())
                ),
            )
            ring.append(tk)

        index = index if index is not None else random.uniform(len(ring))
        ring.insert(index, TmpKey(crypto_helpers.encodepoint(P), crypto_helpers.encodepoint(C)))
        ring2 = list(ring)
        mg_buffer = []

        self.assertTrue(
            crypto.point_eq(
                crypto.scalarmult_base_into(None, priv),
                crypto_helpers.decodepoint(ring[index].dest),
            )
        )
        self.assertTrue(
            crypto.point_eq(
                crypto.scalarmult_base_into(None, crypto.sc_sub_into(None, msk, alpha)),
                crypto.point_sub_into(
                    None, crypto_helpers.decodepoint(ring[index].commitment), Cp
                ),
            )
        )

        clsag.generate_clsag_simple(
            msg, ring, CtKey(priv, msk), alpha, Cp, index, mg_buffer,
        )

        sD = crypto_helpers.decodepoint(mg_buffer[-1])
        sc1 = crypto_helpers.decodeint(mg_buffer[-2])
        scalars = [crypto_helpers.decodeint(x) for x in mg_buffer[1:-2]]
        H = crypto.Point()
        sI = crypto.Point()

        crypto.hash_to_point_into(H, crypto_helpers.encodepoint(P))
        crypto.scalarmult_into(sI, H, priv)  # I = p*H
        return msg, scalars, sc1, sI, sD, ring2, Cp

    def verify_monero_generated(self, clsag):
        msg = ubinascii.unhexlify(clsag["msg"])
        sI = crypto_helpers.decodepoint(ubinascii.unhexlify(clsag["sI"]))
        sD = crypto_helpers.decodepoint(ubinascii.unhexlify(clsag["sD"]))
        sc1 = crypto_helpers.decodeint(ubinascii.unhexlify(clsag["sc1"]))
        Cout = crypto_helpers.decodepoint(ubinascii.unhexlify(clsag["cout"]))
        scalars = [crypto_helpers.decodeint(ubinascii.unhexlify(x)) for x in clsag["ss"]]
        ring = []
        for e in clsag["ring"]:
            ring.append(TmpKey(ubinascii.unhexlify(e[0]), ubinascii.unhexlify(e[1])))

        self.verify_clsag(msg, scalars, sc1, sI, sD, ring, Cout)

    def test_monero_generated_clsag_01(self):
        clsag = {
            "msg": "0100000000000000000000000000000000000000000000000000000000000000",
            "cout": "8e3afb92d8ae1264417489259e38f7205a62baea86ae9592cd91988b9cc48102",
            "sI": "a1c7f4a316ddd16374fe495d402be60566047ae5a1352554e98ebff118705303",
            "sD": "cd80b5c7f3f597de6e20bcef669a4ba9eb3eb89ead12ab1c24c92acd609afcb2",
            "sc1": "cf4f48ed60771d4e8d02e9e0af37281ceeb66573bd528ac256a7e17794a75602",
            "ss": [
                "aaeffa564b5b0ff1e4ed72c9b595cd0241ac64eeb41b902a35688e369922d704",
                "1defc134a853252d734d19b29d8f2fabc85a8ae24ebcf8f050d4daf8a335e901",
                "cdf9ac576f0c7ceb7eb22c1a1254a801d0d2915e59870be8b1ab68cd1281120d",
                "d1973493d8224aaa9732878b9a88d448ea16185f94e5bafd82816277682fa108",
                "a130e076845e512687575942bf3694bcb44eb19eb1181af9a1fc2254949b7c0f",
                "26f5b6ea154d6bd4a969c742563d75f1bfcd5ded3af78669e45ba95e76c48605",
                "5b695d3be46b826fd11e043028dee2aa25cf36910e86537fcd1cd3f5cb49650e",
                "37e811ebb4a2b9c35556b4af911a03a93468f599956c034092c3ece9e1169208",
                "a361ceec9aacd65da6d3e686fbcd0c1aef26096321be7f01653157ee6096a201",
                "f9b762ef1df69bb12ca76a97dce11f7840b8ec63c3dc2683f7ae71cb79c49103",
                "ea010fa6a35f3bd3d7899a7a2a8df4d3ef9c9dfbbd56fe43ff5c7442821d3508",
            ],
            "ring": [
                [
                    "241c0295b4c3a149e5ac7997963e125d0fc6cc8adad9349df3b01ff611936c87",
                    "3a24a4c418ccb2ceb83672d01534a73ff1e9f548937d5ddd7f1971c9b398868c",
                ],
                [
                    "ec432ccfbf730077cb2d8c59968e2796148a590eec7928ecf268d883ced0de5b",
                    "2973d6e9c27538fd0f7c003e014311e9403dcb6e7d86b66df65176a579943bda",
                ],
                [
                    "0cfeafc313a6a2e60110778d53d61fa1705e9049b8afba0f51c1127f6855c07f",
                    "ffa4d4c77202907832294243a96886920017b67fbe5b3800bcc1457c4a4a1ff0",
                ],
                [
                    "bd4eca22dc010a214524901b88bdda27e427217ff784c47520ee76743caba036",
                    "e07135f8398459133c2969184e70610b9b995f73e44acf54b6eaed6227e68bbc",
                ],
                [
                    "73c8d57d0128c99fc2ab0be8cee5fe5c1288b98e51822a6681846035fcc53fea",
                    "2987499fde3f4353013206d89fe2d7c6ad3cd9a66c9a36d17749e39112513572",
                ],
                [
                    "385c538901b79c6bd2ddea5191e808b1414c9dfdcaf424841d843dd788cb89ad",
                    "ec5f987fe138c6cb1d47ff75d77852b7c0a94ba1f0b93d22c0463f75986605bd",
                ],
                [
                    "fed06cb761745a6f087d1af13f84670ecbf1523d72b46e8bd0698d1cdfb398bc",
                    "5d81df981fb885f947b9404cb63cb06fe4e001be281f2bdfb3c638d54ec6e49e",
                ],
                [
                    "667d1edfb83a17bd81fcf7831362b6c9038f26340ee1fe56d41f62cb0b32e989",
                    "e9ceba97867b43cd5420c94fa61cc5f11e440e261df74dfc8b1c07ec4b13aa3c",
                ],
                [
                    "e1e76da5bd52fc065f9af40efde5f733f9673974d14c6af8d200d8576ac3a90d",
                    "97358d6ddad38b2707fb864bfcaaab935851af66d50bcbac569d159d740bdf71",
                ],
                [
                    "4fd5d0db88283c63905d5095a76b11a75337e43f403f8469175ba9c49741552e",
                    "af0ab85872a6355d5c82c1f9a2a41488146e19b272887a1f7385cc26bef3f1d8",
                ],
                [
                    "37e1a4c49a22340fa5ac2c22c1b7a891e7191cdc53911700a317c0d8b92bbf4e",
                    "5c89d29dad77de7d76ece8bb81c7c8cd15008f63c5a14ab1c984b3833e7bbce3",
                ],
            ],
        }
        self.verify_monero_generated(clsag)

    def test_monero_generated_clsag_02(self):
        clsag = {
            "msg": "0100000000000000000000000000000000000000000000000000000000000000",
            "cout": "fdf2503d3217dbf73ababd16f5ab5a63d64c047db1d02b0888a50d2570f3a793",
            "sI": "917fdd3086c056503ffdb1840f03c78d48bfe6d9d60b4efb194bd9798d03acaa",
            "sD": "769d0ca9b272ac02c5efad7df6b5c00f2995c99ca80f4597136decba9a0dd36f",
            "sc1": "fe5c7eb39a32d2aea12e6d127d847b72ea810bfbf3d5bbe23c40e7abdd12900e",
            "ss": [
                "da2940c66cc2405032d959325c8804e216f76b36e71b2ae6b76417ed9c10a80a",
                "ca763505c2e5ebacf72098f8cba89ea6826aa448501f03d439c7a838a88bba0e",
                "b2eadee4c121e85b0c2a09d56c665ba19ee8ebc451f1e9e96cf72c874f945104",
                "5a79523fdc0df9a54ab3937c878bd5a02e62bff77efc338728deb060ecda4509",
                "dfadddc51866cde5206269270f44ca2f6350ca0b1328a968773fcacf57031502",
                "a964f3549a10fc8bdb2f8217df0e9b08e90477be19a665b94b73ce417622450b",
                "48e805427109268b04bf378c869501dbebb79c0cbe664bf7eb0ca222376d1c0f",
                "33f36d9a699e92a66d4b9fdf6c1123ae99701b117fbe8f0af9faec51e45eb409",
                "25ef746a03aaf59701d1d47ea3b9e9f092662cebc9d44902ce18e81cc5035f01",
                "2ba3022d4f9b57da7429499715592073f1608cf270318840a5fd3890bbf5950a",
                "8149ec0d965c9881d6a4adedca7d3c9090359dbfae56dbab526be102722aab09",
            ],
            "ring": [
                [
                    "081b048be784e1ff6f3b7ebe602690c27723b5d9952405bcdcbed31d16125067",
                    "6090eccb73d2e1fc7bc7644a4fad04e5fe93d953a1258307c44d5b23cd636bf9",
                ],
                [
                    "e2f0f100f1634d7c7dd5a09bc6dd7ee53506d73536aa743e8ea049528e4cb2aa",
                    "632438f9aeda72eb9c6c434391cf9fa2f71788bea598a5d5729a5d502865932a",
                ],
                [
                    "6744197cfde37ad1901d518f112c0f4d820c23122a016949e300eec2ab88916c",
                    "1b251d5b32e22de29a4f99a0ed1de32754636175075e21b25d7283036eb85541",
                ],
                [
                    "0e86bb7ee0b4728f2fedde7ac5019b54de7b2bb19b44d1864e6346dac6c171ab",
                    "5a3c85e93890f802d4148140733dcdcd676353fce1bd774ce28034fc2ec00253",
                ],
                [
                    "1847ce49d9552651395b2fa80637c131a31036f0bfc5abb63526701cd1a32320",
                    "a9cb55bc24e6e1fb894c511f2edd4b7bda4c75a608657d952e85bab83ec98a52",
                ],
                [
                    "5c5d0b678f5045b0304e3c48027bd7e9ccaee1dac4449ed1f34b204868ca5651",
                    "badf83ccba38f2194f924a4f7fb7c2fd966b1e16c1fddeb3658033aa009febe0",
                ],
                [
                    "81961aa4c241a91d498d8f3057b31373d9fc72b6e7d7f98bf497e3dfe705eeaa",
                    "a0e632fbb801d6bce99ef97d7bb6acd945aff5cd7fab56c0e6fec6900a3babd7",
                ],
                [
                    "cbd89f10ddf152bd9c756d145ef4cda1d56a31f1e1936759bee04b7a8a815c76",
                    "8b835b8180f36e79ba79528e0d3401f439cc1c7f99e4bcfb3cb4aa2b60b1afc1",
                ],
                [
                    "a7bc55e955a825730f5dcdc3f8126717d7647cbca8a6b90e08b77269aeed3533",
                    "8da31e80698c9b5181b2e8d9773136083a34e3e72c92134d8201d9c368d89284",
                ],
                [
                    "a7902cec90d3f2de25c8ddc87075159fd00f219a51a1e7dcac17c2b8a91887e9",
                    "2b1e848b6649abefbd6b399504a169252358e7ff6bde8fa7a773b9cf0a167069",
                ],
                [
                    "9fc3d5fb7de8cfc59982f7b20f3f5c145ad191088e7f59c10908dc5d55863bee",
                    "b8de2bc9bb46d475007230a92af14afb6f9dd2804b5c31355a282b40ccdadc92",
                ],
            ],
        }

        self.verify_monero_generated(clsag)

    def test_clsag(self):
        self.gen_clsag_test(ring_size=11, index=None)
        self.gen_clsag_test(ring_size=11, index=None)
        self.gen_clsag_test(ring_size=11, index=None)
        self.gen_clsag_test(ring_size=11, index=0)
        self.gen_clsag_test(ring_size=11, index=9)
        self.gen_clsag_test(ring_size=11, index=10)
        self.gen_clsag_test(ring_size=2, index=0)

    def test_clsag_invalid_sI(self):
        res = self.gen_clsag_sig(ring_size=11, index=5)
        msg, scalars, sc1, sI, sD, ring2, Cp = res
        with self.assertRaises(ValueError):
            sI = point_mul8_into(None, sI)
            self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)

    def test_clsag_invalid_sD(self):
        res = self.gen_clsag_sig(ring_size=11, index=5)
        msg, scalars, sc1, sI, sD, ring2, Cp = res
        with self.assertRaises(ValueError):
            sD = crypto.scalarmult_base_into(None, crypto.random_scalar())
            self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)

    def test_clsag_invalid_P(self):
        res = self.gen_clsag_sig(ring_size=11, index=5)
        msg, scalars, sc1, sI, sD, ring2, Cp = res
        with self.assertRaises(ValueError):
            ring2[5].dest = crypto_helpers.encodepoint(
                point_mul8_into(None, crypto_helpers.decodepoint(ring2[5].dest))
            )
            self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)

    def test_clsag_invalid_P(self):
        res = self.gen_clsag_sig(ring_size=11, index=5)
        msg, scalars, sc1, sI, sD, ring2, Cp = res
        with self.assertRaises(ValueError):
            ring2[5].commitment = crypto_helpers.encodepoint(
                point_mul8_into(None, crypto_helpers.decodepoint(ring2[5].dest))
            )
            self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)

    def test_clsag_invalid_Cp(self):
        res = self.gen_clsag_sig(ring_size=11, index=5)
        msg, scalars, sc1, sI, sD, ring2, Cp = res
        with self.assertRaises(ValueError):
            Cp = crypto.point_add_into(
                None, Cp, crypto.scalarmult_base_into(None, crypto.Scalar(1))
            )
            self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)

    def test_clsag_invalid_index(self):
        res = self.gen_clsag_sig(ring_size=11, index=5)
        msg, scalars, sc1, sI, sD, ring2, Cp = res
        with self.assertRaises(ValueError):
            ring2[5], ring2[6] = ring2[6], ring2[5]
            self.verify_clsag(msg, scalars, sc1, sI, sD, ring2, Cp)


if __name__ == "__main__":
    unittest.main()