#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>

#define ROM_HEADER_SIZE 0x200
#define OBOCCHAMA_KUN_BASE_ADDRESS 0xe982
#define MARCHEN_MAZE_BASE_ADDRESS 0xf8c8
#define ZIPANG_BASE_ADDRESS 0xf270

unsigned int g_base_address = ZIPANG_BASE_ADDRESS;
uint8_t g_mpr[8];

off_t get_physical_address(unsigned int address)
{
    return (g_mpr[(address >> 13) & 0xff] << 13) | (address & 0x1fff);
}

/*
 * Extract the encoding type of the current 32 bytes bloc.
 * The encoding type is stored in 2 bits packed in a byte.
 */
uint8_t compute_encoding(uint8_t *header_nibble, uint8_t **header)
{
    uint8_t encoding;
    
    if(*header_nibble == 4)
    {
        ++(*header);
        *header_nibble = 0;
    }

    encoding = (**header) >> ((*header_nibble) * 2);
    ++(*header_nibble);
    
    return encoding & 0x03;
} 

/*
 * "RLE" decoding routine.
 * Strictly speaking, the data is not RLE encoded.
 * First you have a 4 bytes header. If a bit is set
 * the output byte is read from the data following
 * the header.If the bit is 0, then the output byte
 * is set to 0.
 * As the header is 4 bytes long, the maximum data 
 * length is 32 bytes. 
 */
uint8_t* rle_decode(uint8_t *rom,  uint8_t *buffer)
{
    uint8_t *header, *data, byte;
    int i, j;
    
    /* The header is 4 bytes long. */
    header = rom;
    /* Data just follow the header. */
    data   = rom + 4;

    for(i=0; i<4; ++i)
    {
        for(j=0; j<8; j++)
        {
            if(header[i] & (1<<j))
            {
                /* If the header bit is set, read byte
                 * from data. */
                byte = *data++;
            }
            else
            {
                /* The output byte is 0. */
                byte = 0x00;
            }
            
            *buffer++ = byte;
        }
    }
    
    /* We return the pointer to the end of the encoded data. */
    return data;
}

/*
 * During compression, the "xor pass" is used to compute the 
 * "difference" between bytes in order to maximize the number
 * of bytes equal to 0.
 */
void xor_pass(uint8_t *buffer)
{
    int i;
    for(i=7; i>0; --i)
    {
        buffer[0x02] ^= buffer[0x00];
        buffer[0x03] ^= buffer[0x01];
        buffer[0x12] ^= buffer[0x10];
        buffer[0x13] ^= buffer[0x11];
        
        buffer += 2;
    }
}

/*
 * Decode gfx.
 */
