aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-30 22:48:17 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-30 22:48:17 +0200
commit8a1c7bb53923ba9fa01c3c6ceba2046be0530d11 (patch)
tree24830adc8d42e2d9735eefeffa23ae3eb113f62a /config-model/src/main
parent6250e2b6c78e5ee27690b5071f68ae510e8c113f (diff)
Revert "Merge pull request #6742 from vespa-engine/revert-6732-bratseth/generate-rank-profiles-for-all-models-part-10-2"
This reverts commit a294ef166c59c795f9e6fd31fbd6914c502d559a, reversing changes made to cef4c0f9d7c084f320e77abb2a93522acd7f3f53.
Diffstat (limited to 'config-model/src/main')
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java126
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java258
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java35
6 files changed, 300 insertions, 133 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
index 7404ae14a5d..757dab4cbf3 100644
--- a/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
+++ b/config-model/src/main/java/com/yahoo/config/model/test/MockApplicationPackage.java
@@ -16,11 +16,17 @@ import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.vespa.config.ConfigDefinitionKey;
import com.yahoo.config.application.api.ApplicationPackage;
+import java.io.BufferedInputStream;
import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
import java.io.IOException;
+import java.io.InputStream;
import java.io.Reader;
import java.io.StringReader;
+import java.io.UncheckedIOException;
import java.util.*;
+import java.util.stream.Collectors;
/**
* For testing purposes only
@@ -113,7 +119,7 @@ public class MockApplicationPackage implements ApplicationPackage {
@Override
public ApplicationFile getFile(Path file) {
- throw new UnsupportedOperationException();
+ return new MockApplicationFile(file, Path.fromString(root.toString()));
}
@Override
@@ -300,4 +306,122 @@ public class MockApplicationPackage implements ApplicationPackage {
return xmlStringWithIdAttribute.substring(idStart + 4, idEnd - 1);
}
+ public static class MockApplicationFile extends ApplicationFile {
+
+ /** The path to the application package root */
+ private final Path root;
+
+ /** The File pointing to the actual file represented by this */
+ private final File file;
+
+ public MockApplicationFile(Path filePath, Path applicationPackagePath) {
+ super(filePath);
+ this.root = applicationPackagePath;
+ file = applicationPackagePath.append(filePath).toFile();
+ }
+
+ @Override
+ public boolean isDirectory() {
+ return file.isDirectory();
+ }
+
+ @Override
+ public boolean exists() {
+ return file.exists();
+ }
+
+ @Override
+ public Reader createReader() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return IOUtils.createReader(file, "UTF-8");
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws FileNotFoundException {
+ try {
+ if ( ! exists()) throw new FileNotFoundException("File '" + file + "' does not exist");
+ return new BufferedInputStream(new FileInputStream(file));
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public ApplicationFile createDirectory() {
+ file.mkdirs();
+ return this;
+ }
+
+ @Override
+ public ApplicationFile writeFile(Reader input) {
+ try {
+ IOUtils.writeFile(file, IOUtils.readAll(input), false);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public ApplicationFile appendFile(String value) {
+ try {
+ IOUtils.writeFile(file, value, true);
+ return this;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ @Override
+ public List<ApplicationFile> listFiles(PathFilter filter) {
+ if ( ! isDirectory()) return Collections.emptyList();
+ return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString())))
+ .map(f -> new MockApplicationFile(asApplicationRelativePath(f), root))
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public ApplicationFile delete() {
+ file.delete();
+ return this;
+ }
+
+ @Override
+ public MetaData getMetaData() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int compareTo(ApplicationFile other) {
+ return this.getPath().getName().compareTo((other).getPath().getName());
+ }
+
+ /** Strips the application package root path prefix from the path of the given file */
+ private Path asApplicationRelativePath(File file) {
+ Path path = Path.fromString(file.toString());
+
+ Iterator<String> pathIterator = path.iterator();
+ // Skip the path elements this shares with the root
+ for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) {
+ String rootElement = rootIterator.next();
+ String pathElement = pathIterator.next();
+ if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken");
+ }
+ // Build a path from the remaining
+ Path relative = Path.fromString("");
+ while (pathIterator.hasNext())
+ relative = relative.append(pathIterator.next());
+ return relative;
+ }
+
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 0911f567fa1..f7a06f86ab7 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -1,5 +1,6 @@
package com.yahoo.searchdefinition.expressiontransforms;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -16,7 +17,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -38,7 +38,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
@@ -65,49 +64,91 @@ import java.util.stream.Collectors;
public class ConvertedModel {
private final String modelName;
- private final Path modelPath;
-
- /**
- * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots
- * where the first part is always the model name, the second the signature or (if none)
- * expression name (if more than one), and the third is the output name (if any).
- */
- private final Map<String, RankingExpression> expressions;
+ private final String modelDescription;
+ private final ImmutableMap<String, RankingExpression> expressions;
+
+ /** The source importedModel, or empty if this was created from a stored converted model */
+ private final Optional<ImportedModel> sourceModel;
+
+ private ConvertedModel(String modelName,
+ String modelDescription,
+ Map<String, RankingExpression> expressions,
+ Optional<ImportedModel> sourceModel) {
+ this.modelName = modelName;
+ this.modelDescription = modelDescription;
+ this.expressions = ImmutableMap.copyOf(expressions);
+ this.sourceModel = sourceModel;
+ }
/**
- * Create a converted model for a rank profile given from either an imported model,
+ * Create and store a converted model for a rank profile given from either an imported model,
* or (if unavailable) from stored application package data.
*/
- public ConvertedModel(Path modelPath, RankProfileTransformContext context) {
- this.modelPath = modelPath;
- this.modelName = toModelName(modelPath);
- ModelStore store = new ModelStore(context.rankProfile().applicationPackage(), modelPath);
- if ( store.hasSourceModel())
- expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels());
+ public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) {
+ File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath);
+ if (sourceModel.exists())
+ return fromSource(toModelName(modelPath),
+ modelPath.toString(),
+ context.rankProfile(),
+ context.queryProfiles(),
+ context.importedModels().get(sourceModel)); // TODO: Convert to name here, make sure its done just one way
else
- expressions = transformFromStoredModel(store, context.rankProfile());
+ return fromStore(toModelName(modelPath),
+ modelPath.toString(),
+ context.rankProfile());
+ }
+
+ public static ConvertedModel fromSource(String modelName,
+ String modelDescription,
+ RankProfile rankProfile,
+ QueryProfileRegistry queryProfileRegistry,
+ ImportedModel importedModel) {
+ ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
+ return new ConvertedModel(modelName,
+ modelDescription,
+ convertAndStore(importedModel, rankProfile, queryProfileRegistry, modelStore),
+ Optional.of(importedModel));
}
- private Map<String, RankingExpression> convertModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ImportedModels importedModels) {
- ImportedModel model = importedModels.get(store.sourceModelFile());
- return transformFromImportedModel(model, store, profile, queryProfiles);
+ public static ConvertedModel fromStore(String modelName,
+ String modelDescription,
+ RankProfile rankProfile) {
+ ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
+ return new ConvertedModel(modelName,
+ modelDescription,
+ convertStored(modelStore, rankProfile),
+ Optional.empty());
}
- /** Returns the expression matching the given arguments */
- public ExpressionNode expression(FeatureArguments arguments) {
+ /**
+ * Returns all the output expressions of this indexed by name. The names consist of one or two parts
+ * separated by dot, where the first part is the signature name
+ * if signatures are used, or the expression name if signatures are not used and there are multiple
+ * expressions, and the second is the output name if signature names are used.
+ */
+ public Map<String, RankingExpression> expressions() { return expressions; }
+
+ /**
+ * Returns the expression matching the given arguments.
+ */
+ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
+ RankingExpression expression = selectExpression(arguments);
+ if (sourceModel.isPresent()) // we can verify
+ verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
+ return expression.getRoot();
+ }
+
+ private RankingExpression selectExpression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
RankingExpression expression = expressions.get(arguments.toName());
- if (expression != null) return expression.getRoot();
+ if (expression != null) return expression;
if ( ! arguments.signature().isPresent()) {
if (expressions.size() > 1)
throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix());
- return expressions.values().iterator().next().getRoot();
+ return expressions.values().iterator().next();
}
if ( ! arguments.output().isPresent()) {
@@ -119,21 +160,23 @@ public class ConvertedModel {
if (entriesWithTheRightPrefix.size() > 1)
throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() +
missingExpressionMessageSuffix());
- return entriesWithTheRightPrefix.get(0).getValue().getRoot();
+ return entriesWithTheRightPrefix.get(0).getValue();
}
throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix());
}
private String missingExpressionMessageSuffix() {
- return "' in model '" + this.modelPath + "'. " +
+ return "' in model '" + modelDescription + "'. " +
"Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", "));
}
- private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
+ // ----------------------- Static model conversion/storage below here
+
+ private static Map<String, RankingExpression> convertAndStore(ImportedModel model,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelStore store) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -161,22 +204,21 @@ public class ConvertedModel {
return expressions;
}
- private void addExpression(RankingExpression expression,
- String expressionName,
- Set<String> constantsReplacedByMacros,
- ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- Map<String, RankingExpression> expressions) {
+ private static void addExpression(RankingExpression expression,
+ String expressionName,
+ Set<String> constantsReplacedByMacros,
+ ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ Map<String, RankingExpression> expressions) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
reduceBatchDimensions(expression, model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) {
+ private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -192,12 +234,12 @@ public class ConvertedModel {
return store.readExpressions();
}
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
store.writeSmallConstant(constantName, constantValue);
profile.addConstant(constantName, asValue(constantValue));
}
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
Set<String> constantsReplacedByMacros,
String constantName, Tensor constantValue) {
RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
@@ -217,7 +259,7 @@ public class ConvertedModel {
}
}
- private void transformGeneratedMacro(ModelStore store,
+ private static void transformGeneratedMacro(ModelStore store,
Set<String> constantsReplacedByMacros,
String macroName,
RankingExpression expression) {
@@ -226,15 +268,16 @@ public class ConvertedModel {
store.writeMacro(macroName, expression);
}
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
if (profile.getMacros().containsKey(macroName)) {
if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression))
throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile +
- " - with a different definition");
+ " - with a different definition" +
+ ": Has\n" + profile.getMacros().get(macroName).getRankingExpression() +
+ "\nwant to add " + expression + "\n");
return;
}
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
+ RankProfile.Macro macro = profile.addMacro(macroName, false); // TODO: Inline if only used once
macro.setRankingExpression(expression);
macro.setTextualExpression(expression.getRoot().toString());
}
@@ -243,8 +286,8 @@ public class ConvertedModel {
* Verify that the macros referred in the given expression exists in the given rank profile,
* and return tensors of the types specified in requiredMacros.
*/
- private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
@@ -272,7 +315,7 @@ public class ConvertedModel {
}
}
- private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
(actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
"in query profile types - see the documentation."
@@ -282,7 +325,7 @@ public class ConvertedModel {
/**
* Add the generated macros to the rank profile
*/
- private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
+ private static void addGeneratedMacros(ImportedModel model, RankProfile profile) {
model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy()));
}
@@ -291,8 +334,8 @@ public class ConvertedModel {
* macro specifies that a single exemplar should be evaluated, we can
* reduce the batch dimension out.
*/
- private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
@@ -319,8 +362,8 @@ public class ConvertedModel {
expression.setRoot(root);
}
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
- TypeContext<Reference> typeContext) {
+ private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
@@ -350,7 +393,7 @@ public class ConvertedModel {
return node;
}
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
TensorFunction result = function;
TensorType type = function.type(context);
if (type.dimensions().size() > 1) {
@@ -372,7 +415,7 @@ public class ConvertedModel {
* for any following computation of the tensor.
*/
// TODO: determine when this is not necessary!
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
}
@@ -399,14 +442,14 @@ public class ConvertedModel {
* If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
* This method does that for the given expression and returns the result.
*/
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ private static RankingExpression replaceConstantsByMacros(RankingExpression expression,
Set<String> constantsReplacedByMacros) {
if (constantsReplacedByMacros.isEmpty()) return expression;
return new RankingExpression(expression.getName(),
replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
}
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
if (node instanceof ReferenceNode) {
Reference reference = ((ReferenceNode)node).reference();
if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
@@ -424,7 +467,7 @@ public class ConvertedModel {
return node;
}
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
@@ -440,7 +483,7 @@ public class ConvertedModel {
}
}
- private Value asValue(Tensor tensor) {
+ private static Value asValue(Tensor tensor) {
if (tensor.type().rank() == 0)
return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
else
@@ -455,6 +498,13 @@ public class ConvertedModel {
public String toString() { return "model '" + modelName + "'"; }
/**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public static File sourceModelFile(ApplicationPackage application, Path sourceModelPath) {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(sourceModelPath));
+ }
+
+ /**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
static class ModelStore {
@@ -462,20 +512,9 @@ public class ConvertedModel {
private final ApplicationPackage application;
private final ModelFiles modelFiles;
- ModelStore(ApplicationPackage application, Path modelPath) {
+ ModelStore(ApplicationPackage application, String modelName) {
this.application = application;
- this.modelFiles = new ModelFiles(modelPath);
- }
-
- public boolean hasSourceModel() {
- return sourceModelFile().exists();
- }
-
- /**
- * Returns the directory which contains the source model to use for these arguments
- */
- public File sourceModelFile() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath()));
+ this.modelFiles = new ModelFiles(modelName);
}
/**
@@ -508,7 +547,7 @@ public class ConvertedModel {
return expressions;
}
- /** Adds this macro expression to the application package to it can be read later. */
+ /** Adds this macro expression to the application package so it can be read later. */
void writeMacro(String name, RankingExpression expression) {
application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" +
expression.getRoot().toString() + "\n");
@@ -518,7 +557,7 @@ public class ConvertedModel {
List<Pair<String, RankingExpression>> readMacros() {
try {
ApplicationFile file = application.getFile(modelFiles.macrosPath());
- if (!file.exists()) return Collections.emptyList();
+ if ( ! file.exists()) return Collections.emptyList();
List<Pair<String, RankingExpression>> macros = new ArrayList<>();
BufferedReader reader = new BufferedReader(file.createReader());
@@ -527,7 +566,7 @@ public class ConvertedModel {
String[] parts = line.split("\t");
String name = parts[0];
try {
- RankingExpression expression = new RankingExpression(parts[1]);
+ RankingExpression expression = new RankingExpression(parts[0], parts[1]);
macros.add(new Pair<>(name, expression));
}
catch (ParseException e) {
@@ -548,7 +587,7 @@ public class ConvertedModel {
List<RankingConstant> readLargeConstants() {
try {
List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) {
+ for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsInfoPath()).listFiles()) {
String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
}
@@ -566,13 +605,13 @@ public class ConvertedModel {
* @return the path to the stored constant, relative to the application package root
*/
Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants");
+ Path constantsPath = modelFiles.largeConstantsContentPath();
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
// Remember the constant in a file we replicate in ZooKeeper
- application.getFile(modelFiles.largeConstantsPath().append(name + ".constant"))
+ application.getFile(modelFiles.largeConstantsInfoPath().append(name + ".constant"))
.writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
@@ -609,8 +648,8 @@ public class ConvertedModel {
public void writeSmallConstant(String name, Tensor constant) {
// Secret file format for remembering constants:
application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
}
/** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
@@ -632,40 +671,24 @@ public class ConvertedModel {
}
}
- private void close(Reader reader) {
- try {
- if (reader != null)
- reader.close();
- }
- catch (IOException e) {
- // ignore
- }
- }
-
}
static class ModelFiles {
- Path modelPath;
+ String modelName;
- public ModelFiles(Path modelPath) {
- this.modelPath = modelPath;
+ public ModelFiles(String modelName) {
+ this.modelName = modelName;
}
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
-
/** Files stored below this path will be replicated in zookeeper */
public Path storedModelReplicatedPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath());
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName);
}
- /** Files stored below this path will not be replicated */
+ /** Files stored below this path will not be replicated in zookeeper */
public Path storedModelPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath());
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName);
}
public Path expressionPath(String name) {
@@ -681,7 +704,12 @@ public class ConvertedModel {
}
/** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
+ public Path largeConstantsContentPath() {
+ return storedModelPath().append("constants");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsInfoPath() {
return storedModelReplicatedPath().append("constants");
}
@@ -695,27 +723,19 @@ public class ConvertedModel {
/** Encapsulates the arguments of a specific model output */
static class FeatureArguments {
- private final String modelName;
- private final Path modelPath;
-
/** Optional arguments */
private final Optional<String> signature, output;
public FeatureArguments(Arguments arguments) {
- this(Path.fromString(asString(arguments.expressions().get(0))),
- optionalArgument(1, arguments),
+ this(optionalArgument(1, arguments),
optionalArgument(2, arguments));
}
- public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
- this.modelPath = modelPath;
- this.modelName = toModelName(modelPath);
+ public FeatureArguments(Optional<String> signature, Optional<String> output) {
this.signature = signature;
this.output = output;
}
- public Path modelPath() { return modelPath; }
-
public Optional<String> signature() { return signature; }
public Optional<String> output() { return output; }
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 36dc200f3c9..229ae0ebaaf 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -13,6 +13,7 @@ import java.io.File;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
/**
* Replaces instances of the onnx(model-path, output)
@@ -41,10 +42,11 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
+ // TODO: Put modelPath in FeatureArguments instead
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedOnnxModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()));
+ convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 619c13da764..bcb8ef1521d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -41,8 +41,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()));
+ convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index e6b08ab0350..b4a5069b9d6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -43,8 +43,8 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedXGBoostModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
- return convertedModel.expression(asFeatureArguments(feature.getArguments()));
+ convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
} catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 73dd60f63eb..3e9d188670e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -8,6 +8,7 @@ import com.yahoo.config.ConfigInstance;
import com.yahoo.config.ConfigInstance.Builder;
import com.yahoo.config.ConfigurationRuntimeException;
import com.yahoo.config.FileReference;
+import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.application.api.ValidationId;
@@ -26,11 +27,13 @@ import com.yahoo.config.model.producer.AbstractConfigProducerRoot;
import com.yahoo.config.model.producer.UserConfigRepo;
import com.yahoo.config.provision.AllocatedHosts;
import com.yahoo.log.LogLevel;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.RankingConstants;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RankProfileList;
+import com.yahoo.searchdefinition.expressiontransforms.ConvertedModel;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
@@ -162,7 +165,9 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
this.applicationPackage = deployState.getApplicationPackage();
root = builder.getRoot(VespaModel.ROOT_CONFIGID, deployState, this);
- createGlobalRankProfiles(deployState.getImportedModels(), deployState.rankProfileRegistry());
+ createGlobalRankProfiles(deployState.getImportedModels(),
+ deployState.rankProfileRegistry(),
+ deployState.getQueryProfiles().getRegistry());
this.rankProfileList = new RankProfileList(null, // null search -> global
AttributeFields.empty,
deployState.rankProfileRegistry(),
@@ -220,14 +225,30 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
* Creates a rank profile not attached to any search definition, for each imported model in the application package
*/
private ImmutableList<RankProfile> createGlobalRankProfiles(ImportedModels importedModels,
- RankProfileRegistry rankProfileRegistry) {
+ RankProfileRegistry rankProfileRegistry,
+ QueryProfileRegistry queryProfiles) {
List<RankProfile> profiles = new ArrayList<>();
- for (ImportedModel model : importedModels.all()) {
- RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
- for (Pair<String, RankingExpression> entry : model.outputExpressions()) {
- profile.addMacro(entry.getFirst(), false).setRankingExpression(entry.getSecond());
+ if ( ! importedModels.all().isEmpty()) { // models/ directory is available
+ for (ImportedModel model : importedModels.all()) {
+ RankProfile profile = new RankProfile(model.name(), this, rankProfileRegistry);
+ rankProfileRegistry.add(profile);
+ ConvertedModel convertedModel = ConvertedModel.fromSource(model.name(), model.name(), profile, queryProfiles, model);
+ for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
+ profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
+ }
+ }
+ }
+ else { // generated and stored model information may be available instead
+ ApplicationFile generatedModelsDir = applicationPackage.getFile(ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR);
+ for (ApplicationFile generatedModelDir : generatedModelsDir.listFiles()) {
+ String modelName = generatedModelDir.getPath().last();
+ RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
+ rankProfileRegistry.add(profile);
+ ConvertedModel convertedModel = ConvertedModel.fromStore(modelName, modelName, profile);
+ for (Map.Entry<String, RankingExpression> entry : convertedModel.expressions().entrySet()) {
+ profile.addMacro(entry.getKey(), false).setRankingExpression(entry.getValue());
+ }
}
- rankProfileRegistry.add(profile);
}
return ImmutableList.copyOf(profiles);
}