#!/usr/bin/env python3
from PIL import Image
import sys
import struct
import zlib
from os.path import basename


def process_rgb(w, h, pix):
    data = bytes()
    for j in range(h):
        for i in range(w):
            r, g, b = pix[i, j]
            c = ((r & 0xF8) << 8) | ((g & 0xFC) << 3) | ((b & 0xF8) >> 3)
            data += struct.pack(">H", c)
    return data


def process_grayscale(w, h, pix):
    data = bytes()
    for j in range(h):
        for i in range(w // 2):
            l1, l2 = pix[i * 2, j], pix[i * 2 + 1, j]
            c = (l1 & 0xF0) | (l2 >> 4)
            data += struct.pack(">B", c)
    return data


def process_image(ifn, savefiles=False):
    im = Image.open(ifn)
    w, h = im.size
    print("Opened %s ... %d x %d @ %s" % (ifn, w, h, im.mode))

    if im.mode == "RGB":
        print("Detected RGB mode")
    elif im.mode == "L":
        if w % 2 > 0:
            print("PNG file must have width divisible by 2")
            return 3
        print("Detected GRAYSCALE mode")
    else:
        print("Unknown mode:", im.mode)
        return 4

    pix = im.load()

    bname = basename(ifn[:-4])
    ofn_h = "%s.h" % bname
    ofn_py = "%s.py" % bname

    if im.mode == "RGB":
        ofn = "%s.toif" % bname
        pixeldata = process_rgb(w, h, pix)
    else:
        ofn = "%s.toig" % bname
        pixeldata = process_grayscale(w, h, pix)
    z = zlib.compressobj(level=9, wbits=10)
    zdata = z.compress(pixeldata) + z.flush()
    zdata = zdata[2:-4]  # strip header and checksum

    data = b""
    if im.mode == "RGB":
        data += b"TOIf"
    else:
        data += b"TOIg"
    data += struct.pack("<HH", w, h)
    data += struct.pack("<I", len(zdata))
    data += zdata

    if savefiles:
        with open(ofn, "wb") as f:
            f.write(data)
            print("Written %s ... %d bytes" % (ofn, len(data)))
        with open(ofn_py, "wb") as f:
            f.write(("%s = %s\n" % (bname, data)).encode())
            print("Written %s ... %d bytes" % (ofn_py, len(data)))
        with open(ofn_h, "wt") as f:
            f.write("// clang-format off\n")
            f.write("static const uint8_t toi_%s[] = {\n" % bname)
            f.write("    // magic\n")
            if im.mode == "RGB":
                f.write("    'T', 'O', 'I', 'f',\n")
            else:
                f.write("    'T', 'O', 'I', 'g',\n")
            f.write("    // width (16-bit), height (16-bit)\n")
            f.write(
                "    0x%02x, 0x%02x, 0x%02x, 0x%02x,\n"
                % (w & 0xFF, w >> 8, h & 0xFF, h >> 8)
            )
            l = len(zdata)
            f.write("    // compressed data length (32-bit)\n")
            f.write(
                "    0x%02x, 0x%02x, 0x%02x, 0x%02x,\n"
                % (l & 0xFF, (l >> 8) & 0xFF, (l >> 16) & 0xFF, l >> 24)
            )
            f.write("    // compressed data\n")
            f.write("   ")
            for b in zdata:
                f.write(" 0x%02x," % b)
            f.write("\n};\n")
            print("Written %s ... done" % ofn_py)

    return data


def main():
    if len(sys.argv) < 2:
        print("Usage png2toi image.png")
        return 1

    ifn = sys.argv[1]
    if not ifn.endswith(".png"):
        print("Must provide PNG file")
        return 2

    process_image(ifn, savefiles=True)


main()