THRIFT-4285 Move TX/RX methods from gen. code to library
This change removes a lot of duplication from generated code and allows
the caller to customize how they can read from / write to the
transport. Backwards compatible adapters make the change compatible
with existing code in use by consuming applications.
Client: Go
This closes #1382
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index bac1c57..e869b00 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -1878,20 +1878,16 @@
f_types_ << indent() << "type " << serviceName << "Client struct {" << endl;
indent_up();
+ f_types_ << indent() << "c thrift.TClient" << endl;
if (!extends_client.empty()) {
f_types_ << indent() << "*" << extends_client << endl;
- } else {
- f_types_ << indent() << "Transport thrift.TTransport" << endl;
- f_types_ << indent() << "ProtocolFactory thrift.TProtocolFactory" << endl;
- f_types_ << indent() << "InputProtocol thrift.TProtocol" << endl;
- f_types_ << indent() << "OutputProtocol thrift.TProtocol" << endl;
- f_types_ << indent() << "SeqId int32" << endl;
- /*f_types_ << indent() << "reqs map[int32]Deferred" << endl*/;
}
indent_down();
f_types_ << indent() << "}" << endl << endl;
- // Constructor function
+
+ // Legacy constructor function
+ f_types_ << indent() << "// Deprecated: Use New" << serviceName << " instead" << endl;
f_types_ << indent() << "func New" << serviceName
<< "ClientFactory(t thrift.TTransport, f thrift.TProtocolFactory) *" << serviceName
<< "Client {" << endl;
@@ -1902,19 +1898,16 @@
f_types_ << "{" << extends_field << ": " << extends_client_new << "Factory(t, f)}";
} else {
indent_up();
- f_types_ << "{Transport: t," << endl;
- f_types_ << indent() << "ProtocolFactory: f," << endl;
- f_types_ << indent() << "InputProtocol: f.GetProtocol(t)," << endl;
- f_types_ << indent() << "OutputProtocol: f.GetProtocol(t)," << endl;
- f_types_ << indent() << "SeqId: 0," << endl;
- /*f_types_ << indent() << "Reqs: make(map[int32]Deferred)" << endl*/;
+ f_types_ << "{" << endl;
+ f_types_ << indent() << "c: thrift.NewTStandardClient(f.GetProtocol(t), f.GetProtocol(t))," << endl;
indent_down();
f_types_ << indent() << "}" << endl;
}
indent_down();
f_types_ << indent() << "}" << endl << endl;
- // Constructor function
+ // Legacy constructor function with custom input & output protocols
+ f_types_ << indent() << "// Deprecated: Use New" << serviceName << " instead" << endl;
f_types_
<< indent() << "func New" << serviceName
<< "ClientProtocol(t thrift.TTransport, iprot thrift.TProtocol, oprot thrift.TProtocol) *"
@@ -1927,18 +1920,32 @@
<< endl;
} else {
indent_up();
- f_types_ << "{Transport: t," << endl;
- f_types_ << indent() << "ProtocolFactory: nil," << endl;
- f_types_ << indent() << "InputProtocol: iprot," << endl;
- f_types_ << indent() << "OutputProtocol: oprot," << endl;
- f_types_ << indent() << "SeqId: 0," << endl;
- /*f_types_ << indent() << "Reqs: make(map[int32]interface{})" << endl*/;
+ f_types_ << "{" << endl;
+ f_types_ << indent() << "c: thrift.NewTStandardClient(iprot, oprot)," << endl;
indent_down();
f_types_ << indent() << "}" << endl;
}
indent_down();
f_types_ << indent() << "}" << endl << endl;
+
+ // Constructor function
+ f_types_ << indent() << "func New" << serviceName
+ << "Client(c thrift.TClient) *" << serviceName << "Client {" << endl;
+ indent_up();
+ f_types_ << indent() << "return &" << serviceName << "Client{" << endl;
+
+ indent_up();
+ f_types_ << indent() << "c: c," << endl;
+ if (!extends.empty()) {
+ f_types_ << indent() << extends_field << ": " << extends_client_new << "(c)," << endl;
+ }
+ indent_down();
+ f_types_ << indent() << "}" << endl;
+
+ indent_down();
+ f_types_ << indent() << "}" << endl << endl;
+
// Generate client method implementations
vector<t_function*> functions = tservice->get_functions();
vector<t_function*>::const_iterator f_iter;
@@ -1953,177 +1960,75 @@
f_types_ << indent() << "func (p *" << serviceName << "Client) "
<< function_signature_if(*f_iter, "", true) << " {" << endl;
indent_up();
- /*
- f_types_ <<
- indent() << "p.SeqId += 1" << endl;
- if (!(*f_iter)->is_oneway()) {
- f_types_ <<
- indent() << "d := defer.Deferred()" << endl <<
- indent() << "p.Reqs[p.SeqId] = d" << endl;
- }
- */
- f_types_ << indent() << "if err = p.send" << funname << "(";
- bool first = true;
+
+ std::string method = (*f_iter)->get_name();
+ std::string argsType = publicize(method + "_args", true);
+ std::string argsName = tmp("_args");
+ f_types_ << indent() << "var " << argsName << " " << argsType << endl;
for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
- if (first) {
- first = false;
- } else {
- f_types_ << ", ";
- }
-
- f_types_ << variable_name_to_go_name((*fld_iter)->get_name());
+ f_types_ << indent() << argsName << "." << publicize((*fld_iter)->get_name())
+ << " = " << variable_name_to_go_name((*fld_iter)->get_name()) << endl;
}
- f_types_ << "); err != nil { return }" << endl;
-
if (!(*f_iter)->is_oneway()) {
- f_types_ << indent() << "return p.recv" << funname << "()" << endl;
- } else {
- f_types_ << indent() << "return" << endl;
- }
+ std::string resultName = tmp("_result");
+ std::string resultType = publicize(method + "_result", true);
+ f_types_ << indent() << "var " << resultName << " " << resultType << endl;
+ f_types_ << indent() << "if err = p.c.Call(ctx, \""
+ << method << "\", &" << argsName << ", &" << resultName << "); err != nil {" << endl;
- indent_down();
- f_types_ << indent() << "}" << endl << endl;
- f_types_ << indent() << "func (p *" << serviceName << "Client) send"
- << function_signature(*f_iter) << "(err error) {" << endl;
- indent_up();
- std::string argsname = publicize((*f_iter)->get_name() + "_args", true);
- // Serialize the request header
- f_types_ << indent() << "oprot := p.OutputProtocol" << endl;
- f_types_ << indent() << "if oprot == nil {" << endl;
- f_types_ << indent() << " oprot = p.ProtocolFactory.GetProtocol(p.Transport)" << endl;
- f_types_ << indent() << " p.OutputProtocol = oprot" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "p.SeqId++" << endl;
- f_types_ << indent() << "if err = oprot.WriteMessageBegin(\"" << (*f_iter)->get_name()
- << "\", " << ((*f_iter)->is_oneway() ? "thrift.ONEWAY" : "thrift.CALL")
- << ", p.SeqId); err != nil {" << endl;
- indent_up();
- f_types_ << indent() << " return" << endl;
- indent_down();
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "args := " << argsname << "{" << endl;
-
- for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
- f_types_ << indent() << publicize((*fld_iter)->get_name()) << " : "
- << variable_name_to_go_name((*fld_iter)->get_name()) << "," << endl;
- }
- f_types_ << indent() << "}" << endl;
-
- // Write to the stream
- f_types_ << indent() << "if err = args." << write_method_name_ << "(oprot); err != nil {" << endl;
- indent_up();
- f_types_ << indent() << " return" << endl;
- indent_down();
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err = oprot.WriteMessageEnd(); err != nil {" << endl;
- indent_up();
- f_types_ << indent() << " return" << endl;
- indent_down();
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "return oprot.Flush()" << endl;
- indent_down();
- f_types_ << indent() << "}" << endl << endl;
-
- if (!(*f_iter)->is_oneway()) {
- std::string resultname = publicize((*f_iter)->get_name() + "_result", true);
- // Open function
- f_types_ << endl << indent() << "func (p *" << serviceName << "Client) recv"
- << publicize((*f_iter)->get_name()) << "() (";
-
- if (!(*f_iter)->get_returntype()->is_void()) {
- f_types_ << "value " << type_to_go_type((*f_iter)->get_returntype()) << ", ";
- }
-
- f_types_ << "err error) {" << endl;
indent_up();
- // TODO(mcslee): Validate message reply here, seq ids etc.
- string error(tmp("error"));
- string error2(tmp("error"));
- f_types_ << indent() << "iprot := p.InputProtocol" << endl;
- f_types_ << indent() << "if iprot == nil {" << endl;
- f_types_ << indent() << " iprot = p.ProtocolFactory.GetProtocol(p.Transport)" << endl;
- f_types_ << indent() << " p.InputProtocol = iprot" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "method, mTypeId, seqId, err := iprot.ReadMessageBegin()" << endl;
- f_types_ << indent() << "if err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if method != \"" << (*f_iter)->get_name() << "\" {" << endl;
- f_types_ << indent() << " err = thrift.NewTApplicationException("
- << "thrift.WRONG_METHOD_NAME, \"" << (*f_iter)->get_name()
- << " failed: wrong method name\")" << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if p.SeqId != seqId {" << endl;
- f_types_ << indent() << " err = thrift.NewTApplicationException("
- << "thrift.BAD_SEQUENCE_ID, \"" << (*f_iter)->get_name()
- << " failed: out of sequence response\")" << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if mTypeId == thrift.EXCEPTION {" << endl;
- f_types_ << indent() << " " << error
- << " := thrift.NewTApplicationException(thrift.UNKNOWN_APPLICATION_EXCEPTION, "
- "\"Unknown Exception\")" << endl;
- f_types_ << indent() << " var " << error2 << " error" << endl;
- f_types_ << indent() << " " << error2 << ", err = " << error << ".Read(iprot)" << endl;
- f_types_ << indent() << " if err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << " }" << endl;
- f_types_ << indent() << " if err = iprot.ReadMessageEnd(); err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << " }" << endl;
- f_types_ << indent() << " err = " << error2 << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if mTypeId != thrift.REPLY {" << endl;
- f_types_ << indent() << " err = thrift.NewTApplicationException("
- << "thrift.INVALID_MESSAGE_TYPE_EXCEPTION, \"" << (*f_iter)->get_name()
- << " failed: invalid message type\")" << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "result := " << resultname << "{}" << endl;
- f_types_ << indent() << "if err = result." << read_method_name_ << "(iprot); err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
- f_types_ << indent() << "}" << endl;
- f_types_ << indent() << "if err = iprot.ReadMessageEnd(); err != nil {" << endl;
- f_types_ << indent() << " return" << endl;
+ f_types_ << indent() << "return" << endl;
+ indent_down();
f_types_ << indent() << "}" << endl;
t_struct* xs = (*f_iter)->get_xceptions();
const std::vector<t_field*>& xceptions = xs->get_members();
vector<t_field*>::const_iterator x_iter;
- for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
- const std::string pubname = publicize((*x_iter)->get_name());
+ if (!xceptions.empty()) {
+ f_types_ << indent() << "switch {" << endl;
- f_types_ << indent() << "if result." << pubname << " != nil {" << endl;
- f_types_ << indent() << " err = result." << pubname << endl;
- f_types_ << indent() << " return " << endl;
- f_types_ << indent() << "}";
+ for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
+ const std::string pubname = publicize((*x_iter)->get_name());
+ const std::string field = resultName + "." + pubname;
- if ((x_iter + 1) != xceptions.end()) {
- f_types_ << " else ";
- } else {
- f_types_ << endl;
+ f_types_ << indent() << "case " << field << "!= nil:" << endl;
+ indent_up();
+
+ if (!(*f_iter)->get_returntype()->is_void()) {
+ f_types_ << indent() << "return r, " << field << endl;
+ } else {
+ f_types_ << indent() << "return "<< field << endl;
+ }
+
+ indent_down();
}
+
+ f_types_ << indent() << "}" << endl << endl;
}
- // Careful, only return _result if not a void function
if (!(*f_iter)->get_returntype()->is_void()) {
- f_types_ << indent() << "value = result.GetSuccess()" << endl;
+ f_types_ << indent() << "return " << resultName << ".GetSuccess(), nil" << endl;
+ } else {
+ f_types_ << indent() << "return nil" << endl;
}
+ } else {
+ // TODO: would be nice to not to duplicate the call generation
+ f_types_ << indent() << "if err := p.c.Call(ctx, \""
+ << method << "\", &"<< argsName << ", nil); err != nil {" << endl;
- f_types_ << indent() << "return" << endl;
- // Close function
+ indent_up();
+ f_types_ << indent() << "return err" << endl;
indent_down();
- f_types_ << indent() << "}" << endl << endl;
+ f_types_ << indent() << "}" << endl;
+ f_types_ << indent() << "return nil" << endl;
}
- }
- // indent_down();
- f_types_ << endl;
+ indent_down();
+ f_types_ << "}" << endl << endl;
+ }
}
/**
@@ -2308,8 +2213,10 @@
f_remote << indent() << " Usage()" << endl;
f_remote << indent() << " os.Exit(1)" << endl;
f_remote << indent() << "}" << endl;
+ f_remote << indent() << "iprot := protocolFactory.GetProtocol(trans)" << endl;
+ f_remote << indent() << "oprot := protocolFactory.GetProtocol(trans)" << endl;
f_remote << indent() << "client := " << package_name_ << ".New" << publicize(service_name_)
- << "ClientFactory(trans, protocolFactory)" << endl;
+ << "Client(thrift.NewTStandardClient(iprot, oprot))" << endl;
f_remote << indent() << "if err := trans.Open(); err != nil {" << endl;
f_remote << indent() << " fmt.Fprintln(os.Stderr, \"Error opening socket to \", "
"host, \":\", port, \" \", err)" << endl;
@@ -3444,10 +3351,13 @@
* @return String of rendered function definition
*/
string t_go_generator::function_signature_if(t_function* tfunction, string prefix, bool addError) {
- // TODO(mcslee): Nitpicky, no ',' if argument_list is empty
string signature = publicize(prefix + tfunction->get_name()) + "(";
- signature += "ctx context.Context, ";
- signature += argument_list(tfunction->get_arglist()) + ") (";
+ signature += "ctx context.Context";
+ if (!tfunction->get_arglist()->get_members().empty()) {
+ signature += ", " + argument_list(tfunction->get_arglist());
+ }
+ signature += ") (";
+
t_type* ret = tfunction->get_returntype();
t_struct* exceptions = tfunction->get_xceptions();
string errs = argument_list(exceptions);
diff --git a/lib/go/test/tests/client_error_test.go b/lib/go/test/tests/client_error_test.go
index ad43447..4a8ef13 100644
--- a/lib/go/test/tests/client_error_test.go
+++ b/lib/go/test/tests/client_error_test.go
@@ -20,11 +20,12 @@
package tests
import (
- "github.com/golang/mock/gomock"
"errors"
"errortest"
"testing"
"thrift"
+
+ "github.com/golang/mock/gomock"
)
// TestCase: Comprehensive call and reply workflow in the client.
@@ -397,7 +398,6 @@
// Expecting TTransportError on fail.
func TestClientReportTTransportErrors(t *testing.T) {
mockCtrl := gomock.NewController(t)
- transport := thrift.NewTMemoryBuffer()
thing := errortest.NewTestStruct()
thing.M = make(map[string]string)
@@ -411,6 +411,38 @@
if !prepareClientCallReply(protocol, i, err) {
return
}
+ client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
+ _, retErr := client.TestStruct(defaultCtx, thing)
+ mockCtrl.Finish()
+ err2, ok := retErr.(thrift.TTransportException)
+ if !ok {
+ t.Fatal("Expected a TTrasportException")
+ }
+
+ if err2.TypeId() != thrift.TIMED_OUT {
+ t.Fatal("Expected TIMED_OUT error")
+ }
+ }
+}
+
+// TestCase: Comprehensive call and reply workflow in the client.
+// Expecting TTransportError on fail.
+// Similar to TestClientReportTTransportErrors, but using legacy client constructor.
+func TestClientReportTTransportErrorsLegacy(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ transport := thrift.NewTMemoryBuffer()
+ thing := errortest.NewTestStruct()
+ thing.M = make(map[string]string)
+ thing.L = make([]string, 0)
+ thing.S = make([]string, 0)
+ thing.I = 3
+
+ err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
+ for i := 0; ; i++ {
+ protocol := NewMockTProtocol(mockCtrl)
+ if !prepareClientCallReply(protocol, i, err) {
+ return
+ }
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
_, retErr := client.TestStruct(defaultCtx, thing)
mockCtrl.Finish()
@@ -429,7 +461,6 @@
// Expecting TTProtocolErrors on fail.
func TestClientReportTProtocolErrors(t *testing.T) {
mockCtrl := gomock.NewController(t)
- transport := thrift.NewTMemoryBuffer()
thing := errortest.NewTestStruct()
thing.M = make(map[string]string)
@@ -443,6 +474,37 @@
if !prepareClientCallReply(protocol, i, err) {
return
}
+ client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
+ _, retErr := client.TestStruct(defaultCtx, thing)
+ mockCtrl.Finish()
+ err2, ok := retErr.(thrift.TProtocolException)
+ if !ok {
+ t.Fatal("Expected a TProtocolException")
+ }
+ if err2.TypeId() != thrift.INVALID_DATA {
+ t.Fatal("Expected INVALID_DATA error")
+ }
+ }
+}
+
+// TestCase: Comprehensive call and reply workflow in the client.
+// Expecting TTProtocolErrors on fail.
+// Similar to TestClientReportTProtocolErrors, but using legacy client constructor.
+func TestClientReportTProtocolErrorsLegacy(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ transport := thrift.NewTMemoryBuffer()
+ thing := errortest.NewTestStruct()
+ thing.M = make(map[string]string)
+ thing.L = make([]string, 0)
+ thing.S = make([]string, 0)
+ thing.I = 3
+
+ err := thrift.NewTProtocolExceptionWithType(thrift.INVALID_DATA, errors.New("test"))
+ for i := 0; ; i++ {
+ protocol := NewMockTProtocol(mockCtrl)
+ if !prepareClientCallReply(protocol, i, err) {
+ return
+ }
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
_, retErr := client.TestStruct(defaultCtx, thing)
mockCtrl.Finish()
@@ -557,13 +619,47 @@
// TestCase: call and reply with exception workflow in the client.
func TestClientCallException(t *testing.T) {
mockCtrl := gomock.NewController(t)
- transport := thrift.NewTMemoryBuffer()
err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
for i := 0; ; i++ {
protocol := NewMockTProtocol(mockCtrl)
willComplete := !prepareClientCallException(protocol, i, err)
+ client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
+ _, retErr := client.TestString(defaultCtx, "test")
+ mockCtrl.Finish()
+
+ if !willComplete {
+ err2, ok := retErr.(thrift.TTransportException)
+ if !ok {
+ t.Fatal("Expected a TTransportException")
+ }
+ if err2.TypeId() != thrift.TIMED_OUT {
+ t.Fatal("Expected TIMED_OUT error")
+ }
+ } else {
+ err2, ok := retErr.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected a TApplicationException")
+ }
+ if err2.TypeId() != thrift.PROTOCOL_ERROR {
+ t.Fatal("Expected PROTOCOL_ERROR error")
+ }
+ break
+ }
+ }
+}
+
+// TestCase: call and reply with exception workflow in the client.
+// Similar to TestClientCallException, but using legacy client constructor.
+func TestClientCallExceptionLegacy(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
+ transport := thrift.NewTMemoryBuffer()
+ err := thrift.NewTTransportException(thrift.TIMED_OUT, "test")
+ for i := 0; ; i++ {
+ protocol := NewMockTProtocol(mockCtrl)
+ willComplete := !prepareClientCallException(protocol, i, err)
+
client := errortest.NewErrorTestClientProtocol(transport, protocol, protocol)
_, retErr := client.TestString(defaultCtx, "test")
mockCtrl.Finish()
@@ -592,6 +688,36 @@
// TestCase: Mismatching sequence id has been received in the client.
func TestClientSeqIdMismatch(t *testing.T) {
mockCtrl := gomock.NewController(t)
+ protocol := NewMockTProtocol(mockCtrl)
+ gomock.InOrder(
+ protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin("testString_args"),
+ protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString("test"),
+ protocol.EXPECT().WriteFieldEnd(),
+ protocol.EXPECT().WriteFieldStop(),
+ protocol.EXPECT().WriteStructEnd(),
+ protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().Flush(),
+ protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.REPLY, int32(2), nil),
+ )
+
+ client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
+ _, err := client.TestString(defaultCtx, "test")
+ mockCtrl.Finish()
+ appErr, ok := err.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected TApplicationException")
+ }
+ if appErr.TypeId() != thrift.BAD_SEQUENCE_ID {
+ t.Fatal("Expected BAD_SEQUENCE_ID error")
+ }
+}
+
+// TestCase: Mismatching sequence id has been received in the client.
+// Similar to TestClientSeqIdMismatch, but using legacy client constructor.
+func TestClientSeqIdMismatchLegeacy(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
transport := thrift.NewTMemoryBuffer()
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
@@ -622,6 +748,36 @@
// TestCase: Wrong method name has been received in the client.
func TestClientWrongMethodName(t *testing.T) {
mockCtrl := gomock.NewController(t)
+ protocol := NewMockTProtocol(mockCtrl)
+ gomock.InOrder(
+ protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin("testString_args"),
+ protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString("test"),
+ protocol.EXPECT().WriteFieldEnd(),
+ protocol.EXPECT().WriteFieldStop(),
+ protocol.EXPECT().WriteStructEnd(),
+ protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().Flush(),
+ protocol.EXPECT().ReadMessageBegin().Return("unknown", thrift.REPLY, int32(1), nil),
+ )
+
+ client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
+ _, err := client.TestString(defaultCtx, "test")
+ mockCtrl.Finish()
+ appErr, ok := err.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected TApplicationException")
+ }
+ if appErr.TypeId() != thrift.WRONG_METHOD_NAME {
+ t.Fatal("Expected WRONG_METHOD_NAME error")
+ }
+}
+
+// TestCase: Wrong method name has been received in the client.
+// Similar to TestClientWrongMethodName, but using legacy client constructor.
+func TestClientWrongMethodNameLegacy(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
transport := thrift.NewTMemoryBuffer()
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
@@ -652,6 +808,36 @@
// TestCase: Wrong message type has been received in the client.
func TestClientWrongMessageType(t *testing.T) {
mockCtrl := gomock.NewController(t)
+ protocol := NewMockTProtocol(mockCtrl)
+ gomock.InOrder(
+ protocol.EXPECT().WriteMessageBegin("testString", thrift.CALL, int32(1)),
+ protocol.EXPECT().WriteStructBegin("testString_args"),
+ protocol.EXPECT().WriteFieldBegin("s", thrift.TType(thrift.STRING), int16(1)),
+ protocol.EXPECT().WriteString("test"),
+ protocol.EXPECT().WriteFieldEnd(),
+ protocol.EXPECT().WriteFieldStop(),
+ protocol.EXPECT().WriteStructEnd(),
+ protocol.EXPECT().WriteMessageEnd(),
+ protocol.EXPECT().Flush(),
+ protocol.EXPECT().ReadMessageBegin().Return("testString", thrift.INVALID_TMESSAGE_TYPE, int32(1), nil),
+ )
+
+ client := errortest.NewErrorTestClient(thrift.NewTStandardClient(protocol, protocol))
+ _, err := client.TestString(defaultCtx, "test")
+ mockCtrl.Finish()
+ appErr, ok := err.(thrift.TApplicationException)
+ if !ok {
+ t.Fatal("Expected TApplicationException")
+ }
+ if appErr.TypeId() != thrift.INVALID_MESSAGE_TYPE_EXCEPTION {
+ t.Fatal("Expected INVALID_MESSAGE_TYPE_EXCEPTION error")
+ }
+}
+
+// TestCase: Wrong message type has been received in the client.
+// Similar to TestClientWrongMessageType, but using legacy client constructor.
+func TestClientWrongMessageTypeLegacy(t *testing.T) {
+ mockCtrl := gomock.NewController(t)
transport := thrift.NewTMemoryBuffer()
protocol := NewMockTProtocol(mockCtrl)
gomock.InOrder(
diff --git a/lib/go/test/tests/multiplexed_protocol_test.go b/lib/go/test/tests/multiplexed_protocol_test.go
index 27802e5..0b5896b 100644
--- a/lib/go/test/tests/multiplexed_protocol_test.go
+++ b/lib/go/test/tests/multiplexed_protocol_test.go
@@ -36,15 +36,22 @@
}
}
+func createTransport(addr net.Addr) (thrift.TTransport, error) {
+ socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
+ transport := thrift.NewTFramedTransport(socket)
+ err := transport.Open()
+ if err != nil {
+ return nil, err
+ }
+ return transport, nil
+}
-var processor = thrift.NewTMultiplexedProcessor()
-
-func TestInitTwoServers(t *testing.T) {
- var err error
+func TestMultiplexedProtocolFirst(t *testing.T) {
+ processor := thrift.NewTMultiplexedProcessor()
protocolFactory := thrift.NewTBinaryProtocolFactoryDefault()
transportFactory := thrift.NewTTransportFactory()
transportFactory = thrift.NewTFramedTransportFactory(transportFactory)
- addr = FindAvailableTCPServerPort()
+ addr := FindAvailableTCPServerPort()
serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT)
if err != nil {
t.Fatal("Unable to create server socket", err)
@@ -57,82 +64,117 @@
secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{})
processor.RegisterProcessor("SecondService", secondProcessor)
+ defer server.Stop()
go server.Serve()
time.Sleep(10 * time.Millisecond)
-}
-var firstClient *multiplexedprotocoltest.FirstClient
-
-func TestInitClient1(t *testing.T) {
- socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
- transport := thrift.NewTFramedTransport(socket)
- var protocol thrift.TProtocol = thrift.NewTBinaryProtocolTransport(transport)
- protocol = thrift.NewTMultiplexedProtocol(protocol, "FirstService")
- firstClient = multiplexedprotocoltest.NewFirstClientProtocol(transport, protocol, protocol)
- err := transport.Open()
+ transport, err := createTransport(addr)
if err != nil {
- t.Fatal("Unable to open client socket", err)
+ t.Fatal(err)
}
-}
+ defer transport.Close()
+ protocol := thrift.NewTMultiplexedProtocol(thrift.NewTBinaryProtocolTransport(transport), "FirstService")
-var secondClient *multiplexedprotocoltest.SecondClient
+ client := multiplexedprotocoltest.NewFirstClient(thrift.NewTStandardClient(protocol, protocol))
-func TestInitClient2(t *testing.T) {
- socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
- transport := thrift.NewTFramedTransport(socket)
- var protocol thrift.TProtocol = thrift.NewTBinaryProtocolTransport(transport)
- protocol = thrift.NewTMultiplexedProtocol(protocol, "SecondService")
- secondClient = multiplexedprotocoltest.NewSecondClientProtocol(transport, protocol, protocol)
- err := transport.Open()
- if err != nil {
- t.Fatal("Unable to open client socket", err)
- }
-}
-
-//create client without service prefix
-func createLegacyClient(t *testing.T) *multiplexedprotocoltest.SecondClient {
- socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
- transport := thrift.NewTFramedTransport(socket)
- var protocol thrift.TProtocol = thrift.NewTBinaryProtocolTransport(transport)
- legacyClient := multiplexedprotocoltest.NewSecondClientProtocol(transport, protocol, protocol)
- err := transport.Open()
- if err != nil {
- t.Fatal("Unable to open client socket", err)
- }
- return legacyClient
-}
-
-func TestCallFirst(t *testing.T) {
- ret, err := firstClient.ReturnOne(defaultCtx)
+ ret, err := client.ReturnOne(defaultCtx)
if err != nil {
t.Fatal("Unable to call first server:", err)
- }
- if ret != 1 {
+ } else if ret != 1 {
t.Fatal("Unexpected result from server: ", ret)
}
}
-func TestCallSecond(t *testing.T) {
- ret, err := secondClient.ReturnTwo(defaultCtx)
+func TestMultiplexedProtocolSecond(t *testing.T) {
+ processor := thrift.NewTMultiplexedProcessor()
+ protocolFactory := thrift.NewTBinaryProtocolFactoryDefault()
+ transportFactory := thrift.NewTTransportFactory()
+ transportFactory = thrift.NewTFramedTransportFactory(transportFactory)
+ addr := FindAvailableTCPServerPort()
+ serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT)
+ if err != nil {
+ t.Fatal("Unable to create server socket", err)
+ }
+ server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory)
+
+ firstProcessor := multiplexedprotocoltest.NewFirstProcessor(&FirstImpl{})
+ processor.RegisterProcessor("FirstService", firstProcessor)
+
+ secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{})
+ processor.RegisterProcessor("SecondService", secondProcessor)
+
+ defer server.Stop()
+ go server.Serve()
+ time.Sleep(10 * time.Millisecond)
+
+ transport, err := createTransport(addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer transport.Close()
+ protocol := thrift.NewTMultiplexedProtocol(thrift.NewTBinaryProtocolTransport(transport), "SecondService")
+
+ client := multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol))
+
+ ret, err := client.ReturnTwo(defaultCtx)
if err != nil {
t.Fatal("Unable to call second server:", err)
- }
- if ret != 2 {
+ } else if ret != 2 {
t.Fatal("Unexpected result from server: ", ret)
}
}
-func TestCallLegacy(t *testing.T) {
- legacyClient := createLegacyClient(t)
- ret, err := legacyClient.ReturnTwo(defaultCtx)
+func TestMultiplexedProtocolLegacy(t *testing.T) {
+ processor := thrift.NewTMultiplexedProcessor()
+ protocolFactory := thrift.NewTBinaryProtocolFactoryDefault()
+ transportFactory := thrift.NewTTransportFactory()
+ transportFactory = thrift.NewTFramedTransportFactory(transportFactory)
+ addr := FindAvailableTCPServerPort()
+ serverTransport, err := thrift.NewTServerSocketTimeout(addr.String(), TIMEOUT)
+ if err != nil {
+ t.Fatal("Unable to create server socket", err)
+ }
+ server = thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory)
+
+ firstProcessor := multiplexedprotocoltest.NewFirstProcessor(&FirstImpl{})
+ processor.RegisterProcessor("FirstService", firstProcessor)
+
+ secondProcessor := multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{})
+ processor.RegisterProcessor("SecondService", secondProcessor)
+
+ defer server.Stop()
+ go server.Serve()
+ time.Sleep(10 * time.Millisecond)
+
+ transport, err := createTransport(addr)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer transport.Close()
+
+ protocol := thrift.NewTBinaryProtocolTransport(transport)
+ client := multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol))
+
+ ret, err := client.ReturnTwo(defaultCtx)
//expect error since default processor is not registered
if err == nil {
t.Fatal("Expecting error")
}
+
//register default processor and call again
processor.RegisterDefault(multiplexedprotocoltest.NewSecondProcessor(&SecondImpl{}))
- legacyClient = createLegacyClient(t)
- ret, err = legacyClient.ReturnTwo(defaultCtx)
+ transport, err = createTransport(addr)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer transport.Close()
+
+ protocol = thrift.NewTBinaryProtocolTransport(transport)
+ client = multiplexedprotocoltest.NewSecondClient(thrift.NewTStandardClient(protocol, protocol))
+
+ ret, err = client.ReturnTwo(defaultCtx)
if err != nil {
t.Fatal("Unable to call legacy server:", err)
}
@@ -140,9 +182,3 @@
t.Fatal("Unexpected result from server: ", ret)
}
}
-
-func TestShutdownServerAndClients(t *testing.T) {
- firstClient.Transport.Close()
- secondClient.Transport.Close()
- server.Stop()
-}
diff --git a/lib/go/test/tests/one_way_test.go b/lib/go/test/tests/one_way_test.go
index 32881e2..8abd671 100644
--- a/lib/go/test/tests/one_way_test.go
+++ b/lib/go/test/tests/one_way_test.go
@@ -59,7 +59,7 @@
func TestInitOnewayClient(t *testing.T) {
transport := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
protocol := thrift.NewTBinaryProtocolTransport(transport)
- client = onewaytest.NewOneWayClientProtocol(transport, protocol, protocol)
+ client = onewaytest.NewOneWayClient(thrift.NewTStandardClient(protocol, protocol))
err := transport.Open()
if err != nil {
t.Fatal("Unable to open client socket", err)
diff --git a/lib/go/test/tests/protocol_mock.go b/lib/go/test/tests/protocol_mock.go
index 9197fed..8476c86 100644
--- a/lib/go/test/tests/protocol_mock.go
+++ b/lib/go/test/tests/protocol_mock.go
@@ -24,6 +24,7 @@
import (
thrift "thrift"
+
gomock "github.com/golang/mock/gomock"
)
diff --git a/lib/go/test/tests/protocols_test.go b/lib/go/test/tests/protocols_test.go
index 1580678..cffd9c3 100644
--- a/lib/go/test/tests/protocols_test.go
+++ b/lib/go/test/tests/protocols_test.go
@@ -47,7 +47,7 @@
t.Fatal(err)
}
var protocol thrift.TProtocol = protocolFactory.GetProtocol(transport)
- thriftTestClient := thrifttest.NewThriftTestClientProtocol(transport, protocol, protocol)
+ thriftTestClient := thrifttest.NewThriftTestClient(thrift.NewTStandardClient(protocol, protocol))
err = transport.Open()
if err != nil {
t.Fatal("Unable to open client socket", err)
diff --git a/lib/go/thrift/application_exception.go b/lib/go/thrift/application_exception.go
index 525bce2..b9d7eed 100644
--- a/lib/go/thrift/application_exception.go
+++ b/lib/go/thrift/application_exception.go
@@ -45,7 +45,7 @@
type TApplicationException interface {
TException
TypeId() int32
- Read(iprot TProtocol) (TApplicationException, error)
+ Read(iprot TProtocol) error
Write(oprot TProtocol) error
}
@@ -69,10 +69,11 @@
return p.type_
}
-func (p *tApplicationException) Read(iprot TProtocol) (TApplicationException, error) {
+func (p *tApplicationException) Read(iprot TProtocol) error {
+ // TODO: this should really be generated by the compiler
_, err := iprot.ReadStructBegin()
if err != nil {
- return nil, err
+ return err
}
message := ""
@@ -81,7 +82,7 @@
for {
_, ttype, id, err := iprot.ReadFieldBegin()
if err != nil {
- return nil, err
+ return err
}
if ttype == STOP {
break
@@ -90,33 +91,40 @@
case 1:
if ttype == STRING {
if message, err = iprot.ReadString(); err != nil {
- return nil, err
+ return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
- return nil, err
+ return err
}
}
case 2:
if ttype == I32 {
if type_, err = iprot.ReadI32(); err != nil {
- return nil, err
+ return err
}
} else {
if err = SkipDefaultDepth(iprot, ttype); err != nil {
- return nil, err
+ return err
}
}
default:
if err = SkipDefaultDepth(iprot, ttype); err != nil {
- return nil, err
+ return err
}
}
if err = iprot.ReadFieldEnd(); err != nil {
- return nil, err
+ return err
}
}
- return NewTApplicationException(type_, message), iprot.ReadStructEnd()
+ if err := iprot.ReadStructEnd(); err != nil {
+ return err
+ }
+
+ p.message = message
+ p.type_ = type_
+
+ return nil
}
func (p *tApplicationException) Write(oprot TProtocol) (err error) {
diff --git a/lib/go/thrift/client.go b/lib/go/thrift/client.go
new file mode 100644
index 0000000..8bdb53d
--- /dev/null
+++ b/lib/go/thrift/client.go
@@ -0,0 +1,78 @@
+package thrift
+
+import "fmt"
+
+type TStandardClient struct {
+ seqId int32
+ iprot, oprot TProtocol
+}
+
+// TStandardClient implements TClient, and uses the standard message format for Thrift.
+// It is not safe for concurrent use.
+func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClient {
+ return &TStandardClient{
+ iprot: inputProtocol,
+ oprot: outputProtocol,
+ }
+}
+
+func (p *TStandardClient) Send(oprot TProtocol, seqId int32, method string, args TStruct) error {
+ if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
+ return err
+ }
+ if err := args.Write(oprot); err != nil {
+ return err
+ }
+ if err := oprot.WriteMessageEnd(); err != nil {
+ return err
+ }
+ return oprot.Flush()
+}
+
+func (p *TStandardClient) Recv(iprot TProtocol, seqId int32, method string, result TStruct) error {
+ rMethod, rTypeId, rSeqId, err := iprot.ReadMessageBegin()
+ if err != nil {
+ return err
+ }
+
+ if method != rMethod {
+ return NewTApplicationException(WRONG_METHOD_NAME, fmt.Sprintf("%s: wrong method name", method))
+ } else if seqId != rSeqId {
+ return NewTApplicationException(BAD_SEQUENCE_ID, fmt.Sprintf("%s: out of order sequence response", method))
+ } else if rTypeId == EXCEPTION {
+ var exception tApplicationException
+ if err := exception.Read(iprot); err != nil {
+ return err
+ }
+
+ if err := iprot.ReadMessageEnd(); err != nil {
+ return err
+ }
+
+ return &exception
+ } else if rTypeId != REPLY {
+ return NewTApplicationException(INVALID_MESSAGE_TYPE_EXCEPTION, fmt.Sprintf("%s: invalid message type", method))
+ }
+
+ if err := result.Read(iprot); err != nil {
+ return err
+ }
+
+ return iprot.ReadMessageEnd()
+}
+
+func (p *TStandardClient) call(method string, args, result TStruct) error {
+ p.seqId++
+ seqId := p.seqId
+
+ if err := p.Send(p.oprot, seqId, method, args); err != nil {
+ return err
+ }
+
+ // method is oneway
+ if result == nil {
+ return nil
+ }
+
+ return p.Recv(p.iprot, seqId, method, result)
+}
diff --git a/lib/go/thrift/client_go17.go b/lib/go/thrift/client_go17.go
new file mode 100644
index 0000000..15c1c52
--- /dev/null
+++ b/lib/go/thrift/client_go17.go
@@ -0,0 +1,13 @@
+// +build go1.7
+
+package thrift
+
+import "context"
+
+type TClient interface {
+ Call(ctx context.Context, method string, args, result TStruct) error
+}
+
+func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
+ return p.call(method, args, result)
+}
diff --git a/lib/go/thrift/client_pre_go17.go b/lib/go/thrift/client_pre_go17.go
new file mode 100644
index 0000000..d2e99ef
--- /dev/null
+++ b/lib/go/thrift/client_pre_go17.go
@@ -0,0 +1,13 @@
+// +build !go1.7
+
+package thrift
+
+import "golang.org/x/net/context"
+
+type TClient interface {
+ Call(ctx context.Context, method string, args, result TStruct) error
+}
+
+func (p *TStandardClient) Call(ctx context.Context, method string, args, result TStruct) error {
+ return p.call(method, args, result)
+}
diff --git a/test/go/Makefile.am b/test/go/Makefile.am
index db27258..6bc97f5 100644
--- a/test/go/Makefile.am
+++ b/test/go/Makefile.am
@@ -30,6 +30,8 @@
ThriftTest.thrift: $(THRIFTTEST)
grep -v list.*map.*list.*map $(THRIFTTEST) > ThriftTest.thrift
+.PHONY: gopath
+
# Thrift for GO has problems with complex map keys: THRIFT-2063
gopath: $(THRIFT) ThriftTest.thrift
mkdir -p src/gen
diff --git a/test/go/src/bin/testclient/main.go b/test/go/src/bin/testclient/main.go
index b34c539..ab24cbf 100644
--- a/test/go/src/bin/testclient/main.go
+++ b/test/go/src/bin/testclient/main.go
@@ -38,7 +38,7 @@
func main() {
flag.Parse()
- client, err := common.StartClient(*host, *port, *domain_socket, *transport, *protocol, *ssl)
+ client, _, err := common.StartClient(*host, *port, *domain_socket, *transport, *protocol, *ssl)
if err != nil {
t.Fatalf("Unable to start client: ", err)
}
@@ -128,7 +128,7 @@
}
bin, err := client.TestBinary(defaultCtx, binout)
for i := 0; i < 256; i++ {
- if (binout[i] != bin[i]) {
+ if binout[i] != bin[i] {
t.Fatalf("Unexpected TestBinary() result expected %d, got %d ", binout[i], bin[i])
}
}
@@ -224,21 +224,21 @@
}
crazy := thrifttest.NewInsanity()
- crazy.UserMap = map[thrifttest.Numberz]thrifttest.UserId {
- thrifttest.Numberz_FIVE: 5,
+ crazy.UserMap = map[thrifttest.Numberz]thrifttest.UserId{
+ thrifttest.Numberz_FIVE: 5,
thrifttest.Numberz_EIGHT: 8,
}
truck1 := thrifttest.NewXtruct()
truck1.StringThing = "Goodbye4"
- truck1.ByteThing = 4;
- truck1.I32Thing = 4;
- truck1.I64Thing = 4;
+ truck1.ByteThing = 4
+ truck1.I32Thing = 4
+ truck1.I64Thing = 4
truck2 := thrifttest.NewXtruct()
truck2.StringThing = "Hello2"
- truck2.ByteThing = 2;
- truck2.I32Thing = 2;
- truck2.I64Thing = 2;
- crazy.Xtructs = []*thrifttest.Xtruct {
+ truck2.ByteThing = 2
+ truck2.I32Thing = 2
+ truck2.I64Thing = 2
+ crazy.Xtructs = []*thrifttest.Xtruct{
truck1,
truck2,
}
@@ -248,17 +248,17 @@
}
if !reflect.DeepEqual(crazy, insanity[1][2]) {
t.Fatalf("Unexpected TestInsanity() first result expected %#v, got %#v ",
- crazy,
- insanity[1][2])
+ crazy,
+ insanity[1][2])
}
if !reflect.DeepEqual(crazy, insanity[1][3]) {
t.Fatalf("Unexpected TestInsanity() second result expected %#v, got %#v ",
- crazy,
- insanity[1][3])
+ crazy,
+ insanity[1][3])
}
if len(insanity[2][6].UserMap) > 0 || len(insanity[2][6].Xtructs) > 0 {
t.Fatalf("Unexpected TestInsanity() non-empty result got %#v ",
- insanity[2][6])
+ insanity[2][6])
}
xxsret, err := client.TestMulti(defaultCtx, 42, 4242, 424242, map[int16]string{1: "blah", 2: "thing"}, thrifttest.Numberz_EIGHT, thrifttest.UserId(24))
diff --git a/test/go/src/common/client.go b/test/go/src/common/client.go
index 4251d91..236ce43 100644
--- a/test/go/src/common/client.go
+++ b/test/go/src/common/client.go
@@ -41,7 +41,7 @@
domain_socket string,
transport string,
protocol string,
- ssl bool) (client *thrifttest.ThriftTestClient, err error) {
+ ssl bool) (client *thrifttest.ThriftTestClient, trans thrift.TTransport, err error) {
hostPort := fmt.Sprintf("%s:%d", host, port)
@@ -56,12 +56,11 @@
case "binary":
protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
default:
- return nil, fmt.Errorf("Invalid protocol specified %s", protocol)
+ return nil, nil, fmt.Errorf("Invalid protocol specified %s", protocol)
}
if debugClientProtocol {
protocolFactory = thrift.NewTDebugProtocolFactory(protocolFactory, "client:")
}
- var trans thrift.TTransport
if ssl {
trans, err = thrift.NewTSSLSocket(hostPort, &tls.Config{InsecureSkipVerify: true})
} else {
@@ -72,7 +71,7 @@
}
}
if err != nil {
- return nil, err
+ return nil, nil, err
}
switch transport {
case "http":
@@ -86,29 +85,25 @@
} else {
trans, err = thrift.NewTHttpPostClient(fmt.Sprintf("http://%s/", hostPort))
}
-
- if err != nil {
- return nil, err
- }
-
case "framed":
trans = thrift.NewTFramedTransport(trans)
case "buffered":
trans = thrift.NewTBufferedTransport(trans, 8192)
case "zlib":
trans, err = thrift.NewTZlibTransport(trans, zlib.BestCompression)
- if err != nil {
- return nil, err
- }
case "":
trans = trans
default:
- return nil, fmt.Errorf("Invalid transport specified %s", transport)
+ return nil, nil, fmt.Errorf("Invalid transport specified %s", transport)
}
-
+ if err != nil {
+ return nil, nil, err
+ }
if err = trans.Open(); err != nil {
- return nil, err
+ return nil, nil, err
}
- client = thrifttest.NewThriftTestClientFactory(trans, protocolFactory)
+ iprot := protocolFactory.GetProtocol(trans)
+ oprot := protocolFactory.GetProtocol(trans)
+ client = thrifttest.NewThriftTestClient(thrift.NewTStandardClient(iprot, oprot))
return
}
diff --git a/test/go/src/common/clientserver_test.go b/test/go/src/common/clientserver_test.go
index ecd021f..c4cfd44 100644
--- a/test/go/src/common/clientserver_test.go
+++ b/test/go/src/common/clientserver_test.go
@@ -23,6 +23,7 @@
"errors"
"gen/thrifttest"
"reflect"
+ "sync"
"testing"
"thrift"
@@ -47,10 +48,15 @@
func TestAllConnection(t *testing.T) {
certPath = "../../../keys"
+ wg := &sync.WaitGroup{}
+ wg.Add(len(units))
for _, unit := range units {
- t.Logf("%#v", unit)
- doUnit(t, &unit)
+ go func(u test_unit) {
+ defer wg.Done()
+ doUnit(t, &u)
+ }(unit)
}
+ wg.Wait()
}
func doUnit(t *testing.T, unit *test_unit) {
@@ -62,17 +68,17 @@
server := thrift.NewTSimpleServer4(processor, serverTransport, transportFactory, protocolFactory)
if err = server.Listen(); err != nil {
- t.Errorf("Unable to start server", err)
- t.FailNow()
+ t.Errorf("Unable to start server: %v", err)
+ return
}
go server.AcceptLoop()
defer server.Stop()
- client, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl)
+ client, trans, err := StartClient(unit.host, unit.port, unit.domain_socket, unit.transport, unit.protocol, unit.ssl)
if err != nil {
- t.Errorf("Unable to start client", err)
- t.FailNow()
+ t.Errorf("Unable to start client: %v", err)
+ return
}
- defer client.Transport.Close()
+ defer trans.Close()
callEverythingWithMock(t, client, handler)
}
@@ -273,7 +279,7 @@
xxsret, err := client.TestMulti(defaultCtx, 42, 4242, 424242, map[int16]string{1: "blah", 2: "thing"}, thrifttest.Numberz_EIGHT, thrifttest.UserId(24))
if err != nil {
- t.Errorf("Unexpected error in TestMulti() call: ", err)
+ t.Errorf("Unexpected error in TestMulti() call: %v", err)
}
if !reflect.DeepEqual(xxs, xxsret) {
t.Errorf("Unexpected TestMulti() result expected %#v, got %#v ", xxs, xxsret)
@@ -289,9 +295,12 @@
// TODO: connection is being closed on this
err = client.TestException(defaultCtx, "TException")
- tex, ok := err.(thrift.TApplicationException)
- if err == nil || !ok || tex.TypeId() != thrift.INTERNAL_ERROR {
- t.Errorf("Unexpected TestException() result expected ApplicationError, got %#v ", err)
+ if err == nil {
+ t.Error("expected exception got nil")
+ } else if tex, ok := err.(thrift.TApplicationException); !ok {
+ t.Errorf("Unexpected TestException() result expected ApplicationError, got %T ", err)
+ } else if tex.TypeId() != thrift.INTERNAL_ERROR {
+ t.Errorf("expected internal_error got %v", tex.TypeId())
}
ign, err := client.TestMultiException(defaultCtx, "Xception", "ignoreme")
diff --git a/tutorial/go/src/client.go b/tutorial/go/src/client.go
index 65027ea..25616bf 100644
--- a/tutorial/go/src/client.go
+++ b/tutorial/go/src/client.go
@@ -22,8 +22,9 @@
import (
"crypto/tls"
"fmt"
- "git.apache.org/thrift.git/lib/go/thrift"
"tutorial"
+
+ "git.apache.org/thrift.git/lib/go/thrift"
)
func handleClient(client *tutorial.CalculatorClient) (err error) {
@@ -98,5 +99,7 @@
if err := transport.Open(); err != nil {
return err
}
- return handleClient(tutorial.NewCalculatorClientFactory(transport, protocolFactory))
+ iprot := protocolFactory.GetProtocol(transport)
+ oprot := protocolFactory.GetProtocol(transport)
+ return handleClient(tutorial.NewCalculatorClient(thrift.NewTStandardClient(iprot, oprot)))
}