package com.xebialabs.xlplatform.security.sql

import java.util

import com.xebialabs.deployit.booter.local.utils.Strings.isEmpty
import com.xebialabs.deployit.engine.api.dto.Ordering
import com.xebialabs.deployit.checks.Checks.checkArgument
import com.xebialabs.deployit.engine.api.dto.Paging
import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.deployit.security.{Permissions, Role, RoleService}
import com.xebialabs.xlplatform.repository.sql.Database
import com.xebialabs.xlplatform.security.sql.db.Ids.toDbId
import com.xebialabs.xlplatform.security.sql.db.Tables.RolePrincipal
import com.xebialabs.xlplatform.security.sql.db.{Ids, Tables}
import grizzled.slf4j.Logging
import org.springframework.security.core.Authentication
import slick.dbio.DBIOAction.{failed, seq}
import slick.jdbc.JdbcProfile
import slick.lifted.{Query => SQuery}

import scala.collection.JavaConverters._
import scala.concurrent.ExecutionContext

class SqlRoleService(securityDatabase: Database) extends RoleService with Logging {

  import securityDatabase._

  type Q = SQuery[Tables.Roles, Tables.Role, Seq]

  type QP = SQuery[(Tables.Roles, Tables.RolePrincipals), (Tables.Role, RolePrincipal), Seq]

  val profile: JdbcProfile = config.databaseType.profile

  import profile.api._

