Chris Simpson | a9b6c70 | 2018-04-08 07:11:37 -0400 | [diff] [blame] | 1 | /* |
| 2 | * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | * or more contributor license agreements. See the NOTICE file |
| 4 | * distributed with this work for additional information |
| 5 | * regarding copyright ownership. The ASF licenses this file |
| 6 | * to you under the Apache License, Version 2.0 (the |
| 7 | * "License"); you may not use this file except in compliance |
| 8 | * with the License. You may obtain a copy of the License at |
| 9 | * |
| 10 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | * |
| 12 | * Unless required by applicable law or agreed to in writing, |
| 13 | * software distributed under the License is distributed on an |
| 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | * KIND, either express or implied. See the License for the |
| 16 | * specific language governing permissions and limitations |
| 17 | * under the License. |
| 18 | */ |
| 19 | |
| 20 | import Foundation |
Chris Simpson | a9b6c70 | 2018-04-08 07:11:37 -0400 | [diff] [blame] | 21 | |
Jano Svitok | 2e11577 | 2020-03-06 09:01:43 +0100 | [diff] [blame] | 22 | // Conditional import for URLRequest |
| 23 | // It was moved from Foundation to FoundationNetworking in 5.1, but |
| 24 | // not on Darwin. See https://stackoverflow.com/a/58606520 |
| 25 | #if canImport(FoundationNetworking) |
| 26 | import FoundationNetworking |
| 27 | #endif |
| 28 | |
| 29 | import Dispatch |
Chris Simpson | a9b6c70 | 2018-04-08 07:11:37 -0400 | [diff] [blame] | 30 | |
| 31 | public class THTTPSessionTransport: TAsyncTransport { |
| 32 | public class Factory : TAsyncTransportFactory { |
| 33 | public var responseValidate: ((HTTPURLResponse?, Data?) throws -> Void)? |
| 34 | |
| 35 | var session: URLSession |
| 36 | var url: URL |
| 37 | |
| 38 | public class func setupDefaultsForSessionConfiguration(_ config: URLSessionConfiguration, withProtocolName protocolName: String?) { |
| 39 | var thriftContentType = "application/x-thrift" |
| 40 | |
| 41 | if let protocolName = protocolName { |
| 42 | thriftContentType += "; p=\(protocolName)" |
| 43 | } |
| 44 | |
| 45 | config.requestCachePolicy = .reloadIgnoringLocalCacheData |
| 46 | config.urlCache = nil |
| 47 | |
| 48 | config.httpShouldUsePipelining = true |
| 49 | config.httpShouldSetCookies = true |
| 50 | config.httpAdditionalHeaders = ["Content-Type": thriftContentType, |
| 51 | "Accept": thriftContentType, |
| 52 | "User-Agent": "Thrift/Swift (Session)"] |
| 53 | |
| 54 | |
| 55 | } |
| 56 | |
| 57 | public init(session: URLSession, url: URL) { |
| 58 | self.session = session |
| 59 | self.url = url |
| 60 | } |
| 61 | |
| 62 | public func newTransport() -> THTTPSessionTransport { |
| 63 | return THTTPSessionTransport(factory: self) |
| 64 | } |
| 65 | |
| 66 | func validateResponse(_ response: HTTPURLResponse?, data: Data?) throws { |
| 67 | try responseValidate?(response, data) |
| 68 | } |
| 69 | |
| 70 | func taskWithRequest(_ request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> ()) throws -> URLSessionTask { |
| 71 | |
| 72 | let newTask: URLSessionTask? = session.dataTask(with: request, completionHandler: completionHandler) |
| 73 | if let newTask = newTask { |
| 74 | return newTask |
| 75 | } else { |
| 76 | throw TTransportError(error: .unknown, message: "Failed to create session data task") |
| 77 | } |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | var factory: Factory |
| 82 | var requestData = Data() |
| 83 | var responseData = Data() |
| 84 | var responseDataOffset: Int = 0 |
| 85 | |
| 86 | init(factory: Factory) { |
| 87 | self.factory = factory |
| 88 | } |
| 89 | |
| 90 | public func readAll(size: Int) throws -> Data { |
| 91 | let read = try self.read(size: size) |
| 92 | if read.count != size { |
| 93 | throw TTransportError(error: .endOfFile) |
| 94 | } |
| 95 | return read |
| 96 | } |
| 97 | |
| 98 | public func read(size: Int) throws -> Data { |
| 99 | let avail = responseData.count - responseDataOffset |
| 100 | let (start, stop) = (responseDataOffset, responseDataOffset + min(size, avail)) |
| 101 | let read = responseData.subdata(in: start..<stop) |
| 102 | responseDataOffset += read.count |
| 103 | return read |
| 104 | } |
| 105 | |
| 106 | public func write(data: Data) throws { |
| 107 | requestData.append(data) |
| 108 | } |
| 109 | |
| 110 | public func flush(_ completed: @escaping (TAsyncTransport, Error?) -> Void) { |
| 111 | var error: Error? |
| 112 | var task: URLSessionTask? |
| 113 | |
| 114 | var request = URLRequest(url: factory.url) |
| 115 | request.httpMethod = "POST" |
| 116 | request.httpBody = requestData |
| 117 | |
| 118 | requestData = Data() |
| 119 | |
| 120 | do { |
| 121 | task = try factory.taskWithRequest(request, completionHandler: { (data, response, taskError) in |
| 122 | |
| 123 | // Check if there was an error with the network |
| 124 | if taskError != nil { |
| 125 | error = TTransportError(error: .timedOut) |
| 126 | completed(self, error) |
| 127 | return |
| 128 | } |
| 129 | |
| 130 | // Check response type |
| 131 | if taskError == nil && !(response is HTTPURLResponse) { |
| 132 | error = THTTPTransportError(error: .invalidResponse) |
| 133 | completed(self, error) |
| 134 | return |
| 135 | } |
| 136 | |
| 137 | // Check status code |
| 138 | if let httpResponse = response as? HTTPURLResponse { |
| 139 | if taskError == nil && httpResponse.statusCode != 200 { |
| 140 | if httpResponse.statusCode == 401 { |
| 141 | error = THTTPTransportError(error: .authentication) |
| 142 | } else { |
| 143 | error = THTTPTransportError(error: .invalidStatus(statusCode: httpResponse.statusCode)) |
| 144 | } |
| 145 | } |
| 146 | |
| 147 | // Allow factory to check |
| 148 | if error != nil { |
| 149 | do { |
| 150 | try self.factory.validateResponse(httpResponse, data: data) |
| 151 | } catch let validateError { |
| 152 | error = validateError |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | self.responseDataOffset = 0 |
| 157 | if error != nil { |
| 158 | self.responseData = Data() |
| 159 | } else { |
| 160 | self.responseData = data ?? Data() |
| 161 | } |
| 162 | completed(self, error) |
| 163 | } |
| 164 | }) |
| 165 | |
| 166 | } catch let taskError { |
| 167 | error = taskError |
| 168 | } |
| 169 | |
| 170 | if let error = error, task == nil { |
| 171 | completed(self, error) |
| 172 | } |
| 173 | task?.resume() |
| 174 | } |
| 175 | |
| 176 | public func flush() throws { |
| 177 | let completed = DispatchSemaphore(value: 0) |
| 178 | var internalError: Error? |
| 179 | |
| 180 | flush() { _, error in |
| 181 | internalError = error |
| 182 | completed.signal() |
| 183 | } |
| 184 | |
| 185 | _ = completed.wait(timeout: DispatchTime.distantFuture) |
| 186 | |
| 187 | if let error = internalError { |
| 188 | throw error |
| 189 | } |
| 190 | } |
| 191 | } |