package com.xebialabs.satellite.streaming

import java.util.zip.{Deflater, Inflater}

import akka.stream.scaladsl.Flow
import akka.util.{ByteString, ByteStringBuilder}

import scala.collection.mutable

object Compression {

  def decompress(implicit streamConfig: StreamConfig) = ifEnabled(decompressFunction)

  def compress(implicit streamConfig: StreamConfig) = ifEnabled(compressFunction)

  private def ifEnabled(function: (ByteString) => ByteString)(implicit streamConfig: StreamConfig) = Flow[ByteString].map(streamConfig.compression match {
    case true => function
    case false => identity[ByteString]
  })

  private def zipCompressor = new ZipCompressor

  val compressFunction: (ByteString) => ByteString = { bytes: ByteString =>
    val b = new ByteStringBuilder

    b ++= zipCompressor.compress(bytes.toArray)

    b.result()
  }

  val decompressFunction = { bytes: ByteString =>
    val b = new ByteStringBuilder

    b ++= zipCompressor.decompress(bytes.toArray)

    b.result()
  }

  class ZipCompressor {

    lazy val deflater = new Deflater(Deflater.BEST_SPEED)
    lazy val inflater = new Inflater()

    def compress(inputBuff: Array[Byte]): Array[Byte] = {
      val inputSize = inputBuff.length
      val outputBuff = new mutable.ArrayBuilder.ofByte
      outputBuff += (inputSize & 0xff).toByte
      outputBuff += (inputSize >> 8 & 0xff).toByte
      outputBuff += (inputSize >> 16 & 0xff).toByte
      outputBuff += (inputSize >> 24 & 0xff).toByte

      deflater.setInput(inputBuff)
      deflater.finish()
      val buff = new Array[Byte](4096)

      while (!deflater.finished) {
        val n = deflater.deflate(buff)
        outputBuff ++= buff.take(n)
      }
      deflater.reset()
      outputBuff.result()
    }

    def decompress(inputBuff: Array[Byte]): Array[Byte] = {
      val size: Int = (inputBuff(0).asInstanceOf[Int] & 0xff) |
        (inputBuff(1).asInstanceOf[Int] & 0xff) << 8 |
        (inputBuff(2).asInstanceOf[Int] & 0xff) << 16 |
        (inputBuff(3).asInstanceOf[Int] & 0xff) << 24
      val outputBuff = new Array[Byte](size)
      inflater.setInput(inputBuff, 4, inputBuff.length - 4)
      inflater.inflate(outputBuff)
      inflater.reset()
      outputBuff
    }
  }
}