int decode(unsigned int bloc, uint8_t **buffer, size_t *size, uint8_t *rom, size_t romSize)
{
    uint8_t *ptr;
    uint8_t *out;
    uint8_t *header;
    uint8_t encoding;
    unsigned int address;
    unsigned int data_address;
    unsigned int compressed_data_size;
    unsigned int bloc_count;
    unsigned int i;
    uint8_t bank, header_nibble;
    
    /* Compute address from bloc. */
    address = g_base_address + (bloc * 6);
    
    ptr = rom + get_physical_address(address);
    if((ptr + 6) >= (rom + romSize))
    {
        fprintf(stderr, "Inconsistent data found in bloc %02x.\n", bloc);
        return 1;
    }
    
    /* Data address. */
    data_address =  ptr[0] | (ptr[1] << 8);
    ptr += 2;
    
    /* Some sanity check. */
    if(data_address == 0)
    {
        fprintf(stderr, "Inconsistent data found in bloc %02x.\n", bloc);
        return 1;
    }
    
    /* Vram output address. */
    /* vram_address = ptr[0] | (ptr[1] << 8); */
    ptr += 2;
    
    /* ROM bank. */
    bank = *ptr++;
    for(i=0; i<4; i++)
    {
        g_mpr[2+i] = bank+i;
    }
    
    /* Set pointer to the begining of data*/
    ptr = rom + get_physical_address(data_address) + 2;
    if((ptr + 4) >= (rom + romSize))
    {
        fprintf(stderr, "Inconsistent data found in bloc %02x.\n", bloc);
        return 1;
    }

    /* Bloc count. A bloc is 32 bytes long. */
    bloc_count = ptr[0] | (ptr[1] << 8);
    ptr += 2;
    
    /* Sanity check. */
    if((bloc_count == 0) || (bloc_count >= 65536))
    {
        fprintf(stderr, "Inconsistent data found in bloc %02x.\n", bloc);
        return 1;
    }
    
    /* Allocate buffer. */
    *size = bloc_count * 0x20;
    out = (uint8_t*)realloc(*buffer, *size);
    if(out == NULL)
    {
        fprintf(stderr, "Failed to allocate buffer!\n");
        return 1;
    }
    
    *buffer = out;
    /* Compressed data size. */
    compressed_data_size = ptr[0] | (ptr[1] << 8);
    ptr += 2;

    /* Sanity check. */
    if((compressed_data_size == 0) || (compressed_data_size >= 32768))
    {
        fprintf(stderr, "Inconsistent data found in bloc %02x.\n", bloc);
        return 1;
    }
    
    /* Compute header location. */
    header = rom + get_physical_address(data_address + compressed_data_size);

    /* Decompress data. */
    header_nibble = 0;
    while(bloc_count)
    {
        if((header >= (rom + romSize)) || (ptr >= (rom + romSize)))
        {
            fprintf(stderr, "Inconsistent data found in bloc %02x.\n", bloc);
            return 1;
        }

        encoding = compute_encoding(&header_nibble, &header);
        if(encoding == 0)
        {
            /* Empty bloc. */
            memset(out, 0, 0x20);
        }
        else if(encoding == 2)
        {
            /* RLE. */
            ptr = rle_decode(ptr, out);
        }
        else if(encoding == 3)
        {
            /* RLE. */
            ptr = rle_decode(ptr, out);
            /* XOR pass. */
            xor_pass(out);
        }
        else
        {
            /* Raw copy. */
            memcpy(out, ptr, 0x20);
            ptr += 0x20;
        }
        out += 0x20;
        --bloc_count;
    }
    
    return 0;
}

/*
 * The infamous main entry point.
 */
int main(int argc, char **argv)
{
    FILE *input;
    size_t romSize, len, count;
    uint8_t *rom;
    uint8_t *buffer;
    int err = 0;
    int i;
    
    memset(g_mpr, 0, 8);
    
    input = fopen(argv[1], "rb");
    if(input == NULL)
    {
        fprintf(stderr, "Unable to open %s: %s\n", argv[1], strerror(errno));
        return 1;
    }
    
    fseek(input, 0, SEEK_END);
    romSize = ftell(input);
    fseek(input, 0, SEEK_SET);
    romSize -= ftell(input);
    
    rom = (uint8_t*)malloc(romSize);
    if(rom == NULL)
    {
        err = 1;
        fprintf(stderr, "Unable to allocate ROM buffer.\n");
    }
    else
    {
        count = fread(rom, 1, romSize, input);
        if(count != romSize)
        {
            err = 1;
            fprintf(stderr, "Unable to read ROM buffer: %s\n", strerror(errno));
        }
    }
    
    fclose(input);
   
    if(rom)
    {        
        buffer = NULL;
        len    = 0;
        
        for(i=0; i<0x40; i++)
        {
            err = decode(i, &buffer, &len, rom + ROM_HEADER_SIZE, romSize - ROM_HEADER_SIZE);
            if(!err)
            {
                char name[64];
                FILE* out;
                
                sprintf(name, "out_%02x.pce", i);
                
                out = fopen(name, "wb");
                if(out == NULL)
                {
                    fprintf(stderr, "Unable to open %s: %s.\n", name, strerror(errno));
                }
                else
                {
                    fwrite(buffer, 1, len, out);                
                    fclose(out);
                }
            }
        }
        
        if(buffer != NULL)
        {
            free(buffer);
        }

        free(rom);
    }
    
    return err;
}
