package com.xebialabs.deployit.security;

import static com.google.common.collect.Maps.filterKeys;

import java.io.IOException;
import java.security.Principal;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import javax.naming.Context;
import javax.naming.InvalidNameException;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.InitialLdapContext;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;

import org.apache.jackrabbit.api.security.principal.PrincipalIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Predicate;

public class LdapLoginModule implements LoginModule {
	private static final String USER_PROVIDER = "userProvider";
	private static final String USER_FILTER = "userFilter";

	// Used for the username token replacement
	private static final String USERNAME_TOKEN = "{USERNAME}";
	private static final Pattern USERNAME_PATTERN = Pattern.compile("\\{USERNAME\\}");
	private static final String USE_SSL = "useSSL";

	private Subject subject;
	private CallbackHandler callbackHandler;

	private Hashtable ldap = new Hashtable();
	private String ldapUrl;
	private String userFilter;
	private Matcher userFilterMatcher;
	private SearchControls constraints;

	private LdapLoginState state;

	@Override
	public void initialize(Subject subject, CallbackHandler callbackHandler, Map<String, ?> sharedState, Map<String, ?> options) {
		this.subject = subject;
		this.callbackHandler = callbackHandler;

		ldap.put(Context.INITIAL_CONTEXT_FACTORY, LdapContextFactory.class.getName());
		ldap.putAll(filterKeys(options, new Predicate<String>() {
			@Override
			public boolean apply(String input) {
				return input.contains(".");
			}
		}));

		if (options.containsKey(USER_PROVIDER)) {
			ldapUrl = (String) options.get(USER_PROVIDER);
			ldap.put(Context.PROVIDER_URL, ldapUrl);
		}

		if (options.containsKey(USER_FILTER)) {
			userFilter = (String) options.get(USER_FILTER);
			if (userFilter.indexOf(USERNAME_TOKEN) != -1) {
				userFilterMatcher = USERNAME_PATTERN.matcher(userFilter);
			}
			constraints = new SearchControls();
			constraints.setSearchScope(SearchControls.SUBTREE_SCOPE);
			constraints.setReturningAttributes(new String[0]); // return no attrs
		}

		if (options.containsKey(USE_SSL)) {
			boolean useSSL = Boolean.valueOf((String) options.get(USE_SSL));
			if (useSSL) {
				ldap.put(Context.SECURITY_PROTOCOL, "ssl");
			} else {
				ldap.remove(Context.SECURITY_PROTOCOL);
			}
		}
	}

	private boolean getBooleanFromOptions(Map<String, ?> options, String bool) {
		return options.containsKey(bool) && Boolean.valueOf((String) options.get(bool));
	}

	@Override
	public boolean login() throws LoginException {
		state = new LdapLoginState();
		logger.debug("Connection to LDAP {}", ldapUrl);

		try {
			logger.debug("Retrieving username and password");
			retrieveUsernamePassword();
			tryConnect();
			searchUser();
			authenticateUser();
			createPrincipals();
			state.loggedIn = true;
			logger.debug("Logged in successfully");
		} catch (LoginException le) {
			state.clear();
			throw le;
		}

		return true; // To change body of implemented methods use File | Settings | File Templates.
	}

	private void createPrincipals() throws FailedLoginException {
		state.principal = new LdapPrincipalProvider.SimplePrincipal(state.username);
		try {
			state.ldapPrincipal = new LdapPrincipalProvider.LdapPrincipal(state.userDN);
		} catch (InvalidNameException e) {
			throw new FailedLoginException("Cannot create LdapPrincipal from " + state.userDN, e);
		}
	}

	private void authenticateUser() throws FailedLoginException {
		try {
			state.ctx.addToEnvironment(Context.SECURITY_AUTHENTICATION, "simple");
			state.ctx.addToEnvironment(Context.SECURITY_PRINCIPAL, state.userDN);
			state.ctx.addToEnvironment(Context.SECURITY_CREDENTIALS, state.password);
			state.ctx.reconnect(null);
		} catch (NamingException e) {
			throw new FailedLoginException("Cannot authenticate user " + state.userDN, e);
		}

	}

	private void searchUser() throws FailedLoginException {
		try {
			NamingEnumeration<SearchResult> results = state.ctx.search("", replaceUsername(userFilterMatcher, userFilter), constraints);
			try {
				if (results.hasMoreElements()) {
					SearchResult result = results.nextElement();
					state.userDN = result.getNameInNamespace();
					logger.debug("Found user: {}", state.userDN);
				}
			} finally {
				results.close();
			}
		} catch (NamingException e) {
			throw new FailedLoginException("Cannot find user in LDAP", e);
		}
	}

