#include <string.h>
#include <stdio.h>
#include <stdlib.h>

#include "IStream.h"
#include "ICodec.h"
#include "Channels.h"
#include "Codecs.h"
#include "Formats.h"
#include "Utils.h"
#include "Log.h"

#include "Format_WAVE.h"

static const wav_tag_t TagLIST = {{'L','I','S','T'}};
static const wav_tag_t TagINFO = {{'I','N','F','O'}};

typedef struct
{
	const wav_tag_t Tag;
	unsigned int MetaId;
	unsigned int type; // 0 string, 1 integer
} WavMeta;

static const WavMeta Metas[] =
{ {{'I','A','R','T'}, EMeta_StreamAuthor, 0}
, {{'I','C','M','T'}, EMeta_StreamComment, 0}
, {{'I','C','R','D'}, EMeta_StreamDate, 0}
, {{'I','G','N','R'}, EMeta_StreamGenre, 0}
, {{'I','N','A','M'}, EMeta_StreamTitle, 0}
, {{'I','P','R','D'}, EMeta_StreamAlbum, 0}
, {{'I','T','R','K'}, EMeta_StreamTrackNumber, 1}
, {0, 0, 0}
};

#define WAVE_STATE_WAIT_RIFF   0
#define WAVE_STATE_WAIT_TAG    1
#define WAVE_STATE_WAIT_FMT    2
#define WAVE_STATE_WAIT_FMTEX  3
#define WAVE_STATE_WAIT_STRING 4

#define MAX_LEVEL 2

typedef struct
{
	wav_tag_t   tag;
	int32_t     size;
	wav_tag_t   listtag;
	int32_t     read;
} riff_level;

typedef struct
{
	int32_t     skipsize;
	int32_t     state;
	wav_fmt_t   fmt;
	riff_level  levels[MAX_LEVEL + 1];
	int32_t     level;
	const WavMeta*    pMeta;
} riff_work_t;

