mirror of
https://github.com/RAIRLab/Spectra.git
synced 2024-11-21 08:26:30 -05:00
Implemented AStar and rewrote BFS to use AStar algorithm
This commit is contained in:
parent
2f08f98845
commit
edde3cc8e5
9 changed files with 248 additions and 121 deletions
161
src/main/java/org/rairlab/planner/AStarPlanner.java
Normal file
161
src/main/java/org/rairlab/planner/AStarPlanner.java
Normal file
|
@ -0,0 +1,161 @@
|
|||
package org.rairlab.planner;
|
||||
|
||||
import org.rairlab.shadow.prover.representations.formula.Formula;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
class AStarComparator implements Comparator<Pair<State, List<Action>>> {
|
||||
private Map<Pair<State, List<Action>>, Integer> heuristic;
|
||||
|
||||
public AStarComparator() {
|
||||
this.heuristic = new HashMap<Pair<State, List<Action>>, Integer>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compare(Pair<State, List<Action>> o1, Pair<State, List<Action>> o2) {
|
||||
// Print nag message if undefined behavior is happening
|
||||
if (!this.heuristic.containsKey(o1) || !this.heuristic.containsKey(o2)) {
|
||||
System.out.println("[ERROR] Heuristic is not defined for state");
|
||||
}
|
||||
|
||||
int i1 = this.heuristic.get(o1);
|
||||
int i2 = this.heuristic.get(o2);
|
||||
return i1 < i2 ? -1: 1;
|
||||
}
|
||||
|
||||
public void setValue(Pair<State, List<Action>> k, int v) {
|
||||
this.heuristic.put(k, v);
|
||||
}
|
||||
|
||||
public int getValue(Pair<State, List<Action>> k) {
|
||||
return this.heuristic.get(k);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Created by brandonrozek on 03/29/2023.
|
||||
*/
|
||||
public class AStarPlanner {
|
||||
|
||||
// The longest plan to search for, -1 means no bound
|
||||
private Optional<Integer> MAX_DEPTH = Optional.empty();
|
||||
// Number of plans to look for, -1 means up to max_depth
|
||||
private Optional<Integer> K = Optional.empty();
|
||||
|
||||
public AStarPlanner(){ }
|
||||
|
||||
public Set<Plan> plan(Set<Formula> background, Set<Action> actions, State start, State goal, Function<State, Integer> heuristic) {
|
||||
|
||||
// Search Space Data Structures
|
||||
Set<State> history = new HashSet<State>();
|
||||
// Each node in the search space consists of
|
||||
// (state, sequence of actions from initial)
|
||||
AStarComparator comparator = new AStarComparator();
|
||||
Queue<Pair<State, List<Action>>> search = new PriorityQueue<Pair<State,List<Action>>>(comparator);
|
||||
|
||||
// Submit Initial State
|
||||
Pair<State, List<Action>> searchStart = Pair.of(start, new ArrayList<Action>());
|
||||
comparator.setValue(searchStart, 0);
|
||||
search.add(searchStart);
|
||||
|
||||
// Current set of plans
|
||||
Set<Plan> plansFound = new HashSet<Plan>();
|
||||
|
||||
// AStar Traversal until
|
||||
// - No more actions can be applied
|
||||
// - Max depth reached
|
||||
// - Found K plans
|
||||
while (!search.isEmpty()) {
|
||||
|
||||
|
||||
Pair<State, List<Action>> currentSearch = search.remove();
|
||||
State lastState = currentSearch.getLeft();
|
||||
List<Action> previous_actions = currentSearch.getRight();
|
||||
|
||||
// System.out.println("Considering state with heuristic: " + comparator.getValue(currentSearch));
|
||||
|
||||
// Exit loop if we've passed the depth limit
|
||||
int currentDepth = previous_actions.size();
|
||||
if (MAX_DEPTH.isPresent() && currentDepth > MAX_DEPTH.get()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// If we're at the goal return
|
||||
if (Operations.satisfies(background, lastState, goal)) {
|
||||
plansFound.add(new Plan(previous_actions));
|
||||
if (K.isPresent() && plansFound.size() >= K.get()) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only consider non-trivial actions
|
||||
Set<Action> nonTrivialActions = actions.stream()
|
||||
.filter(Action::isNonTrivial)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
// Apply the action to the state and add to the search space
|
||||
for (Action action : nonTrivialActions) {
|
||||
Optional<Set<Pair<State, Action>>> optNextStateActionPairs = Operations.apply(background, action, lastState);
|
||||
|
||||
// Ignore actions that aren't applicable
|
||||
if (optNextStateActionPairs.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Action's aren't grounded so each nextState represents
|
||||
// a different parameter binding
|
||||
Set<Pair<State, Action>> nextStateActionPairs = optNextStateActionPairs.get();
|
||||
for (Pair<State, Action> stateActionPair: nextStateActionPairs) {
|
||||
State nextState = stateActionPair.getLeft();
|
||||
Action nextAction = stateActionPair.getRight();
|
||||
|
||||
// Prune already visited states
|
||||
if (history.contains(nextState)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add to history
|
||||
history.add(nextState);
|
||||
|
||||
// Construct search space parameters
|
||||
List<Action> next_actions = new ArrayList<Action>(previous_actions);
|
||||
next_actions.add(nextAction);
|
||||
|
||||
// Add to search space
|
||||
Pair<State, List<Action>> futureSearch = Pair.of(nextState, next_actions);
|
||||
int planCost = next_actions.stream().map(Action::getCost).reduce(0, (a, b) -> a + b);
|
||||
int heuristicValue = heuristic.apply(nextState);
|
||||
comparator.setValue(futureSearch, planCost + heuristicValue);
|
||||
search.add(futureSearch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return plansFound;
|
||||
}
|
||||
|
||||
public Optional<Integer> getMaxDepth() {
|
||||
return MAX_DEPTH;
|
||||
}
|
||||
|
||||
public void setMaxDepth(int maxDepth) {
|
||||
MAX_DEPTH = Optional.of(maxDepth);
|
||||
}
|
||||
|
||||
public void setK(int k) {
|
||||
K = Optional.of(k);
|
||||
}
|
||||
|
||||
public void clearK() {
|
||||
K = Optional.empty();
|
||||
}
|
||||
|
||||
public Optional<Integer> getK() {
|
||||
return K;
|
||||
}
|
||||
}
|
|
@ -13,6 +13,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* Created by naveensundarg on 1/13/17.
|
||||
|
@ -28,12 +29,13 @@ public class Action {
|
|||
private final String name;
|
||||
private final Formula precondition;
|
||||
|
||||
private int cost;
|
||||
private int weight;
|
||||
private final boolean trivial;
|
||||
|
||||
private final Compound shorthand;
|
||||
|
||||
public Action(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, List<Variable> freeVariables, List<Variable> interestedVars) {
|
||||
public Action(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, int cost, List<Variable> freeVariables, List<Variable> interestedVars) {
|
||||
this.name = name;
|
||||
this.preconditions = preconditions;
|
||||
|
||||
|
@ -52,6 +54,7 @@ public class Action {
|
|||
this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() +
|
||||
additions.stream().mapToInt(Formula::getWeight).sum() +
|
||||
deletions.stream().mapToInt(Formula::getWeight).sum();
|
||||
this.cost = cost;
|
||||
|
||||
List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());;
|
||||
this.shorthand = new Compound(name, valuesList);
|
||||
|
@ -61,7 +64,7 @@ public class Action {
|
|||
}
|
||||
|
||||
public Action(String name, Set<Formula> preconditions, Set<Formula> additions,
|
||||
Set<Formula> deletions, List<Variable> freeVariables,
|
||||
Set<Formula> deletions, int cost, List<Variable> freeVariables,
|
||||
Compound shorthand
|
||||
) {
|
||||
this.name = name;
|
||||
|
@ -82,6 +85,7 @@ public class Action {
|
|||
this.weight = preconditions.stream().mapToInt(Formula::getWeight).sum() +
|
||||
additions.stream().mapToInt(Formula::getWeight).sum() +
|
||||
deletions.stream().mapToInt(Formula::getWeight).sum();
|
||||
this.cost = cost;
|
||||
|
||||
this.shorthand = shorthand;
|
||||
this.trivial = computeTrivialOrNot();
|
||||
|
@ -94,9 +98,10 @@ public class Action {
|
|||
Set<Formula> preconditions,
|
||||
Set<Formula> additions,
|
||||
Set<Formula> deletions,
|
||||
int cost,
|
||||
List<Variable> freeVariables) {
|
||||
|
||||
return new Action(name, preconditions, additions, deletions, freeVariables, freeVariables);
|
||||
return new Action(name, preconditions, additions, deletions, cost, freeVariables, freeVariables);
|
||||
|
||||
}
|
||||
|
||||
|
@ -104,9 +109,10 @@ public class Action {
|
|||
Set<Formula> preconditions,
|
||||
Set<Formula> additions,
|
||||
Set<Formula> deletions,
|
||||
int cost,
|
||||
List<Variable> freeVariables, List<Variable> interestedVars) {
|
||||
|
||||
return new Action(name, preconditions, additions, deletions, freeVariables, interestedVars);
|
||||
return new Action(name, preconditions, additions, deletions, cost, freeVariables, interestedVars);
|
||||
|
||||
}
|
||||
|
||||
|
@ -114,6 +120,10 @@ public class Action {
|
|||
return weight;
|
||||
}
|
||||
|
||||
public int getCost() {
|
||||
return cost;
|
||||
}
|
||||
|
||||
public Formula getPrecondition() {
|
||||
return precondition;
|
||||
}
|
||||
|
@ -131,16 +141,11 @@ public class Action {
|
|||
}
|
||||
|
||||
public List<Variable> openVars() {
|
||||
return freeVariables;
|
||||
}
|
||||
|
||||
Set<Variable> variables = Sets.newSet();
|
||||
|
||||
variables.addAll(freeVariables);
|
||||
|
||||
List<Variable> variablesList = CollectionUtils.newEmptyList();
|
||||
|
||||
variablesList.addAll(variables);
|
||||
return variablesList;
|
||||
|
||||
public List<Variable> getInterestedVars() {
|
||||
return interestedVars;
|
||||
}
|
||||
|
||||
public Set<Formula> instantiateAdditions(Map<Variable, Value> mapping) {
|
||||
|
@ -172,7 +177,7 @@ public class Action {
|
|||
|
||||
List<Value> valuesList = interestedVars.stream().collect(Collectors.toList());;
|
||||
Compound shorthand = (Compound)(new Compound(name, valuesList)).apply(binding);
|
||||
return new Action(name, newPreconditions, newAdditions, newDeletions, newFreeVariables, shorthand);
|
||||
return new Action(name, newPreconditions, newAdditions, newDeletions, cost, newFreeVariables, shorthand);
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
|
|
|
@ -1,127 +1,58 @@
|
|||
package org.rairlab.planner;
|
||||
|
||||
import org.rairlab.shadow.prover.representations.formula.Formula;
|
||||
import org.rairlab.planner.Action;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
/**
|
||||
* Created by brandonrozek on 03/29/2023.
|
||||
*/
|
||||
public class BreadthFirstPlanner {
|
||||
|
||||
// The longest plan to search for, -1 means no bound
|
||||
private Optional<Integer> MAX_DEPTH = Optional.empty();
|
||||
// Number of plans to look for, -1 means up to max_depth
|
||||
private Optional<Integer> K = Optional.empty();
|
||||
private AStarPlanner planner;
|
||||
|
||||
public BreadthFirstPlanner(){ }
|
||||
public BreadthFirstPlanner(){
|
||||
planner = new AStarPlanner();
|
||||
}
|
||||
|
||||
public static int h(State s) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
public Set<Plan> plan(Set<Formula> background, Set<Action> actions, State start, State goal) {
|
||||
|
||||
// Search Space Data Structures
|
||||
Set<State> history = new HashSet<State>();
|
||||
// Each node in the search space consists of
|
||||
// (state, sequence of actions from initial)
|
||||
Queue<Pair<List<State>, List<Action>>> search = new ArrayDeque<Pair<List<State>,List<Action>>>();
|
||||
|
||||
// Submit Initial State
|
||||
search.add(Pair.of(List.of(start), new ArrayList<Action>()));
|
||||
|
||||
// Current set of plans
|
||||
Set<Plan> plansFound = new HashSet<Plan>();
|
||||
|
||||
// Breadth First Traversal until
|
||||
// - No more actions can be applied
|
||||
// - Max depth reached
|
||||
// - Found K plans
|
||||
while (!search.isEmpty()) {
|
||||
|
||||
Pair<List<State>, List<Action>> currentSearch = search.remove();
|
||||
List<State> previous_states = currentSearch.getLeft();
|
||||
List<Action> previous_actions = currentSearch.getRight();
|
||||
State lastState = previous_states.get(previous_states.size() - 1);
|
||||
|
||||
// Exit loop if we've passed the depth limit
|
||||
int currentDepth = previous_actions.size();
|
||||
if (MAX_DEPTH.isPresent() && currentDepth > MAX_DEPTH.get()) {
|
||||
break;
|
||||
// For BFS, need to ignore action costs
|
||||
Set<Action> newActions = new HashSet<Action>();
|
||||
for (Action a : actions) {
|
||||
newActions.add(new Action(
|
||||
a.getName(), a.getPreconditions(), a.getAdditions(), a.getDeletions(),
|
||||
1, a.openVars(), a.getInterestedVars()
|
||||
));
|
||||
}
|
||||
|
||||
// If we're at the goal return
|
||||
if (Operations.satisfies(background, lastState, goal)) {
|
||||
plansFound.add(new Plan(previous_actions, previous_states, background));
|
||||
if (K.isPresent() && plansFound.size() >= K.get()) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only consider non-trivial actions
|
||||
Set<Action> nonTrivialActions = actions.stream()
|
||||
.filter(Action::isNonTrivial)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
// Apply the action to the state and add to the search space
|
||||
for (Action action : nonTrivialActions) {
|
||||
Optional<Set<Pair<State, Action>>> optNextStateActionPairs = Operations.apply(background, action, lastState);
|
||||
|
||||
// Ignore actions that aren't applicable
|
||||
if (optNextStateActionPairs.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Action's aren't grounded so each nextState represents
|
||||
// a different parameter binding
|
||||
Set<Pair<State, Action>> nextStateActionPairs = optNextStateActionPairs.get();
|
||||
for (Pair<State, Action> stateActionPair: nextStateActionPairs) {
|
||||
State nextState = stateActionPair.getLeft();
|
||||
Action nextAction = stateActionPair.getRight();
|
||||
|
||||
// Prune already visited states
|
||||
if (history.contains(nextState)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add to history
|
||||
history.add(nextState);
|
||||
|
||||
// Construct search space parameters
|
||||
List<State> next_states = new ArrayList<State>(previous_states);
|
||||
next_states.add(nextState);
|
||||
|
||||
List<Action> next_actions = new ArrayList<Action>(previous_actions);
|
||||
next_actions.add(nextAction);
|
||||
|
||||
// Add to search space
|
||||
search.add(Pair.of(next_states, next_actions));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return plansFound;
|
||||
return planner.plan(background, actions, start, goal, BreadthFirstPlanner::h);
|
||||
}
|
||||
|
||||
public Optional<Integer> getMaxDepth() {
|
||||
return MAX_DEPTH;
|
||||
return planner.getMaxDepth();
|
||||
}
|
||||
|
||||
public void setMaxDepth(int maxDepth) {
|
||||
MAX_DEPTH = Optional.of(maxDepth);
|
||||
planner.setMaxDepth(maxDepth);
|
||||
}
|
||||
|
||||
public void setK(int k) {
|
||||
K = Optional.of(k);
|
||||
planner.setK(k);
|
||||
}
|
||||
|
||||
public void clearK() {
|
||||
K = Optional.empty();
|
||||
planner.clearK();
|
||||
}
|
||||
|
||||
public Optional<Integer> getK() {
|
||||
return K;
|
||||
return planner.getK();
|
||||
}
|
||||
|
||||
}
|
|
@ -4,10 +4,8 @@ import org.rairlab.planner.utils.Visualizer;
|
|||
import org.rairlab.shadow.prover.core.Prover;
|
||||
import org.rairlab.shadow.prover.core.SnarkWrapper;
|
||||
import org.rairlab.shadow.prover.core.proof.Justification;
|
||||
import org.rairlab.planner.utils.Commons;
|
||||
import org.rairlab.shadow.prover.representations.formula.BiConditional;
|
||||
import org.rairlab.shadow.prover.representations.formula.Formula;
|
||||
import org.rairlab.shadow.prover.representations.formula.Predicate;
|
||||
|
||||
import org.rairlab.shadow.prover.representations.value.Value;
|
||||
import org.rairlab.shadow.prover.representations.value.Variable;
|
||||
|
@ -23,7 +21,6 @@ import java.util.Map;
|
|||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Created by naveensundarg on 1/13/17.
|
||||
|
@ -225,10 +222,17 @@ public class Operations {
|
|||
return true;
|
||||
}
|
||||
|
||||
return proveCached(
|
||||
for (Formula g : goal.getFormulae()) {
|
||||
Optional<Justification> just = proveCached(
|
||||
Sets.union(background, state.getFormulae()),
|
||||
Commons.makeAnd(goal.getFormulae())
|
||||
).isPresent();
|
||||
g
|
||||
);
|
||||
if (just.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public static boolean conflicts(Set<Formula> background, State state1, State state2) {
|
||||
|
|
|
@ -28,6 +28,12 @@ public class Plan {
|
|||
this.background = background;
|
||||
}
|
||||
|
||||
public Plan(List<Action> actions) {
|
||||
this.actions = actions;
|
||||
this.expectedStates = CollectionUtils.newEmptyList();
|
||||
this.background = CollectionUtils.newEmptySet();
|
||||
}
|
||||
|
||||
public List<Action> getActions() {
|
||||
return actions;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
package org.rairlab.planner.heuristics;
|
||||
|
||||
import org.rairlab.planner.State;
|
||||
|
||||
public class ConstantHeuristic {
|
||||
public static int h(State s) {
|
||||
return 1;
|
||||
}
|
||||
}
|
|
@ -10,6 +10,6 @@ import java.util.Set;
|
|||
public class IndefiniteAction extends Action {
|
||||
|
||||
private IndefiniteAction(String name, Set<Formula> preconditions, Set<Formula> additions, Set<Formula> deletions, List<Variable> freeVariables) {
|
||||
super(name, preconditions, additions, deletions, freeVariables, freeVariables);
|
||||
super(name, preconditions, additions, deletions, 1, freeVariables, freeVariables);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,6 +58,7 @@ public class PlanningProblem {
|
|||
private static final Keyword PRECONDITIONS = Keyword.newKeyword("preconditions");
|
||||
private static final Keyword ADDITIONS = Keyword.newKeyword("additions");
|
||||
private static final Keyword DELETIONS = Keyword.newKeyword("deletions");
|
||||
private static final Keyword COST = Keyword.newKeyword("cost");
|
||||
|
||||
private static final Symbol ACTION_DEFINER = Symbol.newSymbol("define-action");
|
||||
|
||||
|
@ -311,11 +312,17 @@ public class PlanningProblem {
|
|||
Set<Formula> preconditions = readFrom((List<?>) actionSpec.get(PRECONDITIONS));
|
||||
Set<Formula> additions = readFrom((List<?>) actionSpec.get(ADDITIONS));
|
||||
Set<Formula> deletions = readFrom((List<?>) actionSpec.get(DELETIONS));
|
||||
int cost;
|
||||
if (actionSpec.containsKey(COST)) {
|
||||
cost = Integer.parseInt(actionSpec.get(COST).toString());
|
||||
} else {
|
||||
cost = 1;
|
||||
}
|
||||
|
||||
List<Variable> interestedVars = CollectionUtils.newEmptyList();
|
||||
interestedVars.addAll(vars);
|
||||
vars.addAll(preconditions.stream().map(Formula::variablesPresent).reduce(Sets.newSet(), Sets::union));
|
||||
return Action.buildActionFrom(name, preconditions, additions, deletions, vars, interestedVars);
|
||||
return Action.buildActionFrom(name, preconditions, additions, deletions, cost, vars, interestedVars);
|
||||
|
||||
|
||||
} catch (Reader.ParsingException e) {
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package org.rairlab.planner.utils;
|
||||
|
||||
import org.rairlab.planner.BreadthFirstPlanner;
|
||||
import org.rairlab.planner.AStarPlanner;
|
||||
import org.rairlab.planner.Plan;
|
||||
import org.rairlab.planner.Planner;
|
||||
import org.rairlab.planner.heuristics.ConstantHeuristic;
|
||||
import org.rairlab.shadow.prover.utils.Reader;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
|
@ -10,6 +10,7 @@ import java.io.FileNotFoundException;
|
|||
import java.util.*;
|
||||
|
||||
|
||||
|
||||
public final class Runner {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
@ -44,15 +45,18 @@ public final class Runner {
|
|||
return;
|
||||
}
|
||||
|
||||
BreadthFirstPlanner breadthFirstPlanner = new BreadthFirstPlanner();
|
||||
breadthFirstPlanner.setK(2);
|
||||
AStarPlanner astarplanner = new AStarPlanner();
|
||||
astarplanner.setK(2);
|
||||
|
||||
for (PlanningProblem planningProblem : planningProblemList) {
|
||||
Set<Plan> plans = breadthFirstPlanner.plan(
|
||||
|
||||
Set<Plan> plans = astarplanner.plan(
|
||||
planningProblem.getBackground(),
|
||||
planningProblem.getActions(),
|
||||
planningProblem.getStart(),
|
||||
planningProblem.getGoal());
|
||||
planningProblem.getGoal(),
|
||||
ConstantHeuristic::h
|
||||
);
|
||||
|
||||
if(plans.size() > 0) {
|
||||
System.out.println(plans.toString());
|
||||
|
|
Loading…
Reference in a new issue