#include "ka_scalers.h"
#include "ka_mem.h"
#include "ka_log.h"
#include <string.h>

/*
 * Merge word aligned lines with 8 bit sized components TRGB32, TBGR32, 8-bit mono)
 * given a given linear interpolation factor [0-255].
 */

typedef void (*FNmerge_lines32_8)(uint32_t* pdst, const uint32_t* line0, const uint32_t* line1, uint32_t sc, int width);
extern void merge_lines32_8_neon(uint32_t* pdst, const uint32_t* line0, const uint32_t* line1, uint32_t sc, int width);

static void merge_lines32_8(uint32_t* pdst, const uint32_t* line0, const uint32_t* line1, uint32_t sc, int width)
{
	const uint32_t mask = 0xff00ff;
	uint32_t val0a, val0b, val1a, val1b, b1, b2;
	int x;
	const uint32_t sc2 = 256 - sc;

	for (x = width; x & 3; x--)
	{
		val1a = *line1++;
		val0a = *line0++;
		val1b = val1a & mask;
		val0b = val0a & mask;
		val1a = (val1a >> 8) & mask;
		val0a = (val0a >> 8) & mask;
		b1 = val1a * sc + val0a * sc2;
		b2 = val1b * sc + val0b * sc2;
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;
	}

	for (; x > 0; x-= 4)
	{
		val1a = *line1++;
		val0a = *line0++;
		val1b = val1a & mask;
		val0b = val0a & mask;
		val1a = (val1a >> 8) & mask;
		val0a = (val0a >> 8) & mask;
		b1 = val1a * sc + val0a * sc2;
		b2 = val1b * sc + val0b * sc2;
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		val1a = *line1++;
		val0a = *line0++;
		val1b = val1a & mask;
		val0b = val0a & mask;
		val1a = (val1a >> 8) & mask;
		val0a = (val0a >> 8) & mask;
		b1 = val1a * sc + val0a * sc2;
		b2 = val1b * sc + val0b * sc2;
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		val1a = *line1++;
		val0a = *line0++;
		val1b = val1a & mask;
		val0b = val0a & mask;
		val1a = (val1a >> 8) & mask;
		val0a = (val0a >> 8) & mask;
		b1 = val1a * sc + val0a * sc2;
		b2 = val1b * sc + val0b * sc2;
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		val1a = *line1++;
		val0a = *line0++;
		val1b = val1a & mask;
		val0b = val0a & mask;
		val1a = (val1a >> 8) & mask;
		val0a = (val0a >> 8) & mask;
		b1 = val1a * sc + val0a * sc2;
		b2 = val1b * sc + val0b * sc2;
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;
	}
}

/*
 * Scale line with 32-bit pixels (word aligned), 8-bit per componant.
 * Use linear interpolation when zooming in, no interpolation when zooming out.
 */
typedef void (*FNscale_line32_8)(uint32_t* pdst, uint32_t width, const uint32_t* psrc, int x0, int x_mag);
extern void scaleup_line32_8_neon(uint32_t* pdst, uint32_t width, const uint32_t* psrc, int x0, int x_mag);

