#include "mem.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "kernel.h"
#include "swis.h"

const char nil = 0;
#define Mem_Tag 0xdc
#define Tag_Size	8
#define Filename_Stack ((void*) 1)

typedef struct __mem__
{
	const char*	file;
	int		line;
	int		size;
	unsigned char	tag[Tag_Size];
} __mem__;

typedef struct
{
	int	tag;
	int	free;
	int	base;
	int	end;
} Heap;

typedef struct
{
	int	size;
	__mem__	header;
} Heap_Alloc;

typedef struct
{
	int	next;
	int	size;
} Heap_Free;

typedef union __p__
{
	unsigned char*	pc;
	Heap_Alloc*	pa;
	Heap_Free*	pf;
} __p__;

static struct
{
	Heap*	m_pHeap;
	int	m_PageSize;
	int	m_HeapSize;
	int	m_StartSize;
	int	m_flags;
} TheHeap = {NULL, 0, 0, 0, 0};

static void* mem_allocproc(unsigned  size);
static void mem_freeproc(void* p);
static void* __mem_alloc0(const char* filename, int atline, int size);
static int mem_slotextend(int size, void** pp);


static void mem_err(const char* p)
{
	_kernel_oserror		err;

	err.errnum = 1;
	snprintf(err.errmess, sizeof(err.errmess), "%s", p);
	_swix(Wimp_ReportError, _INR(0,2), &err, 1, "VorbisAbs");
}

_kernel_oserror* mem_init(int flags)
{
	if (TheHeap.m_flags & EMem_Init) return NULL;

	TheHeap.m_flags = EMem_Init + (flags & EMem_InitMask);

	if (TheHeap.m_flags & EMem_CPP)
	{
		TheHeap.m_flags &= ~(EMem_Check | EMem_Dump);
		return NULL;
	}

	// OS_ReadMemMapInfo, read page size
	if (_swix(OS_ReadMemMapInfo, _OUT(0), &TheHeap.m_PageSize) != NULL)
		TheHeap.m_PageSize = 4096;

	TheHeap.m_HeapSize = TheHeap.m_PageSize;

	// Wimp_SlotSize, read slot size
	_swi(Wimp_SlotSize, _INR(0,1)|_OUT(0), -1, -1, &TheHeap.m_StartSize);

	TheHeap.m_pHeap = (Heap*) (TheHeap.m_StartSize + 0x8000);

	// Wimp_SlotSize, increase slot size
	_swi(Wimp_SlotSize, _INR(0,1), TheHeap.m_StartSize + TheHeap.m_HeapSize, -1);

	// Initialise heap
	_swi(OS_Heap, _INR(0,3), 0, TheHeap.m_pHeap, 0, TheHeap.m_HeapSize);

	_kernel_register_allocs(mem_allocproc, mem_freeproc);
	_kernel_register_slotextend(mem_slotextend);

	TheHeap.m_flags |= EMem_Init;

	return NULL;
}

void mem_finalise(void)
{
}

void mem_pack(void)
{
	int			size;
	_kernel_swi_regs	regs;

	if (TheHeap.m_flags & EMem_CPP) return;
	if (!(TheHeap.m_flags & EMem_Pack)) return;

	TheHeap.m_flags &= ~EMem_Pack;

	// OS_Heap, shrink
	regs.r[0] = 5;
	regs.r[1] = (int) TheHeap.m_pHeap;
	regs.r[2] = 0;
	regs.r[3] = - TheHeap.m_HeapSize;
	_kernel_swi(0x02001d, &regs, &regs);

	size = TheHeap.m_HeapSize;
	TheHeap.m_HeapSize -= regs.r[3];

	if ((size - TheHeap.m_HeapSize) >= TheHeap.m_PageSize)
	{
		// Change slot size
		regs.r[0] = TheHeap.m_StartSize + TheHeap.m_HeapSize;
		regs.r[1] = -1;
		_kernel_swi(0x600ec, &regs, &regs);
		size = regs.r[0] - TheHeap.m_StartSize;
	}

	// OS_Heap, resize
	regs.r[0] = 5;
	regs.r[1] = (int) TheHeap.m_pHeap;
	regs.r[2] = 0;
	regs.r[3] = size - TheHeap.m_HeapSize;
	_kernel_swi(0x02001d, &regs, &regs);

	TheHeap.m_HeapSize = size;
}

