blob: 6e3a5d31be98cf984370f29dd6d825b847e74601 [file] [log] [blame]
Jens Geyerf4598682014-05-08 23:18:44 +02001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package common
21
22import (
Jens Geyerc6b991f2015-08-07 23:41:09 +020023 "compress/zlib"
Jens Geyerf4598682014-05-08 23:18:44 +020024 "crypto/tls"
25 "flag"
26 "fmt"
Yuxuan 'fishy' Wangb71f11e2021-03-22 15:01:00 -070027
28 "github.com/apache/thrift/lib/go/thrift"
29 "github.com/apache/thrift/test/go/src/gen/thrifttest"
Jens Geyerf4598682014-05-08 23:18:44 +020030)
31
32var (
33 debugServerProtocol bool
34 certPath string
35)
36
37func init() {
38 flag.BoolVar(&debugServerProtocol, "debug_server_protocol", false, "turn server protocol trace on")
39}
40
claudemirof8ca0552016-01-10 23:31:30 -020041func GetServerParams(
Jens Geyerf4598682014-05-08 23:18:44 +020042 host string,
43 port int64,
44 domain_socket string,
45 transport string,
46 protocol string,
47 ssl bool,
Roger Meier41ad4342015-03-24 22:30:40 +010048 certPath string,
claudemirof8ca0552016-01-10 23:31:30 -020049 handler thrifttest.ThriftTest) (thrift.TProcessor, thrift.TServerTransport, thrift.TTransportFactory, thrift.TProtocolFactory, error) {
Jens Geyerf4598682014-05-08 23:18:44 +020050
claudemirof8ca0552016-01-10 23:31:30 -020051 var err error
Jens Geyerf4598682014-05-08 23:18:44 +020052 hostPort := fmt.Sprintf("%s:%d", host, port)
53
54 var protocolFactory thrift.TProtocolFactory
55 switch protocol {
56 case "compact":
57 protocolFactory = thrift.NewTCompactProtocolFactory()
58 case "simplejson":
59 protocolFactory = thrift.NewTSimpleJSONProtocolFactory()
60 case "json":
61 protocolFactory = thrift.NewTJSONProtocolFactory()
62 case "binary":
63 protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
Yuxuan 'fishy' Wang4d46c112019-06-07 20:47:18 +080064 case "header":
65 protocolFactory = thrift.NewTHeaderProtocolFactory()
Jens Geyerf4598682014-05-08 23:18:44 +020066 default:
claudemirof8ca0552016-01-10 23:31:30 -020067 return nil, nil, nil, nil, fmt.Errorf("Invalid protocol specified %s", protocol)
Jens Geyerf4598682014-05-08 23:18:44 +020068 }
69 if debugServerProtocol {
70 protocolFactory = thrift.NewTDebugProtocolFactory(protocolFactory, "server:")
71 }
72
73 var serverTransport thrift.TServerTransport
74 if ssl {
75 cfg := new(tls.Config)
76 if cert, err := tls.LoadX509KeyPair(certPath+"/server.crt", certPath+"/server.key"); err != nil {
claudemirof8ca0552016-01-10 23:31:30 -020077 return nil, nil, nil, nil, err
Jens Geyerf4598682014-05-08 23:18:44 +020078 } else {
79 cfg.Certificates = append(cfg.Certificates, cert)
80 }
81 serverTransport, err = thrift.NewTSSLServerSocket(hostPort, cfg)
82 } else {
83 if domain_socket != "" {
84 serverTransport, err = thrift.NewTServerSocket(domain_socket)
85 } else {
86 serverTransport, err = thrift.NewTServerSocket(hostPort)
87 }
88 }
89 if err != nil {
claudemirof8ca0552016-01-10 23:31:30 -020090 return nil, nil, nil, nil, err
Jens Geyerf4598682014-05-08 23:18:44 +020091 }
92
93 var transportFactory thrift.TTransportFactory
94
95 switch transport {
96 case "http":
claudemirof8ca0552016-01-10 23:31:30 -020097 // there is no such factory, and we don't need any
98 transportFactory = nil
Jens Geyerf4598682014-05-08 23:18:44 +020099 case "framed":
100 transportFactory = thrift.NewTTransportFactory()
101 transportFactory = thrift.NewTFramedTransportFactory(transportFactory)
102 case "buffered":
103 transportFactory = thrift.NewTBufferedTransportFactory(8192)
Jens Geyerc6b991f2015-08-07 23:41:09 +0200104 case "zlib":
105 transportFactory = thrift.NewTZlibTransportFactory(zlib.BestCompression)
Jens Geyerf4598682014-05-08 23:18:44 +0200106 case "":
107 transportFactory = thrift.NewTTransportFactory()
108 default:
claudemirof8ca0552016-01-10 23:31:30 -0200109 return nil, nil, nil, nil, fmt.Errorf("Invalid transport specified %s", transport)
Jens Geyerf4598682014-05-08 23:18:44 +0200110 }
111 processor := thrifttest.NewThriftTestProcessor(handler)
claudemirof8ca0552016-01-10 23:31:30 -0200112
113 return processor, serverTransport, transportFactory, protocolFactory, nil
Jens Geyerf4598682014-05-08 23:18:44 +0200114}