/*
 * Copyright 2014-2026 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
 */

package io.ktor.websocket

import io.ktor.util.cio.*
import io.ktor.util.logging.*
import io.ktor.utils.io.*
import io.ktor.utils.io.core.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.channels.*
import kotlinx.io.*
import kotlin.coroutines.*
import kotlin.time.*
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds

internal val LOGGER = KtorSimpleLogger("io.ktor.websocket.WebSocket")

/**
 * Ping interval meaning pinger is disabled.
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.PINGER_DISABLED)
 *
 * @see DefaultWebSocketSession.pingIntervalMillis
 */
public const val PINGER_DISABLED: Long = 0

/**
 * A default WebSocket session with ping-pong and timeout processing and built-in [closeReason] population.
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.DefaultWebSocketSession)
 */
public interface DefaultWebSocketSession : WebSocketSession {

    /**
     * Specifies the ping interval or disables ping if [PINGER_DISABLED] is specified.
     * Note that pongs will be handled despite this setting.
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.DefaultWebSocketSession.pingIntervalMillis)
     */
    public var pingIntervalMillis: Long

    /**
     * Specifies a timeout to wait for pong reply to ping; otherwise, the session will be terminated immediately.
     * It doesn't have any effect if [pingIntervalMillis] is [PINGER_DISABLED].
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.DefaultWebSocketSession.timeoutMillis)
     */
    public var timeoutMillis: Long

    /**
     * A close reason for this session. It could be `null` if a session is terminated with no close reason
     * (for example, due to connection failure).
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.DefaultWebSocketSession.closeReason)
     */
    public val closeReason: Deferred<CloseReason?>

    /**
     * Starts a WebSocket conversation.
     *
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.DefaultWebSocketSession.start)
     *
     * @param negotiatedExtensions specify negotiated extensions list to use in current session.
     */
    @InternalAPI
    public fun start(negotiatedExtensions: List<WebSocketExtension<*>> = emptyList())
}

/**
 * Creates [DefaultWebSocketSession] from a session.
 *
 * @param session raw [WebSocketSession] to wrap.
 * @param pingIntervalMillis interval between pings or [PINGER_DISABLED] to disable.
 * @param timeoutMillis timeout for pings.
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.DefaultWebSocketSession)
 */
public fun DefaultWebSocketSession(
    session: WebSocketSession,
    pingIntervalMillis: Long = PINGER_DISABLED,
    timeoutMillis: Long = 15_000L,
): DefaultWebSocketSession {
    require(session !is DefaultWebSocketSession) { "Cannot wrap other DefaultWebSocketSession" }
    return DefaultWebSocketSessionImpl(
        raw = session,
        pingIntervalMillis,
        timeoutMillis,
        incomingFramesConfig = ChannelConfig.UNLIMITED,
        outgoingFramesConfig = OUTGOING_CHANNEL_CONFIG ?: ChannelConfig.UNLIMITED
    )
}

/**
 * Creates [DefaultWebSocketSession] from a session.
 *
 * @param session raw [WebSocketSession] to wrap.
 * @param pingIntervalMillis interval between pings or [PINGER_DISABLED] to disable.
 * @param timeoutMillis timeout for pings.
 * @param channelsConfig configuration for the I/O frame channels.
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.DefaultWebSocketSession)
 */
public fun DefaultWebSocketSession(
    session: WebSocketSession,
    pingIntervalMillis: Long = PINGER_DISABLED,
    timeoutMillis: Long = 15_000L,
    channelsConfig: WebSocketChannelsConfig = WebSocketChannelsConfig.UNLIMITED,
): DefaultWebSocketSession {
    require(session !is DefaultWebSocketSession) { "Cannot wrap other DefaultWebSocketSession" }
    return DefaultWebSocketSessionImpl(
        session,
        pingIntervalMillis,
        timeoutMillis,
        incomingFramesConfig = channelsConfig.incoming,
        outgoingFramesConfig = channelsConfig.outgoing,
    )
}

private val IncomingProcessorCoroutineName = CoroutineName("ws-incoming-processor")
private val OutgoingProcessorCoroutineName = CoroutineName("ws-outgoing-processor")

private val NORMAL_CLOSE = CloseReason(CloseReason.Codes.NORMAL, "OK")

private val OUTGOING_CHANNEL_CONFIG = OUTGOING_CHANNEL_CAPACITY?.let {
    ChannelConfig(capacity = it, onOverflow = ChannelOverflow.SUSPEND)
}

/**
 * A default WebSocket session implementation that handles ping-pongs, close sequence, and frame fragmentation.
 */

