package com.xebialabs.xlrelease.utils

import com.xebialabs.xlrelease.utils.Tree.{Empty, Node}

import scala.Function.const

sealed trait Tree[+A]

object Tree {

  case object Empty extends Tree[Nothing]

  case class Node[+A](value: A, children: List[Node[A]]) extends Tree[A]

  object Node {
    def apply[A](value: A): Node[A] = Node(value, List.empty)
  }

  def empty[A]: Tree[A] = Empty

  def apply[A](value: A, children: List[Node[A]] = List.empty): Node[A] = Node(value, children)



  implicit class NodeOps[A](val n: Node[A]) extends AnyVal {
    def filterChildren(f: A => Boolean): Node[A] = TreeFunctions.filterChildren(f)(n)

    def map[B](f: A => B): Node[B] = TreeFunctions.mapNode(f)(n)

    def unsafeMap[B](f: Node[A] => B): Node[B] = TreeFunctions.unsafeMapNode(f)(n)

    def unsafeFlatMap[B](f: Node[A] => Tree[B]): Tree[B] = TreeFunctions.unsafeFlatMapNode(f)(n)
  }

  implicit class TreeOps[A](val t: Tree[A]) extends AnyVal {
    def fold[B](default: => B)(f: Node[A] => B): B = TreeFunctions.fold(default)(f)(t)

    def map[B](f: A => B): Tree[B] = TreeFunctions.fold(empty[B])(TreeFunctions.map(f))(t)

    def foreach[U](f: A => U): Unit = TreeFunctions.foreach(f)(t)

    def filter(f: A => Boolean): Tree[A] = TreeFunctions.filter(f)(t)

    def filterChildren(f: A => Boolean): Tree[A] = TreeFunctions.fold(empty[A])(TreeFunctions.filterChildren(f))(t)

    def flatMap[B](f: A => Tree[B]): Tree[B] = TreeFunctions.flatMap(f)(t)

    def toOption: Option[Node[A]] = TreeFunctions.toOption(t)

    def toList: List[A] = TreeFunctions.toList[A](t)

    def topDown: List[A] = TreeFunctions.toList[A](t)

    def bottomUp: List[A] = TreeFunctions.bottomUpList[A](t)

    def unsafeMap[B](f: Node[A] => B): Tree[B] = TreeFunctions.unsafeMap(f)(t)

    def unsafeFlatMap[B](f: Node[A]=> Tree[B]): Tree[B] = TreeFunctions.unsafeFlatMap(f)(t)
  }

  implicit class TreeTreeOps[A](val t: Tree[Tree[A]]) extends AnyVal {
    def flatten: Tree[A] = TreeFunctions.flatten[A](t)
  }

  implicit class ListTreeOps[A](val ts: List[Tree[A]]) extends AnyVal {
    def prune: List[Node[A]] = TreeFunctions.prune(ts)
  }

}


object TreeFunctions {
  def fold[A, B](default: => B)(f: Node[A] => B)(t: Tree[A]): B =
    t match {
      case Empty => default
      case n@Node(_, _) => f(n)
    }

  def prune[A](nodes: List[Tree[A]]): List[Node[A]] =
    nodes.flatMap(toOption(_))

  def mapNode[A, B](f: A => B): Node[A] => Node[B] =
  { case Node(value, children) =>
    Node(f(value), prune(children.map(mapNode(f))))
  }

  def map[A, B](f: A => B): Tree[A] => Tree[B] =
    fold(Tree.empty[B])(mapNode(f))

  def filter[A](f: A => Boolean): Tree[A] => Tree[A] =
    fold(Tree.empty[A]) {
      case n@Node(value, _) if f(value) =>
        filterChildren(f)(n)
      case n@Node(_, _) =>
        Empty
    }

  def filterChildren[A](f: A => Boolean): Node[A] => Node[A] = {
    case Node(value, children) =>
      Node(value, prune(children.map(filter(f))))
  }

  def flattenNode[A]: Node[Tree[A]] => Tree[A] =
    n =>
      n.value.fold(Tree.empty[A]) { case Node(value, children) =>
        Node(value, children ++ prune(n.children.map(flattenNode)))
      }

  def flatten[A]: Tree[Tree[A]] => Tree[A] =
    fold(Tree.empty[A])(flattenNode)

  def flatMap[A, B](f: A => Tree[B]): Tree[A] => Tree[B] =
    map(f) andThen flatten

  def foldLeft[A, B](z: B)(f: (B, A) => B): Tree[A] => B =
    fold(z) { case Node(value, children) =>
      val z0: B = f(z, value)
      children.foldLeft(z0)({ case (z1, c) => foldLeft(z1)(f)(c) })
    }

  def foldRight[B, A](z: B)(f: (A, B) => B): Tree[A] => B =
    fold(z) { case Node(value, children) =>
      f(value, children.foldRight(z)({ case (c, z1) => foldRight(z1)(f)(c) }))
    }

  def toOption[A]: Tree[A] => Option[Node[A]] =
    fold(Option.empty[Node[A]])(Some(_))

  def toList[A]: Tree[A] => List[A] = foldRight(List.empty[A])(_ :: _)

  def bottomUpList[A]: Tree[A] => List[A] = foldLeft(List.empty[A])((l, v) => v :: l)

  def foreach[A, U](f: A => U): Tree[A] => Unit = map(f) andThen const(())

  def unsafeMapNode[A, B](unsafeF: (A, List[Node[B]]) => B): Node[A] => Node[B] =
    n => {
      val ns = prune(n.children.map(unsafeMapNode(unsafeF)))
      val b = unsafeF(n.value, ns)
      Node(b, ns)
    }

  def unsafeMapNode[A, B](unsafeF: Node[A] => B): Node[A] => Node[B] = { node =>
    val b = unsafeF(node)
    val bs = prune(node.children.map(unsafeMapNode(unsafeF)))
    Node(b, bs)
  }

  def unsafeMap[A, B](unsafeF: Node[A] => B): Tree[A] => Tree[B] =
    fold(Tree.empty[B])(unsafeMapNode(unsafeF))

  def unsafeFlatMapNode[A, B](unsafeF: Node[A] => Tree[B]): Node[A] => Tree[B] = unsafeF

  def unsafeFlatMap[A, B](unsafeF: Node[A] => Tree[B]): Tree[A] => Tree[B] =
    fold(Tree.empty[B])(unsafeFlatMapNode(unsafeF))

}
