summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java421
1 files changed, 299 insertions, 122 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 9e9f66be700..0f563a75b11 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -2,179 +2,179 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.Rename;
+import com.yahoo.collections.ListMap;
-import java.util.ArrayDeque;
import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Deque;
+import java.util.Comparator;
import java.util.HashMap;
-import java.util.Iterator;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
- * A constraint satisfier to find suitable dimension names to reduce the
+ * A constraint solver which finds suitable dimension names to reduce the
* amount of necessary renaming during evaluation of an imported model.
*
* @author lesters
+ * @author bratseth
*/
public class DimensionRenamer {
+ private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName());
+
private final String dimensionPrefix;
- private final Map<String, List<Integer>> variables = new HashMap<>();
- private final Map<Arc, Constraint> constraints = new HashMap<>();
- private final Map<String, Integer> renames = new HashMap<>();
- private int iterations = 0;
+ /** The graph we are renaming the dimensions of */
+ private final IntermediateGraph graph;
+
+ /** The set of dimensions to find a solution for */
+ private final Set<String> dimensions = new HashSet<>();
+
+ /** The constraints on the dimension name assignment */
+ private final ListMap<Arc, Constraint> constraints = new ListMap<>();
+
+ /** The solution to this, or null if no solution is found yet */
+ private Map<String, Integer> renames = null;
- public DimensionRenamer() {
- this("d");
+ public DimensionRenamer(IntermediateGraph graph) {
+ this(graph, "d");
}
- public DimensionRenamer(String dimensionPrefix) {
+ public DimensionRenamer(IntermediateGraph graph, String dimensionPrefix) {
+ this.graph = graph;
this.dimensionPrefix = dimensionPrefix;
}
- /**
- * Add a dimension name variable.
- */
- public void addDimension(String name) {
- variables.computeIfAbsent(name, d -> new ArrayList<>());
- }
+ /** Add a dimension to the set of dimensions to be renamed */
+ public void addDimension(String name) { dimensions.add(name); }
+
+ /** Add a constraint between dimension names */
+ public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) {
+ if (constraint instanceof EqualConstraint && from.equals(to)) return;
- /**
- * Add a constraint between dimension names.
- */
- public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
Arc arc = new Arc(from, to, operation);
- Arc opposite = arc.opposite();
- constraints.put(arc, pred);
- constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ constraints.put(arc, constraint);
+ constraints.put(arc.opposite(), constraint.opposite()); // make constraint graph symmetric
}
- /**
- * Retrieve resulting name of dimension after solving for constraints.
- */
- public Optional<String> dimensionNameOf(String name) {
- if (!renames.containsKey(name)) {
- return Optional.empty();
- }
- return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
+ void solve() {
+ log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints));
+ renames = solve(100000);
+ log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames));
}
- /**
- * Perform iterative arc consistency until we have found a solution. After
- * an initial iteration, the variables (dimensions) will have multiple
- * valid values. Find a single valid assignment by iteratively locking one
- * dimension after another, and running the arc consistency algorithm
- * multiple times.
- *
- * This requires having constraints that result in an absolute ordering:
- * equals, lesserThan and greaterThan do that, but adding notEquals does
- * not typically result in a guaranteed ordering. If that is needed, the
- * algorithm below needs to be adapted with a backtracking (tree) search
- * to find solutions.
- */
- private void solve(int maxIterations) {
- initialize();
-
- // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+ private Map<String, Integer> solve(int maxIterations) {
+ Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ if (solution != null) return solution;
- for (String dimension : variables.keySet()) {
- List<Integer> values = variables.get(dimension);
- if (values.size() > 1) {
- if (!ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
- }
- values.sort(Integer::compare);
- variables.put(dimension, Collections.singletonList(values.get(0)));
- }
- renames.put(dimension, variables.get(dimension).get(0));
- if (iterations > maxIterations) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations");
- }
+ for (RenameTarget target : prioritizedRenameTargets()) {
+ System.out.println("Trying rename " + target);
+ target.insertRename(this);
+ solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ if (solution != null) return solution;
+ target.uninsertRename(this);
}
-
- // Todo: handle failure more gracefully:
- // If a solution can't be found, look at the operation node in the arc
- // with the most remaining constraints, and inject a rename operation.
- // Then run this algorithm again.
+ throw new IllegalArgumentException("Could not find a dimension naming solution " +
+ "given constraints\n" + constraintsToString(constraints));
}
- void solve() {
- solve(100000);
+ private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
+ Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations);
+ if ( solution == null) {
+ ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
+ boolean anyRemoved = copyHard(constraints, hardConstraints);
+ if (anyRemoved)
+ solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations);
+ }
+ return solution;
}
- private void initialize() {
- for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
- List<Integer> values = variable.getValue();
- for (int i = 0; i < variables.size(); ++i) {
- values.add(i); // invariant: values are in increasing order
+ /** Removes soft constraints and returns whether something was removed */
+ private boolean copyHard(ListMap<Arc, Constraint> source, ListMap<Arc, Constraint> target) {
+ boolean removed = false;
+ for (var entry : source.entrySet()) {
+ Arc arc = entry.getKey();
+ for (Constraint constraint : entry.getValue()) {
+ if ( ! constraint.isSoft())
+ target.put(arc, constraint);
+ else
+ removed = true;
}
}
+ return removed;
}
- private boolean ac3() {
- Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
- while (!workList.isEmpty()) {
- Arc arc = workList.pop();
- iterations += 1;
- if (revise(arc)) {
- if (variables.get(arc.from).size() == 0) {
- return false; // no solution found
- }
- for (Arc constraint : constraints.keySet()) {
- if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
- workList.add(constraint);
- }
- }
- }
+ private List<RenameTarget> prioritizedRenameTargets() {
+ Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>();
+
+ for (var constraint : constraints.entrySet()) {
+ constraintsPerOperation.compute(constraint.getKey().operation,
+ (operation, count) -> count == null ? 1 : ++count);
}
- return true;
- }
+ List<IntermediateOperation> prioritizedOperations =
+ constraintsPerOperation.entrySet().stream()
+ .sorted(Comparator.comparingInt(entry -> - entry.getValue()))
+ .map(entry -> entry.getKey())
+ .collect(Collectors.toList());
- private boolean revise(Arc arc) {
- boolean revised = false;
- for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
- Integer from = fromIterator.next();
- boolean satisfied = false;
- for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
- Integer to = toIterator.next();
- if (constraints.get(arc).test(from, to)) {
- satisfied = true;
+ List<RenameTarget> targets = new ArrayList<>();
+ for (IntermediateOperation operation : prioritizedOperations) {
+ for (int i = 0; i < operation.inputs().size(); i++) {
+ Optional<OrderedTensorType> inputType = operation.inputs().get(i).type();
+ if (inputType.isEmpty()) continue;
+ for (String dimensionName : inputType.get().dimensionNames()) {
+ RenameTarget target = new RenameTarget(operation, i, dimensionName, graph);
+ if (target.rootKey != null) // TODO: Inserting renames under non-roots is not implemented
+ targets.add(target);
}
}
- if (!satisfied) {
- fromIterator.remove();
- revised = true;
- }
}
- return revised;
- }
-
- public interface Constraint {
- boolean test(Integer x, Integer y);
+ return targets;
}
- public static boolean equals(Integer x, Integer y) {
- return Objects.equals(x, y);
+ /**
+ * Retrieve resulting name of a dimension after solving for constraints, or empty if no
+ * solution is found yet, or this dimension was not added before finding a solution.
+ */
+ public Optional<String> dimensionNameOf(String name) {
+ if ( renames == null || ! renames.containsKey(name))
+ return Optional.empty();
+ return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
}
- public static boolean lesserThan(Integer x, Integer y) {
- return x < y;
+ private static String renamesToString(Map<String, Integer> renames) {
+ return renames.entrySet().stream()
+ .map(e -> " " + e.getKey() + " -> " + e.getValue())
+ .collect(Collectors.joining("\n"));
}
- public static boolean greaterThan(Integer x, Integer y) {
- return x > y;
+ private static String constraintsToString(ListMap<Arc, Constraint> constraints) {
+ StringBuilder b = new StringBuilder();
+ for (var entry : constraints.entrySet()) {
+ Arc arc = entry.getKey();
+ for (Constraint constraint : entry.getValue()) {
+ if (constraint.isOpposite()) continue; // noise
+ b.append(" ");
+ if (constraint.isSoft())
+ b.append("(soft) ");
+ b.append(arc.from).append(" ").append(constraint).append(" ").append(arc.to);
+ b.append(" (origin: ").append(arc.operation).append(")\n");
+ }
+ }
+ return b.toString();
}
- private static class Arc {
+ static class Arc {
- private final String from;
- private final String to;
+ final String from;
+ final String to;
private final IntermediateOperation operation;
Arc(String from, String to, IntermediateOperation operation) {
@@ -194,7 +194,7 @@ public class DimensionRenamer {
@Override
public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof Arc)) {
+ if (!(obj instanceof Arc)) {
return false;
}
Arc other = (Arc) obj;
@@ -203,8 +203,185 @@ public class DimensionRenamer {
@Override
public String toString() {
- return String.format("%s -> %s", from, to);
+ return from + " -> " + to;
+ }
+ }
+
+ public static abstract class Constraint {
+
+ private final boolean soft, opposite;
+
+ protected Constraint(boolean soft, boolean opposite) {
+ this.soft = soft;
+ this.opposite = opposite;
+ }
+
+ abstract boolean test(Integer x, Integer y);
+ abstract Constraint opposite();
+
+ /** Returns whether this constraint can be violated if that is necessary to achieve a solution */
+ boolean isSoft() { return soft; }
+
+ /** Returns whether this is an opposite of another constraint */
+ boolean isOpposite() { return opposite; }
+
+ public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); }
+ public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); }
+ public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); }
+ public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); }
+
+ }
+
+ private static class EqualConstraint extends Constraint {
+
+ private EqualConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return Objects.equals(x, y); }
+
+ @Override
+ public Constraint opposite() { return new EqualConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return "=="; }
+
+ }
+
+ private static class NotEqualConstraint extends Constraint {
+
+ private NotEqualConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return ! Objects.equals(x, y); }
+
+ @Override
+ public Constraint opposite() { return new NotEqualConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return "!="; }
+
+ }
+
+ private static class LessThanConstraint extends Constraint {
+
+ private LessThanConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return x < y; }
+
+ @Override
+ public Constraint opposite() { return new GreaterThanConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return "<"; }
+
+ }
+
+ private static class GreaterThanConstraint extends Constraint {
+
+ private GreaterThanConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return x > y; }
+
+ @Override
+ public Constraint opposite() { return new LessThanConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return ">"; }
+
+ }
+
+ /**
+ * An operation and an input number which we may want to insert a rename operation at.
+ * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...).
+ *
+ * This class is (and must remain) immutable.
+ */
+ private static class RenameTarget {
+
+ final IntermediateOperation operation;
+ final int inputNumber;
+ final String dimensionName;
+ final IntermediateGraph graph;
+
+ /**
+ * Returns the key of this operation in the root operations of the graph,
+ * or null if it is not a root operation
+ */
+ final String rootKey;
+
+ public RenameTarget(IntermediateOperation operation, int inputNumber, String dimensionName, IntermediateGraph graph) {
+ this.operation = operation;
+ this.inputNumber = inputNumber;
+ this.dimensionName = dimensionName;
+ this.rootKey = findRootKey(operation, graph);
+ this.graph = graph;
+ }
+
+ public IntermediateOperation input() {
+ return operation.inputs().get(inputNumber);
+ }
+
+ private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) {
+ for (var entry : graph.operations().entrySet()) {
+ if (entry.getValue() == operation)
+ return entry.getKey();
+ }
+ return null;
+ }
+
+ /** Inserts a rename operation if possible. Returns whether an operation was inserted. */
+ private boolean insertRename(DimensionRenamer renamer) {
+ Rename rename = new Rename(operation.modelName(),
+ dimensionName,
+ renamer.dimensionPrefix + renamer.dimensions.size(),
+ input());
+
+ List<IntermediateOperation> newInputs = new ArrayList<>(operation.inputs());
+ newInputs.set(inputNumber, rename);
+ IntermediateOperation newOperation = operation.withInputs(newInputs);
+ if (rootKey == null)
+ throw new IllegalStateException("Renaming non-roots is not implemented");
+ graph.put(rootKey, newOperation);
+
+ removeConstraintsOf(operation, renamer);
+ rename.addDimensionNameConstraints(renamer);
+ newOperation.addDimensionNameConstraints(renamer);
+ return true;
+ }
+
+ /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */
+ private void uninsertRename(DimensionRenamer renamer) {
+ IntermediateOperation newOperation = graph.operations().get(rootKey);
+ Rename rename = (Rename)newOperation.inputs().get(inputNumber);
+ graph.put(rootKey, operation);
+
+ removeConstraintsOf(rename, renamer);
+ removeConstraintsOf(newOperation, renamer);
+ operation.addDimensionNameConstraints(renamer);
+ }
+
+ private void removeConstraintsOf(IntermediateOperation operation, DimensionRenamer renamer) {
+ for (Arc key : new HashSet<>(renamer.constraints.keySet())) {
+ if (key.operation == operation)
+ renamer.constraints.removeAll(key);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return operation + ", input " + inputNumber;
}
+
}
}