package com.wudsn.productions.atari800.badapplehd;

import java.awt.Point;
import java.awt.Transparency;
import java.awt.image.*;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

import javax.imageio.ImageIO;

import com.wudsn.gfx.avi.Log;

public class TileStreamPacker {

    public static class CountingOutpuStream extends OutputStream {

	private OutputStream os;
	private int count;

	public CountingOutpuStream(OutputStream os) {
	    if (os == null) {
		throw new IllegalArgumentException(
			"Parameter os must not be null.");
	    }
	    this.os = os;
	}

	public void writeByte(int value) throws IOException {
	    write(value & 0xff);
	}

	public void writeWord(int value) throws IOException {
	    writeByte(value);
	    writeByte(value >>> 8);
	}

	public void writeBytes(byte[] data) throws IOException {
	    if (data == null) {
		throw new IllegalArgumentException(
			"Parameter data must not be null.");
	    }
	    os.write(data);
	    count += data.length;
	}

	public void writeSegment(int start, byte[] data) throws IOException {
	    if (data == null) {
		throw new IllegalArgumentException(
			"Parameter data must not be null.");
	    }
	    writeWord(start);
	    writeWord(start + data.length - 1);
	    writeBytes(data);
	}

	public int writePadding(int offset, int alignment) throws IOException {
	    int padding = (((offset + alignment - 1) / alignment) * alignment)
		    - offset;
	    for (int i = 0; i < padding; i++) {
		writeByte(42);
	    }
	    return padding;
	}

	@Override
	public void write(int b) throws IOException {
	    os.write(b);
	    count++;

	}

	@Override
	public void close() throws IOException {
	    os.close();
	}

	public int getCount() {
	    return count;
	}

    }

    private DataInputStream headerDataInputStream;
    private DataInputStream tileDataInputStream;
    private DataInputStream soundDataInputStream;

    private boolean png;
    private File pngFolder;

    private boolean rom;
    private CountingOutpuStream romOutputStream;
    private CountingOutpuStream romControlOutputStream;

    private boolean xex;
    private CountingOutpuStream xexOutputStream;
    private CountingOutpuStream screenMemoryOutputStream;
    private CountingOutpuStream charsetTableOutputStream;
    private CountingOutpuStream charsetOutputStream;

    public static void main(String[] args) {
	TileStreamPacker instance = new TileStreamPacker();
	System.exit(instance.run(args));
    }

    public int run(String[] args) {

	String headerFilePath;
	String tileFilePath;

	String soundFilePath;
	String outputFilePrefix;
	String outputFormats;

	int result = 0;

	if (args.length >= 4) {
	    headerFilePath = args[0];
	    tileFilePath = args[1];
	    soundFilePath = args[2];
	    outputFilePrefix = args[3];
	    if (args.length >= 5) {
		outputFormats = args[4];
	    } else {
		outputFormats = "rom";
	    }
	    png = (outputFormats.indexOf("png") >= 0);
	    rom = (outputFormats.indexOf("rom") >= 0);
	    xex = (outputFormats.indexOf("xex") >= 0);

	} else {
	    System.err
		    .println("ERROR: Invalid arguments '"
			    + Arrays.toString(args)
			    + "'.\nUSAGE: TileStreamPacker headerFilePath tileFilePath soundFilePath outputFilePrefix outputFormats[xex,rom,png]");
	    return -1;
	}

	System.out.println("Header Input File: " + headerFilePath);
	System.out.println("Tile Input File: " + tileFilePath);
	System.out.println("Sound Input File: " + soundFilePath);
	System.out.println("Output File Prefix: " + outputFilePrefix);
	Log log = Log.create(System.out, true);
	FrameSequence frameSequence = null;
	try {

	    headerDataInputStream = createInputStream(headerFilePath);
	    tileDataInputStream = createInputStream(tileFilePath);
	    soundDataInputStream = createInputStream(soundFilePath);
	    frameSequence = read(log);
	} catch (IOException ex) {
	    ex.printStackTrace();
	    result = -1;
	} finally {

	    if (!closeInputStream(headerDataInputStream)) {
		result = -1;
	    }
	    if (!closeInputStream(tileDataInputStream)) {
		result = -1;
	    }
	    if (!closeInputStream(soundDataInputStream)) {
		result = -1;
	    }

	}
	if (result != 0) {
	    return result;
	}
	analyzeFrameSequence(log, frameSequence);
	convertFrameSequence(log, frameSequence);
	result = writeFrameSequence(log, frameSequence, outputFilePrefix);
	if (result != 0) {
	    return result;
	}

	return result;
    }

