summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-04 07:32:18 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-04 07:32:18 -0700
commit1a8cc4d3d2076d6a25b92c1f08c716b356974f62 (patch)
tree25e6da26ac8faa59f5430a655cf8293ba99660ce
parent0ce6fa7cbdf71fd39cb5bb18accfa84a20e7e120 (diff)
Extract constraint solver
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java89
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java111
2 files changed, 117 insertions, 83 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 9821870e38b..10d39a43c61 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -5,15 +5,10 @@ import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.Rename;
import com.yahoo.collections.ListMap;
-import com.yahoo.lang.MutableInteger;
-import com.yahoo.text.ExpressionFormatter;
-import java.util.ArrayDeque;
import java.util.ArrayList;
-import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -94,11 +89,10 @@ public class DimensionRenamer {
* @return the solution in the form of the renames to perform
*/
private Map<String, Integer> solve(int maxIterations) {
- // variables.freeze();
Map<String, Integer> renames = new HashMap<>();
// Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
- boolean solved = trySolve(variables, constraints, maxIterations, renames);
+ boolean solved = NamingConstraintSolver.solve(variables, constraints, maxIterations, renames);
if ( ! solved) {
IntermediateOperation operation = graph.operations().get("dense_out/MatMul");
if (operation != null && operation instanceof MatMul) {
@@ -116,7 +110,7 @@ public class DimensionRenamer {
newOperation.addDimensionNameConstraints(this);
renames.clear();
- solved = trySolve(variables, constraints, maxIterations, renames);
+ solved = NamingConstraintSolver.solve(variables, constraints, maxIterations, renames);
}
}
if ( ! solved) {
@@ -124,7 +118,7 @@ public class DimensionRenamer {
ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
boolean anyRemoved = copyHard(constraints, hardConstraints);
if (anyRemoved)
- solved = trySolve(variables, hardConstraints, maxIterations, renames);
+ solved = NamingConstraintSolver.solve(variables, hardConstraints, maxIterations, renames);
if ( ! solved) {
throw new IllegalArgumentException("Could not find a dimension naming solution " +
"given constraints\n" + constraintsToString(hardConstraints));
@@ -154,27 +148,6 @@ public class DimensionRenamer {
return removed;
}
- /** Try the solve the constraint problem given in the arguments, and put the result in renames */
- private static boolean trySolve(ListMap<String, Integer> inputVariables,
- ListMap<Arc, Constraint> constraints,
- int maxIterations,
- Map<String, Integer> renames) {
- var variables = new ListMap<>(inputVariables);
- initialize(variables);
- MutableInteger iterations = new MutableInteger(0);
- for (String dimension : variables.keySet()) {
- List<Integer> values = variables.get(dimension);
- if (values.size() > 1) {
- if ( ! ac3(iterations, variables, constraints)) return false;
- values.sort(Integer::compare);
- variables.replace(dimension, values.get(0));
- }
- renames.put(dimension, variables.get(dimension).get(0));
- if (iterations.get() > maxIterations) return false;
- }
- return true;
- }
-
void solve() {
log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints));
renames = solve(100000);
@@ -187,56 +160,6 @@ public class DimensionRenamer {
.collect(Collectors.joining("\n"));
}
- private static void initialize(ListMap<String, Integer> variables) {
- for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
- List<Integer> values = variable.getValue();
- for (int i = 0; i < variables.size(); ++i) {
- values.add(i); // invariant: values are in increasing order
- }
- }
- }
-
- private static boolean ac3(MutableInteger iterations,
- ListMap<String, Integer> variables,
- ListMap<Arc, Constraint> constraints) {
- Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
- while ( ! workList.isEmpty()) {
- Arc arc = workList.pop();
- iterations.add(1);
- if (revise(arc, variables, constraints)) {
- if (variables.get(arc.from).size() == 0) {
- return false; // no solution found
- }
- for (Arc constraint : constraints.keySet()) {
- if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
- workList.add(constraint);
- }
- }
- }
- }
- return true;
- }
-
- private static boolean revise(Arc arc,
- ListMap<String, Integer> variables,
- ListMap<Arc, Constraint> constraints) {
- boolean revised = false;
- for (Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
- Integer from = fromIterator.next();
- boolean satisfied = false;
- for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
- Integer to = toIterator.next();
- if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to)))
- satisfied = true;
- }
- if ( ! satisfied) {
- fromIterator.remove();
- revised = true;
- }
- }
- return revised;
- }
-
private static String constraintsToString(ListMap<Arc, Constraint> constraints) {
StringBuilder b = new StringBuilder();
for (var entry : constraints.entrySet()) {
@@ -253,10 +176,10 @@ public class DimensionRenamer {
return b.toString();
}
- private static class Arc {
+ static class Arc {
- private final String from;
- private final String to;
+ final String from;
+ final String to;
private final IntermediateOperation operation;
Arc(String from, String to, IntermediateOperation operation) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
new file mode 100644
index 00000000000..b059bb96d91
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
@@ -0,0 +1,111 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import com.yahoo.collections.ListMap;
+import com.yahoo.lang.MutableInteger;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Solves a dimension naming constraint problem.
+ *
+ * @author lesters
+ * @author bratseth
+ */
+class NamingConstraintSolver {
+
+ private final ListMap<String, Integer> variables;
+ private final ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints;
+
+ private int iterations = 0;
+ private final int maxIterations;
+
+ /** The solution to this rename problem */
+ private Map<String, Integer> renames;
+
+ private NamingConstraintSolver(ListMap<String, Integer> inputVariables,
+ ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints,
+ int maxIterations,
+ Map<String, Integer> renames) {
+ this.variables = new ListMap<>(inputVariables);
+ initialize(variables);
+ this.constraints = constraints;
+ this.maxIterations = maxIterations;
+ this.renames = renames;
+ }
+
+ /** Try the solve the constraint problem given in the arguments, and put the result in renames */
+ private boolean trySolve() {
+ for (String dimension : variables.keySet()) {
+ List<Integer> values = variables.get(dimension);
+ if (values.size() > 1) {
+ if ( ! ac3()) return false;
+ values.sort(Integer::compare);
+ variables.replace(dimension, values.get(0));
+ }
+ renames.put(dimension, variables.get(dimension).get(0));
+ if (iterations > maxIterations) return false;
+ }
+ return true;
+ }
+
+ private static void initialize(ListMap<String, Integer> variables) {
+ for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
+ List<Integer> values = variable.getValue();
+ for (int i = 0; i < variables.size(); ++i) {
+ values.add(i); // invariant: values are in increasing order
+ }
+ }
+ }
+
+ private boolean ac3() {
+ Deque<DimensionRenamer.Arc> workList = new ArrayDeque<>(constraints.keySet());
+ while ( ! workList.isEmpty()) {
+ DimensionRenamer.Arc arc = workList.pop();
+ iterations++;
+ if (revise(arc, variables, constraints)) {
+ if (variables.get(arc.from).size() == 0) {
+ return false; // no solution found
+ }
+ for (DimensionRenamer.Arc constraint : constraints.keySet()) {
+ if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
+ workList.add(constraint);
+ }
+ }
+ }
+ }
+ return true;
+ }
+
+ private static boolean revise(DimensionRenamer.Arc arc,
+ ListMap<String, Integer> variables,
+ ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints) {
+ boolean revised = false;
+ for (Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
+ Integer from = fromIterator.next();
+ boolean satisfied = false;
+ for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
+ Integer to = toIterator.next();
+ if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to)))
+ satisfied = true;
+ }
+ if ( ! satisfied) {
+ fromIterator.remove();
+ revised = true;
+ }
+ }
+ return revised;
+ }
+
+ public static boolean solve(ListMap<String, Integer> inputVariables,
+ ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints,
+ int maxIterations,
+ Map<String, Integer> renames) {
+ return new NamingConstraintSolver(inputVariables, constraints, maxIterations, renames).trySolve();
+ }
+
+}