package com.xebialabs.deployit.security;

import com.google.common.collect.Lists;
import com.google.common.collect.MapMaker;
import org.apache.jackrabbit.api.security.principal.PrincipalIterator;
import org.apache.jackrabbit.core.security.principal.PrincipalIteratorAdapter;
import org.apache.jackrabbit.core.security.principal.PrincipalProvider;
import org.apache.jackrabbit.core.security.principal.UnknownPrincipal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.jcr.Session;
import javax.naming.Context;
import javax.naming.InvalidNameException;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.*;
import java.security.Principal;
import java.security.acl.Group;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.collect.Lists.newArrayList;
import static com.xebialabs.deployit.checks.Checks.checkNotNull;

/**
 * Provides principals for Jackrabbit from LDAP. Should only be configured as a companion to {@link LdapLoginModule}.
 */
public class LdapPrincipalProvider implements PrincipalProvider {

	private String userProvider;
	private String groupFilter = "uniqueMember={DN}";

	private static final String DN_TOKEN = "{DN}";
	private static final Pattern DN_PATTERN = Pattern.compile("\\{DN\\}");
	private Hashtable<String, String> ldapEnvironment;

	/**
	 * Create a caching map that expires values after 10 seconds.
	 */
	private static final Map<Principal, List<Principal>> groupCache = new MapMaker().expireAfterAccess(10, TimeUnit.SECONDS).makeMap();


	@Override
	public void init(final Properties options) {
		userProvider = (String) options.get("userProvider");
		logger.debug("Initialized with userProvider={}", userProvider);
		if (options.containsKey("groupFilter")) {
			groupFilter = (String) options.get("groupFilter");
		}

		logger.debug("Configured groupFilter: " + groupFilter);

		ldapEnvironment = new Hashtable<String, String>();
		ldapEnvironment.put(Context.INITIAL_CONTEXT_FACTORY, LdapContextFactory.class.getName());
		ldapEnvironment.put(Context.PROVIDER_URL, userProvider);
		for (Object o : options.keySet()) {
			// Adding extra keys
			if (((String) o).contains(".")) {
				logger.debug("Adding key to ldapEnvironment: {}", o);
				ldapEnvironment.put((String) o, options.getProperty((String) o));
			}
		}
	}

	@Override
	public void close() {

	}

	@Override
	public Principal getPrincipal(final String principalName) {
		logger.debug("Getting principal for {}", principalName);
		try {
			return new LdapPrincipal(principalName);
		} catch (InvalidNameException e) {
			return new UnknownPrincipal(principalName);
		}
	}

	@Override
	public PrincipalIterator getPrincipals(final int searchType) {
		throw new UnsupportedOperationException("Cannot invoke getPrincipals(int) on LdapPrincipalProvider: not implemented.");
	}

	@Override
	public boolean canReadPrincipal(final Session session, final Principal principalToRead) {
		return true;
	}

	@Override
	public PrincipalIterator findPrincipals(final String simpleFilter) {
		throw new UnsupportedOperationException("Cannot invoke findPrincipals(String) on LdapPrincipalProvider: not implemented.");
	}

	@Override
	public PrincipalIterator findPrincipals(final String simpleFilter, final int searchType) {
		throw new UnsupportedOperationException("Cannot invoke findPrincipals(String, int) on LdapPrincipalProvider: not implemented.");
	}

	@Override
	public PrincipalIterator getGroupMembership(Principal principal) {
		logger.trace("getGroupMembership invoked with {}", principal);
		if (groupCache.containsKey(principal)) {
			logger.debug("Cache hit for {}", principal);
			return new PrincipalIteratorAdapter(groupCache.get(principal));
		}
		return getLoginGroupMembership(principal);
	}

	public PrincipalIterator getLoginGroupMembership(final Principal principal) {
		logger.trace("Getting group membership for principal {}", principal);
		if (!(principal instanceof LdapPrincipal)) {
			logger.debug("Expected principal of type {} but was of type {}, so returning empty set", LdapPrincipal.class, principal.getClass());
			groupCache.put(principal, Lists.<Principal>newArrayList());
			return PrincipalIteratorAdapter.EMPTY;
		}

		try {
			logger.trace("Opening initial dir context to {}", userProvider);
			DirContext dc = new InitialDirContext(new Hashtable<Object, Object>(ldapEnvironment));
			try {
				List<Principal> groupPrincipals = newArrayList();
				String memberName = principal.getName();
				collectGroupsForMember(principal, dc, memberName, groupPrincipals);
				groupCache.put(principal, groupPrincipals);
				return new PrincipalIteratorAdapter(groupPrincipals);
			} finally {
				dc.close();
			}
		} catch (NamingException exc) {
			throw new RuntimeException("Cannot retrieve groups from LDAP server", exc);
		}
	}

