aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-04-21 15:26:58 +0200
committerLester Solbakken <lesters@oath.com>2020-04-21 15:26:58 +0200
commitaad5c7184f37e1441c928efa77b434620742ff88 (patch)
tree34a92e7f954aa92e21d48816335771ff607fe404 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
parent6f5ca49e45cdc8262fcf360b1c731a393385ffa8 (diff)
Update model-integration for supporting BERT-type models
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java21
1 files changed, 15 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 14aa3ebf84e..ea981603481 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -81,28 +82,36 @@ public class IntermediateGraph {
DimensionRenamer renamer = new DimensionRenamer(this);
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- addDimensionNameConstraints(operations.get(output), renamer);
+ addDimensionNameConstraints(operations.get(output), renamer, new HashSet<>());
}
}
renamer.solve();
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- renameDimensions(operations.get(output), renamer);
+ renameDimensions(operations.get(output), renamer, new HashSet<>());
}
}
}
- private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) {
+ if (processed.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, processed));
operation.addDimensionNameConstraints(renamer);
+ processed.add(operation.name());
}
}
- private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) {
+ if (processed.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.inputs().forEach(input -> renameDimensions(input, renamer, processed));
operation.renameDimensions(renamer);
+ processed.add(operation.name());
}
}