    private static boolean closeInputStream(InputStream inputStream) {
	if (inputStream != null) {
	    try {
		inputStream.close();
	    } catch (IOException ex) {
		ex.printStackTrace();
		return false;
	    }
	}
	return true;
    }

    private static DataInputStream createInputStream(String filePath)
	    throws FileNotFoundException {
	return new DataInputStream(new BufferedInputStream(new FileInputStream(
		new File(filePath))));
    }

    private FrameSequence read(Log log) throws IOException {

	int maxFrames = Integer.MAX_VALUE;
	Screen screen = new Screen(8);

	screen.columns = headerDataInputStream.readInt();
	screen.rows = headerDataInputStream.readInt();
	screen.tiles = headerDataInputStream.readInt();

	Sound sound = new Sound();
	sound.soundMemorySize = 1;

	log.printLongHex("screen.columns", screen.columns);
	log.printLongHex("screen.rows", screen.rows);
	log.printLongHex("screen.tiles", screen.tiles);
	log.printLongHex("sound.soundMemorySize", sound.soundMemorySize);
	log.println();

	FrameSequence frameSequence = new FrameSequence(screen, sound);
	frameSequence.readFrom(tileDataInputStream, soundDataInputStream,
		maxFrames);
	return frameSequence;
    }

    private static void analyzeFrameSequence(Log log,
	    FrameSequence frameSequence) {
	Map<Tile, Long> tileCountMap = new TreeMap<Tile, Long>();

	int maxUsedTilesCount = 0;
	int allFramesUsedTilesCount = 0;
	int sameTilesCount1 = 0;
	int sameTilesCount2 = 0;
	Screen screen = frameSequence.screen;
	for (int f = 0; f < frameSequence.frames.size(); f++) {
	    Frame frame = frameSequence.frames.get(f);
	    log.printLong("frameCount", f);
	    tileCountMap.putAll(frame.tileCountMap);
	    int usedTilesCount = frame.getUsedTilesCount();
	    log.printLong("usedTilesCount", usedTilesCount);
	    maxUsedTilesCount = Math.max(maxUsedTilesCount, usedTilesCount);
	    allFramesUsedTilesCount += usedTilesCount;
	    sameTilesCount1 += findSameTiles(log, frameSequence, f, f - 1);
	    sameTilesCount2 += findSameTiles(log, frameSequence, f, f - 2);

	    log.println();
	}
	long distinctTileCount = tileCountMap.size();
	log.printLong("frameCount", frameSequence.frames.size());
	log.printLongHex("allFramesTilesCount", frameSequence.frames.size()
		* screen.tiles);
	log.printLongHex("allFramesUsedTilesCount", allFramesUsedTilesCount);
	log.println();

	log.printLongHex("distinctTileCount", distinctTileCount);
	log.printLongHex("distinctTileBytes", distinctTileCount
		* screen.bytesPerTile);
	log.println();

	log.printLongHex("maxUsedTilesCount", maxUsedTilesCount);
	log.printLongHex("maxUsedTilesBytes", maxUsedTilesCount
		* screen.bytesPerTile);

	log.printLongHex("sameTilesCount1", sameTilesCount1);
	log.printLongHex("sameTilesCount2", sameTilesCount2);
	log.println();
    }

    private static int findSameTiles(Log log, FrameSequence frameSequence,
	    int current, int previous) {
	int sameTileCount = 0;
	if (previous > 0) {
	    Frame frame = frameSequence.frames.get(current);
	    Frame previousFrame = frameSequence.frames.get(previous);
	    Set<Tile> sameTiles = new TreeSet<Tile>(frame.tileCountMap.keySet());
	    sameTiles.retainAll(previousFrame.tileCountMap.keySet());
	    sameTileCount = sameTiles.size();
	    log.printLong("sameTileCount(" + (previous - current) + ")",
		    sameTileCount);
	    log.printLong("sameTileCountPercent",
		    sameTileCount * 100 / frame.getUsedTilesCount());
	}
	return sameTileCount;
    }

    private static void convertFrameSequence(Log log,
	    FrameSequence frameSequence) {
	for (int f = 0; f < frameSequence.frames.size(); f++) {
	    Frame frame = frameSequence.frames.get(f);
	    frame.convert();
	}

    }

