summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-04 08:36:46 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-04 08:36:46 -0700
commit3be34404a96cd782a7f259f29491581272b00c11 (patch)
tree4db26e35aab6601bbd9912c51a3fffd5bad1d055
parent840ddab2e6e3b2e243960b9bea8ff7051963a4c2 (diff)
Loop and rename
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java46
1 files changed, 26 insertions, 20 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index bf5d836b809..f4cf1b5fabc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -69,32 +69,38 @@ public class DimensionRenamer {
private Map<String, Integer> solve(int maxIterations) {
Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations);
- if ( solution == null) {
- IntermediateOperation operation = graph.operations().get("dense_out/MatMul");
- if (operation instanceof MatMul) {
- IntermediateOperation arg0 = operation.inputs().get(0);
- List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs());
- inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0));
- IntermediateOperation newOperation = operation.withInputs(inputs);
- graph.put("dense_out/MatMul", newOperation);
-
- for (Arc key : new HashSet<>(constraints.keySet())) {
- if (key.operation == operation)
- constraints.removeAll(key);
- }
- addDimension("renamed_0");
- newOperation.addDimensionNameConstraints(this);
-
- solution = solveWithOrWithoutSoftConstraints(maxIterations);
- }
+ int renamesTried = 0;
+ while (solution == null && renamesTried++ < dimensions.size()) {
+ boolean inserted = insertRenameOperation();
+ if ( ! inserted ) break;
+ solution = solveWithOrWithoutSoftConstraints(maxIterations);
}
- if ( solution == null) {
+ if ( solution == null)
throw new IllegalArgumentException("Could not find a dimension naming solution " +
"given constraints\n" + constraintsToString(constraints));
- }
return solution;
}
+ private boolean insertRenameOperation() {
+ IntermediateOperation operation = graph.operations().get("dense_out/MatMul");
+ if (operation instanceof MatMul) {
+ IntermediateOperation arg0 = operation.inputs().get(0);
+ List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs());
+ inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0));
+ IntermediateOperation newOperation = operation.withInputs(inputs);
+ graph.put("dense_out/MatMul", newOperation);
+
+ for (Arc key : new HashSet<>(constraints.keySet())) {
+ if (key.operation == operation)
+ constraints.removeAll(key);
+ }
+ addDimension("renamed_0");
+ newOperation.addDimensionNameConstraints(this);
+ return true;
+ }
+ return false;
+ }
+
private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations);
if ( solution == null) {