Improved caching of operations

This commit is contained in:
Brandon Rozek 2023-10-31 12:07:00 -04:00
parent 61ae5a5ed6
commit 0e3895d871
No known key found for this signature in database
GPG key ID: 26E457DA82C9F480

View file

@ -4,9 +4,11 @@ import org.rairlab.planner.utils.Visualizer;
import org.rairlab.shadow.prover.core.Prover; import org.rairlab.shadow.prover.core.Prover;
import org.rairlab.shadow.prover.core.SnarkWrapper; import org.rairlab.shadow.prover.core.SnarkWrapper;
import org.rairlab.shadow.prover.core.proof.Justification; import org.rairlab.shadow.prover.core.proof.Justification;
import org.rairlab.planner.utils.Commons;
import org.rairlab.shadow.prover.representations.formula.BiConditional; import org.rairlab.shadow.prover.representations.formula.BiConditional;
import org.rairlab.shadow.prover.representations.formula.Formula; import org.rairlab.shadow.prover.representations.formula.Formula;
import org.rairlab.shadow.prover.representations.formula.Predicate; import org.rairlab.shadow.prover.representations.formula.Predicate;
import org.rairlab.shadow.prover.representations.value.Value; import org.rairlab.shadow.prover.representations.value.Value;
import org.rairlab.shadow.prover.representations.value.Variable; import org.rairlab.shadow.prover.representations.value.Variable;
import org.rairlab.shadow.prover.utils.CollectionUtils; import org.rairlab.shadow.prover.utils.CollectionUtils;
@ -52,115 +54,92 @@ public class Operations {
public static synchronized Optional<Justification> proveCached(Set<Formula> assumptions, Formula goal) { public static synchronized Optional<Justification> proveCached(Set<Formula> assumptions, Formula goal) {
// (1) If we've asked to prove this exact goal from assumptions before
// then return the previous result
Pair<Set<Formula>, Formula> inputPair = ImmutablePair.of(assumptions, goal); Pair<Set<Formula>, Formula> inputPair = ImmutablePair.of(assumptions, goal);
if (proverCache.containsKey(inputPair)) { if (proverCache.containsKey(inputPair)) {
return proverCache.get(inputPair); return proverCache.get(inputPair);
} }
Optional<Map.Entry<Pair<Set<Formula>, Formula>, Optional<Justification>>> cachedOptionalSuccessful = proverCache.entrySet().stream().filter(pairOptionalEntry -> { // Iterate through the cache
for (Map.Entry<Pair<Set<Formula>, Formula>, Optional<Justification>> entry : proverCache.entrySet()) {
Set<Formula> cachedAssumptions = entry.getKey().getLeft();
Formula cachedGoal = entry.getKey().getRight();
Optional<Justification> optJust = entry.getValue();
Set<Formula> cachedAssumptions = pairOptionalEntry.getKey().getLeft(); // (2) Return the cached justification if:
Formula cachedGoal = pairOptionalEntry.getKey().getRight(); // - Goals are the same
// - The cached assumptions are a subset of the current ones
return cachedGoal.equals(goal) && Sets.subset(cachedAssumptions, assumptions); // - A justification was found
}).findAny(); if (optJust.isPresent() && cachedGoal.equals(goal) && Sets.subset(cachedAssumptions, assumptions)) {
return optJust;
if(cachedOptionalSuccessful.isPresent() && cachedOptionalSuccessful.get().getValue().isPresent()){
return cachedOptionalSuccessful.get().getValue();
} }
// (3) Return cached failure if:
Optional<Map.Entry<Pair<Set<Formula>, Formula>, Optional<Justification>>> cachedOptionalFailed = proverCache.entrySet().stream().filter(pairOptionalEntry -> { // - Goals are the same
// - Assumptions are a subset of cached assumptions
Set<Formula> cachedAssumptions = pairOptionalEntry.getKey().getLeft(); // - No justification was found
Formula cachedGoal = pairOptionalEntry.getKey().getRight(); if (optJust.isEmpty() && cachedGoal.equals(goal) && Sets.subset(assumptions, cachedAssumptions)) {
return optJust;
return cachedGoal.equals(goal) && Sets.subset(assumptions, cachedAssumptions); }
}).findAny();
if(cachedOptionalFailed.isPresent() && !cachedOptionalFailed.get().getValue().isPresent()){
return cachedOptionalFailed.get().getValue();
} }
if(goal instanceof Predicate && ((Predicate) goal).getName().equals("sameroom")){ // Otherwise create a new call to the theorem prover
Predicate p = (Predicate) goal;
Value v1 = p.getArguments()[0];
Value v2 = p.getArguments()[1];
Optional<Formula> inOptv1 = assumptions.stream().filter(x-> x instanceof Predicate &&
((Predicate)x).getName().equals("in") && ((Predicate) x).getArguments()[0].equals(v1)).findAny();
Optional<Formula> inOptv2 = assumptions.stream().filter(x-> x instanceof Predicate &&
((Predicate)x).getName().equals("in") && ((Predicate) x).getArguments()[0].equals(v2)).findAny();
if(inOptv1.isPresent() && inOptv2.isPresent()){
Value room1 = ((Predicate)inOptv1.get()).getArguments()[1];
Value room2 = ((Predicate)inOptv2.get()).getArguments()[1];
if(room1.equals(room2)){
return Optional.of(Justification.trivial(assumptions, goal));
}
}
}
{
Optional<Justification> answer = prover.prove(assumptions, goal); Optional<Justification> answer = prover.prove(assumptions, goal);
proverCache.put(inputPair, answer); proverCache.put(inputPair, answer);
return answer; return answer;
}
} }
public static synchronized Optional<Set<Map<Variable, Value>>> proveAndGetBindingsCached(Set<Formula> givens, Formula goal, List<Variable> variables) { public static synchronized Optional<Set<Map<Variable, Value>>> proveAndGetBindingsCached(Set<Formula> assumptions, Formula goal, List<Variable> variables) {
Triple<Set<Formula>, Formula, List<Variable>> inputTriple = Triple.of(givens, goal, variables);
// (1) If we've asked to find the variables that satisfy this exact goal from assumptions before
// then return the previous result
Triple<Set<Formula>, Formula, List<Variable>> inputTriple = Triple.of(assumptions, goal, variables);
if (proverBindingsCache.containsKey(inputTriple)) { if (proverBindingsCache.containsKey(inputTriple)) {
return proverBindingsCache.get(inputTriple); return proverBindingsCache.get(inputTriple);
} else {
Optional<Set<Map<Variable, Value>>> answer = proveAndGetMultipleBindings(givens, goal, variables);
proverBindingsCache.put(inputTriple, answer);
return answer;
} }
for (Map.Entry<Triple<Set<Formula>, Formula, List<Variable>>, Optional<Set<Map<Variable, Value>>>> entry : proverBindingsCache.entrySet()) {
Set<Formula> cachedAssumptions = entry.getKey().getLeft();
Formula cachedGoal = entry.getKey().getMiddle();
List<Variable> cachedVars = entry.getKey().getRight();
Optional<Set<Map<Variable, Value>>> optMapping = entry.getValue();
// (2) Return the cached justification if:
// - Goals are the same
// - The variable list requested is the same
// - The cached assumptions are a subset of the current ones
// - A justification was found
if (optMapping.isPresent() && cachedGoal.equals(goal) && cachedVars.equals(variables) && Sets.subset(cachedAssumptions, assumptions)) {
return optMapping;
}
// (3) Return cached failure if:
// - Goals are the same
// - The variable list requested is the same
// - Assumptions are a subset of cached assumptions
// - No justification was found
if (optMapping.isEmpty() && cachedGoal.equals(goal) && cachedVars.equals(variables) && Sets.subset(assumptions, cachedAssumptions)) {
return optMapping;
}
}
// Otherwise create a new call to the theorem prover
Optional<Set<Map<Variable, Value>>> answer = proveAndGetMultipleBindings(assumptions, goal, variables);
proverBindingsCache.put(inputTriple, answer);
return answer;
} }
public static synchronized Optional<Map<Variable, Value>> proveAndGetBindings(Set<Formula> givens, Formula goal, List<Variable> variables) { public static synchronized Optional<Map<Variable, Value>> proveAndGetBindings(Set<Formula> givens, Formula goal, List<Variable> variables) {
Future<Optional<Map<Variable, Value>>> future = new FutureTask<>(() -> { Future<Optional<Map<Variable, Value>>> future = new FutureTask<>(() -> {
return prover.proveAndGetBindings(givens, goal, variables); return prover.proveAndGetBindings(givens, goal, variables);
}); });
Optional<Map<Variable, Value>> answer; Optional<Map<Variable, Value>> answer;
try { try {
answer = future.get(1, TimeUnit.SECONDS); answer = future.get(1, TimeUnit.SECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) { } catch (InterruptedException | ExecutionException | TimeoutException e) {
answer = Optional.empty(); answer = Optional.empty();
@ -174,26 +153,11 @@ public class Operations {
Optional<org.apache.commons.lang3.tuple.Pair<Justification, Set<Map<Variable, Value>>>> ans = prover.proveAndGetMultipleBindings(givens, goal, variables); Optional<org.apache.commons.lang3.tuple.Pair<Justification, Set<Map<Variable, Value>>>> ans = prover.proveAndGetMultipleBindings(givens, goal, variables);
if(ans.isPresent()){ if (ans.isEmpty()) {
return Optional.of(ans.get().getRight());
}else {
return Optional.empty(); return Optional.empty();
} }
/* Future<Optional<Set<Map<Variable, Value>>>> future = new FutureTask<>(()-> prover.proveAndGetMultipleBindings(givens, goal, variables));
Optional<Set<Map<Variable, Value>>> answer; return Optional.of(ans.get().getRight());
try{
answer = future.get(50, TimeUnit.SECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e ) {
answer = Optional.empty();
}
return answer;
*/
} }
@ -301,18 +265,22 @@ public class Operations {
public static boolean satisfies(Set<Formula> background, State state, State goal) { public static boolean satisfies(Set<Formula> background, State state, State goal) {
if ((Sets.union(background, state.getFormulae()).containsAll(goal.getFormulae()))) { if ((Sets.union(background, state.getFormulae()).containsAll(goal.getFormulae()))) {
return true; return true;
} }
return goal.getFormulae().stream().
allMatch(x -> proveCached(Sets.union(background, state.getFormulae()), x).isPresent()); return proveCached(
Sets.union(background, state.getFormulae()),
Commons.makeAnd(goal.getFormulae())
).isPresent();
} }
public static boolean conflicts(Set<Formula> background, State state1, State state2) { public static boolean conflicts(Set<Formula> background, State state1, State state2) {
return proveCached(
return proveCached(Sets.union(background, Sets.union(state1.getFormulae(), state2.getFormulae())), State.FALSE).isPresent(); Sets.union(background, Sets.union(state1.getFormulae(), state2.getFormulae())),
State.FALSE
).isPresent();
} }