Implemented AStar and rewrote BFS to use AStar algorithm

This commit is contained in:
Brandon Rozek 2023-11-02 16:26:20 -04:00
parent 2f08f98845
commit edde3cc8e5
No known key found for this signature in database
GPG key ID: 26E457DA82C9F480
9 changed files with 248 additions and 121 deletions

View file

@ -0,0 +1,161 @@
package org.rairlab.planner;
import org.rairlab.shadow.prover.representations.formula.Formula;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
class AStarComparator implements Comparator<Pair<State, List<Action>>> {
private Map<Pair<State, List<Action>>, Integer> heuristic;
public AStarComparator() {
this.heuristic = new HashMap<Pair<State, List<Action>>, Integer>();
}
@Override
public int compare(Pair<State, List<Action>> o1, Pair<State, List<Action>> o2) {
// Print nag message if undefined behavior is happening
if (!this.heuristic.containsKey(o1) || !this.heuristic.containsKey(o2)) {
System.out.println("[ERROR] Heuristic is not defined for state");
}
int i1 = this.heuristic.get(o1);
int i2 = this.heuristic.get(o2);
return i1 < i2 ? -1: 1;
}
public void setValue(Pair<State, List<Action>> k, int v) {
this.heuristic.put(k, v);
}
public int getValue(Pair<State, List<Action>> k) {
return this.heuristic.get(k);
}
}
/**
* Created by brandonrozek on 03/29/2023.
*/
public class AStarPlanner {
// The longest plan to search for, -1 means no bound
private Optional<Integer> MAX_DEPTH = Optional.empty();
// Number of plans to look for, -1 means up to max_depth
private Optional<Integer> K = Optional.empty();
public AStarPlanner(){ }
public Set<Plan> plan(Set<Formula> background, Set<Action> actions, State start, State goal, Function<State, Integer> heuristic) {
// Search Space Data Structures
Set<State> history = new HashSet<State>();
// Each node in the search space consists of
// (state, sequence of actions from initial)
AStarComparator comparator = new AStarComparator();
Queue<Pair<State, List<Action>>> search = new PriorityQueue<Pair<State,List<Action>>>(comparator);
// Submit Initial State
Pair<State, List<Action>> searchStart = Pair.of(start, new ArrayList<Action>());
comparator.setValue(searchStart, 0);
search.add(searchStart);
// Current set of plans
Set<Plan> plansFound = new HashSet<Plan>();
// AStar Traversal until
// - No more actions can be applied
// - Max depth reached
// - Found K plans
while (!search.isEmpty()) {
Pair<State, List<Action>> currentSearch = search.remove();
State lastState = currentSearch.getLeft();
List<Action> previous_actions = currentSearch.getRight();
// System.out.println("Considering state with heuristic: " + comparator.getValue(currentSearch));
// Exit loop if we've passed the depth limit
int currentDepth = previous_actions.size();
if (MAX_DEPTH.isPresent() && currentDepth > MAX_DEPTH.get()) {
break;
}
// If we're at the goal return
if (Operations.satisfies(background, lastState, goal)) {
plansFound.add(new Plan(previous_actions));
if (K.isPresent() && plansFound.size() >= K.get()) {
break;
}
continue;
}
// Only consider non-trivial actions
Set<Action> nonTrivialActions = actions.stream()
.filter(Action::isNonTrivial)
.collect(Collectors.toSet());
// Apply the action to the state and add to the search space
for (Action action : nonTrivialActions) {
Optional<Set<Pair<State, Action>>> optNextStateActionPairs = Operations.apply(background, action, lastState);
// Ignore actions that aren't applicable
if (optNextStateActionPairs.isEmpty()) {
continue;
}
// Action's aren't grounded so each nextState represents
// a different parameter binding
Set<Pair<State, Action>> nextStateActionPairs = optNextStateActionPairs.get();
for (Pair<State, Action> stateActionPair: nextStateActionPairs) {
State nextState = stateActionPair.getLeft();
Action nextAction = stateActionPair.getRight();
// Prune already visited states
if (history.contains(nextState)) {
continue;
}
// Add to history
history.add(nextState);
// Construct search space parameters
List<Action> next_actions = new ArrayList<Action>(previous_actions);
next_actions.add(nextAction);
// Add to search space
Pair<State, List<Action>> futureSearch = Pair.of(nextState, next_actions);
int planCost = next_actions.stream().map(Action::getCost).reduce(0, (a, b) -> a + b);
int heuristicValue = heuristic.apply(nextState);
comparator.setValue(futureSearch, planCost + heuristicValue);
search.add(futureSearch);
}
}
}
return plansFound;
}
public Optional<Integer> getMaxDepth() {
return MAX_DEPTH;
}
public void setMaxDepth(int maxDepth) {
MAX_DEPTH = Optional.of(maxDepth);
}
public void setK(int k) {
K = Optional.of(k);
}
public void clearK() {
K = Optional.empty();
}
public Optional<Integer> getK() {
return K;
}
}

