package com.xebialabs.xlrelease.utils

import com.xebialabs.xlrelease.utils.Graph.Edge
import grizzled.slf4j.Logging

import scala.annotation.tailrec


case class Graph[A](edges: Set[Edge[A]]) {
  lazy val nodes: Set[A] = edges.flatMap {
    case Edge(from, to) => Set(from, to)
  }

  def outgoing(a: A): Set[A] = edges.filter(_.from == a).map(_.to)

  def incoming(a: A): Set[A] = edges.filter(_.to == a).map(_.from)

  lazy val (order, hasCycle) = Graph.sort[A](this).fold(_ => List.empty[A] -> true, _ -> false)
}

object Graph extends Logging {
  def apply[A](edges: Iterable[Edge[A]]): Graph[A] = new Graph(edges.toSet)

  case class Edge[A](from: A, to: A)
  object Edge {
    def apply[A](edge: (A, A)): Edge[A] = new Edge(edge._1, edge._2)
  }

  // Depth-first search state
  case class DFS[A, B](value: B, visited: Set[A], temp: Set[A], order: List[A], done: Boolean, hasCycle: Boolean)

  object DFS {
    def empty[A, B](initialValue: B): DFS[A, B] = DFS(initialValue, Set.empty, Set.empty, List.empty, done = false, hasCycle = false)
  }

  type STEP[A, B] = DFS[A, B] => DFS[A, B]

  // classic depth-first-search algorithm to find the topological order while detecting cycles:
  // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
  def sort[A](graph: Graph[A]): Either[DFS[A, Unit], List[A]] =
    walk[A, Unit](graph)(())((_, _) => ()) match {
      case s if s.hasCycle => Left(s)
      case s => Right(s.order)
    }

  // walk the graph until a cycle is detected.
  // it applies the function `f` after verifying there is no loop involving a node A
  def walk[A, B](graph: Graph[A])(z: B)(f: (DFS[A, B], A) => B): DFS[A, B] = {
    walk0(DFS.empty[A, B](z))(f)(graph)
  }

  // recursive walk method:
  // 0. if a cycle is detected or if we're done, return the current state
  // 1. choose a node that is not visited nor part of the `temp` set
  // 2.a if there are no nodes left, we're done and we return the current state
  // 2.b otherwise, visit the node
  // 3.a if no cycles are detected, apply walk0 recursively
  // 3.b otherwise apply the function `f` one last time and return the state
  @tailrec
  private def walk0[A, B](state: DFS[A, B])(f: (DFS[A, B], A) => B)(implicit graph: Graph[A]): DFS[A, B] = {
    if (!state.hasCycle && !state.done) {
      val notVisited = graph.nodes diff (state.visited union state.temp)
      notVisited.headOption match {
        case None =>
          state
        case Some(node) =>
          visit(node)(f).apply(state) match {
            case sN if !sN.hasCycle =>
              walk0(continue(sN))(f)
            case sN =>
              logger.debug(s"cycle detected: ${sN.order} | ${sN.temp} | ${sN.visited}")
              sN.copy(value = f(sN, node))
          }
      }
    } else {
      logger.debug(s"cycle detected among ${state.temp}")
      state
    }
  }

  // visit a node. if is in s.temp => cycle detected, if already visited, nothing to do
  // otherwise, call visit0 (below)
  private def visit[A, B](node: A)(f: (DFS[A, B], A) => B)(implicit graph: Graph[A]): STEP[A, B] =
    step {
      case state if state.visited contains node => okay(state)
      case state if state.temp contains node => stop(state)
      case state =>
        val visit0 =
          markAsTemp(node) andThen      // add node to state.temp
            process(node)(f) andThen      // process node (see below)
            markAsVisited(node) andThen   // add node to state.visited
            prepend(node)                 // prepend node to state.order
        visit0.apply(state)
    }

  // process a node by visiting the node at the end of each outgoing edge, then update state.value using the function `f`
  private def process[A, B](node: A)(f: (DFS[A, B], A) => B)(implicit graph: Graph[A]): STEP[A, B] =
    step { state =>
      val state1 = graph.outgoing(node).foldLeft(state) { case (stateN, m) =>
        visit(m)(f).apply(stateN)
      }
      state1.copy(value = f(state1, node))
    }

  // step: do nothing if already done or a cycle has been detected
  private def step[A, B](f: STEP[A, B]): STEP[A, B] = {
    case state if state.done || state.hasCycle => state
    case state => f(state)
  }

  // add node to temp
  private def markAsTemp[A, B](node: A): STEP[A, B] =
    step(state => state.copy(temp = state.temp + node))

  // add node to visited
  private def markAsVisited[A, B](node: A): STEP[A, B] =
    step(state => state.copy(visited = state.visited + node))

  // prepend node to order
  private def prepend[A, B](node: A): STEP[A, B] =
    step(state => state.copy(order = node :: state.order))

  // cycle detected
  private def stop[A, B]: STEP[A, B] =
    step(state => state.copy(hasCycle = true))

  // unset `done` flag
  private def continue[A, B]: STEP[A, B] = {
    case state if state.hasCycle => state
    case state => state.copy(done = false)
  }

  // do nothing
  private def okay[A, B]: STEP[A, B] = identity
}
