package ai.digital.deploy.task.serdes.kryo

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import com.twitter.chill.java.{Java8ClosureRegistrar, PackageRegistrar}
import com.twitter.chill.{BitSetSerializer, ClassTagSerializer, EnumerationSerializer, JavaWrapperCollectionRegistrar, KSerializer, KryoInstantiator, LeftSerializer, ManifestSerializer, RegexSerializer, RichKryo, RightSerializer, ScalaTupleSerialization, SingletonSerializer, SomeSerializer, SortedMapSerializer, SortedSetSerializer, StreamSerializer, VolatileByteRefSerializer, WrappedArraySerializer}
import io.altoo.serialization.kryo.scala.serializer.ScalaKryo

import java.math.{BigDecimal => JBigDecimal}
import scala.collection.immutable.{BitSet, HashMap, HashSet, ListMap, ListSet, NumericRange, Queue, Range, SortedMap, SortedSet, TreeMap, TreeSet, WrappedString}
import scala.collection.mutable
import scala.collection.mutable.{Buffer, ListBuffer, WrappedArray, BitSet => MBitSet, HashMap => MHashMap, HashSet => MHashSet, Map => MMap, Queue => MQueue, Set => MSet}
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.matching.Regex

// These serializers come from twitter chill - we copied logic because chill does not use kryo 5 yet
object TwitterChillSerializers extends Registrar {
  def apply(implicit kryo: Kryo): Unit = {
    useField(Set(1, 2, 3).asJava.getClass)
    useField(mutable.Set(1, 2, 3).asJava.getClass)
    useField(Set(1, 2, 3).iterator.asJava.getClass)
    useField(mutable.Buffer(1, 2, 3).asJava.getClass)
    useField(mutable.Buffer(1, 2, 3).iterator.asJava.getClass)
    useField(mutable.Map("k" -> "v").asJava.getClass)

    // duplicate logic found in `new AllScalaRegistrar().apply(kryo)`:
    new AllScalaRegistrar_0_9_5()(kryo)
    new AllScalaRegistrarCompat()(kryo)
  }

}

trait ChillRegistrar extends Registrar {

  import scala.language.implicitConversions

  implicit def scalaKryoToRichKryo(kryo: ScalaKryo): RichKryo = {
    new RichKryo(kryo)
  }

  implicit def kryoToRichKryo(kryo: Kryo): RichKryo = {
    new RichKryo(kryo)
  }
}

