1 /**************************************************************************
3 * Copyright 2011-2012 Jose Fonseca
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to deal
8 * in the Software without restriction, including without limitation the rights
9 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 * copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 **************************************************************************/
28 * Code for the DLL that will be injected in the target process.
30 * The injected DLL will manipulate the import tables to hook the
31 * modules/functions of interest.
34 * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
35 * - http://www.codeproject.com/KB/threads/APIHooking.aspx
36 * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
54 static CRITICAL_SECTION Mutex = {(PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0};
58 debugPrintf(const char *format, ...)
61 static char buf[4096];
63 EnterCriticalSection(&Mutex);
67 _vsnprintf(buf, sizeof buf, format, ap);
70 OutputDebugStringA(buf);
72 LeaveCriticalSection(&Mutex);
78 MyLoadLibraryA(LPCSTR lpLibFileName);
81 MyLoadLibraryW(LPCWSTR lpLibFileName);
84 MyLoadLibraryExA(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags);
87 MyLoadLibraryExW(LPCWSTR lpFileName, HANDLE hFile, DWORD dwFlags);
90 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName);
94 getImportDescriptionName(HMODULE hModule, const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor) {
95 const char* szName = (const char*)((PBYTE)hModule + pImportDescriptor->Name);
100 static PIMAGE_IMPORT_DESCRIPTOR
101 getImportDescriptor(HMODULE hModule,
102 const char *szModule,
103 const char *pszDllName)
105 MEMORY_BASIC_INFORMATION MemoryInfo;
106 if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) {
107 debugPrintf("%s: %s: VirtualQuery failed\n", __FUNCTION__, szModule);
110 if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) {
111 debugPrintf("%s: %s: no read access (Protect = 0x%08x)\n", __FUNCTION__, szModule, MemoryInfo.Protect);
115 PIMAGE_DOS_HEADER pDosHeader = (PIMAGE_DOS_HEADER)hModule;
116 PIMAGE_NT_HEADERS pNtHeaders = (PIMAGE_NT_HEADERS)((PBYTE)hModule + pDosHeader->e_lfanew);
118 PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader;
120 UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;
122 if (!ImportAddress) {
126 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = (PIMAGE_IMPORT_DESCRIPTOR)((PBYTE)hModule + ImportAddress);
128 while (pImportDescriptor->FirstThunk) {
129 const char* szName = getImportDescriptionName(hModule, pImportDescriptor);
130 if (stricmp(pszDllName, szName) == 0) {
131 return pImportDescriptor;
141 replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
145 if (*lpOldAddress == lpNewAddress) {
149 EnterCriticalSection(&Mutex);
151 if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, PAGE_READWRITE, &flOldProtect))) {
152 LeaveCriticalSection(&Mutex);
156 *lpOldAddress = lpNewAddress;
158 if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, flOldProtect, &flOldProtect))) {
159 LeaveCriticalSection(&Mutex);
163 LeaveCriticalSection(&Mutex);
169 getOldFunctionAddress(HMODULE hModule,
170 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
171 const char* pszFunctionName)
173 PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
174 PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
176 //debugPrintf(" %s\n", __FUNCTION__);
178 while (pOriginalFirstThunk->u1.Function) {
179 PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
180 const char* szName = (const char* )pImport->Name;
181 //debugPrintf(" %s\n", szName);
182 if (strcmp(pszFunctionName, szName) == 0) {
183 //debugPrintf(" %s succeeded\n", __FUNCTION__);
184 return (LPVOID *)(&pFirstThunk->u1.Function);
186 ++pOriginalFirstThunk;
190 //debugPrintf(" %s failed\n", __FUNCTION__);
197 replaceModule(HMODULE hModule,
198 const char *szModule,
199 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
202 PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
203 PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
205 while (pOriginalFirstThunk->u1.Function) {
206 PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
207 const char* szFunctionName = (const char* )pImport->Name;
208 debugPrintf(" hooking %s->%s!%s\n", szModule,
209 getImportDescriptionName(hModule, pImportDescriptor),
212 PROC pNewProc = GetProcAddress(hNewModule, szFunctionName);
214 debugPrintf(" warning: no replacement for %s\n", szFunctionName);
216 LPVOID *lpOldAddress = (LPVOID *)(&pFirstThunk->u1.Function);
217 replaceAddress(lpOldAddress, (LPVOID)pNewProc);
220 ++pOriginalFirstThunk;
227 hookFunction(HMODULE hModule,
228 const char *szModule,
229 const char *pszDllName,
230 const char *pszFunctionName,
233 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
234 if (pImportDescriptor == NULL) {
237 LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName);
238 if (lpOldFunctionAddress == NULL) {
242 if (*lpOldFunctionAddress == lpNewAddress) {
246 if (VERBOSITY >= 3) {
247 debugPrintf(" hooking %s->%s!%s\n", szModule, pszDllName, pszFunctionName);
250 return replaceAddress(lpOldFunctionAddress, lpNewAddress);
255 replaceImport(HMODULE hModule,
256 const char *szModule,
257 const char *pszDllName,
264 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
265 if (pImportDescriptor == NULL) {
269 replaceModule(hModule, szModule, pImportDescriptor, hNewModule);
274 static HMODULE g_hThisModule = NULL;
278 const char *szMatchModule;
279 HMODULE hReplaceModule;
282 static unsigned numReplacements = 0;
283 static Replacement replacements[32];
288 hookModule(HMODULE hModule,
289 const char *szModule)
291 if (hModule == g_hThisModule) {
295 for (unsigned i = 0; i < numReplacements; ++i) {
296 if (hModule == replacements[i].hReplaceModule) {
301 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryA", (LPVOID)MyLoadLibraryA);
302 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryW", (LPVOID)MyLoadLibraryW);
303 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExA", (LPVOID)MyLoadLibraryExA);
304 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExW", (LPVOID)MyLoadLibraryExW);
305 hookFunction(hModule, szModule, "kernel32.dll", "GetProcAddress", (LPVOID)MyGetProcAddress);
307 const char *szBaseName = getBaseName(szModule);
308 for (unsigned i = 0; i < numReplacements; ++i) {
309 if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
314 /* Don't hook internal dependencies */
315 if (stricmp(szBaseName, "d3d10core.dll") == 0 ||
316 stricmp(szBaseName, "d3d10level9.dll") == 0 ||
317 stricmp(szBaseName, "d3d10sdklayers.dll") == 0 ||
318 stricmp(szBaseName, "d3d10_1core.dll") == 0 ||
319 stricmp(szBaseName, "d3d11sdklayers.dll") == 0 ||
320 stricmp(szBaseName, "d3d11_1sdklayers.dll") == 0) {
324 for (unsigned i = 0; i < numReplacements; ++i) {
325 replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
326 replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
327 replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
334 HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
335 if (hModuleSnap == INVALID_HANDLE_VALUE) {
340 me32.dwSize = sizeof me32;
342 static bool first = true;
344 if (Module32First(hModuleSnap, &me32)) {
345 debugPrintf(" modules:\n");
347 debugPrintf(" %s\n", me32.szExePath);
348 } while (Module32Next(hModuleSnap, &me32));
353 if (Module32First(hModuleSnap, &me32)) {
355 hookModule(me32.hModule, me32.szExePath);
356 } while (Module32Next(hModuleSnap, &me32));
359 CloseHandle(hModuleSnap);
365 static HMODULE WINAPI
366 MyLoadLibrary(LPCSTR lpLibFileName, HANDLE hFile = NULL, DWORD dwFlags = 0)
368 // To Send the information to the server informing that,
369 // LoadLibrary is invoked.
370 HMODULE hModule = LoadLibraryExA(lpLibFileName, hFile, dwFlags);
372 //hookModule(hModule, lpLibFileName);
378 static HMODULE WINAPI
379 MyLoadLibraryA(LPCSTR lpLibFileName)
381 if (VERBOSITY >= 2) {
382 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
385 const char *szBaseName = getBaseName(lpLibFileName);
386 for (unsigned i = 0; i < numReplacements; ++i) {
387 if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
388 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
390 void *caller = __builtin_return_address (0);
393 BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
398 DWORD dwRet = GetModuleFileNameA(hModule, szCaller, sizeof szCaller);
400 debugPrintf(" called from %s\n", szCaller);
406 return MyLoadLibrary(lpLibFileName);
409 static HMODULE WINAPI
410 MyLoadLibraryW(LPCWSTR lpLibFileName)
412 if (VERBOSITY >= 2) {
413 debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
416 char szFileName[256];
417 wcstombs(szFileName, lpLibFileName, sizeof szFileName);
419 return MyLoadLibrary(szFileName);
422 static HMODULE WINAPI
423 MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
425 if (VERBOSITY >= 2) {
426 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
428 return MyLoadLibrary(lpLibFileName, hFile, dwFlags);
431 static HMODULE WINAPI
432 MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
434 if (VERBOSITY >= 2) {
435 debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
438 char szFileName[256];
439 wcstombs(szFileName, lpLibFileName, sizeof szFileName);
441 return MyLoadLibrary(szFileName, hFile, dwFlags);
444 static FARPROC WINAPI
445 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
447 if (VERBOSITY >= 99) {
448 /* XXX this can cause segmentation faults */
449 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpProcName);
452 assert(hModule != g_hThisModule);
453 for (unsigned i = 0; i < numReplacements; ++i) {
454 if (hModule == replacements[i].hReplaceModule) {
455 return GetProcAddress(hModule, lpProcName);
461 DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule);
463 const char *szBaseName = getBaseName(szModule);
465 for (unsigned i = 0; i < numReplacements; ++i) {
467 if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
468 debugPrintf(" %s(\"%s\", \"%s\")\n", __FUNCTION__, szModule, lpProcName);
469 FARPROC pProcAddress = GetProcAddress(replacements[i].hReplaceModule, lpProcName);
471 if (VERBOSITY >= 2) {
472 debugPrintf(" replacing %s!%s\n", szBaseName, lpProcName);
476 debugPrintf(" ignoring %s!%s\n", szBaseName, lpProcName);
483 return GetProcAddress(hModule, lpProcName);
488 DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
490 const char *szNewDllName = NULL;
491 HMODULE hNewModule = NULL;
492 const char *szNewDllBaseName;
495 case DLL_PROCESS_ATTACH:
496 debugPrintf("DLL_PROCESS_ATTACH\n");
498 g_hThisModule = hinstDLL;
501 char szProcess[MAX_PATH];
502 GetModuleFileNameA(NULL, szProcess, sizeof szProcess);
503 debugPrintf(" attached to %s\n", szProcess);
507 * Calling LoadLibrary inside DllMain is strongly discouraged. But it
508 * works quite well, provided that the loaded DLL does not require or do
509 * anything special in its DllMain, which seems to be the general case.
512 * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain
513 * - http://msdn.microsoft.com/en-us/library/ms682583
517 szNewDllName = getenv("INJECT_DLL");
519 debugPrintf("warning: INJECT_DLL not set\n");
523 static char szSharedMemCopy[MAX_PATH];
524 GetSharedMem(szSharedMemCopy, sizeof szSharedMemCopy);
525 szNewDllName = szSharedMemCopy;
527 debugPrintf(" injecting %s\n", szNewDllName);
529 hNewModule = LoadLibraryA(szNewDllName);
531 debugPrintf("warning: failed to load %s\n", szNewDllName);
535 szNewDllBaseName = getBaseName(szNewDllName);
536 if (stricmp(szNewDllBaseName, "dxgitrace.dll") == 0) {
537 replacements[numReplacements].szMatchModule = "dxgi.dll";
538 replacements[numReplacements].hReplaceModule = hNewModule;
541 replacements[numReplacements].szMatchModule = "d3d10.dll";
542 replacements[numReplacements].hReplaceModule = hNewModule;
545 replacements[numReplacements].szMatchModule = "d3d10_1.dll";
546 replacements[numReplacements].hReplaceModule = hNewModule;
549 replacements[numReplacements].szMatchModule = "d3d11.dll";
550 replacements[numReplacements].hReplaceModule = hNewModule;
553 replacements[numReplacements].szMatchModule = szNewDllBaseName;
554 replacements[numReplacements].hReplaceModule = hNewModule;
561 case DLL_THREAD_ATTACH:
564 case DLL_THREAD_DETACH:
567 case DLL_PROCESS_DETACH:
568 debugPrintf("DLL_PROCESS_DETACH\n");