blob: 3c0af8eb89084fa8ca944c1732bcc77797d07633 [file] [log] [blame]
Chris Simpsona9b6c702018-04-08 07:11:37 -04001/*
2* Licensed to the Apache Software Foundation (ASF) under one
3* or more contributor license agreements. See the NOTICE file
4* distributed with this work for additional information
5* regarding copyright ownership. The ASF licenses this file
6* to you under the Apache License, Version 2.0 (the
7* "License"); you may not use this file except in compliance
8* with the License. You may obtain a copy of the License at
9*
10* http://www.apache.org/licenses/LICENSE-2.0
11*
12* Unless required by applicable law or agreed to in writing,
13* software distributed under the License is distributed on an
14* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15* KIND, either express or implied. See the License for the
16* specific language governing permissions and limitations
17* under the License.
18*/
19
20import Foundation
21import Dispatch
22
23
24public class THTTPSessionTransport: TAsyncTransport {
25 public class Factory : TAsyncTransportFactory {
26 public var responseValidate: ((HTTPURLResponse?, Data?) throws -> Void)?
27
28 var session: URLSession
29 var url: URL
30
31 public class func setupDefaultsForSessionConfiguration(_ config: URLSessionConfiguration, withProtocolName protocolName: String?) {
32 var thriftContentType = "application/x-thrift"
33
34 if let protocolName = protocolName {
35 thriftContentType += "; p=\(protocolName)"
36 }
37
38 config.requestCachePolicy = .reloadIgnoringLocalCacheData
39 config.urlCache = nil
40
41 config.httpShouldUsePipelining = true
42 config.httpShouldSetCookies = true
43 config.httpAdditionalHeaders = ["Content-Type": thriftContentType,
44 "Accept": thriftContentType,
45 "User-Agent": "Thrift/Swift (Session)"]
46
47
48 }
49
50 public init(session: URLSession, url: URL) {
51 self.session = session
52 self.url = url
53 }
54
55 public func newTransport() -> THTTPSessionTransport {
56 return THTTPSessionTransport(factory: self)
57 }
58
59 func validateResponse(_ response: HTTPURLResponse?, data: Data?) throws {
60 try responseValidate?(response, data)
61 }
62
63 func taskWithRequest(_ request: URLRequest, completionHandler: @escaping (Data?, URLResponse?, Error?) -> ()) throws -> URLSessionTask {
64
65 let newTask: URLSessionTask? = session.dataTask(with: request, completionHandler: completionHandler)
66 if let newTask = newTask {
67 return newTask
68 } else {
69 throw TTransportError(error: .unknown, message: "Failed to create session data task")
70 }
71 }
72 }
73
74 var factory: Factory
75 var requestData = Data()
76 var responseData = Data()
77 var responseDataOffset: Int = 0
78
79 init(factory: Factory) {
80 self.factory = factory
81 }
82
83 public func readAll(size: Int) throws -> Data {
84 let read = try self.read(size: size)
85 if read.count != size {
86 throw TTransportError(error: .endOfFile)
87 }
88 return read
89 }
90
91 public func read(size: Int) throws -> Data {
92 let avail = responseData.count - responseDataOffset
93 let (start, stop) = (responseDataOffset, responseDataOffset + min(size, avail))
94 let read = responseData.subdata(in: start..<stop)
95 responseDataOffset += read.count
96 return read
97 }
98
99 public func write(data: Data) throws {
100 requestData.append(data)
101 }
102
103 public func flush(_ completed: @escaping (TAsyncTransport, Error?) -> Void) {
104 var error: Error?
105 var task: URLSessionTask?
106
107 var request = URLRequest(url: factory.url)
108 request.httpMethod = "POST"
109 request.httpBody = requestData
110
111 requestData = Data()
112
113 do {
114 task = try factory.taskWithRequest(request, completionHandler: { (data, response, taskError) in
115
116 // Check if there was an error with the network
117 if taskError != nil {
118 error = TTransportError(error: .timedOut)
119 completed(self, error)
120 return
121 }
122
123 // Check response type
124 if taskError == nil && !(response is HTTPURLResponse) {
125 error = THTTPTransportError(error: .invalidResponse)
126 completed(self, error)
127 return
128 }
129
130 // Check status code
131 if let httpResponse = response as? HTTPURLResponse {
132 if taskError == nil && httpResponse.statusCode != 200 {
133 if httpResponse.statusCode == 401 {
134 error = THTTPTransportError(error: .authentication)
135 } else {
136 error = THTTPTransportError(error: .invalidStatus(statusCode: httpResponse.statusCode))
137 }
138 }
139
140 // Allow factory to check
141 if error != nil {
142 do {
143 try self.factory.validateResponse(httpResponse, data: data)
144 } catch let validateError {
145 error = validateError
146 }
147 }
148
149 self.responseDataOffset = 0
150 if error != nil {
151 self.responseData = Data()
152 } else {
153 self.responseData = data ?? Data()
154 }
155 completed(self, error)
156 }
157 })
158
159 } catch let taskError {
160 error = taskError
161 }
162
163 if let error = error, task == nil {
164 completed(self, error)
165 }
166 task?.resume()
167 }
168
169 public func flush() throws {
170 let completed = DispatchSemaphore(value: 0)
171 var internalError: Error?
172
173 flush() { _, error in
174 internalError = error
175 completed.signal()
176 }
177
178 _ = completed.wait(timeout: DispatchTime.distantFuture)
179
180 if let error = internalError {
181 throw error
182 }
183 }
184}