aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2018-01-17 13:51:14 +0100
committerGitHub <noreply@github.com>2018-01-17 13:51:14 +0100
commitfd26b36e3607df463b35e856b37d24b5e3514fb7 (patch)
tree403836969d050736403f6512a455198a2c63edad /config-model
parentceec6d572c06ff812715c97d2c35383c48402f24 (diff)
parentc84b8f952ef5857aa44fad479551eda1f3a4e106 (diff)
Merge pull request #4692 from vespa-engine/bratseth/store-converted-expressions-in-zk
Bratseth/store converted expressions in zk
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/ConfigModel.java2
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java14
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/Search.java7
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java287
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java12
-rw-r--r--config-model/src/test/integration/tensorflow/models/mnist_softmax/mnist_sftmax_with_saving.py (renamed from config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py)0
-rw-r--r--config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/saved_model.pbtxt (renamed from config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt)0
-rw-r--r--config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.data-00000-of-00001 (renamed from config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001)bin31400 -> 31400 bytes
-rw-r--r--config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.index (renamed from config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index)bin159 -> 159 bytes
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java267
13 files changed, 477 insertions, 128 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java b/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java
index 5daf5ca70a5..385cd883da4 100644
--- a/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java
+++ b/config-model/src/main/java/com/yahoo/config/model/ConfigModel.java
@@ -8,7 +8,7 @@ package com.yahoo.config.model;
*
* @author gjoranv
* @author bratseth
- * @author lulf
+ * @author Ulf Lilleengen
*/
public abstract class ConfigModel {
diff --git a/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java b/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java
index 5eb4afcc241..5912b476783 100644
--- a/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java
+++ b/config-model/src/main/java/com/yahoo/config/model/admin/AdminModel.java
@@ -21,8 +21,7 @@ import java.util.*;
/**
* Config model adaptor of the Admin class.
*
- * @author lulf
- * @since 5.1
+ * @author Ulf Lilleengen
*/
public class AdminModel extends ConfigModel {
@@ -46,8 +45,9 @@ public class AdminModel extends ConfigModel {
@Override
public void prepare(ConfigModelRepo configModelRepo) {
verifyClusterControllersOnlyDefinedForContent(configModelRepo);
- if (admin == null || admin.getClusterControllers() == null) return;
- admin.getClusterControllers().prepare();
+ if (admin == null) return;
+ if (admin.getClusterControllers() != null)
+ admin.getClusterControllers().prepare();
}
private void verifyClusterControllersOnlyDefinedForContent(ConfigModelRepo configModelRepo) {
@@ -61,9 +61,9 @@ public class AdminModel extends ConfigModel {
public static class BuilderV2 extends ConfigModelBuilder<AdminModel> {
public static final List<ConfigModelId> configModelIds =
- ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "2.0"),
+ ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "2.0"),
ConfigModelId.fromNameAndVersion("admin", "1.0"));
-
+
public BuilderV2() {
super(AdminModel.class);
}
@@ -91,7 +91,7 @@ public class AdminModel extends ConfigModel {
public static class BuilderV4 extends ConfigModelBuilder<AdminModel> {
public static final List<ConfigModelId> configModelIds =
- ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "3.0"),
+ ImmutableList.of(ConfigModelId.fromNameAndVersion("admin", "3.0"),
ConfigModelId.fromNameAndVersion("admin", "4.0"));
public BuilderV4() {
diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
index ddee0be6e9c..271ec6958ec 100644
--- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
+++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
@@ -37,8 +37,8 @@ public class MockApplicationPackage implements ApplicationPackage {
private final Optional<String> validationOverrides;
private final boolean failOnValidateXml;
- private MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir,
- String deploymentSpec, String validationOverrides, boolean failOnValidateXml) {
+ protected MockApplicationPackage(String hosts, String services, List<String> searchDefinitions, String searchDefinitionDir,
+ String deploymentSpec, String validationOverrides, boolean failOnValidateXml) {
this.hostsS = hosts;
this.servicesS = services;
this.searchDefinitions = searchDefinitions;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
index f37ab9fb89f..df5697de0d5 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java
@@ -165,9 +165,8 @@ public class Search implements Serializable, ImmutableSearch {
public void addRankingConstant(RankingConstant constant) {
constant.validate();
String name = constant.getName();
- if (rankingConstants.get(name) != null) {
- throw new IllegalArgumentException("Ranking constant '"+name+"' defined twice");
- }
+ if (rankingConstants.containsKey(name))
+ throw new IllegalArgumentException("Ranking constant '" + name + "' defined twice");
rankingConstants.put(name, constant);
}
@@ -268,6 +267,8 @@ public class Search implements Serializable, ImmutableSearch {
return sourceApplication.getRankingExpression(fileName);
}
+ public ApplicationPackage sourceApplication() { return sourceApplication; }
+
/**
* Returns a field defined in this search definition or one if its documents. Fields in this search definition takes
* precedence over document fields having the same name
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 32f8f4871df..0dd5b4166ef 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -2,14 +2,17 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.google.common.base.Joiner;
+import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
@@ -17,12 +20,16 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.File;
import java.io.IOException;
+import java.io.StringReader;
import java.io.UncheckedIOException;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -33,17 +40,14 @@ import java.util.Optional;
*
* @author bratseth
*/
-// TODO: - Verify types of macros
-// - Avoid name conflicts across models for constants
+// TODO: Verify types of macros
+// TODO: Avoid name conflicts across models for constants
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- // TODO: Make system test work with this set to true, then remove the "true" path
- private static final boolean constantsInConfig = true;
-
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<String, TensorFlowModel> importedModels = new HashMap<>();
+ private final Map<Path, TensorFlowModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -56,40 +60,48 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- try {
- if ( ! feature.getName().equals("tensorflow")) return feature;
+ if ( ! feature.getName().equals("tensorflow")) return feature;
- if (feature.getArguments().isEmpty())
- throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
+ try {
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(),
+ feature.getArguments());
+ if (store.hasTensorFlowModels())
+ return transformFromTensorFlowModel(store, context.rankProfile());
+ else // is should have previously stored model information instead
+ return transformFromStoredModel(store, context.rankProfile());
+ }
+ catch (IllegalArgumentException | UncheckedIOException e) {
+ throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
+ }
+ }
- String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0));
- TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath));
+ private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile profile) {
+ TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ k -> tensorFlowImporter.importModel(store.tensorFlowModelDir()));
- // Find the specified expression
- TensorFlowModel.Signature signature = chooseSignature(result,
- optionalArgument(1, feature.getArguments()));
- RankingExpression expression = chooseOutput(signature,
- optionalArgument(2, feature.getArguments()));
+ // Find the specified expression
+ Signature signature = chooseSignature(model, store.arguments().signature());
+ String output = chooseOutput(signature, store.arguments().output());
+ RankingExpression expression = model.expressions().get(output);
+ store.writeConverted(expression);
- // Add all constants (after finding outputs to fail faster when the output is not found)
- if (constantsInConfig)
- result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v)));
- else // correct way, disabled for now
- result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v));
+ model.constants().forEach((k, v) -> transformConstant(store, profile, k, v));
+ return expression.getRoot();
+ }
- return expression.getRoot();
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
+ private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (RankingConstant constant : store.readRankingConstants()) {
+ if (!profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
}
+ return store.readConverted().getRoot();
}
/**
* Returns the specified, existing signature, or the only signature if none is specified.
* Throws IllegalArgumentException in all other cases.
*/
- private TensorFlowModel.Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
+ private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
if ( ! signatureName.isPresent()) {
if (importResult.signatures().size() == 0)
throw new IllegalArgumentException("No signatures are available");
@@ -101,7 +113,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return importResult.signatures().values().stream().findFirst().get();
}
else {
- TensorFlowModel.Signature signature = importResult.signatures().get(signatureName.get());
+ Signature signature = importResult.signatures().get(signatureName.get());
if (signature == null)
throw new IllegalArgumentException("Model does not have the specified signature '" +
signatureName.get() + "'");
@@ -113,7 +125,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
* Returns the specified, existing output expression, or the only output expression if no output name is specified.
* Throws IllegalArgumentException in all other cases.
*/
- private RankingExpression chooseOutput(TensorFlowModel.Signature signature, Optional<String> outputName) {
+ private String chooseOutput(Signature signature, Optional<String> outputName) {
if ( ! outputName.isPresent()) {
if (signature.outputs().size() == 0)
throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
@@ -122,11 +134,11 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
Joiner.on(", ").join(signature.outputs().keySet()) +
"), one must be specified " +
"as a third argument to tensorflow()");
- return signature.outputExpression(signature.outputs().keySet().stream().findFirst().get());
+ return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
}
else {
- RankingExpression expression = signature.outputExpression(outputName.get());
- if (expression == null) {
+ String output = signature.outputs().get(outputName.get());
+ if (output == null) {
if (signature.skippedOutputs().containsKey(outputName.get()))
throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
signature.skippedOutputs().get(outputName.get()));
@@ -134,28 +146,16 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
throw new IllegalArgumentException("Model does not have the specified output '" +
outputName.get() + "'");
}
- return expression;
+ return output;
}
}
- private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) {
- try {
- if (profile.getSearch().getRankingConstants().containsKey(constantName)) return;
+ private void transformConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ if (profile.getSearch().getRankingConstants().containsKey(constantName)) return;
- File constantFilePath = new File(modelPath, "converted_variables").getCanonicalFile();
- if (!constantFilePath.exists()) {
- if (!constantFilePath.mkdir())
- throw new IOException("Could not create directory " + constantFilePath);
- }
-
- // "tbf" ending for "typed binary format" - recognized by the nodes reciving the file:
- File constantFile = new File(constantFilePath, constantName + ".tbf");
- IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue));
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath()));
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
+ Path constantPath = store.writeConstant(constantName, constantValue);
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
}
private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
@@ -165,27 +165,176 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return b.toString();
}
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ private static class ModelStore {
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ public ModelStore(ApplicationPackage application, Arguments arguments) {
+ this.application = application;
+ this.arguments = new FeatureArguments(arguments);
+ }
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasTensorFlowModels() {
+ try {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false; // No files -> no TensorFlow models
+ }
+ }
+
+ /**
+ * Returns the directory which (if hasTensorFlowModels is true)
+ * contains the source model to use for these arguments
+ */
+ public File tensorFlowModelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ public void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ public RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /**
+ * Reads the information about all the constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ public List<RankingConstant> readRankingConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ public Path writeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.rankingConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPath));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return constantPath;
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
}
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
+ /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */
+ private static class FeatureArguments {
+
+ private final Path modelPath;
+
+ /** Optional arguments */
+ private final Optional<String> signature, output;
+
+ public FeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
+ "the tensorflow model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
+
+ modelPath = Path.fromString(asString(arguments.expressions().get(0)));
+ signature = optionalArgument(1, arguments);
+ output = optionalArgument(2, arguments);
+ }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ public Path rankingConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ signature.ifPresent(s -> fileName.append(s).append("."));
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ private String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java
index 59b7388f5bb..071b3090f99 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java
@@ -73,9 +73,7 @@ public class Admin extends AbstractConfigProducer implements Serializable {
this.fileDistribution = fileDistributionConfigProducer;
}
- public Configserver getConfigserver() {
- return defaultConfigserver;
- }
+ public Configserver getConfigserver() { return defaultConfigserver; }
/** Returns the configured monitoring endpoint, or null if not configured */
public Monitoring getMonitoring() {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
index fd062dc4ea4..58fc76f1508 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/AbstractSearchCluster.java
@@ -49,16 +49,7 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer
public static final IndexingMode REALTIME = new IndexingMode("REALTIME");
public static final IndexingMode STREAMING = new IndexingMode("STREAMING");
- public static IndexingMode createIndexingMode(String ixm) {
- if ("REALTIME".equalsIgnoreCase(ixm)) {
- return REALTIME;
- } else if ("STREAMING".equalsIgnoreCase(ixm)) {
- return STREAMING;
- }
- return null;
- }
-
- private String name;
+ private final String name;
private IndexingMode(String name) {
this.name = name;
@@ -72,6 +63,7 @@ public abstract class AbstractSearchCluster extends AbstractConfigProducer
}
public static final class SearchDefinitionSpec {
+
private final SearchDefinition searchDefinition;
private final UserConfigRepo userConfigRepo;
diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py b/config-model/src/test/integration/tensorflow/models/mnist_softmax/mnist_sftmax_with_saving.py
index a1861a1c981..a1861a1c981 100644
--- a/config-model/src/test/integration/tensorflow/mnist_softmax/mnist_sftmax_with_saving.py
+++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/mnist_sftmax_with_saving.py
diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/saved_model.pbtxt
index 8100dfd594d..8100dfd594d 100644
--- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/saved_model.pbtxt
+++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/saved_model.pbtxt
diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001 b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.data-00000-of-00001
index 8474aa0a04c..8474aa0a04c 100644
--- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.data-00000-of-00001
+++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.index
index cfcdac20409..cfcdac20409 100644
--- a/config-model/src/test/integration/tensorflow/mnist_softmax/saved/variables/variables.index
+++ b/config-model/src/test/integration/tensorflow/models/mnist_softmax/saved/variables/variables.index
Binary files differ
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index ff53fdafacf..7c749608e1f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -1,5 +1,7 @@
package com.yahoo.searchdefinition.processing;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
@@ -22,7 +24,11 @@ class RankProfileSearchFixture {
private Search search;
RankProfileSearchFixture(String rankProfiles) throws ParseException {
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ this(MockApplicationPackage.createEmpty(), rankProfiles);
+ }
+
+ RankProfileSearchFixture(ApplicationPackage applicationpackage, String rankProfiles) throws ParseException {
+ SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry);
String sdContent = "search test {\n" +
" document test {\n" +
" }\n" +
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 31f7511155b..0354173f365 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -1,24 +1,36 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.yolean.Exceptions;
import org.junit.After;
import org.junit.Test;
+import java.io.BufferedInputStream;
import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
import java.io.IOException;
+import java.io.InputStream;
+import java.io.Reader;
import java.io.UncheckedIOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
import java.util.Optional;
+import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -27,47 +39,52 @@ import static org.junit.Assert.fail;
*/
public class RankingExpressionWithTensorFlowTestCase {
- // The "../" is to escape the "models/" element prepended to the path
- private final String modelDirectory = "../src/test/integration/tensorflow/mnist_softmax/saved";
+ private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/");
private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))";
@After
public void removeGeneratedConstantTensorFiles() {
- IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "converted_variables"));
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
}
@Test
public void testMinimalTensorFlowReference() throws ParseException {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = new RankProfileSearchFixture(
+ application,
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: tensorflow('" + modelDirectory + "')" +
+ " expression: tensorflow('mnist_softmax/saved')" +
" }\n" +
" }");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertConstant(10, "Variable_1", search);
- assertConstant(7840, "Variable", search);
+ assertConstant("Variable_1", search, Optional.of(10L));
+ assertConstant("Variable", search, Optional.of(7840L));
}
@Test
public void testNestedTensorFlowReference() throws ParseException {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = new RankProfileSearchFixture(
+ application,
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: 5 + sum(tensorflow('" + modelDirectory + "'))" +
+ " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" +
" }\n" +
" }");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
- assertConstant(10, "Variable_1", search);
- assertConstant(7840, "Variable", search);
+ assertConstant("Variable_1", search, Optional.of(10L));
+ assertConstant("Variable", search, Optional.of(7840L));
}
@Test
public void testTensorFlowReferenceSpecifyingSignature() throws ParseException {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = new RankProfileSearchFixture(
+ application,
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: tensorflow('" + modelDirectory + "', 'serving_default')" +
+ " expression: tensorflow('mnist_softmax/saved', 'serving_default')" +
" }\n" +
" }");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -75,10 +92,12 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = new RankProfileSearchFixture(
+ application,
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'y')" +
+ " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" +
" }\n" +
" }");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -87,18 +106,21 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException {
try {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = new RankProfileSearchFixture(
+ application,
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: tensorflow('" + modelDirectory + "', 'serving_defaultz')" +
+ " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" +
" }\n" +
" }");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
- assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" +
- modelDirectory + "','serving_defaultz'): Model does not have the specified signature 'serving_defaultz'",
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
+ "tensorflow('mnist_softmax/saved','serving_defaultz'): " +
+ "Model does not have the specified signature 'serving_defaultz'",
Exceptions.toMessageString(expected));
}
}
@@ -106,36 +128,83 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException {
try {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = new RankProfileSearchFixture(
+ application,
" rank-profile my_profile {\n" +
" first-phase {\n" +
- " expression: tensorflow('" + modelDirectory + "', 'serving_default', 'x')" +
+ " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" +
" }\n" +
" }");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
catch (IllegalArgumentException expected) {
- assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from tensorflow('" +
- modelDirectory + "','serving_default','x'): Model does not have the specified output 'x'",
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
+ "tensorflow('mnist_softmax/saved','serving_default','x'): " +
+ "Model does not have the specified output 'x'",
Exceptions.toMessageString(expected));
}
}
- private void assertConstant(int expectedSize, String name, RankProfileSearchFixture search) {
+ @Test
+ public void testImportingFromStoredExpressions() throws ParseException, IOException {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ application,
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('mnist_softmax/saved', 'serving_default')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertConstant("Variable_1", search, Optional.of(10L));
+ assertConstant("Variable", search, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture(
+ storedApplication,
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: tensorflow('mnist_softmax/saved', 'serving_default')" +
+ " }\n" +
+ " }");
+ searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ // Verify that the constants exists, but don't verify the content as we are not
+ // simulating file distribution in this test
+ assertConstant("Variable_1", searchFromStored, Optional.empty());
+ assertConstant("Variable", searchFromStored, Optional.empty());
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+
+ }
+
+ /**
+ * Verifies that the constant with the given name exists, and - only if an expected size is given -
+ * that the content of the constant is available and has the expected size.
+ */
+ private void assertConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
try {
- TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove
- if (constant == null) { // New way
- File constantFile = new File(modelDirectory.substring(3) + "/converted_variables", name + ".tbf");
- RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
- assertEquals(name, rankingConstant.getName());
- assertEquals(constantFile.getAbsolutePath(), rankingConstant.getFileName());
- assertTrue("Constant file has been written", constantFile.exists());
- Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantFile)));
- assertEquals(expectedSize, deserializedConstant.size());
- } else { // Old way. TODO: Remove
- assertNotNull(name + " is imported", constant);
- assertEquals(expectedSize, constant.asTensor().size());
+ Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf");
+ RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
+ assertEquals(name, rankingConstant.getName());
+ assertEquals(constantApplicationPackagePath.toString(), rankingConstant.getFileName());
+
+ if (expectedSize.isPresent()) {
+ Path constantPath = applicationDir.append(constantApplicationPackagePath);
+ assertTrue("Constant file '" + constantPath + "' has been written",
+ constantPath.toFile().exists());
+ Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
+ assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
}
}
catch (IOException e) {
@@ -143,4 +212,138 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
+ private static class StoringApplicationPackage extends MockApplicationPackage {
+
+ private final File root;
+
+ StoringApplicationPackage(Path applicationPackageWritableRoot) {
+ this(applicationPackageWritableRoot.toFile());
+ }
+
+ StoringApplicationPackage(File applicationPackageWritableRoot) {
+ super(null, null, Collections.emptyList(), null,
+ null, null, false);
+ this.root = applicationPackageWritableRoot;
+ }
+
+ @Override
+ public File getFileReference(Path path) {
+ return Path.fromString(root.toString()).append(path).toFile();
+ }
+
+ @Override
+ public ApplicationFile getFile(Path file) {
+ return new StoringApplicationPackageFile(file, Path.fromString(root.toString()));
+ }
+
+ }
+
+ private static class StoringApplicationPackageFile extends ApplicationFile {
+
+ /** The path to the application package root */
+ private final Path root;
+
+ /** The File pointing to the actual file represented by this */
+ private final File file;
+
+ StoringApplicationPackageFile(Path filePath, Path applicationPackagePath) {
+ super(filePath);
+ this.root = applicationPackagePath;
+ file = applicationPackagePath.append(filePath).toFile();
+ }
+
+ @Override
+ public boolean isDirectory() {
+ return file.isDirectory();
+ }
+
+ @Override
+ public boolean exists() {
+ return file.exists();
+ }
+
+ @Override
+ public Reader createReader() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return IOUtils.createReader(file, "UTF-8");
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return new BufferedInputStream(new FileInputStream(file));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public ApplicationFile createDirectory() {
+ file.mkdirs();
+ return this;
+ }
+
+ @Override
+ public ApplicationFile writeFile(Reader input) {
+ try {
+ IOUtils.writeFile(file, IOUtils.readAll(input), false);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public List<ApplicationFile> listFiles(PathFilter filter) {
+ if ( ! isDirectory()) return Collections.emptyList();
+ return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString())))
+ .map(f -> new StoringApplicationPackageFile(asApplicationRelativePath(f),
+ root))
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public ApplicationFile delete() {
+ file.delete();
+ return this;
+ }
+
+ @Override
+ public MetaData getMetaData() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int compareTo(ApplicationFile other) {
+ return this.getPath().getName().compareTo((other).getPath().getName());
+ }
+
+ /** Strips the application package root path prefix from the path of the given file */
+ private Path asApplicationRelativePath(File file) {
+ Path path = Path.fromString(file.toString());
+
+ Iterator<String> pathIterator = path.iterator();
+ // Skip the path elements this shares with the root
+ for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) {
+ String rootElement = rootIterator.next();
+ String pathElement = pathIterator.next();
+ if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken");
+ }
+ // Build a path from the remaining
+ Path relative = Path.fromString("");
+ while (pathIterator.hasNext())
+ relative = relative.append(pathIterator.next());
+ return relative;
+ }
+
+ }
+
}