#include "Daodan_Patch.h"
#include "Patches/Utility.h"
#include <beaengine/BeaEngine.h>

#include <windows.h>
#include <stdlib.h>
#include <string.h>

bool DDrPatch_MakeJump(void* from, void* to)
{
	DWORD oldp;
	
	if (VirtualProtect(from, 5, PAGE_EXECUTE_READWRITE, &oldp))
	{
		*((unsigned char*)from) = 0xe9; // jmp rel32
		from = (char*)from + 1;
		*(int*)from = (unsigned int)to - (unsigned int)from - 4;
		VirtualProtect(from, 5, oldp, &oldp);
		return true;
	}
	else
		return false;
}

bool DDrPatch_MakeCall(void* from, void* to)
{
	DWORD oldp;
	
	if (VirtualProtect(from, 5, PAGE_EXECUTE_READWRITE, &oldp))
	{
		*((unsigned char*)from) = 0xe8; // call rel32
		from = (char*)from + 1;
		*(int*)from = (unsigned int)to - (unsigned int)from - 4;
		VirtualProtect(from, 5, oldp, &oldp);
		return true;
	}
	else
		return false;
}

void* DDrPatch_MakeDetour(void* from, void* to)
{
	int len = 0;

/*
	STARTUPMESSAGE("Orig before", 0);
	DDrPatch_PrintDisasm(from, 10, 0);
*/
	DISASM disasm;
	memset(&disasm, 0, sizeof(DISASM));
	disasm.EIP = (UIntPtr) from;

	char* trampoline = malloc(40);
	DDrPatch_NOOP(trampoline, 40);
	int pos = 0;
	int branches = 0;

	while (((void*)disasm.EIP - from) < 5) {
		len = Disasm(&disasm);
		if (len != UNKNOWN_OPCODE) {
			if ((disasm.Instruction.Category & 0xffff) == CONTROL_TRANSFER) {
				if (disasm.Prefix.Number > 0) {
						STARTUPMESSAGE("Detour: Branch in trampoline area from address 0x%08x with prefixes", from);
						return (void*)-1;
				}
				branches++;
				int target = disasm.Instruction.AddrValue;
				bool targetInTrampoline = ((void*)((int)disasm.Instruction.AddrValue) - from) < 5;
				switch (disasm.Instruction.BranchType) {
					case JmpType:
					case CallType:
						if (targetInTrampoline) {
							int offset = disasm.Instruction.AddrValue - disasm.EIP;
							if (disasm.Instruction.BranchType == JmpType)
								DDrPatch_MakeJump(&trampoline[pos], &trampoline[pos]+offset);
							else
								DDrPatch_MakeCall(&trampoline[pos], &trampoline[pos]+offset);
						} else {
							if (disasm.Instruction.BranchType == JmpType)
								DDrPatch_MakeJump(&trampoline[pos], (void*)target);
							else
								DDrPatch_MakeCall(&trampoline[pos], (void*)target);
						}
						pos += 5;
						break;
					case RetType:
					case JECXZ:
						memcpy(&trampoline[pos], (void*)disasm.EIP, len);
						pos += len;
						break;
					// Opcode +1
					case JO:
					case JC:
					case JE:
					case JNA:
					case JS:
					case JP:
					case JL:
					case JNG:
						if (targetInTrampoline) {
							memcpy(&trampoline[pos], (void*)disasm.EIP, len);
							pos += len;
						} else {
							trampoline[pos++] = disasm.Instruction.Opcode + 1;
							trampoline[pos++] = 5;
							DDrPatch_MakeJump(&trampoline[pos], (void*)target);
							pos += 5;
						}
						break;
					// Opcode -1
					case JNO:
					case JNC:
					case JNE:
					case JA:
					case JNS:
					case JNP:
					case JNL:
					case JG:
						if (targetInTrampoline) {
							memcpy(&trampoline[pos], (void*)disasm.EIP, len);
							pos += len;
						} else {
							trampoline[pos++] = disasm.Instruction.Opcode - 1;
							trampoline[pos++] = 5;
							DDrPatch_MakeJump(&trampoline[pos], (void*)target);
							pos += 5;
						}
						break;
					default:
						STARTUPMESSAGE("Detour: Unknown branch in trampoline area from address 0x%08x", from);
						return (void*)-1;
				}
			} else {
				memcpy(&trampoline[pos], (void*)disasm.EIP, len);
				pos += len;
			}
			disasm.EIP += (UIntPtr)len;
		}
		else {
			STARTUPMESSAGE("Detour: Unknown opcode in trampoline area from address 0x%08x", from);
			return (void*)-1;
		}
	}

	if (branches > 1) {
		STARTUPMESSAGE("Detour: Too many branches in trampoline'd code from address 0x%08x: %d", from, branches);
		return (void*)-1;
	}


	DDrPatch_MakeJump(&trampoline[pos], (void*)disasm.EIP);
	DDrPatch_NOOP(from, (void*)disasm.EIP - from);

	DWORD oldp;
	if (!VirtualProtect(trampoline, 40, PAGE_EXECUTE_READWRITE, &oldp)) {
		STARTUPMESSAGE("Detour: Could not mark page for trampoline as executable: from address 0x%08x", from);
		return (void*)-1;
	}
	DDrPatch_MakeJump(from, to);

/*
	STARTUPMESSAGE("Trampoline", 0);
	DDrPatch_PrintDisasm(trampoline, 10, 6);

	STARTUPMESSAGE("Orig after", 0);
	DDrPatch_PrintDisasm(disasm.EIP, 7, 0);

	STARTUPMESSAGE("Orig start after", 0);
	DDrPatch_PrintDisasm(from, 3, 6);
*/
	return trampoline;
}

