package is.hugvit.bird.jaas.security;

import is.hugvit.bird.jaas.IBirdRolePrincipal;
import is.hugvit.bird.jaas.IBirdUserPrincipal;

import java.io.IOException;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.Map;

import javax.naming.Context;
import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.Attributes;
import javax.naming.directory.SearchControls;
import javax.naming.directory.SearchResult;
import javax.naming.ldap.InitialLdapContext;
import javax.naming.ldap.LdapContext;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;

public class ActiveDirectoryLoginModule extends AbstractLoginModule implements LoginModule {

	private String directory;
	private String baseDn;
	private String searchBindDn;
	private String searchBindPassword;
	private String searchFilter;
	private String groupSearchFilter;
	private boolean attachCitizenRole = false;
	
	@Override
	public void initialize(Subject sub, CallbackHandler handler, Map state, Map opts) {
		super.initialize(sub, handler, state, opts);
		
		final String[] keys = new String[] {"directory", "base-dn","search-bind-dn", 
				"search-bind-password", "search-filter", "group-search-filter", "add-citizenrole"};
		
		for (final String key : keys) {
			if (!options.containsKey(key)) {
				throw new RuntimeException("Jaas config option: " + key + " is missing.  Cannot continue.");
			}
		}
		
		// Set the Directory Server options
		directory 			= (String) options.get("directory");
		baseDn 				= (String) options.get("base-dn");
		searchBindDn 		= (String) options.get("search-bind-dn");
		searchBindPassword  = (String) options.get("search-bind-password");
		searchFilter	 	= (String) options.get("search-filter");
		groupSearchFilter	= (String) options.get("group-search-filter");
		
		attachCitizenRole = options.get("add-citizenrole").equals("true");
		
		
	}
	
	/* (non-Javadoc)
	 * @see is.hugvit.bird.jaas.security.AbstractLoginModule#commit()
	 */
	@Override
	public boolean commit() throws LoginException {
		debug(getClass().getName() + " Commit");
    	
    	if (currentUser == null) {
    		throw new LoginException("No user to commit");
    	}
    	
        try {
        		
    		subject.getPrincipals().add(currentUser);
    		
    		Iterator<String> it = currentUser.getRoles().iterator();
    		
    		while (it.hasNext()) {
    			IBirdRolePrincipal birdRole = new BirdRolePrincipal(it.next());
    			if (!subject.getPrincipals().contains(birdRole)) {
    				subject.getPrincipals().add(birdRole);
    				debug("Adding role: " + birdRole.getName());
    			}
    		}
    		
    		debug(getClass().getName() + " Commit finished");
    		    		
    		return true;

        } catch (Exception ex) {
            throw new LoginException(ex.getMessage());
        }
	}

	/* (non-Javadoc)
	 * @see is.hugvit.bird.jaas.security.AbstractLoginModule#login()
	 */
	@Override
	public boolean login() throws LoginException {
		debug(getClass().getName() + " login");
		Callback[] callbacks = new Callback[2];
        callbacks[0] = new NameCallback("login");
        callbacks[1] = new PasswordCallback("password", true);

        
        try {

        	callbackHandler.handle(callbacks);

            String name = ((NameCallback) callbacks[0]).getName();
            String password = String.valueOf(((PasswordCallback) callbacks[1]).getPassword());

            debug("user: " + name + " pass: " + password);
            
            
            debug("Launching active directory authentication");
            
            
            currentUser = ldapAuthenticate(name, password);
            
            if (currentUser == null) {
            	debug("Authentication failed. User: " + name);
            	throw new LoginException("Authentication failed");
            }
            
            if (currentUser.getRoles() == null || currentUser.getRoles().size()==0) {
            	throw new LoginException("Authentication failed: No roles are assigned to user");
            }
  
            debug("User has been set - Login Succeded");
            
            return true;

        } catch (IOException ex) {
            throw new LoginException(ex.getMessage());
        } catch (UnsupportedCallbackException ex) {
           throw new LoginException(ex.getMessage());
        } catch (LoginException ex) {
        	throw ex;
        } catch (Exception ex) {
        	throw new LoginException(ex.getMessage());
        }
	}

