#!/usr/bin/env python3
from PIL import Image
import sys
import struct
import zlib


def process_rgb(w, h, data):
    pix = bytearray(w * h * 3)
    for i in range(w * h):
        c = (data[i * 2] << 8) + data[i * 2 + 1]
        pix[i * 3 + 0] = (c & 0xF800) >> 8
        pix[i * 3 + 1] = (c & 0x07C0) >> 3
        pix[i * 3 + 2] = (c & 0x001F) << 3
    return bytes(pix)


def process_grayscale(w, h, data):
    pix = bytearray(w * h)
    for i in range(w * h // 2):
        pix[i * 2 + 0] = data[i] & 0xF0
        pix[i * 2 + 1] = (data[i] & 0x0F) << 4
    return bytes(pix)


def process_image(ifn):

    data = open(ifn, 'rb').read()

    if ifn.endswith('.toif'):
        if data[:4] != b'TOIf':
            print('Unknown TOIF header')
            return 1
    elif ifn.endswith('.toig'):
        if data[:4] != b'TOIg':
            print('Unknown TOIG header')
            return 2
    else:
        print('Unsupported format')
        return 3
    ofn = '%s.png' % ifn[:-5]

    w, h = struct.unpack('<HH', data[4:8])

    print('Opened %s ... %d x %d' % (ifn, w, h))

    l = struct.unpack('<I', data[8:12])[0]
    data = data[12:]
    if len(data) != l:
        print('Compressed data length mismatch (%d vs %d)' % (len(data), l))
        return 4
    data = zlib.decompress(data, -10)

    if ifn.endswith('.toif'):
        if len(data) != w * h * 2:
            print('Uncompressed data length mismatch (%d vs %d)' %
                  (len(data), w * h * 2))
            return 5
        pix = process_rgb(w, h, data)
        img = Image.frombuffer('RGB', (w, h), pix, 'raw', 'RGB', 0, 1)
        img.save(ofn)
        print('Written %s ...' % ofn)

    if ifn.endswith('.toig'):
        if len(data) != w * h // 2:
            print('Uncompressed data length mismatch (%d vs %d)' %
                  (len(data), w * h // 2))
            return 6
        pix = process_grayscale(w, h, data)
        img = Image.frombuffer('L', (w, h), pix, 'raw', 'L', 0, 1)
        img.save(ofn)
        print('Written %s ...' % ofn)


def main():
    if len(sys.argv) < 2:
        print('Usage: toi2png image.toi[fg]')
        return 1

    ifn = sys.argv[1]
    if not ifn.endswith('.toif') and not ifn.endswith('.toig'):
        print('Must provide TOIF/TOIG file')
        return 2

    process_image(ifn)


main()