#include "pch.h"
#include "lkmUtils.h"

//..............................................................................

#ifdef CONFIG_X86
#	ifndef TDEVMON_X86_WPR_PTE

size_t
getWriteProtectionBackupSize(
	const void* begin,
	const void* end
) {
	return sizeof(ulong) * 2;
}

// starting with Linux 5.3, they pin critical bits of control registers

static
inline
void
write_cr0_direct(unsigned long val) {
	asm volatile("mov %0,%%cr0": "+r" (val) : : "memory");
}

static
inline
void
write_cr4_direct(unsigned long val) {
	asm volatile("mov %0,%%cr4": "+r" (val) : : "memory");
}

static
inline
void
disableWriteProtection(
	const void* begin,
	const void* end,
	void* backup,
	size_t backupSize
) {
	ulong cr0;
	ulong cr4;

	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));

	cr0 = read_cr0();
	cr4 = native_read_cr4();

#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 18, 0))
	if (cr4 & X86_CR4_CET)
		write_cr4_direct(cr4 & ~X86_CR4_CET);
#endif

	write_cr0_direct(cr0 & ~X86_CR0_WP);

	((ulong*)backup)[0] = cr0;
	((ulong*)backup)[1] = cr4;
}

static
inline
void
restoreWriteProtection(
	const void* begin,
	const void* end,
	const void* backup,
	size_t backupSize
) {
	ulong cr0;
	ulong cr4;

	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));

	cr0 = ((const ulong*)backup)[0];
	cr4 = ((const ulong*)backup)[1];

	write_cr0_direct(cr0);

#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 18, 0))
	if (cr4 & X86_CR4_CET)
		write_cr4_direct(cr4);
#endif
}

#	else // TDEVMON_X86_WPR_PTE

size_t
getWriteProtectionBackupSize(
	const void* begin0,
	const void* end0
) {
	size_t begin = ALIGN((size_t)begin0 - (PAGE_SIZE - 1), PAGE_SIZE);
	size_t end = ALIGN((size_t)end0, PAGE_SIZE);
	size_t pageCount = (end - begin) / PAGE_SIZE; // could be more than needed but that's ok

	ASSERT(pageCount);
	return pageCount * sizeof(pteval_t);
}

static
void
disableWriteProtectionImpl(
	bool isApply, // 'true' to apply (and back-up), 'false' to restore from back-up
	const void* begin0,
	const void* end0,
	pteval_t* backup
) {
	size_t begin = ALIGN((size_t)begin0 - (PAGE_SIZE - 1), PAGE_SIZE);
	size_t end = ALIGN((size_t)end0, PAGE_SIZE);
	size_t addr;
	pte_t* pte;
	pteval_t pteVal;
	unsigned int level;
	size_t pageSize;

	for (addr = begin; addr < end; backup++) {
		pte = lookup_address(addr, &level);
		pageSize = page_level_size(level);

		if (isApply) {
			pteVal = pte_val(*pte);
			set_pte(pte, __pte(pteVal | _PAGE_RW));
			*backup = pteVal;
		} else {
			pteVal = *backup;
			set_pte(pte, __pte(pteVal));
		}

		ASSERT(pageSize >= PAGE_SIZE);
		addr = ALIGN(addr + 1, pageSize);
	}

	__flush_tlb_all();
}

static
inline
void
disableWriteProtection(
	const void* begin,
	const void* end,
	void* backup,
	size_t backupSize
) {
	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));
	disableWriteProtectionImpl(true, begin, end, (pmdval_t*)backup);
}

static
inline
void
restoreWriteProtection(
	const void* begin,
	const void* end,
	const void* backup,
	size_t backupSize
) {
	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));
	disableWriteProtectionImpl(false, begin, end, (pmdval_t*)backup);
}

#	endif // TDEVMON_X86_WPR_PTE
#elif (defined CONFIG_ARM)

enum {
#ifdef CONFIG_ARM_LPAE
#	if (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 14, 3))
	PmdFlag_WpMask     = L_PMD_SECT_RDONLY | PMD_SECT_AP2,
#	else
	PmdFlag_WpMask     = L_PMD_SECT_RDONLY,
#	endif
	PmdFlag_WpDisabled = 0,
#else
	PmdFlag_WpMask     = PMD_SECT_APX | PMD_SECT_AP_WRITE,
	PmdFlag_WpDisabled = PMD_SECT_AP_WRITE,
#endif
};

static
void
disableWriteProtectionImpl(
	bool isApply, // 'true' to apply (and back-up), 'false' to restore from back-up
	const void* begin0,
	const void* end0,
	pmdval_t* backup
) {
	size_t begin = ALIGN((size_t)begin0 - (SECTION_SIZE - 1), SECTION_SIZE);
	size_t end = ALIGN((size_t)end0, SECTION_SIZE);
	size_t addr;

	pgd_t* pgd;
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 0))
	p4d_t* p4d;
