summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-21 09:38:31 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-21 09:38:31 +0200
commit6110db7eb9e5ca01e6cc9426e68e74232004ae63 (patch)
treee8eb26a98ab9d283b184db9a8246a5edc4b1ff23
parent51318703c054e71ce575d5edde90c6f84c3fedfd (diff)
Revert "Merge pull request #6635 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-2-4"
This reverts commit 3f91e18528b4982398332a30728eed8f7d2b580c, reversing changes made to 8e3ba08f1d3b79e573864726c6c03e58862feee6.
-rw-r--r--config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java3
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java394
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java5
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java28
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java22
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java43
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java9
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java4
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java13
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/path/Path.java1
15 files changed, 350 insertions, 210 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java
index 36009682022..7ca9bcf48f3 100644
--- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java
+++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationPackage.java
@@ -45,6 +45,7 @@ import java.security.MessageDigest;
import java.util.*;
import java.util.jar.JarFile;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
import static com.yahoo.text.Lowercase.toLowerCase;
@@ -164,7 +165,7 @@ public class FilesApplicationPackage implements ApplicationPackage {
return metaData;
}
- private List<NamedReader> getFiles(Path relativePath,String namePrefix,String suffix,boolean recurse) {
+ private List<NamedReader> getFiles(Path relativePath, String namePrefix, String suffix, boolean recurse) {
try {
List<NamedReader> readers=new ArrayList<>();
File dir = new File(appDir, relativePath.getRelative());
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
index dd54fe11c39..a71a0878d3d 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
@@ -88,12 +88,14 @@ public interface ApplicationPackage {
/**
* Contents of services.xml. Caller must close reader after use.
+ *
* @return a Reader, or null if no services.xml/vespa-services.xml present
*/
Reader getServices();
/**
* Contents of hosts.xml. Caller must close reader after use.
+ *
* @return a Reader, or null if no hosts.xml/vespa-hosts.xml present
*/
Reader getHosts();
@@ -160,8 +162,8 @@ public interface ApplicationPackage {
* Gets a file from the root of the application package
*
*
- * @param relativePath The relative path of the file within this application package.
- * @return reader for file
+ * @param relativePath the relative path of the file within this application package.
+ * @return information abut the file
* @throws IllegalArgumentException if the given path does not exist
*/
ApplicationFile getFile(Path relativePath);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
index 7b4d70d85b1..6311751bb88 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java
@@ -52,11 +52,11 @@ public class RankProfileRegistry {
}
private void checkForDuplicateRankProfile(RankProfile rankProfile) {
- final String rankProfileName = rankProfile.getName();
+ String rankProfileName = rankProfile.getName();
RankProfile existingRangProfileWithSameName = rankProfiles.get(rankProfile.getSearch()).get(rankProfileName);
if (existingRangProfileWithSameName == null) return;
- if (!overridableRankProfileNames.contains(rankProfileName)) {
+ if ( ! overridableRankProfileNames.contains(rankProfileName)) {
throw new IllegalArgumentException("Cannot add rank profile '" + rankProfileName + "' in search definition '"
+ rankProfile.getSearch().getName() + "', since it already exists");
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 867740c7912..d85d0983509 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -1,11 +1,11 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.io.IOUtils;
+import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.FeatureNames;
@@ -39,11 +39,14 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
+import java.io.Reader;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -62,69 +65,155 @@ import java.util.stream.Collectors;
*/
public class ConvertedModel {
- private final ExpressionNode convertedExpression;
+ private final String modelName;
+ private final Path modelPath;
- public ConvertedModel(FeatureArguments arguments,
+ /**
+ * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots
+ * where the first part is always the model name, the second the signature or (if none)
+ * expression name (if more than one), and the third is the output name (if any).
+ */
+ private final Map<String, RankingExpression> expressions;
+
+ public ConvertedModel(Path modelPath,
RankProfileTransformContext context,
ModelImporter modelImporter,
- Map<Path, ImportedModel> importedModels) {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
+ FeatureArguments arguments) { // TODO: Remove
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath);
if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory
- convertedExpression = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, importedModels);
+ expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, arguments);
else
- convertedExpression = transformFromStoredModel(store, context.rankProfile());
+ expressions = transformFromStoredModel(store, context.rankProfile());
}
- private ExpressionNode importModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelImporter modelImporter,
- Map<Path, ImportedModel> importedModels) {
- ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
- k -> modelImporter.importModel(store.arguments().modelName(),
- store.modelDir()));
- return transformFromImportedModel(model, store, profile, queryProfiles);
+ private Map<String, RankingExpression> importModel(ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelImporter modelImporter,
+ FeatureArguments arguments) {
+ ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.modelDir());
+ return transformFromImportedModel(model, store, profile, queryProfiles, arguments);
+ }
+
+ /** Returns the expression matching the given arguments */
+ public ExpressionNode expression(FeatureArguments arguments) {
+ if (expressions.isEmpty())
+ throw new IllegalArgumentException("No expressions available in " + this);
+
+ RankingExpression expression = expressions.get(arguments.toName());
+ if (expression != null) return expression.getRoot();
+
+ if ( ! arguments.signature().isPresent()) {
+ if (expressions.size() > 1)
+ throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix());
+ return expressions.values().iterator().next().getRoot();
+ }
+
+ if ( ! arguments.output().isPresent()) {
+ List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix =
+ expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(modelName + "." + arguments.signature().get() + ".")).collect(Collectors.toList());
+ if (entriesWithTheRightPrefix.size() < 1)
+ throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() +
+ missingExpressionMessageSuffix());
+ if (entriesWithTheRightPrefix.size() > 1)
+ throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() +
+ missingExpressionMessageSuffix());
+ return entriesWithTheRightPrefix.get(0).getValue().getRoot();
+ }
+
+ throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix());
}
- public ExpressionNode expression() { return convertedExpression; }
+ private String missingExpressionMessageSuffix() {
+ return "' in model '" + this.modelPath + "'. " +
+ "Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", "));
+ }
- private ExpressionNode transformFromImportedModel(ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
+ private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ FeatureArguments arguments) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
constantsReplacedByMacros, k, v));
- // Find the specified expression
- ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature());
- String output = chooseOutput(signature, store.arguments().output());
- if (signature.skippedOutputs().containsKey(output)) {
- String message = "Could not import model output '" + output + "'";
- if (!signature.skippedOutputs().get(output).isEmpty()) {
- message += ": " + signature.skippedOutputs().get(output);
+ // Add macros
+ addGeneratedMacros(model, profile);
+
+ // Add expressions
+ Map<String, RankingExpression> expressions = new HashMap<>();
+ for (Map.Entry<String, ImportedModel.Signature> signatureEntry : model.signatures().entrySet()) {
+ if ( ! matches(signatureEntry.getValue(), arguments, Optional.empty())) continue;
+
+ for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) {
+ if ( ! matches(signatureEntry.getValue(), arguments, Optional.of(outputEntry.getKey()))) continue;
+ addExpression(model.expressions().get(outputEntry.getValue()),
+ modelName + "." + signatureEntry.getKey() + "." + outputEntry.getKey(),
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
}
- if (!signature.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", signature.importWarnings());
+ if (signatureEntry.getValue().outputs().isEmpty()) { // fallback: Signature without outputs
+ addExpression(model.expressions().get(signatureEntry.getKey()),
+ modelName + "." + signatureEntry.getKey(),
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
}
- throw new IllegalArgumentException(message);
}
+ if (model.signatures().isEmpty()) { // fallback: Model without signatures
+ if (model.expressions().size() == 1) { // Use just model name
+ addExpression(model.expressions().values().iterator().next(),
+ modelName,
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
+ }
+ else {
+ for (Map.Entry<String, RankingExpression> expressionEntry : model.expressions().entrySet()) {
+ addExpression(expressionEntry.getValue(),
+ modelName + "." + expressionEntry.getKey(),
+ constantsReplacedByMacros,
+ model, store, profile, queryProfiles,
+ expressions);
+ }
+ }
+ }
+
+ // Transform and save macro - must come after reading expressions due to optimization transforms
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
- RankingExpression expression = model.expressions().get(output);
+ return expressions;
+ }
+
+ private boolean matches(ImportedModel.Signature signature, FeatureArguments arguments, Optional<String> output) {
+ if ( ! modelName.equals(arguments.modelName)) return false;
+ if ( arguments.signature.isPresent() && ! signature.name().equals(arguments.signature().get())) return false;
+ if (output.isPresent() && arguments.output().isPresent() && ! output.get().matches(arguments.output().get())) return false;
+ return true;
+ }
+
+ private void addExpression(RankingExpression expression,
+ String expressionName,
+ Set<String> constantsReplacedByMacros,
+ ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ Map<String, RankingExpression> expressions) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
+ store.writeExpression(expressionName, expression);
+ expressions.put(expressionName, expression);
}
- ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -137,60 +226,7 @@ public class ConvertedModel {
addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
}
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing signature, or the only signature if none is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) {
- if ( ! signatureName.isPresent()) {
- if (importResult.signatures().size() == 0)
- throw new IllegalArgumentException("No signatures are available");
- if (importResult.signatures().size() > 1)
- throw new IllegalArgumentException("Model has multiple signatures (" +
- Joiner.on(", ").join(importResult.signatures().keySet()) +
- "), one must be specified " +
- "as a second argument to tensorflow()");
- return importResult.signatures().values().stream().findFirst().get();
- }
- else {
- ImportedModel.Signature signature = importResult.signatures().get(signatureName.get());
- if (signature == null)
- throw new IllegalArgumentException("Model does not have the specified signature '" +
- signatureName.get() + "'");
- return signature;
- }
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (signature.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
- if (signature.outputs().size() > 1)
- throw new IllegalArgumentException(signature + " has multiple outputs (" +
- Joiner.on(", ").join(signature.outputs().keySet()) +
- "), one must be specified " +
- "as a third argument to tensorflow()");
- return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = signature.outputs().get(outputName.get());
- if (output == null) {
- if (signature.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- signature.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
+ return store.readExpressions();
}
private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
@@ -227,22 +263,15 @@ public class ConvertedModel {
}
private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
+ if (profile.getMacros().containsKey(macroName))
throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists.");
- }
+
profile.addMacro(macroName, false); // todo: inline if only used once
RankProfile.Macro macro = profile.getMacros().get(macroName);
macro.setRankingExpression(expression);
macro.setTextualExpression(expression.getRoot().toString());
}
- private String skippedOutputsDescription(ImportedModel.Signature signature) {
- if (signature.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
- }
-
/**
* Verify that the macros referred in the given expression exists in the given rank profile,
* and return tensors of the types specified in requiredMacros.
@@ -375,8 +404,8 @@ public class ConvertedModel {
/**
* If batch dimensions have been reduced away above, bring them back here
* for any following computation of the tensor.
- * Todo: determine when this is not necessary!
*/
+ // TODO: determine when this is not necessary!
private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
@@ -452,24 +481,29 @@ public class ConvertedModel {
return new TensorValue(tensor);
}
+ private static String toModelName(Path modelPath) {
+ return modelPath.toString().replace("/", "_");
+ }
+
+ @Override
+ public String toString() { return "model '" + modelName + "'"; }
+
/**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
static class ModelStore {
private final ApplicationPackage application;
- private final FeatureArguments arguments;
+ private final ModelFiles modelFiles;
- ModelStore(ApplicationPackage application, FeatureArguments arguments) {
+ ModelStore(ApplicationPackage application, Path modelPath) {
this.application = application;
- this.arguments = arguments;
+ this.modelFiles = new ModelFiles(modelPath);
}
- public FeatureArguments arguments() { return arguments; }
-
public boolean hasStoredModel() {
try {
- return application.getFile(arguments.expressionPath()).exists();
+ return application.getFileReference(modelFiles.storedModelPath()).exists();
}
catch (UnsupportedOperationException e) {
return false;
@@ -480,40 +514,49 @@ public class ConvertedModel {
* Returns the directory which contains the source model to use for these arguments
*/
public File modelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath()));
}
/**
* Adds this expression to the application package, such that it can be read later.
+ *
+ * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part
+ * is always the model name
*/
- void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
+ void writeExpression(String name, RankingExpression expression) {
+ application.getFile(modelFiles.expressionPath(name))
.writeFile(new StringReader(expression.getRoot().toString()));
}
- /** Reads the previously stored ranking expression for these arguments */
- RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ Map<String, RankingExpression> readExpressions() {
+ Map<String, RankingExpression> expressions = new HashMap<>();
+ ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
+ if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
+ for (ApplicationFile expressionFile : expressionPath.listFiles()) {
+ try {
+ String name = expressionFile.getPath().getName();
+ expressions.put(name, new RankingExpression(name, expressionFile.createReader()));
+ }
+ catch (FileNotFoundException e) {
+ throw new IllegalStateException("Expression file removed while reading: " + expressionFile, e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Invalid stored expression in " + expressionFile, e);
+ }
}
+ return expressions;
}
/** Adds this macro expression to the application package to it can be read later. */
void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" +
expression.getRoot().toString() + "\n");
}
/** Reads the previously stored macro expressions for these arguments */
List<Pair<String, RankingExpression>> readMacros() {
try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
+ ApplicationFile file = application.getFile(modelFiles.macrosPath());
if (!file.exists()) return Collections.emptyList();
List<Pair<String, RankingExpression>> macros = new ArrayList<>();
@@ -527,7 +570,7 @@ public class ConvertedModel {
macros.add(new Pair<>(name, expression));
}
catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ throw new IllegalStateException("Could not parse " + name, e);
}
}
return macros;
@@ -544,7 +587,7 @@ public class ConvertedModel {
List<RankingConstant> readLargeConstants() {
try {
List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) {
String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
}
@@ -562,13 +605,13 @@ public class ConvertedModel {
* @return the path to the stored constant, relative to the application package root
*/
Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants");
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
// Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ application.getFile(modelFiles.largeConstantsPath().append(name + ".constant"))
.writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
@@ -579,7 +622,7 @@ public class ConvertedModel {
private List<Pair<String, Tensor>> readSmallConstants() {
try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ ApplicationFile file = application.getFile(modelFiles.smallConstantsPath());
if (!file.exists()) return Collections.emptyList();
List<Pair<String, Tensor>> constants = new ArrayList<>();
@@ -604,7 +647,7 @@ public class ConvertedModel {
*/
public void writeSmallConstant(String name, Tensor constant) {
// Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" +
constant.type().toString() + "\t" +
constant.toString() + "\n");
}
@@ -628,26 +671,24 @@ public class ConvertedModel {
}
}
+ private void close(Reader reader) {
+ try {
+ if (reader != null)
+ reader.close();
+ }
+ catch (IOException e) {
+ // ignore
+ }
+ }
+
}
- /** Encapsulates the arguments to the import feature */
- static class FeatureArguments {
+ static class ModelFiles {
Path modelPath;
- /** Optional arguments */
- Optional<String> signature, output;
-
- public FeatureArguments(Arguments arguments) {
- this(Path.fromString(asString(arguments.expressions().get(0))),
- optionalArgument(1, arguments),
- optionalArgument(2, arguments));
- }
-
- public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
+ public ModelFiles(Path modelPath) {
this.modelPath = modelPath;
- this.signature = signature;
- this.output = output;
}
/** Returns modelPath with slashes replaced by underscores */
@@ -655,37 +696,66 @@ public class ConvertedModel {
/** Returns relative path to this model below the "models/" dir in the application package */
public Path modelPath() { return modelPath; }
- public Optional<String> signature() { return signature; }
- public Optional<String> output() { return output; }
- /** Path to the small constants file */
+ public Path storedModelPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath());
+ }
+
+ public Path expressionPath(String name) {
+ return storedModelPath().append("expressions").append(name);
+ }
+
+ public Path expressionsPath() {
+ return storedModelPath().append("expressions");
+ }
+
public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ return storedModelPath().append("constants.txt");
}
/** Path to the large (ranking) constants directory */
public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ return storedModelPath().append("constants");
}
/** Path to the macros file */
public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ return storedModelPath().append("macros.txt");
}
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ /** Encapsulates the arguments of a specific model output */
+ static class FeatureArguments {
+
+ private final String modelName;
+ private final Path modelPath;
+
+ /** Optional arguments */
+ private final Optional<String> signature, output;
+
+ public FeatureArguments(Arguments arguments) {
+ this(Path.fromString(asString(arguments.expressions().get(0))),
+ optionalArgument(1, arguments),
+ optionalArgument(2, arguments));
}
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- signature.ifPresent(s -> fileName.append(s).append("."));
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
+ public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
+ this.signature = signature;
+ this.output = output;
+ }
+
+ public Path modelPath() { return modelPath; }
+
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ public String toName() {
+ return modelName +
+ (signature.isPresent() ? "." + signature.get() : "") +
+ (output.isPresent() ? "." + output.get() : "");
}
private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
@@ -694,7 +764,7 @@ public class ConvertedModel {
return Optional.of(asString(arguments.expressions().get(argumentIndex)));
}
- private static String asString(ExpressionNode node) {
+ public static String asString(ExpressionNode node) {
if ( ! (node instanceof ConstantNode))
throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
return stripQuotes(((ConstantNode)node).sourceString());
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index d31ffefde65..97395c1aad3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -31,7 +31,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
private final OnnxImporter onnxImporter = new OnnxImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, ImportedModel> importedModels = new HashMap<>();
+ private final Map<Path, ConvertedModel> convertedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -47,9 +47,9 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
- ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()),
- context, onnxImporter, importedModels);
- return convertedModel.expression();
+ Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter, new ConvertedModel.FeatureArguments(feature.getArguments())));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index d28299b1d30..b3778e2af84 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -28,7 +28,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, ImportedModel> importedModels = new HashMap<>();
+ private final Map<Path, ConvertedModel> convertedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -44,9 +44,9 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()),
- context, tensorFlowImporter, importedModels);
- return convertedModel.expression();
+ Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter, new ConvertedModel.FeatureArguments(feature.getArguments())));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index 4ae223ec3a5..62f43e15849 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -37,7 +37,8 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
try {
ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments());
- ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
+ ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(),
+ arguments.modelPath());
RankingExpression expression = xgboostImporter.parseModel(store.modelDir().toString());
return expression.getRoot();
} catch (IllegalArgumentException | UncheckedIOException e) {
@@ -48,7 +49,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("An xgboost node must take an argument pointing to " +
- "the xgboost model directory under [application]/models");
+ "the xgboost model directory under [application]/models");
if (arguments.expressions().size() > 1)
throw new IllegalArgumentException("An xgboost feature can have at most 1 argument");
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 0ce6129ef7f..ab689b88993 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -10,7 +10,9 @@ import com.yahoo.searchdefinition.Search;
import com.yahoo.searchdefinition.SearchBuilder;
import com.yahoo.searchdefinition.parser.ParseException;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertEquals;
@@ -25,6 +27,7 @@ class RankProfileSearchFixture {
private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
private final QueryProfileRegistry queryProfileRegistry;
private Search search;
+ private Map<String, RankProfile> compiledRankProfiles = new HashMap<>();
RankProfileSearchFixture(String rankProfiles) throws ParseException {
this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles);
@@ -54,25 +57,38 @@ class RankProfileSearchFixture {
}
public void assertFirstPhaseExpression(String expExpression, String rankProfile) {
- assertEquals(expExpression, rankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString());
+ assertEquals(expExpression, compiledRankProfile(rankProfile).getFirstPhaseRanking().getRoot().toString());
}
public void assertSecondPhaseExpression(String expExpression, String rankProfile) {
- assertEquals(expExpression, rankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString());
+ assertEquals(expExpression, compiledRankProfile(rankProfile).getSecondPhaseRanking().getRoot().toString());
}
public void assertRankProperty(String expValue, String name, String rankProfile) {
- List<RankProfile.RankProperty> rankPropertyList = rankProfile(rankProfile).getRankPropertyMap().get(name);
+ List<RankProfile.RankProperty> rankPropertyList = compiledRankProfile(rankProfile).getRankPropertyMap().get(name);
assertEquals(1, rankPropertyList.size());
assertEquals(expValue, rankPropertyList.get(0).getValue());
}
- public void assertMacro(String expExpression, String macroName, String rankProfile) {
- assertEquals(expExpression, rankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString());
+ public void assertMacro(String expexctedExpression, String macroName, String rankProfile) {
+ assertEquals(expexctedExpression,
+ compiledRankProfile(rankProfile).getMacros().get(macroName).getRankingExpression().getRoot().toString());
}
+ public RankProfile compileRankProfile(String rankProfile) {
+ RankProfile compiled = rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry);
+ compiledRankProfiles.put(rankProfile, compiled);
+ return compiled;
+ }
+
+ /** Returns the given uncompiled profile */
public RankProfile rankProfile(String rankProfile) {
- return rankProfileRegistry.getRankProfile(search, rankProfile).compile(queryProfileRegistry);
+ return rankProfileRegistry.getRankProfile(search, rankProfile);
+ }
+
+ /** Returns the given compiled profile, or null if not compiled yet or not present at all */
+ public RankProfile compiledRankProfile(String rankProfile) {
+ return compiledRankProfiles.get(rankProfile);
}
public Search search() { return search; }
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index b2ef08dcc36..a7465fa9695 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -123,6 +123,7 @@ public class RankingExpressionWithOnnxTestCase {
" expression: onnx('mnist_softmax.onnx')" +
" }\n" +
" }");
+ search.compileRankProfile("my_profile");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
@@ -164,7 +165,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx','y'): " +
- "Model does not have the specified signature 'y'",
+ "No expressions available in model 'mnist_softmax.onnx'",
+// "No expressions named 'y' in model 'mnist_softmax.onnx'. Available expressions: mnist_softmax.onnx.default.add",
Exceptions.toMessageString(expected));
}
}
@@ -220,7 +222,8 @@ public class RankingExpressionWithOnnxTestCase {
String vespaExpressionWithoutConstant =
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_onnx_Variable, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))";
- RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ search.compileRankProfile("my_profile");
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
@@ -234,7 +237,8 @@ public class RankingExpressionWithOnnxTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication);
+ RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication);
+ searchFromStored.compileRankProfile("my_profile");
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
searchFromStored.search().getRankingConstants().get("mnist_softmax_onnx_Variable"));
@@ -271,19 +275,19 @@ public class RankingExpressionWithOnnxTestCase {
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
- new StoringApplicationPackage(applicationDir));
+ new StoringApplicationPackage(applicationDir));
}
private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
String constant, String field) {
return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder",
- new StoringApplicationPackage(applicationDir));
+ new StoringApplicationPackage(applicationDir));
}
- private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) {
+ private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) {
try {
return new RankProfileSearchFixture(application, application.getQueryProfiles(),
- rankProfile, null, null);
+ rankProfile, null, null);
}
catch (ParseException e) {
throw new IllegalArgumentException(e);
@@ -297,7 +301,7 @@ public class RankingExpressionWithOnnxTestCase {
String macroName,
StoringApplicationPackage application) {
try {
- return new RankProfileSearchFixture(
+ RankProfileSearchFixture fixture = new RankProfileSearchFixture(
application,
application.getQueryProfiles(),
" rank-profile my_profile {\n" +
@@ -310,6 +314,8 @@ public class RankingExpressionWithOnnxTestCase {
" }",
constant,
field);
+ fixture.compileRankProfile("my_profile");
+ return fixture;
}
catch (ParseException e) {
throw new IllegalArgumentException(e);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 7228af2b0de..29859817736 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -6,6 +6,7 @@ import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
+import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
@@ -22,10 +23,12 @@ import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
+import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.io.UncheckedIOException;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
@@ -156,6 +159,7 @@ public class RankingExpressionWithTensorFlowTestCase {
" expression: tensorflow('mnist_softmax/saved')" +
" }\n" +
" }");
+ search.compileRankProfile("my_profile");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
}
@@ -196,7 +200,9 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_defaultz'): " +
- "Model does not have the specified signature 'serving_defaultz'",
+ "No expressions available in model 'mnist_softmax_saved'",
+// "No expressions named 'serving_defaultz' in model 'mnist_softmax/saved'. "+
+// "Available expressions: mnist_softmax_saved.serving_default.y",
Exceptions.toMessageString(expected));
}
}
@@ -212,7 +218,9 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved','serving_default','x'): " +
- "Model does not have the specified output 'x'",
+ "No expressions available in model 'mnist_softmax_saved'",
+// "No expression 'mnist_softmax_saved.serving_default.x' in model 'mnist_softmax/saved'. " +
+// "Available expressions: mnist_softmax_saved.serving_default.y",
Exceptions.toMessageString(expected));
}
}
@@ -268,7 +276,8 @@ public class RankingExpressionWithTensorFlowTestCase {
String vespaExpressionWithoutConstant =
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_saved_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))";
- RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ RankProfileSearchFixture search = fixtureWithUncompiled(rankProfile, new StoringApplicationPackage(applicationDir));
+ search.compileRankProfile("my_profile");
search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
@@ -282,7 +291,8 @@ public class RankingExpressionWithTensorFlowTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication);
+ RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfile, storedApplication);
+ searchFromStored.compileRankProfile("my_profile");
searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
assertNull("Constant overridden by macro is not added",
searchFromStored.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read"));
@@ -297,7 +307,7 @@ public class RankingExpressionWithTensorFlowTestCase {
public void testTensorFlowReduceBatchDimension() {
final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
- "tensorflow('mnist_softmax/saved')");
+ "tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(expression, "my_profile");
assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", search, Optional.of(10L));
assertLargeConstant("mnist_softmax_saved_layer_Variable_read", search, Optional.of(7840L));
@@ -362,7 +372,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
- Value value = search.rankProfile("my_profile").getConstants().get(name);
+ Value value = search.compiledRankProfile("my_profile").getConstants().get(name);
assertNotNull(value);
assertEquals(type, value.type());
}
@@ -410,7 +420,7 @@ public class RankingExpressionWithTensorFlowTestCase {
String macroName,
StoringApplicationPackage application) {
try {
- return new RankProfileSearchFixture(
+ RankProfileSearchFixture fixture = new RankProfileSearchFixture(
application,
application.getQueryProfiles(),
" rank-profile my_profile {\n" +
@@ -423,13 +433,15 @@ public class RankingExpressionWithTensorFlowTestCase {
" }",
constant,
field);
+ fixture.compileRankProfile("my_profile");
+ return fixture;
}
catch (ParseException e) {
throw new IllegalArgumentException(e);
}
}
- private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) {
+ private RankProfileSearchFixture fixtureWithUncompiled(String rankProfile, StoringApplicationPackage application) {
try {
return new RankProfileSearchFixture(application, application.getQueryProfiles(),
rankProfile, null, null);
@@ -463,6 +475,21 @@ public class RankingExpressionWithTensorFlowTestCase {
return new StoringApplicationPackageFile(file, Path.fromString(root.toString()));
}
+ @Override
+ public List<NamedReader> getFiles(Path path, String suffix) {
+ List<NamedReader> readers = new ArrayList<>();
+ for (File file : getFileReference(path).listFiles()) {
+ if ( ! file.getName().endsWith(suffix)) continue;
+ try {
+ readers.add(new NamedReader(file.getName(), new FileReader(file)));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+ return readers;
+ }
+
}
static class StoringApplicationPackageFile extends ApplicationFile {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
index dba2bdbfbbf..0866d3192cf 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
@@ -25,6 +25,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
+ f.compileRankProfile("my_profile");
f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1,y:2}:1.0,{x:2,y:1}:2.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor(x{},y{})", "constant(my_tensor).type", "my_profile");
@@ -47,6 +48,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
+ f.compileRankProfile("my_profile");
f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1,y:2}:1.0,{x:2,y:1}:2.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor(x{},y{})", "constant(my_tensor).type", "my_profile");
@@ -65,6 +67,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
+ f.compileRankProfile("my_profile");
f.assertSecondPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor(x{})", "constant(my_tensor).type", "my_profile");
@@ -85,6 +88,7 @@ public class RankingExpressionWithTensorTestCase {
" expression: sum(my_tensor)\n" +
" }\n" +
" }");
+ f.compileRankProfile("my_profile");
f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor(x{})", "constant(my_tensor).type", "my_profile");
@@ -106,6 +110,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
+ f.compileRankProfile("my_profile");
f.assertFirstPhaseExpression("5.0 + my_macro", "my_profile");
f.assertMacro("reduce(constant(my_tensor), sum)", "my_macro", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
@@ -127,6 +132,7 @@ public class RankingExpressionWithTensorTestCase {
" my_number_2: 5.0\n" +
" }\n" +
" }");
+ f.compileRankProfile("my_profile");
f.assertFirstPhaseExpression("3.0 + reduce(constant(my_tensor), sum) + 5.0", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor(x{})", "constant(my_tensor).type", "my_profile");
@@ -139,7 +145,7 @@ public class RankingExpressionWithTensorTestCase {
public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'");
- new RankProfileSearchFixture(
+ RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" constants {\n" +
" my_tensor {\n" +
@@ -148,6 +154,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
+ f.compileRankProfile("my_profile");
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
index b65cb0b3d5f..f98783ad671 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
@@ -36,7 +36,7 @@ public class RankingExpressionWithXgboostTestCase {
String field,
RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage application) {
try {
- return new RankProfileSearchFixture(
+ RankProfileSearchFixture fixture = new RankProfileSearchFixture(
application,
application.getQueryProfiles(),
" rank-profile my_profile {\n" +
@@ -46,6 +46,8 @@ public class RankingExpressionWithXgboostTestCase {
" }",
constant,
field);
+ fixture.compileRankProfile("my_profile");
+ return fixture;
} catch (ParseException e) {
throw new IllegalArgumentException(e);
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java
index 2c46f591037..e9b4d6ac1aa 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationPackage.java
@@ -24,12 +24,14 @@ import java.io.File;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.stream.Collectors;
/**
* Represents an application residing in zookeeper.
@@ -200,16 +202,17 @@ public class ZKApplicationPackage implements ApplicationPackage {
return ret;
}
- //Returns readers for all the children of a node.
- //The node is looked up relative to the location of the active application package
- //in zookeeper.
+ /**
+ * Returns readers for all the children of a node.
+ * The node is looked up relative to the location of the active application package in zookeeper.
+ */
@Override
- public List<NamedReader> getFiles(Path relativePath,String suffix,boolean recurse) {
+ public List<NamedReader> getFiles(Path relativePath, String suffix, boolean recurse) {
return liveApp.getAllDataFromDirectory(ConfigCurator.USERAPP_ZK_SUBPATH + '/' + relativePath.getRelative(), suffix, recurse);
}
@Override
- public ApplicationFile getFile(Path file) { // foo/bar/baz.json
+ public ApplicationFile getFile(Path file) {
return new ZKApplicationFile(file, liveApp);
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java
index d7d43dea022..956af02e36f 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKLiveApp.java
@@ -69,7 +69,8 @@ public class ZKLiveApp {
log.finer("ZKApplicationPackage: Skipped '" + child + "' (did not match suffix " + fileNameSuffix + ")");
}
if (recursive)
- result.addAll(getAllDataFromDirectory(path + "/" + child, namePrefix + child + "/", fileNameSuffix, recursive));
+ result.addAll(getAllDataFromDirectory(path + "/" + child,
+ namePrefix + child + "/", fileNameSuffix, recursive));
}
if (log.isLoggable(Level.FINE))
log.fine("ZKApplicationPackage: Found '" + result.size() + "' files in " + fullPath);
@@ -80,14 +81,15 @@ public class ZKLiveApp {
}
/**
- * Retrieves a node relative to the node of the live application, e.g. /vespa/config/apps/$lt;app_id&gt;/&lt;path&gt;/&lt;node&gt;
+ * Retrieves a node relative to the node of the live application,
+ * e.g. /vespa/config/apps/$lt;app_id&gt;/&lt;path&gt;/&lt;node&gt;
*
* @param path a path relative to the currently active application
* @param node a path relative to the path above
* @return a Reader that can be used to get the data
*/
public Reader getDataReader(String path, String node) {
- final String data = getData(path, node);
+ String data = getData(path, node);
if (data == null) {
throw new IllegalArgumentException("No node for " + getFullPath(path) + "/" + node + " exists");
}
@@ -98,7 +100,8 @@ public class ZKLiveApp {
try {
return zk.getData(getFullPath(path), node);
} catch (Exception e) {
- throw new IllegalArgumentException("Could not retrieve node '" + getFullPath(path) + "/" + node + "' in zookeeper", e);
+ throw new IllegalArgumentException("Could not retrieve node '" +
+ getFullPath(path) + "/" + node + "' in zookeeper", e);
}
}
@@ -205,5 +208,6 @@ public class ZKLiveApp {
}
return reader(data);
}
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/path/Path.java b/vespajlib/src/main/java/com/yahoo/path/Path.java
index c466fe50d6f..2806631be18 100644
--- a/vespajlib/src/main/java/com/yahoo/path/Path.java
+++ b/vespajlib/src/main/java/com/yahoo/path/Path.java
@@ -84,6 +84,7 @@ public final class Path {
/**
* Get the name of this path element, typically the last element in the path string.
+ *
* @return the name
*/
public String getName() {