void mem_dump(void)
{
	int	flags = TheHeap.m_flags;

	if (TheHeap.m_flags & EMem_CPP) return;

	if (!(TheHeap.m_flags & EMem_Dump)) return;

	TheHeap.m_flags |= EMem_Check;

	if (!mem_check()) return;

	TheHeap.m_flags = flags;

	{
		__p__ pfree;
		__p__ pfreeend;
		__p__ pblock;
		__p__ pblockend;
		int	bLeaks = 0;

		// struct of heap is "Heap", free, base, end
		pfreeend.pc = ((unsigned char*)TheHeap.m_pHeap) + TheHeap.m_pHeap->base;
		if (TheHeap.m_pHeap->free)
			pfree.pc = (((unsigned char*)TheHeap.m_pHeap) + TheHeap.m_pHeap->free) + 4;
		else
			pfree.pc = pfreeend.pc;
		pblock.pc = (unsigned char*)(TheHeap.m_pHeap + 1);
		pblockend.pc = pfree.pc;


		while(pblock.pc < pfreeend.pc)
		{
			// check used blocks
			while(pblock.pc < pblockend.pc)
			{
				if (pblock.pa->header.line)
				{
					if (!bLeaks)
					{
						printf("Memory leaks detected\n");
						bLeaks = 1;
					}
					printf("In file %s, at line %d\n"
						, pblock.pa->header.file, pblock.pa->header.line);
				}

				pblock.pc += pblock.pa->size;
			}
			// skip free block
			if (pblockend.pc < pfreeend.pc)
			{
				pblock.pc = pfree.pc + pfree.pf->size;
				pblockend.pc = (pfree.pf->next) ? pfree.pc + pfree.pf->next : pfreeend.pc;
				pfree.pc = pblockend.pc;
			}
		}
	}
}

void mem_removechecks(void)
{
	TheHeap.m_flags &= ~(EMem_Dump | EMem_Check);
}

void* __mem_realloc(const char* filename, int atline, void* pdata, int newsize)
{
	const __mem__*		pblock;
	void*			p;
	int			size;

	if (!pdata) return __mem_alloc(filename, atline, newsize);

	if (newsize == 0)
		newsize = 1;

	if (TheHeap.m_flags & EMem_CPP)
		return realloc((void*) pdata, newsize);

	pblock = pdata;
	if (TheHeap.m_flags & (EMem_Check | EMem_Dump))
	{
		pblock--;
		{
			const char* pc = ((char*) pblock) + pblock->size - Tag_Size;
			int i;

			if (!pblock->file || (pblock->size < 0))
				{printf("Realloc of freed/corrupted block in %s at line %d\n", filename, atline);exit(1);}

			// check start/end markers
			for (i = 0; i < Tag_Size; i++)
			{
				if ((pblock->tag[i] != Mem_Tag)
				||  (pc[i] != Mem_Tag))
					{printf("Realloc of freed/corrupted block in %s at line %d\n", filename, atline);exit(1);}
			}
		}
	}

	// OS_Heap, read block size
	_swi(OS_Heap, _INR(0,2)|_OUT(3), 6, TheHeap.m_pHeap, pblock, &size);
	size -= 4;
	if (TheHeap.m_flags & (EMem_Check | EMem_Dump))
		size -= sizeof(__mem__);

	p = __mem_alloc(filename, atline, newsize);
	if (p)
	{
		if (size > newsize) size = newsize;
		memcpy(p, pdata, size);
	}
	__mem_free(filename, atline, pdata);

	return p;
}

