THRIFT-4848: Add ability to set Content-Type,Accept headers in HTTP client
Client: netstd
Patch: Kyle Smith
This closes #1801
diff --git a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
index 5d7f1de..982e91e 100644
--- a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
@@ -33,46 +33,40 @@
private readonly X509Certificate[] _certificates;
private readonly Uri _uri;
- // Timeouts in milliseconds
- private int _connectTimeout = 30000;
+ private int _connectTimeout = 30000; // Timeouts in milliseconds
private HttpClient _httpClient;
private Stream _inputStream;
-
- private bool _isDisposed;
private MemoryStream _outputStream = new MemoryStream();
+ private bool _isDisposed;
- public THttpTransport(Uri u, IDictionary<string, string> customHeaders = null, string userAgent = null)
- : this(u, Enumerable.Empty<X509Certificate>(), customHeaders, userAgent)
+ public THttpTransport(Uri uri, IDictionary<string, string> customRequestHeaders = null, string userAgent = null)
+ : this(uri, Enumerable.Empty<X509Certificate>(), customRequestHeaders, userAgent)
{
}
- public THttpTransport(Uri u, IEnumerable<X509Certificate> certificates,
- IDictionary<string, string> customHeaders, string userAgent = null)
+ public THttpTransport(Uri uri, IEnumerable<X509Certificate> certificates,
+ IDictionary<string, string> customRequestHeaders, string userAgent = null)
{
- _uri = u;
+ _uri = uri;
_certificates = (certificates ?? Enumerable.Empty<X509Certificate>()).ToArray();
- CustomHeaders = customHeaders;
if (!string.IsNullOrEmpty(userAgent))
UserAgent = userAgent;
// due to current bug with performance of Dispose in netcore https://github.com/dotnet/corefx/issues/8809
// this can be switched to default way (create client->use->dispose per flush) later
- _httpClient = CreateClient();
+ _httpClient = CreateClient(customRequestHeaders);
}
// According to RFC 2616 section 3.8, the "User-Agent" header may not carry a version number
public readonly string UserAgent = "Thrift netstd THttpClient";
- public IDictionary<string, string> CustomHeaders { get; }
-
- public int ConnectTimeout
- {
- set { _connectTimeout = value; }
- }
-
public override bool IsOpen => true;
+ public HttpRequestHeaders RequestHeaders => _httpClient.DefaultRequestHeaders;
+
+ public MediaTypeHeaderValue ContentType { get; set; }
+
public override async Task OpenAsync(CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
@@ -141,7 +135,7 @@
await _outputStream.WriteAsync(buffer, offset, length, cancellationToken);
}
- private HttpClient CreateClient()
+ private HttpClient CreateClient(IDictionary<string, string> customRequestHeaders)
{
var handler = new HttpClientHandler();
handler.ClientCertificates.AddRange(_certificates);
@@ -155,10 +149,10 @@
httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/x-thrift"));
httpClient.DefaultRequestHeaders.UserAgent.TryParseAdd(UserAgent);
-
- if (CustomHeaders != null)
+
+ if (customRequestHeaders != null)
{
- foreach (var item in CustomHeaders)
+ foreach (var item in customRequestHeaders)
{
httpClient.DefaultRequestHeaders.Add(item.Key, item.Value);
}
@@ -171,41 +165,34 @@
{
try
{
- try
+ _outputStream.Seek(0, SeekOrigin.Begin);
+
+ using (var contentStream = new StreamContent(_outputStream))
{
- if (_outputStream.CanSeek)
+ contentStream.Headers.ContentType = ContentType ?? new MediaTypeHeaderValue(@"application/x-thrift");
+
+ var response = (await _httpClient.PostAsync(_uri, contentStream, cancellationToken)).EnsureSuccessStatusCode();
+
+ _inputStream?.Dispose();
+ _inputStream = await response.Content.ReadAsStreamAsync();
+ if (_inputStream.CanSeek)
{
- _outputStream.Seek(0, SeekOrigin.Begin);
- }
-
- using (var outStream = new StreamContent(_outputStream))
- {
- var msg = await _httpClient.PostAsync(_uri, outStream, cancellationToken);
-
- msg.EnsureSuccessStatusCode();
-
- if (_inputStream != null)
- {
- _inputStream.Dispose();
- _inputStream = null;
- }
-
- _inputStream = await msg.Content.ReadAsStreamAsync();
- if (_inputStream.CanSeek)
- {
- _inputStream.Seek(0, SeekOrigin.Begin);
- }
+ _inputStream.Seek(0, SeekOrigin.Begin);
}
}
- catch (IOException iox)
- {
- throw new TTransportException(TTransportException.ExceptionType.Unknown, iox.ToString());
- }
- catch (HttpRequestException wx)
- {
- throw new TTransportException(TTransportException.ExceptionType.Unknown,
- "Couldn't connect to server: " + wx);
- }
+ }
+ catch (IOException iox)
+ {
+ throw new TTransportException(TTransportException.ExceptionType.Unknown, iox.ToString());
+ }
+ catch (HttpRequestException wx)
+ {
+ throw new TTransportException(TTransportException.ExceptionType.Unknown,
+ "Couldn't connect to server: " + wx);
+ }
+ catch (Exception ex)
+ {
+ throw new TTransportException(TTransportException.ExceptionType.Unknown, ex.Message);
}
finally
{