package com.xebialabs.deployit.jetty;

import java.util.Collections;
import java.util.Enumeration;
import java.util.List;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.RequestMethod;

import com.google.common.collect.Lists;

public class HttpHeaderOverrideServletRequest extends HttpServletRequestWrapper {

	private static final String DEFAULT_ACCEPT_HEADER = "Accept";

	public static final String DEFAULT_QUERY_PARAM = "acceptHeader";

	private final String headerOverrideQueryParam;
	
	private transient String method;
	
	public HttpHeaderOverrideServletRequest(HttpServletRequest request, String queryParam) {
		super(request);
		this.headerOverrideQueryParam = queryParam;
	}
	
	@SuppressWarnings("unchecked")
	@Override
	public Enumeration<String> getHeaders(String name) {
		if(shouldOverrideAcceptHeader(name)){
			String acceptHeaderValue = "text/csv";
			return  Collections.enumeration(Lists.newArrayList(acceptHeaderValue));
		}
		return super.getHeaders(name);
	}
	
	@SuppressWarnings("unchecked")
	private Boolean shouldOverrideAcceptHeader(String name){
		if(name.equalsIgnoreCase(DEFAULT_ACCEPT_HEADER)){
			 List<String> requestParameterNames = Collections.list((Enumeration<String>)super.getParameterNames());
			 return requestParameterNames.contains(headerOverrideQueryParam);
		}
		return false;
	}
	
	@Override
	public String getMethod() {
		if (method == null) {
			method = resolveMethod();
		}
		return method;
	}
	
	@SuppressWarnings("unchecked")
	protected String resolveMethod() {
		String method;
		List<String> requestParameterNames = Collections.list((Enumeration<String>)super.getParameterNames());
		if(requestParameterNames.contains(headerOverrideQueryParam)) {
			method = RequestMethod.GET.toString();
			logger.debug("Overriding " + super.getMethod() + " request to be a " + method + " request because a " + headerOverrideQueryParam
			        + " queryParameter was present");
		} else {
			method = super.getMethod();
			logger.debug("Not overriding a " + method + " request");
		}
		return method;
	}
	
	private static Logger logger = LoggerFactory.getLogger(HttpHeaderOverrideServletRequest.class);

}
