]> git.notmuchmail.org Git - apitrace/blobdiff - trace.py
Build fixes and minor corrections.
[apitrace] / trace.py
index 9f115a32e92687086280fee60d415419d44405a9..164c9777b073f68f2db7c37d993fe824d2646d10 100644 (file)
--- a/trace.py
+++ b/trace.py
@@ -238,7 +238,9 @@ class Wrapper(stdapi.Visitor):
         pass
 
     def visit_pointer(self, pointer, instance):
+        print "    if (%s) {" % instance
         self.visit(pointer.type, "*" + instance)
+        print "    }"
 
     def visit_handle(self, handle, instance):
         self.visit(handle.type, instance)
@@ -252,8 +254,9 @@ class Wrapper(stdapi.Visitor):
     def visit_interface(self, interface, instance):
         assert instance.startswith('*')
         instance = instance[1:]
-        print "    if (%s)" % instance
+        print "    if (%s) {" % instance
         print "        %s = new %s(%s);" % (instance, interface_wrap_name(interface), instance)
+        print "    }"
 
 
 class Unwrapper(Wrapper):
@@ -261,8 +264,9 @@ class Unwrapper(Wrapper):
     def visit_interface(self, interface, instance):
         assert instance.startswith('*')
         instance = instance[1:]
-        print "    if (%s)" % instance
+        print "    if (%s) {" % instance
         print "        %s = static_cast<%s *>(%s)->m_pInstance;" % (instance, interface_wrap_name(interface), instance)
+        print "    }"
 
 
 wrap_instance = Wrapper().visit
@@ -271,7 +275,12 @@ unwrap_instance = Unwrapper().visit
 
 class Tracer:
 
+    def __init__(self):
+        self.api = None
+
     def trace_api(self, api):
+        self.api = api
+
         self.header(api)
 
         # Includes
@@ -427,8 +436,15 @@ class Tracer:
             wrap_instance(method.type, '__result')
         print '    Trace::EndLeave();'
         if method.name == 'QueryInterface':
-            print '    if (*ppvObj == m_pInstance)'
-            print '        *ppvObj = this;'
+            print '    if (ppvObj && *ppvObj) {'
+            print '        if (*ppvObj == m_pInstance) {'
+            print '            *ppvObj = this;'
+            print '        }'
+            for iface in self.api.interfaces:
+                print '        else if (riid == IID_%s) {' % iface.name
+                print '            *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
+                print '        }'
+            print '    }'
         if method.name == 'Release':
             assert method.type is not stdapi.Void
             print '    if (!__result)'
@@ -444,9 +460,6 @@ class DllTracer(Tracer):
     def __init__(self, dllname):
         self.dllname = dllname
     
-    def get_function_address(self, function):
-        return '__%s' % (function.name,)
-
     def header(self, api):
         print '''
 static HINSTANCE g_hDll = NULL;