package com.xebialabs.xlplatform.security.sql

import java.util

import com.xebialabs.deployit.booter.local.utils.Strings.isEmpty
import com.xebialabs.deployit.checks.Checks.checkArgument
import com.xebialabs.deployit.exception.NotFoundException
import com.xebialabs.deployit.security.{Permissions, Role, RoleService}
import com.xebialabs.xlplatform.repository.sql.Database
import com.xebialabs.xlplatform.security.sql.db.Ids.toDbId
import com.xebialabs.xlplatform.security.sql.db.{Ids, Tables}
import grizzled.slf4j.Logging
import org.springframework.security.core.Authentication
import slick.dbio.DBIOAction.seq
import slick.jdbc.JdbcProfile

import scala.collection.JavaConverters._

class SqlRoleService(securityDatabase: Database) extends RoleService with Logging {

  import securityDatabase._

  val profile: JdbcProfile = config.databaseType.profile

  import profile.api._

  override def getRoles: util.List[Role] = getGlobalRoles().map(toRole(_)).asJavaMutable()

  override def getRoles(onConfigurationItem: String): util.List[Role] =
    runAwait(Tables.roles.filter(_.isOnConfigurationItem(Option(onConfigurationItem))).result).map(toRole(_)).asJavaMutable()

  override def getRolesFor(principal: String): util.List[Role] = getRolesFor(Seq(principal)).asJavaMutable()

  override def getRolesFor(auth: Authentication): util.List[Role] = getRolesFor(Permissions.authenticationToPrincipals(auth).asScala).asJavaMutable()

  private def getRolesFor(principalNames: Iterable[String]): Seq[Role] = runAwait(
    Tables.roles.join(Tables.rolePrincipals).on(_.id === _.roleId).filter {
      case (role, rolePrincipal) => role.isGlobal && rolePrincipal.principalName.in(principalNames)
    }.map(_._1).result
  ).map(toRole(_))

  override def getRoleForRoleName(roleName: String): Role = queryRoleAssignments((role, _, _) => role.name === roleName).toList match {
    case role :: _ => role
    case Nil => throw new NotFoundException("Could not find the role [%s]", roleName)
  }

  override def readRoleAssignments(): util.List[Role] = readRoleAssignments(None).asJavaMutable()

  override def readRoleAssignments(onConfigurationItem: String): util.List[Role] = readRoleAssignments(Some(onConfigurationItem)).asJavaMutable()

  private def readRoleAssignments(onConfigurationItem: Option[String]): Seq[Role] = {
    logger.debug(s"Reading role assignments from [$onConfigurationItem")

    val roleAssignments = queryRoleAssignments((role, _, _) => role.isOnConfigurationItem(onConfigurationItem))

    logger.debug(s"Read from [$onConfigurationItem] role assignments ${roleAssignments.map(formatAssignments)}")
    roleAssignments
  }

  private def queryRoleAssignments(filter: (Tables.Roles, Rep[Option[Tables.RolePrincipals]], Rep[Option[Tables.RoleRoles]]) => Rep[Boolean]): Seq[Role] = {
    val results = runAwait {
      Tables.roles
        .joinLeft(Tables.rolePrincipals).on { case (role, rolePrincipal) => role.id === rolePrincipal.roleId }
        .joinLeft(Tables.roleRoles).on { case ((role, _), roleRole) => role.id === roleRole.roleId }
        .filter { case ((role, rolePrincipal), roleRole) => filter(role, rolePrincipal, roleRole) }
        .map { case ((role, rolePrincipal), roleRole) => (role, rolePrincipal, roleRole) }
        .result
    }

    val memberRoles = results.flatMap(_._3.map(_.memberRoleId)).distinct match {
      case Seq() => Seq.empty
      case roleIds => runAwait(Tables.roles.filter(_.id in roleIds).result)
    }

    results.groupBy(_._1)
      .map { case (role, children) =>
        val roleNames = children.flatMap(_._3).distinct.flatMap(roleRole => memberRoles.find(_.id == roleRole.memberRoleId)).map(_.name)
        val principalNames = children.flatMap(_._2).distinct.map(_.principalName)
        new Role(role.id, role.name, principalNames.asJavaMutable(), roleNames.asJavaMutable())
      }.toList
  }

