THttpClient now utilizes a package level shared HTTP client and optionally allows users of the library to specify one.
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index bf0ed62..16f1cdd 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -27,7 +27,13 @@
"strconv"
)
+// Default to using the shared http client. Library users are
+// free to change this global client or specify one through
+// THttpClientOptions.
+var DefaultHttpClient *http.Client = http.DefaultClient
+
type THttpClient struct {
+ client *http.Client
response *http.Response
url *url.URL
requestBuffer *bytes.Buffer
@@ -37,8 +43,9 @@
}
type THttpClientTransportFactory struct {
- url string
- isPost bool
+ options THttpClientOptions
+ url string
+ isPost bool
}
func (p *THttpClientTransportFactory) GetTransport(trans TTransport) TTransport {
@@ -46,30 +53,43 @@
t, ok := trans.(*THttpClient)
if ok && t.url != nil {
if t.requestBuffer != nil {
- t2, _ := NewTHttpPostClient(t.url.String())
+ t2, _ := NewTHttpPostClientWithOptions(t.url.String(), p.options)
return t2
}
- t2, _ := NewTHttpClient(t.url.String())
+ t2, _ := NewTHttpClientWithOptions(t.url.String(), p.options)
return t2
}
}
if p.isPost {
- s, _ := NewTHttpPostClient(p.url)
+ s, _ := NewTHttpPostClientWithOptions(p.url, p.options)
return s
}
- s, _ := NewTHttpClient(p.url)
+ s, _ := NewTHttpClientWithOptions(p.url, p.options)
return s
}
+type THttpClientOptions struct {
+ // If nil, DefaultHttpClient is used
+ Client *http.Client
+}
+
func NewTHttpClientTransportFactory(url string) *THttpClientTransportFactory {
- return &THttpClientTransportFactory{url: url, isPost: false}
+ return NewTHttpClientTransportFactoryWithOptions(url, THttpClientOptions{})
+}
+
+func NewTHttpClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
+ return &THttpClientTransportFactory{url: url, isPost: false, options: options}
}
func NewTHttpPostClientTransportFactory(url string) *THttpClientTransportFactory {
- return &THttpClientTransportFactory{url: url, isPost: true}
+ return NewTHttpPostClientTransportFactoryWithOptions(url, THttpClientOptions{})
}
-func NewTHttpClient(urlstr string) (TTransport, error) {
+func NewTHttpPostClientTransportFactoryWithOptions(url string, options THttpClientOptions) *THttpClientTransportFactory {
+ return &THttpClientTransportFactory{url: url, isPost: true, options: options}
+}
+
+func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
parsedURL, err := url.Parse(urlstr)
if err != nil {
return nil, err
@@ -78,16 +98,32 @@
if err != nil {
return nil, err
}
- return &THttpClient{response: response, url: parsedURL}, nil
+ client := options.Client
+ if client == nil {
+ client = DefaultHttpClient
+ }
+ return &THttpClient{client: client, response: response, url: parsedURL}, nil
}
-func NewTHttpPostClient(urlstr string) (TTransport, error) {
+func NewTHttpClient(urlstr string) (TTransport, error) {
+ return NewTHttpClientWithOptions(urlstr, THttpClientOptions{})
+}
+
+func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (TTransport, error) {
parsedURL, err := url.Parse(urlstr)
if err != nil {
return nil, err
}
buf := make([]byte, 0, 1024)
- return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: http.Header{}}, nil
+ client := options.Client
+ if client == nil {
+ client = DefaultHttpClient
+ }
+ return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: http.Header{}}, nil
+}
+
+func NewTHttpPostClient(urlstr string) (TTransport, error) {
+ return NewTHttpPostClientWithOptions(urlstr, THttpClientOptions{})
}
// Set the HTTP Header for this specific Thrift Transport
@@ -179,14 +215,13 @@
// Close any previous response body to avoid leaking connections.
p.closeResponse()
- client := &http.Client{}
req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
if err != nil {
return NewTTransportExceptionFromError(err)
}
p.header.Add("Content-Type", "application/x-thrift")
req.Header = p.header
- response, err := client.Do(req)
+ response, err := p.client.Do(req)
if err != nil {
return NewTTransportExceptionFromError(err)
}
@@ -201,12 +236,11 @@
}
func (p *THttpClient) RemainingBytes() (num_bytes uint64) {
- len := p.response.ContentLength
+ len := p.response.ContentLength
if len >= 0 {
return uint64(len)
}
-
- const maxSize = ^uint64(0)
- return maxSize // the thruth is, we just don't know unless framed is used
-}
+ const maxSize = ^uint64(0)
+ return maxSize // the thruth is, we just don't know unless framed is used
+}
diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go
index 0c2cb28..453680a 100644
--- a/lib/go/thrift/http_client_test.go
+++ b/lib/go/thrift/http_client_test.go
@@ -20,6 +20,7 @@
package thrift
import (
+ "net/http"
"testing"
)
@@ -48,3 +49,58 @@
}
TransportHeaderTest(t, trans, trans)
}
+
+func TestHttpCustomClient(t *testing.T) {
+ l, addr := HttpClientSetupForTest(t)
+ if l != nil {
+ defer l.Close()
+ }
+
+ httpTransport := &customHttpTransport{}
+
+ trans, err := NewTHttpPostClientWithOptions("http://"+addr.String(), THttpClientOptions{
+ Client: &http.Client{
+ Transport: httpTransport,
+ },
+ })
+ if err != nil {
+ l.Close()
+ t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ }
+ TransportHeaderTest(t, trans, trans)
+
+ if !httpTransport.hit {
+ t.Fatalf("Custom client was not used")
+ }
+}
+
+func TestHttpCustomClientPackageScope(t *testing.T) {
+ l, addr := HttpClientSetupForTest(t)
+ if l != nil {
+ defer l.Close()
+ }
+ httpTransport := &customHttpTransport{}
+ DefaultHttpClient = &http.Client{
+ Transport: httpTransport,
+ }
+
+ trans, err := NewTHttpPostClient("http://" + addr.String())
+ if err != nil {
+ l.Close()
+ t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+ }
+ TransportHeaderTest(t, trans, trans)
+
+ if !httpTransport.hit {
+ t.Fatalf("Custom client was not used")
+ }
+}
+
+type customHttpTransport struct {
+ hit bool
+}
+
+func (c *customHttpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ c.hit = true
+ return http.DefaultTransport.RoundTrip(req)
+}