blob: 3df83a16fac088c833f32426e377c15aef2a9ed1 [file] [log] [blame]
Nobuaki Sukegawa6525f6a2016-02-11 13:58:39 +09001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#ifndef THRIFT_PY_PROTOCOL_TCC
21#define THRIFT_PY_PROTOCOL_TCC
22
23#define CHECK_RANGE(v, min, max) (((v) <= (max)) && ((v) >= (min)))
24#define INIT_OUTBUF_SIZE 128
25
26#include <cStringIO.h>
27
28namespace apache {
29namespace thrift {
30namespace py {
31
32namespace detail {
33
34inline bool input_check(PyObject* input) {
35 return PycStringIO_InputCheck(input);
36}
37
38inline EncodeBuffer* new_encode_buffer(size_t size) {
39 if (!PycStringIO) {
40 PycString_IMPORT;
41 }
42 if (!PycStringIO) {
43 return NULL;
44 }
45 return PycStringIO->NewOutput(size);
46}
47
48inline int read_buffer(PyObject* buf, char** output, int len) {
49 if (!PycStringIO) {
50 PycString_IMPORT;
51 }
52 if (!PycStringIO) {
53 PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO");
54 return -1;
55 }
56 return PycStringIO->cread(buf, output, len);
57}
58}
59
60template <typename Impl>
61inline ProtocolBase<Impl>::~ProtocolBase() {
62 if (output_) {
63 Py_CLEAR(output_);
64 }
65}
66
67template <typename Impl>
68inline bool ProtocolBase<Impl>::isUtf8(PyObject* typeargs) {
69 return PyString_Check(typeargs) && !strncmp(PyString_AS_STRING(typeargs), "UTF8", 4);
70}
71
72template <typename Impl>
73PyObject* ProtocolBase<Impl>::getEncodedValue() {
74 if (!PycStringIO) {
75 PycString_IMPORT;
76 }
77 if (!PycStringIO) {
78 return NULL;
79 }
80 return PycStringIO->cgetvalue(output_);
81}
82
83template <typename Impl>
84inline bool ProtocolBase<Impl>::writeBuffer(char* data, size_t size) {
85 if (!PycStringIO) {
86 PycString_IMPORT;
87 }
88 if (!PycStringIO) {
89 PyErr_SetString(PyExc_ImportError, "failed to import native cStringIO");
90 return false;
91 }
92 int len = PycStringIO->cwrite(output_, data, size);
93 if (len < 0) {
94 PyErr_SetString(PyExc_IOError, "failed to write to cStringIO object");
95 return false;
96 }
97 if (len != size) {
98 PyErr_Format(PyExc_EOFError, "write length mismatch: expected %d got %d", size, len);
99 return false;
100 }
101 return true;
102}
103
104namespace detail {
105
106#define DECLARE_OP_SCOPE(name, op) \
107 template <typename Impl> \
108 struct name##Scope { \
109 Impl* impl; \
110 bool valid; \
111 name##Scope(Impl* thiz) : impl(thiz), valid(impl->op##Begin()) {} \
112 ~name##Scope() { \
113 if (valid) \
114 impl->op##End(); \
115 } \
116 operator bool() { return valid; } \
117 }; \
118 template <typename Impl, template <typename> class T> \
119 name##Scope<Impl> op##Scope(T<Impl>* thiz) { \
120 return name##Scope<Impl>(static_cast<Impl*>(thiz)); \
121 }
122DECLARE_OP_SCOPE(WriteStruct, writeStruct)
123DECLARE_OP_SCOPE(ReadStruct, readStruct)
124#undef DECLARE_OP_SCOPE
125
126inline bool check_ssize_t_32(Py_ssize_t len) {
127 // error from getting the int
128 if (INT_CONV_ERROR_OCCURRED(len)) {
129 return false;
130 }
131 if (!CHECK_RANGE(len, 0, INT32_MAX)) {
132 PyErr_SetString(PyExc_OverflowError, "size out of range: exceeded INT32_MAX");
133 return false;
134 }
135 return true;
136}
137}
138
139template <typename T>
140bool parse_pyint(PyObject* o, T* ret, int32_t min, int32_t max) {
141 long val = PyInt_AsLong(o);
142
143 if (INT_CONV_ERROR_OCCURRED(val)) {
144 return false;
145 }
146 if (!CHECK_RANGE(val, min, max)) {
147 PyErr_SetString(PyExc_OverflowError, "int out of range");
148 return false;
149 }
150
151 *ret = static_cast<T>(val);
152 return true;
153}
154
155template <typename Impl>
156inline bool ProtocolBase<Impl>::checkType(TType got, TType expected) {
157 if (expected != got) {
158 PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field");
159 return false;
160 }
161 return true;
162}
163
164template <typename Impl>
165bool ProtocolBase<Impl>::checkLengthLimit(int32_t len, long limit) {
166 if (len < 0) {
167 PyErr_Format(PyExc_OverflowError, "negative length: %d", limit);
168 return false;
169 }
170 if (len > limit) {
171 PyErr_Format(PyExc_OverflowError, "size exceeded specified limit: %d", limit);
172 return false;
173 }
174 return true;
175}
176
177template <typename Impl>
178bool ProtocolBase<Impl>::readBytes(char** output, int len) {
179 if (len < 0) {
180 PyErr_Format(PyExc_ValueError, "attempted to read negative length: %d", len);
181 return false;
182 }
183 // TODO(dreiss): Don't fear the malloc. Think about taking a copy of
184 // the partial read instead of forcing the transport
185 // to prepend it to its buffer.
186
187 int rlen = detail::read_buffer(input_.stringiobuf.get(), output, len);
188
189 if (rlen == len) {
190 return true;
191 } else if (rlen == -1) {
192 return false;
193 } else {
194 // using building functions as this is a rare codepath
195 ScopedPyObject newiobuf(
196 PyObject_CallFunction(input_.refill_callable.get(), refill_signature, *output, rlen, len, NULL));
197 if (!newiobuf) {
198 return false;
199 }
200
201 // must do this *AFTER* the call so that we don't deref the io buffer
202 input_.stringiobuf.reset(newiobuf.release());
203
204 rlen = detail::read_buffer(input_.stringiobuf.get(), output, len);
205
206 if (rlen == len) {
207 return true;
208 } else if (rlen == -1) {
209 return false;
210 } else {
211 // TODO(dreiss): This could be a valid code path for big binary blobs.
212 PyErr_SetString(PyExc_TypeError, "refill claimed to have refilled the buffer, but didn't!!");
213 return false;
214 }
215 }
216}
217
218template <typename Impl>
219bool ProtocolBase<Impl>::prepareDecodeBufferFromTransport(PyObject* trans) {
220 if (input_.stringiobuf) {
221 PyErr_SetString(PyExc_ValueError, "decode buffer is already initialized");
222 return false;
223 }
224
225 ScopedPyObject stringiobuf(PyObject_GetAttr(trans, INTERN_STRING(cstringio_buf)));
226 if (!stringiobuf) {
227 return false;
228 }
229 if (!detail::input_check(stringiobuf.get())) {
230 PyErr_SetString(PyExc_TypeError, "expecting stringio input_");
231 return false;
232 }
233
234 ScopedPyObject refill_callable(PyObject_GetAttr(trans, INTERN_STRING(cstringio_refill)));
235 if (!refill_callable) {
236 return false;
237 }
238 if (!PyCallable_Check(refill_callable.get())) {
239 PyErr_SetString(PyExc_TypeError, "expecting callable");
240 return false;
241 }
242
243 input_.stringiobuf.swap(stringiobuf);
244 input_.refill_callable.swap(refill_callable);
245 return true;
246}
247
248template <typename Impl>
249bool ProtocolBase<Impl>::prepareEncodeBuffer() {
250 output_ = detail::new_encode_buffer(INIT_OUTBUF_SIZE);
251 return output_ != NULL;
252}
253
254template <typename Impl>
255bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* typeargs) {
256 /*
257 * Refcounting Strategy:
258 *
259 * We assume that elements of the thrift_spec tuple are not going to be
260 * mutated, so we don't ref count those at all. Other than that, we try to
261 * keep a reference to all the user-created objects while we work with them.
262 * encodeValue assumes that a reference is already held. The *caller* is
263 * responsible for handling references
264 */
265
266 switch (type) {
267
268 case T_BOOL: {
269 int v = PyObject_IsTrue(value);
270 if (v == -1) {
271 return false;
272 }
273 impl()->writeBool(v);
274 return true;
275 }
276 case T_I08: {
277 int8_t val;
278
279 if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) {
280 return false;
281 }
282
283 impl()->writeI8(val);
284 return true;
285 }
286 case T_I16: {
287 int16_t val;
288
289 if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) {
290 return false;
291 }
292
293 impl()->writeI16(val);
294 return true;
295 }
296 case T_I32: {
297 int32_t val;
298
299 if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) {
300 return false;
301 }
302
303 impl()->writeI32(val);
304 return true;
305 }
306 case T_I64: {
307 int64_t nval = PyLong_AsLongLong(value);
308
309 if (INT_CONV_ERROR_OCCURRED(nval)) {
310 return false;
311 }
312
313 if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) {
314 PyErr_SetString(PyExc_OverflowError, "int out of range");
315 return false;
316 }
317
318 impl()->writeI64(nval);
319 return true;
320 }
321
322 case T_DOUBLE: {
323 double nval = PyFloat_AsDouble(value);
324 if (nval == -1.0 && PyErr_Occurred()) {
325 return false;
326 }
327
328 impl()->writeDouble(nval);
329 return true;
330 }
331
332 case T_STRING: {
333 if (PyUnicode_Check(value)) {
334 value = PyUnicode_AsUTF8String(value);
335 if (!value) {
336 return false;
337 }
338 }
339
340 Py_ssize_t len = PyBytes_Size(value);
341 if (!detail::check_ssize_t_32(len)) {
342 return false;
343 }
344
345 impl()->writeString(value, static_cast<int32_t>(len));
346 return true;
347 }
348
349 case T_LIST:
350 case T_SET: {
351 SetListTypeArgs parsedargs;
352 if (!parse_set_list_args(&parsedargs, typeargs)) {
353 return false;
354 }
355
356 Py_ssize_t len = PyObject_Length(value);
357 if (!detail::check_ssize_t_32(len)) {
358 return false;
359 }
360
361 if (!impl()->writeListBegin(value, parsedargs, static_cast<int32_t>(len)) || PyErr_Occurred()) {
362 return false;
363 }
364 ScopedPyObject iterator(PyObject_GetIter(value));
365 if (!iterator) {
366 return false;
367 }
368
369 while (PyObject* rawItem = PyIter_Next(iterator.get())) {
370 ScopedPyObject item(rawItem);
371 if (!encodeValue(item.get(), parsedargs.element_type, parsedargs.typeargs)) {
372 return false;
373 }
374 }
375
376 return true;
377 }
378
379 case T_MAP: {
380 Py_ssize_t len = PyDict_Size(value);
381 if (!detail::check_ssize_t_32(len)) {
382 return false;
383 }
384
385 MapTypeArgs parsedargs;
386 if (!parse_map_args(&parsedargs, typeargs)) {
387 return false;
388 }
389
390 if (!impl()->writeMapBegin(value, parsedargs, static_cast<int32_t>(len)) || PyErr_Occurred()) {
391 return false;
392 }
393 Py_ssize_t pos = 0;
394 PyObject* k = NULL;
395 PyObject* v = NULL;
396 // TODO(bmaurer): should support any mapping, not just dicts
397 while (PyDict_Next(value, &pos, &k, &v)) {
398 if (!encodeValue(k, parsedargs.ktag, parsedargs.ktypeargs)
399 || !encodeValue(v, parsedargs.vtag, parsedargs.vtypeargs)) {
400 return false;
401 }
402 }
403 return true;
404 }
405
406 case T_STRUCT: {
407 StructTypeArgs parsedargs;
408 if (!parse_struct_args(&parsedargs, typeargs)) {
409 return false;
410 }
411
412 Py_ssize_t nspec = PyTuple_Size(parsedargs.spec);
413 if (nspec == -1) {
414 PyErr_SetString(PyExc_TypeError, "spec is not a tuple");
415 return false;
416 }
417
418 detail::WriteStructScope<Impl> scope = detail::writeStructScope(this);
419 if (!scope) {
420 return false;
421 }
422 for (Py_ssize_t i = 0; i < nspec; i++) {
423 PyObject* spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i);
424 if (spec_tuple == Py_None) {
425 continue;
426 }
427
428 StructItemSpec parsedspec;
429 if (!parse_struct_item_spec(&parsedspec, spec_tuple)) {
430 return false;
431 }
432
433 ScopedPyObject instval(PyObject_GetAttr(value, parsedspec.attrname));
434
435 if (!instval) {
436 return false;
437 }
438
439 if (instval.get() == Py_None) {
440 continue;
441 }
442
443 bool res = impl()->writeField(instval.get(), parsedspec);
444 if (!res) {
445 return false;
446 }
447 }
448 impl()->writeFieldStop();
449 return true;
450 }
451
452 case T_STOP:
453 case T_VOID:
454 case T_UTF16:
455 case T_UTF8:
456 case T_U64:
457 default:
458 PyErr_Format(PyExc_TypeError, "Unexpected TType for encodeValue: %d", type);
459 return false;
460 }
461
462 return true;
463}
464
465template <typename Impl>
466bool ProtocolBase<Impl>::skip(TType type) {
467 switch (type) {
468 case T_BOOL:
469 return impl()->skipBool();
470 case T_I08:
471 return impl()->skipByte();
472 case T_I16:
473 return impl()->skipI16();
474 case T_I32:
475 return impl()->skipI32();
476 case T_I64:
477 return impl()->skipI64();
478 case T_DOUBLE:
479 return impl()->skipDouble();
480
481 case T_STRING: {
482 return impl()->skipString();
483 }
484
485 case T_LIST:
486 case T_SET: {
487 TType etype = T_STOP;
488 int32_t len = impl()->readListBegin(etype);
489 if (len < 0) {
490 return false;
491 }
492 for (int32_t i = 0; i < len; i++) {
493 if (!skip(etype)) {
494 return false;
495 }
496 }
497 return true;
498 }
499
500 case T_MAP: {
501 TType ktype = T_STOP;
502 TType vtype = T_STOP;
503 int32_t len = impl()->readMapBegin(ktype, vtype);
504 if (len < 0) {
505 return false;
506 }
507 for (int32_t i = 0; i < len; i++) {
508 if (!skip(ktype) || !skip(vtype)) {
509 return false;
510 }
511 }
512 return true;
513 }
514
515 case T_STRUCT: {
516 detail::ReadStructScope<Impl> scope = detail::readStructScope(this);
517 if (!scope) {
518 return false;
519 }
520 while (true) {
521 TType type = T_STOP;
522 int16_t tag;
523 if (!impl()->readFieldBegin(type, tag)) {
524 return false;
525 }
526 if (type == T_STOP) {
527 return true;
528 }
529 if (!skip(type)) {
530 return false;
531 }
532 }
533 return true;
534 }
535
536 case T_STOP:
537 case T_VOID:
538 case T_UTF16:
539 case T_UTF8:
540 case T_U64:
541 default:
542 PyErr_Format(PyExc_TypeError, "Unexpected TType for skip: %d", type);
543 return false;
544 }
545
546 return true;
547}
548
549// Returns a new reference.
550template <typename Impl>
551PyObject* ProtocolBase<Impl>::decodeValue(TType type, PyObject* typeargs) {
552 switch (type) {
553
554 case T_BOOL: {
555 bool v = 0;
556 if (!impl()->readBool(v)) {
557 return NULL;
558 }
559 if (v) {
560 Py_RETURN_TRUE;
561 } else {
562 Py_RETURN_FALSE;
563 }
564 }
565 case T_I08: {
566 int8_t v = 0;
567 if (!impl()->readI8(v)) {
568 return NULL;
569 }
570 return PyInt_FromLong(v);
571 }
572 case T_I16: {
573 int16_t v = 0;
574 if (!impl()->readI16(v)) {
575 return NULL;
576 }
577 return PyInt_FromLong(v);
578 }
579 case T_I32: {
580 int32_t v = 0;
581 if (!impl()->readI32(v)) {
582 return NULL;
583 }
584 return PyInt_FromLong(v);
585 }
586
587 case T_I64: {
588 int64_t v = 0;
589 if (!impl()->readI64(v)) {
590 return NULL;
591 }
592 // TODO(dreiss): Find out if we can take this fastpath always when
593 // sizeof(long) == sizeof(long long).
594 if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) {
595 return PyInt_FromLong((long)v);
596 }
597 return PyLong_FromLongLong(v);
598 }
599
600 case T_DOUBLE: {
601 double v = 0.0;
602 if (!impl()->readDouble(v)) {
603 return NULL;
604 }
605 return PyFloat_FromDouble(v);
606 }
607
608 case T_STRING: {
609 char* buf = NULL;
610 int len = impl()->readString(&buf);
611 if (len < 0) {
612 return NULL;
613 }
614 if (isUtf8(typeargs)) {
615 return PyUnicode_DecodeUTF8(buf, len, 0);
616 } else {
617 return PyBytes_FromStringAndSize(buf, len);
618 }
619 }
620
621 case T_LIST:
622 case T_SET: {
623 SetListTypeArgs parsedargs;
624 if (!parse_set_list_args(&parsedargs, typeargs)) {
625 return NULL;
626 }
627
628 TType etype = T_STOP;
629 int32_t len = impl()->readListBegin(etype);
630 if (len < 0) {
631 return NULL;
632 }
633 if (len > 0 && !checkType(etype, parsedargs.element_type)) {
634 return NULL;
635 }
636
637 bool use_tuple = type == T_LIST && parsedargs.immutable;
638 ScopedPyObject ret(use_tuple ? PyTuple_New(len) : PyList_New(len));
639 if (!ret) {
640 return NULL;
641 }
642
643 for (int i = 0; i < len; i++) {
644 PyObject* item = decodeValue(etype, parsedargs.typeargs);
645 if (!item) {
646 return NULL;
647 }
648 if (use_tuple) {
649 PyTuple_SET_ITEM(ret.get(), i, item);
650 } else {
651 PyList_SET_ITEM(ret.get(), i, item);
652 }
653 }
654
655 // TODO(dreiss): Consider biting the bullet and making two separate cases
656 // for list and set, avoiding this post facto conversion.
657 if (type == T_SET) {
658 PyObject* setret;
659 setret = parsedargs.immutable ? PyFrozenSet_New(ret.get()) : PySet_New(ret.get());
660 return setret;
661 }
662 return ret.release();
663 }
664
665 case T_MAP: {
666 MapTypeArgs parsedargs;
667 if (!parse_map_args(&parsedargs, typeargs)) {
668 return NULL;
669 }
670
671 TType ktype = T_STOP;
672 TType vtype = T_STOP;
673 uint32_t len = impl()->readMapBegin(ktype, vtype);
674 if (len > 0 && (!checkType(ktype, parsedargs.ktag) || !checkType(vtype, parsedargs.vtag))) {
675 return NULL;
676 }
677
678 ScopedPyObject ret(PyDict_New());
679 if (!ret) {
680 return NULL;
681 }
682
683 for (uint32_t i = 0; i < len; i++) {
684 ScopedPyObject k(decodeValue(ktype, parsedargs.ktypeargs));
685 if (!k) {
686 return NULL;
687 }
688 ScopedPyObject v(decodeValue(vtype, parsedargs.vtypeargs));
689 if (!v) {
690 return NULL;
691 }
692 if (PyDict_SetItem(ret.get(), k.get(), v.get()) == -1) {
693 return NULL;
694 }
695 }
696
697 if (parsedargs.immutable) {
698 if (!ThriftModule) {
699 ThriftModule = PyImport_ImportModule("thrift.Thrift");
700 }
701 if (!ThriftModule) {
702 return NULL;
703 }
704
705 ScopedPyObject cls(PyObject_GetAttr(ThriftModule, INTERN_STRING(TFrozenDict)));
706 if (!cls) {
707 return NULL;
708 }
709
710 ScopedPyObject arg(PyTuple_New(1));
711 PyTuple_SET_ITEM(arg.get(), 0, ret.release());
712 ret.reset(PyObject_CallObject(cls.get(), arg.get()));
713 }
714
715 return ret.release();
716 }
717
718 case T_STRUCT: {
719 StructTypeArgs parsedargs;
720 if (!parse_struct_args(&parsedargs, typeargs)) {
721 return NULL;
722 }
723 return readStruct(Py_None, parsedargs.klass, parsedargs.spec);
724 }
725
726 case T_STOP:
727 case T_VOID:
728 case T_UTF16:
729 case T_UTF8:
730 case T_U64:
731 default:
732 PyErr_Format(PyExc_TypeError, "Unexpected TType for decodeValue: %d", type);
733 return NULL;
734 }
735}
736
737template <typename Impl>
738PyObject* ProtocolBase<Impl>::readStruct(PyObject* output, PyObject* klass, PyObject* spec_seq) {
739 int spec_seq_len = PyTuple_Size(spec_seq);
740 bool immutable = output == Py_None;
741 ScopedPyObject kwargs;
742 if (spec_seq_len == -1) {
743 return NULL;
744 }
745
746 if (immutable) {
747 kwargs.reset(PyDict_New());
748 if (!kwargs) {
749 PyErr_SetString(PyExc_TypeError, "failed to prepare kwargument storage");
750 return NULL;
751 }
752 }
753
754 detail::ReadStructScope<Impl> scope = detail::readStructScope(this);
755 if (!scope) {
756 return NULL;
757 }
758 while (true) {
759 TType type = T_STOP;
760 int16_t tag;
761 if (!impl()->readFieldBegin(type, tag)) {
762 return NULL;
763 }
764 if (type == T_STOP) {
765 break;
766 }
767 if (tag < 0 || tag >= spec_seq_len) {
768 if (!skip(type)) {
769 PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field");
770 return NULL;
771 }
772 continue;
773 }
774
775 PyObject* item_spec = PyTuple_GET_ITEM(spec_seq, tag);
776 if (item_spec == Py_None) {
777 if (!skip(type)) {
778 PyErr_SetString(PyExc_TypeError, "Error while skipping unknown field");
779 return NULL;
780 }
781 continue;
782 }
783 StructItemSpec parsedspec;
784 if (!parse_struct_item_spec(&parsedspec, item_spec)) {
785 return NULL;
786 }
787 if (parsedspec.type != type) {
788 if (!skip(type)) {
789 PyErr_Format(PyExc_TypeError, "struct field had wrong type: expected %d but got %d",
790 parsedspec.type, type);
791 return NULL;
792 }
793 continue;
794 }
795
796 ScopedPyObject fieldval(decodeValue(parsedspec.type, parsedspec.typeargs));
797 if (!fieldval) {
798 return NULL;
799 }
800
801 if ((immutable && PyDict_SetItem(kwargs.get(), parsedspec.attrname, fieldval.get()) == -1)
802 || (!immutable && PyObject_SetAttr(output, parsedspec.attrname, fieldval.get()) == -1)) {
803 return NULL;
804 }
805 }
806 if (immutable) {
807 ScopedPyObject args(PyTuple_New(0));
808 if (!args) {
809 PyErr_SetString(PyExc_TypeError, "failed to prepare argument storage");
810 return NULL;
811 }
812 return PyObject_Call(klass, args.get(), kwargs.get());
813 }
814 Py_INCREF(output);
815 return output;
816}
817}
818}
819}
820#endif // THRIFT_PY_PROTOCOL_H