	protected void collectGroupsForMember(final Principal principal, final DirContext dc, final String memberDn, final List<Principal> groupPrincipalsCollector)
			throws NamingException {
		logger.trace("Getting groups for member dn {} of principal {} from dir context {} and adding it to collector {}", new Object[]{memberDn, dc,
				principal, groupPrincipalsCollector});
		List<String> dnsFound = newArrayList();
		SearchControls controls = new SearchControls(SearchControls.SUBTREE_SCOPE, 0, 0, null, false, false);
		String filter = groupFilter;
		if (groupFilter.contains(DN_TOKEN)) {
			Matcher dnMatcher = DN_PATTERN.matcher(groupFilter);
			filter = dnMatcher.replaceAll(memberDn);
		}

		logger.trace("Search dir context {} with name in namespace {} for filter {} with controls {}", new Object[]{dc, dc.getNameInNamespace(), filter,
				controls});
		NamingEnumeration<SearchResult> groupsFound = dc.search("", filter, controls);
		try {
			while (groupsFound.hasMore()) {
				SearchResult group = groupsFound.next();
				logger.debug("Found group with dn {}, getting cn", group);
				dnsFound.add(group.getNameInNamespace());
				Attribute cn = group.getAttributes().get("cn");
				logger.debug("Found group with dn {} with cn {} with value {}", new Object[]{group, cn, cn.get()});
				groupPrincipalsCollector.add(new LdapGroup(principal, (String) cn.get()));
			}
		} finally {
			groupsFound.close();
		}

		for (String each : dnsFound) {
			logger.trace("Invoking recursively for dn {}", each);
			collectGroupsForMember(principal, dc, each, groupPrincipalsCollector);
			logger.trace("End of recursive invocation for dn {}", each);
		}
	}

	public static class LdapGroup implements Group {

		private Principal member;

		private String cn;

		public LdapGroup(Principal member, String cn) {
			this.member = member;
			this.cn = cn;
		}

		@Override
		public String getName() {
			return cn;
		}

		@Override
		public boolean addMember(Principal user) {
			throw new UnsupportedOperationException("Cannot add a member to an LdapGroup. Use LDAP directly.");
		}

		@Override
		public boolean removeMember(Principal user) {
			throw new UnsupportedOperationException("Cannot remove a member from an LdapGroup. Use LDAP directly.");
		}

		@Override
		public boolean isMember(Principal member) {
			return member.getName().equals(this.member.getName());
		}

		@SuppressWarnings("unchecked")
		@Override
		public Enumeration<? extends Principal> members() {
			return new Vector<Principal>(Collections.singleton(member)).elements();
		}

		@Override
		public boolean equals(Object o) {
			if (this == o) return true;
			if (o == null || getClass() != o.getClass()) return false;

			LdapGroup ldapGroup = (LdapGroup) o;

			if (cn != null ? !cn.equals(ldapGroup.cn) : ldapGroup.cn != null) return false;
			if (member != null ? !member.equals(ldapGroup.member) : ldapGroup.member != null) return false;

			return true;
		}

		@Override
		public int hashCode() {
			int result = member != null ? member.hashCode() : 0;
			result = 31 * result + (cn != null ? cn.hashCode() : 0);
			return result;
		}

		@Override
		public String toString() {
			return "LdapGroup[" + cn + ": " + member + "]";
		}
	}

	@SuppressWarnings("serial")
	public static final class LdapPrincipal implements Principal, java.io.Serializable {

		private final String name;

		public LdapPrincipal(String name) throws InvalidNameException {
			checkNotNull(name, "Name of a principal cannot be null");

			this.name = name;
		}

		public boolean equals(Object object) {
			if (this == object) {
				return true;
			}
			if (object instanceof Principal) {
				Principal other = (Principal) object;
				return name.equals(other.getName());

			}
			return false;
		}

		public int hashCode() {
			return name.hashCode();
		}

		public String getName() {
			return name;
		}

		public String toString() {
			return name.toString();
		}
	}

	public static final class SimplePrincipal implements Principal {
		private String name;

		public SimplePrincipal(String name) {
			checkNotNull(name, "Name cannot be null");
			this.name = name;
		}

		@Override
		public String getName() {
			return name;
		}

		@Override
		public boolean equals(Object o) {

			if (this == o)
				return true;
			if (o == null || getClass() != o.getClass())
				return false;

			SimplePrincipal that = (SimplePrincipal) o;

			return name.equals(that.name);
		}

		@Override
		public int hashCode() {
			return name.hashCode();
		}

		@Override
		public String toString() {
			return name;
		}
	}
	private static final Logger logger = LoggerFactory.getLogger(LdapPrincipalProvider.class);
}