#endif
	pud_t* pud;
	pmd_t* pmd;
	size_t pmdIdx = 0;
	pmdval_t pmdVal;

	for (addr = begin; addr < end; addr += (size_t)SECTION_SIZE, backup++) {
		pgd = pgd_offset(current->active_mm, addr);
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 0))
		p4d = p4d_offset(pgd, addr);
		pud = pud_offset(p4d, addr);
#else
		pud = pud_offset(pgd, addr);
#endif
		pmd = pmd_offset(pud, addr);

#ifndef CONFIG_ARM_LPAE
		pmdIdx = (addr & SECTION_SIZE) ? 1 : 0;
#endif

		if (isApply) {
			pmdVal = pmd_val(pmd[pmdIdx]);
			pmd[pmdIdx] = __pmd((pmdVal & ~PmdFlag_WpMask) | PmdFlag_WpDisabled);
			*backup = pmdVal;
		} else {
			pmdVal = *backup;
			pmd[pmdIdx] = __pmd(pmdVal);
		}

		flush_pmd_entry(pmd);
	}

	local_flush_tlb_all();
}

static
inline
void
disableWriteProtection(
	const void* begin,
	const void* end,
	void* backup,
	size_t backupSize
) {
	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));
	disableWriteProtectionImpl(true, begin, end, (pmdval_t*)backup);
}

static
inline
void
restoreWriteProtection(
	const void* begin,
	const void* end,
	const void* backup,
	size_t backupSize
) {
	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));
	disableWriteProtectionImpl(false, begin, end, (pmdval_t*)backup);
}

size_t
getWriteProtectionBackupSize(
	const void* begin0,
	const void* end0
) {
	size_t begin = ALIGN((size_t)begin0 - (SECTION_SIZE - 1), SECTION_SIZE);
	size_t end = ALIGN((size_t)end0, SECTION_SIZE);
	size_t sectionCount = (end - begin) / SECTION_SIZE;

	ASSERT(sectionCount);

	return sectionCount * sizeof(pmdval_t);
}

#elif (defined CONFIG_ARM64)
#	ifdef CONFIG_ARM64_PA_BITS_52
#		error "CONFIG_ARM64_PA_BITS_52 is not supported yet"
#	endif

static
inline
uint64_t
read_ttbr1_el1(void) {
    uint64_t val;
    asm volatile("mrs %0, ttbr1_el1" : "=r" (val));
    return val;
}

static
void
disableWriteProtectionImpl(
	bool isApply, // 'true' to apply (and back-up), 'false' to restore from back-up
	const void* begin0,
	const void* end0,
	uint64_t* backup
) {
	size_t begin = (size_t)begin0 & PAGE_MASK;
	size_t end = ((size_t)end0 + PAGE_SIZE) & PAGE_MASK;
	size_t addr = begin;
	uint64_t ttbr1_el1 = read_ttbr1_el1(); // the TTBR1_EL1 register holds kernel PGD
	pgd_t* pgd0 = (pgd_t*)phys_to_virt(ttbr1_el1 & PAGE_MASK); // clear CnP and other flags
	pgd_t* pgd;
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 0))
	p4d_t* p4d;
#endif
	pud_t* pud;
	pmd_t* pmd;
	pmd_t pmdv;
	pte_t* pte;
	pte_t ptev;

	while (addr < end) {
		pgd = pgd_offset_pgd(pgd0, addr);
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 0))
		p4d = p4d_offset(pgd, addr);
		pud = pud_offset(p4d, addr);
#else
		pud = pud_offset(pgd, addr);
#endif
		pmd = pmd_offset(pud, addr);
		pmdv = READ_ONCE(*pmd);

		if (pmd_leaf(pmdv)) {
			if (isApply) {
				*backup = pmd_val(pmdv);
				pmdv = pmd_mkwrite_novma(pmdv);
			} else if (*backup & PTE_RDONLY) {
				pmdv = pmd_wrprotect(pmdv);
			}

			set_pmd(pmd, pmdv);
			addr = (addr + PMD_SIZE) & PMD_MASK;
		} else {
			pte = pte_offset_kernel(pmd, addr);
			ptev = READ_ONCE(*pte);

			if (isApply) {
				*backup = pte_val(ptev);
				ptev = pte_mkwrite_novma(ptev);
			} else if (*backup & PTE_RDONLY) {
				ptev = pte_wrprotect(ptev);
			}

			set_pte(pte, ptev);
			addr = (addr + PAGE_SIZE) & PAGE_MASK;
		}

		backup++;
	}

	local_flush_tlb_all();
}

static
inline
void
disableWriteProtection(
	const void* begin,
	const void* end,
	void* backup,
	size_t backupSize
) {
	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));
	disableWriteProtectionImpl(true, begin, end, (uint64_t*)backup);
}

