pass
def visit_pointer(self, pointer, instance):
+ print " if (%s) {" % instance
self.visit(pointer.type, "*" + instance)
+ print " }"
def visit_handle(self, handle, instance):
self.visit(handle.type, instance)
def visit_interface(self, interface, instance):
assert instance.startswith('*')
instance = instance[1:]
- print " if (%s)" % instance
+ print " if (%s) {" % instance
print " %s = new %s(%s);" % (instance, interface_wrap_name(interface), instance)
+ print " }"
class Unwrapper(Wrapper):
def visit_interface(self, interface, instance):
assert instance.startswith('*')
instance = instance[1:]
- print " if (%s)" % instance
+ print " if (%s) {" % instance
print " %s = static_cast<%s *>(%s)->m_pInstance;" % (instance, interface_wrap_name(interface), instance)
+ print " }"
wrap_instance = Wrapper().visit
class Tracer:
+ def __init__(self):
+ self.api = None
+
def trace_api(self, api):
+ self.api = api
+
self.header(api)
# Includes
wrap_instance(method.type, '__result')
print ' Trace::EndLeave();'
if method.name == 'QueryInterface':
- print ' if (*ppvObj == m_pInstance)'
- print ' *ppvObj = this;'
+ print ' if (ppvObj && *ppvObj) {'
+ print ' if (*ppvObj == m_pInstance) {'
+ print ' *ppvObj = this;'
+ print ' }'
+ for iface in self.api.interfaces:
+ print ' else if (riid == IID_%s) {' % iface.name
+ print ' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
+ print ' }'
+ print ' }'
if method.name == 'Release':
assert method.type is not stdapi.Void
print ' if (!__result)'
def __init__(self, dllname):
self.dllname = dllname
- def get_function_address(self, function):
- return '__%s' % (function.name,)
-
def header(self, api):
print '''
static HINSTANCE g_hDll = NULL;