#include "WimpLib:mem.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "kernel.h"

#include "WimpLib:Exception.h"
#include "WimpLib:File.h"
#include "WimpLib:Log.h"
#include "WimpLib:Task.h"
#include "swis.h"

extern const char ProgramName;
const char nil = 0;
const char mem_NoMem[] = "NoMem:Not enough memory.";

#define Mem_Tag     0xdc
#define Tag_Size    8
static const char Filename_Stack[] = "Stack";

typedef struct __mem__ __mem__;

struct __mem__
{
	const char*   file;
	int           line;
	int           size;
	__mem__*      pprev;
	__mem__*      pnext;
	unsigned char tag[Tag_Size];
};

typedef struct
{
	int tag;
	int free;
	int base;
	int end;
} Heap;

typedef struct
{
	int     size;
	__mem__ header;
} Heap_Alloc;

typedef struct
{
	int next;
	int size;
} Heap_Free;

typedef union __p__
{
	unsigned char* pc;
	Heap_Alloc*    pa;
	Heap_Free*     pf;
} __p__;

static struct
{
	Heap* m_pHeap;
	int   m_PageSize;
	int   m_HeapSize;
	int   m_StartSize;
	int   m_flags;
	__mem__* m_pChain;
} TheHeap = {NULL, 0, 0, 0, 0, NULL};

static void* mem_allocproc(unsigned  size);
static void mem_freeproc(void* p);
static void* __mem_alloc0(const char* filename, int atline, int size);
static int mem_slotextend(int size, void** pp);


static void mem_err(const char* p)
{
	_kernel_oserror err;

	err.errnum = 1;
	snprintf(err.errmess, sizeof(err.errmess), "%s", p);
	_swix(Wimp_ReportError, _INR(0,2), &err, 1, ProgramName);
}

void mem_init(int flags)
{
	if (TheHeap.m_flags & EMem_Init) return;

	TheHeap.m_flags = EMem_Init + (flags & EMem_InitMask);

	if (TheHeap.m_flags & (EMem_Check | EMem_Dump))
		TheHeap.m_flags |= EMem_UseBlock;

	if (TheHeap.m_flags & EMem_CPP)
		return;

	// OS_ReadMemMapInfo, read page size
	if (_swix(OS_ReadMemMapInfo, _OUT(0), &TheHeap.m_PageSize) != NULL)
		TheHeap.m_PageSize = 4096;

	TheHeap.m_HeapSize = TheHeap.m_PageSize;

	// Wimp_SlotSize, read slot size
	_swi(Wimp_SlotSize, _INR(0,1)|_OUT(0), -1, -1, &TheHeap.m_StartSize);

	TheHeap.m_pHeap = (Heap*) (TheHeap.m_StartSize + 0x8000);

	// Wimp_SlotSize, increase slot size
	_swi(Wimp_SlotSize, _INR(0,1), TheHeap.m_StartSize + TheHeap.m_HeapSize, -1);

	// Initialise heap
	_swi(OS_Heap, _INR(0,3), 0, TheHeap.m_pHeap, 0, TheHeap.m_HeapSize);

	_kernel_register_allocs(mem_allocproc, mem_freeproc);
	_kernel_register_slotextend(mem_slotextend);

	TheHeap.m_flags |= EMem_Init;
}

void* __throw_mem_alloc(const char* filename, int atline, int size)
{
	void* p = __mem_alloc(filename, atline, size);

	if (!p) __throw_string(filename, atline, mem_NoMem);

	return p;
}

void* __throw_mem_calloc(const char* filename, int atline, int amount, int size)
{
	void* p = __mem_calloc(filename, atline, amount, size);

	if (!p) __throw_string(filename, atline, mem_NoMem);

	return p;
}

char* __throw_mem_allocstring(const char* filename, int atline, const char* pfrom)
{
	void* p = __mem_allocstring(filename, atline, pfrom);

	if (!p) __throw_string(filename, atline, mem_NoMem);

	return p;
}

char* __mem_allocstring(const char* filename, int atline, const char* pfrom)
{
	int len = pfrom ? strlen(pfrom) : 0;
	char* pstring;

	if (!len) return (void*) &nil;

	pstring = __mem_alloc(filename, atline, len + 1);

	if (pstring)
	{
		memmove(pstring, pfrom, len);
		pstring[len] = '\0';
	}

	return pstring;
}

