diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
commit | aad5c7184f37e1441c928efa77b434620742ff88 (patch) | |
tree | 34a92e7f954aa92e21d48816335771ff607fe404 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java | |
parent | 6f5ca49e45cdc8262fcf360b1c731a393385ffa8 (diff) |
Update model-integration for supporting BERT-type models
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 14aa3ebf84e..ea981603481 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -81,28 +82,36 @@ public class IntermediateGraph { DimensionRenamer renamer = new DimensionRenamer(this); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - addDimensionNameConstraints(operations.get(output), renamer); + addDimensionNameConstraints(operations.get(output), renamer, new HashSet<>()); } } renamer.solve(); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - renameDimensions(operations.get(output), renamer); + renameDimensions(operations.get(output), renamer, new HashSet<>()); } } } - private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { + private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) { + if (processed.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, processed)); operation.addDimensionNameConstraints(renamer); + processed.add(operation.name()); } } - private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { + private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) { + if (processed.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.inputs().forEach(input -> renameDimensions(input, renamer, processed)); operation.renameDimensions(renamer); + processed.add(operation.name()); } } |