diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-03 09:47:49 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-03 09:47:49 -0700 |
commit | 0ce6fa7cbdf71fd39cb5bb18accfa84a20e7e120 (patch) | |
tree | 0afa3697a10fad159883a402a2f367ee8175c027 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java | |
parent | fdc6a48fe913de5f5a84c6eb42123543c1d2ee46 (diff) |
Inject rename to relax constraint proof of concept
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 54d4bd3cb0a..6c583d960bd 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -3,9 +3,11 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -20,7 +22,7 @@ import java.util.Set; public class IntermediateGraph { private final String modelName; - private final Map<String, IntermediateOperation> index = new HashMap<>(); + private final Map<String, IntermediateOperation> operations = new HashMap<>(); private final Map<String, GraphSignature> signatures = new HashMap<>(); private static class GraphSignature { @@ -37,11 +39,11 @@ public class IntermediateGraph { } public IntermediateOperation put(String key, IntermediateOperation operation) { - return index.put(key, operation); + return operations.put(key, operation); } public IntermediateOperation get(String key) { - return index.get(key); + return operations.get(key); } public Set<String> signatures() { @@ -61,11 +63,11 @@ public class IntermediateGraph { } public boolean alreadyImported(String key) { - return index.containsKey(key); + return operations.containsKey(key); } - public Collection<IntermediateOperation> operations() { - return index.values(); + public Map<String, IntermediateOperation> operations() { + return operations; } void optimize() { @@ -76,16 +78,16 @@ public class IntermediateGraph { * Find dimension names to avoid excessive renaming while evaluating the model. */ private void renameDimensions() { - DimensionRenamer renamer = new DimensionRenamer(); + DimensionRenamer renamer = new DimensionRenamer(this); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - addDimensionNameConstraints(index.get(output), renamer); + addDimensionNameConstraints(operations.get(output), renamer); } } renamer.solve(); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - renameDimensions(index.get(output), renamer); + renameDimensions(operations.get(output), renamer); } } } @@ -111,7 +113,7 @@ public class IntermediateGraph { public String toFullString() { StringBuilder b = new StringBuilder(); - for (var input : index.entrySet()) + for (var input : operations.entrySet()) b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n"); return b.toString(); } |