char* __throw_mem_allocvprint(const char* filename, int atline, const char* pformat, va_list arg)
{
	va_list arg2;
	va_copy(arg2, arg);

	int size = vsnprintf(NULL, 0, pformat, arg) + 1;
	void* p = __mem_alloc(filename, atline, size);

	if (!p) __throw_string(filename, atline, mem_NoMem);

	vsnprintf(p, size, pformat, arg2);
	va_end(arg2);

	return p;
}

char* __mem_allocvprint(const char* filename, int atline, const char* pformat, va_list arg)
{
	va_list arg2;
	va_copy(arg2, arg);

	int size = vsnprintf(NULL, 0, pformat, arg) + 1;
	void* p = __mem_alloc(filename, atline, size);

	if (p) vsnprintf(p, size, pformat, arg2);
	va_end(arg2);

	return p;
}

char* __throw_mem_allocprint(const char* filename, int atline, const char* pformat, ...)
{
	va_list arg;
	va_start(arg, pformat);
	void* p = __mem_allocvprint(filename, atline, pformat, arg);
	va_end(arg);

	if (!p) __throw_string(filename, atline, mem_NoMem);

	return p;
}

char* __mem_allocprint(const char* filename, int atline, const char* pformat, ...)
{
	va_list arg;
	va_start(arg, pformat);
	void* p = __mem_allocvprint(filename, atline, pformat, arg);
	va_end(arg);

	return p;
}

bool __throw_mem_setstring(const char* filename, int atline, const char** ppTo, const char* pfrom)
{
	if (!strcmp(*ppTo ? *ppTo : &nil, pfrom ? pfrom  : &nil))
		return false;

	pfrom = __throw_mem_allocstring(filename, atline, pfrom);
	mem_free(*ppTo);
	*ppTo = (char*) pfrom;

	return true;
}

bool __mem_setstring(const char* filename, int atline, const char** ppTo, const char* pfrom)
{
	if (!strcmp(*ppTo ? *ppTo : &nil, pfrom ? pfrom  : &nil))
		return false;

	pfrom = __mem_allocstring(filename, atline, pfrom);
	if (pfrom)
	{
		mem_free(*ppTo);
		*ppTo = (char*) pfrom;
		return true;
	}

	return false;
}

void mem_pack(void)
{
	int size;
	_kernel_swi_regs regs;

	if (TheHeap.m_flags & EMem_CPP) return;
	if (!(TheHeap.m_flags & EMem_Pack)) return;

	TheHeap.m_flags &= ~EMem_Pack;

	// OS_Heap, shrink
	regs.r[0] = 5;
	regs.r[1] = (int) TheHeap.m_pHeap;
	regs.r[2] = 0;
	regs.r[3] = - TheHeap.m_HeapSize;
	_kernel_swi(0x02001d, &regs, &regs);

	size = TheHeap.m_HeapSize;
	TheHeap.m_HeapSize -= regs.r[3];

	if ((size - TheHeap.m_HeapSize) >= TheHeap.m_PageSize)
	{
		// Change slot size
		regs.r[0] = TheHeap.m_StartSize + TheHeap.m_HeapSize;
		regs.r[1] = -1;
		_kernel_swi(0x600ec, &regs, &regs);
		size = regs.r[0] - TheHeap.m_StartSize;
	}

	// OS_Heap, resize
	regs.r[0] = 5;
	regs.r[1] = (int) TheHeap.m_pHeap;
	regs.r[2] = 0;
	regs.r[3] = size - TheHeap.m_HeapSize;
	_kernel_swi(0x02001d, &regs, &regs);

	TheHeap.m_HeapSize = size;
}

void mem_dump(void)
{
	int flags = TheHeap.m_flags;

	if (!(TheHeap.m_flags & EMem_Dump)) return;

	TheHeap.m_flags |= EMem_Check;

	if (!mem_check()) return;

	TheHeap.m_flags = flags;

	if (TheHeap.m_flags & EMem_CPP)
	{
		if (TheHeap.m_pChain != NULL)
			Log("Memory leaks detected\n");

		for (const __mem__* pblock = TheHeap.m_pChain; pblock; pblock = pblock->pnext)
		{
			Log("In file %s, at line %d, size %06d\n", pblock->file, pblock->line, pblock->size);
		}
	}
    else
	{
		__p__ pfree;
		__p__ pfreeend;
		__p__ pblock;
		__p__ pblockend;
		bool  bLeaks = false;

		// struct of heap is "Heap", free, base, end
		pfreeend.pc = ((unsigned char*) TheHeap.m_pHeap) + TheHeap.m_pHeap->base;
		if (TheHeap.m_pHeap->free)
			pfree.pc = (((unsigned char*) TheHeap.m_pHeap) + TheHeap.m_pHeap->free) + 4;
		else
			pfree.pc = pfreeend.pc;
		pblock.pc = (unsigned char*)(TheHeap.m_pHeap + 1);
		pblockend.pc = pfree.pc;


		while(pblock.pc < pfreeend.pc)
		{
			// check used blocks
			while(pblock.pc < pblockend.pc)
			{
				if (pblock.pa->header.line)
				{
					if (!bLeaks)
					{
						Log("Memory leaks detected\n");
						bLeaks = true;
					}
					Log("In file %s, at line %d\n"
						, pblock.pa->header.file, pblock.pa->header.line);
				}

				pblock.pc += pblock.pa->size;
			}
			// skip free block
			if (pblockend.pc < pfreeend.pc)
			{
				pblock.pc    = pfree.pc + pfree.pf->size;
				pblockend.pc = (pfree.pf->next) ? pfree.pc + pfree.pf->next : pfreeend.pc;
				pfree.pc = pblockend.pc;
			}
		}
	}
}

