THRIFT-5580: refactor kotlin cross tests (#2600)

refactor kotlin cross test to:
* use proper cli framework,
* add more transport/protocol cases
diff --git a/lib/kotlin/cross-test-server/build.gradle.kts b/lib/kotlin/cross-test-server/build.gradle.kts
index 2246fae..7a0c48b 100644
--- a/lib/kotlin/cross-test-server/build.gradle.kts
+++ b/lib/kotlin/cross-test-server/build.gradle.kts
@@ -32,17 +32,17 @@
 val httpcoreVersion: String by project
 val logbackVersion: String by project
 val kotlinxCoroutinesJdk8Version: String by project
+val cliktVersion: String by project
 
 dependencies {
     implementation(platform("org.jetbrains.kotlin:kotlin-bom"))
     implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
-    // https://mvnrepository.com/artifact/org.jetbrains.kotlinx/kotlinx-coroutines-jdk8
+    // clikt is used to drive command line parsing and validation
+    implementation("com.github.ajalt.clikt:clikt:$cliktVersion")
     implementation("org.jetbrains.kotlinx:kotlinx-coroutines-jdk8:$kotlinxCoroutinesJdk8Version")
-    // https://mvnrepository.com/artifact/org.apache.thrift/libthrift
     implementation("org.apache.thrift:libthrift:INCLUDED")
     implementation("org.slf4j:slf4j-api:$slf4jVersion")
     implementation("org.apache.httpcomponents:httpcore:$httpcoreVersion")
-    // https://mvnrepository.com/artifact/ch.qos.logback/logback-classic
     implementation("ch.qos.logback:logback-classic:$logbackVersion")
     testImplementation("org.jetbrains.kotlin:kotlin-test")
     testImplementation("org.jetbrains.kotlin:kotlin-test-junit")
diff --git a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt
index 4bbdb6a..4967345 100644
--- a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt
+++ b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestHandler.kt
@@ -89,10 +89,7 @@
 
     override suspend fun testStruct(thing: Xtruct): Xtruct {
         logger.info(
-            """
-testStruct({"${thing.string_thing}", ${thing.byte_thing}, ${thing.i32_thing}, ${thing.i64_thing}})
-
-""".trimIndent()
+            """testStruct({"${thing.string_thing}", ${thing.byte_thing}, ${thing.i32_thing}, ${thing.i64_thing}})"""
         )
         return thing
     }
@@ -100,10 +97,7 @@
     override suspend fun testNest(thing: Xtruct2): Xtruct2 {
         val thing2: Xtruct = thing.struct_thing!!
         logger.info(
-            """
-testNest({${thing.byte_thing}, {"${thing2.string_thing}", ${thing2.byte_thing}, ${thing2.i32_thing}, ${thing2.i64_thing}}, ${thing.i32_thing}})
-
-""".trimIndent()
+            """testNest({${thing.byte_thing}, {"${thing2.string_thing}", ${thing2.byte_thing}, ${thing2.i32_thing}, ${thing2.i64_thing}}, ${thing.i32_thing}})""".trimIndent()
         )
         return thing
     }
diff --git a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt
index 4b2bdff..d9c0c86 100644
--- a/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt
+++ b/lib/kotlin/cross-test-server/src/main/kotlin/org/apache/thrift/test/TestServer.kt
@@ -19,7 +19,13 @@
 
 package org.apache.thrift.test
 
-import kotlin.system.exitProcess
+import com.github.ajalt.clikt.core.CliktCommand
+import com.github.ajalt.clikt.parameters.options.default
+import com.github.ajalt.clikt.parameters.options.flag
+import com.github.ajalt.clikt.parameters.options.option
+import com.github.ajalt.clikt.parameters.types.enum
+import com.github.ajalt.clikt.parameters.types.int
+import com.github.ajalt.clikt.parameters.types.long
 import kotlinx.coroutines.GlobalScope
 import org.apache.thrift.TException
 import org.apache.thrift.TMultiplexedProcessor
@@ -136,177 +142,147 @@
     }
 }
 
