1 /**************************************************************************
3 * Copyright 2011-2012 Jose Fonseca
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to deal
8 * in the Software without restriction, including without limitation the rights
9 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 * copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 **************************************************************************/
28 * Code for the DLL that will be injected in the target process.
30 * The injected DLL will manipulate the import tables to hook the
31 * modules/functions of interest.
34 * - http://www.codeproject.com/KB/system/api_spying_hack.aspx
35 * - http://www.codeproject.com/KB/threads/APIHooking.aspx
36 * - http://msdn.microsoft.com/en-us/magazine/cc301808.aspx
54 static CRITICAL_SECTION Mutex = {(PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0};
58 debugPrintf(const char *format, ...)
64 _vsnprintf(buf, sizeof buf, format, ap);
67 OutputDebugStringA(buf);
72 MyLoadLibraryA(LPCSTR lpLibFileName);
75 MyLoadLibraryW(LPCWSTR lpLibFileName);
78 MyLoadLibraryExA(LPCSTR lpFileName, HANDLE hFile, DWORD dwFlags);
81 MyLoadLibraryExW(LPCWSTR lpFileName, HANDLE hFile, DWORD dwFlags);
84 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName);
88 getImportDescriptionName(HMODULE hModule, const PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor) {
89 const char* szName = (const char*)((PBYTE)hModule + pImportDescriptor->Name);
94 static PIMAGE_IMPORT_DESCRIPTOR
95 getImportDescriptor(HMODULE hModule,
97 const char *pszDllName)
99 MEMORY_BASIC_INFORMATION MemoryInfo;
100 if (VirtualQuery(hModule, &MemoryInfo, sizeof MemoryInfo) != sizeof MemoryInfo) {
101 debugPrintf("%s: %s: VirtualQuery failed\n", __FUNCTION__, szModule);
104 if (MemoryInfo.Protect & (PAGE_NOACCESS | PAGE_EXECUTE)) {
105 debugPrintf("%s: %s: no read access (Protect = 0x%08x)\n", __FUNCTION__, szModule, MemoryInfo.Protect);
109 PIMAGE_DOS_HEADER pDosHeader = (PIMAGE_DOS_HEADER)hModule;
110 PIMAGE_NT_HEADERS pNtHeaders = (PIMAGE_NT_HEADERS)((PBYTE)hModule + pDosHeader->e_lfanew);
112 PIMAGE_OPTIONAL_HEADER pOptionalHeader = &pNtHeaders->OptionalHeader;
114 UINT_PTR ImportAddress = pOptionalHeader->DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;
116 if (!ImportAddress) {
120 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = (PIMAGE_IMPORT_DESCRIPTOR)((PBYTE)hModule + ImportAddress);
122 while (pImportDescriptor->FirstThunk) {
123 const char* szName = getImportDescriptionName(hModule, pImportDescriptor);
124 if (stricmp(pszDllName, szName) == 0) {
125 return pImportDescriptor;
135 replaceAddress(LPVOID *lpOldAddress, LPVOID lpNewAddress)
139 if (*lpOldAddress == lpNewAddress) {
143 EnterCriticalSection(&Mutex);
145 if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, PAGE_READWRITE, &flOldProtect))) {
146 LeaveCriticalSection(&Mutex);
150 *lpOldAddress = lpNewAddress;
152 if (!(VirtualProtect(lpOldAddress, sizeof *lpOldAddress, flOldProtect, &flOldProtect))) {
153 LeaveCriticalSection(&Mutex);
157 LeaveCriticalSection(&Mutex);
163 getOldFunctionAddress(HMODULE hModule,
164 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
165 const char* pszFunctionName)
167 PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
168 PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
170 //debugPrintf(" %s\n", __FUNCTION__);
172 while (pOriginalFirstThunk->u1.Function) {
173 PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
174 const char* szName = (const char* )pImport->Name;
175 //debugPrintf(" %s\n", szName);
176 if (strcmp(pszFunctionName, szName) == 0) {
177 //debugPrintf(" %s succeeded\n", __FUNCTION__);
178 return (LPVOID *)(&pFirstThunk->u1.Function);
180 ++pOriginalFirstThunk;
184 //debugPrintf(" %s failed\n", __FUNCTION__);
191 replaceModule(HMODULE hModule,
192 const char *szModule,
193 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor,
196 PIMAGE_THUNK_DATA pOriginalFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->OriginalFirstThunk);
197 PIMAGE_THUNK_DATA pFirstThunk = (PIMAGE_THUNK_DATA)((PBYTE)hModule + pImportDescriptor->FirstThunk);
199 while (pOriginalFirstThunk->u1.Function) {
200 PIMAGE_IMPORT_BY_NAME pImport = (PIMAGE_IMPORT_BY_NAME)((PBYTE)hModule + pOriginalFirstThunk->u1.AddressOfData);
201 const char* szFunctionName = (const char* )pImport->Name;
203 debugPrintf(" hooking %s->%s!%s\n", szModule,
204 getImportDescriptionName(hModule, pImportDescriptor),
208 PROC pNewProc = GetProcAddress(hNewModule, szFunctionName);
210 debugPrintf("warning: no replacement for %s\n", szFunctionName);
212 LPVOID *lpOldAddress = (LPVOID *)(&pFirstThunk->u1.Function);
213 replaceAddress(lpOldAddress, (LPVOID)pNewProc);
216 ++pOriginalFirstThunk;
223 hookFunction(HMODULE hModule,
224 const char *szModule,
225 const char *pszDllName,
226 const char *pszFunctionName,
229 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
230 if (pImportDescriptor == NULL) {
233 LPVOID* lpOldFunctionAddress = getOldFunctionAddress(hModule, pImportDescriptor, pszFunctionName);
234 if (lpOldFunctionAddress == NULL) {
238 if (*lpOldFunctionAddress == lpNewAddress) {
242 if (VERBOSITY >= 3) {
243 debugPrintf(" hooking %s->%s!%s\n", szModule, pszDllName, pszFunctionName);
246 return replaceAddress(lpOldFunctionAddress, lpNewAddress);
251 replaceImport(HMODULE hModule,
252 const char *szModule,
253 const char *pszDllName,
260 PIMAGE_IMPORT_DESCRIPTOR pImportDescriptor = getImportDescriptor(hModule, szModule, pszDllName);
261 if (pImportDescriptor == NULL) {
265 replaceModule(hModule, szModule, pImportDescriptor, hNewModule);
270 static HMODULE g_hThisModule = NULL;
274 const char *szMatchModule;
275 HMODULE hReplaceModule;
278 static unsigned numReplacements = 0;
279 static Replacement replacements[32];
284 hookModule(HMODULE hModule,
285 const char *szModule)
287 if (hModule == g_hThisModule) {
291 for (unsigned i = 0; i < numReplacements; ++i) {
292 if (hModule == replacements[i].hReplaceModule) {
297 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryA", (LPVOID)MyLoadLibraryA);
298 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryW", (LPVOID)MyLoadLibraryW);
299 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExA", (LPVOID)MyLoadLibraryExA);
300 hookFunction(hModule, szModule, "kernel32.dll", "LoadLibraryExW", (LPVOID)MyLoadLibraryExW);
301 hookFunction(hModule, szModule, "kernel32.dll", "GetProcAddress", (LPVOID)MyGetProcAddress);
303 const char *szBaseName = getBaseName(szModule);
304 for (unsigned i = 0; i < numReplacements; ++i) {
305 if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
310 /* Don't hook internal dependencies */
311 if (stricmp(szBaseName, "d3d10core.dll") == 0 ||
312 stricmp(szBaseName, "d3d10level9.dll") == 0 ||
313 stricmp(szBaseName, "d3d10sdklayers.dll") == 0 ||
314 stricmp(szBaseName, "d3d10_1core.dll") == 0 ||
315 stricmp(szBaseName, "d3d11sdklayers.dll") == 0 ||
316 stricmp(szBaseName, "d3d11_1sdklayers.dll") == 0) {
320 for (unsigned i = 0; i < numReplacements; ++i) {
321 replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
322 replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
323 replaceImport(hModule, szModule, replacements[i].szMatchModule, replacements[i].hReplaceModule);
330 HANDLE hModuleSnap = CreateToolhelp32Snapshot(TH32CS_SNAPMODULE, GetCurrentProcessId());
331 if (hModuleSnap == INVALID_HANDLE_VALUE) {
336 me32.dwSize = sizeof me32;
339 static bool first = true;
341 if (Module32First(hModuleSnap, &me32)) {
342 debugPrintf(" modules:\n");
344 debugPrintf(" %s\n", me32.szExePath);
345 } while (Module32Next(hModuleSnap, &me32));
351 if (Module32First(hModuleSnap, &me32)) {
353 hookModule(me32.hModule, me32.szExePath);
354 } while (Module32Next(hModuleSnap, &me32));
357 CloseHandle(hModuleSnap);
363 static HMODULE WINAPI
364 MyLoadLibrary(LPCSTR lpLibFileName, HANDLE hFile = NULL, DWORD dwFlags = 0)
366 // To Send the information to the server informing that,
367 // LoadLibrary is invoked.
368 HMODULE hModule = LoadLibraryExA(lpLibFileName, hFile, dwFlags);
370 //hookModule(hModule, lpLibFileName);
376 static HMODULE WINAPI
377 MyLoadLibraryA(LPCSTR lpLibFileName)
379 if (VERBOSITY >= 2) {
380 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
384 const char *szBaseName = getBaseName(lpLibFileName);
385 for (unsigned i = 0; i < numReplacements; ++i) {
386 if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
387 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
389 void *caller = __builtin_return_address (0);
392 BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
397 DWORD dwRet = GetModuleFileNameA(hModule, szCaller, sizeof szCaller);
399 debugPrintf(" called from %s\n", szCaller);
406 return MyLoadLibrary(lpLibFileName);
409 static HMODULE WINAPI
410 MyLoadLibraryW(LPCWSTR lpLibFileName)
412 if (VERBOSITY >= 2) {
413 debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
416 char szFileName[256];
417 wcstombs(szFileName, lpLibFileName, sizeof szFileName);
419 return MyLoadLibrary(szFileName);
422 static HMODULE WINAPI
423 MyLoadLibraryExA(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
425 if (VERBOSITY >= 2) {
426 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpLibFileName);
428 return MyLoadLibrary(lpLibFileName, hFile, dwFlags);
431 static HMODULE WINAPI
432 MyLoadLibraryExW(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags)
434 if (VERBOSITY >= 2) {
435 debugPrintf("%s(L\"%S\")\n", __FUNCTION__, lpLibFileName);
438 char szFileName[256];
439 wcstombs(szFileName, lpLibFileName, sizeof szFileName);
441 return MyLoadLibrary(szFileName, hFile, dwFlags);
444 static FARPROC WINAPI
445 MyGetProcAddress(HMODULE hModule, LPCSTR lpProcName) {
447 if (VERBOSITY >= 99) {
448 /* XXX this can cause segmentation faults */
449 debugPrintf("%s(\"%s\")\n", __FUNCTION__, lpProcName);
452 assert(hModule != g_hThisModule);
453 for (unsigned i = 0; i < numReplacements; ++i) {
454 if (hModule == replacements[i].hReplaceModule) {
455 return GetProcAddress(hModule, lpProcName);
461 DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule);
463 const char *szBaseName = getBaseName(szModule);
465 for (unsigned i = 0; i < numReplacements; ++i) {
467 if (stricmp(szBaseName, replacements[i].szMatchModule) == 0) {
469 debugPrintf(" %s(\"%s\", \"%s\")\n", __FUNCTION__, szModule, lpProcName);
471 FARPROC pProcAddress = GetProcAddress(replacements[i].hReplaceModule, lpProcName);
473 if (VERBOSITY >= 2) {
474 debugPrintf(" replacing %s!%s\n", szBaseName, lpProcName);
479 debugPrintf(" ignoring %s!%s\n", szBaseName, lpProcName);
487 return GetProcAddress(hModule, lpProcName);
492 DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved)
494 const char *szNewDllName = NULL;
495 HMODULE hNewModule = NULL;
496 const char *szNewDllBaseName;
499 case DLL_PROCESS_ATTACH:
501 debugPrintf("DLL_PROCESS_ATTACH\n");
504 g_hThisModule = hinstDLL;
507 char szProcess[MAX_PATH];
508 GetModuleFileNameA(NULL, szProcess, sizeof szProcess);
510 debugPrintf(" attached to %s\n", szProcess);
515 * Calling LoadLibrary inside DllMain is strongly discouraged. But it
516 * works quite well, provided that the loaded DLL does not require or do
517 * anything special in its DllMain, which seems to be the general case.
520 * - http://stackoverflow.com/questions/4370812/calling-loadlibrary-from-dllmain
521 * - http://msdn.microsoft.com/en-us/library/ms682583
525 szNewDllName = getenv("INJECT_DLL");
527 debugPrintf("warning: INJECT_DLL not set\n");
531 static char szSharedMemCopy[MAX_PATH];
532 GetSharedMem(szSharedMemCopy, sizeof szSharedMemCopy);
533 szNewDllName = szSharedMemCopy;
536 debugPrintf(" injecting %s\n", szNewDllName);
539 hNewModule = LoadLibraryA(szNewDllName);
541 debugPrintf("warning: failed to load %s\n", szNewDllName);
545 szNewDllBaseName = getBaseName(szNewDllName);
546 if (stricmp(szNewDllBaseName, "dxgitrace.dll") == 0) {
547 replacements[numReplacements].szMatchModule = "dxgi.dll";
548 replacements[numReplacements].hReplaceModule = hNewModule;
551 replacements[numReplacements].szMatchModule = "d3d10.dll";
552 replacements[numReplacements].hReplaceModule = hNewModule;
555 replacements[numReplacements].szMatchModule = "d3d10_1.dll";
556 replacements[numReplacements].hReplaceModule = hNewModule;
559 replacements[numReplacements].szMatchModule = "d3d11.dll";
560 replacements[numReplacements].hReplaceModule = hNewModule;
563 replacements[numReplacements].szMatchModule = szNewDllBaseName;
564 replacements[numReplacements].hReplaceModule = hNewModule;
571 case DLL_THREAD_ATTACH:
574 case DLL_THREAD_DETACH:
577 case DLL_PROCESS_DETACH:
579 debugPrintf("DLL_PROCESS_DETACH\n");