package is.hugvit.bird.jaas.security;

import is.hugvit.bird.jaas.IBirdRolePrincipal;

import java.io.IOException;
import java.security.MessageDigest;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.Date;
import java.util.Iterator;
import java.util.Map;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.security.auth.spi.LoginModule;

public class JdbcLoginModule extends AbstractLoginModule implements LoginModule {

	private String dsDriver;
	private String dsConnection;
	private String dsUsername;
	private String dsPassword;
	
	public JdbcLoginModule() {
		super();
		debug(getClass().getName() + " loaded");
	}
	
	public void initialize(Subject sub, CallbackHandler handler, Map state, Map opts) {
		super.initialize(sub, handler, state, opts);
		
		// Set the JDBC connection options
		Object driver = options.get("driver");
		Object connection = options.get("conn");
		Object username = options.get("username");
		Object password = options.get("password");
		
		if (driver != null) {
			this.dsDriver = (String) driver;
		}
		if (connection != null) {
			this.dsConnection = (String) connection;
		}
		if (username != null) {
			this.dsUsername = (String) username;
		}
		if (password != null) {
			this.dsPassword = (String) password;
		}
	}
	
	/* (non-Javadoc)
	 * @see is.hugvit.bird.jaas.security.AbstractLoginModule#commit()
	 */
	@Override
	public boolean commit() throws LoginException {
		debug(getClass().getName() + " Commit");
    	
    	if (currentUser == null) {
    		throw new LoginException("No user to commit");
    	}
    	
        try {
        		
    		subject.getPrincipals().add(currentUser);
    		Iterator<String> it = currentUser.getRoles().iterator();
    		
    		while (it.hasNext()) {
    			IBirdRolePrincipal birdRole = new BirdRolePrincipal(it.next());
    			if (!subject.getPrincipals().contains(birdRole)) {
    				subject.getPrincipals().add(birdRole);
    				debug("Adding role: " + birdRole.getName());
    			}
    		}
    		
    		try {
    			createLogEntry();
    		} catch (Exception ex) {
    			debug("Error creating log entry: " + ex.getMessage());
    		}
    		
    		debug(getClass().getName() + " Commit finished");
    		
    		return true;

        } catch (Exception ex) {
            throw new LoginException(ex.getMessage());
        }
	}



	/* (non-Javadoc)
	 * @see is.hugvit.bird.jaas.security.AbstractLoginModule#login()
	 */
	@Override
	public boolean login() throws LoginException {
		debug(getClass().getName() + " login");
		Callback[] callbacks = new Callback[2];
        callbacks[0] = new NameCallback("login");
        callbacks[1] = new PasswordCallback("password", true);

        validateConnectionOptions();
        
        try {

        	callbackHandler.handle(callbacks);

            String name = ((NameCallback) callbacks[0]).getName();
            String password = String.valueOf(((PasswordCallback) callbacks[1]).getPassword());

            debug("user: " + name + " pass: " + password);
            
            
            debug("Launching jdbc authentication");
            
            
            currentUser = jdbcAuthenticate(name, password);
            
            if (currentUser == null) {
            	debug("Authentication failed. User: " + name);
            	throw new LoginException("Authentication failed");
            	// return false;
            }
            
            if (currentUser.getRoles() == null || currentUser.getRoles().size()==0) {
            	throw new LoginException("Authentication failed: No roles are assigned to user");
            }
  
            debug("User has been set - Login Succeded");
            
            return true;

        } catch (IOException ex) {
            throw new LoginException(ex.getMessage());
        } catch (UnsupportedCallbackException ex) {
           throw new LoginException(ex.getMessage());
        } catch (LoginException ex) {
        	throw ex;
        } catch (Exception ex) {
        	throw new LoginException(ex.getMessage());
        }
	}



	
	/* (non-Javadoc)
	 * @see is.hugvit.bird.jaas.security.AbstractLoginModule#logout()
	 */
	@Override
	public boolean logout() throws LoginException {
		return super.logout();
	}