bool DDrPatch_String(char* dest, const unsigned char* string, int length)
{
	DWORD oldp;
	
	if (VirtualProtect(dest, length, PAGE_EXECUTE_READWRITE, &oldp))
	{
		memcpy(dest, string, length);
		VirtualProtect(dest, length, oldp, &oldp);
		return true;
	}
	else
		return false;
}

bool DDrPatch_Byte(char* dest, unsigned char value)
{
	DWORD oldp;
	
	if (VirtualProtect(dest, 1, PAGE_EXECUTE_READWRITE, &oldp))
	{
		*dest = value;
		VirtualProtect(dest, 1, oldp, &oldp);
		return true;
	}
	else
		return false;
}

bool DDrPatch_Int32(int* dest, unsigned int value)
{
	DWORD oldp;
	
	if (VirtualProtect(dest, 4, PAGE_EXECUTE_READWRITE, &oldp))
	{
		*dest = value;
		VirtualProtect(dest, 4, oldp, &oldp);
		return true;
	}
	else
		return false;
}

bool DDrPatch_Int16(short* dest, unsigned short value)
{
	DWORD oldp;
	
	if (VirtualProtect(dest, 2, PAGE_EXECUTE_READWRITE, &oldp))
	{
		*dest = value;
		VirtualProtect(dest, 2, oldp, &oldp);
		return true;
	}
	else
		return false;
}

bool DDrPatch_NOOP(char* dest, unsigned int length)
{
	DWORD oldp;
	
	if (VirtualProtect(dest, length, PAGE_EXECUTE_READWRITE, &oldp))
	{
		memset(dest, 0x90, length);
		VirtualProtect(dest, length, oldp, &oldp);
		return true;
	}
	else
		return false;
}

void* DDrPatch_ExecutableASM(char* from, char* nextInst, const unsigned char* code, int length)
{
	char* newCode = malloc(length+5);
	if (!DDrPatch_NOOP(newCode, length+5))
		return (void*)-1;

	memcpy(newCode, code, length);
	if (!DDrPatch_MakeJump(&newCode[length], nextInst))
		return (void*)-1;

	DWORD oldp;
	if (!VirtualProtect(newCode, length+5, PAGE_EXECUTE_READWRITE, &oldp)) {
		STARTUPMESSAGE("ExecASM: Could not mark page for new code as executable: from address 0x%08x", from);
		return (void*)-1;
	}
	
	if (!DDrPatch_MakeJump(from, newCode))
		return (void*)-1;

	return newCode;
}

void DDrPatch_PrintDisasm(void* addr, int instLimit, int sizeLimit)
{
	DISASM MyDisasm;
	int len = 0;
	int size = 0;
	int i = 0;

	memset(&MyDisasm, 0, sizeof(DISASM));

	MyDisasm.EIP = (UIntPtr) addr;

	STARTUPMESSAGE("", 0);
	STARTUPMESSAGE("Disassembly @ 0x%06x", addr);

	if (sizeLimit <= 0)
		sizeLimit = 20 * instLimit;

	while ((i < instLimit) && (size < sizeLimit)) {
		len = Disasm(&MyDisasm);
		if (len != UNKNOWN_OPCODE) {
			size += len;
			STARTUPMESSAGE("    %s, Opcode: 0x%x, len: %d, branch: %d, to: 0x%06x", MyDisasm.CompleteInstr, MyDisasm.Instruction.Opcode, len, MyDisasm.Instruction.BranchType, MyDisasm.Instruction.AddrValue);
			STARTUPMESSAGE("          Cat: 0x%04x, prefix count: %d", MyDisasm.Instruction.Category & 0xffff, MyDisasm.Prefix.Number );

			MyDisasm.EIP += (UIntPtr)len;
			i++;
		}
	};

	STARTUPMESSAGE("", 0);
}

