diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-04 07:32:18 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-04 07:32:18 -0700 |
commit | 1a8cc4d3d2076d6a25b92c1f08c716b356974f62 (patch) | |
tree | 25e6da26ac8faa59f5430a655cf8293ba99660ce /model-integration | |
parent | 0ce6fa7cbdf71fd39cb5bb18accfa84a20e7e120 (diff) |
Extract constraint solver
Diffstat (limited to 'model-integration')
2 files changed, 117 insertions, 83 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 9821870e38b..10d39a43c61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -5,15 +5,10 @@ import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.Rename; import com.yahoo.collections.ListMap; -import com.yahoo.lang.MutableInteger; -import com.yahoo.text.ExpressionFormatter; -import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Deque; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -94,11 +89,10 @@ public class DimensionRenamer { * @return the solution in the form of the renames to perform */ private Map<String, Integer> solve(int maxIterations) { - // variables.freeze(); Map<String, Integer> renames = new HashMap<>(); // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts - boolean solved = trySolve(variables, constraints, maxIterations, renames); + boolean solved = NamingConstraintSolver.solve(variables, constraints, maxIterations, renames); if ( ! solved) { IntermediateOperation operation = graph.operations().get("dense_out/MatMul"); if (operation != null && operation instanceof MatMul) { @@ -116,7 +110,7 @@ public class DimensionRenamer { newOperation.addDimensionNameConstraints(this); renames.clear(); - solved = trySolve(variables, constraints, maxIterations, renames); + solved = NamingConstraintSolver.solve(variables, constraints, maxIterations, renames); } } if ( ! solved) { @@ -124,7 +118,7 @@ public class DimensionRenamer { ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); boolean anyRemoved = copyHard(constraints, hardConstraints); if (anyRemoved) - solved = trySolve(variables, hardConstraints, maxIterations, renames); + solved = NamingConstraintSolver.solve(variables, hardConstraints, maxIterations, renames); if ( ! solved) { throw new IllegalArgumentException("Could not find a dimension naming solution " + "given constraints\n" + constraintsToString(hardConstraints)); @@ -154,27 +148,6 @@ public class DimensionRenamer { return removed; } - /** Try the solve the constraint problem given in the arguments, and put the result in renames */ - private static boolean trySolve(ListMap<String, Integer> inputVariables, - ListMap<Arc, Constraint> constraints, - int maxIterations, - Map<String, Integer> renames) { - var variables = new ListMap<>(inputVariables); - initialize(variables); - MutableInteger iterations = new MutableInteger(0); - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); - if (values.size() > 1) { - if ( ! ac3(iterations, variables, constraints)) return false; - values.sort(Integer::compare); - variables.replace(dimension, values.get(0)); - } - renames.put(dimension, variables.get(dimension).get(0)); - if (iterations.get() > maxIterations) return false; - } - return true; - } - void solve() { log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints)); renames = solve(100000); @@ -187,56 +160,6 @@ public class DimensionRenamer { .collect(Collectors.joining("\n")); } - private static void initialize(ListMap<String, Integer> variables) { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order - } - } - } - - private static boolean ac3(MutableInteger iterations, - ListMap<String, Integer> variables, - ListMap<Arc, Constraint> constraints) { - Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); - while ( ! workList.isEmpty()) { - Arc arc = workList.pop(); - iterations.add(1); - if (revise(arc, variables, constraints)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } - for (Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { - workList.add(constraint); - } - } - } - } - return true; - } - - private static boolean revise(Arc arc, - ListMap<String, Integer> variables, - ListMap<Arc, Constraint> constraints) { - boolean revised = false; - for (Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { - Integer from = fromIterator.next(); - boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { - Integer to = toIterator.next(); - if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to))) - satisfied = true; - } - if ( ! satisfied) { - fromIterator.remove(); - revised = true; - } - } - return revised; - } - private static String constraintsToString(ListMap<Arc, Constraint> constraints) { StringBuilder b = new StringBuilder(); for (var entry : constraints.entrySet()) { @@ -253,10 +176,10 @@ public class DimensionRenamer { return b.toString(); } - private static class Arc { + static class Arc { - private final String from; - private final String to; + final String from; + final String to; private final IntermediateOperation operation; Arc(String from, String to, IntermediateOperation operation) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java new file mode 100644 index 00000000000..b059bb96d91 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java @@ -0,0 +1,111 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import com.yahoo.collections.ListMap; +import com.yahoo.lang.MutableInteger; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * Solves a dimension naming constraint problem. + * + * @author lesters + * @author bratseth + */ +class NamingConstraintSolver { + + private final ListMap<String, Integer> variables; + private final ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints; + + private int iterations = 0; + private final int maxIterations; + + /** The solution to this rename problem */ + private Map<String, Integer> renames; + + private NamingConstraintSolver(ListMap<String, Integer> inputVariables, + ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, + int maxIterations, + Map<String, Integer> renames) { + this.variables = new ListMap<>(inputVariables); + initialize(variables); + this.constraints = constraints; + this.maxIterations = maxIterations; + this.renames = renames; + } + + /** Try the solve the constraint problem given in the arguments, and put the result in renames */ + private boolean trySolve() { + for (String dimension : variables.keySet()) { + List<Integer> values = variables.get(dimension); + if (values.size() > 1) { + if ( ! ac3()) return false; + values.sort(Integer::compare); + variables.replace(dimension, values.get(0)); + } + renames.put(dimension, variables.get(dimension).get(0)); + if (iterations > maxIterations) return false; + } + return true; + } + + private static void initialize(ListMap<String, Integer> variables) { + for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { + List<Integer> values = variable.getValue(); + for (int i = 0; i < variables.size(); ++i) { + values.add(i); // invariant: values are in increasing order + } + } + } + + private boolean ac3() { + Deque<DimensionRenamer.Arc> workList = new ArrayDeque<>(constraints.keySet()); + while ( ! workList.isEmpty()) { + DimensionRenamer.Arc arc = workList.pop(); + iterations++; + if (revise(arc, variables, constraints)) { + if (variables.get(arc.from).size() == 0) { + return false; // no solution found + } + for (DimensionRenamer.Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { + workList.add(constraint); + } + } + } + } + return true; + } + + private static boolean revise(DimensionRenamer.Arc arc, + ListMap<String, Integer> variables, + ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints) { + boolean revised = false; + for (Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to))) + satisfied = true; + } + if ( ! satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + public static boolean solve(ListMap<String, Integer> inputVariables, + ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, + int maxIterations, + Map<String, Integer> renames) { + return new NamingConstraintSolver(inputVariables, constraints, maxIterations, renames).trySolve(); + } + +} |