static void scaleup_line32_8(uint32_t* pdst, uint32_t width, const uint32_t* psrc, int x0, int x_mag)
{
	const uint32_t mask = 0xff00ff;

	uint32_t val0a, val0b, val1a, val1b, sc, b1, b2;
	int x;

	psrc += x0 >> 16;
	x0 &= 0xffff;
	val0a = *psrc++;
	val0b = val0a & mask;
	val0a = (val0a >> 8) & mask;
	val1b = val0b;
	val1a = val0a;

	for (x = width; x & 3; x--)
	{
		sc = x0 >> 8;
		b1 = val1a * sc + val0a * (256 - sc);
		b2 = val1b * sc + val0b * (256 - sc);
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		x0 += x_mag;
		if (x0 >= (1<<16))
		{
			x0 &= 0xffff;
			val0a = val1a;
			val1a = *psrc++;
			val0b = val1b;
			val1b = val1a & mask;
			val1a = (val1a >> 8) & mask;
		}
	}

	for (; x > 0; x-=4)
	{
		sc = x0 >> 8;
		b1 = val1a * sc + val0a * (256 - sc);
		b2 = val1b * sc + val0b * (256 - sc);
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		x0 += x_mag;
		if (x0 >= (1<<16))
		{
			x0 &= 0xffff;
			val0a = val1a;
			val1a = *psrc++;
			val0b = val1b;
			val1b = val1a & mask;
			val1a = (val1a >> 8) & mask;
		}

		sc = x0 >> 8;
		b1 = val1a * sc + val0a * (256 - sc);
		b2 = val1b * sc + val0b * (256 - sc);
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		x0 += x_mag;
		if (x0 >= (1<<16))
		{
			x0 &= 0xffff;
			val0a = val1a;
			val1a = *psrc++;
			val0b = val1b;
			val1b = val1a & mask;
			val1a = (val1a >> 8) & mask;
		}

		sc = x0 >> 8;
		b1 = val1a * sc + val0a * (256 - sc);
		b2 = val1b * sc + val0b * (256 - sc);
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		x0 += x_mag;
		if (x0 >= (1<<16))
		{
			x0 &= 0xffff;
			val0a = val1a;
			val1a = *psrc++;
			val0b = val1b;
			val1b = val1a & mask;
			val1a = (val1a >> 8) & mask;
		}

		sc = x0 >> 8;
		b1 = val1a * sc + val0a * (256 - sc);
		b2 = val1b * sc + val0b * (256 - sc);
		b1 &= ~mask;
		b1 |= ((b2 >> 8) & mask);
		*pdst++ = b1;

		x0 += x_mag;
		if (x0 >= (1<<16))
		{
			x0 &= 0xffff;
			val0a = val1a;
			val1a = *psrc++;
			val0b = val1b;
			val1b = val1a & mask;
			val1a = (val1a >> 8) & mask;
		}
	}
}

static void copy_line32_8(uint32_t* pdst, uint32_t width, const uint32_t* psrc, int x0, int x_mag)
{
	x_mag = x_mag; // unused
	memcpy(pdst, psrc + (x0 >> 16), 4 * width);
}

static void scaledown_line32_8(uint32_t* pdst, uint32_t width, const uint32_t* psrc, int x0, int x_mag)
{
	int x;

	for (x = width; x & 3; x--)
	{
		*pdst++ = psrc[x0 >> 16];
		x0 += x_mag;
	}
	for (; x > 0; x-= 4)
	{
		*pdst++ = psrc[x0 >> 16];
		x0 += x_mag;
		*pdst++ = psrc[x0 >> 16];
		x0 += x_mag;
		*pdst++ = psrc[x0 >> 16];
		x0 += x_mag;
		*pdst++ = psrc[x0 >> 16];
		x0 += x_mag;
	}
}

static inline FNscale_line32_8 ka_select_scale_line32_8(ka_scale_t* scale)
{
	if (scale->x_mag < (1 << 16))
		return (scale->hardware & ka_hardware_neon) ? scaleup_line32_8_neon : scaleup_line32_8;
	else if (scale->x_mag == (1 << 16))
		return copy_line32_8;
	else
		return scaledown_line32_8;
}

/*
 * Scale image with 32-bit pixels (word aligned), 8-bit per componant.
 * Use linear interpolation (on the horizontal) when zooming in, no interpolation when zooming out.
 */