	private String replaceUsername(Matcher userFilterMatcher, String userFilter) {
		return userFilterMatcher != null ? userFilterMatcher.replaceAll(state.username) : userFilter;
	}

	private void tryConnect() throws FailedLoginException {
		try {
			state.ctx = new InitialLdapContext(ldap, null);
		} catch (NamingException e) {
			throw new FailedLoginException("Cannot connect to LDAP " + ldapUrl, e);
		}
	}

	private void retrieveUsernamePassword() {
		try {
			Callback[] callbacks = { new NameCallback("username; "), new PasswordCallback("password: ", false) };
			callbackHandler.handle(callbacks);
			state.username = ((NameCallback) callbacks[0]).getName();
			state.password = ((PasswordCallback) callbacks[1]).getPassword();
		} catch (IOException e) {
			e.printStackTrace(); // To change body of catch statement use File | Settings | File Templates.
		} catch (UnsupportedCallbackException e) {
			e.printStackTrace(); // To change body of catch statement use File | Settings | File Templates.
		}
	}

	@Override
	public boolean commit() throws LoginException {
		if (state.loggedIn) {
			checkSubject();

			Set<Principal> principals = subject.getPrincipals();
			addIfNotContained(principals, state.ldapPrincipal);
			addIfNotContained(principals, state.principal);
			getGroupPrincipals(principals);

			state.clear();
			state.committed = true;
			return true;
		}
		return false;
	}

	private void getGroupPrincipals(Set<Principal> principalsCollector) {
		logger.info("Getting group principals for {}", state.ldapPrincipal);
		LdapPrincipalProvider ldapPrincipalProvider = getPrincipalProvider();
		PrincipalIterator groupMemberships = ldapPrincipalProvider.getGroupMembership(state.ldapPrincipal);
		while (groupMemberships.hasNext()) {
			Principal nextPrincipal = groupMemberships.nextPrincipal();
			logger.info("Adding group principal {}", nextPrincipal.getName());
			addIfNotContained(principalsCollector, nextPrincipal);
		}
	}

	private LdapPrincipalProvider getPrincipalProvider() {
		// FIXME: Does not check value of "principalProvider" option
		// FIXME: Instantiates a new PrincipalProvider every time
		LdapPrincipalProvider ldapPrincipalProvider = new LdapPrincipalProvider();
		Properties ldapPrincipalProviderOptions = new Properties();
		ldapPrincipalProviderOptions.put(USER_PROVIDER, ldapUrl);
		ldapPrincipalProvider.init(ldapPrincipalProviderOptions);
		return ldapPrincipalProvider;
	}

	private void addIfNotContained(Set<Principal> principals, Principal principal) {
		if (!principals.contains(principal)) {
			principals.add(principal);
		}
	}

	@Override
	public boolean abort() throws LoginException {
		if (state.loggedIn && state.committed) {
			logout();
		} else if (state.loggedIn) {
			state.loggedIn = false;
			state.clear();
			state.clearPrincipals();
		} else {
			return false;
		}
		return true;
	}

	@Override
	public boolean logout() throws LoginException {
		if (state == null)
			return true;

		checkSubject();

		Set<Principal> principals = subject.getPrincipals();
		if (principals != null) {
			principals.remove(state.ldapPrincipal);
			principals.remove(state.principal);
		}

		state.clear();
		state.loggedIn = false;
		state.committed = false;
		state.clearPrincipals();

		return true;
	}

	private void checkSubject() throws LoginException {
		if (subject.isReadOnly()) {
			state.clear();
			throw new LoginException("Subject is readOnly");
		}
	}

	private static final Logger logger = LoggerFactory.getLogger(LdapLoginModule.class);

	private class LdapLoginState {
		private String username;
		private char[] password;
		public InitialLdapContext ctx;
		public String userDN;
		public LdapPrincipalProvider.SimplePrincipal principal;
		public LdapPrincipalProvider.LdapPrincipal ldapPrincipal;
		public boolean loggedIn;
		public boolean committed;

		public void clear() {
			if (ctx != null) {
				try {
					ctx.close();
				} catch (NamingException e) {
					// ignore
				}
			}
			Arrays.fill(password, '*');
		}

		public void clearPrincipals() {
			ldapPrincipal = null;
			principal = null;
		}
	}

	private static class FailedLoginException extends LoginException {
		private FailedLoginException(String msg, Exception cause) {
			super(msg);
			initCause(cause);
		}
	}
}