void* __mem_calloc(const char* filename, int atline, int amount, int size)
{
	char* p = __mem_alloc(filename, atline, amount * size);

	if (p) memset(p, 0, amount * size);

	return p;
}

void* __mem_alloc(const char* filename, int atline, int size)
{
	char* p;

	if (size == 0) size = 1;

	if (filename != Filename_Stack)
	{
		if (TheHeap.m_flags & EMem_CPP)
			return malloc(size);

		if (!mem_check())
			exit(1);
	}

	p = __mem_alloc0(filename, atline, size);
	if (!p) printf("no memory for %d bytes in %s at line %d\n", size, filename, atline);

	return p;
}

void __mem_free(const char* filename, int atline, const void* pdata)
{
	__mem__* pblock = (void*) pdata;

	if (!pdata || (pdata == &nil)) return;

	if (filename != Filename_Stack)
	{
		if (TheHeap.m_flags & EMem_CPP)
		{
			free((void*) pdata);
			return;
		}

		if (!mem_check())
			exit(1);
	}

	if (TheHeap.m_flags & (EMem_Check | EMem_Dump))
	{
		pblock--;

		if (filename != Filename_Stack)
		{
			char* pc = ((char*) pblock) + pblock->size - Tag_Size;
			int i;

			if (!pblock->file || (pblock->size < 0))
				{printf("Free of freed/corrupted block in %s at line %d\n", filename, atline);exit(1);}

			// check start/end markers
			for (i = 0; i < Tag_Size; i++)
			{
				if ((pblock->tag[i] != Mem_Tag)
				||  (pc[i] != Mem_Tag))
					{printf("Free of freed/corrupted block in %s at line %d\n", filename, atline);exit(1);}
			}

			// overwrite memory
			while (pc >= pdata)
				*pc-- = Mem_Tag;

			// mark block as freed
			pblock->file = NULL;
		}
	}

	// OS_Heap, free block
	_swi(OS_Heap, _INR(0,2), 3, TheHeap.m_pHeap, pblock);

	TheHeap.m_flags |= EMem_Pack;
}

int mem_check(void)
{
	if (TheHeap.m_flags & EMem_CPP) return 1;
	if (!(TheHeap.m_flags & EMem_Check)) return 1;

	{
		__p__	pfree;
		__p__	pfreeend;
		__p__	pblock;
		__p__	pblockend;
		int	i;

		// struct of heap is "Heap", free, base, end */
		pfreeend.pc = ((unsigned char*)TheHeap.m_pHeap) + TheHeap.m_pHeap->base;
		if (TheHeap.m_pHeap->free)
			pfree.pc = (((unsigned char*)TheHeap.m_pHeap) + TheHeap.m_pHeap->free) + 4;
		else
			pfree.pc = pfreeend.pc;
		pblock.pc = (unsigned char*)(TheHeap.m_pHeap + 1);
		pblockend.pc = pfree.pc;

		while(pblock.pc < pfreeend.pc)
		{
			// check used blocks
			while(pblock.pc < pblockend.pc)
			{
				if ((pblock.pa->size < 8)
				||  (pblock.pa->size & 3))
				{
					mem_err("Heap block corrupted case 1\n");
					return 0;
				}
				// check start/end markers
				for (i = 0; i < Tag_Size; i++)
				{
					if ((pblock.pa->header.tag[i] != Mem_Tag)
					||  (pblock.pc[pblock.pa->header.size+3-i] != Mem_Tag))
					{
						mem_err("Heap block corrupted case 2\n");
						return 0;
					}
				}
				pblock.pc += pblock.pa->size;
			}
			if (pblock.pc != pblockend.pc)
			{
				mem_err("Heap block corrupted case 3\n");
				return 0;
			}
			// skip free block
			if (pblockend.pc < pfreeend.pc)
			{
				if ((pfree.pf->size < 8)
				||  (pfree.pf->size & 3)
				||  (pfree.pf->next & 3))
				{
					mem_err("Heap corrupted case 4\n");
					return 0;
				}
				pblock.pc    = pfree.pc + pfree.pf->size;
				pblockend.pc = (pfree.pf->next) ? pfree.pc + pfree.pf->next : pfreeend.pc;
				if (pblockend.pc < pblock.pc + 8)
				{
					mem_err("Heap corrupted case 5\n");
					return 0;
				}
				pfree.pc = pblockend.pc;
			}
		}
	}

	return 1;
}

