#include "stdafx.h"
#include "regscan.h"

#define MAX_PROGRESS 500
#define REGSAM (KEY_QUERY_VALUE | KEY_ENUMERATE_SUB_KEYS | KEY_WOW64_64KEY)

// UI managemen
static volatile bool g_cancelled = false;
static HWND g_hDlg;
static unsigned g_current_progress;

// Reporting
static FILE *g_ofp;
static _locale_t g_locale;
static std::wstring report_filename;

// Set of specfic keys; one for interested (i.e., to be reported) keys and the other 
// is for ignored (i.e., to be excluded in report) keys.  Note that, as far as I know, 
// the Windows' comparison algorithm for registry keys is simply explained as "case 
// insensitive", but it internally does a lot of undocumented and no trivial processing.
// I can't write a <less> operator that behaves identically to the Windows' Reg* APIs.
// The following code seems to give a reasonable estimate. 

struct KeyLess : public std::binary_function<std::wstring, std::wstring, bool>
{
	bool operator()(std::wstring const &x, std::wstring const &y) const
	{
		return CSTR_LESS_THAN == CompareStringW(LOCALE_INVARIANT, NORM_IGNORECASE, x.data(), x.size(), y.data(), y.size());
	}
};

class KeySet : public std::set<std::wstring, KeyLess>
{
public:
	bool contains(std::wstring const &key) const { return find(key) != end(); }
};

static KeySet g_interested_keys;
static KeySet g_ignore_keys;

class Aborted
{
};

static void msgvprintf(LPCWSTR fmt, va_list ap)
{
	WCHAR caption[1024]; caption[0] = L'\0';
	GetWindowTextW(g_hDlg, caption, 1023);
	WCHAR text[1024]; text[0] = L'\0';
	_vswprintf_l(text, 1023, fmt, g_locale, ap);
	MessageBoxW(g_hDlg, text, caption, MB_OK | MB_ICONWARNING);
}

static void msgprintf(LPCWSTR fmt, ...)
{
	va_list ap;
	va_start(ap, fmt);
	msgvprintf(fmt, ap);
	va_end(ap);
}

static void fatal_error(LPCWSTR fmt, ...)
{
	va_list ap;
	va_start(ap, fmt);
	msgvprintf(fmt, ap);
	va_end(ap);
	PostMessage(g_hDlg, WM_CLOSE, 0, 0);
	throw Aborted();
}

static void xprintf(LPCWSTR fmt, ...)
{
	va_list ap;
	va_start(ap, fmt);
	_vfwprintf_l(g_ofp, fmt, g_locale, ap);
	va_end(ap);
}

static void report(LPCWSTR keypath, LPCWSTR name, DWORD type, LPCWSTR data)
{
	char const *rt = "?"; LPCWSTR d = L"";
	switch (type) {
	case REG_NONE:      rt = "(NONE)";               break;
	case REG_LINK:      rt = "(LINK)";               break;
	case REG_DWORD:     rt = "(DWORD)";              break;
	case REG_QWORD:     rt = "(QWORD)";              break;
	case REG_BINARY:    rt = "(BINARY)";             break;
	case REG_SZ:        rt = "";           d = data; break;
	case REG_EXPAND_SZ: rt = "(EXP_SZ)";   d = data; break;
	case REG_MULTI_SZ:  rt = "(MULTI_SZ)"; d = data; break;
	}

	xprintf(L"%s\\%s = %hs %s\n", keypath, name, rt, d);
}

static void report_ignore(LPCWSTR keypath, LPCWSTR subkey)
{
	xprintf(L"%s\\%s\\ : (ignored)\n", keypath, subkey);
}

static void update_progress(unsigned new_progress)
{
	if (new_progress <= g_current_progress) return;
	PostMessage(GetDlgItem(g_hDlg, IDC_PROGRESS), PBM_SETPOS, g_current_progress, 0);
	g_current_progress = new_progress;
}

