diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-25 12:33:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-25 12:33:10 +0200 |
commit | 011d94ba6809d9c931a9e3d8d8bbcf9a28a97a61 (patch) | |
tree | 1cf090b79ef73ec1cd6e9799642ea62c86377578 /searchlib | |
parent | ca44e13502f4d6b7efe1ba327973900f9b8e0f44 (diff) | |
parent | ed3923c95484e57e4eac43c25e985512ca3aa645 (diff) |
Merge pull request #6672 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-6
Bratseth/generate rank profiles for all models part 6
Diffstat (limited to 'searchlib')
2 files changed, 52 insertions, 2 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index ec7bdcf5f2b..045844ee219 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -1,5 +1,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.collections.Pair; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -97,6 +98,37 @@ public class ImportedModel { void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } /** + * Returns all the outputs of this by name. The names consist of one to three parts + * separated by dot, where the first part is the model name, the second is the signature name + * if signatures are used, or the expression name if signatures are not used and there are multiple + * expressions, and the third is the output name if signature names are used. + */ + public List<Pair<String, RankingExpression>> outputExpressions() { + List<Pair<String, RankingExpression>> names = new ArrayList<>(); + for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { + for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) + names.add(new Pair<>(name + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), + expressions().get(outputEntry.getValue()))); + if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs + names.add(new Pair<>(name + "." + signatureEntry.getKey(), + expressions().get(signatureEntry.getKey()))); + } + if (signatures().isEmpty()) { // fallback for models without signatures + if (expressions().size() == 1) {// Use just model name + names.add(new Pair<>(name, + expressions().values().iterator().next())); + } + else { + for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { + names.add(new Pair<>(name + "." + expressionEntry.getKey(), + expressionEntry.getValue())); + } + } + } + return names; + } + + /** * A signature is a set of named inputs and outputs, where the inputs maps to argument * ("placeholder") names+types, and outputs maps to expressions nodes. * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java index 3fa6141a696..827b1911369 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -6,6 +6,9 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import java.util.Optional; /** @@ -63,9 +66,24 @@ public class ImportedModels { return importedModels.get(toName(modelPath)); } + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedModel> all() { + return importedModels.values(); + } + private static String toName(File modelPath) { - Path localPath = Path.fromString(modelPath.toString()).getChildPath(); - return localPath.toString().replace("/", "_").replace('.', '_'); + String localPath = concatenateAfterModelsDirectory(Path.fromString(modelPath.toString())); + return localPath.replace('.', '_'); + } + + private static String concatenateAfterModelsDirectory(Path path) { + boolean afterModels = false; + StringBuilder result = new StringBuilder(); + for (String element : path.elements()) { + if (afterModels) result.append(element).append("_"); + if (element.equals("models")) afterModels = true; + } + return result.substring(0, result.length()-1); } } |