# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from unittest.mock import patch
from types import SimpleNamespace

import pytest

from trezorlib import protobuf

SimpleEnum = SimpleNamespace(FOO=0, BAR=5, QUUX=13)
SimpleEnumType = protobuf.EnumType("SimpleEnum", (0, 5, 13))

with_simple_enum = patch("trezorlib.messages.SimpleEnum", SimpleEnum, create=True)


class SimpleMessage(protobuf.MessageType):
    @classmethod
    def get_fields(cls):
        return {
            1: ("uvarint", protobuf.UVarintType, None),
            2: ("svarint", protobuf.SVarintType, None),
            3: ("bool", protobuf.BoolType, None),
            4: ("bytes", protobuf.BytesType, None),
            5: ("unicode", protobuf.UnicodeType, None),
            6: ("enum", SimpleEnumType, None),
            7: ("rep_int", protobuf.UVarintType, protobuf.FLAG_REPEATED),
            8: ("rep_str", protobuf.UnicodeType, protobuf.FLAG_REPEATED),
            9: ("rep_enum", SimpleEnumType, protobuf.FLAG_REPEATED),
        }


class NestedMessage(protobuf.MessageType):
    @classmethod
    def get_fields(cls):
        return {
            1: ("scalar", protobuf.UVarintType, 0),
            2: ("nested", SimpleMessage, 0),
            3: ("repeated", SimpleMessage, protobuf.FLAG_REPEATED),
        }