static
inline
void
restoreWriteProtection(
	const void* begin,
	const void* end,
	const void* backup,
	size_t backupSize
) {
	ASSERT(backupSize >= getWriteProtectionBackupSize(begin, end));
	disableWriteProtectionImpl(false, begin, end, (uint64_t*)backup);
}

size_t
getWriteProtectionBackupSize(
	const void* begin0,
	const void* end0
) {
	size_t begin = (size_t)begin0 & PAGE_MASK;
	size_t end = ((size_t)end0 + PAGE_SIZE) & PAGE_MASK;
	size_t pageCount = (end - begin) / PAGE_SIZE;

	ASSERT(pageCount);
	return pageCount * sizeof(uint64_t);
}

#endif

void
disablePreemptionAndWriteProtection(
	const void* begin,
	const void* end,
	void* backup,
	size_t backupSize
) {
	preempt_disable(); // barrier is created
	disableWriteProtection(begin, end, backup, backupSize);
}

void
restoreWriteProtectionAndPreemption(
	const void* begin,
	const void* end,
	const void* backup,
	size_t backupSize
) {
	restoreWriteProtection(begin, end, backup, backupSize);
	preempt_enable(); // barrier is created
}

// . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .

char*
createPathString(const struct path* path) {
	char buffer[128]; // more than enough
	char* src;
	char* dst;
	size_t size;

	src = d_path(path, buffer, sizeof(buffer) - 1);
	if (IS_ERR(src))
		return src;

	buffer[sizeof(buffer) - 1] = 0;
	size = strlen(src) + 1;
	dst = kmalloc(size, GFP_KERNEL);
	if (!dst)
		return ERR_PTR(-ENOMEM);

	memcpy(dst, src, size);
	return dst;
}

char*
copyStringFromUser(const dm_String __user* string_u) {
	int result;
	dm_String string;
	char* p;

	result = copy_from_user(&string, string_u, sizeof(dm_String));
	if (result != 0)
		return ERR_PTR(-EFAULT);

	p = kmalloc(string.m_length + 1, GFP_KERNEL);
	if (!p)
		return ERR_PTR(-ENOMEM);

	result = copy_from_user(p, string_u + 1, string.m_length);
	if (result != 0) {
		kfree(p);
		return ERR_PTR(-EFAULT);
	}

	p[string.m_length] = 0; // ensure zero-terminated
	return p;
}

int
copyStringToUser(
	dm_String __user* string_u,
	const char* p
) {
	int result;
	dm_String string;
	size_t stringSize;
	size_t bufferSize;

	result = copy_from_user(&string, string_u, sizeof(dm_String));
	if (result != 0)
		return -EFAULT;

	string.m_length = strlen(p);

	stringSize = string.m_length + 1;
	bufferSize = sizeof(dm_String) + stringSize;

	if (string.m_bufferSize < bufferSize) {
		string.m_bufferSize = bufferSize;
		result = copy_to_user(string_u, &string, sizeof(dm_String));
		return result == 0 ? -ENOBUFS : -EFAULT;
	}

	result = copy_to_user(string_u, &string, sizeof(dm_String));

	if (result == 0)
		result = copy_to_user(string_u + 1, p, stringSize);

	return result == 0 ? 0 : -EFAULT;
}

uint64_t
getTimestamp(void) {
	enum {
		// epoch difference between Unix time (1 Jan 1970 00:00) and Windows time (1 Jan 1601 00:00)
		EpochDiff = 11644473600LL
	};

#if (LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0))
	struct timespec tspec;
	ktime_get_real_ts(&tspec);
#else
	struct timespec64 tspec;
	ktime_get_real_ts64(&tspec);
#endif

	return (uint64_t)(tspec.tv_sec + EpochDiff) * 10000000 + tspec.tv_nsec / 100;
}

struct module*
getOwnerModule(struct file* filp) {
	if (filp->f_op && filp->f_op->owner)
		return filp->f_op->owner;

	if (!filp->f_inode) // bummer, nowhere to go from here. but wait, is it even possible?
		return NULL;

	if (!S_ISCHR(filp->f_inode->i_mode)) {
		printk(KERN_WARNING "tdevmon: couldn't find owner module of non-char device (%x)\n", filp->f_inode->i_mode);
		return NULL;
	}

	if (filp->f_inode->i_cdev &&
		filp->f_inode->i_cdev->owner)
		return filp->f_inode->i_cdev->owner;

	printk(
		KERN_WARNING
		"tdevmon: couldn't find owner module (cdev: %p, cdev->owner: %p); "
		"TODO: lookup via f_inode->i_rdev (%x))\n",
		filp->f_inode->i_cdev,
		filp->f_inode->i_cdev->owner,
		filp->f_inode->i_rdev
	);

	return NULL;
}

//..............................................................................
