| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| import Foundation |
| |
| // Conditional import for URLRequest |
| // It was moved from Foundation to FoundationNetworking in 5.1, but |
| // not on Darwin. See https://stackoverflow.com/a/58606520 |
| #if canImport(FoundationNetworking) |
| import FoundationNetworking |
| #endif |
| |
| import Dispatch |
| |
| public class THTTPSessionTransport: TAsyncTransport { |
| public class Factory : TAsyncTransportFactory { |
| public var responseValidate: ((HTTPURLResponse?, Data?) throws -> Void)? |
| |
| var session: URLSession |
| var url: URL |
| |
| public class func setupDefaultsForSessionConfiguration(_ config: URLSessionConfiguration, withProtocolName protocolName: String?) { |
| var thriftContentType = "application/x-thrift" |
| |
| if let protocolName = protocolName { |
| thriftContentType += "; p=\(protocolName)" |
| } |
| |
| config.requestCachePolicy = .reloadIgnoringLocalCacheData |
| config.urlCache = nil |
| |
| config.httpShouldUsePipelining = true |
| config.httpShouldSetCookies = true |
| config.httpAdditionalHeaders = ["Content-Type": thriftContentType, |
| "Accept": thriftContentType, |
| "User-Agent": "Thrift/Swift (Session)"] |
| |
| |
| } |
| |
| public init(session: URLSession, url: URL) { |
| self.session = session |
| self.url = url |
| } |
| |
| public func newTransport() -> THTTPSessionTransport { |
| return THTTPSessionTransport(factory: self) |
| } |
| |
| func validateResponse(_ response: HTTPURLResponse?, data: Data?) throws { |
| try responseValidate?(response, data) |
| } |
| |
| func taskWithRequest(_ request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> ()) throws -> URLSessionTask { |
| |
| let newTask: URLSessionTask? = session.dataTask(with: request, completionHandler: completionHandler) |
| if let newTask = newTask { |
| return newTask |
| } else { |
| throw TTransportError(error: .unknown, message: "Failed to create session data task") |
| } |
| } |
| } |
| |
| var factory: Factory |
| var requestData = Data() |
| var responseData = Data() |
| var responseDataOffset: Int = 0 |
| |
| init(factory: Factory) { |
| self.factory = factory |
| } |
| |
| public func readAll(size: Int) throws -> Data { |
| let read = try self.read(size: size) |
| if read.count != size { |
| throw TTransportError(error: .endOfFile) |
| } |
| return read |
| } |
| |
| public func read(size: Int) throws -> Data { |
| let avail = responseData.count - responseDataOffset |
| let (start, stop) = (responseDataOffset, responseDataOffset + min(size, avail)) |
| let read = responseData.subdata(in: start..<stop) |
| responseDataOffset += read.count |
| return read |
| } |
| |
| public func write(data: Data) throws { |
| requestData.append(data) |
| } |
| |
| public func flush(_ completed: @escaping (TAsyncTransport, Error?) -> Void) { |
| var error: Error? |
| var task: URLSessionTask? |
| |
| var request = URLRequest(url: factory.url) |
| request.httpMethod = "POST" |
| request.httpBody = requestData |
| |
| requestData = Data() |
| |
| do { |
| task = try factory.taskWithRequest(request, completionHandler: { (data, response, taskError) in |
| |
| // Check if there was an error with the network |
| if taskError != nil { |
| error = TTransportError(error: .timedOut) |
| completed(self, error) |
| return |
| } |
| |
| // Check response type |
| if taskError == nil && !(response is HTTPURLResponse) { |
| error = THTTPTransportError(error: .invalidResponse) |
| completed(self, error) |
| return |
| } |
| |
| // Check status code |
| if let httpResponse = response as? HTTPURLResponse { |
| if taskError == nil && httpResponse.statusCode != 200 { |
| if httpResponse.statusCode == 401 { |
| error = THTTPTransportError(error: .authentication) |
| } else { |
| error = THTTPTransportError(error: .invalidStatus(statusCode: httpResponse.statusCode)) |
| } |
| } |
| |
| // Allow factory to check |
| if error != nil { |
| do { |
| try self.factory.validateResponse(httpResponse, data: data) |
| } catch let validateError { |
| error = validateError |
| } |
| } |
| |
| self.responseDataOffset = 0 |
| if error != nil { |
| self.responseData = Data() |
| } else { |
| self.responseData = data ?? Data() |
| } |
| completed(self, error) |
| } |
| }) |
| |
| } catch let taskError { |
| error = taskError |
| } |
| |
| if let error = error, task == nil { |
| completed(self, error) |
| } |
| task?.resume() |
| } |
| |
| public func flush() throws { |
| let completed = DispatchSemaphore(value: 0) |
| var internalError: Error? |
| |
| flush() { _, error in |
| internalError = error |
| completed.signal() |
| } |
| |
| _ = completed.wait(timeout: DispatchTime.distantFuture) |
| |
| if let error = internalError { |
| throw error |
| } |
| } |
| } |