void mem_removechecks(void)
{
	TheHeap.m_flags &= ~(EMem_Dump | EMem_Check);
}

static bool __mem_checkblock(const __mem__* pblock, bool log)
{
	const char* pc = ((char*) pblock) + pblock->size - Tag_Size;
	int i;

	if (!pblock->file || (pblock->size < 0))
	{
		if (log)
		{
			Log("corrupted block in %s at line %d\n", pblock->file, pblock->line);
			return false;
		}
		else
			throw_string("corrupted block in %s at line %d", pblock->file, pblock->line);
	}

	// check start/end markers
	for (i = 0; i < Tag_Size; i++)
	{
		if ((pblock->tag[i] != Mem_Tag)
		||  (pc[i] != Mem_Tag))
		{
			if (log)
			{
				Log("corrupted block in %s at line %d\n", pblock->file, pblock->line);
				return false;
			}
			else
				throw_string("corrupted block in %s at line %d", pblock->file, pblock->line);
		}
	}

	return true;
}

void* __mem_realloc(const char* filename, int atline, void* pdata, int newsize)
{
	__mem__* pblock;

	if (newsize == 0)
		newsize = 1;

	pblock = pdata;
	if (TheHeap.m_flags & EMem_UseBlock)
	{
		pblock--;

		__mem_checkblock(pblock, false);
	}

	if (TheHeap.m_flags & EMem_CPP)
	{
		const __mem__* pold = pblock;

		if (TheHeap.m_flags & EMem_Log)
			Log("Mem: realloc %06d in %s, line %d\n", newsize, filename, atline);

		if (TheHeap.m_flags & EMem_UseBlock)
			newsize += sizeof(__mem__) + Tag_Size;

		pblock = realloc((void*) pblock, newsize);

		if (pblock == NULL)
			return NULL;

		if (TheHeap.m_flags & EMem_UseBlock)
		{
			// replace old block in chain
			if (TheHeap.m_pChain == pold)
				TheHeap.m_pChain = pblock;
			if (pblock->pnext != NULL)
				pblock->pnext->pprev = pblock;
			if (pblock->pprev != NULL)
				pblock->pprev->pnext = pblock;
			pblock++;
		}

		return pblock;
	}
	else
	{
		int size;
		void* p;

		// OS_Heap, read block size
		_swi(OS_Heap, _INR(0,2)|_OUT(3), 6, TheHeap.m_pHeap, pblock, &size);
		size -= 4;

		if (TheHeap.m_flags & EMem_UseBlock)
			size -= sizeof(__mem__);

		p = __mem_alloc(filename, atline, newsize);
		if (p)
		{
			if (size > newsize) size = newsize;
			memcpy(p, pdata, size);
		}
		mem_free(pdata);

		return p;
	}
}

void* __mem_calloc(const char* filename, int atline, int amount, int size)
{
	char* p = __mem_alloc(filename, atline, amount * size);

	if (p) memset(p, 0, amount * size);

	return p;
}

