package com.xebialabs.deployit.security;

import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.collect.Lists.newArrayList;
import static com.xebialabs.deployit.checks.Checks.checkNotNull;

import java.security.Principal;
import java.security.acl.Group;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Hashtable;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Vector;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import javax.jcr.Session;
import javax.naming.Context;
import javax.naming.InvalidNameException;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.DirContext;
import javax.naming.directory.InitialDirContext;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.LdapName;

import org.apache.jackrabbit.api.security.principal.PrincipalIterator;
import org.apache.jackrabbit.core.security.principal.PrincipalIteratorAdapter;
import org.apache.jackrabbit.core.security.principal.PrincipalProvider;
import org.apache.jackrabbit.core.security.principal.UnknownPrincipal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Function;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.MapMaker;

/**
 * Provides principals for Jackrabbit from LDAP. Should only be configured as a companion to {@link LdapLoginModule}.
 */
public class LdapPrincipalProvider implements PrincipalProvider {

	private String userProvider;
	private String groupFilter = "uniqueMember={DN}";
	private String groupSearchBase;

	private static final String DN_TOKEN = "{DN}";
	private static final String NESTED_GROUPS = "nestedGroups";
	private static final String GROUP_SEARCH_BASE = "groupSearchBase";
	private static final int DEFAULT_CACHE_TIMEOUT_SECS = 10;
	private Hashtable<String, String> ldapEnvironment;

	private boolean includeGroupsOfGroups = true;

	/**
	 * Create a caching map that expires values after 'cacheTTL' seconds (default 10).
	 */
	private static AtomicReference<Map<Principal, List<Principal>>> groupCache = 
		new AtomicReference<Map<Principal,List<Principal>>>();


	@Override
	public void init(final Properties options) {
		initGroupCache(options.getProperty("cacheTTL"));

		userProvider = (String) options.get("userProvider");
		logger.debug("Initialized with userProvider={}", userProvider);
		if (options.containsKey("groupFilter")) {
			groupFilter = (String) options.get("groupFilter");
		}

		logger.debug("Configured groupFilter: " + groupFilter);

		if (options.containsKey(NESTED_GROUPS)) {
			includeGroupsOfGroups = Boolean.valueOf(options.getProperty(NESTED_GROUPS));
		}
		logger.debug("Nested groups? {}", includeGroupsOfGroups);

		groupSearchBase = nullToEmpty(options.getProperty(GROUP_SEARCH_BASE));
		logger.debug("Setting group search base: '{}'", groupSearchBase);

		ldapEnvironment = new Hashtable<String, String>();
		ldapEnvironment.put(Context.INITIAL_CONTEXT_FACTORY, LdapContextFactory.class.getName());
		ldapEnvironment.put(Context.PROVIDER_URL, userProvider);
		for (Object o : options.keySet()) {
			// Adding extra keys
			if (((String) o).contains(".")) {
				logger.debug("Adding key to ldapEnvironment: {}", o);
				ldapEnvironment.put((String) o, options.getProperty((String) o));
			}
		}
	}

	private static void initGroupCache(String cacheTimeToLiveSecs) {
		int cacheTimeToLive = getCacheTimeout(cacheTimeToLiveSecs);
		if (groupCache.compareAndSet(null, new MapMaker().expireAfterAccess(cacheTimeToLive, 
				TimeUnit.SECONDS).<Principal, List<Principal>>makeMap())) {
			logger.debug("Initialized cache with TTL (secs): {}", cacheTimeToLive);
		} else {
			logger.debug("Cache already initialized");
		}
	}

	private static int getCacheTimeout(String cacheTimeoutSecs) {
		try {
			return ((cacheTimeoutSecs != null) ? Integer.parseInt(cacheTimeoutSecs) 
					                           : DEFAULT_CACHE_TIMEOUT_SECS);
		} catch (NumberFormatException exception) {
			logger.warn("Invalid value for configuration option 'cacheTimeout': {}. A numeric value is required", 
					cacheTimeoutSecs);
			return DEFAULT_CACHE_TIMEOUT_SECS;
		}
	}