    private int writeFrameSequence(Log log, FrameSequence frameSequence,
	    String fileNamePrefix) {
	int result = 0;
	try {
	    if (xex) {
		xexOutputStream = createOutputStream(fileNamePrefix, "-xex.bin");
		screenMemoryOutputStream = createOutputStream(fileNamePrefix,
			"-screenMemory.bin");
		charsetTableOutputStream = createOutputStream(fileNamePrefix,
			"-charsetTable.bin");
		charsetOutputStream = createOutputStream(fileNamePrefix,
			"-charset.bin");
		writeFrameSequenceAsXEX(log, frameSequence);
	    }
	    if (rom) {
		romOutputStream = createOutputStream(fileNamePrefix, "-rom.bin");
		romControlOutputStream = createOutputStream(fileNamePrefix,
			"-rom-control.bin");
		writeFrameSequenceAsROM(log, frameSequence);
	    }
	    if (png) {
		pngFolder = new File(fileNamePrefix + ".png");
		pngFolder.mkdirs();
		writeFrameSequenceAsPNG(log, frameSequence);
	    }
	    log.printLong("WriteSequence.Done", 1);
	    log.println();
	} catch (IOException ex) {
	    ex.printStackTrace();
	    result = -1;
	} finally {
	    closeOutputStream(xexOutputStream);
	    closeOutputStream(screenMemoryOutputStream);
	    closeOutputStream(charsetTableOutputStream);
	    closeOutputStream(charsetOutputStream);
	    closeOutputStream(romOutputStream);
	    closeOutputStream(romControlOutputStream);

	}
	return result;
    }

    private static CountingOutpuStream createOutputStream(String fileName,
	    String suffix) throws FileNotFoundException {
	return new CountingOutpuStream(new BufferedOutputStream(
		new FileOutputStream(fileName + suffix)));
    }

    private static void closeOutputStream(OutputStream outputStream) {
	if (outputStream == null) {
	    return;
	}

	try {
	    outputStream.close();
	} catch (IOException ex) {
	    ex.printStackTrace();
	    System.exit(1);

	}
    }

    private void writeFrameSequenceAsPNG(Log log, FrameSequence frameSequence)
	    throws IOException {
	int columns = frameSequence.screen.columns;
	int rows = frameSequence.screen.rows;
	int bytesPerTile = frameSequence.screen.bytesPerTile;
	int width = columns * 8;
	int height = rows * bytesPerTile;

	int pixelPerByte = 8;
	for (int f = 0; f < frameSequence.frames.size(); f++) {

	    byte[] aByteArray = new byte[width * height * 3];

	    DataBuffer buffer = new DataBufferByte(aByteArray,
		    aByteArray.length);

	    // 3 bytes per pixel: red, green, blue
	    WritableRaster raster = Raster.createInterleavedRaster(buffer,
		    width, height, 3 * width, 3, new int[] { 0, 1, 2 },
		    (Point) null);

	    int[] color = new int[] { 255, 255, 255 };

	    Frame frame = frameSequence.frames.get(f);
	    int sm = 0;
	    for (int row = 0; row < rows; row++) {
		try {
		    int tileSetIndex = frame.tileSetTable[row];
		    TileSet tileSet = frame.tileSets.get(tileSetIndex);
		    byte[] charSet = tileSet.getTilesAsBytes();
		    int yOffset = row * bytesPerTile;
		    for (int column = 0; column < columns; column++) {
			int xOffset = column * pixelPerByte;
			byte c = frame.screenMemory[sm++];
			boolean inverse;
			if (c < 0) {
			    c = (byte) (c & 0x7f);
			    inverse = true;
			} else {
			    inverse = false;
			}
			int cOffset = c * bytesPerTile;
			for (int y = 0; y < bytesPerTile; y++) {
			    int p = 0;
			    p = c;
			    try {
				p = charSet[cOffset + y];
			    } catch (ArrayIndexOutOfBoundsException ex) {
				throw new RuntimeException("Exception at row "
					+ row + ", column=" + column, ex);
			    }
			    if (inverse) {
				p = p ^ 0xff;
			    }
			    int py = yOffset + y;
			    for (int x = 0; x < pixelPerByte; x++) {
				int px = xOffset + x;
				try {
				    if ((p & 0x80) == 0) {
					raster.setPixel(px, py, color);
				    }
				} catch (ArrayIndexOutOfBoundsException ex) {
				    throw new RuntimeException(
					    "Exception at px=" + px + ", py="
						    + py);
				}
				p = p << 1;
			    }
			}
		    }
		} catch (RuntimeException ex) {
		    throw new RuntimeException("Exception in frame " + f, ex);
		}
	    }

	    ColorModel cm = new ComponentColorModel(ColorModel.getRGBdefault()
		    .getColorSpace(), false, true, Transparency.OPAQUE,
		    DataBuffer.TYPE_BYTE);
	    BufferedImage image = new BufferedImage(cm, raster, true, null);

	    String formatted = String.format("%05d", f);
	    ImageIO.write(image, "png", new File(pngFolder, "Frame-"
		    + formatted + ".png"));

	}

    }

