#!/usr/bin/env python3
from PIL import Image
import sys
import struct
import zlib


def process_rgb(w, h, data):
    pix = bytearray(w * h * 3)
    for i in range(w * h):
        c = (data[i * 2] << 8) + data[i * 2 + 1]
        pix[i * 3 + 0] = (c & 0xF800) >> 8
        pix[i * 3 + 1] = (c & 0x07C0) >> 3
        pix[i * 3 + 2] = (c & 0x001F) << 3
    return bytes(pix)


def process_grayscale(w, h, data):
    pix = bytearray(w * h)
    for i in range(w * h // 2):
        pix[i * 2 + 0] = data[i] & 0xF0
        pix[i * 2 + 1] = (data[i] & 0x0F) << 4
    return bytes(pix)


def process_image(ifn, ofn):

    data = open(ifn, "rb").read()

    if ifn.endswith(".toif"):
        if data[:4] != b"TOIf":
            print("Unknown TOIF header")
            return 1
    elif ifn.endswith(".toig"):
        if data[:4] != b"TOIg":
            print("Unknown TOIG header")
            return 2
    else:
        print("Unsupported format")
        return 3
    if ofn is None:
        ofn = "%s.png" % ifn[:-5]

    w, h = struct.unpack("<HH", data[4:8])

    print("Opened %s ... %d x %d" % (ifn, w, h))

    l = struct.unpack("<I", data[8:12])[0]
    data = data[12:]
    if len(data) != l:
        print("Compressed data length mismatch (%d vs %d)" % (len(data), l))
        return 4
    data = zlib.decompress(data, -10)

    if ifn.endswith(".toif"):
        if len(data) != w * h * 2:
            print(
                "Uncompressed data length mismatch (%d vs %d)" % (len(data), w * h * 2)
            )
            return 5
        pix = process_rgb(w, h, data)
        img = Image.frombuffer("RGB", (w, h), pix, "raw", "RGB", 0, 1)
        img.save(ofn)
        print("Written %s ..." % ofn)

    if ifn.endswith(".toig"):
        if len(data) != w * h // 2:
            print(
                "Uncompressed data length mismatch (%d vs %d)" % (len(data), w * h // 2)
            )
            return 6
        pix = process_grayscale(w, h, data)
        img = Image.frombuffer("L", (w, h), pix, "raw", "L", 0, 1)
        img.save(ofn)
        print("Written %s ..." % ofn)


def main():
    if len(sys.argv) < 2:
        print("Usage: toi2png image.toi[fg] [output]")
        return 1

    ifn = sys.argv[1]
    if not ifn.endswith(".toif") and not ifn.endswith(".toig"):
        print("Must provide TOIF/TOIG file")
        return 2

    ofn = sys.argv[2] if len(sys.argv) > 2 else None
    process_image(ifn, ofn)


main()