	private BirdUserPrincipal jdbcAuthenticate(final String username, final String password) throws LoginException {
		
		BirdUserPrincipal principal = null;
		Connection con = null;
		PreparedStatement stmt = null;
		ResultSet rs = null;
		
		debug("Starting SQL Query");
		
		try {
			
			final String TBL_USERS 		= "ss_users";
			final String TBL_USERROLES  = "ss_userroles";
			final String TBL_ROLEMAP 	= "ss_userrolemap";
			
			String sql = "SELECT u.user_id, u.username, u.fullname, u.password, u.email, u.lang " +
					"FROM " + TBL_USERS + " u WHERE u.username = ? AND u.is_deleted = 0";
			
			con = getConnection();
			stmt = con.prepareStatement(sql);
			stmt.setString(1, username);
			rs = stmt.executeQuery();
			
			if (rs.next()) {
				String dbPassword = rs.getString("password");
				if (dbPassword != null) {
					if ( dbPassword.equals(encryptPassword(password)) ) {
						debug("Found user match for: " + username);
						principal = new BirdUserPrincipal();
						principal.setName(rs.getString("username"));
						principal.setFullName(rs.getString("fullname"));
						principal.setUserId(rs.getString("user_id"));
						principal.setEmail(rs.getString("email"));
						principal.setLocale(rs.getString("lang"));
					}
				}
			}
			rs.close();
			
			
			if (principal != null) {

				// Set the user roles
				sql = "SELECT r.rolename FROM " + TBL_USERROLES + " r " +
					"INNER JOIN " + TBL_ROLEMAP + " m ON r.role_id = m.role_id " +
					"WHERE m.user_id = ? AND r.is_deleted = 0";
				
				stmt = con.prepareStatement(sql);
				stmt.setString(1, principal.getUserId());
				rs = stmt.executeQuery();
				while (rs.next()) {
					principal.getRoles().add(rs.getString("rolename"));
				}
				rs.close();
			}
			
			if (principal != null) {
				debug("Principal successfully set for: " + principal.getName());
			}
			
		} catch (Exception ex) {
			throw new LoginException(ex.getMessage());
		} finally {
			try {
				if (stmt != null) {
					stmt.close();
				}
				if (con != null) {
					con.close();
				}
			} catch (SQLException ex) {
				throw new LoginException(ex.getMessage());
			}
		}
		
		
		return principal;
	}
	
	private void createLogEntry() throws Exception {
	
		Connection con = null;
		PreparedStatement stmt = null;
		
		try {
			
			final String TBL_LOG = "ss_history";
			
			String sql = "INSERT INTO " + TBL_LOG + " (document_id, user_id, stamp, action, datatype) " +
					"VALUES (?,?,?,?,?)";
			
			con = getConnection();
			stmt = con.prepareStatement(sql);
			
			stmt.setString(1, currentUser.getUserId());
			stmt.setString(2, currentUser.getUserId());
			stmt.setTimestamp(3, new Timestamp(new Date().getTime()));
			stmt.setString(4, "LOGIN");
			stmt.setString(5, "USER");
			
			stmt.executeUpdate();
			
		} finally {
			if (stmt != null) {
				stmt.close();
			}
			if (con != null) {
				con.close();
			}
		}
	}
	
	private Connection getConnection() throws Exception {
		Class.forName(this.dsDriver).newInstance();
		return DriverManager.getConnection(
				this.dsConnection, this.dsUsername, this.dsPassword);
	}


	private void validateConnectionOptions() throws LoginException {
		String[] fields = new String[] {dsConnection, dsDriver, dsUsername, dsPassword};
		for (String s : fields) {
			if (s == null) {
				throw new LoginException("JDBC options are not set, " +
						"the following options must be set: [driver,conn,username,password]");
			}
		}
	}

	private static final String encryptPassword(String password) throws Exception {
		
		StringBuffer encoded = new StringBuffer();
		if (password != null) {
	        MessageDigest md5 = MessageDigest.getInstance("MD5");
	        md5.update(password.getBytes("UTF-8"));
	        byte[] digest = md5.digest();
	        for (int i = 0; i < digest.length; ++i) {
	          int b = (int) digest[i] & 0xFF;
	          	if (b < 16) {
	          		encoded.append("0");
	          	}
	          encoded.append(Integer.toHexString(b));
	        }
		}
		
		return encoded.toString();
    
	}
	

}
