diff options
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java | 46 |
1 files changed, 26 insertions, 20 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index bf5d836b809..f4cf1b5fabc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -69,32 +69,38 @@ public class DimensionRenamer { private Map<String, Integer> solve(int maxIterations) { Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations); - if ( solution == null) { - IntermediateOperation operation = graph.operations().get("dense_out/MatMul"); - if (operation instanceof MatMul) { - IntermediateOperation arg0 = operation.inputs().get(0); - List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs()); - inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0)); - IntermediateOperation newOperation = operation.withInputs(inputs); - graph.put("dense_out/MatMul", newOperation); - - for (Arc key : new HashSet<>(constraints.keySet())) { - if (key.operation == operation) - constraints.removeAll(key); - } - addDimension("renamed_0"); - newOperation.addDimensionNameConstraints(this); - - solution = solveWithOrWithoutSoftConstraints(maxIterations); - } + int renamesTried = 0; + while (solution == null && renamesTried++ < dimensions.size()) { + boolean inserted = insertRenameOperation(); + if ( ! inserted ) break; + solution = solveWithOrWithoutSoftConstraints(maxIterations); } - if ( solution == null) { + if ( solution == null) throw new IllegalArgumentException("Could not find a dimension naming solution " + "given constraints\n" + constraintsToString(constraints)); - } return solution; } + private boolean insertRenameOperation() { + IntermediateOperation operation = graph.operations().get("dense_out/MatMul"); + if (operation instanceof MatMul) { + IntermediateOperation arg0 = operation.inputs().get(0); + List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs()); + inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0)); + IntermediateOperation newOperation = operation.withInputs(inputs); + graph.put("dense_out/MatMul", newOperation); + + for (Arc key : new HashSet<>(constraints.keySet())) { + if (key.operation == operation) + constraints.removeAll(key); + } + addDimension("renamed_0"); + newOperation.addDimensionNameConstraints(this); + return true; + } + return false; + } + private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) { Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); if ( solution == null) { |