	private IBirdUserPrincipal ldapAuthenticate(String username, String password) throws NamingException {
		
		LdapContext ctx = null;
		BirdUserPrincipal principal = null;
		
		
		try {
		
			final String userFilter = String.format(this.searchFilter, username);
			
			ctx = createContext();
			ctx.setRequestControls(null);
			
			String commonName = "";
			String distinguishedName = "";
			
			debug("Base DN: " + baseDn);
			debug("Search filter: " +userFilter);
			
			// Search for the user name
			NamingEnumeration<?> namingEnum = ctx.search(baseDn, userFilter, getSimpleSearchControls());
	        
	        if (namingEnum.hasMore()) {
	            SearchResult result = (SearchResult) namingEnum.next();    
	            Attributes attrs = result.getAttributes();
	            	            
	            commonName = attrs.get("cn").get().toString();
	            
	            
	            distinguishedName = attrs.get("distinguishedName").get().toString();
	            
	            // Try to authenticate as the user
	            debug("Found user in directory, trying authentication for: " + distinguishedName);
	            if (authenticateUser(distinguishedName, password)) {
	            	
	            	debug("User "+commonName+" was authenticated");
	            	
	            	principal = new BirdUserPrincipal();
	            	principal.setName(username);
	            	principal.setFullName(commonName);
	            	principal.setDirectoryAuthentication(true);
	            	principal.setCertificateAuthentication(false);
	            		            		            	
	            	
	            	Attribute email = attrs.get("mail");
	            	if (email != null) {
	            		String emailAddress = email.get().toString();
	            		debug("Got mail entry: " + emailAddress);
	            		principal.setEmail(emailAddress);
	            	} else {
	            		debug("No mail entry was provided");
	            	}
	            	
	            	
	            	
	            }
	        }
	        namingEnum.close();
	        
	        
	        // If we have a user, lookup the roles
	        if (principal != null) {
	        	
	        	String groupSearch = String.format(this.groupSearchFilter, distinguishedName);
            	namingEnum = ctx.search(baseDn, groupSearch, getSimpleSearchControls());
    	        while (namingEnum.hasMore()) {
    	        	SearchResult result = (SearchResult) namingEnum.next();    
    	        	Attributes attrs = result.getAttributes();
    	            String roleName = attrs.get("cn").get().toString();
    	            if (roleName != null && !roleName.equals("")) {
	    	            debug("["+commonName+"] Adding user role: " + roleName);
	    	            principal.getRoles().add(new String(roleName));
    	            }
    	        } 
    	        namingEnum.close();
    	        
    	        if (attachCitizenRole) {
    	        	principal.getRoles().add("citizen");
    	        }
    	        
	        }    	
	    
			return principal;
			
		} catch (ClassCastException ex) {
			ex.printStackTrace();
			throw ex;
		} finally {
			if (ctx != null) {
				ctx.close();
			}
		}
	}

	/**
	 * Try to authenticate using submitted credentials
	 * 
	 * @param distinguishedName The distinguished name of the user 
	 * @param password The user password
	 * @return
	 */
	private boolean authenticateUser(final String distinguishedName, final String password) {
		try {
			 LdapContext ctx = null;
			 ctx = tryAuthentication(distinguishedName, password);
			 ctx.close();
			 
			 return true;
		} catch (NamingException ex) {
			return false;
		}
	}
	

	private LdapContext tryAuthentication(String bindDN, String bindPassword) throws NamingException {
		Hashtable<String, String> env = new Hashtable<String,String>();
		env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory");
		env.put(Context.PROVIDER_URL, this.directory);
		if (this.directory.startsWith("ldaps")) {
			env.put("java.naming.ldap.factory.socket", TrustAllSSLSocketFactory.class.getCanonicalName());
			env.put(Context.SECURITY_PROTOCOL, "ssl");
		}
		
		env.put(Context.SECURITY_PRINCIPAL, bindDN);
		env.put(Context.SECURITY_CREDENTIALS, bindPassword);
		env.put(Context.SECURITY_AUTHENTICATION, "simple");
				
		LdapContext context = new InitialLdapContext(env, null);
		return context;
	}
	
	
	private LdapContext createContext() throws NamingException 	{
		Hashtable<String, String> env = new Hashtable<String,String>();
		env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory");
		env.put(Context.PROVIDER_URL, this.directory);
		if (this.directory.startsWith("ldaps")) {
			env.put("java.naming.ldap.factory.socket", TrustAllSSLSocketFactory.class.getCanonicalName());
			env.put(Context.SECURITY_PROTOCOL, "ssl");
		}
		env.put(Context.SECURITY_PRINCIPAL, this.searchBindDn);
		env.put(Context.SECURITY_CREDENTIALS, this.searchBindPassword);
		env.put(Context.SECURITY_AUTHENTICATION, "simple");
		env.put("com.sun.jndi.ldap.connect.pool", "true");
		
		// AD specific to fix (unprocessed continuation reference exceptions)
		env.put(Context.REFERRAL, "follow");
		
		LdapContext context = new InitialLdapContext(env, null);
		return context;
	}
	
	private SearchControls getSimpleSearchControls() {
	    SearchControls searchControls = new SearchControls();
	    searchControls.setSearchScope(SearchControls.SUBTREE_SCOPE);
	    searchControls.setDerefLinkFlag(false);
	    searchControls.setTimeLimit(30000);
	    return searchControls;
	}

}