static void scan(HKEY key, std::wstring const &keypath, bool hit, int rmin, int rmax)
{
	if (g_cancelled) return;

	LONG result;

	DWORD subkeys, max_subkey_len, values, max_value_name_len, max_value_len;
	result = RegQueryInfoKeyW(key, NULL, NULL, NULL, &subkeys, &max_subkey_len, NULL, &values, &max_value_name_len, &max_value_len, NULL, NULL);
	if (ERROR_SUCCESS != result) return; // Should report something.

	// Enumerate all values.
	{
		// We need WCHAR compatible memory alignment on the backed array
		// of value, although all Reg* function considers the value is an
		// array of BYTE.  It causes a lot of mess...
		std::vector<WCHAR> name(max_value_name_len + 1);
		std::vector<WCHAR> value((max_value_len + sizeof(WCHAR) - 1) / sizeof(WCHAR) + 1);

		// They say we need to keep enumerating until we get
		// ERROR_NO_MORE_ITEMS, so we need to repeat one more than the
		// actual values.
		for (DWORD i = 0; i <= values; i++) {
			if (g_cancelled) return;

			DWORD type;
			DWORD cch_value_name = name.size();
			DWORD cb_value = value.size() * sizeof(WCHAR);
			result = RegEnumValueW(key, i, &name.front(), &cch_value_name, NULL, &type, reinterpret_cast<LPBYTE>(&value.front()), &cb_value);
			if (ERROR_NO_MORE_ITEMS == result) break;
			if (ERROR_SUCCESS != result) {
				// I believe this is a serious failure and we should report the case.  FIXME.
				continue;
			}
			if (hit || g_interested_keys.contains(&name.front())) {
				value.back() = L'\0'; // Just to make sure.
				report(keypath.c_str(), &name.front(), type, &value.front());
			}
		}
	}

	// Enumerate all subkey names.
	{
		// They say we need to enumerate all subkeys before opening any of
		// the subkeys.
		std::vector<std::wstring> subkey_names;
		std::vector<WCHAR> name(max_subkey_len + 1);

		// They say we need to keep enumerating until we get
		// ERROR_NO_MORE_ITEMS, so we need to repeat one more than the
		// actual subkeys.
		for (DWORD i = 0;i <= subkeys; i++) {
			if (g_cancelled) return;

			DWORD size = name.size();
			result = RegEnumKeyExW(key, i, &name.front(), &size, NULL, NULL, NULL, NULL);
			if (ERROR_NO_MORE_ITEMS == result) break;
			if (ERROR_SUCCESS != result) {
				// I believe this is a serious failure and we should report the case.  FIXME.
				continue;
			}
			std::wstring const name_str(&name.front());
			if (g_ignore_keys.contains(name_str)) {
				if (hit) report_ignore(keypath.c_str(), &name.front());
			} else {
				subkey_names.push_back(name_str);
			}
		}

		std::wstring newpath;
		newpath.reserve(keypath.size() + 1 + max_subkey_len + 1); // +1 for L'\\' and + 1 for L'\0'.
		newpath.assign(keypath);
		newpath.append(L"\\");
		std::wstring::size_type const newpath_base_size = newpath.size();

		unsigned pmin = rmin;
		for (std::vector<std::wstring>::size_type i = 0; i < subkey_names.size(); i++) {
			if (g_cancelled) return;

			HKEY subkey;
			result = RegOpenKeyExW(key, subkey_names[i].c_str(), NULL, REGSAM, &subkey);
			if (ERROR_SUCCESS != result) {
				// I believe this is a serious failure and we should report the case.  FIXME.
				continue;
			}

			newpath.resize(newpath_base_size);
			newpath.append(subkey_names[i]);

			update_progress(pmin);
			unsigned pmax = (rmax - rmin) * (i + 1) / subkey_names.size() + rmin;
			scan(subkey, newpath, hit || g_interested_keys.contains(subkey_names[i]), pmin, pmax);
			pmin = pmax;

			RegCloseKey(subkey);
		}
	}
}

void initialize_registry_scanner(HWND hDlg)
{
	g_hDlg = hDlg;

	g_current_progress = 0;
	SendDlgItemMessageW(g_hDlg, IDC_PROGRESS, PBM_SETRANGE, 0, MAKELPARAM(0, MAX_PROGRESS));
	SendDlgItemMessageW(g_hDlg, IDC_PROGRESS, PBM_SETPOS, 0, 0);
}

static BOOL is_wow64()
{
	typedef BOOL (WINAPI *QueryFunc)(HANDLE, PBOOL);
	QueryFunc func = (QueryFunc)GetProcAddress(GetModuleHandleA("kernel32"), "IsWow64Process");
	BOOL w;
	return func && func(GetCurrentProcess(), &w) && w;
}

static void add_interest_from_progid(LPCWSTR progid)
{
	WCHAR data1[256], data2[256];
	LPWSTR p1 = data1, p2 = data2;
	LONG size;
	for (;;) {
		if (!progid || !progid[0]) return;
		if (g_interested_keys.contains(progid)) return;
		g_interested_keys.insert(progid);

		HKEY key;
		if (ERROR_SUCCESS != RegOpenKeyW(HKEY_CLASSES_ROOT, progid, &key)) return;
		ZeroMemory(p1, sizeof(data1));
		size = sizeof(data1);
		if (ERROR_SUCCESS == RegQueryValueW(key, L"CLSID", p1, &size) && p1[0]) {
			g_interested_keys.insert(p1);
		}
		ZeroMemory(p1, sizeof(data1));
		size = sizeof(data1);
		if (ERROR_SUCCESS == RegQueryValueW(key, L"CurVer", p1, &size) && p1[0]) {
			LPWSTR p = p1;
			p1 = p2;
			p2 = p;
			progid = p;
		}
		RegCloseKey(key);
	}
}

