Chris Simpson | a9b6c70 | 2018-04-08 07:11:37 -0400 | [diff] [blame^] | 1 | /* |
| 2 | * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | * or more contributor license agreements. See the NOTICE file |
| 4 | * distributed with this work for additional information |
| 5 | * regarding copyright ownership. The ASF licenses this file |
| 6 | * to you under the Apache License, Version 2.0 (the |
| 7 | * "License"); you may not use this file except in compliance |
| 8 | * with the License. You may obtain a copy of the License at |
| 9 | * |
| 10 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | * |
| 12 | * Unless required by applicable law or agreed to in writing, |
| 13 | * software distributed under the License is distributed on an |
| 14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | * KIND, either express or implied. See the License for the |
| 16 | * specific language governing permissions and limitations |
| 17 | * under the License. |
| 18 | */ |
| 19 | |
| 20 | import Foundation |
| 21 | import Dispatch |
| 22 | |
| 23 | |
| 24 | public class THTTPSessionTransport: TAsyncTransport { |
| 25 | public class Factory : TAsyncTransportFactory { |
| 26 | public var responseValidate: ((HTTPURLResponse?, Data?) throws -> Void)? |
| 27 | |
| 28 | var session: URLSession |
| 29 | var url: URL |
| 30 | |
| 31 | public class func setupDefaultsForSessionConfiguration(_ config: URLSessionConfiguration, withProtocolName protocolName: String?) { |
| 32 | var thriftContentType = "application/x-thrift" |
| 33 | |
| 34 | if let protocolName = protocolName { |
| 35 | thriftContentType += "; p=\(protocolName)" |
| 36 | } |
| 37 | |
| 38 | config.requestCachePolicy = .reloadIgnoringLocalCacheData |
| 39 | config.urlCache = nil |
| 40 | |
| 41 | config.httpShouldUsePipelining = true |
| 42 | config.httpShouldSetCookies = true |
| 43 | config.httpAdditionalHeaders = ["Content-Type": thriftContentType, |
| 44 | "Accept": thriftContentType, |
| 45 | "User-Agent": "Thrift/Swift (Session)"] |
| 46 | |
| 47 | |
| 48 | } |
| 49 | |
| 50 | public init(session: URLSession, url: URL) { |
| 51 | self.session = session |
| 52 | self.url = url |
| 53 | } |
| 54 | |
| 55 | public func newTransport() -> THTTPSessionTransport { |
| 56 | return THTTPSessionTransport(factory: self) |
| 57 | } |
| 58 | |
| 59 | func validateResponse(_ response: HTTPURLResponse?, data: Data?) throws { |
| 60 | try responseValidate?(response, data) |
| 61 | } |
| 62 | |
| 63 | func taskWithRequest(_ request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> ()) throws -> URLSessionTask { |
| 64 | |
| 65 | let newTask: URLSessionTask? = session.dataTask(with: request, completionHandler: completionHandler) |
| 66 | if let newTask = newTask { |
| 67 | return newTask |
| 68 | } else { |
| 69 | throw TTransportError(error: .unknown, message: "Failed to create session data task") |
| 70 | } |
| 71 | } |
| 72 | } |
| 73 | |
| 74 | var factory: Factory |
| 75 | var requestData = Data() |
| 76 | var responseData = Data() |
| 77 | var responseDataOffset: Int = 0 |
| 78 | |
| 79 | init(factory: Factory) { |
| 80 | self.factory = factory |
| 81 | } |
| 82 | |
| 83 | public func readAll(size: Int) throws -> Data { |
| 84 | let read = try self.read(size: size) |
| 85 | if read.count != size { |
| 86 | throw TTransportError(error: .endOfFile) |
| 87 | } |
| 88 | return read |
| 89 | } |
| 90 | |
| 91 | public func read(size: Int) throws -> Data { |
| 92 | let avail = responseData.count - responseDataOffset |
| 93 | let (start, stop) = (responseDataOffset, responseDataOffset + min(size, avail)) |
| 94 | let read = responseData.subdata(in: start..<stop) |
| 95 | responseDataOffset += read.count |
| 96 | return read |
| 97 | } |
| 98 | |
| 99 | public func write(data: Data) throws { |
| 100 | requestData.append(data) |
| 101 | } |
| 102 | |
| 103 | public func flush(_ completed: @escaping (TAsyncTransport, Error?) -> Void) { |
| 104 | var error: Error? |
| 105 | var task: URLSessionTask? |
| 106 | |
| 107 | var request = URLRequest(url: factory.url) |
| 108 | request.httpMethod = "POST" |
| 109 | request.httpBody = requestData |
| 110 | |
| 111 | requestData = Data() |
| 112 | |
| 113 | do { |
| 114 | task = try factory.taskWithRequest(request, completionHandler: { (data, response, taskError) in |
| 115 | |
| 116 | // Check if there was an error with the network |
| 117 | if taskError != nil { |
| 118 | error = TTransportError(error: .timedOut) |
| 119 | completed(self, error) |
| 120 | return |
| 121 | } |
| 122 | |
| 123 | // Check response type |
| 124 | if taskError == nil && !(response is HTTPURLResponse) { |
| 125 | error = THTTPTransportError(error: .invalidResponse) |
| 126 | completed(self, error) |
| 127 | return |
| 128 | } |
| 129 | |
| 130 | // Check status code |
| 131 | if let httpResponse = response as? HTTPURLResponse { |
| 132 | if taskError == nil && httpResponse.statusCode != 200 { |
| 133 | if httpResponse.statusCode == 401 { |
| 134 | error = THTTPTransportError(error: .authentication) |
| 135 | } else { |
| 136 | error = THTTPTransportError(error: .invalidStatus(statusCode: httpResponse.statusCode)) |
| 137 | } |
| 138 | } |
| 139 | |
| 140 | // Allow factory to check |
| 141 | if error != nil { |
| 142 | do { |
| 143 | try self.factory.validateResponse(httpResponse, data: data) |
| 144 | } catch let validateError { |
| 145 | error = validateError |
| 146 | } |
| 147 | } |
| 148 | |
| 149 | self.responseDataOffset = 0 |
| 150 | if error != nil { |
| 151 | self.responseData = Data() |
| 152 | } else { |
| 153 | self.responseData = data ?? Data() |
| 154 | } |
| 155 | completed(self, error) |
| 156 | } |
| 157 | }) |
| 158 | |
| 159 | } catch let taskError { |
| 160 | error = taskError |
| 161 | } |
| 162 | |
| 163 | if let error = error, task == nil { |
| 164 | completed(self, error) |
| 165 | } |
| 166 | task?.resume() |
| 167 | } |
| 168 | |
| 169 | public func flush() throws { |
| 170 | let completed = DispatchSemaphore(value: 0) |
| 171 | var internalError: Error? |
| 172 | |
| 173 | flush() { _, error in |
| 174 | internalError = error |
| 175 | completed.signal() |
| 176 | } |
| 177 | |
| 178 | _ = completed.wait(timeout: DispatchTime.distantFuture) |
| 179 | |
| 180 | if let error = internalError { |
| 181 | throw error |
| 182 | } |
| 183 | } |
| 184 | } |