summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-25 12:33:10 +0200
committerGitHub <noreply@github.com>2018-08-25 12:33:10 +0200
commit011d94ba6809d9c931a9e3d8d8bbcf9a28a97a61 (patch)
tree1cf090b79ef73ec1cd6e9799642ea62c86377578 /searchlib
parentca44e13502f4d6b7efe1ba327973900f9b8e0f44 (diff)
parented3923c95484e57e4eac43c25e985512ca3aa645 (diff)
Merge pull request #6672 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-6
Bratseth/generate rank profiles for all models part 6
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java22
2 files changed, 52 insertions, 2 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index ec7bdcf5f2b..045844ee219 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -1,5 +1,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.collections.Pair;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -97,6 +98,37 @@ public class ImportedModel {
void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
/**
+ * Returns all the outputs of this by name. The names consist of one to three parts
+ * separated by dot, where the first part is the model name, the second is the signature name
+ * if signatures are used, or the expression name if signatures are not used and there are multiple
+ * expressions, and the third is the output name if signature names are used.
+ */
+ public List<Pair<String, RankingExpression>> outputExpressions() {
+ List<Pair<String, RankingExpression>> names = new ArrayList<>();
+ for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
+ for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
+ names.add(new Pair<>(name + "." + signatureEntry.getKey() + "." + outputEntry.getKey(),
+ expressions().get(outputEntry.getValue())));
+ if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
+ names.add(new Pair<>(name + "." + signatureEntry.getKey(),
+ expressions().get(signatureEntry.getKey())));
+ }
+ if (signatures().isEmpty()) { // fallback for models without signatures
+ if (expressions().size() == 1) {// Use just model name
+ names.add(new Pair<>(name,
+ expressions().values().iterator().next()));
+ }
+ else {
+ for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
+ names.add(new Pair<>(name + "." + expressionEntry.getKey(),
+ expressionEntry.getValue()));
+ }
+ }
+ }
+ return names;
+ }
+
+ /**
* A signature is a set of named inputs and outputs, where the inputs maps to argument
* ("placeholder") names+types, and outputs maps to expressions nodes.
* Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
index 3fa6141a696..827b1911369 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
@@ -6,6 +6,9 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
import java.io.File;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
import java.util.Optional;
/**
@@ -63,9 +66,24 @@ public class ImportedModels {
return importedModels.get(toName(modelPath));
}
+ /** Returns an immutable collection of all the imported models */
+ public Collection<ImportedModel> all() {
+ return importedModels.values();
+ }
+
private static String toName(File modelPath) {
- Path localPath = Path.fromString(modelPath.toString()).getChildPath();
- return localPath.toString().replace("/", "_").replace('.', '_');
+ String localPath = concatenateAfterModelsDirectory(Path.fromString(modelPath.toString()));
+ return localPath.replace('.', '_');
+ }
+
+ private static String concatenateAfterModelsDirectory(Path path) {
+ boolean afterModels = false;
+ StringBuilder result = new StringBuilder();
+ for (String element : path.elements()) {
+ if (afterModels) result.append(element).append("_");
+ if (element.equals("models")) afterModels = true;
+ }
+ return result.substring(0, result.length()-1);
}
}