]> git.notmuchmail.org Git - apitrace/blob - retrace.py
96a09abc1fcd623048453a578d202d612c6e10e3
[apitrace] / retrace.py
1 ##########################################################################
2 #
3 # Copyright 2010 VMware, Inc.
4 # All Rights Reserved.
5 #
6 # Permission is hereby granted, free of charge, to any person obtaining a copy
7 # of this software and associated documentation files (the "Software"), to deal
8 # in the Software without restriction, including without limitation the rights
9 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 # copies of the Software, and to permit persons to whom the Software is
11 # furnished to do so, subject to the following conditions:
12 #
13 # The above copyright notice and this permission notice shall be included in
14 # all copies or substantial portions of the Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 # THE SOFTWARE.
23 #
24 ##########################################################################/
25
26
27 """Generic retracing code generator."""
28
29
30 import sys
31
32 import specs.stdapi as stdapi
33 import specs.glapi as glapi
34
35
36 class ConstRemover(stdapi.Rebuilder):
37     '''Type visitor which strips out const qualifiers from types.'''
38
39     def visitConst(self, const):
40         return const.type
41
42     def visitOpaque(self, opaque):
43         return opaque
44
45
46 def lookupHandle(handle, value):
47     if handle.key is None:
48         return "__%s_map[%s]" % (handle.name, value)
49     else:
50         key_name, key_type = handle.key
51         return "__%s_map[%s][%s]" % (handle.name, key_name, value)
52
53
54 class ValueDeserializer(stdapi.Visitor):
55
56     def visitLiteral(self, literal, lvalue, rvalue):
57         print '    %s = (%s).to%s();' % (lvalue, rvalue, literal.kind)
58
59     def visitConst(self, const, lvalue, rvalue):
60         self.visit(const.type, lvalue, rvalue)
61
62     def visitAlias(self, alias, lvalue, rvalue):
63         self.visit(alias.type, lvalue, rvalue)
64     
65     def visitEnum(self, enum, lvalue, rvalue):
66         print '    %s = (%s).toSInt();' % (lvalue, rvalue)
67
68     def visitBitmask(self, bitmask, lvalue, rvalue):
69         self.visit(bitmask.type, lvalue, rvalue)
70
71     def visitArray(self, array, lvalue, rvalue):
72         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
73         print '    if (__a%s) {' % (array.tag)
74         length = '__a%s->values.size()' % array.tag
75         print '        %s = new %s[%s];' % (lvalue, array.type, length)
76         index = '__j' + array.tag
77         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
78         try:
79             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
80         finally:
81             print '        }'
82             print '    } else {'
83             print '        %s = NULL;' % lvalue
84             print '    }'
85     
86     def visitPointer(self, pointer, lvalue, rvalue):
87         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
88         print '    if (__a%s) {' % (pointer.tag)
89         print '        %s = new %s;' % (lvalue, pointer.type)
90         try:
91             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
92         finally:
93             print '    } else {'
94             print '        %s = NULL;' % lvalue
95             print '    }'
96
97     def visitHandle(self, handle, lvalue, rvalue):
98         OpaqueValueDeserializer().visit(handle.type, lvalue, rvalue);
99         new_lvalue = lookupHandle(handle, lvalue)
100         print '    if (retrace::verbosity >= 2) {'
101         print '        std::cout << "%s " << size_t(%s) << " <- " << size_t(%s) << "\\n";' % (handle.name, lvalue, new_lvalue)
102         print '    }'
103         print '    %s = %s;' % (lvalue, new_lvalue)
104     
105     def visitBlob(self, blob, lvalue, rvalue):
106         print '    %s = static_cast<%s>((%s).toPointer());' % (lvalue, blob, rvalue)
107     
108     def visitString(self, string, lvalue, rvalue):
109         print '    %s = (%s)((%s).toString());' % (lvalue, string.expr, rvalue)
110
111
112 class OpaqueValueDeserializer(ValueDeserializer):
113     '''Value extractor that also understands opaque values.
114
115     Normally opaque values can't be retraced, unless they are being extracted
116     in the context of handles.'''
117
118     def visitOpaque(self, opaque, lvalue, rvalue):
119         print '    %s = static_cast<%s>(retrace::toPointer(%s));' % (lvalue, opaque, rvalue)
120
121
122 class SwizzledValueRegistrator(stdapi.Visitor):
123     '''Type visitor which will register (un)swizzled value pairs, to later be
124     swizzled.'''
125
126     def visitLiteral(self, literal, lvalue, rvalue):
127         pass
128
129     def visitAlias(self, alias, lvalue, rvalue):
130         self.visit(alias.type, lvalue, rvalue)
131     
132     def visitEnum(self, enum, lvalue, rvalue):
133         pass
134
135     def visitBitmask(self, bitmask, lvalue, rvalue):
136         pass
137
138     def visitArray(self, array, lvalue, rvalue):
139         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (array.tag, rvalue)
140         print '    if (__a%s) {' % (array.tag)
141         length = '__a%s->values.size()' % array.tag
142         index = '__j' + array.tag
143         print '        for (size_t {i} = 0; {i} < {length}; ++{i}) {{'.format(i = index, length = length)
144         try:
145             self.visit(array.type, '%s[%s]' % (lvalue, index), '*__a%s->values[%s]' % (array.tag, index))
146         finally:
147             print '        }'
148             print '    }'
149     
150     def visitPointer(self, pointer, lvalue, rvalue):
151         print '    const trace::Array *__a%s = dynamic_cast<const trace::Array *>(&%s);' % (pointer.tag, rvalue)
152         print '    if (__a%s) {' % (pointer.tag)
153         try:
154             self.visit(pointer.type, '%s[0]' % (lvalue,), '*__a%s->values[0]' % (pointer.tag,))
155         finally:
156             print '    }'
157     
158     def visitHandle(self, handle, lvalue, rvalue):
159         print '    %s __orig_result;' % handle.type
160         OpaqueValueDeserializer().visit(handle.type, '__orig_result', rvalue);
161         if handle.range is None:
162             rvalue = "__orig_result"
163             entry = lookupHandle(handle, rvalue) 
164             print "    %s = %s;" % (entry, lvalue)
165             print '    if (retrace::verbosity >= 2) {'
166             print '        std::cout << "{handle.name} " << {rvalue} << " -> " << {lvalue} << "\\n";'.format(**locals())
167             print '    }'
168         else:
169             i = '__h' + handle.tag
170             lvalue = "%s + %s" % (lvalue, i)
171             rvalue = "__orig_result + %s" % (i,)
172             entry = lookupHandle(handle, rvalue) 
173             print '    for ({handle.type} {i} = 0; {i} < {handle.range}; ++{i}) {{'.format(**locals())
174             print '        {entry} = {lvalue};'.format(**locals())
175             print '        if (retrace::verbosity >= 2) {'
176             print '            std::cout << "{handle.name} " << ({rvalue}) << " -> " << ({lvalue}) << "\\n";'.format(**locals())
177             print '        }'
178             print '    }'
179     
180     def visitBlob(self, blob, lvalue, rvalue):
181         pass
182     
183     def visitString(self, string, lvalue, rvalue):
184         pass
185
186
187 class Retracer:
188
189     def retraceFunction(self, function):
190         print 'static void retrace_%s(trace::Call &call) {' % function.name
191         self.retraceFunctionBody(function)
192         print '}'
193         print
194
195     def retraceFunctionBody(self, function):
196         if not function.sideeffects:
197             print '    (void)call;'
198             return
199
200         success = True
201         for arg in function.args:
202             arg_type = ConstRemover().visit(arg.type)
203             #print '    // %s ->  %s' % (arg.type, arg_type)
204             print '    %s %s;' % (arg_type, arg.name)
205             rvalue = 'call.arg(%u)' % (arg.index,)
206             lvalue = arg.name
207             try:
208                 self.extractArg(function, arg, arg_type, lvalue, rvalue)
209             except NotImplementedError:
210                 success = False
211                 print '    %s = 0; // FIXME' % arg.name
212         if not success:
213             print '    if (1) {'
214             self.failFunction(function)
215             print '    }'
216         self.invokeFunction(function)
217         for arg in function.args:
218             if arg.output:
219                 arg_type = ConstRemover().visit(arg.type)
220                 rvalue = 'call.arg(%u)' % (arg.index,)
221                 lvalue = arg.name
222                 try:
223                     self.regiterSwizzledValue(arg_type, lvalue, rvalue)
224                 except NotImplementedError:
225                     print '    // XXX: %s' % arg.name
226         if function.type is not stdapi.Void:
227             rvalue = '*call.ret'
228             lvalue = '__result'
229             try:
230                 self.regiterSwizzledValue(function.type, lvalue, rvalue)
231             except NotImplementedError:
232                 print '    // XXX: result'
233         if not success:
234             if function.name[-1].islower():
235                 sys.stderr.write('warning: unsupported %s call\n' % function.name)
236
237     def failFunction(self, function):
238         print '    if (retrace::verbosity >= 0) {'
239         print '        retrace::unsupported(call);'
240         print '    }'
241         print '    return;'
242
243     def extractArg(self, function, arg, arg_type, lvalue, rvalue):
244         ValueDeserializer().visit(arg_type, lvalue, rvalue)
245     
246     def extractOpaqueArg(self, function, arg, arg_type, lvalue, rvalue):
247         OpaqueValueDeserializer().visit(arg_type, lvalue, rvalue)
248
249     def regiterSwizzledValue(self, type, lvalue, rvalue):
250         visitor = SwizzledValueRegistrator()
251         visitor.visit(type, lvalue, rvalue)
252
253     def invokeFunction(self, function):
254         arg_names = ", ".join([arg.name for arg in function.args])
255         if function.type is not stdapi.Void:
256             print '    %s __result;' % (function.type)
257             print '    __result = %s(%s);' % (function.name, arg_names)
258             print '    (void)__result;'
259         else:
260             print '    %s(%s);' % (function.name, arg_names)
261
262     def filterFunction(self, function):
263         return True
264
265     table_name = 'retrace::callbacks'
266
267     def retraceFunctions(self, functions):
268         functions = filter(self.filterFunction, functions)
269
270         for function in functions:
271             self.retraceFunction(function)
272
273         print 'const retrace::Entry %s[] = {' % self.table_name
274         for function in functions:
275             print '    {"%s", &retrace_%s},' % (function.name, function.name)
276         print '    {NULL, NULL}'
277         print '};'
278         print
279
280
281     def retraceApi(self, api):
282
283         print '#include "trace_parser.hpp"'
284         print '#include "retrace.hpp"'
285         print
286
287         types = api.all_types()
288         handles = [type for type in types if isinstance(type, stdapi.Handle)]
289         handle_names = set()
290         for handle in handles:
291             if handle.name not in handle_names:
292                 if handle.key is None:
293                     print 'static retrace::map<%s> __%s_map;' % (handle.type, handle.name)
294                 else:
295                     key_name, key_type = handle.key
296                     print 'static std::map<%s, retrace::map<%s> > __%s_map;' % (key_type, handle.type, handle.name)
297                 handle_names.add(handle.name)
298         print
299
300         self.retraceFunctions(api.functions)
301