  override def getRoles(rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    getGlobalRoles(rolePattern, paging, order).map(toRole(_)).asJavaMutable()

  override def getRoles(onConfigurationItem: String, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    runAwait(Tables.roles.filter(_.isOnConfigurationItem(Option(onConfigurationItem))).result).map(toRole(_)).asJavaMutable()

  override def getRolesFor(principal: String, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    queryRoles(Seq(principal), rolePattern, paging, order).asJavaMutable()

  override def getRolesFor(auth: Authentication, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    queryRoles(Permissions.authenticationToPrincipals(auth).asScala, rolePattern, paging, order).asJavaMutable()

  private def queryRoles(principalNames: Iterable[String], rolePattern: String, paging: Paging, order: Ordering): Seq[Role] = {
    val query: QP = Tables.roles.join(Tables.rolePrincipals).on(_.id === _.roleId).filter {
      case (role, rolePrincipal) => role.isGlobal && rolePrincipal.principalName.in(principalNames)
    }

    val filters = List(
      (query: QP) => Option(rolePattern) match {
        case Some(pattern) if pattern.nonEmpty => query.filter(_._1.name.toLowerCase.like(s"%${rolePattern.toLowerCase}%"))
        case _ => query
      },
      (query: QP) => Option(order) match {
        case Some(ord) => if (ord.isAscending) query.sortBy(_._1.name.toLowerCase.asc) else query.sortBy(_._1.name.toLowerCase.desc)
        case _ => query
      },
      (query: QP) => withPaging(query, paging)
    )
    runAwait(
      filters
          .foldLeft(query)((acc, filter) => filter(acc))
        .map(_._1).result
    ).map(toRole(_))
  }

  override def getRoleForRoleName(roleName: String): Role = queryRoleAssignments((role, _, _) => role.name === roleName).toList match {
    case role :: _ => role
    case Nil => throw new NotFoundException("Could not find the role [%s]", roleName)
  }

  override def readRoleAssignments(rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    queryRoleAssignments(None, rolePattern, paging, order).asJavaMutable()

  override def readRoleAssignments(onConfigurationItem: String, rolePattern: String, paging: Paging, order: Ordering): util.List[Role] =
    queryRoleAssignments(Some(onConfigurationItem), rolePattern, paging, order).asJavaMutable()

  private def queryRoleAssignments(onConfigurationItem: Option[String], rolePattern: String, paging: Paging, order: Ordering): Seq[Role] = {
    logger.debug(s"Reading role assignments from [$onConfigurationItem")

    val roleAssignments = queryRoleAssignments((role, _, _) => role.isOnConfigurationItem(onConfigurationItem))

    logger.debug(s"Read from [$onConfigurationItem] role assignments ${roleAssignments.map(formatAssignments)}")
    roleAssignments
  }

  private def queryRoleAssignments(filter: (Tables.Roles, Rep[Option[Tables.RolePrincipals]], Rep[Option[Tables.RoleRoles]]) => Rep[Boolean]): Seq[Role] = {
    val results = runAwait {
      Tables.roles
        .joinLeft(Tables.rolePrincipals).on { case (role, rolePrincipal) => role.id === rolePrincipal.roleId }
        .joinLeft(Tables.roleRoles).on { case ((role, _), roleRole) => role.id === roleRole.roleId }
        .filter { case ((role, rolePrincipal), roleRole) => filter(role, rolePrincipal, roleRole) }
        .map { case ((role, rolePrincipal), roleRole) => (role, rolePrincipal, roleRole) }
        .result
    }

    val memberRoles = results.flatMap(_._3.map(_.memberRoleId)).distinct match {
      case Seq() => Seq.empty
      case roleIds => runAwait(Tables.roles.filter(_.id in roleIds).result)
    }

    results.groupBy(_._1)
      .map { case (role, children) =>
        val roleNames = children.flatMap(_._3).distinct.flatMap(roleRole => memberRoles.find(_.id == roleRole.memberRoleId)).map(_.name)
        val principalNames = children.flatMap(_._2).distinct.map(_.principalName)
        new Role(role.id, role.name, principalNames.asJavaMutable(), roleNames.asJavaMutable())
      }.toList
  }

  override def create(name: String, onConfigurationItem: String): String = {
    create(onConfigurationItem, new Role(name))
    getRoleForRoleName(name).getId
  }

  override def writeRoleAssignments(roles: util.List[Role]): Unit = doWriteRoleAssignments(None, roles.asScala)

  override def writeRoleAssignments(onConfigurationItem: String, roles: util.List[Role]): Unit = doWriteRoleAssignments(Some(onConfigurationItem), roles.asScala)

  override def createOrUpdateRole(role: Role, onConfigurationItem: String): String = {
    val findFirstGlobalRoleIdByNameOrId = Tables.roles
      .filter(roleRow => roleRow.isGlobal && (roleRow.name === role.getName || roleRow.id === role.getId))
      .map(roleRow => roleRow.id)
      .result

    def createOrUpdateAction(existingIds: Seq[String]) = existingIds match {
      case found if found.exists(_ != role.getId) =>
        failed(new IllegalArgumentException(s"A role named '${role.getName}' already exists."))
      case found if found.contains(role.getId) =>
        updateActions(Seq(role))
      case _ =>
        createActions(Seq(role), Option(onConfigurationItem))

    }

    implicit val ec: ExecutionContext = securityDatabase.database.ioExecutionContext
    runAwaitTry {
      val transaction = for {
        maybeIdAndName <- findFirstGlobalRoleIdByNameOrId
        _ <- createOrUpdateAction(maybeIdAndName)
      } yield role.getId
      transaction.transactionally
    }.get
  }

  override def rename(name: String, newName: String, onConfigurationItem: String): String = {
    val roleId: String = getRoleForRoleName(name).getId;
    Tables.roles.filter(_.id === roleId).map(_.name).update(newName)
    roleId
  }

  override def deleteByName(name: String): Unit = {
    delete(getRoleForRoleName(name).getId)
  }

  override def deleteById(roleId: String): Unit = {
    delete(roleId)
  }

  private def doWriteRoleAssignments(onConfigurationItem: Option[String], updatedRoles: Seq[Role]): Unit = {
    logger.debug(s"Writing role assignments ${updatedRoles.map(formatAssignments)} to CI [$onConfigurationItem]")

    val duplicateRoles = updatedRoles.groupBy(_.getName).filter(_._2.size > 1)
    checkArgument(duplicateRoles.isEmpty, s"Roles with duplicate names [${duplicateRoles.keys.mkString(", ")}] are not allowed")

    updatedRoles.foreach(generateIdIfNecessary)

    val originalRoles = queryRoleAssignments(onConfigurationItem, null, null, null)
    val rolesDiff = Diff(originalRoles.toSet, updatedRoles.toSet)
    val globalRoles = Some(getGlobalRoles(null, null, null))

    runAwait(
      seq(
        deleteActions(rolesDiff.deletedEntries.map(_.getId)),
        createActions(rolesDiff.newEntries, onConfigurationItem, globalRoles),
        updateActions(rolesDiff.updatedEntries, globalRoles)
      ).transactionally
    )
  }

  def create(roles: Role*): Unit = create(null, roles: _*)

  def create(onConfigurationItem: String, roles: Role*): Unit = {
    roles.foreach(generateIdIfNecessary)
    runAwait(createActions(roles, Option(onConfigurationItem)).transactionally)
  }

  private def createActions(roles: Iterable[Role], onConfigurationItem: Option[String], globalRoles: Option[Iterable[Tables.Role]] = None): DBIOAction[_, NoStream, Effect.Write] = {
    val dbRoles = toDbRoles(roles, onConfigurationItem)
    val dbPrincipals = toDbPrincipals(roles)
    val dbRoleRoles = toDbRoleRoles(roles, globalRoles)

    seq(
      ifNotEmpty(dbRoles)(items => Seq(Tables.roles ++= items)) ++
        ifNotEmpty(dbPrincipals)(items => Seq(Tables.rolePrincipals ++= items)) ++
        ifNotEmpty(dbRoleRoles)(items => Seq(Tables.roleRoles ++= items)): _*
    )
  }

  def update(roles: Role*): Unit = runAwait(updateActions(roles).transactionally)

  private def updateActions(roles: Iterable[Role]): DBIOAction[_, NoStream, Effect.Write] = updateActions(roles, roles, roles)

  private def updateActions(rolesDiff: Iterable[(Role, Role)], globalRoles: Option[Iterable[Tables.Role]]): DBIOAction[_, NoStream, Effect.Write] = {
    val rolesToUpdate = rolesDiff.filter { case (original, updated) => original.getName != updated.getName }.map(_._2)

    val rolesWithPrincipalsToUpdate = rolesDiff.filter { case (original, updated) =>
      val diff = Diff(original.getPrincipals.asScala.toSet, updated.getPrincipals.asScala.toSet)
      diff.newEntries.nonEmpty || diff.deletedEntries.nonEmpty
    }.map(_._2)

    val rolesWithRolesToUpdate = rolesDiff.filter { case (original, updated) =>
      val diff = Diff(original.getRoles.asScala.toSet, updated.getRoles.asScala.toSet)
      diff.newEntries.nonEmpty || diff.deletedEntries.nonEmpty
    }.map(_._2)

    updateActions(rolesToUpdate, rolesWithPrincipalsToUpdate, rolesWithRolesToUpdate)
  }

  private def updateActions(rolesToUpdate: Iterable[Role],
                            rolesWithPrincipalsToUpdate: Iterable[Role],
                            rolesWithRolesToUpdate: Iterable[Role],
                            globalRoles: Option[Iterable[Tables.Role]] = None): DBIOAction[_, NoStream, Effect.Write] = {
    seq(
      ifNotEmpty(rolesToUpdate)(_.map(r => Tables.roles.filter(_.id === r.getId).map(_.name).update(r.getName)).toSeq) ++

        ifNotEmpty(rolesWithPrincipalsToUpdate)(items => Seq(Tables.rolePrincipals.filter(_.roleId in items.map(_.getId)).delete)) ++
        ifNotEmpty(rolesWithPrincipalsToUpdate)(items => Seq(Tables.rolePrincipals ++= toDbPrincipals(items))) ++

        ifNotEmpty(rolesWithRolesToUpdate)(items => Seq(Tables.roleRoles.filter(_.roleId in items.map(_.getId)).delete)) ++
        ifNotEmpty(rolesWithRolesToUpdate)(items => Seq(Tables.roleRoles ++= toDbRoleRoles(items, globalRoles))): _*
    )
  }

  def delete(roleIds: String*): Unit = runAwait(deleteActions(roleIds).transactionally)

  private def deleteActions(roleIds: Iterable[String]): DBIOAction[Unit, NoStream, Effect.Write] = seq(
    ifNotEmpty(roleIds)(roleIds => Seq(Tables.rolePermissions.filter(_.roleId in roleIds).delete)) ++
      ifNotEmpty(roleIds)(roleIds => Seq(Tables.rolePrincipals.filter(_.roleId in roleIds).delete)) ++
      ifNotEmpty(roleIds)(roleIds => Seq(Tables.roleRoles.filter(roleRole => roleRole.memberRoleId.in(roleIds) || roleRole.roleId.in(roleIds)).delete)) ++
      ifNotEmpty(roleIds)(roleIds => Seq(Tables.roles.filter(_.id in roleIds).delete)): _*
  )

  private def generateIdIfNecessary(role: Role) = {
    if (isEmpty(role.getId) || role.getId == "-1") {
      role.setId(Ids.generate())
    }
  }

  private def toDbRoleRoles(roles: Iterable[Role], globalRolesOpt: Option[Iterable[Tables.Role]] = None) = {
    val globalRoles = globalRolesOpt.getOrElse(if (roles.isEmpty) Seq.empty else getGlobalRoles(null, null, null))

    roles.flatMap(r => r.getRoles.asScala.map(roleName => Tables.RoleRole(r.getId, roleNameToId(roleName, globalRoles))))
  }

  private def withRolePattern(query: Q, rolePattern: String) = Option(rolePattern) match {
    case Some(pattern) if pattern.nonEmpty => query.filter(_.name.toLowerCase.like(s"%${rolePattern.toLowerCase}%"))
    case _ => query
  }

  private def withOrder(query: Q, order: Ordering) = Option(order) match {
    case Some(ord) => if (ord.isAscending) query.sortBy(_.name.toLowerCase.asc) else query.sortBy(_.name.toLowerCase.desc)
    case _ => query
  }

  private def withPaging[A, B](query: SQuery[A, B, Seq], paging: Paging) = Option(paging) match {
    case Some(p) => query
      .drop((p.page - 1) * p.resultsPerPage)
      .take(p.resultsPerPage)
    case _ => query
  }

  private def getGlobalRoles(rolePattern: String, paging: Paging, order: Ordering): Seq[Tables.Role] = {
    var query: Q = Tables.roles.filter(_.isGlobal)
    query = withRolePattern(query, rolePattern)
    query = withOrder(query, order)
    query = withPaging(query, paging)
    runAwait(query.result)
  }

  private def toDbRoles(roles: Iterable[Role], onConfigurationItem: Option[String]) =
    roles.map(r => Tables.Role(r.getId, r.getName, toDbId(onConfigurationItem)))

  private def toDbPrincipals(roles: Iterable[Role]) =
    roles.flatMap(r => r.getPrincipals.asScala.map(Tables.RolePrincipal(r.getId, _)))

  private def roleNameToId(roleName: String, roles: Iterable[Tables.Role]): String = roles.find(_.name == roleName) match {
    case Some(role) => role.id
    case None => throw new NotFoundException("Role [%s] not found", roleName)
  }

  private def toRole(role: Tables.Role, principals: Seq[Tables.RolePrincipal] = Seq.empty): Role =
    new Role(role.id, role.name, principals.map(_.principalName).asJava)

  private def formatAssignments(role: Role) = s"${role.getName} -> PRINCIPALS=${role.getPrincipals}, ROLES=${role.getRoles}"

  private def ifNotEmpty[T, R](items: Iterable[T])(mapper: Iterable[T] => Seq[R]): Seq[R] = if (items.isEmpty) Seq.empty else mapper(items)

  def countRoles(onConfigurationItem: Number, rolePattern: String): Long = {
    var query = Tables.roles.filter(_.isGlobal).filter(_.ciId === onConfigurationItem.intValue())
    runAwait(withRolePattern(query, rolePattern).length.result).toLong
  }

  def roleExists(roleName: String): Boolean = {
    var query = Tables.roles.filter(_.name === roleName).exists
    runAwait(query.result)
  }

  /**
    * Returns the total of roles for specified CI path
    */
  def countRoles(onConfigurationItem: String, rolePattern: String): Long = {
    var query = Tables.roles.filter(_.isGlobal).filter(_.isOnConfigurationItem(Option(onConfigurationItem)))
    runAwait(withRolePattern(query, rolePattern).length.result).toLong
  }

}

case class Diff[T](original: Set[T], updated: Set[T]) {
  lazy val newEntries: Set[T] = updated -- original
  lazy val updatedEntries: Set[(T, T)] = for {
    u <- updated
    o <- original
    if u == o
  } yield (o, u)
  lazy val deletedEntries: Set[T] = original -- updated
}