#include "IStream.h"
#include "bitbuffer.h"
#include "Log.h"

#ifdef MAKELOG
#ifdef MAKEABS
static void bs_check(bytebuf* bs)
{
	int count = bs->free - bs->start;
	if (count < 0) count += bs->size;
	if (bs->count != (count << 3) - bs->bitindex)
		Log("Stream count corrupted %d != 8*%d - %d\n", bs->count, count, bs->bitindex);
}
#else
#define bs_check(x)
#endif
#else
#define bs_check(x)
#endif

/*
 * Read 1 bit from the bit stream
 */
unsigned int bs_get1bit(bytebuf* bs)
{
	unsigned int bit;

	if (bs->count <= 0)
		return 0;

	bs->count--;

	bit = (*bs->start) << bs->bitindex;
	bs->bitindex++;
	if (bs->bitindex > 7)
	{
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
		bs->bitindex = 0;
	}

	bs_check(bs);

	return (bit >> 7) & 1;
}

/*
 * Skip N bit from the bit stream
 */
void bs_skipbits(bytebuf* bs, int N)
{
	int j;

	if ((bs->count < N) || (N <= 0))
		return;

	bs->count -= N;

	j = bs->bitindex + N;

	bs->start += j>>3;
	if (bs->start > bs->last)
		bs->start -= bs->size;

	bs->bitindex = j & 7;

	bs_check(bs);
}

/*
 * Read N bit from the bit stream, N <= 24
 */
unsigned int bs_getbits(bytebuf* bs, int N)
{
	unsigned int val = 0;
	int j;

	if ((bs->count < N) || (N <= 0))
		return 0;

	bs->count -= N;

	j = bs->bitindex + N;
	while(j > 7)
	{
		val += *bs->start;
		val <<= 8;
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
		j -= 8;
	}
	val += *bs->start;
	bs->bitindex = j;
	val <<= (24 - N + j);
	val >>= (32 - N);

	bs_check(bs);

	return val;
}

/*
 * Read a signed N bit from the bit stream
 */
int bs_getsbits(bytebuf* bs, int N)
{
	int val = 0;
	int j;

	if (N > 24)
	{
		val = bs_getsbits(bs, 24);
		N -= 24;
		return (val << N) | bs_getbits(bs, N);
	}

	if ((bs->count < N) || (N <= 0))
		return 0;

	bs->count -= N;

	j = bs->bitindex + N;
	while(j > 7)
	{
		val += *bs->start;
		val <<= 8;
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
		j -= 8;
	}
	val += *bs->start;
	bs->bitindex = j;
	val <<= (24 - N + j);
	val >>= (32 - N);

	bs_check(bs);

	return val;
}

/*
 * Read N bit from the bit stream without updating position
 */
unsigned int bs_peekbits(const bytebuf* bs, int N)
{
	unsigned int val = 0;
	int j;
	byte* start = bs->start;

	if ((bs->count < N) || (N <= 0))
		return 0;

	j = bs->bitindex + N;
	while(j > 7)
	{
		val += *start;
		val <<= 8;
		start = (start < bs->last) ? start + 1 : bs->data;
		j -= 8;
	}
	val += *start;
	val <<= (24 - N + j);
	val >>= (32 - N);

	return val;
}

/*
 * Returns the number of bytes in buffer.
 */
unsigned int bs_bitcount(const bytebuf* bs)
{
	return bs->count;
}

/*
 * Skip some bits to align on byte boundary.
 */
void bs_align(bytebuf* bs)
{
	if (bs->bitindex)
	{
		bs->count -= 8 - bs->bitindex;
		bs->bitindex = 0;
		bs->start = (bs->start < bs->last) ? bs->start + 1 : bs->data;
	}
}

/*
 * Skip n bytes in buffer.
 * Return number of bytes missing from buffer to complete the request.
 */
unsigned int bs_skipBytes(bytebuf* bs, unsigned int N)
{
	unsigned int count;

	// byte align
	bs_align(bs);

	// count bytes in buffer
	count = bs->count >> 3;

	if (count <= N)
		N -= count;
	else
	{
		count = N;
		N = 0;
	}

	bs->count -= (count<<3);

	// discard bytes
	bs->start += count;
	if (bs->start > bs->last)
		bs->start -= bs->size;

	bs_check(bs);

	return N;
}

bool bs_peekBytes(const bytebuf* bs, uint8_t* p, int N)
{
	uint8_t* ps;
	uint8_t* ph;

	// Not yet that many bytes in buffer
	if (bs->count < (N<<3))
		return false;

	for (ph = p, ps = bs->start; N; N--)
	{
		*ph++ = *ps++;
		if (ps > bs->last) ps = bs->data;
	}

	return true;
}

bool bs_getBytes(bytebuf* bs, uint8_t* p, int N)
{
	uint8_t* ph;

	// Not yet that many bytes in buffer
	if (bs->count < (N<<3))
		return false;

	bs->count -= N<<3;

	for (ph = p; N; N--)
	{
		*ph++ = *bs->start++;
		if (bs->start > bs->last) bs->start = bs->data;
	}

	bs_check(bs);

	return true;
}

bool bs_getInt(bytebuf* bs, uint32_t* p, int N)
{
	uint32_t val = 0;
	uint32_t i;

	// Not yet that many bytes in buffer
	N <<= 3;

	if ((bs->count < N) || (bs->count & 7))
		return false;

	bs->count -= N;

	for (i = 0; i < N; i += 8)
	{
		val |= (*bs->start++) << i;
		if (bs->start > bs->last) bs->start = bs->data;
	}

	bs_check(bs);

	*p = val;

	return true;
}

int bs_getUnary(bytebuf* bs, bool b)
{
	int i = 0;
	int val;
	int index = bs->bitindex;
	uint8_t* start = bs->start;

	if (bs->count <= 0)
		return 0;

	if (b)
	{
		while (bs->count > i)
		{
			val = ((~*start) << index) & 0xff;

			if (val == 0)
				i += 8 - index;
			else
			{
				if (!(val & 0xf0))
				{
					i += 4;
					val <<= 4;
				}
				if (!(val & 0xc0))
				{
					i += 2;
					val <<= 2;
				}
				if (!(val & 0x80))
				 	i += 1;

				break;
			}

			start = (start == bs->last) ? bs->data : start + 1;
			index = 0;
		}
	}
	else
	{
		while (bs->count > i)
		{
			val = (*start << index) & 0xff;

			if (val == 0)
				i += 8 - index;
			else
			{
				if (!(val & 0xf0))
				{
					i += 4;
					val <<= 4;
				}
				if (!(val & 0xc0))
				{
					i += 2;
					val <<= 2;
				}
				if (!(val & 0x80))
				 	i += 1;

				break;
			}

			start = (start == bs->last) ? bs->data : start + 1;
			index = 0;
		}
	}

	bs_skipbits(bs, i + 1);

	return i;
}