  override def writeRoleAssignments(roles: util.List[Role]): Unit = writeRoleAssignments(None, roles.asScala)

  override def writeRoleAssignments(onConfigurationItem: String, roles: util.List[Role]): Unit = writeRoleAssignments(Some(onConfigurationItem), roles.asScala)

  private def writeRoleAssignments(onConfigurationItem: Option[String], updatedRoles: Seq[Role]): Unit = {
    logger.debug(s"Writing role assignments ${updatedRoles.map(formatAssignments)} to CI [$onConfigurationItem]")

    val duplicateRoles = updatedRoles.groupBy(_.getName).filter(_._2.size > 1)
    checkArgument(duplicateRoles.isEmpty, s"Roles with duplicate names [${duplicateRoles.keys.mkString(", ")}] are not allowed")

    updatedRoles.foreach(generateIdIfNecessary)

    val originalRoles = readRoleAssignments(onConfigurationItem)
    val rolesDiff = Diff(originalRoles.toSet, updatedRoles.toSet)
    val globalRoles = Some(getGlobalRoles())

    runAwait(
      seq(
        deleteActions(rolesDiff.deletedEntries.map(_.getId)),
        createActions(rolesDiff.newEntries, onConfigurationItem, globalRoles),
        updateActions(rolesDiff.updatedEntries, globalRoles)
      ).transactionally
    )
  }

  def create(roles: Role*): Unit = create(null, roles: _*)

  def create(onConfigurationItem: String, roles: Role*): Unit = {
    roles.foreach(generateIdIfNecessary)
    runAwait(createActions(roles, Option(onConfigurationItem)).transactionally)
  }

  private def createActions(roles: Iterable[Role], onConfigurationItem: Option[String], globalRoles: Option[Iterable[Tables.Role]] = None): DBIOAction[_, NoStream, Effect.Write] = {
    val dbRoles = toDbRoles(roles, onConfigurationItem)
    val dbPrincipals = toDbPrincipals(roles)
    val dbRoleRoles = toDbRoleRoles(roles, globalRoles)

    seq(
      ifNotEmpty(dbRoles)(items => Seq(Tables.roles ++= items)) ++
        ifNotEmpty(dbPrincipals)(items => Seq(Tables.rolePrincipals ++= items)) ++
        ifNotEmpty(dbRoleRoles)(items => Seq(Tables.roleRoles ++= items)): _*
    )
  }

  def update(roles: Role*): Unit = runAwait(updateActions(roles).transactionally)

  private def updateActions(roles: Iterable[Role]): DBIOAction[_, NoStream, Effect.Write] = updateActions(roles, roles, roles)

  private def updateActions(rolesDiff: Iterable[(Role, Role)], globalRoles: Option[Iterable[Tables.Role]]): DBIOAction[_, NoStream, Effect.Write] = {
    val rolesToUpdate = rolesDiff.filter { case (original, updated) => original.getName != updated.getName }.map(_._2)

    val rolesWithPrincipalsToUpdate = rolesDiff.filter { case (original, updated) =>
      val diff = Diff(original.getPrincipals.asScala.toSet, updated.getPrincipals.asScala.toSet)
      diff.newEntries.nonEmpty || diff.deletedEntries.nonEmpty
    }.map(_._2)

    val rolesWithRolesToUpdate = rolesDiff.filter { case (original, updated) =>
      val diff = Diff(original.getRoles.asScala.toSet, updated.getRoles.asScala.toSet)
      diff.newEntries.nonEmpty || diff.deletedEntries.nonEmpty
    }.map(_._2)

    updateActions(rolesToUpdate, rolesWithPrincipalsToUpdate, rolesWithRolesToUpdate)
  }