void ka_scale_linear_32bpp(ka_scale_t* scale)
{
	FNscale_line32_8 fscale_line32_8 = ka_select_scale_line32_8(scale);
	uint32_t* pdst = (uint32_t*) scale->wdst;
	const int sx0 = scale->x_mag * scale->wdst_x0;
	int sy0 = scale->y_mag * scale->wdst_y0;

	if (scale->y_mag < (1 << 16))
	{
		// assume increased mem usage is compensated by less computation
		int bsize = 4 * scale->wdst_width;
		if (scale->wb_size < bsize)
		{
			void* ptr = ka_mem_alloc(bsize);
			if (!ptr)
				goto fallback; // use fallback method

			if (scale->wb_ptr) ka_mem_free(scale->wb_ptr);
			scale->wb_ptr = ptr;
			scale->wb_size = bsize;
		}

		int sy16 = sy0 & 0xffff;
		sy0 >>= 16;
		const uint32_t* psrc = ((const uint32_t*) scale->src) + scale->src_pix_width * sy0;
		uint32_t* line = scale->wb_ptr;
		fscale_line32_8(line, scale->wdst_width, psrc, sx0, scale->x_mag);

		for (int y = scale->wdst_height; y > 0; y--)
		{
			memcpy(pdst, line, 4 * scale->wdst_width);

			sy16 += scale->y_mag;
			if (sy16 >= (1<<16))
			{
				sy0 += 1;
				sy16 &= 0xffff;
				if (sy0 < scale->src_height - 1)
				{
					psrc += scale->src_pix_width;
					fscale_line32_8(line, scale->wdst_width, psrc, sx0, scale->x_mag);
				}
			}

			pdst += (scale->wdst_bpr / 4);
		}
		return;
	}

fallback:

	for (int y = scale->wdst_height; y > 0; y--)
	{
		const uint32_t* psrc = ((const uint32_t*) scale->src) + scale->src_pix_width * (sy0 >> 16);

		fscale_line32_8(pdst, scale->wdst_width, psrc, sx0, scale->x_mag);

		sy0 += scale->y_mag;
		pdst += (scale->wdst_bpr / 4);
	}
}

/*
 * Scale image with 32-bit pixels (word aligned), 8-bit per componant.
 * Use linear interpolation (on horiz. & vert.) when zooming in, no interpolation when zooming out.
 * Scaling horizontaly first use much less memory and (for 32-bit) only slightly more computations.
 */
void ka_scale_bilinear_32bpp(ka_scale_t* scale)
{
	FNscale_line32_8 fscale_line32_8 = ka_select_scale_line32_8(scale);
	FNmerge_lines32_8 fmerge_lines32_8 = (scale->hardware & ka_hardware_neon) ? merge_lines32_8_neon : merge_lines32_8;
	uint32_t* pdst = (uint32_t*) scale->wdst;
	const int sx0 = scale->x_mag * scale->wdst_x0;
	int sy0 = scale->y_mag * scale->wdst_y0;

	if (scale->y_mag < (1 << 16))
	{
		int width = scale->src_pix_width + ((scale->src_pix_width & 1) ? 1 : 0);
		int bsize = 4 * width;
		if (scale->wb_size < bsize)
		{
			void* ptr = ka_mem_alloc(bsize);
			if (ptr)
			{
				if (scale->wb_ptr) ka_mem_free(scale->wb_ptr);
				scale->wb_ptr = ptr;
				scale->wb_size = bsize;
			}
			else
			{
				ka_scale_linear_32bpp(scale);
				return;
			}
		}

		const uint32_t* psrc = ((const uint32_t*) scale->src) + scale->src_pix_width * (sy0 >> 16);
		const uint32_t* psrc2 = psrc + scale->src_pix_width;
		const uint32_t* psrcmax = ((const uint32_t*) scale->src) + scale->src_pix_width * (scale->src_height - 1);
		uint32_t* line = scale->wb_ptr;
		if (psrc2 > psrcmax)
			psrc2 = psrcmax;
		sy0 &= 0xffff;
		for (int y = scale->wdst_height; y > 0; y--)
		{
			fmerge_lines32_8(line, psrc, psrc2, sy0 >> 8, width);
			fscale_line32_8(pdst, scale->wdst_width, line, sx0, scale->x_mag);

			sy0 += scale->y_mag;
			if (sy0 >= (1<<16))
			{
				sy0 &= 0xffff;
				psrc = psrc2;
				if (psrc2 < psrcmax)
					psrc2 += scale->src_pix_width;
			}

			pdst += (scale->wdst_bpr / 4);
		}
	}
	else
	{
		for (int y = scale->wdst_height; y > 0; y--)
		{
			const uint32_t* psrc = ((const uint32_t*) scale->src) + scale->src_pix_width * (sy0 >> 16);
			fscale_line32_8(pdst, scale->wdst_width, psrc, sx0, scale->x_mag);

			sy0 += scale->y_mag;
			pdst += (scale->wdst_bpr / 4);
		}
	}
}

/*
 * Scale line with 8-bit monochrome pixels (byte aligned).
 * Use linear interpolation when zooming in, no interpolation when zooming out.
 */
