summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java131
1 files changed, 109 insertions, 22 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 22fabe3ada7..89ba36f0d39 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -2,17 +2,19 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
-import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.Rename;
import com.yahoo.collections.ListMap;
import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
+import java.util.TreeMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -68,34 +70,55 @@ public class DimensionRenamer {
}
private Map<String, Integer> solve(int maxIterations) {
- int renamesTried = 0;
- while (renamesTried++ <= dimensions.size()) {
- Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ if (solution != null) return solution;
+
+ for (RenameTarget renameTarget : prioritizedRenameTargets()) {
+ System.out.println("Trying rename " + renameTarget);
+ insertRenameOperation(renameTarget);
+ solution = solveWithOrWithoutSoftConstraints(maxIterations);
if (solution != null) return solution;
- if ( ! insertRenameOperation()) return null;
+ rollbackRenameOperation(renameTarget);
}
throw new IllegalArgumentException("Could not find a dimension naming solution " +
"given constraints\n" + constraintsToString(constraints));
}
- private boolean insertRenameOperation() {
- IntermediateOperation operation = graph.operations().get("dense_out/MatMul");
- if (operation instanceof MatMul) {
- IntermediateOperation arg0 = operation.inputs().get(0);
- List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs());
- inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0));
- IntermediateOperation newOperation = operation.withInputs(inputs);
- graph.put("dense_out/MatMul", newOperation);
-
- for (Arc key : new HashSet<>(constraints.keySet())) {
- if (key.operation == operation)
- constraints.removeAll(key);
- }
- addDimension("renamed_0");
- newOperation.addDimensionNameConstraints(this);
- return true;
+ /** Inserts a rename operation if possible. Returns whether an operation was inserted. */
+ private boolean insertRenameOperation(RenameTarget target) {
+ Rename rename = new Rename(target.operation.modelName(),
+ "Dot_ExpandDims_1", "renamed_0",
+ target.input());
+
+ List<IntermediateOperation> newInputs = new ArrayList<>(target.operation.inputs());
+ newInputs.set(target.inputNumber, rename);
+ IntermediateOperation newOperation = target.operation.withInputs(newInputs);
+ if (target.rootKey == null)
+ throw new IllegalStateException("Renaming non-roots is not implemented");
+ graph.put(target.rootKey, newOperation);
+
+ removeConstraintsOf(target.operation);
+ rename.addDimensionNameConstraints(this);
+ newOperation.addDimensionNameConstraints(this);
+ return true;
+ }
+
+ /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */
+ private void rollbackRenameOperation(RenameTarget target) {
+ IntermediateOperation newOperation = graph.operations().get(target.rootKey);
+ Rename rename = (Rename)newOperation.inputs().get(target.inputNumber);
+ graph.put(target.rootKey, target.operation);
+
+ removeConstraintsOf(rename);
+ removeConstraintsOf(newOperation);
+ target.operation.addDimensionNameConstraints(this);
+ }
+
+ private void removeConstraintsOf(IntermediateOperation operation) {
+ for (Arc key : new HashSet<>(constraints.keySet())) {
+ if (key.operation == operation)
+ constraints.removeAll(key);
}
- return false;
}
private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
@@ -124,6 +147,30 @@ public class DimensionRenamer {
return removed;
}
+ private List<RenameTarget> prioritizedRenameTargets() {
+ Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>();
+
+ for (var constraint : constraints.entrySet()) {
+ constraintsPerOperation.compute(constraint.getKey().operation,
+ (operation, count) -> count == null ? 1 : ++count);
+ }
+ List<IntermediateOperation> prioritizedOperations =
+ constraintsPerOperation.entrySet().stream()
+ .sorted(Comparator.comparingInt(entry -> - entry.getValue()))
+ .map(entry -> entry.getKey())
+ .collect(Collectors.toList());
+
+ List<RenameTarget> targets = new ArrayList<>();
+ for (IntermediateOperation operation : prioritizedOperations) {
+ for (int i = 0; i < operation.inputs().size(); i++) {
+ RenameTarget target = new RenameTarget(operation, i, graph);
+ if (target.rootKey != null) // Inserting renames under non-roots is not implemented
+ targets.add(new RenameTarget(operation, i, graph));
+ }
+ }
+ return targets;
+ }
+
/**
* Retrieve resulting name of a dimension after solving for constraints, or empty if no
* solution is found yet, or this dimension was not added before finding a solution.
@@ -285,4 +332,44 @@ public class DimensionRenamer {
}
+ /**
+ * An operation and an input number which we may want to insert a rename operation at.
+ * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...).
+ */
+ private static class RenameTarget {
+
+ final IntermediateOperation operation;
+ final int inputNumber;
+
+ /**
+ * Returns the key of this operation in the root operations of the graph,
+ * or null if it is not a root operation
+ */
+ final String rootKey;
+
+ public RenameTarget(IntermediateOperation operation, int inputNumber, IntermediateGraph graph) {
+ this.operation = operation;
+ this.inputNumber = inputNumber;
+ this.rootKey = findRootKey(operation, graph);
+ }
+
+ public IntermediateOperation input() {
+ return operation.inputs().get(inputNumber);
+ }
+
+ private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) {
+ for (var entry : graph.operations().entrySet()) {
+ if (entry.getValue() == operation)
+ return entry.getKey();
+ }
+ return null;
+ }
+
+ @Override
+ public String toString() {
+ return operation + ", input " + inputNumber;
+ }
+
+ }
+
}