class RequiredFields(protobuf.MessageType):
    @classmethod
    def get_fields(cls):
        return {
            1: ("scalar", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
        }


def test_get_field_type():
    # smoke test
    assert SimpleMessage.get_field_type("bool") is protobuf.BoolType

    # full field list
    for fname, ftype, _ in SimpleMessage.get_fields().values():
        assert SimpleMessage.get_field_type(fname) is ftype


@with_simple_enum
def test_enum_to_str():
    # smoke test
    assert SimpleEnumType.to_str(5) == "BAR"

    # full value list
    for name, value in SimpleEnum.__dict__.items():
        assert SimpleEnumType.to_str(value) == name
        assert SimpleEnumType.from_str(name) == value

    with pytest.raises(TypeError):
        SimpleEnumType.from_str("NotAValidValue")

    with pytest.raises(TypeError):
        SimpleEnumType.to_str(999)


@with_simple_enum
def test_dict_roundtrip():
    msg = SimpleMessage(
        uvarint=5,
        svarint=-13,
        bool=False,
        bytes=b"\xca\xfe\x00\xfe",
        unicode="žluťoučký kůň",
        enum=5,
        rep_int=[1, 2, 3],
        rep_str=["a", "b", "c"],
        rep_enum=[0, 5, 13],
    )

    converted = protobuf.to_dict(msg)
    recovered = protobuf.dict_to_proto(SimpleMessage, converted)

    assert recovered == msg


@with_simple_enum
def test_to_dict():
    msg = SimpleMessage(
        uvarint=5,
        svarint=-13,
        bool=False,
        bytes=b"\xca\xfe\x00\xfe",
        unicode="žluťoučký kůň",
        enum=5,
        rep_int=[1, 2, 3],
        rep_str=["a", "b", "c"],
        rep_enum=[0, 5, 13],
    )

    converted = protobuf.to_dict(msg)

    fields = [fname for fname, _, _ in msg.get_fields().values()]
    assert list(sorted(converted.keys())) == list(sorted(fields))

    assert converted["uvarint"] == 5
    assert converted["svarint"] == -13
    assert converted["bool"] is False
    assert converted["bytes"] == "cafe00fe"
    assert converted["unicode"] == "žluťoučký kůň"
    assert converted["enum"] == "BAR"
    assert converted["rep_int"] == [1, 2, 3]
    assert converted["rep_str"] == ["a", "b", "c"]
    assert converted["rep_enum"] == ["FOO", "BAR", "QUUX"]


@with_simple_enum
def test_recover_mismatch():
    dictdata = {
        "bool": True,
        "enum": "FOO",
        "another_field": "hello",
        "rep_enum": ["FOO", 5, 5],
    }
    recovered = protobuf.dict_to_proto(SimpleMessage, dictdata)

    assert recovered.bool is True
    assert recovered.enum is SimpleEnum.FOO
    assert not hasattr(recovered, "another_field")
    assert recovered.rep_enum == [SimpleEnum.FOO, SimpleEnum.BAR, SimpleEnum.BAR]

    for name, _, flags in SimpleMessage.get_fields().values():
        if name not in dictdata:
            if flags == protobuf.FLAG_REPEATED:
                assert getattr(recovered, name) == []
            else:
                assert getattr(recovered, name) is None


@with_simple_enum
def test_hexlify():
    msg = SimpleMessage(bytes=b"\xca\xfe\x00\x12\x34", unicode="žluťoučký kůň")
    converted_nohex = protobuf.to_dict(msg, hexlify_bytes=False)
    converted_hex = protobuf.to_dict(msg, hexlify_bytes=True)

    assert converted_nohex["bytes"] == b"\xca\xfe\x00\x12\x34"
    assert converted_nohex["unicode"] == "žluťoučký kůň"
    assert converted_hex["bytes"] == "cafe001234"
    assert converted_hex["unicode"] == "žluťoučký kůň"

    recovered_nohex = protobuf.dict_to_proto(SimpleMessage, converted_nohex)
    recovered_hex = protobuf.dict_to_proto(SimpleMessage, converted_hex)

    assert recovered_nohex.bytes == msg.bytes
    assert recovered_hex.bytes == msg.bytes


@with_simple_enum
def test_nested_round_trip():
    msg = NestedMessage(
        scalar=9,
        nested=SimpleMessage(uvarint=4, enum=SimpleEnum.FOO),
        repeated=[
            SimpleMessage(),
            SimpleMessage(rep_enum=[SimpleEnum.BAR, SimpleEnum.BAR]),
            SimpleMessage(bytes=b"\xca\xfe"),
        ],
    )

    converted = protobuf.to_dict(msg)
    recovered = protobuf.dict_to_proto(NestedMessage, converted)

    assert msg == recovered


@with_simple_enum
def test_nested_to_dict():
    msg = NestedMessage(
        scalar=9,
        nested=SimpleMessage(uvarint=4, enum=SimpleEnum.FOO),
        repeated=[
            SimpleMessage(),
            SimpleMessage(rep_enum=[SimpleEnum.BAR, SimpleEnum.BAR]),
            SimpleMessage(bytes=b"\xca\xfe"),
        ],
    )

    converted = protobuf.to_dict(msg)
    assert converted["scalar"] == 9
    assert isinstance(converted["nested"], dict)
    assert isinstance(converted["repeated"], list)

    rep = converted["repeated"]
    assert rep[0] == {}
    assert rep[1] == {"rep_enum": ["BAR", "BAR"]}
    assert rep[2] == {"bytes": "cafe"}


@with_simple_enum
def test_nested_recover():
    dictdata = {"nested": {}}
    recovered = protobuf.dict_to_proto(NestedMessage, dictdata)
    assert isinstance(recovered.nested, SimpleMessage)


@with_simple_enum
def test_unknown_enum_to_str():
    simple = SimpleMessage(enum=SimpleEnum.QUUX)
    string = protobuf.format_message(simple)
    assert "enum: QUUX (13)" in string

    simple = SimpleMessage(enum=6000)
    string = protobuf.format_message(simple)
    assert "enum: 6000" in string


@with_simple_enum
def test_unknown_enum_to_dict():
    simple = SimpleMessage(enum=6000)
    converted = protobuf.to_dict(simple)
    assert converted["enum"] == 6000


def test_constructor_deprecations():
    # ok:
    RequiredFields(scalar=0)

    # positional argument
    with pytest.deprecated_call():
        RequiredFields(0)

    # missing required value
    with pytest.deprecated_call():
        RequiredFields()

    # more args than fields
    with pytest.deprecated_call(), pytest.raises(TypeError):
        RequiredFields(0, 0)

    # colliding arg and kwarg
    with pytest.deprecated_call(), pytest.raises(TypeError):
        RequiredFields(0, scalar=0)