# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.

from enum import IntEnum
from io import BytesIO
import logging

import pytest

from trezorlib import messages, protobuf


class SomeEnum(IntEnum):
    Zero = 0
    Five = 5
    TwentyFive = 25


class WiderEnum(IntEnum):
    One = 1
    Two = 2
    Three = 3
    Four = 4
    Five = 5


class NarrowerEnum(IntEnum):
    One = 1
    Five = 5


class PrimitiveMessage(protobuf.MessageType):
    FIELDS = {
        1: protobuf.Field("uvarint", "uint64"),
        2: protobuf.Field("svarint", "sint64"),
        3: protobuf.Field("bool", "bool"),
        4: protobuf.Field("bytes", "bytes"),
        5: protobuf.Field("unicode", "string"),
        6: protobuf.Field("enum", "SomeEnum"),
    }


class EnumMessageMoreValues(protobuf.MessageType):
    FIELDS = {1: protobuf.Field("enum", "WiderEnum")}


class EnumMessageLessValues(protobuf.MessageType):
    FIELDS = {1: protobuf.Field("enum", "NarrowerEnum")}


class RepeatedFields(protobuf.MessageType):
    FIELDS = {
        1: protobuf.Field("uintlist", "uint64", repeated=True),
        2: protobuf.Field("enumlist", "SomeEnum", repeated=True),
        3: protobuf.Field("strlist", "string", repeated=True),
    }


class RequiredFields(protobuf.MessageType):
    FIELDS = {
        1: protobuf.Field("uvarint", "uint64", required=True),
        2: protobuf.Field("nested", "PrimitiveMessage", required=True),
    }


class DefaultFields(protobuf.MessageType):
    FIELDS = {
        1: protobuf.Field("uvarint", "uint32", default=42),
        2: protobuf.Field("svarint", "sint32", default=-42),
        3: protobuf.Field("bool", "bool", default=True),
        4: protobuf.Field("bytes", "bytes", default=b"hello"),
        5: protobuf.Field("unicode", "string", default="hello"),
        6: protobuf.Field("enum", "SomeEnum", default=SomeEnum.Five),
    }


class RecursiveMessage(protobuf.MessageType):
    FIELDS = {
        1: protobuf.Field("uvarint", "uint64"),
        2: protobuf.Field("recursivefield", "RecursiveMessage", required=False)
    }


# message types are read from the messages module so we need to "include" these messages there for now
messages.SomeEnum = SomeEnum
messages.WiderEnum = WiderEnum
messages.NarrowerEnum = NarrowerEnum
messages.PrimitiveMessage = PrimitiveMessage
messages.EnumMessageMoreValues = EnumMessageMoreValues
messages.EnumMessageLessValues = EnumMessageLessValues
messages.RepeatedFields = RepeatedFields
messages.RequiredFields = RequiredFields
messages.DefaultFields = DefaultFields
messages.RecursiveMessage = RecursiveMessage


def load_uvarint(buffer):
    reader = BytesIO(buffer)
    return protobuf.load_uvarint(reader)


def dump_uvarint(value):
    writer = BytesIO()
    protobuf.dump_uvarint(writer, value)
    return writer.getvalue()


def load_message(buffer, msg_type):
    reader = BytesIO(buffer)
    return protobuf.load_message(reader, msg_type)


def dump_message(msg):
    writer = BytesIO()
    protobuf.dump_message(writer, msg)
    return writer.getvalue()


def test_dump_uvarint():
    assert dump_uvarint(0) == b"\x00"
    assert dump_uvarint(1) == b"\x01"
    assert dump_uvarint(0xFF) == b"\xff\x01"
    assert dump_uvarint(123456) == b"\xc0\xc4\x07"

    with pytest.raises(ValueError):
        dump_uvarint(-1)


def test_load_uvarint():
    assert load_uvarint(b"\x00") == 0
    assert load_uvarint(b"\x01") == 1
    assert load_uvarint(b"\xff\x01") == 0xFF
    assert load_uvarint(b"\xc0\xc4\x07") == 123456
    assert load_uvarint(b"\x80\x80\x80\x80\x00") == 0


def test_broken_uvarint():
    with pytest.raises(IOError):
        load_uvarint(b"\x80\x80")


def test_sint_uint():
    """
    Protobuf interleaved signed encoding
    https://developers.google.com/protocol-buffers/docs/encoding#structure
    LSbit is sign, rest is shifted absolute value.
    Or, by example, you count like so: 0, -1, 1, -2, 2, -3 ...
    """
    assert protobuf.sint_to_uint(0) == 0
    assert protobuf.uint_to_sint(0) == 0

    assert protobuf.sint_to_uint(-1) == 1
    assert protobuf.sint_to_uint(1) == 2

    assert protobuf.uint_to_sint(1) == -1
    assert protobuf.uint_to_sint(2) == 1

    # roundtrip:
    assert protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)) == 1234567891011
    assert protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))) == -(2 ** 32)


