Skip to content

Commit

Permalink
Allow users to pass a thread pool to run things on
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarchambault committed Apr 26, 2022
1 parent 831445b commit 1752556
Showing 1 changed file with 42 additions and 24 deletions.
66 changes: 42 additions & 24 deletions core/src/main/scala/snailgun/protocol/Protocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import java.nio.file.Paths
import java.nio.file.Files
import java.nio.ByteBuffer

import java.util.concurrent.ExecutorService
import java.util.concurrent.Future
import java.util.concurrent.Semaphore
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
Expand Down Expand Up @@ -45,7 +47,8 @@ class Protocol(
environment: Map[String, String],
logger: Logger,
stopFurtherProcessing: AtomicBoolean,
interactiveSession: Boolean
interactiveSession: Boolean,
threadPoolOpt: Option[ExecutorService] = None
) {
private val absoluteCwd = cwd.toAbsolutePath().toString
private val exitCode: AtomicInteger = new AtomicInteger(-1)
Expand Down Expand Up @@ -74,10 +77,8 @@ class Protocol(
val in = new DataInputStream(in0)
val out = new DataOutputStream(out0)

val sendStdinOpt = createStdinThread(out)
var sendStdinOpt = Option.empty[(Either[Future[_], Thread], Semaphore)]
val scheduleHeartbeat = createHeartbeatAndShutdownThread(in, out)
// Start heartbeat thread before sending command as python and C clients do
scheduleHeartbeat.start()

try {
// Send client command's environment to Nailgun server
Expand All @@ -93,7 +94,7 @@ class Protocol(

// Start thread sending stdin right after sending command
logger.debug("Starting thread to read stdin...")
sendStdinOpt.foreach(_._1.start())
sendStdinOpt = createStdinThread(out)

while (exitCode.get() == -1) {
val action = processChunkFromServer(in)
Expand Down Expand Up @@ -129,13 +130,22 @@ class Protocol(
}

if (stopFurtherProcessing.get()) {
sendStdinOpt.foreach(_._1.interrupt())
sendStdinOpt.map(_._1).foreach {
case Left(f) => f.cancel(true)
case Right(t) => t.interrupt()
}
}

logger.debug("Waiting for stdin thread to finish...")
sendStdinOpt.foreach(_._1.join())
sendStdinOpt.map(_._1).foreach {
case Left(f) => f.get()
case Right(t) => t.join()
}
logger.debug("Waiting for heartbeat thread to finish...")
scheduleHeartbeat.join()
scheduleHeartbeat match {
case Left(f) => f.get()
case Right(t) => t.join()
}
logger.debug("Returning exit code...")
exitCode.get()
}
Expand Down Expand Up @@ -198,7 +208,7 @@ class Protocol(
def createHeartbeatAndShutdownThread(
in: DataInputStream,
out: DataOutputStream
): Thread = {
): Either[Future[_], Thread] = {
daemonThread("snailgun-heartbeat") { () =>
var continue: Boolean = true
while (continue) {
Expand Down Expand Up @@ -226,10 +236,10 @@ class Protocol(
}
}

def createStdinThread(out: DataOutputStream): Option[(Thread, Semaphore)] = {
def createStdinThread(out: DataOutputStream): Option[(Either[Future[_], Thread], Semaphore)] = {
streams.in.map { in =>
val sendStdinSemaphore = new Semaphore(0)
val thread = daemonThread("snailgun-stdin") { () =>
val threadOrFuture = daemonThread("snailgun-stdin") { () =>
val reader = new BufferedReader(new InputStreamReader(in))
def shouldStop = !isRunning.get() || stopFurtherProcessing.get()
try {
Expand Down Expand Up @@ -261,7 +271,7 @@ class Protocol(
}
} finally reader.close()
}
(thread, sendStdinSemaphore)
(threadOrFuture, sendStdinSemaphore)
}
}

Expand Down Expand Up @@ -293,19 +303,27 @@ class Protocol(
logger.trace(exception)
}

private def daemonThread(name: String)(run0: () => Unit): Thread = {
val t = new Thread(name) {
override def run(): Unit = {
try run0()
catch {
case NonFatal(exception) =>
if (anyThreadFailed.compareAndSet(false, true)) {
printException(exception)
}
}
private def daemonThread(name: String)(run0: () => Unit): Either[Future[_], Thread] = {

val runnable: Runnable = { () =>
try run0()
catch {
case NonFatal(exception) =>
if (anyThreadFailed.compareAndSet(false, true)) {
printException(exception)
}
}
}
t.setDaemon(true)
t

threadPoolOpt match {
case Some(threadPool) =>
val f = threadPool.submit(runnable)
Left(f)
case None =>
val t = new Thread(runnable, name)
t.setDaemon(true)
t.start()
Right(t)
}
}
}

0 comments on commit 1752556

Please sign in to comment.