@OptIn(InternalAPI::class)
internal class DefaultWebSocketSessionImpl(
    private val raw: WebSocketSession,
    pingIntervalMillis: Long,
    timeoutMillis: Long,
    incomingFramesConfig: ChannelConfig,
    outgoingFramesConfig: ChannelConfig,
) : DefaultWebSocketSession, WebSocketSession {
    private val pinger = atomic<SendChannel<Frame.Pong>?>(null)
    private val closeReasonRef = CompletableDeferred<CloseReason>()

    private val context = Job()
    override val coroutineContext: CoroutineContext =
        raw.coroutineContext.minusKey(Job) + context + CoroutineName("ws-default")

    private val filtered = Channel.from<Frame>(incomingFramesConfig)

    private val outgoingToBeProcessed = Channel.from<Frame>(outgoingFramesConfig)

    private val closed: AtomicBoolean = atomic(false)

    private val _extensions: MutableList<WebSocketExtension<*>> = mutableListOf()
    private val started = atomic(false)

    override val incoming: ReceiveChannel<Frame> get() = filtered

    override val outgoing: SendChannel<Frame> get() = outgoingToBeProcessed

    override val extensions: List<WebSocketExtension<*>>
        get() = _extensions

    override var masking: Boolean
        get() = raw.masking
        set(value) {
            raw.masking = value
        }

    override var maxFrameSize: Long
        get() = raw.maxFrameSize
        set(value) {
            raw.maxFrameSize = value
        }

    override var pingIntervalMillis: Long = pingIntervalMillis
        set(newValue) {
            field = newValue
            runOrCancelPinger()
        }

    override var timeoutMillis: Long = timeoutMillis
        set(newValue) {
            field = newValue
            runOrCancelPinger()
        }

    override val closeReason: Deferred<CloseReason?> = closeReasonRef

    @OptIn(InternalAPI::class)
    override fun start(negotiatedExtensions: List<WebSocketExtension<*>>) {
        if (!started.compareAndSet(false, true)) {
            error("WebSocket session $this is already started.")
        }

        LOGGER.trace {
            "Starting default WebSocketSession($this) " +
                "with negotiated extensions: ${negotiatedExtensions.joinToString()}"
        }

        _extensions.addAll(negotiatedExtensions)
        runOrCancelPinger()

        val incomingJob = runIncomingProcessor(ponger(outgoing))
        val outgoingJob = runOutgoingProcessor()

        launch {
            incomingJob.join()
            outgoingJob.join()

            context.cancel()
        }
    }

    /**
     * Close session with GOING_AWAY reason
     */
    suspend fun goingAway(message: String = "Server is going down") {
        sendCloseSequence(CloseReason(CloseReason.Codes.GOING_AWAY, message))
    }

    override suspend fun flush() {
        raw.flush()
    }

    @Deprecated(
        "Use cancel() instead.",
        ReplaceWith("cancel()", "kotlinx.coroutines.cancel"),
        level = DeprecationLevel.ERROR
    )
    override fun terminate() {
        context.cancel()
        raw.cancel()
    }

    @OptIn(InternalAPI::class)
    private fun runIncomingProcessor(ponger: SendChannel<Frame.Ping>): Job = launch(
        IncomingProcessorCoroutineName + Dispatchers.Unconfined
    ) {
        var firstFrame: Frame? = null
        var frameBody: Sink? = null
        var closeFramePresented = false
        try {
            @OptIn(DelicateCoroutinesApi::class)
            raw.incoming.consumeEach { frame ->
                LOGGER.trace { "WebSocketSession($this) receiving frame $frame" }
                when (frame) {
                    is Frame.Close -> {
                        if (!outgoing.isClosedForSend) {
                            outgoing.send(Frame.Close(frame.readReason() ?: NORMAL_CLOSE))
                        }
                        closeFramePresented = true
                        return@launch
                    }

                    is Frame.Pong -> pinger.value?.send(frame)
                    is Frame.Ping -> ponger.send(frame)
                    else -> {
                        checkMaxFrameSize(frameBody, frame)

                        if (!frame.fin) {
                            if (firstFrame == null) {
                                firstFrame = frame
                            }
                            if (frameBody == null) {
                                frameBody = BytePacketBuilder()
                            }

                            frameBody.writeFully(frame.data)
                            return@consumeEach
                        }

                        if (firstFrame == null) {
                            filtered.send(processIncomingExtensions(frame))
                            return@consumeEach
                        }

                        frameBody!!.writeFully(frame.data)
                        val defragmented = Frame.byType(
                            fin = true,
                            firstFrame.frameType,
                            frameBody.build().readByteArray(),
                            firstFrame.rsv1,
                            firstFrame.rsv2,
                            firstFrame.rsv3
                        )

                        firstFrame = null
                        filtered.send(processIncomingExtensions(defragmented))
                    }
                }
            }
        } catch (_: ClosedSendChannelException) {
        } catch (cause: Throwable) {
            ponger.close()
            filtered.close(cause)
        } finally {
            ponger.close()
            frameBody?.close()
            filtered.close()

            if (!closeFramePresented) {
                close(CloseReason(CloseReason.Codes.CLOSED_ABNORMALLY, "Connection was closed without close frame"))
            }
        }
    }

    private fun runOutgoingProcessor(): Job = launch(
        OutgoingProcessorCoroutineName + Dispatchers.Unconfined,
        start = CoroutineStart.UNDISPATCHED
    ) {
        try {
            outgoingProcessorLoop()
        } catch (_: ClosedSendChannelException) {
        } catch (_: ClosedReceiveChannelException) {
        } catch (_: CancellationException) {
            sendCloseSequence(CloseReason(CloseReason.Codes.NORMAL, ""))
        } catch (_: ChannelIOException) {
        } catch (cause: Throwable) {
            outgoingToBeProcessed.cancel(CancellationException("Failed to send frame", cause))
            raw.closeExceptionally(cause)
            return@launch
        } finally {
            outgoingToBeProcessed.cancel()
            raw.close()
        }
    }

    private suspend fun outgoingProcessorLoop() {
        for (frame in outgoingToBeProcessed) {
            LOGGER.trace { "Sending $frame from session $this" }
            val processedFrame: Frame = when (frame) {
                is Frame.Close -> {
                    sendCloseSequence(frame.readReason())
                    break
                }

                is Frame.Text,
                is Frame.Binary -> processOutgoingExtensions(frame)

                else -> frame
            }

            raw.outgoing.send(processedFrame)
        }
    }

    @OptIn(InternalAPI::class)
    private suspend fun sendCloseSequence(reason: CloseReason?, exception: Throwable? = null) {
        if (!tryClose()) return
        LOGGER.trace { "Sending Close Sequence for session $this with reason $reason and exception $exception" }
        // don't cancel because sendCloseSequence is invoked inside a child coroutine of this context
        context.complete()

        val reasonToSend = reason ?: CloseReason(CloseReason.Codes.NORMAL, "")
        try {
            runOrCancelPinger()
            if (reasonToSend.code != CloseReason.Codes.CLOSED_ABNORMALLY.code) {
                raw.outgoing.send(Frame.Close(reasonToSend))
            }
        } finally {
            closeReasonRef.complete(reasonToSend)

            if (exception != null) {
                outgoingToBeProcessed.close(exception)
                filtered.close(exception)
            }
        }
    }

    private fun tryClose(): Boolean = closed.compareAndSet(false, true)

    private fun runOrCancelPinger() {
        val interval = pingIntervalMillis

        val newPinger: SendChannel<Frame.Pong>? = when {
            closed.value -> null
            interval > PINGER_DISABLED -> pinger(raw.outgoing, interval, timeoutMillis) {
                sendCloseSequence(it, IOException("Ping timeout"))
            }

            else -> null
        }

        // pinger is always lazy so we publish it first and then start it by sending EmptyPong
        // otherwise it may send ping before it get published so corresponding pong will not be dispatched to pinger
        // that will cause it to terminate connection on timeout
        pinger.getAndSet(newPinger)?.close()

        // it is safe here to send dummy pong because pinger will ignore it
        newPinger?.trySend(EmptyPong)?.isSuccess

        if (closed.value && newPinger != null) {
            runOrCancelPinger()
        }
    }

    private suspend fun checkMaxFrameSize(
        packet: Sink?,
        frame: Frame
    ) {
        val size = frame.data.size + (packet?.size ?: 0)
        if (size > maxFrameSize) {
            packet?.close()
            close(CloseReason(CloseReason.Codes.TOO_BIG, "Frame is too big: $size. Max size is $maxFrameSize"))
            throw FrameTooBigException(size.toLong())
        }
    }

    private fun processIncomingExtensions(frame: Frame): Frame =
        extensions.fold(frame) { current, extension -> extension.processIncomingFrame(current) }

    private fun processOutgoingExtensions(frame: Frame): Frame =
        extensions.fold(frame) { current, extension -> extension.processOutgoingFrame(current) }

    companion object {
        private val EmptyPong = Frame.Pong(ByteArray(0), NonDisposableHandle)
    }
}

/**
 * Ping interval or `null` to disable pinger. Note that pongs will be handled despite this setting.
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.pingInterval)
 */
public inline var DefaultWebSocketSession.pingInterval: Duration?
    get() = pingIntervalMillis.takeIf { it > PINGER_DISABLED }?.milliseconds
    set(newDuration) {
        pingIntervalMillis = newDuration?.inWholeMilliseconds ?: PINGER_DISABLED
    }

/**
 * A timeout to wait for pong reply to ping, otherwise the session will be terminated immediately.
 * It doesn't have any effect if [pingInterval] is `null` (pinger is disabled).
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.websocket.timeout)
 */
public inline var DefaultWebSocketSession.timeout: Duration
    get() = timeoutMillis.milliseconds
    set(newDuration) {
        timeoutMillis = newDuration.inWholeMilliseconds
    }

// TODO: drop in version 4, pass channel config only
internal expect val OUTGOING_CHANNEL_CAPACITY: Int?
