package com.xebialabs.xlrelease.planner

/** A simple State monad
  *
  * @param run the transition from a state to a new state and a computed value
  * @tparam S the type of the state that is carried on during the transition
  * @tparam A the type of the value returned by the transition
  */
case class State[S, A](run: S => (S, A)) {
  def apply(s: S): (S, A) = run(s)

  /* map the value of this computation */
  def map[T](f: A => T): State[S, T] = State { (s: S) =>
    run(s) match {
      case (s1, a) =>
        (s1, f(a))
    }
  }

  /* flatMap the valuye of this computation */
  def flatMap[T](f: A => State[S, T]): State[S, T] = State[S, T] { s =>
    run(s) match {
      case (s1, a) =>
        f(a).apply(s1)
    }
  }

  /* chain two transition, discarding the value */
  def andThen[U](next: State[S, U]): State[S, U] = this.flatMap(_ => next)

  // run the computation on a given state and discard the state
  def run0(s: S): A = run(s)._2

  // run the computation on a given state and discard the value
  def runState(s: S): S = this.andThen(State.get).run0(s)
}


object State {
  // get the current state as value
  def get[S]: State[S, S] = State[S, S](s => (s, s))

  // set the state to a given value
  def put[S](v: S): State[S, Unit] = State[S, Unit](s => (v, ()))

  // modify the state and get the modified state as value
  def modify[S](f: S => S): State[S, S] = State[S, S] { s =>
    val s1 = f(s)
    (s1, s1)
  }

  def update[S](f: S => S): State[S, Unit] = modify(f).map(_ => ())

  // return a constant value, with no state modifications
  def const[S, T](v: T): State[S, T] = State[S, T](s => (s, v))

  // do nothing
  def nop[S]: State[S, Unit] = State.const(())

  // gets the state and maps it to a value
  def gets[S, T](f: S => T): State[S, T] = get.map(f)

  implicit class ListStateOps[S, T](val list: List[State[S, T]]) extends AnyVal {
    // from List[State[S, T]] to State[S, List[T]]
    // i.e. from a list of transitions computing each one 'T',
    // to a single transition computing a list of 'T's
    def sequence: State[S, List[T]] = list.foldLeft(State.const[S, List[T]](List.empty)) { case (getResults, transition) =>
      for {
        results <- getResults
        next <- transition
      } yield next :: results
    }
  }
}