View file

@ -13,6 +13,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.ArrayList;
/** /**
* Created by naveensundarg on 1/13/17. * Created by naveensundarg on 1/13/17.
@ -28,12 +29,13 @@ public class Action {
private final String name; private final String name;
private final Formula precondition; private final Formula precondition;
private int cost;
private int weight; private int weight;
private final boolean trivial; private final boolean trivial;
private final Compound shorthand; private final Compound shorthand;
public Action(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, List<Variable> freeVariables, List<Variable> interestedVars) { public Action(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, int cost, List<Variable> freeVariables, List<Variable> interestedVars) {
this.name = name; this.name = name;
this.preconditions = preconditions; this.preconditions = preconditions;
@ -52,6 +54,7 @@ public class Action {
this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() + this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() +
additions.stream().mapToInt(Formula::getWeight).sum() + additions.stream().mapToInt(Formula::getWeight).sum() +
deletions.stream().mapToInt(Formula::getWeight).sum(); deletions.stream().mapToInt(Formula::getWeight).sum();
this.cost = cost;
List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());; List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());;
this.shorthand = new Compound(name, valuesList); this.shorthand = new Compound(name, valuesList);
@ -61,7 +64,7 @@ public class Action {
} }
public Action(String name, Set<Formula> preconditions, Set<Formula> additions, public Action(String name, Set<Formula> preconditions, Set<Formula> additions,
Set<Formula> deletions, List<Variable> freeVariables, Set<Formula> deletions, int cost, List<Variable> freeVariables,
Compound shorthand Compound shorthand
) { ) {
this.name = name; this.name = name;
@ -82,6 +85,7 @@ public class Action {
this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() + this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() +
additions.stream().mapToInt(Formula::getWeight).sum() + additions.stream().mapToInt(Formula::getWeight).sum() +
deletions.stream().mapToInt(Formula::getWeight).sum(); deletions.stream().mapToInt(Formula::getWeight).sum();
this.cost = cost;
this.shorthand = shorthand; this.shorthand = shorthand;
this.trivial = computeTrivialOrNot(); this.trivial = computeTrivialOrNot();
@ -94,9 +98,10 @@ public class Action {
Set<Formula> preconditions, Set<Formula> preconditions,
Set<Formula> additions, Set<Formula> additions,
Set<Formula> deletions, Set<Formula> deletions,
int cost,
List<Variable> freeVariables) { List<Variable> freeVariables) {
return new Action(name, preconditions, additions, deletions, freeVariables, freeVariables); return new Action(name, preconditions, additions, deletions, cost, freeVariables, freeVariables);
} }
@ -104,9 +109,10 @@ public class Action {
Set<Formula> preconditions, Set<Formula> preconditions,
Set<Formula> additions, Set<Formula> additions,
Set<Formula> deletions, Set<Formula> deletions,
int cost,
List<Variable> freeVariables, List<Variable> interestedVars) { List<Variable> freeVariables, List<Variable> interestedVars) {
return new Action(name, preconditions, additions, deletions, freeVariables, interestedVars); return new Action(name, preconditions, additions, deletions, cost, freeVariables, interestedVars);
} }
@ -114,6 +120,10 @@ public class Action {
return weight; return weight;
} }
public int getCost() {
return cost;
}
public Formula getPrecondition() { public Formula getPrecondition() {
return precondition; return precondition;
} }
@ -131,16 +141,11 @@ public class Action {
} }
public List<Variable> openVars() { public List<Variable> openVars() {
return freeVariables;
}
Set<Variable> variables = Sets.newSet(); public List<Variable> getInterestedVars() {
return interestedVars;
variables.addAll(freeVariables);
List<Variable> variablesList = CollectionUtils.newEmptyList();
variablesList.addAll(variables);
return variablesList;
} }
public Set<Formula> instantiateAdditions(Map<Variable, Value> mapping) { public Set<Formula> instantiateAdditions(Map<Variable, Value> mapping) {
@ -172,7 +177,7 @@ public class Action {
List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());; List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());;
Compound shorthand = (Compound)(new Compound(name, valuesList)).apply(binding); Compound shorthand = (Compound)(new Compound(name, valuesList)).apply(binding);
return new Action(name, newPreconditions, newAdditions, newDeletions, newFreeVariables, shorthand); return new Action(name, newPreconditions, newAdditions, newDeletions, cost, newFreeVariables, shorthand);
} }
public String getName() { public String getName() {

View file

@ -1,127 +1,58 @@
package org.rairlab.planner; package org.rairlab.planner;
import org.rairlab.shadow.prover.representations.formula.Formula; import org.rairlab.shadow.prover.representations.formula.Formula;
import org.rairlab.planner.Action;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
/** /**
* Created by brandonrozek on 03/29/2023. * Created by brandonrozek on 03/29/2023.
*/ */
public class BreadthFirstPlanner { public class BreadthFirstPlanner {
// The longest plan to search for, -1 means no bound private AStarPlanner planner;
private Optional<Integer> MAX_DEPTH = Optional.empty();
// Number of plans to look for, -1 means up to max_depth
private Optional<Integer> K = Optional.empty();
public BreadthFirstPlanner(){ } public BreadthFirstPlanner(){
planner = new AStarPlanner();
}
public static int h(State s) {
return 1;
}
public Set<Plan> plan(Set<Formula> background, Set<Action> actions, State start, State goal) { public Set<Plan> plan(Set<Formula> background, Set<Action> actions, State start, State goal) {
// Search Space Data Structures // For BFS, need to ignore action costs
Set<State> history = new HashSet<State>(); Set<Action> newActions = new HashSet<Action>();
// Each node in the search space consists of for (Action a : actions) {
// (state, sequence of actions from initial) newActions.add(new Action(
Queue<Pair<List<State>, List<Action>>> search = new ArrayDeque<Pair<List<State>,List<Action>>>(); a.getName(), a.getPreconditions(), a.getAdditions(), a.getDeletions(),
1, a.openVars(), a.getInterestedVars()
// Submit Initial State ));
search.add(Pair.of(List.of(start), new ArrayList<Action>()));
// Current set of plans
Set<Plan> plansFound = new HashSet<Plan>();
// Breadth First Traversal until
// - No more actions can be applied
// - Max depth reached
// - Found K plans
while (!search.isEmpty()) {
Pair<List<State>, List<Action>> currentSearch = search.remove();
List<State> previous_states = currentSearch.getLeft();
List<Action> previous_actions = currentSearch.getRight();
State lastState = previous_states.get(previous_states.size() - 1);
// Exit loop if we've passed the depth limit
int currentDepth = previous_actions.size();
if (MAX_DEPTH.isPresent() && currentDepth > MAX_DEPTH.get()) {
break;
}
// If we're at the goal return
if (Operations.satisfies(background, lastState, goal)) {
plansFound.add(new Plan(previous_actions, previous_states, background));
if (K.isPresent() && plansFound.size() >= K.get()) {
break;
}
continue;
}
// Only consider non-trivial actions
Set<Action> nonTrivialActions = actions.stream()
.filter(Action::isNonTrivial)
.collect(Collectors.toSet());
// Apply the action to the state and add to the search space
for (Action action : nonTrivialActions) {
Optional<Set<Pair<State, Action>>> optNextStateActionPairs = Operations.apply(background, action, lastState);
// Ignore actions that aren't applicable
if (optNextStateActionPairs.isEmpty()) {
continue;
}
// Action's aren't grounded so each nextState represents
// a different parameter binding
Set<Pair<State, Action>> nextStateActionPairs = optNextStateActionPairs.get();
for (Pair<State, Action> stateActionPair: nextStateActionPairs) {
State nextState = stateActionPair.getLeft();
Action nextAction = stateActionPair.getRight();
// Prune already visited states
if (history.contains(nextState)) {
continue;
}
// Add to history
history.add(nextState);
// Construct search space parameters
List<State> next_states = new ArrayList<State>(previous_states);
next_states.add(nextState);
List<Action> next_actions = new ArrayList<Action>(previous_actions);
next_actions.add(nextAction);
// Add to search space
search.add(Pair.of(next_states, next_actions));
}
}
} }
return plansFound; return planner.plan(background, actions, start, goal, BreadthFirstPlanner::h);
} }
public Optional<Integer> getMaxDepth() { public Optional<Integer> getMaxDepth() {
return MAX_DEPTH; return planner.getMaxDepth();
} }
public void setMaxDepth(int maxDepth) { public void setMaxDepth(int maxDepth) {
MAX_DEPTH = Optional.of(maxDepth); planner.setMaxDepth(maxDepth);
} }
public void setK(int k) { public void setK(int k) {
K = Optional.of(k); planner.setK(k);
} }
public void clearK() { public void clearK() {
K = Optional.empty(); planner.clearK();
} }
public Optional<Integer> getK() { public Optional<Integer> getK() {
return K; return planner.getK();
} }
} }