static void add_interest_from_extension(LPCWSTR ext)
{
	if (!ext || !ext[0]) return;
	g_interested_keys.insert(ext);

	LONG size;
	WCHAR data[256];
	ZeroMemory(data, sizeof(data));
	size = sizeof(data);
	if (ERROR_SUCCESS != RegQueryValueW(HKEY_CLASSES_ROOT, ext, data, &size)) return;
	add_interest_from_progid(data);
}

static void scan_registry()
{
	static LPCWSTR const extensions[] = {
		L".azw",
		L".mobi",
		L".prc",
		NULL
	};
	static LPCWSTR const progids[] = {
		NULL
	};
	static LPCWSTR const guids[] = {
		L"{ba91bbaf-9243-49bd-b0da-97bfc25eb976}",	// CLSID_MobiHandler
		NULL
	};
	static LPCWSTR const ignores[] = {
		L"shell",
		L"RecentDocs",
		NULL
	};

	for (LPCWSTR const *p = extensions; *p; p++) {
		add_interest_from_extension(*p);
	}
	for (LPCWSTR const *p = progids; *p; p++) {
		add_interest_from_progid(*p);
	}
	for (LPCWSTR const *p = guids; *p; p++) {
		g_interested_keys.insert(*p);
	}
	for (LPCWSTR const *p = ignores; *p; p++) {
		g_ignore_keys.insert(*p);
	}

	{
		WCHAR tmpdir[MAX_PATH + 1];
		DWORD ret = GetTempPathW(MAX_PATH, tmpdir);
		if (ret <= 0 || ret >= MAX_PATH) {
			fatal_error(L"Temporary folder is not available.");
		}
		report_filename.assign(tmpdir).append(L"report.txt");
		g_ofp = _wfopen(report_filename.c_str(), L"w,ccs=UTF-8");
		if (!g_ofp) {
			fatal_error(L"Cannot create the report file: %s", report_filename.c_str());
		}
	}

	OSVERSIONINFOW osv;
	ZeroMemory(&osv, sizeof(osv));
	osv.dwOSVersionInfoSize = sizeof(osv);
	GetVersionExW(&osv);
	xprintf(L"Windows %d bit Kernel %d.%d.%d (%ls)\n", (is_wow64() ? 64 : 32), osv.dwMajorVersion, osv.dwMinorVersion, osv.dwBuildNumber, osv.szCSDVersion);

	SetDlgItemTextW(g_hDlg, IDC_SYSLINK, L"Scanning Registry...");

	PostMessage(GetDlgItem(g_hDlg, IDC_PROGRESS), PBM_SETSTATE, PBST_NORMAL, 0);

	if (!g_cancelled) scan(HKEY_CLASSES_ROOT,  L"HKCR", false, MAX_PROGRESS * 0 / 4, MAX_PROGRESS * 1 / 4); 
	if (!g_cancelled) scan(HKEY_CURRENT_USER,  L"HKCU", false, MAX_PROGRESS * 1 / 4, MAX_PROGRESS * 2 / 4);
	if (!g_cancelled) scan(HKEY_LOCAL_MACHINE, L"HKLM", false, MAX_PROGRESS * 2 / 4, MAX_PROGRESS * 4 / 4);

	if (!g_cancelled)xprintf(L"Done.\n");
	fclose(g_ofp);
	g_ofp = NULL;

	ShowWindowAsync(GetDlgItem(g_hDlg, IDOK), SW_SHOW);
	ShowWindowAsync(GetDlgItem(g_hDlg, IDCANCEL), SW_HIDE);

	if (g_cancelled) {
		PostMessage(GetDlgItem(g_hDlg, IDC_PROGRESS), PBM_SETSTATE, PBST_ERROR, 0);
		SetDlgItemTextW(g_hDlg, IDC_SYSLINK, L"Cancelled.");
	} else {
		PostMessage(GetDlgItem(g_hDlg, IDC_PROGRESS), PBM_SETPOS, MAX_PROGRESS, 0);
		SetDlgItemTextW(g_hDlg, IDC_SYSLINK, L"Done.  Click <a ID=\"a\">here</a> to see the report.");
	}
}

void __cdecl start_registry_scan(void *)
{
	g_locale = _create_locale(LC_ALL, "C");
	g_ofp = NULL;
	try {
		scan_registry();
	} catch (Aborted const &) {
	} catch (...) {
		msgprintf(L"Internal Error: Unhandled exception.");
	}
	if (g_ofp) fclose(g_ofp); // XXX
	_free_locale(g_locale);
}

void __cdecl cancel_registry_scan(void *)
{
	g_cancelled = true;
	EnableWindow(GetDlgItem(g_hDlg, IDCANCEL), FALSE);
}


std::wstring const &get_report_filename()
{
	return report_filename;
}