diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-01-11 13:44:12 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-01-11 13:44:12 +0100 |
commit | 20194ff27f6d20ec22a951af068631a7e2305590 (patch) | |
tree | c1bcaae3e677dca252013aef88f5ea47a7852ee9 /config-model | |
parent | b6db8a09a19685e165a0759a5d4f7e70b60fa96b (diff) | |
parent | 7e2e38d17a51d0ca93dc74b8e7e0d34c5eeb19af (diff) |
Merge pull request #4618 from vespa-engine/bratseth/tf-constants-as-files
Use constant tensor files WIP
Diffstat (limited to 'config-model')
9 files changed, 87 insertions, 56 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index ec3100bc6b9..bacff94d776 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -39,7 +39,7 @@ public class RankProfile implements Serializable, Cloneable { private final String name; /** The search definition owning this profile, or null if none */ - private Search search=null; + private Search search = null; /** The name of the rank profile inherited by this */ private String inheritedName = null; @@ -51,7 +51,7 @@ public class RankProfile implements Serializable, Cloneable { protected Set<RankSetting> rankSettings = new java.util.LinkedHashSet<>(); /** The ranking expression to be used for first phase */ - private RankingExpression firstPhaseRanking= null; + private RankingExpression firstPhaseRanking = null; /** The ranking expression to be used for second phase */ private RankingExpression secondPhaseRanking = null; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java index 16e57ee913d..c65e0fad1c7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstant.java @@ -6,7 +6,7 @@ import com.yahoo.tensor.TensorType; import java.util.Objects; /** - * Represents a global ranking constant (declared in a .sd file) + * Represents a global ranking constant * * @author arnej */ @@ -16,23 +16,35 @@ public class RankingConstant { private final String name; private TensorType tensorType = null; private String fileName = null; - private String fileRef = ""; + private String fileReference = ""; public RankingConstant(String name) { this.name = name; } - public void setFileName(String fileName) { + public RankingConstant(String name, TensorType type, String fileName) { + this(name); + this.tensorType = type; + this.fileName = fileName; + validate(); + } + + public void setFileName(String fileName) { Objects.requireNonNull(fileName, "Filename cannot be null"); - this.fileName = fileName; + this.fileName = fileName; } - public void setFileReference(String fileRef) { this.fileRef = fileRef; } + /** + * Set the internally generated reference to this file used to identify this instance of the file for + * file distribution. + */ + public void setFileReference(String fileReference) { this.fileReference = fileReference; } + public void setType(TensorType tensorType) { this.tensorType = tensorType; } public String getName() { return name; } public String getFileName() { return fileName; } - public String getFileReference() { return fileRef; } + public String getFileReference() { return fileReference; } public TensorType getTensorType() { return tensorType; } public String getType() { return tensorType.toString(); } @@ -47,7 +59,7 @@ public class RankingConstant { StringBuilder b = new StringBuilder(); b.append("constant '").append(name) .append("' from file '").append(fileName) - .append("' with ref '").append(fileRef) + .append("' with ref '").append(fileReference) .append("' of type '").append(tensorType) .append("'"); return b.toString(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index 7baa3eb170f..bd7b8ce6e15 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -58,7 +58,7 @@ public class Search implements Serializable, ImmutableSearch { // Field sets private FieldSets fieldSets = new FieldSets(); - + // Whether or not this object has been processed. private boolean processed; @@ -162,13 +162,13 @@ public class Search implements Serializable, ImmutableSearch { docType = document; } - public void addRankingConstant(RankingConstant rConstant) { - rConstant.validate(); - String name = rConstant.getName(); + public void addRankingConstant(RankingConstant constant) { + constant.validate(); + String name = constant.getName(); if (rankingConstants.get(name) != null) { throw new IllegalArgumentException("Ranking constant '"+name+"' defined twice"); } - rankingConstants.put(name, rConstant); + rankingConstants.put(name, constant); } public Iterable<RankingConstant> getRankingConstants() { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index d05027dda39..a36384ce6f2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,17 +1,22 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import java.io.File; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -48,7 +53,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); - String modelPath = asString(feature.getArguments().expressions().get(0)); + String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0)); TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); // Find the specified expression @@ -58,7 +63,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil optionalArgument(2, feature.getArguments())); // Add all constants (after finding outputs to fail faster when the output is not found) - result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); + if (1==1) + result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); + else // correct way, disabled for now + result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v)); return expression.getRoot(); } @@ -120,6 +128,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) { + File constantFilePath = new File(modelPath, "converted_variables"); + if ( ! constantFilePath.exists() ) { + if ( ! constantFilePath.mkdir() ) + throw new IllegalStateException("Could not create directory " + constantFilePath); + } + + File constantFile = new File(constantFilePath, constantName + ".json"); + // writeAsVespaTensor(constantValue, constantFile); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFilePath.getPath())); + } + private String skippedOutputsDescription(TensorFlowModel.Signature signature) { if (signature.skippedOutputs().isEmpty()) return ""; StringBuilder b = new StringBuilder(": "); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java index c686f023d5b..a1f372f2307 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankingConstantsValidator.java @@ -11,7 +11,6 @@ import com.yahoo.vespa.model.application.validation.ConstantTensorJsonValidator. import com.yahoo.vespa.model.search.SearchDefinition; import java.io.FileNotFoundException; -import java.io.Reader; /** * RankingConstantsValidator validates all constant tensors (ranking constants) bundled with an application package @@ -22,10 +21,11 @@ import java.io.Reader; public class RankingConstantsValidator extends Validator { private static class ExceptionMessageCollector { - public String combinedMessage; - public boolean exceptionsOccurred = false; - public ExceptionMessageCollector(String messagePrelude) { + String combinedMessage; + boolean exceptionsOccurred = false; + + ExceptionMessageCollector(String messagePrelude) { this.combinedMessage = messagePrelude; } @@ -36,8 +36,8 @@ public class RankingConstantsValidator extends Validator { } } - public static class TensorValidationFailed extends RuntimeException { - public TensorValidationFailed(String message) { + static class TensorValidationFailed extends RuntimeException { + TensorValidationFailed(String message) { super(message); } } @@ -45,7 +45,7 @@ public class RankingConstantsValidator extends Validator { @Override public void validate(VespaModel model, DeployState deployState) { ApplicationPackage applicationPackage = deployState.getApplicationPackage(); - ExceptionMessageCollector exceptionMessageCollector = new ExceptionMessageCollector("Failed to validate constant tensor file(s):"); + ExceptionMessageCollector exceptionMessageCollector = new ExceptionMessageCollector("Invalid constant tensor file(s):"); for (SearchDefinition sd : deployState.getSearchDefinitions()) { for (RankingConstant rc : sd.getSearch().getRankingConstants()) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java index ec621e743fc..e3eb66e6a18 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java @@ -1,19 +1,20 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.search; +import com.yahoo.config.FileReference; +import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.prelude.fastsearch.DocumentdbInfoConfig; -import com.yahoo.vespa.config.search.DispatchConfig; -import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.vespa.config.search.AttributesConfig; -import com.yahoo.config.model.producer.AbstractConfigProducer; -import com.yahoo.config.model.ConfigModelRepo; -import com.yahoo.search.config.IndexInfoConfig; +import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.model.utils.FileSender; -import com.yahoo.config.FileReference; -import java.util.*; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; /** * Superclass for search clusters. @@ -21,11 +22,11 @@ import java.util.*; * @author Peter Boros */ public abstract class AbstractSearchCluster extends AbstractConfigProducer - implements + implements DocumentdbInfoConfig.Producer, IndexInfoConfig.Producer, - IlscriptsConfig.Producer -{ + IlscriptsConfig.Producer { + private Double queryTimeout; protected String clusterName; protected int index; @@ -36,9 +37,9 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer public void prepareToDistributeFiles(List<SearchNode> backends) { for (SearchDefinitionSpec sds : localSDS) { - for (RankingConstant rc : sds.getSearchDefinition().getSearch().getRankingConstants()) { - FileReference reference = FileSender.sendFileToServices(rc.getFileName(), backends); - rc.setFileReference(reference.value()); + for (RankingConstant constant : sds.getSearchDefinition().getSearch().getRankingConstants()) { + FileReference reference = FileSender.sendFileToServices(constant.getFileName(), backends); + constant.setFileReference(reference.value()); } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java index 4d4aea93a36..32548039fdd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/DocumentDatabase.java @@ -3,19 +3,17 @@ package com.yahoo.vespa.model.search; import com.yahoo.config.model.producer.AbstractConfigProducer; import com.yahoo.search.config.IndexInfoConfig; -import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.vespa.config.search.ImportedFieldsConfig; -import com.yahoo.vespa.config.search.SummaryConfig; import com.yahoo.vespa.config.search.IndexschemaConfig; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.config.search.SummaryConfig; import com.yahoo.vespa.config.search.SummarymapConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.summary.JuniperrcConfig; import com.yahoo.vespa.configdefinition.IlscriptsConfig; -import com.yahoo.config.FileReference; -import com.yahoo.vespa.model.utils.FileSender; /** * Represents a document database and the backend configuration needed for this database. @@ -59,17 +57,17 @@ public class DocumentDatabase extends AbstractConfigProducer implements public void getConfig(IndexInfoConfig.Builder builder) { derivedCfg.getIndexInfo().getConfig(builder); } - + @Override public void getConfig(IlscriptsConfig.Builder builder) { derivedCfg.getIndexingScript().getConfig(builder); } - + @Override public void getConfig(AttributesConfig.Builder builder) { derivedCfg.getAttributeFields().getConfig(builder); } - + @Override public void getConfig(RankProfilesConfig.Builder builder) { derivedCfg.getRankProfileList().getConfig(builder); @@ -77,15 +75,15 @@ public class DocumentDatabase extends AbstractConfigProducer implements @Override public void getConfig(RankingConstantsConfig.Builder builder) { - for (RankingConstant rConstant : derivedCfg.getSearch().getRankingConstants()) { - if ("".equals(rConstant.getFileReference())) { - System.err.println("INVALID rank constant "+rConstant.getName()+" [missing file reference]"); + for (RankingConstant constant : derivedCfg.getSearch().getRankingConstants()) { + if ("".equals(constant.getFileReference())) { + System.err.println("INVALID rank constant "+constant.getName()+" [missing file reference]"); // TODO: Throw or log warning continue; } builder.constant(new RankingConstantsConfig.Constant.Builder() - .name(rConstant.getName()) - .fileref(rConstant.getFileReference()) - .type(rConstant.getType())); + .name(constant.getName()) + .fileref(constant.getFileReference()) + .type(constant.getType())); } } @@ -93,17 +91,17 @@ public class DocumentDatabase extends AbstractConfigProducer implements public void getConfig(IndexschemaConfig.Builder builder) { derivedCfg.getIndexSchema().getConfig(builder); } - + @Override public void getConfig(JuniperrcConfig.Builder builder) { derivedCfg.getJuniperrc().getConfig(builder); } - + @Override public void getConfig(SummarymapConfig.Builder builder) { derivedCfg.getSummaryMap().getConfig(builder); } - + @Override public void getConfig(SummaryConfig.Builder builder) { derivedCfg.getSummaries().getConfig(builder); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java b/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java index 605c87912ac..413363d7b0d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/utils/FileSender.java @@ -17,7 +17,6 @@ import java.util.*; * Utility methods for sending files to a collection of nodes. * * @author gjoranv - * @since 5.1.9 */ public class FileSender implements Serializable { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 8fcd821adfd..3ec621618e5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -15,7 +15,8 @@ import static org.junit.Assert.fail; */ public class RankingExpressionWithTensorFlowTestCase { - private final String modelDirectory = "src/test/integration/tensorflow/mnist_softmax/saved"; + // The "../" is to escape the "models/" element prepended to the path + private final String modelDirectory = "../src/test/integration/tensorflow/mnist_softmax/saved"; private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; @Test |