package com.xebialabs.deployit.engine.tasker

import java.io._
import java.nio.file.{Files, StandardCopyOption}
import javax.crypto.SealedObject

import akka.actor.{Actor, ActorLogging, Props}
import com.xebialabs.deployit.engine.tasker.RecoverySupervisorActor.{Delete, Deleted, Write}
import com.xebialabs.deployit.engine.tasker.TaskRecoveryActor.WriteDelayed
import com.xebialabs.deployit.security.SecretKeyHolder
import com.xebialabs.xlplatform.utils.ResourceManagement._
import org.slf4j.LoggerFactory

import scala.concurrent.duration._
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

object RecoverySupervisorActor {

  def props(recoveryDir: File) = Props(new RecoverySupervisorActor(recoveryDir))

  val name = "recovery-supervisor"

  case class Write(task: Task)

  case class Delete(taskId: String)

  case class Deleted(taskId: String)

}

class RecoverySupervisorActor(recoveryDir: File) extends Actor with RecoveryWriter with ActorContextCreationSupport {

  override def receive: Actor.Receive = {
    case msg@Delete(taskId) =>
      lookUpOrCreateTaskRecoveryActor(taskId) forward msg
    case msg@Write(task) if task.getSpecification.isRecoverable =>
      lookUpOrCreateTaskRecoveryActor(task.getId) forward msg
  }

  private def lookUpOrCreateTaskRecoveryActor(taskId: TaskId) = {
    context.child(taskId).getOrElse {
      createChild(TaskRecoveryActor.props(taskId, recoveryDir), taskId)
    }
  }
}

object TaskRecoveryActor {
  def props(taskId: TaskId, recoveryDir: File) = Props(new TaskRecoveryActor(taskId, recoveryDir))

  case class WriteDelayed(task: Task)

}

class TaskRecoveryActor(taskId: TaskId, recoveryDir: File) extends Actor with ActorLogging with RecoveryWriter {

  import context._

  lazy private val keyHolder: SecretKeyHolder = SecretKeyHolder.get()

  override def receive: Actor.Receive = handleWrite orElse handleDelete

  def handleWrite: Actor.Receive = {
    case Write(task) =>
      implicit val dispatcher = system.dispatcher
      context.system.scheduler.scheduleOnce(1.second, self, WriteDelayed(task))
      become(handleWriteDelayed orElse handleDelete)
  }

  def handleWriteDelayed: Receive = {
    case WriteDelayed(task) =>
      writeSealedTask(task, keyHolder)(recoveryDir)
      become(handleWrite orElse handleDelete)
    case Write(task) =>
      log.debug(s"Already scheduled for writing a task recovery for task ${task.getId}")
  }

  def handleDelete: Receive = {
    case Delete(`taskId`) =>
      deleteRecoveryFile(taskId)(recoveryDir)
      sender() ! Deleted(taskId)
      system stop self
  }
}

trait RecoveryWriter {

  private[this] val log = LoggerFactory.getLogger(getClass)

  def writeSealedTaskTo(file: File, task: Task, keyHolder: SecretKeyHolder): File = {
    writeToFile(file, new SealedObject(taskToByteArray(task), keyHolder.getEncryption))
  }

  private def taskToByteArray(task: Task): TaskByteArrayWrapper = {
    import com.xebialabs.xlplatform.utils.ResourceManagement._
    val taskByteStream = new ByteArrayOutputStream()
    using(new ObjectOutputStream(taskByteStream)) { os =>
      os.writeObject(task)
    }
    new TaskByteArrayWrapper(taskByteStream.toByteArray)
  }

  def writeSealedTask(task: Task, keyHolder: SecretKeyHolder)(implicit recoveryDir: File): File = {
    writeSealedTaskTo(recoveryFile(task.getId), task, keyHolder)
  }

  private[tasker] def writeToFile[T](file: File, any: T) = {
    val tmpFile = new File(s"${file.getAbsolutePath}.tmp")
    import com.xebialabs.xlplatform.utils.ResourceManagement._
    try {
      log.debug(s"Writing recovery file to: ${file.getAbsolutePath}")
      using(new ObjectOutputStream(new FileOutputStream(tmpFile))) { os =>
        os.writeObject(any)
      }
      Files.move(tmpFile.toPath, file.toPath, StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.ATOMIC_MOVE)
      file
    } catch {
      case e: IOException =>
        log.error(s"Could not write recovery file [$file]", e)
        throw e
    }
  }

  def deleteRecoveryFile(id: String)(implicit recoveryDir: File): Boolean = {
    log.debug(s"Deleting recovery file for [$id]")
    recoveryFile(id).delete()
  }

  def recoveryFile(id: String)(implicit recoveryDir: File): File = {
    if (!recoveryDir.exists()) {
      recoveryDir.mkdirs()
    }
    new File(recoveryDir, s"$id.task")
  }
}

trait TaskRecovery extends RecoveryWriter {
  private[this] val log = LoggerFactory.getLogger(getClass)

  def recover(file: File)(implicit keyHolder: SecretKeyHolder): Option[Task] = {
    securedRecover(file).recoverWith {
      case ex: Exception =>
        log.error("Secured recovery failed. Trying plain recovery", ex)
        plainRecover(file).map {
          task: Task =>
            convertPlainToSealed(file, task)
            task
        }
    } match {
      case Success(task) =>
        task.recovered()
        if (task.getState.isFinal) {
          log.info(s"Task [${task.getId}] was recovered in final state [${task.getState}]. Removing recovery file.")
          file.delete()
          None
        } else {
          Option(task)
        }
      case Failure(e) if e.isInstanceOf[ClassNotFoundException] =>
        log.error(s"Could not find serialized class in recovery file [$file]", e)
        None
      case Failure(e) =>
        log.error(s"Could not read recovery file [$file]", e)
        None
    }
  }

  def convertPlainToSealed(file: File, task: Task)(implicit keyHolder: SecretKeyHolder): Task = {
    log.info(s"Converting plain task [$file] to sealed task")
    writeSealedTaskTo(file, task, keyHolder)
    task
  }

  def byteArrayToTask(wrapper: TaskByteArrayWrapper): Try[Task] = {
    Try(using(new CLObjectInputStream(new ByteArrayInputStream(wrapper.getSerializedTask))) { is =>
      is.readObject.asInstanceOf[Task]
    })
  }

  def securedRecover(file: File)(implicit keyHolder: SecretKeyHolder): Try[Task] = {
    val buffer: Try[TaskByteArrayWrapper] = Try {
      recover[SealedObject](file).map(_.getObject(keyHolder.getDecryption)).map(_.asInstanceOf[TaskByteArrayWrapper])
    }.flatten
    buffer.flatMap(byteArrayToTask)
  }

  def plainRecover(file: File): Try[Task] = recover[Task](file)

  private def recover[T: ClassTag](file: File): Try[T] = {
    import com.xebialabs.xlplatform.utils.ResourceManagement._
    log.info(s"Recovering [${implicitly[ClassTag[T]].runtimeClass.getName}] from file [$file]")
    Try(using(new CLObjectInputStream(new FileInputStream(file))) { is =>
      is.readObject.asInstanceOf[T]
    })
  }

}

class CLObjectInputStream(input: InputStream) extends ObjectInputStream(input) {
  override def resolveClass(desc: ObjectStreamClass): Class[_] = {
    import com.xebialabs.xlplatform.utils.ClassLoaderUtils._
    Try(super.resolveClass(desc)).getOrElse(desc.getName.loadClass)
  }
}
