1 ##########################################################################
3 # Copyright 2008-2010 VMware, Inc.
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24 ##########################################################################/
26 """Common trace code generation."""
32 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
35 import specs.stdapi as stdapi
38 def getWrapperInterfaceName(interface):
39 return "Wrap" + interface.expr
43 class ComplexValueSerializer(stdapi.OnceVisitor):
44 '''Type visitors which generates serialization functions for
47 Simple types are serialized inline.
50 def __init__(self, serializer):
51 stdapi.OnceVisitor.__init__(self)
52 self.serializer = serializer
54 def visitVoid(self, literal):
57 def visitLiteral(self, literal):
60 def visitString(self, string):
63 def visitConst(self, const):
64 self.visit(const.type)
66 def visitStruct(self, struct):
67 print 'static const char * _struct%s_members[%u] = {' % (struct.tag, len(struct.members))
68 for type, name, in struct.members:
72 print ' "%s",' % (name,)
74 print 'static const trace::StructSig _struct%s_sig = {' % (struct.tag,)
75 if struct.name is None:
78 structName = '"%s"' % struct.name
79 print ' %u, %s, %u, _struct%s_members' % (struct.id, structName, len(struct.members), struct.tag)
83 def visitArray(self, array):
84 self.visit(array.type)
86 def visitBlob(self, array):
89 def visitEnum(self, enum):
90 print 'static const trace::EnumValue _enum%s_values[] = {' % (enum.tag)
91 for value in enum.values:
92 print ' {"%s", %s},' % (value, value)
95 print 'static const trace::EnumSig _enum%s_sig = {' % (enum.tag)
96 print ' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag)
100 def visitBitmask(self, bitmask):
101 print 'static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag)
102 for value in bitmask.values:
103 print ' {"%s", %s},' % (value, value)
106 print 'static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag)
107 print ' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag)
111 def visitPointer(self, pointer):
112 self.visit(pointer.type)
114 def visitIntPointer(self, pointer):
117 def visitObjPointer(self, pointer):
118 self.visit(pointer.type)
120 def visitLinearPointer(self, pointer):
121 self.visit(pointer.type)
123 def visitHandle(self, handle):
124 self.visit(handle.type)
126 def visitReference(self, reference):
127 self.visit(reference.type)
129 def visitAlias(self, alias):
130 self.visit(alias.type)
132 def visitOpaque(self, opaque):
135 def visitInterface(self, interface):
138 def visitPolymorphic(self, polymorphic):
139 if not polymorphic.contextLess:
141 print 'static void _write__%s(int selector, const %s & value) {' % (polymorphic.tag, polymorphic.expr)
142 print ' switch (selector) {'
143 for cases, type in polymorphic.iterSwitch():
146 self.serializer.visit(type, 'static_cast<%s>(value)' % (type,))
153 class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
154 '''Visitor which generates code to serialize any type.
156 Simple types are serialized inline here, whereas the serialization of
157 complex types is dispatched to the serialization functions generated by
158 ComplexValueSerializer visitor above.
161 def visitLiteral(self, literal, instance):
162 print ' trace::localWriter.write%s(%s);' % (literal.kind, instance)
164 def visitString(self, string, instance):
166 cast = 'const char *'
169 cast = 'const wchar_t *'
171 if cast != string.expr:
172 # reinterpret_cast is necessary for GLubyte * <=> char *
173 instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
174 if string.length is not None:
175 length = ', %s' % self.expand(string.length)
178 print ' trace::localWriter.write%s(%s%s);' % (suffix, instance, length)
180 def visitConst(self, const, instance):
181 self.visit(const.type, instance)
183 def visitStruct(self, struct, instance):
184 print ' trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,)
185 for member in struct.members:
186 self.visitMember(member, instance)
187 print ' trace::localWriter.endStruct();'
189 def visitArray(self, array, instance):
190 length = '_c' + array.type.tag
191 index = '_i' + array.type.tag
192 array_length = self.expand(array.length)
193 print ' if (%s) {' % instance
194 print ' size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length)
195 print ' trace::localWriter.beginArray(%s);' % length
196 print ' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index)
197 print ' trace::localWriter.beginElement();'
198 self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
199 print ' trace::localWriter.endElement();'
201 print ' trace::localWriter.endArray();'
203 print ' trace::localWriter.writeNull();'
206 def visitBlob(self, blob, instance):
207 print ' trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size))
209 def visitEnum(self, enum, instance):
210 print ' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance)
212 def visitBitmask(self, bitmask, instance):
213 print ' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance)
215 def visitPointer(self, pointer, instance):
216 print ' if (%s) {' % instance
217 print ' trace::localWriter.beginArray(1);'
218 print ' trace::localWriter.beginElement();'
219 self.visit(pointer.type, "*" + instance)
220 print ' trace::localWriter.endElement();'
221 print ' trace::localWriter.endArray();'
223 print ' trace::localWriter.writeNull();'
226 def visitIntPointer(self, pointer, instance):
227 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
229 def visitObjPointer(self, pointer, instance):
230 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
232 def visitLinearPointer(self, pointer, instance):
233 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
235 def visitReference(self, reference, instance):
236 self.visit(reference.type, instance)
238 def visitHandle(self, handle, instance):
239 self.visit(handle.type, instance)
241 def visitAlias(self, alias, instance):
242 self.visit(alias.type, instance)
244 def visitOpaque(self, opaque, instance):
245 print ' trace::localWriter.writePointer((uintptr_t)%s);' % instance
247 def visitInterface(self, interface, instance):
250 def visitPolymorphic(self, polymorphic, instance):
251 if polymorphic.contextLess:
252 print ' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance)
254 switchExpr = self.expand(polymorphic.switchExpr)
255 print ' switch (%s) {' % switchExpr
256 for cases, type in polymorphic.iterSwitch():
259 caseInstance = instance
260 if type.expr is not None:
261 caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
262 self.visit(type, caseInstance)
264 if polymorphic.defaultType is None:
266 print r' os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,)
267 print r' trace::localWriter.writeNull();'
272 class WrapDecider(stdapi.Traverser):
273 '''Type visitor which will decide wheter this type will need wrapping or not.
275 For complex types (arrays, structures), we need to know this before hand.
279 self.needsWrapping = False
281 def visitLinearPointer(self, void):
284 def visitInterface(self, interface):
285 self.needsWrapping = True
288 class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
289 '''Type visitor which will generate the code to wrap an instance.
291 Wrapping is necessary mostly for interfaces, however interface pointers can
292 appear anywhere inside complex types.
295 def visitStruct(self, struct, instance):
296 for member in struct.members:
297 self.visitMember(member, instance)
299 def visitArray(self, array, instance):
300 array_length = self.expand(array.length)
301 print " if (%s) {" % instance
302 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
303 self.visitElement('_i', array.type, instance + "[_i]")
307 def visitPointer(self, pointer, instance):
308 print " if (%s) {" % instance
309 self.visit(pointer.type, "*" + instance)
312 def visitObjPointer(self, pointer, instance):
313 elem_type = pointer.type.mutable()
314 if isinstance(elem_type, stdapi.Interface):
315 self.visitInterfacePointer(elem_type, instance)
316 elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
317 self.visitInterfacePointer(elem_type.type, instance)
319 self.visitPointer(pointer, instance)
321 def visitInterface(self, interface, instance):
322 raise NotImplementedError
324 def visitInterfacePointer(self, interface, instance):
325 print " if (%s) {" % instance
326 print " %s = new %s(%s);" % (instance, getWrapperInterfaceName(interface), instance)
329 def visitPolymorphic(self, type, instance):
330 # XXX: There might be polymorphic values that need wrapping in the future
331 raise NotImplementedError
334 class ValueUnwrapper(ValueWrapper):
335 '''Reverse of ValueWrapper.'''
339 def visitStruct(self, struct, instance):
340 if not self.allocated:
341 # Argument is constant. We need to create a non const
343 print " %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct)
344 print ' *_t = %s;' % (instance,)
345 assert instance.startswith('*')
346 print ' %s = _t;' % (instance[1:],)
348 self.allocated = True
350 return ValueWrapper.visitStruct(self, struct, instance)
354 return ValueWrapper.visitStruct(self, struct, instance)
356 def visitArray(self, array, instance):
357 if self.allocated or isinstance(instance, stdapi.Interface):
358 return ValueWrapper.visitArray(self, array, instance)
359 array_length = self.expand(array.length)
360 elem_type = array.type.mutable()
361 print " if (%s && %s) {" % (instance, array_length)
362 print " %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length)
363 print " for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length
364 print " _t[_i] = %s[_i];" % instance
365 self.allocated = True
366 self.visit(array.type, "_t[_i]")
368 print " %s = _t;" % instance
371 def visitInterfacePointer(self, interface, instance):
372 print r' if (%s) {' % instance
373 print r' const %s *pWrapper = static_cast<const %s*>(%s);' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), instance)
374 print r' if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {'
375 print r' %s = pWrapper->m_pInstance;' % (instance,)
377 print r' os::log("apitrace: warning: %%s: unexpected %%s pointer\n", __FUNCTION__, "%s");' % interface.name
383 '''Base class to orchestrate the code generation of API tracing.'''
388 def serializerFactory(self):
389 '''Create a serializer.
391 Can be overriden by derived classes to inject their own serialzer.
394 return ValueSerializer()
396 def traceApi(self, api):
402 for module in api.modules:
403 for header in module.headers:
407 # Generate the serializer functions
408 types = api.getAllTypes()
409 visitor = ComplexValueSerializer(self.serializerFactory())
410 map(visitor.visit, types)
414 self.traceInterfaces(api)
417 self.interface = None
419 for function in api.getAllFunctions():
420 self.traceFunctionDecl(function)
421 for function in api.getAllFunctions():
422 self.traceFunctionImpl(function)
427 def header(self, api):
428 print '#ifdef _WIN32'
429 print '# include <malloc.h> // alloca'
430 print '# ifndef alloca'
431 print '# define alloca _alloca'
434 print '# include <alloca.h> // alloca'
437 print '#include "trace.hpp"'
440 def footer(self, api):
443 def traceFunctionDecl(self, function):
444 # Per-function declarations
446 if not function.internal:
448 print 'static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args]))
450 print 'static const char ** _%s_args = NULL;' % (function.name,)
451 print 'static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, function.id, function.name, len(function.args), function.name)
454 def isFunctionPublic(self, function):
457 def traceFunctionImpl(self, function):
458 if self.isFunctionPublic(function):
459 print 'extern "C" PUBLIC'
461 print 'extern "C" PRIVATE'
462 print function.prototype() + ' {'
463 if function.type is not stdapi.Void:
464 print ' %s _result;' % function.type
466 # No-op if tracing is disabled
467 print ' if (!trace::isTracingEnabled()) {'
468 Tracer.invokeFunction(self, function)
469 if function.type is not stdapi.Void:
470 print ' return _result;'
475 self.traceFunctionImplBody(function)
476 if function.type is not stdapi.Void:
477 print ' return _result;'
481 def traceFunctionImplBody(self, function):
482 if not function.internal:
483 print ' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
484 for arg in function.args:
486 self.unwrapArg(function, arg)
487 self.serializeArg(function, arg)
488 print ' trace::localWriter.endEnter();'
489 self.invokeFunction(function)
490 if not function.internal:
491 print ' trace::localWriter.beginLeave(_call);'
492 print ' if (%s) {' % self.wasFunctionSuccessful(function)
493 for arg in function.args:
495 self.serializeArg(function, arg)
496 self.wrapArg(function, arg)
498 if function.type is not stdapi.Void:
499 self.serializeRet(function, "_result")
500 print ' trace::localWriter.endLeave();'
501 if function.type is not stdapi.Void:
502 self.wrapRet(function, "_result")
504 def invokeFunction(self, function, prefix='_', suffix=''):
505 if function.type is stdapi.Void:
508 result = '_result = '
509 dispatch = prefix + function.name + suffix
510 print ' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args]))
512 def wasFunctionSuccessful(self, function):
513 if function.type is stdapi.Void:
515 if str(function.type) == 'HRESULT':
516 return 'SUCCEEDED(_result)'
519 def serializeArg(self, function, arg):
520 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
521 self.serializeArgValue(function, arg)
522 print ' trace::localWriter.endArg();'
524 def serializeArgValue(self, function, arg):
525 self.serializeValue(arg.type, arg.name)
527 def wrapArg(self, function, arg):
528 assert not isinstance(arg.type, stdapi.ObjPointer)
530 from specs.winapi import REFIID
532 for other_arg in function.args:
533 if not other_arg.output and other_arg.type is REFIID:
535 if riid is not None \
536 and isinstance(arg.type, stdapi.Pointer) \
537 and isinstance(arg.type.type, stdapi.ObjPointer):
538 self.wrapIid(function, riid, arg)
541 self.wrapValue(arg.type, arg.name)
543 def unwrapArg(self, function, arg):
544 self.unwrapValue(arg.type, arg.name)
546 def serializeRet(self, function, instance):
547 print ' trace::localWriter.beginReturn();'
548 self.serializeValue(function.type, instance)
549 print ' trace::localWriter.endReturn();'
551 def serializeValue(self, type, instance):
552 serializer = self.serializerFactory()
553 serializer.visit(type, instance)
555 def wrapRet(self, function, instance):
556 self.wrapValue(function.type, instance)
558 def unwrapRet(self, function, instance):
559 self.unwrapValue(function.type, instance)
561 def needsWrapping(self, type):
562 visitor = WrapDecider()
564 return visitor.needsWrapping
566 def wrapValue(self, type, instance):
567 if self.needsWrapping(type):
568 visitor = ValueWrapper()
569 visitor.visit(type, instance)
571 def unwrapValue(self, type, instance):
572 if self.needsWrapping(type):
573 visitor = ValueUnwrapper()
574 visitor.visit(type, instance)
576 def traceInterfaces(self, api):
577 interfaces = api.getAllInterfaces()
580 map(self.declareWrapperInterface, interfaces)
581 self.implementIidWrapper(api)
582 map(self.implementWrapperInterface, interfaces)
585 def declareWrapperInterface(self, interface):
586 print "class %s : public %s " % (getWrapperInterfaceName(interface), interface.name)
589 print " %s(%s * pInstance);" % (getWrapperInterfaceName(interface), interface.name)
590 print " virtual ~%s();" % getWrapperInterfaceName(interface)
592 for method in interface.iterMethods():
593 print " " + method.prototype() + ";"
596 for type, name, value in self.enumWrapperInterfaceVariables(interface):
597 print ' %s %s;' % (type, name)
601 def enumWrapperInterfaceVariables(self, interface):
603 ("DWORD", "m_dwMagic", "0xd8365d6c"),
604 ("%s *" % interface.name, "m_pInstance", "pInstance"),
607 def implementWrapperInterface(self, interface):
608 self.interface = interface
610 print '%s::%s(%s * pInstance) {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface), interface.name)
611 for type, name, value in self.enumWrapperInterfaceVariables(interface):
612 print ' %s = %s;' % (name, value)
615 print '%s::~%s() {' % (getWrapperInterfaceName(interface), getWrapperInterfaceName(interface))
619 for base, method in interface.iterBaseMethods():
621 self.implementWrapperInterfaceMethod(interface, base, method)
625 def implementWrapperInterfaceMethod(self, interface, base, method):
626 print method.prototype(getWrapperInterfaceName(interface) + '::' + method.name) + ' {'
627 if method.type is not stdapi.Void:
628 print ' %s _result;' % method.type
630 self.implementWrapperInterfaceMethodBody(interface, base, method)
632 if method.type is not stdapi.Void:
633 print ' return _result;'
637 def implementWrapperInterfaceMethodBody(self, interface, base, method):
638 assert not method.internal
640 print ' static const char * _args[%u] = {%s};' % (len(method.args) + 1, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args]))
641 print ' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (method.id, interface.name + '::' + method.name, len(method.args) + 1)
643 print ' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base)
645 print ' unsigned _call = trace::localWriter.beginEnter(&_sig);'
646 print ' trace::localWriter.beginArg(0);'
647 print ' trace::localWriter.writePointer((uintptr_t)m_pInstance);'
648 print ' trace::localWriter.endArg();'
649 for arg in method.args:
651 self.unwrapArg(method, arg)
652 self.serializeArg(method, arg)
653 print ' trace::localWriter.endEnter();'
655 self.invokeMethod(interface, base, method)
657 print ' trace::localWriter.beginLeave(_call);'
659 print ' if (%s) {' % self.wasFunctionSuccessful(method)
660 for arg in method.args:
662 self.serializeArg(method, arg)
663 self.wrapArg(method, arg)
666 if method.type is not stdapi.Void:
667 self.serializeRet(method, '_result')
668 print ' trace::localWriter.endLeave();'
669 if method.type is not stdapi.Void:
670 self.wrapRet(method, '_result')
672 if method.name == 'Release':
673 assert method.type is not stdapi.Void
674 print ' if (!_result)'
675 print ' delete this;'
677 def implementIidWrapper(self, api):
679 print r'warnIID(const char *functionName, REFIID riid, const char *reason) {'
680 print r' os::log("apitrace: warning: %s: %s IID {0x%08lX,0x%04X,0x%04X,{0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X,0x%02X}}\n",'
681 print r' functionName, reason,'
682 print r' riid.Data1, riid.Data2, riid.Data3,'
683 print r' riid.Data4[0], riid.Data4[1], riid.Data4[2], riid.Data4[3], riid.Data4[4], riid.Data4[5], riid.Data4[6], riid.Data4[7]);'
687 print r'wrapIID(const char *functionName, REFIID riid, void * * ppvObj) {'
688 print r' if (!ppvObj || !*ppvObj) {'
692 for iface in api.getAllInterfaces():
693 print r' %sif (riid == IID_%s) {' % (else_, iface.name)
694 print r' *ppvObj = new Wrap%s((%s *) *ppvObj);' % (iface.name, iface.name)
697 print r' %s{' % else_
698 print r' warnIID(functionName, riid, "unknown");'
703 def wrapIid(self, function, riid, out):
704 # Cast output arg to `void **` if necessary
706 obj_type = out.type.type.type
707 if not obj_type is stdapi.Void:
708 assert isinstance(obj_type, stdapi.Interface)
709 out_name = 'reinterpret_cast<void * *>(%s)' % out_name
711 print r' if (%s && *%s) {' % (out.name, out.name)
712 functionName = function.name
714 if self.interface is not None:
715 functionName = self.interface.name + '::' + functionName
716 print r' if (*%s == m_pInstance &&' % (out_name,)
717 print r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases())
718 print r' *%s = this;' % (out_name,)
721 print r' %s{' % else_
722 print r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name)
726 def invokeMethod(self, interface, base, method):
727 if method.type is stdapi.Void:
730 result = '_result = '
731 print ' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args]))
733 def emit_memcpy(self, dest, src, length):
734 print ' unsigned _call = trace::localWriter.beginEnter(&trace::memcpy_sig);'
735 print ' trace::localWriter.beginArg(0);'
736 print ' trace::localWriter.writePointer((uintptr_t)%s);' % dest
737 print ' trace::localWriter.endArg();'
738 print ' trace::localWriter.beginArg(1);'
739 print ' trace::localWriter.writeBlob(%s, %s);' % (src, length)
740 print ' trace::localWriter.endArg();'
741 print ' trace::localWriter.beginArg(2);'
742 print ' trace::localWriter.writeUInt(%s);' % length
743 print ' trace::localWriter.endArg();'
744 print ' trace::localWriter.endEnter();'
745 print ' trace::localWriter.beginLeave(_call);'
746 print ' trace::localWriter.endLeave();'
748 def fake_call(self, function, args):
749 print ' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,)
750 for arg, instance in zip(function.args, args):
751 assert not arg.output
752 print ' trace::localWriter.beginArg(%u);' % (arg.index,)
753 self.serializeValue(arg.type, instance)
754 print ' trace::localWriter.endArg();'
755 print ' trace::localWriter.endEnter();'
756 print ' trace::localWriter.beginLeave(_fake_call);'
757 print ' trace::localWriter.endLeave();'