blob: f66da922e81d8e580eef106f7743eedba804ddb9 [file] [log] [blame]
Chris Simpsona9b6c702018-04-08 07:11:37 -04001/*
2* Licensed to the Apache Software Foundation (ASF) under one
3* or more contributor license agreements. See the NOTICE file
4* distributed with this work for additional information
5* regarding copyright ownership. The ASF licenses this file
6* to you under the Apache License, Version 2.0 (the
7* "License"); you may not use this file except in compliance
8* with the License. You may obtain a copy of the License at
9*
10* http://www.apache.org/licenses/LICENSE-2.0
11*
12* Unless required by applicable law or agreed to in writing,
13* software distributed under the License is distributed on an
14* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15* KIND, either express or implied. See the License for the
16* specific language governing permissions and limitations
17* under the License.
18*/
19
20import Foundation
Chris Simpsona9b6c702018-04-08 07:11:37 -040021
Jano Svitok2e115772020-03-06 09:01:43 +010022// Conditional import for URLRequest
23// It was moved from Foundation to FoundationNetworking in 5.1, but
24// not on Darwin. See https://stackoverflow.com/a/58606520
25#if canImport(FoundationNetworking)
26import FoundationNetworking
27#endif
28
29import Dispatch
Chris Simpsona9b6c702018-04-08 07:11:37 -040030
31public class THTTPSessionTransport: TAsyncTransport {
32 public class Factory : TAsyncTransportFactory {
33 public var responseValidate: ((HTTPURLResponse?, Data?) throws -> Void)?
34
35 var session: URLSession
36 var url: URL
37
38 public class func setupDefaultsForSessionConfiguration(_ config: URLSessionConfiguration, withProtocolName protocolName: String?) {
39 var thriftContentType = "application/x-thrift"
40
41 if let protocolName = protocolName {
42 thriftContentType += "; p=\(protocolName)"
43 }
44
45 config.requestCachePolicy = .reloadIgnoringLocalCacheData
46 config.urlCache = nil
47
48 config.httpShouldUsePipelining = true
49 config.httpShouldSetCookies = true
50 config.httpAdditionalHeaders = ["Content-Type": thriftContentType,
51 "Accept": thriftContentType,
52 "User-Agent": "Thrift/Swift (Session)"]
53
54
55 }
56
57 public init(session: URLSession, url: URL) {
58 self.session = session
59 self.url = url
60 }
61
62 public func newTransport() -> THTTPSessionTransport {
63 return THTTPSessionTransport(factory: self)
64 }
65
66 func validateResponse(_ response: HTTPURLResponse?, data: Data?) throws {
67 try responseValidate?(response, data)
68 }
69
70 func taskWithRequest(_ request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> ()) throws -> URLSessionTask {
71
72 let newTask: URLSessionTask? = session.dataTask(with: request, completionHandler: completionHandler)
73 if let newTask = newTask {
74 return newTask
75 } else {
76 throw TTransportError(error: .unknown, message: "Failed to create session data task")
77 }
78 }
79 }
80
81 var factory: Factory
82 var requestData = Data()
83 var responseData = Data()
84 var responseDataOffset: Int = 0
85
86 init(factory: Factory) {
87 self.factory = factory
88 }
89
90 public func readAll(size: Int) throws -> Data {
91 let read = try self.read(size: size)
92 if read.count != size {
93 throw TTransportError(error: .endOfFile)
94 }
95 return read
96 }
97
98 public func read(size: Int) throws -> Data {
99 let avail = responseData.count - responseDataOffset
100 let (start, stop) = (responseDataOffset, responseDataOffset + min(size, avail))
101 let read = responseData.subdata(in: start..<stop)
102 responseDataOffset += read.count
103 return read
104 }
105
106 public func write(data: Data) throws {
107 requestData.append(data)
108 }
109
110 public func flush(_ completed: @escaping (TAsyncTransport, Error?) -> Void) {
111 var error: Error?
112 var task: URLSessionTask?
113
114 var request = URLRequest(url: factory.url)
115 request.httpMethod = "POST"
116 request.httpBody = requestData
117
118 requestData = Data()
119
120 do {
121 task = try factory.taskWithRequest(request, completionHandler: { (data, response, taskError) in
122
123 // Check if there was an error with the network
124 if taskError != nil {
125 error = TTransportError(error: .timedOut)
126 completed(self, error)
127 return
128 }
129
130 // Check response type
131 if taskError == nil && !(response is HTTPURLResponse) {
132 error = THTTPTransportError(error: .invalidResponse)
133 completed(self, error)
134 return
135 }
136
137 // Check status code
138 if let httpResponse = response as? HTTPURLResponse {
139 if taskError == nil && httpResponse.statusCode != 200 {
140 if httpResponse.statusCode == 401 {
141 error = THTTPTransportError(error: .authentication)
142 } else {
143 error = THTTPTransportError(error: .invalidStatus(statusCode: httpResponse.statusCode))
144 }
145 }
146
147 // Allow factory to check
148 if error != nil {
149 do {
150 try self.factory.validateResponse(httpResponse, data: data)
151 } catch let validateError {
152 error = validateError
153 }
154 }
155
156 self.responseDataOffset = 0
157 if error != nil {
158 self.responseData = Data()
159 } else {
160 self.responseData = data ?? Data()
161 }
162 completed(self, error)
163 }
164 })
165
166 } catch let taskError {
167 error = taskError
168 }
169
170 if let error = error, task == nil {
171 completed(self, error)
172 }
173 task?.resume()
174 }
175
176 public func flush() throws {
177 let completed = DispatchSemaphore(value: 0)
178 var internalError: Error?
179
180 flush() { _, error in
181 internalError = error
182 completed.signal()
183 }
184
185 _ = completed.wait(timeout: DispatchTime.distantFuture)
186
187 if let error = internalError {
188 throw error
189 }
190 }
191}