    private void writeFrameSequenceAsROM(Log log, FrameSequence frameSequence)
	    throws IOException {

	final byte control_same_frame = 0;
	final byte control_next_bank = 1;
	final byte control_last_frame = 2;

	final int charsetSize = 0x400;
	final int bankSize = 0x4000;
	final int baseAddress = 0x8000;

	int offset = 0;
	int totalPadding = 0;
	Frame previousFrame = null;

	for (int f = 0; f < frameSequence.frames.size(); f++) {

	    boolean details = (f < 0x20);
	    Frame frame = frameSequence.frames.get(f);

	    if (frame.contentEquals(previousFrame)) {
		romControlOutputStream.writeByte(control_same_frame);
		continue;
	    }
	    previousFrame = frame;

	    // Write tiles always aligned to 1k / 0x0400
	    int tileMemorySize = 0;
	    int tileSetsSize = frame.tileSets.size();
	    for (int i = 0; i < tileSetsSize; i++) {
		byte[] tileAsBytes = frame.getTilesAsBytesWithPadding(i,
			charsetSize);
		int length = tileAsBytes.length;
		tileMemorySize += length;
	    }

	    // Check if everything fits into the current bank
	    int screenMemorySize = frame.screenMemory.length;
	    int tileSetTableSize = frame.tileSetTable.length;
	    int frameSize = tileMemorySize + screenMemorySize
		    + tileSetTableSize;

	    if (offset + frameSize > bankSize) {
		romControlOutputStream.writeByte(control_next_bank);
		int bankPadding = romOutputStream
			.writePadding(offset, bankSize);
		totalPadding += bankPadding;
		offset = 0;
	    }

	    // Write charsets aligned to 1k boundaries
	    for (int i = 0; i < frame.tileSets.size(); i++) {
		byte[] tileAsBytes = frame.getTilesAsBytesWithPadding(i,
			charsetSize);
		if (details) {

		    log.printLongHex("ROM frame", f);
		    log.printLongHex("romOutputStream.count",
			    romOutputStream.getCount());
		    log.printLongHex("charset", i);
		    log.println();
		}
		romOutputStream.writeBytes(tileAsBytes);

	    }

	    // Write screen memory so it does not cross a 4k boundary
	    int screenMemoryAddress = baseAddress + offset + tileMemorySize;
	    int screenMemoryEnd = screenMemoryAddress + screenMemorySize;
	    int screenMemoryPadding = 0;

	    if ((screenMemoryAddress & 0xfff) > (screenMemoryEnd & 0xfff)) {
		screenMemoryPadding = romOutputStream.writePadding(
			screenMemoryAddress, 0x1000);
		screenMemoryAddress += screenMemoryPadding;
		totalPadding += screenMemoryPadding;
	    }

	    if (details) {
		log.printLongHex("ROM frame", f);
		log.printLongHex("romOutputStream.count",
			romOutputStream.getCount());
		log.printLongHex("screenMemoryAddress", screenMemoryAddress);
		log.println();
	    }

	    // Write screen memory address with non-zero high-byte first
	    romControlOutputStream.writeByte(screenMemoryAddress / 256);
	    romControlOutputStream.writeByte(screenMemoryAddress & 255);

	    romOutputStream.writeBytes(frame.screenMemory);

	    int chartSetAddress = baseAddress + offset;
	    int length = frame.tileSetTable.length;
	    byte[] charsetTable = new byte[length];
	    for (int i = 0; i < frame.tileSetTable.length; i++) {
		charsetTable[i] = (byte) ((chartSetAddress + charsetSize
			* frame.tileSetTable[i]) / 256);
	    }
	    romOutputStream.writeBytes(charsetTable);

	    offset += frameSize + screenMemoryPadding;
	    int padding = romOutputStream.writePadding(offset, 0x400);
	    offset += padding;
	    totalPadding += padding;

	}
	totalPadding += romOutputStream.writePadding(offset, bankSize);

	// Also align control stream
	romControlOutputStream.writeByte(control_last_frame);
	romControlOutputStream.writePadding(romControlOutputStream.getCount(),
		0x2000);

	log.printLong("totalPadding", totalPadding);
	log.println();

    }

