aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java25
1 files changed, 23 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 14aa3ebf84e..3c8a6bde232 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -74,6 +75,8 @@ public class IntermediateGraph {
renameDimensions();
}
+ static int counter = 0;
+
/**
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
@@ -93,16 +96,34 @@ public class IntermediateGraph {
}
private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ Set<String> operations = new HashSet<>();
+ addDimensionNameConstraints(operation, renamer, operations);
+ }
+
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
+ if (operations.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, operations));
operation.addDimensionNameConstraints(renamer);
+ operations.add(operation.name());
}
}
private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ Set<String> operations = new HashSet<>();
+ renameDimensions(operation, renamer, operations);
+ }
+
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
+ if (operations.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.inputs().forEach(input -> renameDimensions(input, renamer, operations));
operation.renameDimensions(renamer);
+ operations.add(operation.name());
}
}