  private def updateActions(rolesToUpdate: Iterable[Role],
                            rolesWithPrincipalsToUpdate: Iterable[Role],
                            rolesWithRolesToUpdate: Iterable[Role],
                            globalRoles: Option[Iterable[Tables.Role]] = None): DBIOAction[_, NoStream, Effect.Write] = {
    seq(
      ifNotEmpty(rolesToUpdate)(_.map(r => Tables.roles.filter(_.id === r.getId).map(_.name).update(r.getName)).toSeq) ++

        ifNotEmpty(rolesWithPrincipalsToUpdate)(items => Seq(Tables.rolePrincipals.filter(_.roleId in items.map(_.getId)).delete)) ++
        ifNotEmpty(rolesWithPrincipalsToUpdate)(items => Seq(Tables.rolePrincipals ++= toDbPrincipals(items))) ++

        ifNotEmpty(rolesWithRolesToUpdate)(items => Seq(Tables.roleRoles.filter(_.roleId in items.map(_.getId)).delete)) ++
        ifNotEmpty(rolesWithRolesToUpdate)(items => Seq(Tables.roleRoles ++= toDbRoleRoles(items, globalRoles))): _*
    )
  }

  def delete(roleIds: String*): Unit = runAwait(deleteActions(roleIds).transactionally)

  private def deleteActions(roleIds: Iterable[String]): DBIOAction[Unit, NoStream, Effect.Write] = seq(
    ifNotEmpty(roleIds)(roleIds => Seq(Tables.rolePermissions.filter(_.roleId in roleIds).delete)) ++
      ifNotEmpty(roleIds)(roleIds => Seq(Tables.rolePrincipals.filter(_.roleId in roleIds).delete)) ++
      ifNotEmpty(roleIds)(roleIds => Seq(Tables.roleRoles.filter(roleRole => roleRole.memberRoleId.in(roleIds) || roleRole.roleId.in(roleIds)).delete)) ++
      ifNotEmpty(roleIds)(roleIds => Seq(Tables.roles.filter(_.id in roleIds).delete)): _*
  )

  private def generateIdIfNecessary(role: Role) = {
    if (isEmpty(role.getId) || role.getId == "-1") {
      role.setId(Ids.generate())
    }
  }

  private def toDbRoleRoles(roles: Iterable[Role], globalRolesOpt: Option[Iterable[Tables.Role]] = None) = {
    val globalRoles = globalRolesOpt.getOrElse(if (roles.isEmpty) Seq.empty else getGlobalRoles())

    roles.flatMap(r => r.getRoles.asScala.map(roleName => Tables.RoleRole(r.getId, roleNameToId(roleName, globalRoles))))
  }

  private def getGlobalRoles(): Seq[Tables.Role] = runAwait(Tables.roles.filter(_.isGlobal).result)

  private def toDbRoles(roles: Iterable[Role], onConfigurationItem: Option[String]) =
    roles.map(r => Tables.Role(r.getId, r.getName, toDbId(onConfigurationItem)))

  private def toDbPrincipals(roles: Iterable[Role]) =
    roles.flatMap(r => r.getPrincipals.asScala.map(Tables.RolePrincipal(r.getId, _)))

  private def roleNameToId(roleName: String, roles: Iterable[Tables.Role]): String = roles.find(_.name == roleName) match {
    case Some(role) => role.id
    case None => throw new NotFoundException("Role [%s] not found", roleName)
  }

  private def toRole(role: Tables.Role, principals: Seq[Tables.RolePrincipal] = Seq.empty): Role =
    new Role(role.id, role.name, principals.map(_.principalName).asJava)

  private def formatAssignments(role: Role) = s"${role.getName} -> PRINCIPALS=${role.getPrincipals}, ROLES=${role.getRoles}"

  private def ifNotEmpty[T, R](items: Iterable[T])(mapper: Iterable[T] => Seq[R]): Seq[R] = if (items.isEmpty) Seq.empty else mapper(items)
}

case class Diff[T](original: Set[T], updated: Set[T]) {
  lazy val newEntries: Set[T] = updated -- original
  lazy val updatedEntries: Set[(T, T)] = for {
    u <- updated
    o <- original
    if u == o
  } yield (o, u)
  lazy val deletedEntries: Set[T] = original -- updated
}