    private void writeFrameSequenceAsXEX(Log log, FrameSequence frameSequence)
	    throws IOException {

	boolean blockStream = true;
	byte[] blockMemory = new byte[65535];

	final int chrsize = 0x400;
	Page page1 = new Page(0x4a70, 0x4fe0, 0x5000);
	Page page2 = Page.createPageWithOffset(page1, 0x2000);
	for (int f = 0; f < frameSequence.frames.size(); f++) {

	    int blockStart = 0;
	    int blockEnd = 0;
	    Frame frame = frameSequence.frames.get(f);
	    Page page = (f & 1) == 0 ? page1 : page2;

	    // Skip identical screens
	    if (!Arrays.equals(frame.screenMemory, page.lastScreenMemory)) {
		if (page.lastScreenMemory == null) {
		    page.lastScreenMemory = new byte[frame.screenMemory.length];
		}
		System.arraycopy(frame.screenMemory, 0, page.lastScreenMemory,
			0, frame.screenMemory.length);

		screenMemoryOutputStream.writeBytes(frame.screenMemory);
		if (!blockStream) {
		    xexOutputStream.writeSegment(page.sm, frame.screenMemory);
		} else {
		    blockStart = page.sm;
		    blockEnd = page.sm + frame.screenMemory.length;
		    System.arraycopy(frame.screenMemory, 0, blockMemory,
			    page.sm, frame.screenMemory.length);
		}

	    }

	    // Charset table
	    if (!Arrays.equals(frame.tileSetTable, page.lastTileSetTable)) {
		if (page.lastTileSetTable == null) {
		    page.lastTileSetTable = new int[frame.tileSetTable.length];
		}
		System.arraycopy(frame.tileSetTable, 0, page.lastTileSetTable,
			0, frame.tileSetTable.length);

		int length = frame.tileSetTable.length;
		byte[] charsetTable = new byte[length];
		for (int i = 0; i < frame.tileSetTable.length; i++) {
		    charsetTable[i] = (byte) ((page.chr + chrsize
			    * frame.tileSetTable[i]) / 256);
		}
		charsetTableOutputStream.writeBytes(charsetTable);
		if (!blockStream) {
		    xexOutputStream.writeSegment(page.chrtab, charsetTable);
		} else {
		    if (blockStart == 0) {
			blockStart = page.chrtab;
		    }
		    blockEnd = page.chrtab + frame.screenMemory.length;
		    System.arraycopy(charsetTable, 0, blockMemory, page.chrtab,
			    charsetTable.length);
		}

	    }

	    // Write only non-zero tiles
	    for (int i = 0; i < frame.tileSets.size(); i++) {
		TileSet tileSet = frame.tileSets.get(i);
		int usedTilesCount = tileSet.getSize();
		if (usedTilesCount > 1) {
		    byte[] tileAsBytes = tileSet.getTilesAsBytes();
		    int length = tileAsBytes.length;
		    int chrbase = page.chr + chrsize * i;

		    // Character 0 is always the blank, so skip it
		    int start = tileSet.getBytesPerTile();
		    int effectiveLength = length - start;

		    charsetOutputStream.write(tileAsBytes, start,
			    effectiveLength);

		    if (!blockStream) {
			xexOutputStream.writeWord(chrbase + start);
			xexOutputStream.writeWord(chrbase + length - 1);
			xexOutputStream.write(tileAsBytes, start,
				effectiveLength);
		    } else {
			if (blockStart == 0) {
			    blockStart = chrbase;
			}
			blockEnd = chrbase + length;

			System.arraycopy(tileAsBytes, 0, blockMemory, chrbase,
				length);
		    }
		}
	    }
	    if (blockStream && blockStart != 0) {
		log.printLong("frame", f);
		log.printLongHex("blockStart", blockStart);
		log.printLongHex("blockEnd", blockEnd);
		log.println();
		xexOutputStream.writeWord(blockStart);
		xexOutputStream.writeWord(blockEnd - 1);
		xexOutputStream.write(blockMemory, blockStart, blockEnd
			- blockStart);
	    }
	    xexOutputStream.writeWord(0x02e2);
	    xexOutputStream.writeWord(0x02e3);
	    xexOutputStream.writeWord(0x2003);
	}

    }
}
