#!/usr/bin/env python3
from PIL import Image
import sys
import struct
import zlib


def process_rgb(w, h, pix):
    data = bytes()
    for j in range(h):
        for i in range(w):
            r, g, b = pix[i, j]
            c = ((r & 0xF8) << 8) | ((g & 0xFC) << 3) | ((b & 0xF8) >> 3)
            data += struct.pack('>H', c)
    return data


def process_grayscale(w, h, pix):
    data = bytes()
    for j in range(h):
            for i in range(w // 2):
                l1, l2 = pix[i * 2, j], pix[i * 2 + 1, j]
                c = (l1 & 0xF0) | (l2 >> 4)
                data += struct.pack('>B', c)
    return data


def process_image(ifn):
    im = Image.open(ifn)
    w, h = im.size
    print('Opened %s ... %d x %d @ %s' % (ifn, w, h, im.mode))

    if im.mode == 'RGB':
        print('Detected RGB mode')
    elif im.mode == 'L':
        if w % 2 > 0:
            print('PNG file must have width divisible by 2')
            return 3
        print('Detected GRAYSCALE mode')
    else:
        print('Unknown mode:', im.mode)
        return 4

    pix = im.load()

    if im.mode == 'RGB':
        ofn = '%s.toif' % ifn[:-4]
        pixeldata = process_rgb(w, h, pix)
    else:
        ofn = '%s.toig' % ifn[:-4]
        pixeldata = process_grayscale(w, h, pix)
    z = zlib.compressobj(level=9, wbits=10)
    zdata = z.compress(pixeldata) + z.flush()
    zdata = zdata[2:-4] # strip header and checksum

    with open(ofn, 'wb') as f:
        if im.mode == 'RGB':
            f.write(bytes('TOIf', 'ascii'))
        else:
            f.write(bytes('TOIg', 'ascii'))
        f.write(struct.pack('<HH', w, h))
        f.write(struct.pack('<I', len(zdata)))
        f.write(zdata)
        print('Written %s ... %d bytes' % (ofn, 4 + 4 + len(zdata)))
    with open(ofn + '.h', 'wt') as f:
        f.write('static const uint8_t toi_%s[] = {\n' % ifn[:-4])
        if im.mode == 'RGB':
            f.write("    'T', 'O', 'I', 'f',\n")
        else:
            f.write("    'T', 'O', 'I', 'g',\n")
        f.write('    (uint16_t)%d, (uint16_t)%d,\n' % (w, h))
        f.write('    (uint32_t)%d,\n' % len(zdata))
        f.write('   ')
        for b in zdata:
            f.write(' 0x%02x,' % b)
        f.write('\n};\n')
        print('Written %s ...' % (ofn + '.h'))


def main():
    if len(sys.argv) < 2:
        print('Usage png2toi image.png')
        return 1

    ifn = sys.argv[1]
    if not ifn.endswith('.png'):
        print('Must provide PNG file')
        return 2

    process_image(ifn)


main()