static void scale_line8(uint8_t* pdst, uint32_t width, const uint8_t* psrc, int x0, int x_mag)
{
	if (x_mag < (1 << 16))
	{
		uint32_t val0, val1, sc, b1, b2;
		int x;

		psrc += x0 >> 16;
		x0 &= 0xffff;
		val0 = *psrc++;
		val1 = val0;

		for (x = width; ((int) pdst) & 3; x--)
		{
			sc = x0 >> 8;
			b1 = val1 * sc + val0 + (256 - sc);

			*pdst++ = b1 >> 8;

			x0 += x_mag;
			if (x0 >= (1<<16))
			{
				x0 &= 0xffff;
				val0 = val1;
				val1 = *psrc++;
			}
		}

		for (; x > 3; x -= 4)
		{
			sc = x0 >> 8;
			b1 = val1 * sc + val0 * (256 - sc);
			b2 = (b1 & 0xff00) >> 8;

			x0 += x_mag;
			if (x0 >= (1<<16))
			{
				x0 &= 0xffff;
				val0 = val1;
				val1 = *psrc++;
			}

			sc = x0 >> 8;
			b1 = val1 * sc + val0 * (256 - sc);
			b2 |= (b1 & 0xff00);

			x0 += x_mag;
			if (x0 >= (1<<16))
			{
				x0 &= 0xffff;
				val0 = val1;
				val1 = *psrc++;
			}

			sc = x0 >> 8;
			b1 = val1 * sc + val0 * (256 - sc);
			b2 |= (b1 & 0xff00) << 8;

			x0 += x_mag;
			if (x0 >= (1<<16))
			{
				x0 &= 0xffff;
				val0 = val1;
				val1 = *psrc++;
			}

			sc = x0 >> 8;
			b1 = val1 * sc + val0 * (256 - sc);
			b2 |= (b1 & 0xff00) << 16;
			uint32_t* pdst2 = (uint32_t*) pdst;
			*pdst2 = b2;
			pdst += 4;

			x0 += x_mag;
			if (x0 >= (1<<16))
			{
				x0 &= 0xffff;
				val0 = val1;
				val1 = *psrc++;
			}
		}

		for (; x > 0; x--)
		{
			sc = x0 >> 8;
			b1 = val1 * sc + val0 * (256 - sc);

			*pdst++ = b1 >> 8;

			x0 += x_mag;
			if (x0 >= (1<<16))
			{
				x0 &= 0xffff;
				val0 = val1;
				val1 = *psrc++;
			}
		}
	}
	else
	{
		int x;

		for (x = width; x & 3; x--)
		{
			*pdst++ = psrc[x0 >> 16];
			x0 += x_mag;
		}
		for (; x > 0; x-= 4)
		{
			*pdst++ = psrc[x0 >> 16];
			x0 += x_mag;
			*pdst++ = psrc[x0 >> 16];
			x0 += x_mag;
			*pdst++ = psrc[x0 >> 16];
			x0 += x_mag;
			*pdst++ = psrc[x0 >> 16];
			x0 += x_mag;
		}
	}
}

/*
 * Scale image with 8-bit monochrome pixels (byte aligned).
 * Use linear interpolation (on the horizontal) when zooming in, no interpolation when zooming out.
 */