-fun main(args: Array<String>) {
-    try {
-        var port = 9090
-        var ssl = false
-        var zlib = false
-        var transportType = "buffered"
-        var protocolType = "binary"
-        //  var serverType = "thread-pool"
-        var serverType = "nonblocking"
-        val domainSocket = ""
-        var stringLimit: Long = -1
-        var containerLimit: Long = -1
-        try {
-            for (i in args.indices) {
-                if (args[i].startsWith("--port")) {
-                    port = Integer.valueOf(args[i].split("=").toTypedArray()[1])
-                } else if (args[i].startsWith("--server-type")) {
-                    serverType = args[i].split("=").toTypedArray()[1]
-                    serverType.trim { it <= ' ' }
-                } else if (args[i].startsWith("--port")) {
-                    port = args[i].split("=").toTypedArray()[1].toInt()
-                } else if (args[i].startsWith("--protocol")) {
-                    protocolType = args[i].split("=").toTypedArray()[1]
-                    protocolType.trim { it <= ' ' }
-                } else if (args[i].startsWith("--transport")) {
-                    transportType = args[i].split("=").toTypedArray()[1]
-                    transportType.trim { it <= ' ' }
-                } else if (args[i] == "--ssl") {
-                    ssl = true
-                } else if (args[i] == "--zlib") {
-                    zlib = true
-                } else if (args[i].startsWith("--string-limit")) {
-                    stringLimit = args[i].split("=").toTypedArray()[1].toLong()
-                } else if (args[i].startsWith("--container-limit")) {
-                    containerLimit = args[i].split("=").toTypedArray()[1].toLong()
-                } else if (args[i] == "--help") {
-                    println("Allowed options:")
-                    println("  --help\t\t\tProduce help message")
-                    println("  --port=arg (=$port)\tPort number to connect")
-                    println(
-                        "  --transport=arg (=$transportType)\n\t\t\t\tTransport: buffered, framed, fastframed, zlib"
-                    )
-                    println(
-                        "  --protocol=arg (=$protocolType)\tProtocol: binary, compact, json, multi, multic, multij"
-                    )
-                    println("  --ssl\t\t\tEncrypted Transport using SSL")
-                    println("  --zlib\t\t\tCompressed Transport using Zlib")
-                    println(
-                        "  --server-type=arg (=$serverType)\n\t\t\t\tType of server: simple, thread-pool, nonblocking, threaded-selector"
-                    )
-                    println("  --string-limit=arg (=$stringLimit)\tString read length limit")
-                    println(
-                        "  --container-limit=arg (=$containerLimit)\tContainer read length limit"
-                    )
-                    exitProcess(0)
-                }
-            }
-        } catch (e: Exception) {
-            System.err.println("Can not parse arguments! See --help")
-            exitProcess(1)
-        }
-        try {
-            when (serverType) {
-                "simple" -> {}
-                "thread-pool" -> {}
-                "nonblocking" -> {
-                    if (ssl) {
-                        throw Exception("SSL is not supported over nonblocking servers!")
-                    }
-                }
-                "threaded-selector" -> {
-                    if (ssl) {
-                        throw Exception("SSL is not supported over nonblocking servers!")
-                    }
-                }
-                else -> {
-                    throw Exception("Unknown server type! $serverType")
-                }
-            }
-            when (protocolType) {
-                "binary" -> {}
-                "compact" -> {}
-                "json" -> {}
-                "multi" -> {}
-                "multic" -> {}
-                "multij" -> {}
-                else -> {
-                    throw Exception("Unknown protocol type! $protocolType")
-                }
-            }
-            when (transportType) {
-                "buffered" -> {}
-                "framed" -> {}
-                "fastframed" -> {}
-                "zlib" -> {}
-                else -> {
-                    throw Exception("Unknown transport type! $transportType")
-                }
-            }
-        } catch (e: Exception) {
-            System.err.println("Error: " + e.message)
-            exitProcess(1)
-        }
+enum class ServerType(val key: String) {
+    Simple("simple"),
+    ThreadPool("thread-pool"),
+    NonBlocking("nonblocking"),
+    ThreadedSelector("threaded-selector")
+}
 
-        // Processors
+enum class ProtocolType(val key: String) {
+    Binary("binary"),
+    Multi("multi"),
+    Json("json"),
+    MultiJson("multij"),
+    Compact("compact"),
+    MultiCompact("multic")
+}
+
+enum class TransportType(val key: String) {
+    Buffered("buffered"),
+    FastFramed("fastframed"),
+    Framed("framed"),
+    Zlib("zlib")
+}
+
+class TestServerCommand : CliktCommand() {
+    private val port: Int by option(help = "The cross test port to connect to").int().default(9090)
+    private val protocolType: ProtocolType by option("--protocol", help = "Protocol type")
+        .enum<ProtocolType> { it.key }
+        .default(ProtocolType.Binary)
+    private val transportType: TransportType by option("--transport", help = "Transport type")
+        .enum<TransportType> { it.key }
+        .default(TransportType.Buffered)
+    private val serverType: ServerType by option("--server-type")
+        .enum<ServerType> { it.key }
+        .default(ServerType.NonBlocking)
+    private val useSSL: Boolean by option("--ssl", help = "Use SSL for encrypted transport")
+        .flag(default = false)
+    private val stringLimit: Long by option("--string-limit").long().default(-1)
+    private val containerLimit: Long by option("--container-limit").long().default(-1)
+
+    @Suppress("OPT_IN_USAGE")
+    override fun run() {
         val testHandler = TestHandler()
         val testProcessor = ThriftTestProcessor(testHandler, scope = GlobalScope)
         val secondHandler = TestServer.SecondHandler()
         val secondProcessor = SecondServiceProcessor(secondHandler, scope = GlobalScope)
+        val serverEngine: TServer =
+            getServerEngine(
+                testProcessor,
+                secondProcessor,
+                serverType,
+                port,
+                protocolType,
+                getProtocolFactory(),
+                getTransportFactory(),
+                useSSL
+            )
+        // Set server event handler
+        serverEngine.setServerEventHandler(TestServer.TestServerEventHandler())
+        // Run it
+        println(
+            "Starting the ${if (useSSL) "ssl server" else "server"} [$protocolType/$transportType/$serverType] on port $port"
+        )
+        serverEngine.serve()
+    }
 
-        // Protocol factory
-        val tProtocolFactory: TProtocolFactory =
-            when (protocolType) {
-                "json", "multij" -> {
-                    TJSONProtocol.Factory()
-                }
-                "compact", "multic" -> {
-                    TCompactProtocol.Factory(stringLimit, containerLimit)
-                }
-                else -> { // also covers multi
-                    TBinaryProtocol.Factory(stringLimit, containerLimit)
-                }
+    private fun getTransportFactory(): TTransportFactory =
+        when (transportType) {
+            TransportType.Framed -> {
+                TFramedTransport.Factory()
             }
-        val tTransportFactory: TTransportFactory =
-            when (transportType) {
-                "framed" -> {
-                    TFramedTransport.Factory()
-                }
-                "fastframed" -> {
-                    TFastFramedTransport.Factory()
-                }
-                "zlib" -> {
-                    TZlibTransport.Factory()
-                }
-                else -> { // .equals("buffered") => default value
-                    TTransportFactory()
-                }
+            TransportType.FastFramed -> {
+                TFastFramedTransport.Factory()
             }
-        val serverEngine: TServer
-        // If we are multiplexing services in one server...
-        val multiplexedProcessor = TMultiplexedProcessor()
-        multiplexedProcessor.registerDefault(testProcessor)
-        multiplexedProcessor.registerProcessor("ThriftTest", testProcessor)
-        multiplexedProcessor.registerProcessor("SecondService", secondProcessor)
-        if (serverType == "nonblocking" || serverType == "threaded-selector") {
-            // Nonblocking servers
+            TransportType.Zlib -> {
+                TZlibTransport.Factory()
+            }
+            TransportType.Buffered -> {
+                TTransportFactory()
+            }
+        }
+
+    private fun getProtocolFactory(): TProtocolFactory =
+        when (protocolType) {
+            ProtocolType.Json, ProtocolType.MultiJson -> TJSONProtocol.Factory()
+            ProtocolType.Compact, ProtocolType.MultiCompact ->
+                TCompactProtocol.Factory(stringLimit, containerLimit)
+            ProtocolType.Binary, ProtocolType.Multi ->
+                TBinaryProtocol.Factory(stringLimit, containerLimit)
+        }
+}
+
+fun main(args: Array<String>) {
+    TestServerCommand().main(args)
+}
+
+private fun getServerEngine(
+    testProcessor: ThriftTestProcessor,
+    secondProcessor: SecondServiceProcessor,
+    serverType: ServerType,
+    port: Int,
+    protocolType: ProtocolType,
+    tProtocolFactory: TProtocolFactory,
+    tTransportFactory: TTransportFactory,
+    ssl: Boolean
+): TServer {
+    val isMulti =
+        protocolType == ProtocolType.Multi ||
+            protocolType == ProtocolType.MultiCompact ||
+            protocolType == ProtocolType.MultiJson
+    // If we are multiplexing services in one server...
+    val multiplexedProcessor = TMultiplexedProcessor()
+    multiplexedProcessor.registerDefault(testProcessor)
+    multiplexedProcessor.registerProcessor("ThriftTest", testProcessor)
+    multiplexedProcessor.registerProcessor("SecondService", secondProcessor)
+    when (serverType) {
+        ServerType.NonBlocking, ServerType.ThreadedSelector -> {
             val tNonblockingServerSocket =
                 TNonblockingServerSocket(NonblockingAbstractServerSocketArgs().port(port))
-            if (serverType.contains("nonblocking")) {
-                // Nonblocking Server
-                val tNonblockingServerArgs = TNonblockingServer.Args(tNonblockingServerSocket)
-                tNonblockingServerArgs.processor(
-                    if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
-                )
-                tNonblockingServerArgs.protocolFactory(tProtocolFactory)
-                tNonblockingServerArgs.transportFactory(tTransportFactory)
-                serverEngine = TNonblockingServer(tNonblockingServerArgs)
-            } else { // server_type.equals("threaded-selector")
-                // ThreadedSelector Server
-                val tThreadedSelectorServerArgs =
-                    TThreadedSelectorServer.Args(tNonblockingServerSocket)
-                tThreadedSelectorServerArgs.processor(
-                    if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
-                )
-                tThreadedSelectorServerArgs.protocolFactory(tProtocolFactory)
-                tThreadedSelectorServerArgs.transportFactory(tTransportFactory)
-                serverEngine = TThreadedSelectorServer(tThreadedSelectorServerArgs)
+            when (serverType) {
+                ServerType.NonBlocking -> {
+                    val tNonblockingServerArgs = TNonblockingServer.Args(tNonblockingServerSocket)
+                    tNonblockingServerArgs.processor(
+                        if (isMulti) multiplexedProcessor else testProcessor
+                    )
+                    tNonblockingServerArgs.protocolFactory(tProtocolFactory)
+                    tNonblockingServerArgs.transportFactory(tTransportFactory)
+                    return TNonblockingServer(tNonblockingServerArgs)
+                }
+                else -> {
+                    val tThreadedSelectorServerArgs =
+                        TThreadedSelectorServer.Args(tNonblockingServerSocket)
+                    tThreadedSelectorServerArgs.processor(
+                        if (isMulti) multiplexedProcessor else testProcessor
+                    )
+                    tThreadedSelectorServerArgs.protocolFactory(tProtocolFactory)
+                    tThreadedSelectorServerArgs.transportFactory(tTransportFactory)
+                    return TThreadedSelectorServer(tThreadedSelectorServerArgs)
+                }
             }
-        } else {
-            // Blocking servers
-
+        }
+        ServerType.Simple, ServerType.ThreadPool -> {
             // SSL socket
             val tServerSocket: TServerSocket =
                 if (ssl) {
@@ -314,46 +290,24 @@
                 } else {
                     TServerSocket(ServerSocketTransportArgs().port(port))
                 }
-            if (serverType == "simple") {
-                // Simple Server
-                val tServerArgs = TServer.Args(tServerSocket)
-                tServerArgs.processor(
-                    if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
-                )
-                tServerArgs.protocolFactory(tProtocolFactory)
-                tServerArgs.transportFactory(tTransportFactory)
-                serverEngine = TSimpleServer(tServerArgs)
-            } else { // server_type.equals("threadpool")
-                // ThreadPool Server
-                val tThreadPoolServerArgs = TThreadPoolServer.Args(tServerSocket)
-                tThreadPoolServerArgs.processor(
-                    if (protocolType.startsWith("multi")) multiplexedProcessor else testProcessor
-                )
-                tThreadPoolServerArgs.protocolFactory(tProtocolFactory)
-                tThreadPoolServerArgs.transportFactory(tTransportFactory)
-                serverEngine = TThreadPoolServer(tThreadPoolServerArgs)
+            when (serverType) {
+                ServerType.Simple -> {
+                    val tServerArgs = TServer.Args(tServerSocket)
+                    tServerArgs.processor(if (isMulti) multiplexedProcessor else testProcessor)
+                    tServerArgs.protocolFactory(tProtocolFactory)
+                    tServerArgs.transportFactory(tTransportFactory)
+                    return TSimpleServer(tServerArgs)
+                }
+                else -> {
+                    val tThreadPoolServerArgs = TThreadPoolServer.Args(tServerSocket)
+                    tThreadPoolServerArgs.processor(
+                        if (isMulti) multiplexedProcessor else testProcessor
+                    )
+                    tThreadPoolServerArgs.protocolFactory(tProtocolFactory)
+                    tThreadPoolServerArgs.transportFactory(tTransportFactory)
+                    return TThreadPoolServer(tThreadPoolServerArgs)
+                }
             }
         }
-
-        // Set server event handler
-        serverEngine.setServerEventHandler(TestServer.TestServerEventHandler())
-
-        // Run it
-        println(
-            "Starting the " +
-                (if (ssl) "ssl server" else "server") +
-                " [" +
-                protocolType +
-                "/" +
-                transportType +
-                "/" +
-                serverType +
-                "] on " +
-                if (domainSocket === "") "port $port" else "unix socket $domainSocket"
-        )
-        serverEngine.serve()
-    } catch (x: Exception) {
-        x.printStackTrace()
     }
-    println("done.")
 }