//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.ee11.servlet;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.io.Writer;
import java.nio.charset.Charset;
import java.util.List;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.handler.ContextHandler;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ErrorHandler extends org.eclipse.jetty.server.handler.ErrorHandler
{
    private static final Logger LOG = LoggerFactory.getLogger(ErrorHandler.class);

    public ErrorHandler()
    {
        setShowOrigin(true);
        setShowStacks(true);
        setShowMessageInTitle(true);
    }

    @Override
    public boolean writeError(Request request, Response response, Callback callback, int code)
    {
        // If we have not entered the servlet channel we should trigger a sendError for when we do enter the servlet channel.
        ServletContextRequest servletContextRequest = Request.asInContext(request, ServletContextRequest.class);
        boolean enteredServletChannel = servletContextRequest.getServletChannel().getCallback() != null;
        if (!enteredServletChannel)
        {
            response.setStatus(code);
            request.setAttribute(ERROR_STATUS, code);
            return false;
        }

        return super.writeError(request, response, callback, code);
    }

    @Override
    public boolean handle(Request request, Response response, Callback callback) throws Exception
    {
        if (!errorPageForMethod(request.getMethod()))
        {
            callback.succeeded();
            return true;
        }

        generateCacheControl(response);

        ServletContextRequest servletContextRequest = Request.asInContext(request, ServletContextRequest.class);
        HttpServletRequest httpServletRequest = servletContextRequest.getServletApiRequest();
        HttpServletResponse httpServletResponse = servletContextRequest.getHttpServletResponse();
        ServletContextHandler contextHandler = servletContextRequest.getServletContext().getServletContextHandler();
        String cacheControl = getCacheControl();
        if (cacheControl != null)
            response.getHeaders().put(HttpHeader.CACHE_CONTROL.asString(), cacheControl);

        // Look for an error page dispatcher
        // This logic really should be in ErrorPageErrorHandler, but some implementations extend ErrorHandler
        // and implement ErrorPageMapper directly, so we do this here in the base class.
        ServletContextHandler.ServletScopedContext context = servletContextRequest.getErrorContext();
        Integer errorStatus = (Integer)request.getAttribute(ERROR_STATUS);
        Throwable errorCause = (Throwable)request.getAttribute(ERROR_EXCEPTION);

        // Error page mapping can only be supported from within the ServletChannel handling.
        // If an error that may be mapped to an error page occurs before entering ServletChannel,
        // then the ErrorHandler#writeError(...) method should be used to delay
        // invoking sendError until the handling is within the ServletChannel.
        boolean enteredServletChannel = servletContextRequest.getServletChannel().getCallback() != null;
        if (this instanceof ErrorPageMapper mapper && enteredServletChannel)
        {
            ErrorPageMapper.ErrorPage errorPage = mapper.getErrorPage(errorStatus, errorCause);
            if (LOG.isDebugEnabled())
                LOG.debug("{} {} {} -> {}", context, errorStatus, errorCause, errorPage);
            if (errorPage != null && context.getServletContext().getRequestDispatcher(errorPage.errorPage) instanceof Dispatcher errorDispatcher)
            {
                try
                {
                    try
                    {
                        mapper.prepare(errorPage, httpServletRequest, httpServletResponse);
                        contextHandler.requestInitialized(servletContextRequest, httpServletRequest);
                        errorDispatcher.error(httpServletRequest, httpServletResponse);
                    }
                    finally
                    {
                        contextHandler.requestDestroyed(servletContextRequest, httpServletRequest);
                    }
                    callback.succeeded();
                    return true;
                }
                catch (ServletException e)
                {
                    if (LOG.isDebugEnabled())
                        LOG.debug("Unable to call error dispatcher", e);
                    if (response.isCommitted())
                    {
                        callback.failed(e);
                        return true;
                    }
                }
            }
        }

        String message = (String)request.getAttribute(ERROR_MESSAGE);
        if (message == null)
            message = HttpStatus.getMessage(response.getStatus());
        generateResponse(request, response, response.getStatus(), message,  (Throwable)request.getAttribute(ERROR_EXCEPTION), callback);
        callback.succeeded();
        return true;
    }

    protected boolean generateAcceptableResponse(Request request, Response response, Callback callback, String contentType, List<Charset> charsets, int code, String message, Throwable cause) throws IOException
    {
        boolean result = super.generateAcceptableResponse(request, response, callback, contentType, charsets, code, message, cause);
        if (result)
        {
            // Do an asynchronous completion
            ServletContextRequest servletContextRequest = Request.as(request, ServletContextRequest.class);
            servletContextRequest.getServletChannel().sendErrorResponseAndComplete();
        }
        return result;
    }

    protected void writeErrorHtmlMessage(Request request, Writer writer, int code, String message, Throwable cause, String uri) throws IOException
    {
        writer.write("<h2>HTTP ERROR ");
        String status = Integer.toString(code);
        writer.write(status);
        if (message != null && !message.equals(status))
        {
            writer.write(' ');
            writer.write(StringUtil.sanitizeXmlString(message));
        }
        writer.write("</h2>\n");
        writer.write("<table>\n");
        htmlRow(writer, "URI", uri);
        htmlRow(writer, "STATUS", status);
        htmlRow(writer, "MESSAGE", message);
        writeErrorOrigin((String)request.getAttribute(ERROR_ORIGIN), (o) ->
        {
            try
            {
                htmlRow(writer, "SERVLET", o);
            }
            catch (IOException x)
            {
                throw new UncheckedIOException(x);
            }
        });

        while (cause != null)
        {
            htmlRow(writer, "CAUSED BY", cause);
            cause = cause.getCause();
        }
        writer.write("</table>\n");
    }

    public interface ErrorPageMapper
    {
        enum PageLookupTechnique
        {
            THROWABLE, STATUS_CODE, GLOBAL
        }

        record ErrorPage(String errorPage, PageLookupTechnique match, Throwable error, Throwable cause, Class<?> matchedClass)
        {
        }

        ErrorPage getErrorPage(Integer errorStatusCode, Throwable error);

        default void prepare(ErrorPage errorPage, HttpServletRequest request, HttpServletResponse response)
        {}
    }

    public static Request.Handler getErrorHandler(Server server, ContextHandler context)
    {
        Request.Handler errorHandler = null;
        if (context != null)
            errorHandler = context.getErrorHandler();
        if (errorHandler == null && server != null)
            errorHandler = server.getErrorHandler();
        return errorHandler;
    }
}
