Improvements to BreadthFirstPlanner

- Allow for unbounded search
- Allow for a way to look for a certain number of plans
- Use Pair from standard library
This commit is contained in:
Brandon Rozek 2023-09-26 15:08:10 -04:00
parent 88b5ceee7a
commit 08b9c01f3e
No known key found for this signature in database
GPG key ID: 26E457DA82C9F480
4 changed files with 90 additions and 87 deletions

View file

@ -1,119 +1,120 @@
package com.naveensundarg.planner; package com.naveensundarg.planner;
import com.naveensundarg.planner.utils.PlanningProblem;
import com.naveensundarg.shadow.prover.representations.formula.Formula; import com.naveensundarg.shadow.prover.representations.formula.Formula;
import com.naveensundarg.shadow.prover.utils.Pair;
import com.naveensundarg.shadow.prover.utils.Sets;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.lang3.tuple.Pair;
/** /**
* Created by naveensundarg on 1/13/17. * Created by brandonrozek on 03/29/2023.
*/ */
public class BreadthFirstPlanner implements Planner { public class BreadthFirstPlanner {
private static int MAX_DEPTH = 7; // The longest plan to search for, -1 means no bound
private int MAX_DEPTH = -1;
// Number of plans to look for, -1 means up to max_depth
private int K = -1;
public BreadthFirstPlanner(){ } public BreadthFirstPlanner(){ }
public static int getMaxDepth() { public Set<Plan> plan(Set<Formula> background, Set<Action> actions, State start, State goal) {
return MAX_DEPTH;
}
public static void setMaxDepth(int maxDepth) {
MAX_DEPTH = maxDepth;
}
@Override
public Optional<Set<Plan>> plan(Set<Formula> background, Set<Action> actions, State start, State goal) {
// Search Space Data Structures // Search Space Data Structures
Set<State> history = new HashSet<State>(); Set<State> history = new HashSet<State>();
Queue<Triple<List<State>, List<Action>, Integer>> search = new ArrayDeque<Triple<List<State>,List<Action>,Integer>>(); // Each node in the search space consists of
// (state, sequence of actions from initial)
Queue<Pair<List<State>, List<Action>>> search = new ArrayDeque<Pair<List<State>,List<Action>>>();
// Submit Initial State // Submit Initial State
search.add(Triple.of(List.of(start), new ArrayList<Action>(), 0)); search.add(Pair.of(List.of(start), new ArrayList<Action>()));
// Current set of plans
Set<Plan> plansFound = new HashSet<Plan>();
// Breadth First Traversal until // Breadth First Traversal until
// - Goal Reached
// - No more actions can be applied // - No more actions can be applied
// - Max depth reached // - Max depth reached
while (!search.isEmpty()) { // - Found K plans
while (!search.isEmpty() && !(K > 0 && plansFound.size() >= K)) {
Triple<List<State>, List<Action>, Integer> currentSearch = search.remove();
// Return if we're past the depth limit
int currentDepth = currentSearch.getRight();
if (currentDepth >= MAX_DEPTH) {
return Optional.empty();
}
Pair<List<State>, List<Action>> currentSearch = search.remove();
List<State> previous_states = currentSearch.getLeft(); List<State> previous_states = currentSearch.getLeft();
List<Action> previous_actions = currentSearch.getMiddle(); List<Action> previous_actions = currentSearch.getRight();
State lastState = previous_states.get(previous_states.size() - 1); State lastState = previous_states.get(previous_states.size() - 1);
// Exit loop if we've passed the depth limit
int currentDepth = previous_actions.size();
if (MAX_DEPTH > 0 && currentDepth > MAX_DEPTH) {
break;
}
// If we're at the goal return // If we're at the goal return
if (Operations.satisfies(background, lastState, goal)) { if (Operations.satisfies(background, lastState, goal)) {
return Optional.of(Sets.with( plansFound.add(new Plan(previous_actions, previous_states, background));
new Plan(previous_actions, previous_states, background) continue;
));
} }
// Try to apply each action to get to the next state // Only consider non-trivial actions
for (Action action : actions.stream().filter(Action::isNonTrivial).collect(Collectors.toSet())) { Set<Action> nonTrivialActions = actions.stream()
Optional<Set<Pair<State, Action>>> nextStateActionPairs = Operations.apply(background, action, lastState); .filter(Action::isNonTrivial)
.collect(Collectors.toSet());
if (nextStateActionPairs.isPresent()) { // Apply the action to the state and add to the search space
for (Action action : nonTrivialActions) {
Optional<Set<Pair<State, Action>>> optNextStateActionPairs = Operations.apply(background, action, lastState);
// Actions aren't grounded, so each nextState represents a different // Ignore actions that aren't applicable
// paramter binding if (optNextStateActionPairs.isEmpty()) {
for (Pair<State, Action> stateActionPair : nextStateActionPairs.get()) { continue;
State nextState = stateActionPair.first(); }
Action nextAction = stateActionPair.second();
// Action's aren't grounded so each nextState represents
// a different parameter binding
Set<Pair<State, Action>> nextStateActionPairs = optNextStateActionPairs.get();
for (Pair<State, Action> stateActionPair: nextStateActionPairs) {
State nextState = stateActionPair.getLeft();
Action nextAction = stateActionPair.getRight();
// Prune already visited states // Prune already visited states
if (history.contains(nextState)) { if (history.contains(nextState)) {
continue; continue;
} }
// Add to history
history.add(nextState);
// Construct search space parameters
List<State> next_states = new ArrayList<State>(previous_states); List<State> next_states = new ArrayList<State>(previous_states);
next_states.add(nextState); next_states.add(nextState);
List<Action> next_actions = new ArrayList<Action>(previous_actions); List<Action> next_actions = new ArrayList<Action>(previous_actions);
next_actions.add(nextAction); next_actions.add(nextAction);
// Add new state to history and search space // Add to search space
search.add(Triple.of(next_states, next_actions, currentDepth + 1)); search.add(Pair.of(next_states, next_actions));
history.add(nextState);
} }
} }
} }
return plansFound;
} }
return Optional.empty(); public int getMaxDepth() {
return MAX_DEPTH;
} }
@Override public void setMaxDepth(int maxDepth) {
public Optional<Set<Plan>> plan(PlanningProblem problem, Set<Formula> background, Set<Action> actions, State start, State goal) { MAX_DEPTH = maxDepth;
return Optional.empty();
} }
public void setK(int k) {
public Optional<Set<Plan>> plan(PlanningProblem problem, Set<Formula> background, Set<Action> actions, State start, State goal, List<PlanMethod> planMethods){ K = k;
return Optional.empty();
} }
public int getK() {
public Optional<Plan> verify(Set<Formula> background, State start, State goal, PlanSketch planSketch){ return K;
return Optional.empty();
} }
} }

