summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-24 10:50:29 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-24 10:50:29 +0200
commitb88fd2c2b9c3e220b6884da0392a3602fb3aa994 (patch)
tree1a175bc9a8b1c64adb2b9a97da9ef600c6094db8
parentdd34698bcd051c1eff8d94506a7ac7a1545ee1d2 (diff)
Let ImportedModel know how to address each expression
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java38
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java8
4 files changed, 62 insertions, 36 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index a38fbe1aaa0..0dcba922ee3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -145,39 +145,11 @@ public class ConvertedModel {
// Add expressions
Map<String, RankingExpression> expressions = new HashMap<>();
- for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) {
- for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) {
- addExpression(model.expressions().get(outputEntry.getValue()),
- modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(),
- constantsReplacedByMacros,
- model, store, profile, queryProfiles,
- expressions);
- }
- if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs
- addExpression(model.expressions().get(signatureEntry.getKey()),
- modelName + "." + signatureEntry.getKey(),
- constantsReplacedByMacros,
- model, store, profile, queryProfiles,
- expressions);
- }
- }
- if (model.signatures().isEmpty()) { // fallback: Model without signatures
- if (model.expressions().size() == 1) { // Use just model name
- addExpression(model.expressions().values().iterator().next(),
- modelName,
- constantsReplacedByMacros,
- model, store, profile, queryProfiles,
- expressions);
- }
- else {
- for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) {
- addExpression(expressionEntry.getValue(),
- modelName + "." + expressionEntry.getKey(),
- constantsReplacedByMacros,
- model, store, profile, queryProfiles,
- expressions);
- }
- }
+ for (Pair<String, RankingExpression> output : model.outputExpressions(modelName)) {
+ addExpression(output.getSecond(), output.getFirst(),
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
}
// Transform and save macro - must come after reading expressions due to optimization transforms
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index f3e7a9623d1..1ebcc73027d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model;
+import com.google.common.collect.ImmutableList;
import com.yahoo.config.ConfigBuilder;
import com.yahoo.config.ConfigInstance;
import com.yahoo.config.ConfigInstance.Builder;
@@ -25,6 +26,10 @@ import com.yahoo.config.model.producer.AbstractConfigProducerRoot;
import com.yahoo.config.model.producer.UserConfigRepo;
import com.yahoo.config.provision.AllocatedHosts;
import com.yahoo.log.LogLevel;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.vespa.config.ConfigKey;
import com.yahoo.vespa.config.ConfigPayload;
@@ -94,11 +99,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
private ApplicationConfigProducerRoot root = null;
- /**
- * Generic service instances - service clusters which have no specific model
- */
+ /** Generic service instances - service clusters which have no specific model */
private List<ServiceCluster> serviceClusters = new ArrayList<>();
+ private final ImmutableList<RankProfile> globalRankProfiles;
+
private DeployState deployState;
/** The validation overrides of this. This is never null. */
@@ -148,6 +153,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
if (complete) { // create a a completed, frozen model
configModelRepo.readConfigModels(deployState, builder, root, configModelRegistry);
addServiceClusters(deployState.getApplicationPackage(), builder);
+ this.globalRankProfiles = createGlobalRankProfiles(deployState.getImportedModels());
this.allocatedHosts = AllocatedHosts.withHosts(root.getHostSystem().getHostSpecs()); // must happen after the two lines above
setupRouting();
this.fileDistributor = root.getFileDistributionConfigProducer().getFileDistributor();
@@ -161,6 +167,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
else { // create a model with no services instantiated and the given file distributor
this.allocatedHosts = AllocatedHosts.withHosts(root.getHostSystem().getHostSpecs());
this.fileDistributor = fileDistributor;
+ this.globalRankProfiles = ImmutableList.of();
}
}
@@ -185,6 +192,13 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
serviceClusters.add(sc);
}
+ /**
+ * Creates a rank profile not attached to any search definition, for each imported model in the application package
+ */
+ private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels) {
+ return ImmutableList.of();
+ }
+
private void setupRouting() {
root.setupRouting(configModelRepo);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index ec7bdcf5f2b..0ddc202a3e0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -1,5 +1,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.collections.Pair;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -97,6 +98,37 @@ public class ImportedModel {
void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
/**
+ * Returns all the outputs of this by name. The names consist of one to three parts
+ * separated by dot, where the first part is the model name, the second is the signature name
+ * if signatures are used, or the expression name if signatures are not used and there are multiple
+ * expressions, and the third is the output name if signature names are used.
+ */
+ public List<Pair<String, RankingExpression>> outputExpressions(String modelName) {
+ List<Pair<String, RankingExpression>> names = new ArrayList<>();
+ for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
+ for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
+ names.add(new Pair<>(modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(),
+ expressions().get(outputEntry.getValue())));
+ if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
+ names.add(new Pair<>(modelName + "." + signatureEntry.getKey(),
+ expressions().get(signatureEntry.getKey())));
+ }
+ if (signatures().isEmpty()) { // fallback for models without signatures
+ if (expressions().size() == 1) {// Use just model name
+ names.add(new Pair<>(modelName,
+ expressions().values().iterator().next()));
+ }
+ else {
+ for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
+ names.add(new Pair<>(modelName + "." + expressionEntry.getKey(),
+ expressionEntry.getValue()));
+ }
+ }
+ }
+ return names;
+ }
+
+ /**
* A signature is a set of named inputs and outputs, where the inputs maps to argument
* ("placeholder") names+types, and outputs maps to expressions nodes.
* Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
index 3fa6141a696..268064f98d8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
@@ -6,6 +6,9 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.path.Path;
import java.io.File;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
import java.util.Optional;
/**
@@ -63,6 +66,11 @@ public class ImportedModels {
return importedModels.get(toName(modelPath));
}
+ /** Returns an immutable collection of all the imported models */
+ public Collection<ImportedModel> all() {
+ return importedModels.values();
+ }
+
private static String toName(File modelPath) {
Path localPath = Path.fromString(modelPath.toString()).getChildPath();
return localPath.toString().replace("/", "_").replace('.', '_');