THRIFT-2377 Allow addition of custom HTTP Headers to an HTTP Transport

Patch: Sheran Gunasekera
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index 18b1671..9f60992 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -30,6 +30,7 @@
 	response           *http.Response
 	url                *url.URL
 	requestBuffer      *bytes.Buffer
+	header             http.Header
 	nsecConnectTimeout int64
 	nsecReadTimeout    int64
 }
@@ -85,7 +86,37 @@
 		return nil, err
 	}
 	buf := make([]byte, 0, 1024)
-	return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf)}, nil
+	return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: http.Header{}}, nil
+}
+
+// Set the HTTP Header for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// httpTrans.SetHeader("User-Agent","Thrift Client 1.0")
+func (p *THttpClient) SetHeader(key string, value string) {
+	p.header.Add(key, value)
+}
+
+// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// hdrValue := httpTrans.GetHeader("User-Agent")
+func (p *THttpClient) GetHeader(key string) string {
+	return p.header.Get(key)
+}
+
+// Deletes the HTTP Header given a Header Key for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// httpTrans.DelHeader("User-Agent")
+func (p *THttpClient) DelHeader(key string) {
+	p.header.Del(key)
 }
 
 func (p *THttpClient) Open() error {
@@ -128,7 +159,14 @@
 }
 
 func (p *THttpClient) Flush() error {
-	response, err := http.Post(p.url.String(), "application/x-thrift", p.requestBuffer)
+	client := &http.Client{}
+	req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
+	if err != nil {
+		return NewTTransportExceptionFromError(err)
+	}
+	p.header.Add("Content-Type", "application/x-thrift")
+	req.Header = p.header
+	response, err := client.Do(req)
 	if err != nil {
 		return NewTTransportExceptionFromError(err)
 	}
diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go
index 041faec..0c2cb28 100644
--- a/lib/go/thrift/http_client_test.go
+++ b/lib/go/thrift/http_client_test.go
@@ -35,3 +35,16 @@
 	}
 	TransportTest(t, trans, trans)
 }
+
+func TestHttpClientHeaders(t *testing.T) {
+	l, addr := HttpClientSetupForTest(t)
+	if l != nil {
+		defer l.Close()
+	}
+	trans, err := NewTHttpPostClient("http://" + addr.String())
+	if err != nil {
+		l.Close()
+		t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+	}
+	TransportHeaderTest(t, trans, trans)
+}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index 632098c..d88afed 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -58,6 +58,7 @@
 }
 
 type HTTPEchoServer struct{}
+type HTTPHeaderEchoServer struct{}
 
 func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 	buf, err := ioutil.ReadAll(req.Body)
@@ -70,6 +71,17 @@
 	}
 }
 
+func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	buf, err := ioutil.ReadAll(req.Body)
+	if err != nil {
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write(buf)
+	} else {
+		w.WriteHeader(http.StatusOK)
+		w.Write(buf)
+	}
+}
+
 func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) {
 	addr, err := FindAvailableTCPServerPort(40000)
 	if err != nil {
@@ -85,6 +97,21 @@
 	return l, addr
 }
 
+func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) {
+	addr, err := FindAvailableTCPServerPort(40000)
+	if err != nil {
+		t.Fatalf("Unable to find available tcp port addr: %s", err)
+		return nil, addr
+	}
+	l, err := net.Listen(addr.Network(), addr.String())
+	if err != nil {
+		t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err)
+		return l, addr
+	}
+	go http.Serve(l, &HTTPHeaderEchoServer{})
+	return l, addr
+}
+
 func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
 	buf := bytes.NewBuffer(make([]byte, 0, 1024))
 	l, addr := HttpClientSetupForTest(t)
@@ -145,13 +172,13 @@
 	}
 
 	for _, tf := range transports {
-	  trans := tf.GetTransport(nil)
-	  p := protocolFactory.GetProtocol(trans);
-	  ReadWriteI64(t, p, trans);
-	  ReadWriteDouble(t, p, trans);
-	  ReadWriteBinary(t, p, trans);
-	  ReadWriteByte(t, p, trans);
-	  trans.Close()
+		trans := tf.GetTransport(nil)
+		p := protocolFactory.GetProtocol(trans)
+		ReadWriteI64(t, p, trans)
+		ReadWriteDouble(t, p, trans)
+		ReadWriteBinary(t, p, trans)
+		ReadWriteByte(t, p, trans)
+		trans.Close()
 	}
 
 }
diff --git a/lib/go/thrift/transport_test.go b/lib/go/thrift/transport_test.go
index c9f1d56..864958a 100644
--- a/lib/go/thrift/transport_test.go
+++ b/lib/go/thrift/transport_test.go
@@ -29,7 +29,8 @@
 const TRANSPORT_BINARY_DATA_SIZE = 4096
 
 var (
-	transport_bdata []byte // test data for writing; same as data
+	transport_bdata  []byte // test data for writing; same as data
+	transport_header map[string]string
 )
 
 func init() {
@@ -37,6 +38,8 @@
 	for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ {
 		transport_bdata[i] = byte((i + 'a') % 255)
 	}
+	transport_header = map[string]string{"key": "User-Agent",
+		"value": "Mozilla/5.0 (Windows NT 6.2; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/32.0.1667.0 Safari/537.36"}
 }
 
 func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
@@ -94,6 +97,50 @@
 	}
 }
 
+func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
+	buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE)
+	if !writeTrans.IsOpen() {
+		t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans)
+	}
+	if !readTrans.IsOpen() {
+		t.Fatalf("Transport %T not open: %s", readTrans, readTrans)
+	}
+	// Need to assert type of TTransport to THttpClient to expose the Setter
+	httpWPostTrans := writeTrans.(*THttpClient)
+	httpWPostTrans.SetHeader(transport_header["key"], transport_header["value"])
+
+	_, err := writeTrans.Write(transport_bdata)
+	if err != nil {
+		t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err)
+	}
+	err = writeTrans.Flush()
+	if err != nil {
+		t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err)
+	}
+	// Need to assert type of TTransport to THttpClient to expose the Getter
+	httpRPostTrans := readTrans.(*THttpClient)
+	readHeader := httpRPostTrans.GetHeader(transport_header["key"])
+	if err != nil {
+		t.Errorf("Transport %T cannot read HTTP Header Value", httpRPostTrans)
+	}
+
+	if transport_header["value"] != readHeader {
+		t.Errorf("Expected HTTP Header Value %s, got %s", transport_header["value"], readHeader)
+	}
+	n, err := io.ReadFull(readTrans, buf)
+	if err != nil {
+		t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err)
+	}
+	if n != TRANSPORT_BINARY_DATA_SIZE {
+		t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE)
+	}
+	for k, v := range buf {
+		if v != transport_bdata[k] {
+			t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k)
+		}
+	}
+}
+
 func CloseTransports(t *testing.T, readTrans TTransport, writeTrans TTransport) {
 	err := readTrans.Close()
 	if err != nil {
@@ -118,3 +165,12 @@
 	}
 	return nil, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Could not find available server port")
 }
+
+func valueInSlice(value string, slice []string) bool {
+	for _, v := range slice {
+		if value == v {
+			return true
+		}
+	}
+	return false
+}