View file

@ -4,10 +4,8 @@ import org.rairlab.planner.utils.Visualizer;
import org.rairlab.shadow.prover.core.Prover; import org.rairlab.shadow.prover.core.Prover;
import org.rairlab.shadow.prover.core.SnarkWrapper; import org.rairlab.shadow.prover.core.SnarkWrapper;
import org.rairlab.shadow.prover.core.proof.Justification; import org.rairlab.shadow.prover.core.proof.Justification;
import org.rairlab.planner.utils.Commons;
import org.rairlab.shadow.prover.representations.formula.BiConditional; import org.rairlab.shadow.prover.representations.formula.BiConditional;
import org.rairlab.shadow.prover.representations.formula.Formula; import org.rairlab.shadow.prover.representations.formula.Formula;
import org.rairlab.shadow.prover.representations.formula.Predicate;
import org.rairlab.shadow.prover.representations.value.Value; import org.rairlab.shadow.prover.representations.value.Value;
import org.rairlab.shadow.prover.representations.value.Variable; import org.rairlab.shadow.prover.representations.value.Variable;
@ -23,7 +21,6 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.concurrent.*; import java.util.concurrent.*;
import java.util.stream.Collectors;
/** /**
* Created by naveensundarg on 1/13/17. * Created by naveensundarg on 1/13/17.
@ -225,10 +222,17 @@ public class Operations {
return true; return true;
} }
return proveCached( for (Formula g : goal.getFormulae()) {
Sets.union(background, state.getFormulae()), Optional<Justification> just = proveCached(
Commons.makeAnd(goal.getFormulae()) Sets.union(background, state.getFormulae()),
).isPresent(); g
);
if (just.isEmpty()) {
return false;
}
}
return true;
} }
public static boolean conflicts(Set<Formula> background, State state1, State state2) { public static boolean conflicts(Set<Formula> background, State state1, State state2) {

View file

@ -28,6 +28,12 @@ public class Plan {
this.background = background; this.background = background;
} }
public Plan(List<Action> actions) {
this.actions = actions;
this.expectedStates = CollectionUtils.newEmptyList();
this.background = CollectionUtils.newEmptySet();
}
public List<Action> getActions() { public List<Action> getActions() {
return actions; return actions;
} }

View file

@ -0,0 +1,9 @@
package org.rairlab.planner.heuristics;
import org.rairlab.planner.State;
public class ConstantHeuristic {
public static int h(State s) {
return 1;
}
}

View file

@ -10,6 +10,6 @@ import java.util.Set;
public class IndefiniteAction extends Action { public class IndefiniteAction extends Action {
private IndefiniteAction(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, List<Variable> freeVariables) { private IndefiniteAction(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, List<Variable> freeVariables) {
super(name, preconditions, additions, deletions, freeVariables, freeVariables); super(name, preconditions, additions, deletions, 1, freeVariables, freeVariables);
} }
} }

View file

@ -58,6 +58,7 @@ public class PlanningProblem {
private static final Keyword PRECONDITIONS = Keyword.newKeyword("preconditions"); private static final Keyword PRECONDITIONS = Keyword.newKeyword("preconditions");
private static final Keyword ADDITIONS = Keyword.newKeyword("additions"); private static final Keyword ADDITIONS = Keyword.newKeyword("additions");
private static final Keyword DELETIONS = Keyword.newKeyword("deletions"); private static final Keyword DELETIONS = Keyword.newKeyword("deletions");
private static final Keyword COST = Keyword.newKeyword("cost");
private static final Symbol ACTION_DEFINER = Symbol.newSymbol("define-action"); private static final Symbol ACTION_DEFINER = Symbol.newSymbol("define-action");
@ -311,11 +312,17 @@ public class PlanningProblem {
Set<Formula> preconditions = readFrom((List<?>) actionSpec.get(PRECONDITIONS)); Set<Formula> preconditions = readFrom((List<?>) actionSpec.get(PRECONDITIONS));
Set<Formula> additions = readFrom((List<?>) actionSpec.get(ADDITIONS)); Set<Formula> additions = readFrom((List<?>) actionSpec.get(ADDITIONS));
Set<Formula> deletions = readFrom((List<?>) actionSpec.get(DELETIONS)); Set<Formula> deletions = readFrom((List<?>) actionSpec.get(DELETIONS));
int cost;
if (actionSpec.containsKey(COST)) {
cost = Integer.parseInt(actionSpec.get(COST).toString());
} else {
cost = 1;
}
List<Variable> interestedVars = CollectionUtils.newEmptyList(); List<Variable> interestedVars = CollectionUtils.newEmptyList();
interestedVars.addAll(vars); interestedVars.addAll(vars);
vars.addAll(preconditions.stream().map(Formula::variablesPresent).reduce(Sets.newSet(), Sets::union)); vars.addAll(preconditions.stream().map(Formula::variablesPresent).reduce(Sets.newSet(), Sets::union));
return Action.buildActionFrom(name, preconditions, additions, deletions, vars, interestedVars); return Action.buildActionFrom(name, preconditions, additions, deletions, cost, vars, interestedVars);
} catch (Reader.ParsingException e) { } catch (Reader.ParsingException e) {

View file

@ -1,8 +1,8 @@
package org.rairlab.planner.utils; package org.rairlab.planner.utils;
import org.rairlab.planner.BreadthFirstPlanner; import org.rairlab.planner.AStarPlanner;
import org.rairlab.planner.Plan; import org.rairlab.planner.Plan;
import org.rairlab.planner.Planner; import org.rairlab.planner.heuristics.ConstantHeuristic;
import org.rairlab.shadow.prover.utils.Reader; import org.rairlab.shadow.prover.utils.Reader;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -10,6 +10,7 @@ import java.io.FileNotFoundException;
import java.util.*; import java.util.*;
public final class Runner { public final class Runner {
public static void main(String[] args) { public static void main(String[] args) {
@ -44,15 +45,18 @@ public final class Runner {
return; return;
} }
BreadthFirstPlanner breadthFirstPlanner = new BreadthFirstPlanner(); AStarPlanner astarplanner = new AStarPlanner();
breadthFirstPlanner.setK(2); astarplanner.setK(2);
for (PlanningProblem planningProblem : planningProblemList) { for (PlanningProblem planningProblem : planningProblemList) {
Set<Plan> plans = breadthFirstPlanner.plan(
Set<Plan> plans = astarplanner.plan(
planningProblem.getBackground(), planningProblem.getBackground(),
planningProblem.getActions(), planningProblem.getActions(),
planningProblem.getStart(), planningProblem.getStart(),
planningProblem.getGoal()); planningProblem.getGoal(),
ConstantHeuristic::h
);
if(plans.size() > 0) { if(plans.size() > 0) {
System.out.println(plans.toString()); System.out.println(plans.toString());