blob: d3c11266b0588efdca874e0767ba11fd82f65545 [file] [log] [blame]
Greg Steindb84a102009-01-31 07:40:26 +00001#
2# simple generator for Thrift
3#
4
5import sys
6import os
7import cStringIO
8import operator
9
10import parser
11import ezt
12
13
14### temporary
15PATH = '/Users/gstein/src/asf/thrift/compiler/py/src/templates-py'
16t_py = ezt.Template(os.path.join(PATH, 'py.ezt'),
17 compress_whitespace=False)
18t_py_ser = ezt.Template(os.path.join(PATH, 'py_ser.ezt'),
19 compress_whitespace=False)
20t_py_deser = ezt.Template(os.path.join(PATH, 'py_deser.ezt'),
21 compress_whitespace=False)
22t_py_cvalue = ezt.Template(os.path.join(PATH, 'py_cvalue.ezt'),
23 compress_whitespace=False)
24
25
26def generate(program):
27 t_py.generate(sys.stdout, Proxy(program))
28
29
30class AutoVars(object):
31 def __init__(self):
32 self._counter = 0
33 self._mapping = { }
34 self._saved = [ ]
35
36 def open_context(self):
37 self._saved.append(self._mapping)
38 self._mapping = { }
39
40 def close_context(self):
41 self._mapping = self._saved.pop()
42
43 def __getattr__(self, name):
44 if name.startswith('__'):
45 raise AttributeError(name)
46
47 if name in self._mapping:
48 return self._mapping[name]
49 var = '%s%d' % (name, self._counter)
50 self._counter += 1
51 self._mapping[name] = var
52 return var
53
54
55class Proxy(object):
56 def __init__(self, ob):
57 self._ob = ob
58
59 for name, value in vars(ob).items():
60 proxy = custom_proxy(value)
61 if proxy:
62 value = proxy(value)
63 elif isinstance(value, list) and value:
64 # lists are homogenous, so check the first item
65 proxy = custom_proxy(value[0])
66 if proxy:
67 value = [proxy(ob) for ob in value]
68 elif hasattr(value[0], '__dict__'):
69 value = [Proxy(ob) for ob in value]
70 setattr(self, name, value)
71
72 def __getattr__(self, name):
73 if name == 'auto':
74 return g_auto
75 raise AttributeError(name)
76
77
78class ProxyFieldType(Proxy):
79 def __getattr__(self, name):
80 if name == 'serializer':
81 return Subtemplate(t_py_ser, self)
82 if name == 'deserializer':
83 return Subtemplate(t_py_deser, self)
84 return Proxy.__getattr__(self, name)
85
86
87class Subtemplate(object):
88 def __init__(self, template, data):
89 self._template = template
90 self._data = data
91
92 def __getattr__(self, name):
93 # jam the name of the result variable into the data params
94 self._data.result_var = getattr(g_auto, name)
95
96 # use a new variable context for this template generation
97 g_auto.open_context()
98 value = gen_value(self._template, self._data)
99 g_auto.close_context()
100
101 return value
102
103
104class ProxyField(Proxy):
105 def __getattr__(self, name):
106 if name == 'type_enum':
107 return TYPE_ENUM.get(self._ob.field_type.ident,
108 self._ob.field_type.ident.tvalue)
109 return Proxy.__getattr__(self, name)
110
111
112class ProxyStruct(Proxy):
113 def __getattr__(self, name):
114 if name == 'sorted_fields':
115 highest = max(int(f.field_id or -1) for f in self._ob.fields)
116 fields = [None] * (highest + 1)
117 for field in self._ob.fields:
118 if field.field_id:
119 id = int(field.field_id)
120 if id > 0:
121 fields[id] = ProxyField(field)
122 return fields
123 return Proxy.__getattr__(self, name)
124
125
126class ProxyConstValue(Proxy):
127 def __getattr__(self, name):
128 if name == 'cvalue':
129 return gen_value(t_py_cvalue, self)
130 return Proxy.__getattr__(self, name)
131
132
133def custom_proxy(value):
134 if isinstance(value, parser.FieldType):
135 return ProxyFieldType
136 if isinstance(value, parser.Field):
137 return ProxyField
138 if isinstance(value, parser.Struct):
139 return ProxyStruct
140 if isinstance(value, parser.ConstValue):
141 return ProxyConstValue
142 return None
143
144
145TYPE_ENUM = {
146 parser.ID_STRING: 'TType.STRING',
147 parser.ID_BOOL: 'TType.BOOL',
148 parser.ID_BYTE: 'TType.BYTE',
149 parser.ID_I16: 'TType.I16',
150 parser.ID_I32: 'TType.I32',
151 parser.ID_I64: 'TType.I64',
152 parser.ID_DOUBLE: 'TType.DOUBLE',
153 parser.ID_MAP: 'TType.MAP',
154 parser.ID_SET: 'TType.SET',
155 parser.ID_LIST: 'TType.LIST',
156 # TType.STRUCT and TType.I32 for enums
157 }
158
159
160def gen_value(template, ob):
161 buf = cStringIO.StringIO()
162 template.generate(buf, ob)
163 return buf.getvalue()
164
165
166if __name__ == '__main__':
167 import sys
168 program = parser.parse(open(sys.argv[1]).read())
169 generate(program)