Saturday, September 14, 2013

A generic problem solver from State to Goal

I thought of making a generic problem solver, which could solve any problem which has an initial state and a final state and a set of possible moves.

Though my program is not perfect, but still it is able to give some learning and if you guys have suggestions to make it better, please feel free to write in comments.

I have tried to use Template Design pattern here.. Where an algorithm is implemented in the abstract class with some parts left out to be implemented by the user.

Create a State interface

package mystate;
import java.util.ArrayList;
import java.util.List;
public abstract class State
{
    // Define a state that'll hold reference to predecessor.
    public State predecessor = null;    
    public abstract boolean isGoal();
    @Override
    public abstract boolean equals(Object xx);
    @Override
    public abstract int hashCode();
    public abstract List<State> getChildren() ;
    @Override
    public abstract String toString();

    public State getPredecessor() {
        return this.predecessor;
    }

    public void setPredecessor(State s) {
        this.predecessor = s;
    }
    public void printResult()
    {
        List<State> result =  new ArrayList<State>();
        State c = this;
        while(c.getPredecessor() !=null)
        {
            result.add(c.getPredecessor());
            c=c.getPredecessor();
        }      
        for(int i=result.size()-1;i>=0;i--)
        {
            State curr = result.get(i);
            System.out.println(curr);
        }   
        System.out.print(this);
    }
}


Now implement this abstract class and create a custom class that defines the state of your problem.
Please make sure that you override all the methods given in the interface very correctly.
Do make member variables that can signify your state.
For example, let me take an example of 15-squred number puzzle

Implement the abstract class and create SlidingState class

package mystate;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 *
 * @author Yogi
 */
public class SlidingState extends State {

    int[][] currentState = {{1,2,3,4}, {5,6,7,8}, {9,10,11,12},{0,13,14,15}};  
    public SlidingState()
    {

    }    
    // Provide constructors to initialize the state.
    public SlidingState(int[][] state)
    {
        this.currentState = state;
    }
    @Override
    public int hashCode()
    {
        return Arrays.deepHashCode(currentState);
    }
    // Be double sure that you override equals method correctly
    // else you'll end up in infinite loop.
    @Override
    public boolean equals(Object xx)
    {
        if(xx==null) return false;
        if(!(xx instanceof SlidingState))
        {
            return false;
        }
        return Arrays.deepEquals(this.currentState, ((SlidingState)xx).currentState);
    }
    @Override
    public boolean isGoal() {
        int[][] goalState = {{1,2,3,4},{5,6,7,8},{9,10,11,12},{13,14,15,0}};
        return Arrays.deepEquals(currentState, goalState);
    }

    @Override
    public List<State> getChildren() {
        List<State> children = new ArrayList<State>();
        // First find where is the 0 in the grid
        boolean found0=false;
        int i=0;
        int j=0;
        for(i=0;i<4;i++)
        {
            for(j=0;j<4;j++)
            {
                if(currentState[i][j]==0)
                {
                    found0=true;
                    break;
                }
            }
            if(found0) break;
        }
        List<Integer> moves = findMoves(i,j);
        for(Integer m : moves)
        {
            int[][] newPosition = exchange(currentState, m);
            State t = new SlidingState(newPosition);
            t.setPredecessor(this);
            children.add(t);
        }
        return children;
        
    }
    @Override 
    public String toString()
    {
        StringBuilder s = new StringBuilder("");
        s.append("{\n");
        for(int i=0;i<4;i++)
        {
            s.append("{");
            for(int j=0;j<4;j++)
            {
                s.append(currentState[i][j]);
                if(j!=3) s.append(",");
            }
            s.append("}");            
        }
        s.append("\n}");    
        return s.toString();
    }
    private List<Integer> findMoves(int i, int j) {
        List<Integer> moves = new ArrayList<Integer>();
        try
        {
            moves.add(currentState[i+1][j]);
        }
        catch(Throwable t)
        {
            
        
        }         
        try
        {
            moves.add(currentState[i][j+1]);
        }
        catch(Throwable t)
        {
            
        } 
        try
        {
            moves.add(currentState[i][j-1]);
        }
        catch(Throwable t)
        {
            
        }   
        
        try
        {
            moves.add(currentState[i-1][j]);
        }
        catch(Throwable t)
        {
            
        }
 
           
        return moves;
        
    }

    private int[][] exchange(int[][] currentState, Integer m) {
        int i=0;
        int j=0;
        int mi=0;
        int mj=0;        
        int[][] newPos = new int[4][4];
        for(i=0;i<4;i++)
        {
            for(j=0;j<4;j++)
            {
                newPos[i][j]=currentState[i][j];
            }
        }
        boolean find0 = false;

        for(i=0;i<4;i++)
        {
            for(j=0;j<4;j++)
            {
                if(newPos[i][j]==0)
                {
                    find0=true;
                    break;
                }
            }
            if(find0) break;
        }                
        boolean findm=false;
        for(mi=0;mi<4;mi++)
        {
            for(mj=0;mj<4;mj++)
            {
                if(newPos[mi][mj]==m.intValue())
                {
                    findm=true;
                    break;
                }
            }
            if(findm) break;
        }    
        newPos[i][j]=m.intValue();
        newPos[mi][mj]=0;
        return newPos;        
    }

}