void* __mem_alloc(const char* filename, int atline, int size)
{
	void* p;

	if (size == 0) size = 1;

	if (!mem_check())
		exit(EXIT_FAILURE);

	if (TheHeap.m_flags & EMem_CPP)
	{
		__mem__* pblock;
		int i;

		if (TheHeap.m_flags & EMem_UseBlock)
			size += sizeof(__mem__) + Tag_Size;

		pblock = malloc(size);

		if (pblock == NULL)
			return NULL;

		if (TheHeap.m_flags & EMem_UseBlock)
		{
			unsigned char* pdata;

			pdata = (unsigned char*) pblock;

			pblock->file = filename;
			pblock->line = atline;
			pblock->size = size;
			pblock->pprev = NULL;
			pblock->pnext = TheHeap.m_pChain;
			TheHeap.m_pChain = pblock;
			if (pblock->pnext != NULL)
				pblock->pnext->pprev = pblock;

			for (i = 0; i < Tag_Size; i++)
			{
				pblock->tag[i] = Mem_Tag;
				pdata[size-1-i] = Mem_Tag;
			}

			size -= sizeof(__mem__) + Tag_Size;
			p = (pblock + 1);
		}
		else
			p = pblock;
	}
	else
		p = __mem_alloc0(filename, atline, size);

	if (TheHeap.m_flags & EMem_Log)
		Log("Mem: %p, %06d in %s, line %d\n", p, size, filename, atline);

	return p;
}

void __mem_free(const char* filename, int atline, const void* pdata)
{
	__mem__* pblock = (void*) pdata;

	if (!pdata || (pdata == &nil)) return;

	if (TheHeap.m_flags & EMem_Log)
		Log("Mem: free %p in %s, line %d\n", pdata, filename, atline);

	if (!mem_check())
		exit(EXIT_FAILURE);

	if (TheHeap.m_flags & EMem_UseBlock)
	{
		pblock--;


		if (!__mem_checkblock(pblock, true))
			throw_string("Attempt fo free invalid block in %s at line %d", filename, atline);

		if (TheHeap.m_flags & EMem_Check)
		{
			char* pc = ((char*) pblock) + pblock->size - Tag_Size - 1;

			// overwrite memory to try to make freed block usage noticable
			while (pc >= pdata)
				*pc-- = Mem_Tag;
		}

		// mark block as freed
		pblock->file = NULL;

		// remove from chain
		if (TheHeap.m_pChain == pblock)
			TheHeap.m_pChain =pblock->pnext;
		if (pblock->pnext != NULL)
			pblock->pnext->pprev = pblock->pprev;
		if (pblock->pprev != NULL)
			pblock->pprev->pnext = pblock->pnext;
	}

	if (TheHeap.m_flags & EMem_CPP)
		free((void*) pblock);
	else
	{
		// OS_Heap, free block
		_swi(OS_Heap, _INR(0,2), 3, TheHeap.m_pHeap, pblock);

		TheHeap.m_flags |= EMem_Pack;
	}
}

bool mem_check(void)
{
	if (!(TheHeap.m_flags & EMem_Check)) return true;

	if (TheHeap.m_flags & EMem_CPP)
	{
		for (const __mem__* pblock = TheHeap.m_pChain; pblock; pblock = pblock->pnext)
		{
			if (!__mem_checkblock(pblock, true))
			{
				mem_err("Memory block corrupted");
				return false;
			}
		}
	}
	else
	{
		__p__ pfree;
		__p__ pfreeend;
		__p__ pblock;
		__p__ pblockend;

		// struct of heap is "Heap", free, base, end */
		pfreeend.pc = ((unsigned char*) TheHeap.m_pHeap) + TheHeap.m_pHeap->base;
		if (TheHeap.m_pHeap->free)
			pfree.pc = (((unsigned char*) TheHeap.m_pHeap) + TheHeap.m_pHeap->free) + 4;
		else
			pfree.pc = pfreeend.pc;
		pblock.pc = (unsigned char*)(TheHeap.m_pHeap + 1);
		pblockend.pc = pfree.pc;

		while(pblock.pc < pfreeend.pc)
		{
			// check used blocks
			while(pblock.pc < pblockend.pc)
			{
				if ((pblock.pa->size < 8)
				||  (pblock.pa->size & 3))
				{
					mem_err("Heap block corrupted, invalid size\n");
					return false;
				}

				if (!__mem_checkblock(&pblock.pa->header, true))
				{
					mem_err("Heap block memory corrupted");
					return false;
				}

				pblock.pc += pblock.pa->size;
			}

			if (pblock.pc != pblockend.pc)
			{
				mem_err("Heap block corrupted , overlaps end/free block\n");
				return false;
			}
			// skip free block
			if (pblockend.pc < pfreeend.pc)
			{
				if ((pfree.pf->size < 8)
				||  (pfree.pf->size & 3)
				||  (pfree.pf->next & 3))
				{
					mem_err("Heap free block corrupted, invalid size\n");
					return false;
				}
				pblock.pc    = pfree.pc + pfree.pf->size;
				pblockend.pc = (pfree.pf->next) ? pfree.pc + pfree.pf->next : pfreeend.pc;
				if (pblockend.pc < pblock.pc + 8)
				{
					mem_err("Heap free block corrupted , overlaps next block\n");
					return false;
				}
				pfree.pc = pblockend.pc;
			}
		}
	}

	return true;
}

