diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-03 11:29:43 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-03 11:29:43 +0200 |
commit | 3789127189224d6cbd6f109b9a95f848869ea6cc (patch) | |
tree | 79cef74e6c61da059ed0eae79632fa001433ddc2 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java | |
parent | 706cb2d3b2d623318ba9c0a8db0e4355448af65a (diff) |
for testing onlylesters/bert-testing
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 14aa3ebf84e..3c8a6bde232 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -74,6 +75,8 @@ public class IntermediateGraph { renameDimensions(); } + static int counter = 0; + /** * Find dimension names to avoid excessive renaming while evaluating the model. */ @@ -93,16 +96,34 @@ public class IntermediateGraph { } private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { + Set<String> operations = new HashSet<>(); + addDimensionNameConstraints(operation, renamer, operations); + } + + private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) { + if (operations.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, operations)); operation.addDimensionNameConstraints(renamer); + operations.add(operation.name()); } } private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { + Set<String> operations = new HashSet<>(); + renameDimensions(operation, renamer, operations); + } + + private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) { + if (operations.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.inputs().forEach(input -> renameDimensions(input, renamer, operations)); operation.renameDimensions(renamer); + operations.add(operation.name()); } } |