diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-24 10:50:29 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-24 10:50:29 +0200 |
commit | b88fd2c2b9c3e220b6884da0392a3602fb3aa994 (patch) | |
tree | 1a175bc9a8b1c64adb2b9a97da9ef600c6094db8 | |
parent | dd34698bcd051c1eff8d94506a7ac7a1545ee1d2 (diff) |
Let ImportedModel know how to address each expression
4 files changed, 62 insertions, 36 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index a38fbe1aaa0..0dcba922ee3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -145,39 +145,11 @@ public class ConvertedModel { // Add expressions Map<String, RankingExpression> expressions = new HashMap<>(); - for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) { - for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) { - addExpression(model.expressions().get(outputEntry.getValue()), - modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs - addExpression(model.expressions().get(signatureEntry.getKey()), - modelName + "." + signatureEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - } - if (model.signatures().isEmpty()) { // fallback: Model without signatures - if (model.expressions().size() == 1) { // Use just model name - addExpression(model.expressions().values().iterator().next(), - modelName, - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - else { - for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) { - addExpression(expressionEntry.getValue(), - modelName + "." + expressionEntry.getKey(), - constantsReplacedByMacros, - model, store, profile, queryProfiles, - expressions); - } - } + for (Pair<String, RankingExpression> output : model.outputExpressions(modelName)) { + addExpression(output.getSecond(), output.getFirst(), + constantsReplacedByMacros, + model, store, profile, queryProfiles, + expressions); } // Transform and save macro - must come after reading expressions due to optimization transforms diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index f3e7a9623d1..1ebcc73027d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model; +import com.google.common.collect.ImmutableList; import com.yahoo.config.ConfigBuilder; import com.yahoo.config.ConfigInstance; import com.yahoo.config.ConfigInstance.Builder; @@ -25,6 +26,10 @@ import com.yahoo.config.model.producer.AbstractConfigProducerRoot; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.log.LogLevel; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigKey; import com.yahoo.vespa.config.ConfigPayload; @@ -94,11 +99,11 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri private ApplicationConfigProducerRoot root = null; - /** - * Generic service instances - service clusters which have no specific model - */ + /** Generic service instances - service clusters which have no specific model */ private List<ServiceCluster> serviceClusters = new ArrayList<>(); + private final ImmutableList<RankProfile> globalRankProfiles; + private DeployState deployState; /** The validation overrides of this. This is never null. */ @@ -148,6 +153,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri if (complete) { // create a a completed, frozen model configModelRepo.readConfigModels(deployState, builder, root, configModelRegistry); addServiceClusters(deployState.getApplicationPackage(), builder); + this.globalRankProfiles = createGlobalRankProfiles(deployState.getImportedModels()); this.allocatedHosts = AllocatedHosts.withHosts(root.getHostSystem().getHostSpecs()); // must happen after the two lines above setupRouting(); this.fileDistributor = root.getFileDistributionConfigProducer().getFileDistributor(); @@ -161,6 +167,7 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri else { // create a model with no services instantiated and the given file distributor this.allocatedHosts = AllocatedHosts.withHosts(root.getHostSystem().getHostSpecs()); this.fileDistributor = fileDistributor; + this.globalRankProfiles = ImmutableList.of(); } } @@ -185,6 +192,13 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri serviceClusters.add(sc); } + /** + * Creates a rank profile not attached to any search definition, for each imported model in the application package + */ + private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels) { + return ImmutableList.of(); + } + private void setupRouting() { root.setupRouting(configModelRepo); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index ec7bdcf5f2b..0ddc202a3e0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -1,5 +1,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.collections.Pair; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -97,6 +98,37 @@ public class ImportedModel { void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } /** + * Returns all the outputs of this by name. The names consist of one to three parts + * separated by dot, where the first part is the model name, the second is the signature name + * if signatures are used, or the expression name if signatures are not used and there are multiple + * expressions, and the third is the output name if signature names are used. + */ + public List<Pair<String, RankingExpression>> outputExpressions(String modelName) { + List<Pair<String, RankingExpression>> names = new ArrayList<>(); + for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { + for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) + names.add(new Pair<>(modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(), + expressions().get(outputEntry.getValue()))); + if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs + names.add(new Pair<>(modelName + "." + signatureEntry.getKey(), + expressions().get(signatureEntry.getKey()))); + } + if (signatures().isEmpty()) { // fallback for models without signatures + if (expressions().size() == 1) {// Use just model name + names.add(new Pair<>(modelName, + expressions().values().iterator().next())); + } + else { + for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { + names.add(new Pair<>(modelName + "." + expressionEntry.getKey(), + expressionEntry.getValue())); + } + } + } + return names; + } + + /** * A signature is a set of named inputs and outputs, where the inputs maps to argument * ("placeholder") names+types, and outputs maps to expressions nodes. * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java index 3fa6141a696..268064f98d8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java @@ -6,6 +6,9 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.path.Path; import java.io.File; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import java.util.Optional; /** @@ -63,6 +66,11 @@ public class ImportedModels { return importedModels.get(toName(modelPath)); } + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedModel> all() { + return importedModels.values(); + } + private static String toName(File modelPath) { Path localPath = Path.fromString(modelPath.toString()).getChildPath(); return localPath.toString().replace("/", "_").replace('.', '_'); |