def test_simple_message():
    msg = PrimitiveMessage(
        uvarint=12345678910,
        svarint=-12345678910,
        bool=True,
        bytes=b"\xDE\xAD\xCA\xFE",
        unicode="Příliš žluťoučký kůň úpěl ďábelské ódy 😊",
        enum=SomeEnum.Five,
    )

    buf = dump_message(msg)
    retr = load_message(buf, PrimitiveMessage)

    assert msg == retr
    assert retr.uvarint == 12345678910
    assert retr.svarint == -12345678910
    assert retr.bool is True
    assert retr.bytes == b"\xDE\xAD\xCA\xFE"
    assert retr.unicode == "Příliš žluťoučký kůň úpěl ďábelské ódy 😊"
    assert retr.enum == SomeEnum.Five
    assert retr.enum == 5


def test_validate_enum(caplog):
    caplog.set_level(logging.INFO)
    # round-trip of a valid value
    msg = EnumMessageMoreValues(enum=WiderEnum.Five)
    buf = dump_message(msg)
    retr = load_message(buf, EnumMessageLessValues)
    assert retr.enum == msg.enum

    assert not caplog.records

    # dumping an invalid enum value fails
    msg.enum = 19
    with pytest.raises(
        ValueError, match="Value 19 in field enum unknown for WiderEnum"
    ):
        dump_message(msg)

    msg.enum = WiderEnum.Three
    buf = dump_message(msg)
    retr = load_message(buf, EnumMessageLessValues)

    assert len(caplog.records) == 1
    record = caplog.records.pop(0)
    assert record.levelname == "INFO"
    assert record.getMessage() == "On field enum: 3 is not a valid NarrowerEnum"
    assert retr.enum == 3


def test_repeated():
    msg = RepeatedFields(
        uintlist=[1, 2, 3], enumlist=[0, 5, 0, 5], strlist=["hello", "world"]
    )
    buf = dump_message(msg)
    retr = load_message(buf, RepeatedFields)

    assert retr == msg


def test_packed():
    values = [4, 44, 444]
    packed_values = b"".join(dump_uvarint(v) for v in values)
    field_id = 1 << 3 | 2  # field number 1, wire type 2
    field_len = len(packed_values)
    message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values

    msg = load_message(message_bytes, RepeatedFields)
    assert msg
    assert msg.uintlist == values
    assert not msg.enumlist
    assert not msg.strlist


def test_packed_enum():
    values = [0, 0, 0, 0]
    packed_values = b"".join(dump_uvarint(v) for v in values)
    field_id = 2 << 3 | 2  # field number 2, wire type 2
    field_len = len(packed_values)
    message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values

    msg = load_message(message_bytes, RepeatedFields)
    assert msg
    assert msg.enumlist == values
    assert not msg.uintlist
    assert not msg.strlist


def test_required():
    msg = RequiredFields(uvarint=3, nested=PrimitiveMessage())
    buf = dump_message(msg)
    msg_ok = load_message(buf, RequiredFields)

    assert msg_ok == msg

    with pytest.deprecated_call():
        msg = RequiredFields(uvarint=3)
    with pytest.raises(ValueError):
        # cannot encode instance without the required fields
        dump_message(msg)

    msg = RequiredFields(uvarint=3, nested=None)
    # we can always encode an invalid message
    buf = dump_message(msg)
    with pytest.raises(ValueError):
        # required field `nested` is also not sent
        load_message(buf, RequiredFields)

    msg = RequiredFields(uvarint=None, nested=PrimitiveMessage())
    buf = dump_message(msg)
    with pytest.raises(ValueError):
        # required field `uvarint` is not sent
        load_message(buf, RequiredFields)


def test_default():
    # load empty message
    retr = load_message(b"", DefaultFields)
    assert retr.uvarint == 42
    assert retr.svarint == -42
    assert retr.bool is True
    assert retr.bytes == b"hello"
    assert retr.unicode == "hello"
    assert retr.enum == SomeEnum.Five

    msg = DefaultFields(uvarint=0)
    buf = dump_message(msg)
    retr = load_message(buf, DefaultFields)
    assert retr.uvarint == 0

    msg = DefaultFields(uvarint=None)
    buf = dump_message(msg)
    retr = load_message(buf, DefaultFields)
    assert retr.uvarint == 42


def test_recursive():
    msg = RecursiveMessage(
        uvarint=1,
        recursivefield=RecursiveMessage(
            uvarint=2,
            recursivefield=RecursiveMessage(
                uvarint=3
            )
        )
    )

    buf = dump_message(msg)
    retr = load_message(buf, RecursiveMessage)

    assert msg == retr
    assert retr.uvarint == 1
    assert type(retr.recursivefield) == RecursiveMessage
    assert retr.recursivefield.uvarint == 2
    assert type(retr.recursivefield.recursivefield) == RecursiveMessage
    assert retr.recursivefield.recursivefield.uvarint == 3