diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-04 08:14:31 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-04 08:14:31 -0700 |
commit | 97f7d89a3b769d6e23fb367d6c96b5cfa8bae700 (patch) | |
tree | fc85e5cea9abb54943f08c7d38bf1b6260bfb313 /model-integration | |
parent | f0473187794b105ba8bf5ae32f99889dd2d909ad (diff) |
Keep solution possibilities internal to solver
Diffstat (limited to 'model-integration')
2 files changed, 42 insertions, 47 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index e8291470054..5bbcd3a7265 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -12,6 +12,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -21,6 +22,7 @@ import java.util.stream.Collectors; * amount of necessary renaming during evaluation of an imported model. * * @author lesters + * @author bratseth */ public class DimensionRenamer { @@ -31,7 +33,10 @@ public class DimensionRenamer { /** The graph we are renaming the dimensions of */ private final IntermediateGraph graph; - private final ListMap<String, Integer> variables = new ListMap<>(); + /** The set of dimensions to find a solution for */ + private final Set<String> dimensions = new HashSet<>(); + + /** The constraints on the dimension name assignment */ private final ListMap<Arc, Constraint> constraints = new ListMap<>(); /** The solution to this, or null if no solution is found yet */ @@ -46,16 +51,10 @@ public class DimensionRenamer { this.dimensionPrefix = dimensionPrefix; } - /** - * Add a dimension name variable. - */ - public void addDimension(String name) { - variables.put(name); - } + /** Add a dimension to the set of dimensions to be renamed */ + public void addDimension(String name) { dimensions.add(name); } - /** - * Add a constraint between dimension names. - */ + /** Add a constraint between dimension names */ public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) { Arc arc = new Arc(from, to, operation); constraints.put(arc, constraint); @@ -63,12 +62,12 @@ public class DimensionRenamer { } /** - * Retrieve resulting name of dimension after solving for constraints. + * Retrieve resulting name of a dimension after solving for constraints, or empty if no + * solution is found yet, or this dimension was not added before finding a solution. */ public Optional<String> dimensionNameOf(String name) { - if ( ! renames.containsKey(name)) { + if ( renames == null || ! renames.containsKey(name)) return Optional.empty(); - } return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); } @@ -88,7 +87,7 @@ public class DimensionRenamer { * @return the solution in the form of the renames to perform */ private Map<String, Integer> solve(int maxIterations) { - Map<String, Integer> solution = NamingConstraintSolver.solve(variables, constraints, maxIterations); + Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); if ( solution == null) { IntermediateOperation operation = graph.operations().get("dense_out/MatMul"); if (operation != null && operation instanceof MatMul) { @@ -105,14 +104,14 @@ public class DimensionRenamer { addDimension("renamed_0"); newOperation.addDimensionNameConstraints(this); - solution = NamingConstraintSolver.solve(variables, constraints, maxIterations); + solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); } } if ( solution == null) { ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); boolean anyRemoved = copyHard(constraints, hardConstraints); if (anyRemoved) - solution = NamingConstraintSolver.solve(variables, hardConstraints, maxIterations); + solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations); if ( solution == null) { throw new IllegalArgumentException("Could not find a dimension naming solution " + "given constraints\n" + constraintsToString(hardConstraints)); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java index 54f97a3427c..3b1b9ce1715 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java @@ -2,7 +2,6 @@ package ai.vespa.rankingexpression.importer; import com.yahoo.collections.ListMap; -import com.yahoo.lang.MutableInteger; import java.util.ArrayDeque; import java.util.Deque; @@ -10,6 +9,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; /** * Solves a dimension naming constraint problem. @@ -19,75 +19,71 @@ import java.util.Map; */ class NamingConstraintSolver { - private final ListMap<String, Integer> variables; + private final ListMap<String, Integer> possibleAssignments; private final ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints; private int iterations = 0; private final int maxIterations; - private NamingConstraintSolver(ListMap<String, Integer> inputVariables, + private NamingConstraintSolver(Set<String> dimensions, ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, int maxIterations) { - this.variables = new ListMap<>(inputVariables); - initialize(variables); + this.possibleAssignments = allPossibilities(dimensions); this.constraints = constraints; this.maxIterations = maxIterations; } + /** Returns a list containing a list of all assignment possibilities for each of the given dimensions */ + private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) { + ListMap<String, Integer> all = new ListMap<>(); + for (String dimension : dimensions) { + for (int i = 0; i < dimensions.size(); ++i) + all.put(dimension, i); + } + return all; + } + /** Try the solve the constraint problem given in the arguments, and put the result in renames */ private Map<String, Integer> trySolve() { // TODO: Evaluate possible improved efficiency by using a heuristic such as min-conflicts Map<String, Integer> solution = new HashMap<>(); - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); + for (String dimension : possibleAssignments.keySet()) { + List<Integer> values = possibleAssignments.get(dimension); if (values.size() > 1) { if ( ! ac3()) return null; values.sort(Integer::compare); - variables.replace(dimension, values.get(0)); + possibleAssignments.replace(dimension, values.get(0)); } - solution.put(dimension, variables.get(dimension).get(0)); + solution.put(dimension, possibleAssignments.get(dimension).get(0)); if (iterations > maxIterations) return null; } return solution; } - private static void initialize(ListMap<String, Integer> variables) { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order - } - } - } - private boolean ac3() { Deque<DimensionRenamer.Arc> workList = new ArrayDeque<>(constraints.keySet()); while ( ! workList.isEmpty()) { DimensionRenamer.Arc arc = workList.pop(); iterations++; - if (revise(arc, variables, constraints)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } + if (revise(arc)) { + if (possibleAssignments.get(arc.from).isEmpty()) return false; + for (DimensionRenamer.Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) workList.add(constraint); - } } } } return true; } - private static boolean revise(DimensionRenamer.Arc arc, - ListMap<String, Integer> variables, - ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints) { + private boolean revise(DimensionRenamer.Arc arc) { boolean revised = false; - for (Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { + for (Iterator<Integer> fromIterator = possibleAssignments.get(arc.from).iterator(); fromIterator.hasNext(); ) { Integer from = fromIterator.next(); boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { + for (Iterator<Integer> toIterator = possibleAssignments.get(arc.to).iterator(); toIterator.hasNext(); ) { Integer to = toIterator.next(); if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to))) satisfied = true; @@ -106,10 +102,10 @@ class NamingConstraintSolver { * @return the solution as a map from existing names to name ids represented as integers, or NULL * if no solution could be found */ - public static Map<String, Integer> solve(ListMap<String, Integer> inputVariables, + public static Map<String, Integer> solve(Set<String> dimensions, ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, int maxIterations) { - return new NamingConstraintSolver(inputVariables, constraints, maxIterations).trySolve(); + return new NamingConstraintSolver(dimensions, constraints, maxIterations).trySolve(); } } |