package com.xebialabs.xlrelease.utils

import scala.collection.immutable.ListMap
import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._

case class Diff[K, A] private(before: Map[K, A],
                              after: Map[K, A],
                              areEqual: (A, A) => Boolean = (_: A, _: A) => false) {

  private case class Difference(deletedEntries: ListMap[K, A],
                                newEntries: ListMap[K, A],
                                updatedEntries: ListMap[K, (A, A)])

  private lazy val diff: Difference = {
    val deletedBuf = ListBuffer.empty[(K, A)]
    val newBuf = ListBuffer.empty[(K, A)]
    val updatedBuf = ListBuffer.empty[(K, (A, A))]

    before.foreachEntry { (key, beforeVal) =>
      after.get(key) match {
        case Some(afterVal) if !areEqual(beforeVal, afterVal) =>
          updatedBuf += key -> (beforeVal -> afterVal)
        case None => deletedBuf += key -> beforeVal
        case _ =>
      }
    }
    after.foreachEntry { (key, afterVal) =>
      if (!before.contains(key)) {
        newBuf += key -> afterVal
      }
    }

    Difference(
      deletedBuf.to(ListMap),
      newBuf.to(ListMap),
      updatedBuf.to(ListMap)
    )
  }

  lazy val deletedEntries: ListMap[K, A] = diff.deletedEntries
  lazy val newEntries: ListMap[K, A] = diff.newEntries
  lazy val updatedEntries: ListMap[K, (A, A)] = diff.updatedEntries

  def deletedKeys: Set[K] = deletedEntries.keySet

  def newKeys: Set[K] = newEntries.keySet

  def updatedKeys: Set[K] = updatedEntries.keySet

  def deletedValues: Iterable[A] = deletedEntries.values

  def newValues: Iterable[A] = newEntries.values

  def updatedPairs: Iterable[(A, A)] = updatedEntries.values

  def updatedValues: Iterable[A] = updatedPairs.map(_._2)

  def fold[B](init: B)
             (onNew: (B, (K, A)) => B,
              onUpdated: (B, (K, (A, A))) => B,
              onDeleted: (B, (K, A)) => B
             ): B = {
    var result = init
    deletedEntries.foreach { e =>
      result = onDeleted(result, e)
    }
    newEntries.foreach { e =>
      result = onNew(result, e)
    }
    updatedEntries.foreach { e =>
      result = onUpdated(result, e)
    }
    result
  }

  def foreach(onNew: (K, A) => Unit, onUpdated: (K, (A, A)) => Unit, onDeleted: (K, A) => Unit): Unit =
    fold(())(
      (_, t) => onNew(t._1, t._2),
      (_, t) => onUpdated(t._1, t._2),
      (_, t) => onDeleted(t._1, t._2)
    )

  def foreachValue(onNew: A => Unit, onUpdated: A => Unit, onDeleted: A => Unit): Unit =
    fold(())(
      (_, t) => onNew(t._2),
      (_, t) => onUpdated(t._2._2),
      (_, t) => onDeleted(t._2)
    )
}

object Diff {
  def apply[K, A](before: Map[K, A], after: Map[K, A]): Diff[K, A] = {
    new Diff(before, after)
  }

  def apply[A](before: Iterable[A], after: Iterable[A]): Diff[A, A] = {
    new Diff(
      before.map(x => (x, x)).toMap,
      after.map(x => (x, x)).toMap
    )
  }

  def applyWithKeyMapping[K, A](before: Iterable[A], after: Iterable[A])
                               (keyMapping: A => K): Diff[K, A] = {
    new Diff(
      before.map(x => (keyMapping(x), x)).toMap,
      after.map(x => (keyMapping(x), x)).toMap
    )
  }

  def applyWithKeyMapping[K, A](before: java.util.List[A], after: java.util.List[A])(keyMapping: A => K): Diff[K, A] =
    new Diff(before.asScala.map(x => (keyMapping(x), x)).toMap, after.asScala.map(x => (keyMapping(x), x)).toMap)

  def applyWithKeyMappingAndComparator[K, A](before: Iterable[A], after: Iterable[A])
                                            (keyMapping: A => K, areEqual: (A, A) => Boolean): Diff[K, A] = {
    new Diff(
      before.map(x => (keyMapping(x), x)).toMap,
      after.map(x => (keyMapping(x), x)).toMap,
      areEqual
    )
  }
}
