diff --git a/core/src/main/scala/ox/race.scala b/core/src/main/scala/ox/race.scala index 6a57d91e..bc3b5cdb 100644 --- a/core/src/main/scala/ox/race.scala +++ b/core/src/main/scala/ox/race.scala @@ -4,6 +4,7 @@ import java.util.concurrent.ArrayBlockingQueue import scala.annotation.tailrec import scala.concurrent.TimeoutException import scala.concurrent.duration.FiniteDuration +import scala.util.control.{ControlThrowable, NonFatal} import scala.util.{Failure, Success, Try} /** A `Some` if the computation `t` took less than `duration`, and `None` otherwise. if the computation `t` throws an exception, it is @@ -37,8 +38,22 @@ def race[T](fs: Seq[() => T]): T = race(NoErrorMode)(fs) */ def race[E, F[_], T](em: ErrorMode[E, F])(fs: Seq[() => F[T]]): F[T] = unsupervised { - val result = new ArrayBlockingQueue[Try[F[T]]](fs.size) - fs.foreach(f => forkUnsupervised(result.put(Try(f())))) + val result = new ArrayBlockingQueue[RaceBranchResult[F[T]]](fs.size) + fs.foreach(f => + forkUnsupervised { + val r = + try RaceBranchResult.Success(f()) + catch + case NonFatal(e) => RaceBranchResult.NonFatalException(e) + // #213: we treat ControlThrowables as non-fatal, as in the context of `race` they should count as a + // "failed branch", but not cause immediate interruption + case e: ControlThrowable => RaceBranchResult.NonFatalException(e) + // #213: any other fatal exceptions must cause `race` to be interrupted immediately; this is needed as we + // are in an unsupervised scope, so by default exceptions aren't propagated + case e => RaceBranchResult.FatalException(e) + result.put(r) + } + ) @tailrec def takeUntilSuccess(failures: Vector[Either[E, Throwable]], left: Int): F[T] = @@ -57,10 +72,11 @@ def race[E, F[_], T](em: ErrorMode[E, F])(fs: Seq[() => F[T]]): F[T] = throw e else result.take() match - case Success(v) => + case RaceBranchResult.Success(v) => if em.isError(v) then takeUntilSuccess(failures :+ Left(em.getError(v)), left - 1) else v - case Failure(e) => takeUntilSuccess(failures :+ Right(e), left - 1) + case RaceBranchResult.NonFatalException(e) => takeUntilSuccess(failures :+ Right(e), left - 1) + case RaceBranchResult.FatalException(e) => throw e takeUntilSuccess(Vector.empty, fs.size) } @@ -113,7 +129,15 @@ def raceEither[E, T](f1: => Either[E, T], f2: => Either[E, T], f3: => Either[E, // /** Returns the result of the first computation to complete (either successfully or with an exception). */ -def raceResult[T](fs: Seq[() => T]): T = race(fs.map(f => () => Try(f()))).get // TODO optimize +def raceResult[T](fs: Seq[() => T]): T = race( + fs.map(f => + () => + // #213: the Try() constructor doesn't catch fatal exceptions; in this context, we want to propagate *all* + // exceptions as fast as possible + try Success(f()) + catch case e: Throwable => Failure(e) + ) +).get // TODO optimize /** Returns the result of the first computation to complete (either successfully or with an exception). */ def raceResult[T](f1: => T, f2: => T): T = raceResult(List(() => f1, () => f2)) @@ -123,3 +147,8 @@ def raceResult[T](f1: => T, f2: => T, f3: => T): T = raceResult(List(() => f1, ( /** Returns the result of the first computation to complete (either successfully or with an exception). */ def raceResult[T](f1: => T, f2: => T, f3: => T, f4: => T): T = raceResult(List(() => f1, () => f2, () => f3, () => f4)) + +private enum RaceBranchResult[+T]: + case Success(value: T) + case NonFatalException(throwable: Throwable) + case FatalException(throwable: Throwable) diff --git a/core/src/test/scala/ox/RaceTest.scala b/core/src/test/scala/ox/RaceTest.scala index 3b56ecee..af98936c 100644 --- a/core/src/test/scala/ox/RaceTest.scala +++ b/core/src/test/scala/ox/RaceTest.scala @@ -5,8 +5,10 @@ import org.scalatest.matchers.should.Matchers import ox.* import ox.util.Trail +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.TimeoutException import scala.concurrent.duration.DurationInt +import scala.util.control.ControlThrowable class RaceTest extends AnyFlatSpec with Matchers { "timeout" should "short-circuit a long computation" in { @@ -131,6 +133,40 @@ class RaceTest extends AnyFlatSpec with Matchers { e.getSuppressed.map(_.getMessage).toSet shouldBe Set("boom2!", "boom3!") } + it should "treat ControlThrowable as a non-fatal exception" in { + try + race( + throw new NastyControlThrowable("boom1!"), { + sleep(200.millis) + throw new NastyControlThrowable("boom2!") + }, { + sleep(200.millis) + throw new NastyControlThrowable("boom3!") + } + ) + fail("Race should throw") + catch + case e: Throwable => + e.getMessage shouldBe "boom1!" + // Suppressed exceptions are not available for ControlThrowable + } + + it should "immediately rethrow other fatal exceptions" in { + val flag = new AtomicBoolean(false) + try + race( + throw new StackOverflowError(), { + sleep(1.second) + flag.set(true) + throw new RuntimeException() + } + ) + fail("Race should throw") + catch + case e: StackOverflowError => // the expected exception + flag.get() shouldBe false // because a fatal exception was thrown, the second computation should be interrupted + } + "raceEither" should "return the first successful computation to complete" in { val trail = Trail() val start = System.currentTimeMillis() @@ -155,4 +191,48 @@ class RaceTest extends AnyFlatSpec with Matchers { trail.get shouldBe Vector("error", "slow") end - start should be < 1000L } + + "raceResult" should "immediately return when a normal exception occurs" in { + val flag = new AtomicBoolean(false) + try + raceResult( + throw new RuntimeException("boom!"), { + sleep(1.second) + flag.set(true) + } + ) + fail("raceResult should throw") + catch + case e: Throwable => + e.getMessage shouldBe "boom!" + flag.get() shouldBe false + } + + it should "immediately return when a control exception occurs" in { + val flag = new AtomicBoolean(false) + try + raceResult( + throw new NastyControlThrowable("boom!"), { + sleep(1.second) + flag.set(true) + } + ) + fail("raceResult should throw") + catch case e: NastyControlThrowable => flag.get() shouldBe false + } + + it should "immediately return when a fatal exception occurs" in { + val flag = new AtomicBoolean(false) + try + raceResult( + throw new StackOverflowError(), { + sleep(1.second) + flag.set(true) + } + ) + fail("raceResult should throw") + catch case e: StackOverflowError => flag.get() shouldBe false + } + + class NastyControlThrowable(val message: String) extends ControlThrowable(message) {} }