blob: 0e7aa1a25331343b8426fca51e4d79a87bbb14aa [file] [log] [blame]
Bryan Duxburyc0166282009-02-02 00:48:17 +00001
2#include <struct.h>
3#include <constants.h>
4
5static native_proto_method_table *mt;
6
7#define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET)
8#define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id)
9
10//-------------------------------------------
11// Writing section
12//-------------------------------------------
13
14// default fn pointers for protocol stuff here
15
16VALUE default_write_bool(VALUE protocol, VALUE value) {
17 rb_funcall(protocol, write_boolean_method_id, 1, value);
18 return Qnil;
19}
20
21VALUE default_write_byte(VALUE protocol, VALUE value) {
22 rb_funcall(protocol, write_byte_method_id, 1, value);
23 return Qnil;
24}
25
26VALUE default_write_i16(VALUE protocol, VALUE value) {
27 rb_funcall(protocol, write_i16_method_id, 1, value);
28 return Qnil;
29}
30
31VALUE default_write_i32(VALUE protocol, VALUE value) {
32 rb_funcall(protocol, write_i32_method_id, 1, value);
33 return Qnil;
34}
35
36VALUE default_write_i64(VALUE protocol, VALUE value) {
37 rb_funcall(protocol, write_i64_method_id, 1, value);
38 return Qnil;
39}
40
41VALUE default_write_double(VALUE protocol, VALUE value) {
42 rb_funcall(protocol, write_double_method_id, 1, value);
43 return Qnil;
44}
45
46VALUE default_write_string(VALUE protocol, VALUE value) {
47 rb_funcall(protocol, write_string_method_id, 1, value);
48 return Qnil;
49}
50
51VALUE default_write_list_begin(VALUE protocol, VALUE etype, VALUE length) {
52 rb_funcall(protocol, write_list_begin_method_id, 2, etype, length);
53 return Qnil;
54}
55
56VALUE default_write_list_end(VALUE protocol) {
57 rb_funcall(protocol, write_list_end_method_id, 0);
58 return Qnil;
59}
60
61VALUE default_write_set_begin(VALUE protocol, VALUE etype, VALUE length) {
62 rb_funcall(protocol, write_set_begin_method_id, 2, etype, length);
63 return Qnil;
64}
65
66VALUE default_write_set_end(VALUE protocol) {
67 rb_funcall(protocol, write_set_end_method_id, 0);
68 return Qnil;
69}
70
71VALUE default_write_map_begin(VALUE protocol, VALUE ktype, VALUE vtype, VALUE length) {
72 rb_funcall(protocol, write_map_begin_method_id, 3, ktype, vtype, length);
73 return Qnil;
74}
75
76VALUE default_write_map_end(VALUE protocol) {
77 rb_funcall(protocol, write_map_end_method_id, 0);
78 return Qnil;
79}
80
81VALUE default_write_struct_begin(VALUE protocol, VALUE struct_name) {
82 rb_funcall(protocol, write_struct_begin_method_id, 1, struct_name);
83 return Qnil;
84}
85
86VALUE default_write_struct_end(VALUE protocol) {
87 rb_funcall(protocol, write_struct_end_method_id, 0);
88 return Qnil;
89}
90
91VALUE default_write_field_begin(VALUE protocol, VALUE name, VALUE type, VALUE id) {
92 rb_funcall(protocol, write_field_begin_method_id, 3, name, type, id);
93 return Qnil;
94}
95
96VALUE default_write_field_end(VALUE protocol) {
97 rb_funcall(protocol, write_field_end_method_id, 0);
98 return Qnil;
99}
100
101VALUE default_write_field_stop(VALUE protocol) {
102 rb_funcall(protocol, write_field_stop_method_id, 0);
103 return Qnil;
104}
105
106VALUE default_read_field_begin(VALUE protocol) {
107 return rb_funcall(protocol, read_field_begin_method_id, 0);
108}
109
110VALUE default_read_field_end(VALUE protocol) {
111 return rb_funcall(protocol, read_field_end_method_id, 0);
112}
113
114VALUE default_read_map_begin(VALUE protocol) {
115 return rb_funcall(protocol, read_map_begin_method_id, 0);
116}
117
118VALUE default_read_map_end(VALUE protocol) {
119 return rb_funcall(protocol, read_map_end_method_id, 0);
120}
121
122VALUE default_read_list_begin(VALUE protocol) {
123 return rb_funcall(protocol, read_list_begin_method_id, 0);
124}
125
126VALUE default_read_list_end(VALUE protocol) {
127 return rb_funcall(protocol, read_list_end_method_id, 0);
128}
129
130VALUE default_read_set_begin(VALUE protocol) {
131 return rb_funcall(protocol, read_set_begin_method_id, 0);
132}
133
134VALUE default_read_set_end(VALUE protocol) {
135 return rb_funcall(protocol, read_set_end_method_id, 0);
136}
137
138VALUE default_read_byte(VALUE protocol) {
139 return rb_funcall(protocol, read_byte_method_id, 0);
140}
141
142VALUE default_read_bool(VALUE protocol) {
143 return rb_funcall(protocol, read_bool_method_id, 0);
144}
145
146VALUE default_read_i16(VALUE protocol) {
147 return rb_funcall(protocol, read_i16_method_id, 0);
148}
149
150VALUE default_read_i32(VALUE protocol) {
151 return rb_funcall(protocol, read_i32_method_id, 0);
152}
153
154VALUE default_read_i64(VALUE protocol) {
155 return rb_funcall(protocol, read_i64_method_id, 0);
156}
157
158VALUE default_read_double(VALUE protocol) {
159 return rb_funcall(protocol, read_double_method_id, 0);
160}
161
162VALUE default_read_string(VALUE protocol) {
163 return rb_funcall(protocol, read_string_method_id, 0);
164}
165
166VALUE default_read_struct_begin(VALUE protocol) {
167 return rb_funcall(protocol, read_struct_begin_method_id, 0);
168}
169
170VALUE default_read_struct_end(VALUE protocol) {
171 return rb_funcall(protocol, read_struct_end_method_id, 0);
172}
173
174static void set_default_proto_function_pointers() {
175 mt = ALLOC(native_proto_method_table);
176
177 mt->write_field_begin = default_write_field_begin;
178 mt->write_field_stop = default_write_field_stop;
179 mt->write_map_begin = default_write_map_begin;
180 mt->write_map_end = default_write_map_end;
181 mt->write_list_begin = default_write_list_begin;
182 mt->write_list_end = default_write_list_end;
183 mt->write_set_begin = default_write_set_begin;
184 mt->write_set_end = default_write_set_end;
185 mt->write_byte = default_write_byte;
186 mt->write_bool = default_write_bool;
187 mt->write_i16 = default_write_i16;
188 mt->write_i32 = default_write_i32;
189 mt->write_i64 = default_write_i64;
190 mt->write_double = default_write_double;
191 mt->write_string = default_write_string;
192 mt->write_struct_begin = default_write_struct_begin;
193 mt->write_struct_end = default_write_struct_end;
194 mt->write_field_end = default_write_field_end;
195
196 mt->read_struct_begin = default_read_struct_begin;
197 mt->read_struct_end = default_read_struct_end;
198 mt->read_field_begin = default_read_field_begin;
199 mt->read_field_end = default_read_field_end;
200 mt->read_map_begin = default_read_map_begin;
201 mt->read_map_end = default_read_map_end;
202 mt->read_list_begin = default_read_list_begin;
203 mt->read_list_end = default_read_list_end;
204 mt->read_set_begin = default_read_set_begin;
205 mt->read_set_end = default_read_set_end;
206 mt->read_byte = default_read_byte;
207 mt->read_bool = default_read_bool;
208 mt->read_i16 = default_read_i16;
209 mt->read_i32 = default_read_i32;
210 mt->read_i64 = default_read_i64;
211 mt->read_double = default_read_double;
212 mt->read_string = default_read_string;
213
214}
215
216static void set_native_proto_function_pointers(VALUE protocol) {
217 VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table"));
218 // TODO: check nil?
219 Data_Get_Struct(method_table_object, native_proto_method_table, mt);
220}
221
222// end default protocol methods
223
224
225static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol);
226static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info);
227
228VALUE get_field_value(VALUE obj, VALUE field_name) {
229 char name_buf[RSTRING(field_name)->len + 1];
230
231 name_buf[0] = '@';
232 strlcpy(&name_buf[1], RSTRING(field_name)->ptr, sizeof(name_buf));
233
234 VALUE value = rb_ivar_get(obj, rb_intern(name_buf));
235
236 return value;
237}
238
239static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) {
240 int sz, i;
241
242 if (ttype == TTYPE_MAP) {
243 VALUE keys;
244 VALUE key;
245 VALUE val;
246
247 Check_Type(value, T_HASH);
248
249 VALUE key_info = rb_hash_aref(field_info, key_sym);
250 VALUE keytype_value = rb_hash_aref(key_info, type_sym);
251 int keytype = FIX2INT(keytype_value);
252
253 VALUE value_info = rb_hash_aref(field_info, value_sym);
254 VALUE valuetype_value = rb_hash_aref(value_info, type_sym);
255 int valuetype = FIX2INT(valuetype_value);
256
257 keys = rb_funcall(value, keys_method_id, 0);
258
259 sz = RARRAY(keys)->len;
260
261 mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz));
262
263 for (i = 0; i < sz; i++) {
264 key = rb_ary_entry(keys, i);
265 val = rb_hash_aref(value, key);
266
267 if (IS_CONTAINER(keytype)) {
268 write_container(keytype, key_info, key, protocol);
269 } else {
270 write_anything(keytype, key, protocol, key_info);
271 }
272
273 if (IS_CONTAINER(valuetype)) {
274 write_container(valuetype, value_info, val, protocol);
275 } else {
276 write_anything(valuetype, val, protocol, value_info);
277 }
278 }
279
280 mt->write_map_end(protocol);
281 } else if (ttype == TTYPE_LIST) {
282 Check_Type(value, T_ARRAY);
283
284 sz = RARRAY(value)->len;
285
286 VALUE element_type_info = rb_hash_aref(field_info, element_sym);
287 VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
288 int element_type = FIX2INT(element_type_value);
289
290 mt->write_list_begin(protocol, element_type_value, INT2FIX(sz));
291 for (i = 0; i < sz; ++i) {
292 VALUE val = rb_ary_entry(value, i);
293 if (IS_CONTAINER(element_type)) {
294 write_container(element_type, element_type_info, val, protocol);
295 } else {
296 write_anything(element_type, val, protocol, element_type_info);
297 }
298 }
299 mt->write_list_end(protocol);
300 } else if (ttype == TTYPE_SET) {
301 VALUE items;
302
303 if (TYPE(value) == T_ARRAY) {
304 items = value;
305 } else {
306 if (rb_cSet == CLASS_OF(value)) {
307 items = rb_funcall(value, entries_method_id, 0);
308 } else {
309 Check_Type(value, T_HASH);
310 items = rb_funcall(value, keys_method_id, 0);
311 }
312 }
313
314 sz = RARRAY(items)->len;
315
316 VALUE element_type_info = rb_hash_aref(field_info, element_sym);
317 VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
318 int element_type = FIX2INT(element_type_value);
319
320 mt->write_set_begin(protocol, element_type_value, INT2FIX(sz));
321
322 for (i = 0; i < sz; i++) {
323 VALUE val = rb_ary_entry(items, i);
324 if (IS_CONTAINER(element_type)) {
325 write_container(element_type, element_type_info, val, protocol);
326 } else {
327 write_anything(element_type, val, protocol, element_type_info);
328 }
329 }
330
331 mt->write_set_end(protocol);
332 } else {
333 rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype);
334 }
335}
336
337static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info) {
338 if (ttype == TTYPE_BOOL) {
339 mt->write_bool(protocol, value);
340 } else if (ttype == TTYPE_BYTE) {
341 mt->write_byte(protocol, value);
342 } else if (ttype == TTYPE_I16) {
343 mt->write_i16(protocol, value);
344 } else if (ttype == TTYPE_I32) {
345 mt->write_i32(protocol, value);
346 } else if (ttype == TTYPE_I64) {
347 mt->write_i64(protocol, value);
348 } else if (ttype == TTYPE_DOUBLE) {
349 mt->write_double(protocol, value);
350 } else if (ttype == TTYPE_STRING) {
351 mt->write_string(protocol, value);
352 } else if (IS_CONTAINER(ttype)) {
353 write_container(ttype, field_info, value, protocol);
354 } else if (ttype == TTYPE_STRUCT) {
355 rb_thrift_struct_write(value, protocol);
356 } else {
357 rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype);
358 }
359}
360
361static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol) {
362 // call validate
363 rb_funcall(self, validate_method_id, 0);
364
365 if (RTEST(rb_funcall(protocol, native_qmark_method_id, 0))) {
366 set_native_proto_function_pointers(protocol);
367 } else {
368 set_default_proto_function_pointers();
369 }
370
371 // write struct begin
372 mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self)));
373
374 // iterate through all the fields here
375 VALUE struct_fields = STRUCT_FIELDS(self);
376 VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0);
377 VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0);
378
379 int i = 0;
380 for (i=0; i < RARRAY(struct_field_ids_ordered)->len; i++) {
381 VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i);
382 VALUE field_info = rb_hash_aref(struct_fields, field_id);
383
384 VALUE ttype_value = rb_hash_aref(field_info, type_sym);
385 int ttype = FIX2INT(ttype_value);
386 VALUE field_name = rb_hash_aref(field_info, name_sym);
387 VALUE field_value = get_field_value(self, field_name);
388
389 if (!NIL_P(field_value)) {
390 mt->write_field_begin(protocol, field_name, ttype_value, field_id);
391
392 write_anything(ttype, field_value, protocol, field_info);
393
394 mt->write_field_end(protocol);
395 }
396 }
397
398 mt->write_field_stop(protocol);
399
400 // write struct end
401 mt->write_struct_end(protocol);
402
403 return Qnil;
404}
405
406//-------------------------------------------
407// Reading section
408//-------------------------------------------
409
410static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol);
411
412static void set_field_value(VALUE obj, VALUE field_name, VALUE value) {
413 char name_buf[RSTRING(field_name)->len + 1];
414
415 name_buf[0] = '@';
416 strlcpy(&name_buf[1], RSTRING(field_name)->ptr, sizeof(name_buf));
417
418 rb_ivar_set(obj, rb_intern(name_buf), value);
419}
420
421static VALUE read_anything(VALUE protocol, int ttype, VALUE field_info) {
422 VALUE result = Qnil;
423
424 if (ttype == TTYPE_BOOL) {
425 result = mt->read_bool(protocol);
426 } else if (ttype == TTYPE_BYTE) {
427 result = mt->read_byte(protocol);
428 } else if (ttype == TTYPE_I16) {
429 result = mt->read_i16(protocol);
430 } else if (ttype == TTYPE_I32) {
431 result = mt->read_i32(protocol);
432 } else if (ttype == TTYPE_I64) {
433 result = mt->read_i64(protocol);
434 } else if (ttype == TTYPE_STRING) {
435 result = mt->read_string(protocol);
436 } else if (ttype == TTYPE_DOUBLE) {
437 result = mt->read_double(protocol);
438 } else if (ttype == TTYPE_STRUCT) {
439 VALUE klass = rb_hash_aref(field_info, class_sym);
440 result = rb_class_new_instance(0, NULL, klass);
441 rb_thrift_struct_read(result, protocol);
442 } else if (ttype == TTYPE_MAP) {
443 int i;
444
445 VALUE map_header = mt->read_map_begin(protocol);
446 int key_ttype = FIX2INT(rb_ary_entry(map_header, 0));
447 int value_ttype = FIX2INT(rb_ary_entry(map_header, 1));
448 int num_entries = FIX2INT(rb_ary_entry(map_header, 2));
449
450 VALUE key_info = rb_hash_aref(field_info, key_sym);
451 VALUE value_info = rb_hash_aref(field_info, value_sym);
452
453 result = rb_hash_new();
454
455 for (i = 0; i < num_entries; ++i) {
456 VALUE key, val;
457
458 key = read_anything(protocol, key_ttype, key_info);
459 val = read_anything(protocol, value_ttype, value_info);
460
461 rb_hash_aset(result, key, val);
462 }
463
464 mt->read_map_end(protocol);
465 } else if (ttype == TTYPE_LIST) {
466 int i;
467
468 VALUE list_header = mt->read_list_begin(protocol);
469 int element_ttype = FIX2INT(rb_ary_entry(list_header, 0));
470 int num_elements = FIX2INT(rb_ary_entry(list_header, 1));
471 result = rb_ary_new2(num_elements);
472
473 for (i = 0; i < num_elements; ++i) {
474 rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
475 }
476
477
478 mt->read_list_end(protocol);
479 } else if (ttype == TTYPE_SET) {
480 VALUE items;
481 int i;
482
483 VALUE set_header = mt->read_set_begin(protocol);
484 int element_ttype = FIX2INT(rb_ary_entry(set_header, 0));
485 int num_elements = FIX2INT(rb_ary_entry(set_header, 1));
486 items = rb_ary_new2(num_elements);
487
488 for (i = 0; i < num_elements; ++i) {
489 rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
490 }
491
492
493 mt->read_set_end(protocol);
494
495 result = rb_class_new_instance(1, &items, rb_cSet);
496 } else {
497 rb_raise(rb_eNotImpError, "read_anything not implemented for type %d!", ttype);
498 }
499
500 return result;
501}
502
503static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol) {
504 // read struct begin
505 mt->read_struct_begin(protocol);
506
507 VALUE struct_fields = STRUCT_FIELDS(self);
508
509 // read each field
510 while (true) {
511 VALUE field_header = rb_funcall(protocol, read_field_begin_method_id, 0);
512 VALUE field_type_value = rb_ary_entry(field_header, 1);
513 int field_type = FIX2INT(field_type_value);
514
515 if (field_type == TTYPE_STOP) {
516 break;
517 }
518
519 // make sure we got a type we expected
520 VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2));
521
522 if (!NIL_P(field_info)) {
523 int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym));
524 if (field_type == specified_type) {
525 // read the value
526 VALUE name = rb_hash_aref(field_info, name_sym);
527 set_field_value(self, name, read_anything(protocol, field_type, field_info));
528 } else {
529 rb_funcall(protocol, skip_method_id, 1, field_type_value);
530 }
531 } else {
532 rb_funcall(protocol, skip_method_id, 1, field_type_value);
533 }
534
535 // read field end
536 mt->read_field_end(protocol);
537 }
538
539 // read struct end
540 mt->read_struct_end(protocol);
541
542 return Qnil;
543}
544
545void Init_struct() {
546 VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct"));
547
548 rb_define_method(struct_module, "write", rb_thrift_struct_write, 1);
549 rb_define_method(struct_module, "read", rb_thrift_struct_read, 1);
550
551 set_default_proto_function_pointers();
552}
553