	@Override
	public void close() {

	}

	@Override
	public Principal getPrincipal(final String principalName) {
		logger.debug("Getting principal for {}", principalName);
		try {
			return new LdapPrincipal(principalName);
		} catch (InvalidNameException e) {
			logger.debug("Could not create LdapPrincipal for {}, returning UnknownPrincipal", principalName);
			return new UnknownPrincipal(principalName);
		}
	}

	@Override
	public PrincipalIterator getPrincipals(final int searchType) {
		throw new UnsupportedOperationException("Cannot invoke getPrincipals(int) on LdapPrincipalProvider: not implemented.");
	}

	@Override
	public boolean canReadPrincipal(final Session session, final Principal principalToRead) {
		return true;
	}

	@Override
	public PrincipalIterator findPrincipals(final String simpleFilter) {
		throw new UnsupportedOperationException("Cannot invoke findPrincipals(String) on LdapPrincipalProvider: not implemented.");
	}

	@Override
	public PrincipalIterator findPrincipals(final String simpleFilter, final int searchType) {
		throw new UnsupportedOperationException("Cannot invoke findPrincipals(String, int) on LdapPrincipalProvider: not implemented.");
	}

	@Override
	public PrincipalIterator getGroupMembership(Principal principal) {
		logger.trace("getGroupMembership invoked with {}", principal);
		if (groupCache.get().containsKey(principal)) {
			logger.debug("Cache hit for {}", principal);
			return new PrincipalIteratorAdapter(groupCache.get().get(principal));
		}
		return new PrincipalIteratorAdapter(getLoginGroupMembership(principal));
	}

	public Collection<? extends Principal> getLoginGroupMembership(final Principal principal) {
		logger.trace("Getting group membership for principal {}", principal);
		if (!(principal instanceof LdapPrincipal)) {
			logger.debug("Expected principal of type {} but was of type {}, so returning empty set", LdapPrincipal.class, principal.getClass());
			return ImmutableList.of();
		}

		try {
			logger.trace("Opening initial dir context to {}", userProvider);
			DirContext dc = new InitialDirContext(new Hashtable<Object, Object>(ldapEnvironment));
			try {
				List<Principal> groupPrincipals = newArrayList();
				String memberName = principal.getName();
				collectGroupsForMember(principal, dc, memberName, groupPrincipals);
				groupCache.get().put(principal, groupPrincipals);
				logger.trace("Found the following groups for user [{}]: {}", principal.getName(), principlesAsNames(groupPrincipals));
				return groupPrincipals;
			} finally {
				dc.close();
			}
		} catch (NamingException exc) {
			// Need to print, because this exception might be swallowed.
			logger.error("NamingException occurred", exc);
			throw new RuntimeException("Cannot retrieve groups from LDAP server", exc);
		}
	}

	private List<Object> principlesAsNames(List<Principal> groupPrincipals) {
		return Lists.transform(groupPrincipals, new Function<Principal, Object>() {
			@Override
			public String apply(Principal input) {
				return input.getName();
			}
		});
	}

