aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-03 09:47:49 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-03 09:47:49 -0700
commit0ce6fa7cbdf71fd39cb5bb18accfa84a20e7e120 (patch)
tree0afa3697a10fad159883a402a2f367ee8175c027 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
parentfdc6a48fe913de5f5a84c6eb42123543c1d2ee46 (diff)
Inject rename to relax constraint proof of concept
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java22
1 files changed, 12 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 54d4bd3cb0a..6c583d960bd 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -3,9 +3,11 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -20,7 +22,7 @@ import java.util.Set;
public class IntermediateGraph {
private final String modelName;
- private final Map<String, IntermediateOperation> index = new HashMap<>();
+ private final Map<String, IntermediateOperation> operations = new HashMap<>();
private final Map<String, GraphSignature> signatures = new HashMap<>();
private static class GraphSignature {
@@ -37,11 +39,11 @@ public class IntermediateGraph {
}
public IntermediateOperation put(String key, IntermediateOperation operation) {
- return index.put(key, operation);
+ return operations.put(key, operation);
}
public IntermediateOperation get(String key) {
- return index.get(key);
+ return operations.get(key);
}
public Set<String> signatures() {
@@ -61,11 +63,11 @@ public class IntermediateGraph {
}
public boolean alreadyImported(String key) {
- return index.containsKey(key);
+ return operations.containsKey(key);
}
- public Collection<IntermediateOperation> operations() {
- return index.values();
+ public Map<String, IntermediateOperation> operations() {
+ return operations;
}
void optimize() {
@@ -76,16 +78,16 @@ public class IntermediateGraph {
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
private void renameDimensions() {
- DimensionRenamer renamer = new DimensionRenamer();
+ DimensionRenamer renamer = new DimensionRenamer(this);
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- addDimensionNameConstraints(index.get(output), renamer);
+ addDimensionNameConstraints(operations.get(output), renamer);
}
}
renamer.solve();
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- renameDimensions(index.get(output), renamer);
+ renameDimensions(operations.get(output), renamer);
}
}
}
@@ -111,7 +113,7 @@ public class IntermediateGraph {
public String toFullString() {
StringBuilder b = new StringBuilder();
- for (var input : index.entrySet())
+ for (var input : operations.entrySet())
b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n");
return b.toString();
}