#pragma -s1

static void* mem_allocproc(unsigned  size)
{
	return __mem_alloc0(Filename_Stack, 0, size);
}

static void mem_freeproc(void* pdata)
{
	__mem__* pblock = (void*) pdata;

	if (!pdata || (pdata == &nil)) return;

	if (TheHeap.m_flags & (EMem_Check | EMem_Dump))
		pblock--;

	// OS_Heap, free block
	_swi(OS_Heap, _INR(0,2), 3, TheHeap.m_pHeap, pblock);

	TheHeap.m_flags |= EMem_Pack;
}

static void* __mem_alloc0(const char* filename, int atline, int size)
{
	_kernel_swi_regs regs;
	int tsize;
	int delta;
	__mem__* pblock;
	int i;

	if (size == 0) size = 1;

	if (TheHeap.m_flags & (EMem_Check | EMem_Dump))
		size += sizeof(__mem__) + Tag_Size;

	// Request OS_Heap block
	regs.r[0] = 2;
	regs.r[1] = (int) TheHeap.m_pHeap;
	regs.r[2] = 0;
	regs.r[3] = size;
	if (_kernel_swi(0x02001d, &regs, &regs) != NULL)
		regs.r[2] = 0;

	if (regs.r[2] == 0)
	{
		// Not enough memory, increase slot size
		delta = 1 + size / TheHeap.m_PageSize;
		if (size % TheHeap.m_PageSize) delta++;
		delta *= TheHeap.m_PageSize;

		// Wimp_SlotSize, increase slot size
		tsize = TheHeap.m_StartSize + TheHeap.m_HeapSize + delta;
		regs.r[0] = tsize;
		regs.r[1] = -1;
		if (_kernel_swi(0x600ec, &regs, &regs) != NULL)
			return NULL;

		// Wimp_SlotSize, increased correctly?
		if (tsize > regs.r[0])
		{
			// No, restore old slot size
			regs.r[0] = TheHeap.m_StartSize + TheHeap.m_HeapSize;
			regs.r[1] = -1;
			_kernel_swi(0x600ec, &regs, &regs);
			return NULL;
		}
		else delta = regs.r[0] - TheHeap.m_StartSize - TheHeap.m_HeapSize;

		// Update OS_Heap Size
		regs.r[0] = 5;
		regs.r[1] = (int) TheHeap.m_pHeap;
		regs.r[2] = 0;
		regs.r[3] = delta;
		if (_kernel_swi(0x02001d, &regs, &regs) != NULL)
			return NULL;

		TheHeap.m_HeapSize += delta;

		// Request block
		regs.r[0] = 2;
		regs.r[1] = (int) TheHeap.m_pHeap;
		regs.r[2] = 0;
		regs.r[3] = size;
		if (_kernel_swi(0x02001d, &regs, &regs) != NULL)
			return NULL;
	}

	pblock = (__mem__*) regs.r[2];

	if (TheHeap.m_flags & (EMem_Check | EMem_Dump))
	{
		unsigned char* pdata;

		pdata = (unsigned char*) pblock;

		pblock->file = filename;
		pblock->line = atline;
		pblock->size = size;
		for (i = 0; i < Tag_Size; i++)
		{
			pblock->tag[i] = Mem_Tag;
			pdata[size-1-i] = Mem_Tag;
		}

		return (pblock + 1);
	}
	else
		return pblock;
}

static int mem_slotextend(int size, void** pp)
{
	void* p = __mem_alloc0(NULL, 0, size);

	if (p)
	{
		*pp = p;
		return size;
	}

	return 0;
}
