THRIFT-2377 Allow addition of custom HTTP Headers to an HTTP Transport
Patch: Sheran Gunasekera
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index 18b1671..9f60992 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -30,6 +30,7 @@
response *http.Response
url *url.URL
requestBuffer *bytes.Buffer
+ header http.Header
nsecConnectTimeout int64
nsecReadTimeout int64
}
@@ -85,7 +86,37 @@
return nil, err
}
buf := make([]byte, 0, 1024)
- return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf)}, nil
+ return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: http.Header{}}, nil
+}
+
+// Set the HTTP Header for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// httpTrans.SetHeader("User-Agent","Thrift Client 1.0")
+func (p *THttpClient) SetHeader(key string, value string) {
+ p.header.Add(key, value)
+}
+
+// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// hdrValue := httpTrans.GetHeader("User-Agent")
+func (p *THttpClient) GetHeader(key string) string {
+ return p.header.Get(key)
+}
+
+// Deletes the HTTP Header given a Header Key for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// httpTrans.DelHeader("User-Agent")
+func (p *THttpClient) DelHeader(key string) {
+ p.header.Del(key)
}
func (p *THttpClient) Open() error {
@@ -128,7 +159,14 @@
}
func (p *THttpClient) Flush() error {
- response, err := http.Post(p.url.String(), "application/x-thrift", p.requestBuffer)
+ client := &http.Client{}
+ req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
+ if err != nil {
+ return NewTTransportExceptionFromError(err)
+ }
+ p.header.Add("Content-Type", "application/x-thrift")
+ req.Header = p.header
+ response, err := client.Do(req)
if err != nil {
return NewTTransportExceptionFromError(err)
}
diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go
index 041faec..0c2cb28 100644
--- a/lib/go/thrift/http_client_test.go
+++ b/lib/go/thrift/http_client_test.go
@@ -35,3 +35,16 @@
}
TransportTest(t, trans, trans)
}
+
+func TestHttpClientHeaders(t *testing.T) {
+ l, addr := HttpClientSetupForTest(t)
+ if l != nil {
+ defer l.Close()
+ }
+ trans, err := NewTHttpPostClient("http://" + addr.String())
+ if err != nil {
+ l.Close()
+ t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ }
+ TransportHeaderTest(t, trans, trans)
+}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index 632098c..d88afed 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -58,6 +58,7 @@
}
type HTTPEchoServer struct{}
+type HTTPHeaderEchoServer struct{}
func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
buf, err := ioutil.ReadAll(req.Body)
@@ -70,6 +71,17 @@
}
}
+func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+ buf, err := ioutil.ReadAll(req.Body)
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ w.Write(buf)
+ } else {
+ w.WriteHeader(http.StatusOK)
+ w.Write(buf)
+ }
+}
+
func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) {
addr, err := FindAvailableTCPServerPort(40000)
if err != nil {
@@ -85,6 +97,21 @@
return l, addr
}
+func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) {
+ addr, err := FindAvailableTCPServerPort(40000)
+ if err != nil {
+ t.Fatalf("Unable to find available tcp port addr: %s", err)
+ return nil, addr
+ }
+ l, err := net.Listen(addr.Network(), addr.String())
+ if err != nil {
+ t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err)
+ return l, addr
+ }
+ go http.Serve(l, &HTTPHeaderEchoServer{})
+ return l, addr
+}
+
func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
l, addr := HttpClientSetupForTest(t)
@@ -145,13 +172,13 @@
}
for _, tf := range transports {
- trans := tf.GetTransport(nil)
- p := protocolFactory.GetProtocol(trans);
- ReadWriteI64(t, p, trans);
- ReadWriteDouble(t, p, trans);
- ReadWriteBinary(t, p, trans);
- ReadWriteByte(t, p, trans);
- trans.Close()
+ trans := tf.GetTransport(nil)
+ p := protocolFactory.GetProtocol(trans)
+ ReadWriteI64(t, p, trans)
+ ReadWriteDouble(t, p, trans)
+ ReadWriteBinary(t, p, trans)
+ ReadWriteByte(t, p, trans)
+ trans.Close()
}
}
diff --git a/lib/go/thrift/transport_test.go b/lib/go/thrift/transport_test.go
index c9f1d56..864958a 100644
--- a/lib/go/thrift/transport_test.go
+++ b/lib/go/thrift/transport_test.go
@@ -29,7 +29,8 @@
const TRANSPORT_BINARY_DATA_SIZE = 4096
var (
- transport_bdata []byte // test data for writing; same as data
+ transport_bdata []byte // test data for writing; same as data
+ transport_header map[string]string
)
func init() {
@@ -37,6 +38,8 @@
for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ {
transport_bdata[i] = byte((i + 'a') % 255)
}
+ transport_header = map[string]string{"key": "User-Agent",
+ "value": "Mozilla/5.0 (Windows NT 6.2; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/32.0.1667.0 Safari/537.36"}
}
func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
@@ -94,6 +97,50 @@
}
}
+func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
+ buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE)
+ if !writeTrans.IsOpen() {
+ t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans)
+ }
+ if !readTrans.IsOpen() {
+ t.Fatalf("Transport %T not open: %s", readTrans, readTrans)
+ }
+ // Need to assert type of TTransport to THttpClient to expose the Setter
+ httpWPostTrans := writeTrans.(*THttpClient)
+ httpWPostTrans.SetHeader(transport_header["key"], transport_header["value"])
+
+ _, err := writeTrans.Write(transport_bdata)
+ if err != nil {
+ t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err)
+ }
+ err = writeTrans.Flush()
+ if err != nil {
+ t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err)
+ }
+ // Need to assert type of TTransport to THttpClient to expose the Getter
+ httpRPostTrans := readTrans.(*THttpClient)
+ readHeader := httpRPostTrans.GetHeader(transport_header["key"])
+ if err != nil {
+ t.Errorf("Transport %T cannot read HTTP Header Value", httpRPostTrans)
+ }
+
+ if transport_header["value"] != readHeader {
+ t.Errorf("Expected HTTP Header Value %s, got %s", transport_header["value"], readHeader)
+ }
+ n, err := io.ReadFull(readTrans, buf)
+ if err != nil {
+ t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err)
+ }
+ if n != TRANSPORT_BINARY_DATA_SIZE {
+ t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE)
+ }
+ for k, v := range buf {
+ if v != transport_bdata[k] {
+ t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k)
+ }
+ }
+}
+
func CloseTransports(t *testing.T, readTrans TTransport, writeTrans TTransport) {
err := readTrans.Close()
if err != nil {
@@ -118,3 +165,12 @@
}
return nil, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Could not find available server port")
}
+
+func valueInSlice(value string, slice []string) bool {
+ for _, v := range slice {
+ if value == v {
+ return true
+ }
+ }
+ return false
+}