diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java | 131 |
1 files changed, 109 insertions, 22 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 22fabe3ada7..89ba36f0d39 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -2,17 +2,19 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; -import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.Rename; import com.yahoo.collections.ListMap; import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.TreeMap; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -68,34 +70,55 @@ public class DimensionRenamer { } private Map<String, Integer> solve(int maxIterations) { - int renamesTried = 0; - while (renamesTried++ <= dimensions.size()) { - Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations); + Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; + + for (RenameTarget renameTarget : prioritizedRenameTargets()) { + System.out.println("Trying rename " + renameTarget); + insertRenameOperation(renameTarget); + solution = solveWithOrWithoutSoftConstraints(maxIterations); if (solution != null) return solution; - if ( ! insertRenameOperation()) return null; + rollbackRenameOperation(renameTarget); } throw new IllegalArgumentException("Could not find a dimension naming solution " + "given constraints\n" + constraintsToString(constraints)); } - private boolean insertRenameOperation() { - IntermediateOperation operation = graph.operations().get("dense_out/MatMul"); - if (operation instanceof MatMul) { - IntermediateOperation arg0 = operation.inputs().get(0); - List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs()); - inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0)); - IntermediateOperation newOperation = operation.withInputs(inputs); - graph.put("dense_out/MatMul", newOperation); - - for (Arc key : new HashSet<>(constraints.keySet())) { - if (key.operation == operation) - constraints.removeAll(key); - } - addDimension("renamed_0"); - newOperation.addDimensionNameConstraints(this); - return true; + /** Inserts a rename operation if possible. Returns whether an operation was inserted. */ + private boolean insertRenameOperation(RenameTarget target) { + Rename rename = new Rename(target.operation.modelName(), + "Dot_ExpandDims_1", "renamed_0", + target.input()); + + List<IntermediateOperation> newInputs = new ArrayList<>(target.operation.inputs()); + newInputs.set(target.inputNumber, rename); + IntermediateOperation newOperation = target.operation.withInputs(newInputs); + if (target.rootKey == null) + throw new IllegalStateException("Renaming non-roots is not implemented"); + graph.put(target.rootKey, newOperation); + + removeConstraintsOf(target.operation); + rename.addDimensionNameConstraints(this); + newOperation.addDimensionNameConstraints(this); + return true; + } + + /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */ + private void rollbackRenameOperation(RenameTarget target) { + IntermediateOperation newOperation = graph.operations().get(target.rootKey); + Rename rename = (Rename)newOperation.inputs().get(target.inputNumber); + graph.put(target.rootKey, target.operation); + + removeConstraintsOf(rename); + removeConstraintsOf(newOperation); + target.operation.addDimensionNameConstraints(this); + } + + private void removeConstraintsOf(IntermediateOperation operation) { + for (Arc key : new HashSet<>(constraints.keySet())) { + if (key.operation == operation) + constraints.removeAll(key); } - return false; } private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) { @@ -124,6 +147,30 @@ public class DimensionRenamer { return removed; } + private List<RenameTarget> prioritizedRenameTargets() { + Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>(); + + for (var constraint : constraints.entrySet()) { + constraintsPerOperation.compute(constraint.getKey().operation, + (operation, count) -> count == null ? 1 : ++count); + } + List<IntermediateOperation> prioritizedOperations = + constraintsPerOperation.entrySet().stream() + .sorted(Comparator.comparingInt(entry -> - entry.getValue())) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()); + + List<RenameTarget> targets = new ArrayList<>(); + for (IntermediateOperation operation : prioritizedOperations) { + for (int i = 0; i < operation.inputs().size(); i++) { + RenameTarget target = new RenameTarget(operation, i, graph); + if (target.rootKey != null) // Inserting renames under non-roots is not implemented + targets.add(new RenameTarget(operation, i, graph)); + } + } + return targets; + } + /** * Retrieve resulting name of a dimension after solving for constraints, or empty if no * solution is found yet, or this dimension was not added before finding a solution. @@ -285,4 +332,44 @@ public class DimensionRenamer { } + /** + * An operation and an input number which we may want to insert a rename operation at. + * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...). + */ + private static class RenameTarget { + + final IntermediateOperation operation; + final int inputNumber; + + /** + * Returns the key of this operation in the root operations of the graph, + * or null if it is not a root operation + */ + final String rootKey; + + public RenameTarget(IntermediateOperation operation, int inputNumber, IntermediateGraph graph) { + this.operation = operation; + this.inputNumber = inputNumber; + this.rootKey = findRootKey(operation, graph); + } + + public IntermediateOperation input() { + return operation.inputs().get(inputNumber); + } + + private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) { + for (var entry : graph.operations().entrySet()) { + if (entry.getValue() == operation) + return entry.getKey(); + } + return null; + } + + @Override + public String toString() { + return operation + ", input " + inputNumber; + } + + } + } |