View file

@ -6,9 +6,11 @@ import com.naveensundarg.planner.utils.Visualizer;
import com.naveensundarg.shadow.prover.core.proof.Justification; import com.naveensundarg.shadow.prover.core.proof.Justification;
import com.naveensundarg.shadow.prover.representations.formula.Formula; import com.naveensundarg.shadow.prover.representations.formula.Formula;
import com.naveensundarg.shadow.prover.utils.CollectionUtils; import com.naveensundarg.shadow.prover.utils.CollectionUtils;
import com.naveensundarg.shadow.prover.utils.Pair;
import com.naveensundarg.shadow.prover.utils.Sets; import com.naveensundarg.shadow.prover.utils.Sets;
import org.apache.commons.lang3.tuple.Pair;
import java.util.*; import java.util.*;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -250,7 +252,7 @@ public class DepthFirstPlanner implements Planner {
for (Pair<State, Action> stateActionPair : nextStateActionPairs.get()) { for (Pair<State, Action> stateActionPair : nextStateActionPairs.get()) {
Visualizer.push(); Visualizer.push();
Optional<Set<Plan>> planOpt = planInternal(history, currentDepth + 1, maxDepth, background, actions, stateActionPair.first(), goal); Optional<Set<Plan>> planOpt = planInternal(history, currentDepth + 1, maxDepth, background, actions, stateActionPair.getLeft(), goal);
Visualizer.pop(); Visualizer.pop();
@ -259,8 +261,8 @@ public class DepthFirstPlanner implements Planner {
atleastOnePlanFound = true; atleastOnePlanFound = true;
Set<Plan> nextPlans = planOpt.get(); Set<Plan> nextPlans = planOpt.get();
State nextSate = stateActionPair.first(); State nextSate = stateActionPair.getLeft();
Action instantiatedAction = stateActionPair.second(); Action instantiatedAction = stateActionPair.getRight();
Set<Plan> augmentedPlans = nextPlans.stream(). Set<Plan> augmentedPlans = nextPlans.stream().
map(plan -> plan.getPlanByStartingWith(instantiatedAction, nextSate)). map(plan -> plan.getPlanByStartingWith(instantiatedAction, nextSate)).

View file

@ -10,10 +10,10 @@ import com.naveensundarg.shadow.prover.representations.formula.Predicate;
import com.naveensundarg.shadow.prover.representations.value.Value; import com.naveensundarg.shadow.prover.representations.value.Value;
import com.naveensundarg.shadow.prover.representations.value.Variable; import com.naveensundarg.shadow.prover.representations.value.Variable;
import com.naveensundarg.shadow.prover.utils.CollectionUtils; import com.naveensundarg.shadow.prover.utils.CollectionUtils;
import com.naveensundarg.shadow.prover.utils.ImmutablePair;
import com.naveensundarg.shadow.prover.utils.Pair;
import com.naveensundarg.shadow.prover.utils.Sets; import com.naveensundarg.shadow.prover.utils.Sets;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Triple; import org.apache.commons.lang3.tuple.Triple;
import java.util.List; import java.util.List;
@ -52,7 +52,7 @@ public class Operations {
public static synchronized Optional<Justification> proveCached(Set<Formula> assumptions, Formula goal) { public static synchronized Optional<Justification> proveCached(Set<Formula> assumptions, Formula goal) {
Pair<Set<Formula>, Formula> inputPair = ImmutablePair.from(assumptions, goal); Pair<Set<Formula>, Formula> inputPair = ImmutablePair.of(assumptions, goal);
if (proverCache.containsKey(inputPair)) { if (proverCache.containsKey(inputPair)) {
@ -62,8 +62,8 @@ public class Operations {
Optional<Map.Entry<Pair<Set<Formula>, Formula>, Optional<Justification>>> cachedOptionalSuccessful = proverCache.entrySet().stream().filter(pairOptionalEntry -> { Optional<Map.Entry<Pair<Set<Formula>, Formula>, Optional<Justification>>> cachedOptionalSuccessful = proverCache.entrySet().stream().filter(pairOptionalEntry -> {
Set<Formula> cachedAssumptions = pairOptionalEntry.getKey().first(); Set<Formula> cachedAssumptions = pairOptionalEntry.getKey().getLeft();
Formula cachedGoal = pairOptionalEntry.getKey().second(); Formula cachedGoal = pairOptionalEntry.getKey().getRight();
return cachedGoal.equals(goal) && Sets.subset(cachedAssumptions, assumptions); return cachedGoal.equals(goal) && Sets.subset(cachedAssumptions, assumptions);
}).findAny(); }).findAny();
@ -77,8 +77,8 @@ public class Operations {
Optional<Map.Entry<Pair<Set<Formula>, Formula>, Optional<Justification>>> cachedOptionalFailed = proverCache.entrySet().stream().filter(pairOptionalEntry -> { Optional<Map.Entry<Pair<Set<Formula>, Formula>, Optional<Justification>>> cachedOptionalFailed = proverCache.entrySet().stream().filter(pairOptionalEntry -> {
Set<Formula> cachedAssumptions = pairOptionalEntry.getKey().first(); Set<Formula> cachedAssumptions = pairOptionalEntry.getKey().getLeft();
Formula cachedGoal = pairOptionalEntry.getKey().second(); Formula cachedGoal = pairOptionalEntry.getKey().getRight();
return cachedGoal.equals(goal) && Sets.subset(assumptions, cachedAssumptions); return cachedGoal.equals(goal) && Sets.subset(assumptions, cachedAssumptions);
}).findAny(); }).findAny();
@ -249,7 +249,7 @@ public class Operations {
newState = State.initializeWith(newFormulae); newState = State.initializeWith(newFormulae);
nexts.add(ImmutablePair.from(newState, action.instantiate(binding))); nexts.add(ImmutablePair.of(newState, action.instantiate(binding)));
} }
@ -273,12 +273,12 @@ public class Operations {
newState = State.initializeWith(newFormulae); newState = State.initializeWith(newFormulae);
nexts.add(ImmutablePair.from(newState, action.instantiate(emptyBinding))); nexts.add(ImmutablePair.of(newState, action.instantiate(emptyBinding)));
} }
nexts = nexts.stream().filter(n-> !n.first().getFormulae().equals(state.getFormulae())).collect(Collectors.toSet());; nexts = nexts.stream().filter(n-> !n.getLeft().getFormulae().equals(state.getFormulae())).collect(Collectors.toSet());;

View file

@ -44,17 +44,17 @@ public final class Runner {
return; return;
} }
Planner breadthFirstPlanner = new BreadthFirstPlanner(); BreadthFirstPlanner breadthFirstPlanner = new BreadthFirstPlanner();
for (PlanningProblem planningProblem : planningProblemList) { for (PlanningProblem planningProblem : planningProblemList) {
Optional<Set<Plan>> optionalPlans = breadthFirstPlanner.plan( Set<Plan> plans = breadthFirstPlanner.plan(
planningProblem.getBackground(), planningProblem.getBackground(),
planningProblem.getActions(), planningProblem.getActions(),
planningProblem.getStart(), planningProblem.getStart(),
planningProblem.getGoal()); planningProblem.getGoal());
if(optionalPlans.isPresent()) { if(plans.size() > 0) {
System.out.println(optionalPlans.get().toString()); System.out.println(plans.toString());
} }
else { else {
System.out.println("FAILED"); System.out.println("FAILED");