	protected void collectGroupsForMember(final Principal principal, final DirContext dc, final String memberDn, final List<Principal> groupPrincipalsCollector)
			throws NamingException {
		logger.trace("Getting groups for member dn {} of principal {} from dir context {} and adding it to collector {}", new Object[]{memberDn, dc,
				principal, groupPrincipalsCollector});
		List<String> dnsFound = newArrayList();
		SearchControls controls = new SearchControls(SearchControls.SUBTREE_SCOPE, 0, 0, null, false, false);
		String filter = groupFilter.replace(DN_TOKEN, memberDn.replace("\\", "\\\\"));

		logger.trace("Search dir context {} with name in namespace {} from base {} for filter {} with controls {}", new Object[]{dc, dc.getNameInNamespace(), groupSearchBase, filter,
				controls});
		NamingEnumeration<SearchResult> groupsFound = dc.search(groupSearchBase, filter, controls);
		try {
			while (groupsFound.hasMore()) {
				SearchResult group = groupsFound.next();
				logger.trace("Found group with dn {}, getting cn", group);
				String nameInNamespace = group.getNameInNamespace();
				dnsFound.add(nameInNamespace);
				logger.debug("Found group: {}", nameInNamespace);
				Attribute cn = group.getAttributes().get("cn");
				logger.trace("Found group with dn {} with cn {} with value {}", new Object[]{group, cn, cn.get()});
				groupPrincipalsCollector.add(new LdapGroup(principal, (String) cn.get()));
			}
		} finally {
			groupsFound.close();
		}

		if (includeGroupsOfGroups) {
			for (String each : dnsFound) {
				logger.trace("Finding 'supergroups' for group dn {}", each);
				collectGroupsForMember(principal, dc, each, groupPrincipalsCollector);
				logger.trace("End of 'supergroup' search for group dn {}", each);
			}
		}
	}

	public static class LdapGroup implements Group {

		private Principal member;

		private String cn;

		public LdapGroup(Principal member, String cn) {
			this.member = member;
			this.cn = cn;
		}

		@Override
		public String getName() {
			return cn;
		}

		@Override
		public boolean addMember(Principal user) {
			throw new UnsupportedOperationException("Cannot add a member to an LdapGroup. Use LDAP directly.");
		}

		@Override
		public boolean removeMember(Principal user) {
			throw new UnsupportedOperationException("Cannot remove a member from an LdapGroup. Use LDAP directly.");
		}

		@Override
		public boolean isMember(Principal member) {
			return member.getName().equals(this.member.getName());
		}

		@Override
		public Enumeration<? extends Principal> members() {
			return new Vector<Principal>(Collections.singleton(member)).elements();
		}

		@Override
		public boolean equals(Object o) {
			if (this == o) return true;
			if (o == null || getClass() != o.getClass()) return false;

			LdapGroup ldapGroup = (LdapGroup) o;

			if (cn != null ? !cn.equals(ldapGroup.cn) : ldapGroup.cn != null) return false;
			if (member != null ? !member.equals(ldapGroup.member) : ldapGroup.member != null) return false;

			return true;
		}

		@Override
		public int hashCode() {
			int result = member != null ? member.hashCode() : 0;
			result = 31 * result + (cn != null ? cn.hashCode() : 0);
			return result;
		}

		@Override
		public String toString() {
			return "LdapGroup[" + cn + ": " + member + "]";
		}
	}

	@SuppressWarnings("serial")
	public static final class LdapPrincipal implements Principal, java.io.Serializable {

		private final LdapName ldapName;
		private final String name;

		public LdapPrincipal(String name) throws InvalidNameException {
			checkNotNull(name, "Name of a principal cannot be null");

			this.ldapName = new LdapName(name);
			this.name = name;
		}

		public boolean equals(Object object) {
			if (this == object) {
				return true;
			}
			if (object instanceof Principal) {
				try {
					return ldapName.equals(new LdapName(((Principal) object).getName()));
				} catch (InvalidNameException e) {
					return false;
				}
			}
			return false;
		}

		public int hashCode() {
			return ldapName.hashCode();
		}

		public String getName() {
			return name;
		}

		public String toString() {
			return ldapName.toString();
		}
	}

	public static final class SimplePrincipal implements Principal {
		private String name;

		public SimplePrincipal(String name) {
			checkNotNull(name, "Name cannot be null");
			this.name = name;
		}

		@Override
		public String getName() {
			return name;
		}

		@Override
		public boolean equals(Object o) {

			if (this == o)
				return true;
			if (o == null || getClass() != o.getClass())
				return false;

			SimplePrincipal that = (SimplePrincipal) o;

			return name.equals(that.name);
		}

		@Override
		public int hashCode() {
			return name.hashCode();
		}

		@Override
		public String toString() {
			return name;
		}
	}

	private static final Logger logger = LoggerFactory.getLogger(LdapPrincipalProvider.class);
}
