#include <string.h>

#include "Loaders.h"
#include "Song.h"

struct bitstream
{
	// bit buffer for rolling data bit by bit from the compressed file
	uint32_t        word;
	// compressed data source
	const uint8_t*  src;
	const uint8_t*  orgsrc;
};

// Note the stange bit reading used will work only if bit-0 of first byte is 1
static void initGetb(struct bitstream* bs, const uint8_t* src, uint32_t src_length)
{
	bs->src = src + src_length - 1;
	bs->orgsrc = src;

	// get the first 8-bits of the compressed stream
	bs->word = *bs->src--;
}

// get nbits from the compressed stream
static uint32_t getb(struct bitstream* bs, int nbits)
{
	uint32_t val = 0;
	for (;nbits > 0; nbits--)
	{
		bs->word <<= 1;
		if (!(bs->word & 0xff))
			bs->word = ((*bs->src--) << 1) + (bs->word >> 8);
		val = (val << 1) + (bs->word >> 8);
		bs->word &= 0xff;
	}
	return val;
}

const _kernel_oserror* Decomp_ICE(SongHdr* pSong)
{
	const _kernel_oserror* e = NULL;
	char tag[4];
	uint32_t outSize, inSize;
	uint8_t* pSrc = NULL;
	int32_t l, x;
	uint8_t* dst;
	uint8_t* dst2;
	uint8_t* dststart;
	struct bitstream bs;

	e = FileLoad_ReadInt(pSong, (uint32_t*) tag, 4);
	if (e) return e;

	if (memcmp(tag, "ICE!", 4) != 0)
		return Loaders_NotThisType;

	e = FileLoad_ReadReverseInt(pSong, &inSize, 4);
	if (!e) e = FileLoad_ReadReverseInt(pSong, &outSize, 4);
	if (e) return e;

	e = Loaders_Alloc(pSong, (void**) &pSong->pLoaderData->Decomp.pMemory, outSize);
	if (e) return e;
	pSong->pLoaderData->Decomp.MemSize = outSize;

	inSize -= 12;
	e = Loaders_Alloc(pSong, (void**) &pSrc, inSize);
	if (e) return e;
	e = FileLoad_Read(pSong, pSrc, inSize);
	if (e) return e;

	dststart = pSong->pLoaderData->Decomp.pMemory;
	dst = dststart + outSize - 1;

	initGetb(&bs, pSrc, inSize);

	while (dst >= dststart)
	{
		// normal bytes
		if (getb(&bs, 1))
		{
			l = 0;
			if (getb(&bs, 1))
			{
				if ((l = getb(&bs, 2)) != 3)
					l += 1;
				else if ((l = getb(&bs, 2)) != 3)
					l += 4;
				else if ((l = getb(&bs, 3)) != 7)
					l += 7;
				else if ((l = getb(&bs, 8)) != 0xff)
					l += 14;
				else
				{
					l = getb(&bs, 15);
					l += 269;
				}
			}
			l++;
			if ((dst + 1 - dststart) < l)
			{
				e = Loaders_Error(pSong, 0, "ICE: Writing out of bounds, %d", l);
				goto decomp_error;
			}
			if ((bs.src + 1 - bs.orgsrc) < l)
			{
				e = Loaders_Error(pSong, 0, "ICE: Reading out of bounds, %d", l);
				goto decomp_error;
			}

			// copy direct
			for (; l > 0; l--)
			{
				*dst-- = *bs.src--;
			}
		}

		if (dst < dststart)
			break;

		// strings
		if (!getb(&bs, 1))
			l = 0;
		else if (!getb(&bs, 1))
			l = 1;
		else if (!getb(&bs, 1))
			l = 2 + getb(&bs, 1);
		else if (!getb(&bs, 1))
			l = 4 + getb(&bs, 2);
		else
			l = 8 + getb(&bs, 10);
		if (l)
		{
			if (!getb(&bs, 1))
				x = 0x1f + getb(&bs, 8);
			else if (!getb(&bs, 1))
				x = -1 + getb(&bs, 5);
			else
				x = 0x11f + getb(&bs, 12);

			if (x < 0)
				x -= l;
		}
		else
		{
			if (getb(&bs, 1))
				x = 0x3f + getb(&bs, 9);
			else
				x = -1 + getb(&bs, 6);
		}
		dst2 = dst + l + 2 + x;
		l += 2;

		if ((dst + 1 - dststart) < l)
		{
			e = Loaders_Error(pSong, 0, "ICE: Pasting out of bounds, %d", l);
			goto decomp_error;
		}
		if (dst2 >= dststart + outSize)
		{
			e = Loaders_Error(pSong, 0, "ICE: Copying out of bounds, %p %p %p %d", dst, dst2, dststart + outSize, l);
			goto decomp_error;
		}
		for(; l > 0; l--)
		{
			*dst-- = *dst2--;
		}
	}

decomp_error:
	Loaders_Free(pSong, pSrc);

	return e;
}
