THRIFT-5580: refactor kotlin cross tests (#2600)
refactor kotlin cross test to:
* use proper cli framework,
* add more transport/protocol cases
diff --git a/lib/kotlin/cross-test-server/build.gradle.kts b/lib/kotlin/cross-test-server/build.gradle.kts
index 2246fae..7a0c48b 100644
--- a/lib/kotlin/cross-test-server/build.gradle.kts
+++ b/lib/kotlin/cross-test-server/build.gradle.kts
@@ -32,17 +32,17 @@
val httpcoreVersion: String by project
val logbackVersion: String by project
val kotlinxCoroutinesJdk8Version: String by project
+val cliktVersion: String by project
dependencies {
implementation(platform("org.jetbrains.kotlin:kotlin-bom"))
implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
- // https://mvnrepository.com/artifact/org.jetbrains.kotlinx/kotlinx-coroutines-jdk8
+ // clikt is used to drive command line parsing and validation
+ implementation("com.github.ajalt.clikt:clikt:$cliktVersion")
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-jdk8:$kotlinxCoroutinesJdk8Version")
- // https://mvnrepository.com/artifact/org.apache.thrift/libthrift
implementation("org.apache.thrift:libthrift:INCLUDED")
implementation("org.slf4j:slf4j-api:$slf4jVersion")
implementation("org.apache.httpcomponents:httpcore:$httpcoreVersion")
- // https://mvnrepository.com/artifact/ch.qos.logback/logback-classic
implementation("ch.qos.logback:logback-classic:$logbackVersion")
testImplementation("org.jetbrains.kotlin:kotlin-test")
testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
diff --git a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt
index 4bbdb6a..4967345 100644
--- a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt
+++ b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt
@@ -89,10 +89,7 @@
override suspend fun testStruct(thing: Xtruct): Xtruct {
logger.info(
- """
-testStruct({"${thing.string_thing}", ${thing.byte_thing}, ${thing.i32_thing}, ${thing.i64_thing}})
-
-""".trimIndent()
+ """testStruct({"${thing.string_thing}", ${thing.byte_thing}, ${thing.i32_thing}, ${thing.i64_thing}})"""
)
return thing
}
@@ -100,10 +97,7 @@
override suspend fun testNest(thing: Xtruct2): Xtruct2 {
val thing2: Xtruct = thing.struct_thing!!
logger.info(
- """
-testNest({${thing.byte_thing}, {"${thing2.string_thing}", ${thing2.byte_thing}, ${thing2.i32_thing}, ${thing2.i64_thing}}, ${thing.i32_thing}})
-
-""".trimIndent()
+ """testNest({${thing.byte_thing}, {"${thing2.string_thing}", ${thing2.byte_thing}, ${thing2.i32_thing}, ${thing2.i64_thing}}, ${thing.i32_thing}})""".trimIndent()
)
return thing
}
diff --git a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt
index 4b2bdff..d9c0c86 100644
--- a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt
+++ b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt
@@ -19,7 +19,13 @@
package org.apache.thrift.test
-import kotlin.system.exitProcess
+import com.github.ajalt.clikt.core.CliktCommand
+import com.github.ajalt.clikt.parameters.options.default
+import com.github.ajalt.clikt.parameters.options.flag
+import com.github.ajalt.clikt.parameters.options.option
+import com.github.ajalt.clikt.parameters.types.enum
+import com.github.ajalt.clikt.parameters.types.int
+import com.github.ajalt.clikt.parameters.types.long
import kotlinx.coroutines.GlobalScope
import org.apache.thrift.TException
import org.apache.thrift.TMultiplexedProcessor
@@ -136,177 +142,147 @@
}
}
-fun main(args: Array<String>) {
- try {
- var port = 9090
- var ssl = false
- var zlib = false
- var transportType = "buffered"
- var protocolType = "binary"
- // var serverType = "thread-pool"
- var serverType = "nonblocking"
- val domainSocket = ""
- var stringLimit: Long = -1
- var containerLimit: Long = -1
- try {
- for (i in args.indices) {
- if (args[i].startsWith("--port")) {
- port = Integer.valueOf(args[i].split("=").toTypedArray()[1])
- } else if (args[i].startsWith("--server-type")) {
- serverType = args[i].split("=").toTypedArray()[1]
- serverType.trim { it <= ' ' }
- } else if (args[i].startsWith("--port")) {
- port = args[i].split("=").toTypedArray()[1].toInt()
- } else if (args[i].startsWith("--protocol")) {
- protocolType = args[i].split("=").toTypedArray()[1]
- protocolType.trim { it <= ' ' }
- } else if (args[i].startsWith("--transport")) {
- transportType = args[i].split("=").toTypedArray()[1]
- transportType.trim { it <= ' ' }
- } else if (args[i] == "--ssl") {
- ssl = true
- } else if (args[i] == "--zlib") {
- zlib = true
- } else if (args[i].startsWith("--string-limit")) {
- stringLimit = args[i].split("=").toTypedArray()[1].toLong()
- } else if (args[i].startsWith("--container-limit")) {
- containerLimit = args[i].split("=").toTypedArray()[1].toLong()
- } else if (args[i] == "--help") {
- println("Allowed options:")
- println(" --help\t\t\tProduce help message")
- println(" --port=arg (=$port)\tPort number to connect")
- println(
- " --transport=arg (=$transportType)\n\t\t\t\tTransport: buffered, framed, fastframed, zlib"
- )
- println(
- " --protocol=arg (=$protocolType)\tProtocol: binary, compact, json, multi, multic, multij"
- )
- println(" --ssl\t\t\tEncrypted Transport using SSL")
- println(" --zlib\t\t\tCompressed Transport using Zlib")
- println(
- " --server-type=arg (=$serverType)\n\t\t\t\tType of server: simple, thread-pool, nonblocking, threaded-selector"
- )
- println(" --string-limit=arg (=$stringLimit)\tString read length limit")
- println(
- " --container-limit=arg (=$containerLimit)\tContainer read length limit"
- )
- exitProcess(0)
- }
- }
- } catch (e: Exception) {
- System.err.println("Can not parse arguments! See --help")
- exitProcess(1)
- }
- try {
- when (serverType) {
- "simple" -> {}
- "thread-pool" -> {}
- "nonblocking" -> {
- if (ssl) {
- throw Exception("SSL is not supported over nonblocking servers!")
- }
- }
- "threaded-selector" -> {
- if (ssl) {
- throw Exception("SSL is not supported over nonblocking servers!")
- }
- }
- else -> {
- throw Exception("Unknown server type! $serverType")
- }
- }
- when (protocolType) {
- "binary" -> {}
- "compact" -> {}
- "json" -> {}
- "multi" -> {}
- "multic" -> {}
- "multij" -> {}
- else -> {
- throw Exception("Unknown protocol type! $protocolType")
- }
- }
- when (transportType) {
- "buffered" -> {}
- "framed" -> {}
- "fastframed" -> {}
- "zlib" -> {}
- else -> {
- throw Exception("Unknown transport type! $transportType")
- }
- }
- } catch (e: Exception) {
- System.err.println("Error: " + e.message)
- exitProcess(1)
- }
+enum class ServerType(val key: String) {
+ Simple("simple"),
+ ThreadPool("thread-pool"),
+ NonBlocking("nonblocking"),
+ ThreadedSelector("threaded-selector")
+}
- // Processors
+enum class ProtocolType(val key: String) {
+ Binary("binary"),
+ Multi("multi"),
+ Json("json"),
+ MultiJson("multij"),
+ Compact("compact"),
+ MultiCompact("multic")
+}
+
+enum class TransportType(val key: String) {
+ Buffered("buffered"),
+ FastFramed("fastframed"),
+ Framed("framed"),
+ Zlib("zlib")
+}
+
+class TestServerCommand : CliktCommand() {
+ private val port: Int by option(help = "The cross test port to connect to").int().default(9090)
+ private val protocolType: ProtocolType by option("--protocol", help = "Protocol type")
+ .enum<ProtocolType> { it.key }
+ .default(ProtocolType.Binary)
+ private val transportType: TransportType by option("--transport", help = "Transport type")
+ .enum<TransportType> { it.key }
+ .default(TransportType.Buffered)
+ private val serverType: ServerType by option("--server-type")
+ .enum<ServerType> { it.key }
+ .default(ServerType.NonBlocking)
+ private val useSSL: Boolean by option("--ssl", help = "Use SSL for encrypted transport")
+ .flag(default = false)
+ private val stringLimit: Long by option("--string-limit").long().default(-1)
+ private val containerLimit: Long by option("--container-limit").long().default(-1)
+
+ @Suppress("OPT_IN_USAGE")
+ override fun run() {
val testHandler = TestHandler()
val testProcessor = ThriftTestProcessor(testHandler, scope = GlobalScope)
val secondHandler = TestServer.SecondHandler()
val secondProcessor = SecondServiceProcessor(secondHandler, scope = GlobalScope)
+ val serverEngine: TServer =
+ getServerEngine(
+ testProcessor,
+ secondProcessor,
+ serverType,
+ port,
+ protocolType,
+ getProtocolFactory(),
+ getTransportFactory(),
+ useSSL
+ )
+ // Set server event handler
+ serverEngine.setServerEventHandler(TestServer.TestServerEventHandler())
+ // Run it
+ println(
+ "Starting the ${if (useSSL) "ssl server" else "server"} [$protocolType/$transportType/$serverType] on port $port"
+ )
+ serverEngine.serve()
+ }
- // Protocol factory
- val tProtocolFactory: TProtocolFactory =
- when (protocolType) {
- "json", "multij" -> {
- TJSONProtocol.Factory()
- }
- "compact", "multic" -> {
- TCompactProtocol.Factory(stringLimit, containerLimit)
- }
- else -> { // also covers multi
- TBinaryProtocol.Factory(stringLimit, containerLimit)
- }
+ private fun getTransportFactory(): TTransportFactory =
+ when (transportType) {
+ TransportType.Framed -> {
+ TFramedTransport.Factory()
}
- val tTransportFactory: TTransportFactory =
- when (transportType) {
- "framed" -> {
- TFramedTransport.Factory()
- }
- "fastframed" -> {
- TFastFramedTransport.Factory()
- }
- "zlib" -> {
- TZlibTransport.Factory()
- }
- else -> { // .equals("buffered") => default value
- TTransportFactory()
- }
+ TransportType.FastFramed -> {
+ TFastFramedTransport.Factory()
}
- val serverEngine: TServer
- // If we are multiplexing services in one server...
- val multiplexedProcessor = TMultiplexedProcessor()
- multiplexedProcessor.registerDefault(testProcessor)
- multiplexedProcessor.registerProcessor("ThriftTest", testProcessor)
- multiplexedProcessor.registerProcessor("SecondService", secondProcessor)
- if (serverType == "nonblocking" || serverType == "threaded-selector") {
- // Nonblocking servers
+ TransportType.Zlib -> {
+ TZlibTransport.Factory()
+ }
+ TransportType.Buffered -> {
+ TTransportFactory()
+ }
+ }
+
+ private fun getProtocolFactory(): TProtocolFactory =
+ when (protocolType) {
+ ProtocolType.Json, ProtocolType.MultiJson -> TJSONProtocol.Factory()
+ ProtocolType.Compact, ProtocolType.MultiCompact ->
+ TCompactProtocol.Factory(stringLimit, containerLimit)
+ ProtocolType.Binary, ProtocolType.Multi ->
+ TBinaryProtocol.Factory(stringLimit, containerLimit)
+ }
+}
+
+fun main(args: Array<String>) {
+ TestServerCommand().main(args)
+}
+
+private fun getServerEngine(
+ testProcessor: ThriftTestProcessor,
+ secondProcessor: SecondServiceProcessor,
+ serverType: ServerType,
+ port: Int,
+ protocolType: ProtocolType,
+ tProtocolFactory: TProtocolFactory,
+ tTransportFactory: TTransportFactory,
+ ssl: Boolean
+): TServer {
+ val isMulti =
+ protocolType == ProtocolType.Multi ||
+ protocolType == ProtocolType.MultiCompact ||
+ protocolType == ProtocolType.MultiJson
+ // If we are multiplexing services in one server...
+ val multiplexedProcessor = TMultiplexedProcessor()
+ multiplexedProcessor.registerDefault(testProcessor)
+ multiplexedProcessor.registerProcessor("ThriftTest", testProcessor)
+ multiplexedProcessor.registerProcessor("SecondService", secondProcessor)
+ when (serverType) {
+ ServerType.NonBlocking, ServerType.ThreadedSelector -> {
val tNonblockingServerSocket =
TNonblockingServerSocket(NonblockingAbstractServerSocketArgs().port(port))
- if (serverType.contains("nonblocking")) {
- // Nonblocking Server
- val tNonblockingServerArgs = TNonblockingServer.Args(tNonblockingServerSocket)
- tNonblockingServerArgs.processor(
- if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
- )
- tNonblockingServerArgs.protocolFactory(tProtocolFactory)
- tNonblockingServerArgs.transportFactory(tTransportFactory)
- serverEngine = TNonblockingServer(tNonblockingServerArgs)
- } else { // server_type.equals("threaded-selector")
- // ThreadedSelector Server
- val tThreadedSelectorServerArgs =
- TThreadedSelectorServer.Args(tNonblockingServerSocket)
- tThreadedSelectorServerArgs.processor(
- if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
- )
- tThreadedSelectorServerArgs.protocolFactory(tProtocolFactory)
- tThreadedSelectorServerArgs.transportFactory(tTransportFactory)
- serverEngine = TThreadedSelectorServer(tThreadedSelectorServerArgs)
+ when (serverType) {
+ ServerType.NonBlocking -> {
+ val tNonblockingServerArgs = TNonblockingServer.Args(tNonblockingServerSocket)
+ tNonblockingServerArgs.processor(
+ if (isMulti) multiplexedProcessor else testProcessor
+ )
+ tNonblockingServerArgs.protocolFactory(tProtocolFactory)
+ tNonblockingServerArgs.transportFactory(tTransportFactory)
+ return TNonblockingServer(tNonblockingServerArgs)
+ }
+ else -> {
+ val tThreadedSelectorServerArgs =
+ TThreadedSelectorServer.Args(tNonblockingServerSocket)
+ tThreadedSelectorServerArgs.processor(
+ if (isMulti) multiplexedProcessor else testProcessor
+ )
+ tThreadedSelectorServerArgs.protocolFactory(tProtocolFactory)
+ tThreadedSelectorServerArgs.transportFactory(tTransportFactory)
+ return TThreadedSelectorServer(tThreadedSelectorServerArgs)
+ }
}
- } else {
- // Blocking servers
-
+ }
+ ServerType.Simple, ServerType.ThreadPool -> {
// SSL socket
val tServerSocket: TServerSocket =
if (ssl) {
@@ -314,46 +290,24 @@
} else {
TServerSocket(ServerSocketTransportArgs().port(port))
}
- if (serverType == "simple") {
- // Simple Server
- val tServerArgs = TServer.Args(tServerSocket)
- tServerArgs.processor(
- if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
- )
- tServerArgs.protocolFactory(tProtocolFactory)
- tServerArgs.transportFactory(tTransportFactory)
- serverEngine = TSimpleServer(tServerArgs)
- } else { // server_type.equals("threadpool")
- // ThreadPool Server
- val tThreadPoolServerArgs = TThreadPoolServer.Args(tServerSocket)
- tThreadPoolServerArgs.processor(
- if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
- )
- tThreadPoolServerArgs.protocolFactory(tProtocolFactory)
- tThreadPoolServerArgs.transportFactory(tTransportFactory)
- serverEngine = TThreadPoolServer(tThreadPoolServerArgs)
+ when (serverType) {
+ ServerType.Simple -> {
+ val tServerArgs = TServer.Args(tServerSocket)
+ tServerArgs.processor(if (isMulti) multiplexedProcessor else testProcessor)
+ tServerArgs.protocolFactory(tProtocolFactory)
+ tServerArgs.transportFactory(tTransportFactory)
+ return TSimpleServer(tServerArgs)
+ }
+ else -> {
+ val tThreadPoolServerArgs = TThreadPoolServer.Args(tServerSocket)
+ tThreadPoolServerArgs.processor(
+ if (isMulti) multiplexedProcessor else testProcessor
+ )
+ tThreadPoolServerArgs.protocolFactory(tProtocolFactory)
+ tThreadPoolServerArgs.transportFactory(tTransportFactory)
+ return TThreadPoolServer(tThreadPoolServerArgs)
+ }
}
}
-
- // Set server event handler
- serverEngine.setServerEventHandler(TestServer.TestServerEventHandler())
-
- // Run it
- println(
- "Starting the " +
- (if (ssl) "ssl server" else "server") +
- " [" +
- protocolType +
- "/" +
- transportType +
- "/" +
- serverType +
- "] on " +
- if (domainSocket === "") "port $port" else "unix socket $domainSocket"
- )
- serverEngine.serve()
- } catch (x: Exception) {
- x.printStackTrace()
}
- println("done.")
}