#include "buffer.h"
#include "Log.h"

void bs_buildmask(bytebuf* bs, unsigned int* pval, int* pnum)
{
	*pval = (*bs->start)<<(24+bs->bitindex);
	*pnum = 8 - bs->bitindex;
	bs->count -= *pnum;
	bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
	bs->bitindex = 0;
}

void bs_refreshmask(bytebuf* bs, unsigned int* pval, int* pnum)
{
	int num = *pnum;
	unsigned int val = *pval;

	while(num < 24)
	{
		val += (*bs->start)<<(24-num);
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
		num += 8;
		bs->count -= 8;
	}

	*pnum = num;
	*pval = val;
}

/*
 * Open the device to read the bit stream from it
 */
void bs_init(bytebuf* bs, char* data, int size)
{
	bs->bitindex = 0;
	bs->start = data;
	bs->free = data;
	bs->data = data;
	bs->last = data + size - 1;
	bs->size = size;
	bs->count = 0;
}

void bs_fixcount(bytebuf* bs)
{
	int count = bs->free - bs->start;
	if (count < 0) count += bs->size;
	bs->count = (count << 3) - bs->bitindex;
}

#ifdef MAKELOG
#ifdef MAKEABS
static void bs_check(bytebuf* bs)
{
	int count = bs->free - bs->start;
	if (count < 0) count += bs->size;
	if (bs->count != (count << 3) - bs->bitindex)
		Log("Stream count corrupted\n");
}
#else
#define bs_check(x)
#endif
#else
#define bs_check(x)
#endif

/*
 * Read 1 bit from the bit stream
 */
unsigned int bs_get1bit(bytebuf* bs)
{
	unsigned int bit;

	if (bs->count <= 0)
	{
		Log("Attempt to read from empty stream\n");

		return 0;
	}
	bs->count--;

	bit = (*bs->start) << bs->bitindex;
	bs->bitindex++;
	if (bs->bitindex > 7)
	{
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
		bs->bitindex = 0;
	}

	bs_check(bs);

	return (bit >> 7) & 1;
}

/*
 * Read max 8 bits from the bit stream
 */
unsigned int bs_getbits8(bytebuf* bs, int N)
{
	unsigned int	val = 0;

	if (bs->count < N)
	{
		Log("Attempt to read from empty stream\n");

		return 0;
	}
	bs->count -= N;

	bs->bitindex += N;
	if (bs->bitindex > 7)
	{
		val = (*bs->start) << 8;
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
	}
	val += *bs->start;
	bs->bitindex &= 7;
	val <<= (24 - N + bs->bitindex);
	val >>= (32 - N);

	bs_check(bs);

	return val;
}

/*
 * Read N bit from the bit stream
 */
unsigned int bs_getbits(bytebuf* bs, int N)
{
	unsigned int	val = 0;
	int		j;

	if (bs->count < N)
	{
		Log("Attempt to read from empty stream\n");

		return 0;
	}
	bs->count -= N;

	j = bs->bitindex + N;
	while(j > 7)
	{
		val += *bs->start;
		val <<= 8;
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
		j -= 8;
	}
	val += *bs->start;
	bs->bitindex = j;
	val <<= (24 - N + j);
	val >>= (32 - N);

	bs_check(bs);

	return val;
}

/* Write N bytes into the bit stream */

void bs_transferBytes(bytebuf* hbuf, bytebuf* bs, int N)
{
	char *ps;
	char *ph;

	if (bs->count < (N<<3))
	{
		Log("Attempt to read from empty buffer\n");

		return;
	}

#ifdef MAKELOG
	if ((N + (hbuf->count >> 3)) >= hbuf->size)
	{
		Log("Attempt to write in full buffer\n");
		return;
	}

	if (bs->bitindex || hbuf->bitindex)
		Log("Attempt to use byte on non aligned buffer\n");
#endif
	bs->count -= (N<<3);
	hbuf->count += (N<<3);

	for (ph = hbuf->free, ps = bs->start; N; N--)
	{
		*ph++ = *ps++;
		if (ps > bs->last) ps = bs->data;
		if (ph > hbuf->last) ph = hbuf->data;
	}
	hbuf->free = ph;
	bs->start = ps;

	bs_check(bs);
	bs_check(hbuf);
}

void bs_adjustNbits(bytebuf* hbuf, int N)
{
	if (N < 0)
	{
//		Log("Rewinding %d bits\n", -N);

		hbuf->bitindex += N;

		while(hbuf->bitindex < 0)
		{
			hbuf->bitindex += 8;
			hbuf->start--;
		}
		if (hbuf->start < hbuf->data)
			hbuf->start += hbuf->size;
	}
	else
	{
//		Log("Skip %d bits\n", N);

		if (hbuf->count < N)
		{
			Log("Attempt to skip from empty buffer\n");

			return;
		}
		hbuf->bitindex += N;

		hbuf->start += (hbuf->bitindex >> 3);
		if (hbuf->start > hbuf->last)
			hbuf->start -= hbuf->size;
		hbuf->bitindex &= 7;
	}
	hbuf->count -= N;

	bs_check(hbuf);
}

/*
 * Keep n bytes in buffer
 */

int bs_keepBytes(bytebuf* hbuf, int N)
{
	int count;
	int M;

	// byte align
	if (hbuf->bitindex)
		bs_getbits8(hbuf, 8 - hbuf->bitindex);

	// count bytes in buffer
	count = hbuf->count >> 3;

	// count bytes to discard
	M = count - N;

	if (M <= 0)
	{
		int NN = N + M;
		if (M < 0) Log("Missing %d bytes in reservoir\n", -M);

/* Filling the reservoir is a bad idea, better to mute sound instead
*/
		// put zeros in buffer
		for (;M; M++)
		{
			*hbuf->free = 0;
			hbuf->free = (hbuf->free < hbuf->last) ? hbuf->free + 1 : hbuf->data;
		}

		hbuf->count = N << 3;


		bs_check(hbuf);

		return NN;
	}

	hbuf->count = N << 3;

	// discard bytes
	hbuf->start += M;
	if (hbuf->start > hbuf->last)
		hbuf->start -= hbuf->size;

	bs_check(hbuf);

	return N;
}

/*
 * Skip n bytes in buffer
 */
void bs_skipBytes(bytebuf* hbuf, unsigned int N)
{
	int count;

	// byte align
	if (hbuf->bitindex)
		bs_getbits8(hbuf, 8 - hbuf->bitindex);

	// count bytes in buffer
	count = hbuf->count >> 3;

	if (count < N)
		Log("Attempt to skip Bytes from empty buffer\n");

	if (N > count) N = count;
	hbuf->count -= (N<<3);

	// discard bytes
	hbuf->start += N;
	if (hbuf->start > hbuf->last)
		hbuf->start -= hbuf->size;

	bs_check(hbuf);
}

unsigned int bs_bytecount(const bytebuf* bs)
{
	return bs->count >> 3;
}

int bs_peekBytes(const bytebuf* bs, char* p, int N)
{
	char *ps;
	char *ph;

	// Not yet that many bytes in buffer
	if (bs->count < (N<<3))
		return 0;

	for (ph = p, ps = bs->start; N; N--)
	{
		*ph++ = *ps++;
		if (ps > bs->last) ps = bs->data;
	}

	return 1;
}
