summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-08-29 11:00:43 +0200
committerGitHub <noreply@github.com>2018-08-29 11:00:43 +0200
commit89e6983eb08238e4912aeabb7b001ad966cfb23b (patch)
treeb3ee61981cce60a9296e69d7d24b0a671fac52c2 /config-model
parent389c149b8f1a1f27ae6b060f6ffa8958daccec5b (diff)
parent2f615f8288220a2667d75544d8bc747119cd3013 (diff)
Merge pull request #6713 from vespa-engine/bratseth/generate-rank-profiles-for-all-models-part-10
Read stored models from Zk package for global rank profiles
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java126
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java245
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java35
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java13
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java1
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java123
9 files changed, 306 insertions, 251 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
index 7404ae14a5d..757dab4cbf3 100644
--- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
+++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
@@ -16,11 +16,17 @@ import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.config.application.api.ApplicationPackage;
+import java.io.BufferedInputStream;
import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
import java.io.IOException;
+import java.io.InputStream;
import java.io.Reader;
import java.io.StringReader;
+import java.io.UncheckedIOException;
import java.util.*;
+import java.util.stream.Collectors;
/**
* For testing purposes only
@@ -113,7 +119,7 @@ public class MockApplicationPackage implements ApplicationPackage {
@Override
public ApplicationFile getFile(Path file) {
- throw new UnsupportedOperationException();
+ return new MockApplicationFile(file, Path.fromString(root.toString()));
}
@Override
@@ -300,4 +306,122 @@ public class MockApplicationPackage implements ApplicationPackage {
return xmlStringWithIdAttribute.substring(idStart + 4, idEnd - 1);
}
+ public static class MockApplicationFile extends ApplicationFile {
+
+ /** The path to the application package root */
+ private final Path root;
+
+ /** The File pointing to the actual file represented by this */
+ private final File file;
+
+ public MockApplicationFile(Path filePath, Path applicationPackagePath) {
+ super(filePath);
+ this.root = applicationPackagePath;
+ file = applicationPackagePath.append(filePath).toFile();
+ }
+
+ @Override
+ public boolean isDirectory() {
+ return file.isDirectory();
+ }
+
+ @Override
+ public boolean exists() {
+ return file.exists();
+ }
+
+ @Override
+ public Reader createReader() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return IOUtils.createReader(file, "UTF-8");
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return new BufferedInputStream(new FileInputStream(file));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public ApplicationFile createDirectory() {
+ file.mkdirs();
+ return this;
+ }
+
+ @Override
+ public ApplicationFile writeFile(Reader input) {
+ try {
+ IOUtils.writeFile(file, IOUtils.readAll(input), false);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public ApplicationFile appendFile(String value) {
+ try {
+ IOUtils.writeFile(file, value, true);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public List<ApplicationFile> listFiles(PathFilter filter) {
+ if ( ! isDirectory()) return Collections.emptyList();
+ return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString())))
+ .map(f -> new MockApplicationFile(asApplicationRelativePath(f), root))
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public ApplicationFile delete() {
+ file.delete();
+ return this;
+ }
+
+ @Override
+ public MetaData getMetaData() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int compareTo(ApplicationFile other) {
+ return this.getPath().getName().compareTo((other).getPath().getName());
+ }
+
+ /** Strips the application package root path prefix from the path of the given file */
+ private Path asApplicationRelativePath(File file) {
+ Path path = Path.fromString(file.toString());
+
+ Iterator<String> pathIterator = path.iterator();
+ // Skip the path elements this shares with the root
+ for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) {
+ String rootElement = rootIterator.next();
+ String pathElement = pathIterator.next();
+ if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken");
+ }
+ // Build a path from the remaining
+ Path relative = Path.fromString("");
+ while (pathIterator.hasNext())
+ relative = relative.append(pathIterator.next());
+ return relative;
+ }
+
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 0911f567fa1..935b9200868 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -1,5 +1,6 @@
package com.yahoo.searchdefinition.expressiontransforms;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -16,7 +17,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -38,7 +38,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
@@ -65,49 +64,91 @@ import java.util.stream.Collectors;
public class ConvertedModel {
private final String modelName;
- private final Path modelPath;
-
- /**
- * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots
- * where the first part is always the model name, the second the signature or (if none)
- * expression name (if more than one), and the third is the output name (if any).
- */
- private final Map<String, RankingExpression> expressions;
+ private final String modelDescription;
+ private final ImmutableMap<String, RankingExpression> expressions;
+
+ /** The source importedModel, or empty if this was created from a stored converted model */
+ private final Optional<ImportedModel> sourceModel;
+
+ private ConvertedModel(String modelName,
+ String modelDescription,
+ Map<String, RankingExpression> expressions,
+ Optional<ImportedModel> sourceModel) {
+ this.modelName = modelName;
+ this.modelDescription = modelDescription;
+ this.expressions = ImmutableMap.copyOf(expressions);
+ this.sourceModel = sourceModel;
+ }
/**
- * Create a converted model for a rank profile given from either an imported model,
+ * Create and store a converted model for a rank profile given from either an imported model,
* or (if unavailable) from stored application package data.
*/
- public ConvertedModel(Path modelPath, RankProfileTransformContext context) {
- this.modelPath = modelPath;
- this.modelName = toModelName(modelPath);
- ModelStore store = new ModelStore(context.rankProfile().applicationPackage(), modelPath);
- if ( store.hasSourceModel())
- expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels());
+ public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) {
+ File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath);
+ if (sourceModel.exists())
+ return fromSource(toModelName(modelPath),
+ modelPath.toString(),
+ context.rankProfile(),
+ context.queryProfiles(),
+ context.importedModels().get(sourceModel)); // TODO: Convert to name here, make sure its done just one way
else
- expressions = transformFromStoredModel(store, context.rankProfile());
+ return fromStore(toModelName(modelPath),
+ modelPath.toString(),
+ context.rankProfile());
+ }
+
+ public static ConvertedModel fromSource(String modelName,
+ String modelDescription,
+ RankProfile rankProfile,
+ QueryProfileRegistry queryProfileRegistry,
+ ImportedModel importedModel) {
+ ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
+ return new ConvertedModel(modelName,
+ modelDescription,
+ convertAndStore(importedModel, rankProfile, queryProfileRegistry, modelStore),
+ Optional.of(importedModel));
}
- private Map<String, RankingExpression> convertModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ImportedModels importedModels) {
- ImportedModel model = importedModels.get(store.sourceModelFile());
- return transformFromImportedModel(model, store, profile, queryProfiles);
+ public static ConvertedModel fromStore(String modelName,
+ String modelDescription,
+ RankProfile rankProfile) {
+ ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
+ return new ConvertedModel(modelName,
+ modelDescription,
+ convertStored(modelStore, rankProfile),
+ Optional.empty());
}
- /** Returns the expression matching the given arguments */
- public ExpressionNode expression(FeatureArguments arguments) {
+ /**
+ * Returns all the output expressions of this indexed by name. The names consist of one or two parts
+ * separated by dot, where the first part is the signature name
+ * if signatures are used, or the expression name if signatures are not used and there are multiple
+ * expressions, and the second is the output name if signature names are used.
+ */
+ public Map<String, RankingExpression> expressions() { return expressions; }
+
+ /**
+ * Returns the expression matching the given arguments.
+ */
+ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
+ RankingExpression expression = selectExpression(arguments);
+ if (sourceModel.isPresent()) // we can verify
+ verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
+ return expression.getRoot();
+ }
+
+ private RankingExpression selectExpression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
RankingExpression expression = expressions.get(arguments.toName());
- if (expression != null) return expression.getRoot();
+ if (expression != null) return expression;
if ( ! arguments.signature().isPresent()) {
if (expressions.size() > 1)
throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix());
- return expressions.values().iterator().next().getRoot();
+ return expressions.values().iterator().next();
}
if ( ! arguments.output().isPresent()) {
@@ -119,21 +160,23 @@ public class ConvertedModel {
if (entriesWithTheRightPrefix.size() > 1)
throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() +
missingExpressionMessageSuffix());
- return entriesWithTheRightPrefix.get(0).getValue().getRoot();
+ return entriesWithTheRightPrefix.get(0).getValue();
}
throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix());
}
private String missingExpressionMessageSuffix() {
- return "' in model '" + this.modelPath + "'. " +
+ return "' in model '" + modelDescription + "'. " +
"Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", "));
}
- private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
+ // ----------------------- Static model conversion/storage below here
+
+ private static Map<String, RankingExpression> convertAndStore(ImportedModel model,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelStore store) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -161,22 +204,21 @@ public class ConvertedModel {
return expressions;
}
- private void addExpression(RankingExpression expression,
- String expressionName,
- Set<String> constantsReplacedByMacros,
- ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- Map<String, RankingExpression> expressions) {
+ private static void addExpression(RankingExpression expression,
+ String expressionName,
+ Set<String> constantsReplacedByMacros,
+ ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ Map<String, RankingExpression> expressions) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
reduceBatchDimensions(expression, model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) {
+ private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -192,12 +234,12 @@ public class ConvertedModel {
return store.readExpressions();
}
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
store.writeSmallConstant(constantName, constantValue);
profile.addConstant(constantName, asValue(constantValue));
}
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
Set<String> constantsReplacedByMacros,
String constantName, Tensor constantValue) {
RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
@@ -217,7 +259,7 @@ public class ConvertedModel {
}
}
- private void transformGeneratedMacro(ModelStore store,
+ private static void transformGeneratedMacro(ModelStore store,
Set<String> constantsReplacedByMacros,
String macroName,
RankingExpression expression) {
@@ -226,7 +268,7 @@ public class ConvertedModel {
store.writeMacro(macroName, expression);
}
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
if (profile.getMacros().containsKey(macroName)) {
if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression))
throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile +
@@ -243,8 +285,8 @@ public class ConvertedModel {
* Verify that the macros referred in the given expression exists in the given rank profile,
* and return tensors of the types specified in requiredMacros.
*/
- private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
@@ -272,7 +314,7 @@ public class ConvertedModel {
}
}
- private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
(actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
"in query profile types - see the documentation."
@@ -282,7 +324,7 @@ public class ConvertedModel {
/**
* Add the generated macros to the rank profile
*/
- private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
+ private static void addGeneratedMacros(ImportedModel model, RankProfile profile) {
model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy()));
}
@@ -291,8 +333,8 @@ public class ConvertedModel {
* macro specifies that a single exemplar should be evaluated, we can
* reduce the batch dimension out.
*/
- private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
@@ -319,8 +361,8 @@ public class ConvertedModel {
expression.setRoot(root);
}
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
- TypeContext<Reference> typeContext) {
+ private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
@@ -350,7 +392,7 @@ public class ConvertedModel {
return node;
}
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
TensorFunction result = function;
TensorType type = function.type(context);
if (type.dimensions().size() > 1) {
@@ -372,7 +414,7 @@ public class ConvertedModel {
* for any following computation of the tensor.
*/
// TODO: determine when this is not necessary!
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
}
@@ -399,14 +441,14 @@ public class ConvertedModel {
* If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
* This method does that for the given expression and returns the result.
*/
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ private static RankingExpression replaceConstantsByMacros(RankingExpression expression,
Set<String> constantsReplacedByMacros) {
if (constantsReplacedByMacros.isEmpty()) return expression;
return new RankingExpression(expression.getName(),
replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
}
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
if (node instanceof ReferenceNode) {
Reference reference = ((ReferenceNode)node).reference();
if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
@@ -424,7 +466,7 @@ public class ConvertedModel {
return node;
}
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
@@ -440,7 +482,7 @@ public class ConvertedModel {
}
}
- private Value asValue(Tensor tensor) {
+ private static Value asValue(Tensor tensor) {
if (tensor.type().rank() == 0)
return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
else
@@ -455,6 +497,13 @@ public class ConvertedModel {
public String toString() { return "model '" + modelName + "'"; }
/**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public static File sourceModelFile(ApplicationPackage application, Path sourceModelPath) {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(sourceModelPath));
+ }
+
+ /**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
static class ModelStore {
@@ -462,20 +511,9 @@ public class ConvertedModel {
private final ApplicationPackage application;
private final ModelFiles modelFiles;
- ModelStore(ApplicationPackage application, Path modelPath) {
+ ModelStore(ApplicationPackage application, String modelName) {
this.application = application;
- this.modelFiles = new ModelFiles(modelPath);
- }
-
- public boolean hasSourceModel() {
- return sourceModelFile().exists();
- }
-
- /**
- * Returns the directory which contains the source model to use for these arguments
- */
- public File sourceModelFile() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath()));
+ this.modelFiles = new ModelFiles(modelName);
}
/**
@@ -548,7 +586,7 @@ public class ConvertedModel {
List<RankingConstant> readLargeConstants() {
try {
List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) {
+ for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsInfoPath()).listFiles()) {
String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
}
@@ -566,13 +604,13 @@ public class ConvertedModel {
* @return the path to the stored constant, relative to the application package root
*/
Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants");
+ Path constantsPath = modelFiles.largeConstantsContentPath();
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
// Remember the constant in a file we replicate in ZooKeeper
- application.getFile(modelFiles.largeConstantsPath().append(name + ".constant"))
+ application.getFile(modelFiles.largeConstantsInfoPath().append(name + ".constant"))
.writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
@@ -609,8 +647,8 @@ public class ConvertedModel {
public void writeSmallConstant(String name, Tensor constant) {
// Secret file format for remembering constants:
application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
}
/** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
@@ -632,40 +670,24 @@ public class ConvertedModel {
}
}
- private void close(Reader reader) {
- try {
- if (reader != null)
- reader.close();
- }
- catch (IOException e) {
- // ignore
- }
- }
-
}
static class ModelFiles {
- Path modelPath;
+ String modelName;
- public ModelFiles(Path modelPath) {
- this.modelPath = modelPath;
+ public ModelFiles(String modelName) {
+ this.modelName = modelName;
}
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
-
/** Files stored below this path will be replicated in zookeeper */
public Path storedModelReplicatedPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath());
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName);
}
- /** Files stored below this path will not be replicated */
+ /** Files stored below this path will not be replicated in zookeeper */
public Path storedModelPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath());
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName);
}
public Path expressionPath(String name) {
@@ -681,7 +703,12 @@ public class ConvertedModel {
}
/** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
+ public Path largeConstantsContentPath() {
+ return storedModelPath().append("constants");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsInfoPath() {
return storedModelReplicatedPath().append("constants");
}
@@ -695,27 +722,19 @@ public class ConvertedModel {
/** Encapsulates the arguments of a specific model output */
static class FeatureArguments {
- private final String modelName;
- private final Path modelPath;
-
/** Optional arguments */
private final Optional<String> signature, output;
public FeatureArguments(Arguments arguments) {
- this(Path.fromString(asString(arguments.expressions().get(0))),
- optionalArgument(1, arguments),
+ this(optionalArgument(1, arguments),
optionalArgument(2, arguments));
}
- public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
- this.modelPath = modelPath;
- this.modelName = toModelName(modelPath);
+ public FeatureArguments(Optional<String> signature, Optional<String> output) {
this.signature = signature;
this.output = output;
}
- public Path modelPath() { return modelPath; }
-
public Optional<String> signature() { return signature; }
public Optional<String> output() { return output; }
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 36dc200f3c9..229ae0ebaaf 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -13,6 +13,7 @@ import java.io.File;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
/**
* Replaces instances of the onnx(model-path, output)
@@ -41,10 +42,11 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
+ // TODO: Put modelPath in FeatureArguments instead
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()));
+ convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 619c13da764..bcb8ef1521d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -41,8 +41,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()));
+ convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index e6b08ab0350..b4a5069b9d6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -43,8 +43,8 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedXGBoostModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()));
+ convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
} catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 73dd60f63eb..ec92905cd1f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -8,6 +8,7 @@ import com.yahoo.config.ConfigInstance;
import com.yahoo.config.ConfigInstance.Builder;
import com.yahoo.config.ConfigurationRuntimeException;
import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.application.api.ValidationId;
@@ -26,11 +27,13 @@ import com.yahoo.config.model.producer.AbstractConfigProducerRoot;
import com.yahoo.config.model.producer.UserConfigRepo;
import com.yahoo.config.provision.AllocatedHosts;
import com.yahoo.log.LogLevel;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstants;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RankProfileList;
+import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
@@ -162,7 +165,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
this.applicationPackage = deployState.getApplicationPackage();
root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this);
- createGlobalRankProfiles(deployState.getImportedModels(), deployState.rankProfileRegistry());
+ createGlobalRankProfiles(deployState.getImportedModels(),
+ deployState.rankProfileRegistry(),
+ deployState.getQueryProfiles().getRegistry());
this.rankProfileList = new RankProfileList(null, // null search -> global
AttributeFields.empty,
deployState.rankProfileRegistry(),
@@ -220,14 +225,30 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
* Creates a rank profile not attached to any search definition, for each imported model in the application package
*/
private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels,
- RankProfileRegistry rankProfileRegistry) {
+ RankProfileRegistry rankProfileRegistry,
+ QueryProfileRegistry queryProfiles) {
List<RankProfile> profiles = new ArrayList<>();
- for (ImportedModel model : importedModels.all()) {
- RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
- for (Pair<String, RankingExpression> entry : model.outputExpressions()) {
- profile.addMacro(entry.getFirst(), false).setRankingExpression(entry.getSecond());
+ if ( ! importedModels.all().isEmpty()) { // models/ directory is available
+ for (ImportedModel model : importedModels.all()) {
+ RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
+ rankProfileRegistry.add(profile);
+ ConvertedModel convertedModel = ConvertedModel.fromSource(model.name(), model.name(), profile, queryProfiles, model);
+ for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
+ profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
+ }
+ }
+ }
+ else { // generated and saved model information may be available instead
+ ApplicationFile generatedModelsDir = applicationPackage.getFile(ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR);
+ for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) {
+ String modelName = generatedModelDir.getPath().last();
+ RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
+ rankProfileRegistry.add(profile);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, modelName, profile);
+ for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
+ profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
+ }
}
- rankProfileRegistry.add(profile);
}
return ImmutableList.copyOf(profiles);
}
diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
index d06752c9b6d..8558ccc44bd 100644
--- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java
@@ -2,10 +2,14 @@ package com.yahoo.config.model;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ContainerCluster;
+import org.junit.After;
import org.junit.Test;
import org.xml.sax.SAXException;
@@ -22,11 +26,16 @@ import static org.junit.Assert.assertTrue;
*/
public class ModelEvaluationTest {
- private static final String TESTDIR = "src/test/cfg/application/";
+ private static final String appDir = "src/test/cfg/application/ml_serving";
+
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(Path.fromString(appDir).append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
@Test
public void testMl_ServingApplication() throws SAXException, IOException {
- ApplicationPackageTester tester = ApplicationPackageTester.create(TESTDIR + "ml_serving");
+ ApplicationPackageTester tester = ApplicationPackageTester.create(appDir);
VespaModel model = new VespaModel(tester.app());
ContainerCluster cluster = model.getContainerClusters().get("container");
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 815a01cdb99..04a6f953bb6 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -142,7 +142,6 @@ public class RankingExpressionWithOnnxTestCase {
}
}
-
@Test
public void testOnnxReferenceWithWrongMacroType() {
try {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index c317f07b87a..28fcf871cf3 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -403,7 +403,7 @@ public class RankingExpressionWithTensorFlowTestCase {
*/
private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
try {
- Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf");
+ Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax_saved/constants").append(name + ".tbf");
RankingConstant rankingConstant = search.search().rankingConstants().get(name);
assertEquals(name, rankingConstant.getName());
assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
@@ -485,7 +485,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Override
public ApplicationFile getFile(Path file) {
- return new StoringApplicationPackageFile(file, Path.fromString(root().toString()));
+ return new MockApplicationFile(file, Path.fromString(root().toString()));
}
@Override
@@ -505,123 +505,4 @@ public class RankingExpressionWithTensorFlowTestCase {
}
- static class StoringApplicationPackageFile extends ApplicationFile {
-
- /** The path to the application package root */
- private final Path root;
-
- /** The File pointing to the actual file represented by this */
- private final File file;
-
- StoringApplicationPackageFile(Path filePath, Path applicationPackagePath) {
- super(filePath);
- this.root = applicationPackagePath;
- file = applicationPackagePath.append(filePath).toFile();
- }
-
- @Override
- public boolean isDirectory() {
- return file.isDirectory();
- }
-
- @Override
- public boolean exists() {
- return file.exists();
- }
-
- @Override
- public Reader createReader() throws FileNotFoundException {
- try {
- if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
- return IOUtils.createReader(file, "UTF-8");
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public InputStream createInputStream() throws FileNotFoundException {
- try {
- if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
- return new BufferedInputStream(new FileInputStream(file));
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public ApplicationFile createDirectory() {
- file.mkdirs();
- return this;
- }
-
- @Override
- public ApplicationFile writeFile(Reader input) {
- try {
- IOUtils.writeFile(file, IOUtils.readAll(input), false);
- return this;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public ApplicationFile appendFile(String value) {
- try {
- IOUtils.writeFile(file, value, true);
- return this;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- @Override
- public List<ApplicationFile> listFiles(PathFilter filter) {
- if ( ! isDirectory()) return Collections.emptyList();
- return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString())))
- .map(f -> new StoringApplicationPackageFile(asApplicationRelativePath(f),
- root))
- .collect(Collectors.toList());
- }
-
- @Override
- public ApplicationFile delete() {
- file.delete();
- return this;
- }
-
- @Override
- public MetaData getMetaData() {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int compareTo(ApplicationFile other) {
- return this.getPath().getName().compareTo((other).getPath().getName());
- }
-
- /** Strips the application package root path prefix from the path of the given file */
- private Path asApplicationRelativePath(File file) {
- Path path = Path.fromString(file.toString());
-
- Iterator<String> pathIterator = path.iterator();
- // Skip the path elements this shares with the root
- for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) {
- String rootElement = rootIterator.next();
- String pathElement = pathIterator.next();
- if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken");
- }
- // Build a path from the remaining
- Path relative = Path.fromString("");
- while (pathIterator.hasNext())
- relative = relative.append(pathIterator.next());
- return relative;
- }
-
- }
-
}