from common import *

from trezor import protobuf
from trezor.messages import WebAuthnCredential, Failure, SignMessage, DebugLinkMemoryRead


def load_uvarint32(data: bytes) -> int:
    # use known uint32 field in an all-optional message
    buffer = bytearray(len(data) + 1)
    buffer[1:] = data
    buffer[0] = (1 << 3) | 0  # field number 1, wire type 0
    msg = protobuf.decode(buffer, WebAuthnCredential, False)
    return msg.index


def load_uvarint64(data: bytes) -> int:
    # use known uint64 field in an all-optional message
    buffer = bytearray(len(data) + 1)
    buffer[1:] = data
    buffer[0] = (2 << 3) | 0  # field number 1, wire type 0
    msg = protobuf.decode(buffer, DebugLinkMemoryRead, False)
    return msg.length


def dump_uvarint32(value: int) -> bytearray:
    # use known uint32 field in an all-optional message
    msg = WebAuthnCredential(index=value)
    length = protobuf.encoded_length(msg)
    buffer = bytearray(length)
    protobuf.encode(buffer, msg)
    assert buffer[0] == (1 << 3) | 0  # field number 1, wire type 0
    return buffer[1:]


def dump_uvarint64(value: int) -> bytearray:
    # use known uint64 field in an all-optional message
    msg = DebugLinkMemoryRead(length=value)
    length = protobuf.encoded_length(msg)
    buffer = bytearray(length)
    protobuf.encode(buffer, msg)
    assert buffer[0] == (2 << 3) | 0  # field number 2, wire type 0
    return buffer[1:]


def dump_message(msg: protobuf.MessageType) -> bytearray:
    length = protobuf.encoded_length(msg)
    buffer = bytearray(length)
    protobuf.encode(buffer, msg)
    return buffer


def load_message(msg_type: Type[protobuf.MessageType], buffer: bytes) -> protobuf.MessageType:
    return protobuf.decode(buffer, msg_type, False)


class TestProtobuf(unittest.TestCase):
    def test_dump_uvarint(self):
        for dump_uvarint in (dump_uvarint32, dump_uvarint64):
            self.assertEqual(dump_uvarint(0), b"\x00")
            self.assertEqual(dump_uvarint(1), b"\x01")
            self.assertEqual(dump_uvarint(0xFF), b"\xff\x01")
            self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07")
            with self.assertRaises(OverflowError):
                dump_uvarint(-1)

    def test_load_uvarint(self):
        for load_uvarint in (load_uvarint32, load_uvarint64):
            self.assertEqual(load_uvarint(b"\x00"), 0)
            self.assertEqual(load_uvarint(b"\x01"), 1)
            self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF)
            self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456)

    def test_validate_enum(self):
        # ok message:
        msg = Failure(code=7)
        msg_encoded = dump_message(msg)
        nmsg = load_message(Failure, msg_encoded)

        self.assertEqual(msg.code, nmsg.code)

        # bad enum value:
        msg = Failure(code=1000)
        msg_encoded = dump_message(msg)
        with self.assertRaises(ValueError):
            load_message(Failure, msg_encoded)

    def test_required(self):
        msg = SignMessage(message=b"hello", coin_name="foo", script_type=1)
        msg_encoded = dump_message(msg)
        nmsg = load_message(SignMessage, msg_encoded)

        self.assertEqual(nmsg.message, b"hello")
        self.assertEqual(nmsg.coin_name, "foo")
        self.assertEqual(nmsg.script_type, 1)

        # try a message without the required_field
        msg = SignMessage(message=None)
        # encoding always succeeds
        msg_encoded = dump_message(msg)
        with self.assertRaises(ValueError):
            load_message(SignMessage, msg_encoded)

        # try a message without the default field
        msg = SignMessage(message=b"hello")
        msg.coin_name = None
        msg_encoded = dump_message(msg)
        nmsg = load_message(SignMessage, msg_encoded)

        self.assertEqual(nmsg.message, b"hello")
        self.assertEqual(nmsg.coin_name, "Bitcoin")



if __name__ == "__main__":
    unittest.main()