#include "Loaders.h"
#include "kernel.h"
#ifndef NULL
#define NULL (void*) 0
#endif

typedef struct
{
	uint16_t older;
	uint16_t newer;
} dict_entry;

typedef struct
{
	dict_entry*  dict;      // pointer to dictionnary
	uint32_t     bits;      // nr of bits required to read a code
	uint32_t     index;     // index of first free entry in dictionnary
	uint32_t     min_index; // minimal index in dictionnary
	uint32_t     max_index; // maximal index in dictionnary for current bits value
	uint32_t     full_index;// maximum maximorum index in dictionnary
	uint32_t     shift;     // current bit shift of word preloaded from buffer
	uint32_t     preloaded; // unread bits of current word preloaded from buffer
	uint32_t     count;     // bytes reads module bits
} lzwd;

static const _kernel_oserror* lzwd_getcode(SongHdr* pSong, lzwd* l, uint32_t* pval)
{
	const _kernel_oserror* err = NULL;

	*pval = l->preloaded >> (32 - l->shift);

	if (l->shift < l->bits)
	{
		err = FileLoad_ReadInt(pSong, &l->preloaded, 4);
		if (err) return err;
		*pval += l->preloaded << l->shift;
		l->shift += 32;
	}
	l->shift -= l->bits;

	// keep only required bits
	*pval <<= (32 - l->bits);
	*pval >>= (32 - l->bits);

	l->count++;
	l->count &= 7;

	return NULL;
}

static const _kernel_oserror* lzwd_cleardict(SongHdr* pSong, lzwd* l)
{
	const _kernel_oserror* err = NULL;

	// Compress block mode, purge remaining of last block of 8 codes first
	if (l->min_index == 257)
	{
		while (l->count != 0)
		{
			uint32_t val;

			err = lzwd_getcode(pSong, l, &val);
			if (err) return err;
		}
	}

	// initialise to use codes of minimal bit size, dictionnary is empty
	l->bits = 9;    // start with codes of a size of 9 bits
	l->index = l->min_index; // index of first free code, [0,255] are normal bytes
	l->max_index = 1 << l->bits; // index of first free code for next code size
	l->count = 0;

	return NULL;
}

// 256 is reserved to clear a full dictionnary
// 257 is reserved to mark the end of the data, cf DSym

const _kernel_oserror* lzwd_decode(SongHdr* pSong, uint8_t* buffer, uint32_t len, int maxbits, int min_code)
{
	const _kernel_oserror* err = NULL;
	lzwd l;
	uint32_t new_code, last_code;
	uint8_t *to, *to1, *tp1, *tp2, *toend;

	if ((maxbits < 9) || (maxbits > 16) || (min_code < 256) || (min_code > 258))
		return Loaders_Error(pSong, FileLoad_GetPos(pSong), "Invalid LZWD parameters");

	// Reserve storage for dictionnary
	err = Loaders_Alloc(pSong, (void**) &l.dict, sizeof(*l.dict) << maxbits);
	if (err) return err;

	// set buffer limits
	l.shift = 0;
	l.preloaded = 0;
	l.min_index = min_code;
	l.full_index = 1 << maxbits;
	l.count = 0;

	lzwd_cleardict(pSong, &l);

	to = buffer;
	toend = to + len;

	while ((to < toend) || (l.min_index == 258))
	{
		err = lzwd_getcode(pSong, &l, &new_code);
		if (err) goto lzwd_error;

		if (new_code >= l.index)
		{
			// valid codes must already be in dictionnary
			err = Loaders_Error(pSong, FileLoad_GetPos(pSong), "LZWD code not in dict");
			goto lzwd_error;
		}

		if (new_code < l.min_index)
		{
			// Use a special code to mark correct decoding end
			if (new_code == 257)
				break;

			// Use a special code to reset the dictionnary once it is full
			if (new_code == 256)
			{
				err = lzwd_cleardict(pSong, &l);
				if (err) goto lzwd_error;
				continue;
			}
		}

		// this forms a new code to add to the dictionnary

		// still free entries in dictionnary?
		if (l.index < l.full_index)
		{
			// still fits in current code size?
			if (l.index == l.max_index)
			{
				// Compress block mode, purge remaining of last block of 8 codes first
				if (l.min_index == 257)
				{
					while (l.count != 0)
					{
						uint32_t val;

						err = lzwd_getcode(pSong, &l, &val);
						if (err) goto lzwd_error;
					}
				}

				// no, increase code size
				l.bits++;
				l.max_index <<= 1;
			}

			// store new entry
			l.dict[l.index].older = new_code;
		}
		if (l.index <= l.full_index)
		{
			last_code = new_code;
			while (last_code > 255)
			{
				last_code = l.dict[last_code].older;
			}
			l.dict[l.index-1].newer = last_code;
			l.index++;
		}

		// expand code
		to1 = to;

		last_code = new_code;
		while (last_code > 255)
		{
			if (to1 >= toend)
			{
				err = Loaders_Error(pSong, FileLoad_GetPos(pSong), "LZWD code expension past end");
				goto lzwd_error;
			}
			*to1++ = l.dict[last_code].newer;
			last_code = l.dict[last_code].older;
		}

		if (to1 >= toend)
		{
			err = Loaders_Error(pSong, FileLoad_GetPos(pSong), "LZWD code expension past end");
			goto lzwd_error;
		}
		*to1++ = last_code;

		// expanded code is in the wrong order, we must reverse it
		tp2 = to1 - 1;
		tp1 = to;
		while(tp2 > tp1)
		{
			uint8_t byte = *tp2;
			*tp2-- = *tp1;
			*tp1++ = byte;
		}

		// code was expanded correctly
		to = to1;
	}

	while (to < toend)
		*to++ = 0;

lzwd_error:
	Loaders_Free(pSong, l.dict);

	return err;
}