const _kernel_oserror* Format_WAVE(IStream* s, bool* bContinue)
{
	riff_work_t* w = s->codec.data;
	const _kernel_oserror* e = NULL;
	bool bLoop = true;

	*bContinue = false;

	// First time, we allocate our data
	if (!w)
	{
		e = IStream_Alloc(s, &s->codec.data, sizeof(*w));
		if (e) return e;
		w = s->codec.data;
		memset(w, 0, sizeof(*w));
		w->level = 1;
	}

	// Loop on available data
	while (bLoop)
	{
		riff_level* pclevel = &w->levels[w->level - 1];
		riff_level* plevel = &w->levels[w->level];

		// Some data to skip?
		if (w->skipsize) w->skipsize = IStream_SkipBytes(s, w->skipsize);

		// Still some data to skip?
		if (w->skipsize)
		{
			// Wait for buffer to be filled
			if (!s->inb->finishflag)
				*bContinue = true;

			bLoop = false;
			break;
		}

		switch (w->state)
		{
			case WAVE_STATE_WAIT_RIFF:
			{
				// Attempt to find RIFF/wave header
				if (!bs_peekBytes(&s->bitb, (uint8_t*) pclevel, 12))
				{
					// Wait for buffer to be filled
					if (!s->inb->finishflag)
						*bContinue = true;

					bLoop = false;
					break;
				}

				// Not an RIFF WAVE, exit
				if ((pclevel->tag.i != TagRIFF.i)
				||  (pclevel->listtag.i  != TagWAVE.i))
				{
					bLoop = false;
					break;
				}

				if ((s->source.end - s->source.start - 12) < pclevel->size)
					pclevel->size = s->source.end - s->source.start - 12;

				w->skipsize = 12;
				pclevel->read += w->skipsize;
				w->state = WAVE_STATE_WAIT_TAG;
			}
			break;
			case WAVE_STATE_WAIT_TAG:
			{
				// First check if we reach end of level(s)
				while (bLoop && ((pclevel->size - pclevel->read) < 12))
				{
					bs_skipBytes(&s->bitb, pclevel->size - pclevel->read);
					if (w->level <= 1)
					{
						*bContinue = false;
						bLoop = false;
						break;
					}
					w->level--;
					pclevel = &w->levels[w->level - 1];
					plevel = &w->levels[w->level];
					pclevel->read += plevel->size;
				}
				if (!bLoop) break;

				// Attempt to read tag header
				if (!bs_peekBytes(&s->bitb, (uint8_t*) plevel, 12))
				{
					// Wait for buffer to be filled
					if (!s->inb->finishflag)
						*bContinue = true;

					bLoop = false;
					break;
				}

				// Sanity checks
				if (plevel->size < 4)
				{
					s->LastError.errnum = ErrNum_CodecError;
					snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
			    		  , "Corrupted RIFF header '%.4s' of size: %08x"
					      , plevel->tag.c, plevel->size);
					e = &s->LastError;

					bLoop = false;
					break;
				}

				if (plevel->tag.i == Tagfmt.i)
				{
					w->state = WAVE_STATE_WAIT_FMT;
					continue;
				}
				else if (plevel->tag.i == TagLIST.i)
				{
					// Padding
					if (plevel->size & 1) plevel->size++;

					// INFO list, scan that level else skip it
					if (plevel->listtag.i == TagINFO.i)
					{
						w->skipsize = 12;
						pclevel->read += w->skipsize;
						plevel->read += 4;
						w->level++;
					}
					else
					{
						w->skipsize = 8 + plevel->size;
						pclevel->read += w->skipsize;
					}
					w->state = WAVE_STATE_WAIT_TAG;
					continue;
				}

				// Skip tag header
				bs_skipBytes(&s->bitb, 8);
				pclevel->read += 8;

				if (plevel->tag.i == Tagdata.i)
				{
					// If header tag is data, stop
					const pcm_params* params = s->codec.params;

					if (plevel->size < 4)
					{
						s->LastError.errnum = ErrNum_CodecError;
						snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
			    			  , "No RIFF header 'fmt ' present before header '%.4s' of size: %08x"
						      , plevel->tag.c, plevel->size);
						e = &s->LastError;

						bLoop = false;
						break;
					}

					s->source.start = IStream_GetInputPos(s);
					if (s->source.end > (s->source.start + plevel->size))
						s->source.end = s->source.start + plevel->size;

					if (params && params->blocksize)
					{
						// beware of truncated files
						s->source.end -= s->source.start;
						s->source.end /= params->blocksize;
						s->source.end *= params->blocksize;
						s->source.end += s->source.start;
					}

					// That's it for us, switch to codec reader!
					bLoop = false;
					break;
				}

				// Metadata tag ?
				w->pMeta = NULL;
				for (int i = 0; Metas[i].Tag.i; i++)
				{
					if (plevel->tag.i == Metas[i].Tag.i)
					{
						w->state = WAVE_STATE_WAIT_STRING;
						w->pMeta = &Metas[i];
						break;
					}
				}
				if (w->pMeta) continue;

				// Skip tag
				// Padding
				if (plevel->size & 1) plevel->size++;
				// Size to skip
				w->skipsize = plevel->size;
				pclevel->read += w->skipsize;
			}
			break;
			case WAVE_STATE_WAIT_FMT:
			{
				// Attempt to find RIFF/wave header
				if (!bs_peekBytes(&s->bitb, (uint8_t*) &w->fmt, sizeof(w->fmt)))
				{
					// Wait for buffer to be filled
					if (!s->inb->finishflag)
						*bContinue = true;

					bLoop = false;
					break;
				}

				// Some sanity checks
				if (   (w->fmt.fmtsize < 16)
				   || !w->fmt.channels
				   || !w->fmt.samplerate)
				{
					s->LastError.errnum = ErrNum_CodecError;
					snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
					      , "Corrupted RIFF WAVE fmt standard header");
					e = &s->LastError;
					bLoop = false;
					break;
				}

				// Skip header and decode extended hdr
				w->skipsize = sizeof(w->fmt);
				pclevel->read += w->skipsize;
				w->state = WAVE_STATE_WAIT_FMTEX;
				w->fmt.fmtsize -= 16;
				// Padding
				if (w->fmt.fmtsize & 1) w->fmt.fmtsize++;
			}
			break;
			case WAVE_STATE_WAIT_FMTEX:
			{
				uint32_t val;
				uint32_t ch_layout = 0;

				// Attempt to read tag header
				if (bs_bitcount(&s->bitb) < (w->fmt.fmtsize * 8))
				{
					// Wait for buffer to be filled
					if (!s->inb->finishflag)
						*bContinue = true;

					bLoop = false;
					break;
				}

				if (w->fmt.fmtsize > 0)
				{
					// Extra header size
					bs_getInt(&s->bitb, &val, 2);
					if ((val + 2) > w->fmt.fmtsize)
					{
						s->LastError.errnum = ErrNum_CodecError;
						snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
						      , "Corrupted RIFF WAVE fmt, invalid extra size");
						e = &s->LastError;
						bLoop = false;
						break;
					}
					w->fmt.fmtsize -= 2;
				}

				if (w->fmt.format == 0xFFFE)
				{
					if (w->fmt.fmtsize < 22)
					{
						s->LastError.errnum = ErrNum_CodecError;
						snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
						      , "Corrupted RIFF WAVE fmt, invalid extensible size");
						e = &s->LastError;
						bLoop = false;
						break;
					}

					// Bits per sample
					bs_getInt(&s->bitb, &val, 2);
					if (val != w->fmt.bitspersample)
					{
						s->LastError.errnum = ErrNum_CodecError;
						snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
					      , "Corrupted RIFF WAVE fmt, extensible bits per sample");
						e = &s->LastError;
						bLoop = false;
						break;
					}

					// Channel layout
					bs_getInt(&s->bitb, &ch_layout, 4);
					// Codec id
					bs_getInt(&s->bitb, &val, 4);

					if (val < 0x10000)
						w->fmt.format = val;

					w->fmt.fmtsize -= 10;
				}

				if ((w->fmt.format == 0x0055)
				||  (w->fmt.format == 0x0050))
				{
					// MPEG in WAVE header
					LogInfo("Found RIFF type WAVE format MPEG1");

					s->codec.prefn = &MP3_Codec;
				}
				else if (w->fmt.format == 0x0001)
				{
					// PCM
					pcm_params* params;

					LogInfo("Found RIFF type WAVE format PCM");

					s->codec.prefn = &PCM_Codec;

					e = IStream_Alloc(s, (void**) &params, sizeof(*params));
					if (e)
					{
						bLoop = false;
						break;
					}
					s->codec.params = params;

					if (w->fmt.bitspersample <= 8)
						params->flags = DiskSample_Desc_Unsigned;
					else
					    params->flags = 0;
				    params->channels = w->fmt.channels;
				    params->channels_layout = ch_layout
				                            ? ch_layout
				                            : Channels_GetDefaultLayout(params->channels)
				                            ;
				    params->samplerate = w->fmt.samplerate;
			    	params->bitspersample = w->fmt.bitspersample;
			    	// x bits per samples are stored like (n*8 bits per sample)
			    	if (params->bitspersample & 7)
			    	{
			    		params->bitspersample &= ~7;
			    		params->bitspersample += 8;
			    	}
				    params->blocksize = w->fmt.blocksize;
				    if (!params->blocksize)
				    	params->blocksize = params->channels * params->bitspersample >> 3;
					if (params->blocksize < params->channels * params->bitspersample >> 3)
					{
						s->LastError.errnum = ErrNum_CodecError;
						snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
					      , "Corrupted RIFF WAVE fmt, invalid blocksize");
						e = &s->LastError;
						bLoop = false;
						break;
					}
				}
				else
				{
					// Unsupported format
					s->LastError.errnum = ErrNum_UnsupportedFormat;
					snprintf(s->LastError.errmess, sizeof(s->LastError.errmess)
					      , "Unsuported RIFF WAVE format %04x"
				    	  , w->fmt.format);
					e = &s->LastError;
					bLoop = false;
					break;
				}

				// skip rest of extended header and check first tag
				w->state = WAVE_STATE_WAIT_TAG;
				w->skipsize = w->fmt.fmtsize;
				// Padding
				if (w->skipsize & 1) w->skipsize++;
				pclevel->read += w->skipsize;
			}
			break;
			case WAVE_STATE_WAIT_STRING:
			{
				char* pString = NULL;

				// Attempt to string
				if (bs_bitcount(&s->bitb) < (plevel->size * 8))
				{
					// Wait for buffer to be filled
					if (!s->inb->finishflag)
						*bContinue = true;

					bLoop = false;
					break;
				}

				e = IStream_Alloc(s, (void**) &pString, plevel->size + 1);
				if (e == NULL)
				{
					bs_getBytes(&s->bitb, (uint8_t*) pString, plevel->size);
					pString[plevel->size] = '\0';
					if (w->pMeta->type == 0)
						e = IStream_SetText(s, 1, w->pMeta->MetaId, pString, plevel->size);
					else
					{
						int val = atoi(pString);
						e = IStream_SetMetadata(s, 1, w->pMeta->MetaId, &val, sizeof(int));
					}
				}
				IStream_Free(s, (void**) &pString);
				if (e)
				{
					s->LastError = *e;
					bLoop = false;
					break;
				}
				// Padding?
				if (plevel->size & 1)
				{
					plevel->size += 1;
					w->skipsize = 1;
				}
				pclevel->read += plevel->size;
				w->state = WAVE_STATE_WAIT_TAG;
			}
			break;
		}
	}

	// Release our data if we won't come back here
	if (!*bContinue || e)
		IStream_Free(s, &s->codec.data);

	return e;
}