class AllScalaRegistrarCompat extends ChillRegistrar {
  override def apply(implicit newK: Kryo): Unit =
    newK
      .forConcreteTraversableClass(Vector[Any]())
      .forConcreteTraversableClass(Vector('a))
}

class AllScalaRegistrar_0_9_5 extends ChillRegistrar {
  // scalastyle:off method.length
  def apply(implicit k: Kryo): Unit = {
    new AllScalaRegistrar_0_9_2()(k)
    new AllScalaRegistrarCompat_0_9_5()(k)
    k.registerClasses(
      Seq(
        classOf[Array[Byte]],
        classOf[Array[Short]],
        classOf[Array[Int]],
        classOf[Array[Long]],
        classOf[Array[Float]],
        classOf[Array[Double]],
        classOf[Array[Boolean]],
        classOf[Array[Char]],
        classOf[Array[String]],
        classOf[Array[Any]],
        classOf[Class[_]], // needed for the WrappedArraySerializer
        classOf[Any], // needed for scala.collection.mutable.WrappedArray$ofRef
        mutable.WrappedArray.make(Array[Byte]()).getClass,
        mutable.WrappedArray.make(Array[Short]()).getClass,
        mutable.WrappedArray.make(Array[Int]()).getClass,
        mutable.WrappedArray.make(Array[Long]()).getClass,
        mutable.WrappedArray.make(Array[Float]()).getClass,
        mutable.WrappedArray.make(Array[Double]()).getClass,
        mutable.WrappedArray.make(Array[Boolean]()).getClass,
        mutable.WrappedArray.make(Array[Char]()).getClass,
        mutable.WrappedArray.make(Array[String]()).getClass,
        None.getClass,
        classOf[Queue[_]],
        Nil.getClass,
        classOf[::[_]],
        classOf[Range],
        classOf[WrappedString],
        classOf[TreeSet[_]],
        classOf[TreeMap[_, _]],
        // The most common orderings for TreeSet and TreeMap
        Ordering.Byte.getClass,
        Ordering.Short.getClass,
        Ordering.Int.getClass,
        Ordering.Long.getClass,
        Ordering.Float.getClass,
        Ordering.Double.getClass,
        Ordering.Boolean.getClass,
        Ordering.Char.getClass,
        Ordering.String.getClass
      )
    ).forConcreteTraversableClass(Set[Any]())
      .forConcreteTraversableClass(ListSet[Any]())
      .forConcreteTraversableClass(ListSet[Any]('a))
      .forConcreteTraversableClass(HashSet[Any]())
      .forConcreteTraversableClass(HashSet[Any]('a))
      .forConcreteTraversableClass(Map[Any, Any]())
      .forConcreteTraversableClass(HashMap[Any, Any]())
      .forConcreteTraversableClass(HashMap('a -> 'a))
      .forConcreteTraversableClass(ListMap[Any, Any]())
      .forConcreteTraversableClass(ListMap('a -> 'a))
    k.register(classOf[Stream.Cons[_]], new StreamSerializer[Any])
    k.register(Stream.empty[Any].getClass)
    k.forClass[scala.runtime.VolatileByteRef](new VolatileByteRefSerializer)
    k.forClass[BigDecimal](new BigDecimalSerializer)
    k.register(Queue.empty[Any].getClass)
    k.forConcreteTraversableClass(Map(1 -> 2).filterKeys(_ != 2).toMap)
      .forConcreteTraversableClass(Map(1 -> 2).mapValues(_ + 1).toMap)
      .forConcreteTraversableClass(Map(1 -> 2).keySet)
  }
}

class AllScalaRegistrarCompat_0_9_5 extends ChillRegistrar {
  def apply(implicit newK: Kryo): Unit =
    newK.register(classOf[Range.Exclusive])

}

class AllScalaRegistrar_0_9_2 extends ChillRegistrar {

  def apply(implicit k: Kryo): Unit = {
    new ScalaCollectionsRegistrar()(k)
    new JavaWrapperCollectionRegistrar()(k)

    // Register all 22 tuple serializers and specialized serializers
    ScalaTupleSerialization.register(k)
    k.forClass[Symbol](new KSerializer[Symbol] {
      override def isImmutable = true

      def write(k: Kryo, out: Output, obj: Symbol): Unit = out.writeString(obj.name)

      def read(k: Kryo, in: Input, cls: Class[_ <: Symbol]): Symbol = Symbol(in.readString)
    }).forSubclass[Regex](new RegexSerializer)
      .forClass[ClassTag[Any]](new ClassTagSerializer[Any])
      .forSubclass[Manifest[Any]](new ManifestSerializer[Any])
      .forSubclass[scala.Enumeration#Value](new EnumerationSerializer)

    // use the singleton serializer for boxed Unit
    val boxedUnit = scala.runtime.BoxedUnit.UNIT
    k.register(boxedUnit.getClass, new SingletonSerializer(boxedUnit))
    PackageRegistrar.all()(k)
    new Java8ClosureRegistrar()(k)
  }
}

/**
 * Note that additional scala collections registrations are provided by [[AllScalaRegistrar]]. They have not been
 * included in this registrar for backwards compatibility reasons.
 */
class ScalaCollectionsRegistrar extends ChillRegistrar {
  def apply(implicit newK: Kryo): Unit = {
    // The wrappers are private classes:
    useField(List(1, 2, 3).asJava.getClass)
    useField(List(1, 2, 3).iterator.asJava.getClass)
    useField(Map(1 -> 2, 4 -> 3).asJava.getClass)
    useField(new _root_.java.util.ArrayList().asScala.getClass)
    useField(new _root_.java.util.HashMap().asScala.getClass)

    /*
     * Note that subclass-based use: addDefaultSerializers, else: register
     * You should go from MOST specific, to least to specific when using
     * default serializers. The FIRST one found is the one used
     */
    newK
      // wrapper array is abstract
      .forSubclass[WrappedArray[Any]](new WrappedArraySerializer[Any])
      .forSubclass[BitSet](new BitSetSerializer)
      .forSubclass[SortedSet[Any]](new SortedSetSerializer)
      .forClass[Some[Any]](new SomeSerializer[Any])
      .forClass[Left[Any, Any]](new LeftSerializer[Any, Any])
      .forClass[Right[Any, Any]](new RightSerializer[Any, Any])
      .forTraversableSubclass(Queue.empty[Any])
      // List is a sealed class, so there are only two subclasses:
      .forTraversableSubclass(List.empty[Any])
      // Add ListBuffer subclass before Buffer to prevent the more general case taking precedence
      .forTraversableSubclass(ListBuffer.empty[Any], isImmutable = false)
      // add mutable Buffer before Vector, otherwise Vector is used
      .forTraversableSubclass(Buffer.empty[Any], isImmutable = false)
      // Vector is a final class
      .forTraversableClass(Vector.empty[Any])
      .forTraversableSubclass(ListSet.empty[Any])
      // specifically register small sets since Scala represents them differently
      .forConcreteTraversableClass(Set[Any]('a))
      .forConcreteTraversableClass(Set[Any]('a, 'b))
      .forConcreteTraversableClass(Set[Any]('a, 'b, 'c))
      .forConcreteTraversableClass(Set[Any]('a, 'b, 'c, 'd))
      // default set implementation
      .forConcreteTraversableClass(HashSet[Any]('a, 'b, 'c, 'd, 'e))
      // specifically register small maps since Scala represents them differently
      .forConcreteTraversableClass(Map[Any, Any]('a -> 'a))
      .forConcreteTraversableClass(Map[Any, Any]('a -> 'a, 'b -> 'b))
      .forConcreteTraversableClass(Map[Any, Any]('a -> 'a, 'b -> 'b, 'c -> 'c))
      .forConcreteTraversableClass(Map[Any, Any]('a -> 'a, 'b -> 'b, 'c -> 'c, 'd -> 'd))
      // default map implementation
      .forConcreteTraversableClass(HashMap[Any, Any]('a -> 'a, 'b -> 'b, 'c -> 'c, 'd -> 'd, 'e -> 'e))
      // The normal fields serializer works for ranges
      .registerClasses(
        Seq(classOf[Range.Inclusive], classOf[NumericRange.Inclusive[_]], classOf[NumericRange.Exclusive[_]])
      )
      // Add some maps
      .forSubclass[SortedMap[Any, Any]](new SortedMapSerializer)
      .forTraversableSubclass(ListMap.empty[Any, Any])
      .forTraversableSubclass(HashMap.empty[Any, Any])
      // The above ListMap/HashMap must appear before this:
      .forTraversableSubclass(Map.empty[Any, Any])
      // here are the mutable ones:
      .forTraversableClass(MBitSet.empty, isImmutable = false)
      .forTraversableClass(MHashMap.empty[Any, Any], isImmutable = false)
      .forTraversableClass(MHashSet.empty[Any], isImmutable = false)
      .forTraversableSubclass(MQueue.empty[Any], isImmutable = false)
      .forTraversableSubclass(MMap.empty[Any, Any], isImmutable = false)
      .forTraversableSubclass(MSet.empty[Any], isImmutable = false)
  }
}

class BigDecimalSerializer extends KSerializer[BigDecimal] {
  override def read(kryo: Kryo, input: Input, cls: Class[_ <: BigDecimal]): BigDecimal = {
    val jBigDec = kryo.readClassAndObject(input).asInstanceOf[JBigDecimal]
    BigDecimal(jBigDec)
  }

  override def write(kryo: Kryo, output: Output, obj: BigDecimal): Unit =
    kryo.writeClassAndObject(output, obj.bigDecimal)
}

class EmptyScalaKryoInstantiator extends KryoInstantiator {
  override def newKryo: TwitterChillKryoBase = {
    val k = new TwitterChillKryoBase
    k.setRegistrationRequired(false)
    k.setInstantiatorStrategy(new org.objenesis.strategy.StdInstantiatorStrategy)
    k.setOptimizedGenerics(false)

    // Handle cases where we may have an odd classloader setup like with libjars
    // for hadoop
    val classLoader = Thread.currentThread.getContextClassLoader
    k.setClassLoader(classLoader)

    k
  }
}