void throw_memcheck(void* pdata)
{
	__mem__* pblock = pdata;
	if (TheHeap.m_flags & EMem_UseBlock)
	{
		pblock--;

		__mem_checkblock(pblock, false);
	}
}

#pragma -s1

static void* mem_allocproc(unsigned  size)
{
	return __mem_alloc0(Filename_Stack, 0, size);
}

static void mem_freeproc(void* pdata)
{
	__mem__* pblock = (void*) pdata;

	if (!pdata || (pdata == &nil)) return;

	if (TheHeap.m_flags & EMem_UseBlock)
	{
		pblock--;

		// mark block as freed
		pblock->file = NULL;

		// remove from chain
		if (TheHeap.m_pChain == pblock)
			TheHeap.m_pChain =pblock->pnext;
		if (pblock->pnext != NULL)
			pblock->pnext->pprev = pblock->pprev;
		if (pblock->pprev != NULL)
			pblock->pprev->pnext = pblock->pnext;
	}

	// OS_Heap, free block
	_swi(OS_Heap, _INR(0,2), 3, TheHeap.m_pHeap, pblock);

	TheHeap.m_flags |= EMem_Pack;
}

static void* __mem_alloc0(const char* filename, int atline, int size)
{
	_kernel_swi_regs regs;
	int tsize;
	int delta;
	__mem__* pblock;
	int i;

	if (size == 0) size = 1;

	if (TheHeap.m_flags & EMem_UseBlock)
		size += sizeof(__mem__) + Tag_Size;

	// Request OS_Heap block
	regs.r[0] = 2;
	regs.r[1] = (int) TheHeap.m_pHeap;
	regs.r[2] = 0;
	regs.r[3] = size;
	if (_kernel_swi(0x02001d, &regs, &regs) != NULL)
		regs.r[2] = 0;

	if (regs.r[2] == 0)
	{
		// Not enough memory, increase slot size
		delta = 1 + size / TheHeap.m_PageSize;
		if (size % TheHeap.m_PageSize) delta++;
		delta *= TheHeap.m_PageSize;

		// Wimp_SlotSize, increase slot size
		tsize = TheHeap.m_StartSize + TheHeap.m_HeapSize + delta;
		regs.r[0] = tsize;
		regs.r[1] = -1;
		if (_kernel_swi(0x600ec, &regs, &regs) != NULL)
			return NULL;

		// Wimp_SlotSize, increased correctly?
		if (tsize > regs.r[0])
		{
			// No, restore old slot size
			regs.r[0] = TheHeap.m_StartSize + TheHeap.m_HeapSize;
			regs.r[1] = -1;
			_kernel_swi(0x600ec, &regs, &regs);
			return NULL;
		}
		else delta = regs.r[0] - TheHeap.m_StartSize - TheHeap.m_HeapSize;

		// Update OS_Heap Size
		regs.r[0] = 5;
		regs.r[1] = (int) TheHeap.m_pHeap;
		regs.r[2] = 0;
		regs.r[3] = delta;
		if (_kernel_swi(0x02001d, &regs, &regs) != NULL)
			return NULL;

		TheHeap.m_HeapSize += delta;

		// Request block
		regs.r[0] = 2;
		regs.r[1] = (int) TheHeap.m_pHeap;
		regs.r[2] = 0;
		regs.r[3] = size;
		if (_kernel_swi(0x02001d, &regs, &regs) != NULL)
			return NULL;
	}

	pblock = (__mem__*) regs.r[2];

	if (TheHeap.m_flags & EMem_UseBlock)
	{
		unsigned char* pdata;

		pdata = (unsigned char*) pblock;

		pblock->file = filename;
		pblock->line = atline;
		pblock->size = size;
		pblock->pprev = NULL;
		pblock->pnext = TheHeap.m_pChain;
		TheHeap.m_pChain = pblock;
		if (pblock->pnext != NULL)
			pblock->pnext->pprev = pblock;

		for (i = 0; i < Tag_Size; i++)
		{
			pblock->tag[i] = Mem_Tag;
			pdata[size-1-i] = Mem_Tag;
		}

		return (pblock + 1);
	}
	else
		return pblock;
}

static int mem_slotextend(int size, void** pp)
{
	void* p = __mem_alloc0(NULL, 0, size);

	if (p)
	{
		*pp = p;
		return size;
	}

	return 0;
}