Create a SolveProblem abstract class as follows

SolveProblem Abstract class

package solution;


import mystate.State;
import java.util.Set;
import java.util.List;

public abstract class SolveProblem {
    private java.util.Set<State> visitedStates;

    public Set<State> getStateQueue() {
        return stateQueue;
    }

    public void setStateQueue(Set<State> stateQueue) {
        this.stateQueue = stateQueue;
    }

    public java.util.Set<State> getVisitedStates() {
        return visitedStates;
    }

    public void setVisitedStates(java.util.Set<State> visitedStates) {
        this.visitedStates = visitedStates;
    }
    private Set<State> stateQueue;  
    public abstract State getStateObject();
    public SolveProblem()
    {
        // For visitedState, HashSet did work, because we are never
        // retrieving from visitedStates, but checking only contains method.
        visitedStates=new java.util.HashSet<State>();
        // You have to use LinkedHashSet here, because, retrieval order must 
        // be same as insertion order... otherwise, it'll go into an infinite
        // loop.
        stateQueue = new java.util.TreeSet<State>();  
        
    }
    public void bfs()
    {
        State currentState = getStateObject();
        // Add current state to state Queue.
        stateQueue.add(currentState);
        do
        {
            // Get the first Element from Queue.
            //Collections.sort(stateQueue);
            State firstElementInQueue = stateQueue.iterator().next();//stateQueue.peek();
            // If the first Element is the Goal
            // We are done.
            if(firstElementInQueue.isGoal())
            {
                firstElementInQueue.printResult();
                // There is no recursion here, so simple return would do.
                return;
            }
            else
            {
                // Add firstElement to visited States
                visitedStates.add(firstElementInQueue);    
                // Get the children of first element
                List<State> children = firstElementInQueue.getChildren();
                for(State v : children)
                {
                    if(v.isGoal())
                    {
                        v.printResult();
                        return;
                    }
                    if(!visitedStates.contains(v))
                    {
                        stateQueue.add(v);
                    }
                            
                }
                // Remove the first element from state queue.
                stateQueue.remove(firstElementInQueue);
                
            }
            long sz=stateQueue.size();
            if(sz%1000==0)
            System.out.println(sz);
            // do this till state queue is empty.
        }while(!stateQueue.isEmpty());
    }
    public void dfs(State currentState, java.util.Set<State> vStates)
    {
        // if we pass vStates as null. i.e. in the beginning.
        if(vStates==null) vStates = visitedStates;
        // if visisted state contains currentState, then just return.
        // This is the wrong branch, and we need not traverse it further.
        if(vStates.contains(currentState))
            return;
        
        // if it is GOAL
        if(currentState.isGoal())
        {
            // That's it we are done.
            currentState.printResult();
            System.exit(0);            
        }
        else
        {
            System.out.println("Number of nodes checked = " + vStates.size());
        }
        
        
        // Add current state to visited states.
        vStates.add(currentState);        
        
        // Find the set of possible children of current state.
        List<State> children = currentState.getChildren();
        for(State c : children)
        {
            // if a children C is not in the visited states 
            // again call DFS on current child and visited States.
            if(!vStates.contains(c))
            {
                // Make clone of visited states.
                java.util.Set<State> clonedVStates = new java.util.HashSet<State>(vStates);
                dfs(c, clonedVStates);
            }
        }
        vStates=null;
    }
}


Now extend the SolveProblem class and override the abstract method to give your implementation as follows
And use the methods to solve your problem:

Extend SolveProblem class and use

package solution;

import mystate.SlidingState;
import mystate.State;

/**
 *
 * @author Yogi
 */
public class Solve15Puzzle extends SolveProblem {

    public static void main(String[] args)
    {
        //Get the jvm heap size.
        long heapSize = Runtime.getRuntime().totalMemory();
         
        //Print the jvm heap size.
        System.out.println("Heap Size = " + heapSize/(1024*1024) + " MB");        
        SolveProblem n = new Solve15Puzzle();
        n.bfs();
    }

    @Override
    public State getStateObject() {
//    int[][] currentState = {{8,5,0,6}, {2,1,9,4}, {14,10,7,11},{13,3,15,12}};    
    int[][] currentState = {{1,2,3,4}, {5,6,7,8}, {9,10,11,12},{0,13,14,15}};    
//    int[][] currentState = {{1,2,3,4}, {5,6,7,8}, {9,0,10,12},{13,14,11,15}}; 
    // The following problem needs minimum 6 moves to solve. But DFS algo takes around 40 moves...
//    int[][] currentState = {{1,2,3,4}, {5,6,11,7}, {9,10,0,8},{13,14,15,12}};        
//        int[][] currentState = {{1,2,3,4}, {5,6,7,8}, {9,0,10,12},{13,14,11,15}};         
//        int[][] currentState = {{8,5,0,6}, {2,1,9,4}, {14,10,7,11},{13,3,15,12}};
        return new SlidingState(currentState);
    }
}

No comments:

Post a Comment