void ka_scale_linear_8bpp_mono(ka_scale_t* scale)
{
	uint8_t* pdst = scale->wdst;
	int sx0 = scale->x_mag * scale->wdst_x0;
	int sy0 = scale->y_mag * scale->wdst_y0;

	if (scale->y_mag < (1 << 16))
	{
		// assume increased mem usage is compensated by less computation
		const uint8_t* psrc = scale->src + scale->src_pix_width * (sy0 >> 16);

		int bsize = scale->wdst_width;
		if (scale->wb_size < bsize)
		{
			void* ptr = ka_mem_alloc(bsize);
			if (!ptr)
				goto fallback; // use fallback method

			if (scale->wb_ptr) ka_mem_free(scale->wb_ptr);
			scale->wb_ptr = ptr;
			scale->wb_size = bsize;
		}
		uint8_t* line = scale->wb_ptr;
		scale_line8(line, scale->wdst_width, psrc, sx0, scale->x_mag);
		int sy16 = sy0 & 0xffff;
		sy0 >>= 16;

		for (int y = scale->wdst_height; y > 0; y--)
		{
			memcpy(pdst, line, scale->wdst_width);

			sy16 += scale->y_mag;
			if (sy16 >= (1<<16))
			{
				sy0 += 1;
				sy16 &= 0xffff;
				if (sy0 < scale->src_height - 1)
				{
					psrc += scale->src_pix_width;
					scale_line8(line, scale->wdst_width, psrc, sx0, scale->x_mag);
				}
			}

			pdst += scale->wdst_bpr;
		}
		return;
	}

fallback:

	for (int y = scale->wdst_height; y > 0; y--)
	{
		const uint8_t* psrc = scale->src + scale->src_pix_width * (sy0 >> 16);

		scale_line8(pdst, scale->wdst_width, psrc, sx0, scale->x_mag);

		sy0 += scale->y_mag;
		pdst += scale->wdst_bpr;
	}
}

/*
 * Scale image with 8-bit monochrome pixels (byte aligned).
 * Use linear interpolation (on horiz. & vert.) when zooming in, no interpolation when zooming out.
 * Precondition: source byte width is word aligned.
 */
void ka_scale_bilinear_8bpp_mono(ka_scale_t* scale)
{
	FNmerge_lines32_8 fmerge_lines32_8 = (scale->hardware & ka_hardware_neon) ? merge_lines32_8_neon : merge_lines32_8;
	const int sx0 = scale->x_mag * scale->wdst_x0;
	int sy0 = scale->y_mag * scale->wdst_y0;
	uint8_t* pdst = scale->wdst;

	if (scale->y_mag < (1 << 16))
	{
		const uint8_t* psrc = scale->src + scale->src_pix_width * (sy0 >> 16);
		const int skipb = ((int) scale->wdst) & 3;
		int width = skipb + scale->wdst_width;
		const int skipe = (width & 3) ? 4 - (width & 3) : 0;
		pdst -= skipb;
		width += skipe;

		int bsize = 2 * width;
		if (scale->wb_size < bsize)
		{
			void* ptr = ka_mem_alloc(bsize);
			if (ptr)
			{
				if (scale->wb_ptr) ka_mem_free(scale->wb_ptr);
				scale->wb_ptr = ptr;
				scale->wb_size = bsize;
			}
			else
			{
				ka_scale_linear_8bpp_mono(scale);
				return;
			}
		}
		uint8_t* line0 = scale->wb_ptr;
		uint8_t* line1 = line0 + width;

		const uint8_t* psrcmax = scale->src + scale->src_pix_width * (scale->src_height - 1);

		scale_line8(line0 + skipb, scale->wdst_width, psrc, sx0, scale->x_mag);
		psrc += scale->src_pix_width;
		memcpy(line1, line0, width);
		sy0 &= 0xffff;

		for (int y = scale->wdst_height; y > 0; y--)
		{
			int x;

			for (x = 0; x < skipb; x++)
			{
				line0[x] = line1[x] = pdst[x];
			}
			for (x = width - skipe; x < width; x++)
			{
				line0[x] = line1[x] = pdst[x];
			}

			fmerge_lines32_8((uint32_t*) pdst, (uint32_t*) line0, (uint32_t*) line1, sy0 >> 8, width / 4);

			sy0 += scale->y_mag;
			if (sy0 >= (1<<16))
			{
				uint8_t* linex = line0;
				line0 = line1;
				line1 = linex;
				scale_line8(line1 + skipb, scale->wdst_width, psrc, sx0, scale->x_mag);
				if (psrc < psrcmax)
					psrc += scale->src_pix_width;
				sy0 &= 0xffff;
			}

			pdst += scale->wdst_bpr;
		}
	}
	else
	{
		for (int y = scale->wdst_height; y > 0; y--)
		{
			const uint8_t* psrc = scale->src + scale->src_pix_width * (sy0 >> 16);
			scale_line8(pdst, scale->wdst_width, psrc, sx0, scale->x_mag);

			sy0 += scale->y_mag;
			pdst += scale->wdst_bpr;
		}
	}
}
