package com.xebialabs.xlrelease.stress.handlers

import java.util.concurrent.atomic.AtomicBoolean

import cats.effect._
import cats.implicits._
import com.xebialabs.xlrelease.stress.api
import org.joda.time.DateTime

import scala.concurrent.{CancellationException, ExecutionContext}
import scala.concurrent.duration._

class ControlHandler()(implicit ec: ExecutionContext) extends api.control.Control with api.control.Flow {
  private implicit val cs: ContextShift[IO] = IO.contextShift(ec)
  private implicit val timer: Timer[IO] = IO.timer(ec)

  private var isShuttingDown: AtomicBoolean = new AtomicBoolean(false)

  private def ifNotShuttingDown[A](p: => IO[A]): IO[A] = if (isShuttingDown.get) IO.raiseError(new CancellationException("Shutting down")) else p

  override def shutdown(): IO[Unit] = IO {
    isShuttingDown.set(true)
  }

  override def nop: IO[Unit] = IO.pure(())

  override def ok[A](a: A): IO[A] = IO.pure(a)

  override def fail[A](msg: => String, cause: => Throwable = null): IO[A] = IO.raiseError[A](new RuntimeException(msg, cause))

  override def error[A](err: => Throwable): IO[A] = IO.raiseError[A](err)

  override def sleep(duration: FiniteDuration): IO[Unit] =
    IO.sleep(duration)

  override def fork[A, B](left: IO[A], right: IO[B]): IO[(A, B)] =
    for {
      l <- IO.shift *> left.start
      r <- IO.shift *> right.start
      a <- l.join
      b <- r.join
    } yield (a, b)

  override def backgroundOf[A, B](foreground: IO[A])(background: IO[B]): IO[(A, List[B])] = {
    val fg: IO[A] = IO.shift *> foreground
    val bg: IO[B] = background

    var acc = List.empty[B]
    val stop = new AtomicBoolean(false)

    def bgLoop: IO[List[B]] = IO.suspend {
      if (isShuttingDown.get() || stop.get()) {
        IO.pure(acc)
      } else {
        val next = for {
          b <- bg
          _ <- IO {
            acc = b :: acc
          }
          res <- bgLoop
        } yield res
        IO.cancelBoundary *> next
      }
    }

    for {
      fa <- IO.shift *> fg.start
      fb <- IO.shift *> bgLoop.start
      a <- fa.join
      _ <- IO {
        stop.set(true)
      }
      bs <- fb.join
    } yield (a, bs)
  }

  override def repeat[A](n: Int)(program: IO[A]): IO[List[A]] = {
    (0 until n).toList.map(_ => ifNotShuttingDown(program)).sequence
  }

  override def parallel[A](n: Int)(p: Int => IO[A]): IO[List[A]] =
    (0 until n).toList
      .map(i => ifNotShuttingDown(p(i)))
      .map(code => IO.shift *> code.start)
      .parTraverse(async => async >>= (_.join))

  override def sequenced[A](n: Int, start: Int = 0)(p: Int => IO[A]): IO[List[A]] =
    (start until n).toList
      .map(i => ifNotShuttingDown(p(i)))
      .sequence

  override def concurrently[A](n: Int)(queue: Iterable[IO[A]]): IO[List[List[A]]] = {
    if (n <= 0) {
      IO.raiseError(new IllegalArgumentException("concurrently: first argument must be >= 1"))
    } else {
      var curr: Iterable[IO[A]] = queue

      val getNext: IO[Option[IO[A]]] = IO {
        synchronized {
          curr.headOption.map { job =>
            curr = curr.tail
            ifNotShuttingDown(job)
          }
        }
      }

      def runJob(previous: List[A] = List.empty[A]): IO[List[A]] =
        for {
          job <- getNext
          res <- job.sequence.flatMap {
            case None => IO.pure(previous)
            case Some(h) => runJob(h :: previous)
          }
        } yield res.reverse

      parallel[List[A]](n)(_ => runJob())
    }
  }

  override def now(): IO[DateTime] =
    IO(DateTime.now)

  override def time[A](p: IO[A]): IO[(FiniteDuration, A)] =
    for {
      start <- now()
      result <- p
      end <- now()
      duration = new FiniteDuration(end.getMillis - start.getMillis, MILLISECONDS)
    } yield duration -> result

  override def until[A](cond: A => Boolean, interval: FiniteDuration, retries: Option[Int])
                       (get: IO[A]): IO[A] = {
    require(retries.fold(true)(_ > 0))

    def loop(attempts: Int): IO[A] = {
      retries match {
        case Some(n) if attempts > n =>
          IO.raiseError(new RuntimeException(s"Exceeded allowed $n retries. Giving up"))
        case _ =>
          ifNotShuttingDown(get) flatMap { found =>
            if (cond(found))
              IO.pure(found)
            else
              for {
                _ <- sleep(interval)
                r <- loop(attempts + 1)
              } yield r
          }
      }
    }

    loop(attempts = 1)
  }
}
