summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java245
1 files changed, 113 insertions, 132 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 935b9200868..0911f567fa1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -1,6 +1,5 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -17,6 +16,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -38,6 +38,7 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
@@ -64,91 +65,49 @@ import java.util.stream.Collectors;
public class ConvertedModel {
private final String modelName;
- private final String modelDescription;
- private final ImmutableMap<String, RankingExpression> expressions;
-
- /** The source importedModel, or empty if this was created from a stored converted model */
- private final Optional<ImportedModel> sourceModel;
-
- private ConvertedModel(String modelName,
- String modelDescription,
- Map<String, RankingExpression> expressions,
- Optional<ImportedModel> sourceModel) {
- this.modelName = modelName;
- this.modelDescription = modelDescription;
- this.expressions = ImmutableMap.copyOf(expressions);
- this.sourceModel = sourceModel;
- }
+ private final Path modelPath;
/**
- * Create and store a converted model for a rank profile given from either an imported model,
- * or (if unavailable) from stored application package data.
+ * The ranking expressions of this, indexed by their name. which is a 1-3 part string separated by dots
+ * where the first part is always the model name, the second the signature or (if none)
+ * expression name (if more than one), and the third is the output name (if any).
*/
- public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) {
- File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath);
- if (sourceModel.exists())
- return fromSource(toModelName(modelPath),
- modelPath.toString(),
- context.rankProfile(),
- context.queryProfiles(),
- context.importedModels().get(sourceModel)); // TODO: Convert to name here, make sure its done just one way
- else
- return fromStore(toModelName(modelPath),
- modelPath.toString(),
- context.rankProfile());
- }
-
- public static ConvertedModel fromSource(String modelName,
- String modelDescription,
- RankProfile rankProfile,
- QueryProfileRegistry queryProfileRegistry,
- ImportedModel importedModel) {
- ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
- return new ConvertedModel(modelName,
- modelDescription,
- convertAndStore(importedModel, rankProfile, queryProfileRegistry, modelStore),
- Optional.of(importedModel));
- }
-
- public static ConvertedModel fromStore(String modelName,
- String modelDescription,
- RankProfile rankProfile) {
- ModelStore modelStore = new ModelStore(rankProfile.applicationPackage(), modelName);
- return new ConvertedModel(modelName,
- modelDescription,
- convertStored(modelStore, rankProfile),
- Optional.empty());
- }
+ private final Map<String, RankingExpression> expressions;
/**
- * Returns all the output expressions of this indexed by name. The names consist of one or two parts
- * separated by dot, where the first part is the signature name
- * if signatures are used, or the expression name if signatures are not used and there are multiple
- * expressions, and the second is the output name if signature names are used.
+ * Create a converted model for a rank profile given from either an imported model,
+ * or (if unavailable) from stored application package data.
*/
- public Map<String, RankingExpression> expressions() { return expressions; }
+ public ConvertedModel(Path modelPath, RankProfileTransformContext context) {
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
+ ModelStore store = new ModelStore(context.rankProfile().applicationPackage(), modelPath);
+ if ( store.hasSourceModel())
+ expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), context.importedModels());
+ else
+ expressions = transformFromStoredModel(store, context.rankProfile());
+ }
- /**
- * Returns the expression matching the given arguments.
- */
- public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
- RankingExpression expression = selectExpression(arguments);
- if (sourceModel.isPresent()) // we can verify
- verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
- return expression.getRoot();
+ private Map<String, RankingExpression> convertModel(ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ImportedModels importedModels) {
+ ImportedModel model = importedModels.get(store.sourceModelFile());
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- private RankingExpression selectExpression(FeatureArguments arguments) {
+ /** Returns the expression matching the given arguments */
+ public ExpressionNode expression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
RankingExpression expression = expressions.get(arguments.toName());
- if (expression != null) return expression;
+ if (expression != null) return expression.getRoot();
if ( ! arguments.signature().isPresent()) {
if (expressions.size() > 1)
throw new IllegalArgumentException("Multiple candidate expressions " + missingExpressionMessageSuffix());
- return expressions.values().iterator().next();
+ return expressions.values().iterator().next().getRoot();
}
if ( ! arguments.output().isPresent()) {
@@ -160,23 +119,21 @@ public class ConvertedModel {
if (entriesWithTheRightPrefix.size() > 1)
throw new IllegalArgumentException("Multiple candidate expression named '" + arguments.signature().get() +
missingExpressionMessageSuffix());
- return entriesWithTheRightPrefix.get(0).getValue();
+ return entriesWithTheRightPrefix.get(0).getValue().getRoot();
}
throw new IllegalArgumentException("No expression '" + arguments.toName() + missingExpressionMessageSuffix());
}
private String missingExpressionMessageSuffix() {
- return "' in model '" + modelDescription + "'. " +
+ return "' in model '" + this.modelPath + "'. " +
"Available expressions: " + expressions.keySet().stream().collect(Collectors.joining(", "));
}
- // ----------------------- Static model conversion/storage below here
-
- private static Map<String, RankingExpression> convertAndStore(ImportedModel model,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelStore store) {
+ private Map<String, RankingExpression> transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
// Add constants
Set<String> constantsReplacedByMacros = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -204,21 +161,22 @@ public class ConvertedModel {
return expressions;
}
- private static void addExpression(RankingExpression expression,
- String expressionName,
- Set<String> constantsReplacedByMacros,
- ImportedModel model,
- ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- Map<String, RankingExpression> expressions) {
+ private void addExpression(RankingExpression expression,
+ String expressionName,
+ Set<String> constantsReplacedByMacros,
+ ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ Map<String, RankingExpression> expressions) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
reduceBatchDimensions(expression, model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) {
+ private Map<String, RankingExpression> transformFromStoredModel(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -234,12 +192,12 @@ public class ConvertedModel {
return store.readExpressions();
}
- private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
store.writeSmallConstant(constantName, constantValue);
profile.addConstant(constantName, asValue(constantValue));
}
- private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
Set<String> constantsReplacedByMacros,
String constantName, Tensor constantValue) {
RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
@@ -259,7 +217,7 @@ public class ConvertedModel {
}
}
- private static void transformGeneratedMacro(ModelStore store,
+ private void transformGeneratedMacro(ModelStore store,
Set<String> constantsReplacedByMacros,
String macroName,
RankingExpression expression) {
@@ -268,7 +226,7 @@ public class ConvertedModel {
store.writeMacro(macroName, expression);
}
- private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
if (profile.getMacros().containsKey(macroName)) {
if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression))
throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile +
@@ -285,8 +243,8 @@ public class ConvertedModel {
* Verify that the macros referred in the given expression exists in the given rank profile,
* and return tensors of the types specified in requiredMacros.
*/
- private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
@@ -314,7 +272,7 @@ public class ConvertedModel {
}
}
- private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
(actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
"in query profile types - see the documentation."
@@ -324,7 +282,7 @@ public class ConvertedModel {
/**
* Add the generated macros to the rank profile
*/
- private static void addGeneratedMacros(ImportedModel model, RankProfile profile) {
+ private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy()));
}
@@ -333,8 +291,8 @@ public class ConvertedModel {
* macro specifies that a single exemplar should be evaluated, we can
* reduce the batch dimension out.
*/
- private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
@@ -361,8 +319,8 @@ public class ConvertedModel {
expression.setRoot(root);
}
- private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
- TypeContext<Reference> typeContext) {
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
@@ -392,7 +350,7 @@ public class ConvertedModel {
return node;
}
- private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
TensorFunction result = function;
TensorType type = function.type(context);
if (type.dimensions().size() > 1) {
@@ -414,7 +372,7 @@ public class ConvertedModel {
* for any following computation of the tensor.
*/
// TODO: determine when this is not necessary!
- private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
}
@@ -441,14 +399,14 @@ public class ConvertedModel {
* If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
* This method does that for the given expression and returns the result.
*/
- private static RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
Set<String> constantsReplacedByMacros) {
if (constantsReplacedByMacros.isEmpty()) return expression;
return new RankingExpression(expression.getName(),
replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
}
- private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
if (node instanceof ReferenceNode) {
Reference reference = ((ReferenceNode)node).reference();
if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
@@ -466,7 +424,7 @@ public class ConvertedModel {
return node;
}
- private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
@@ -482,7 +440,7 @@ public class ConvertedModel {
}
}
- private static Value asValue(Tensor tensor) {
+ private Value asValue(Tensor tensor) {
if (tensor.type().rank() == 0)
return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
else
@@ -497,13 +455,6 @@ public class ConvertedModel {
public String toString() { return "model '" + modelName + "'"; }
/**
- * Returns the directory which contains the source model to use for these arguments
- */
- public static File sourceModelFile(ApplicationPackage application, Path sourceModelPath) {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(sourceModelPath));
- }
-
- /**
* Provides read/write access to the correct directories of the application package given by the feature arguments
*/
static class ModelStore {
@@ -511,9 +462,20 @@ public class ConvertedModel {
private final ApplicationPackage application;
private final ModelFiles modelFiles;
- ModelStore(ApplicationPackage application, String modelName) {
+ ModelStore(ApplicationPackage application, Path modelPath) {
this.application = application;
- this.modelFiles = new ModelFiles(modelName);
+ this.modelFiles = new ModelFiles(modelPath);
+ }
+
+ public boolean hasSourceModel() {
+ return sourceModelFile().exists();
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File sourceModelFile() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(modelFiles.modelPath()));
}
/**
@@ -586,7 +548,7 @@ public class ConvertedModel {
List<RankingConstant> readLargeConstants() {
try {
List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsInfoPath()).listFiles()) {
+ for (ApplicationFile constantFile : application.getFile(modelFiles.largeConstantsPath()).listFiles()) {
String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
}
@@ -604,13 +566,13 @@ public class ConvertedModel {
* @return the path to the stored constant, relative to the application package root
*/
Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = modelFiles.largeConstantsContentPath();
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(modelFiles.modelPath()).append("constants");
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
// Remember the constant in a file we replicate in ZooKeeper
- application.getFile(modelFiles.largeConstantsInfoPath().append(name + ".constant"))
+ application.getFile(modelFiles.largeConstantsPath().append(name + ".constant"))
.writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
// Write content explicitly as a file on the file system as this is distributed using file distribution
@@ -647,8 +609,8 @@ public class ConvertedModel {
public void writeSmallConstant(String name, Tensor constant) {
// Secret file format for remembering constants:
application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
}
/** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
@@ -670,24 +632,40 @@ public class ConvertedModel {
}
}
+ private void close(Reader reader) {
+ try {
+ if (reader != null)
+ reader.close();
+ }
+ catch (IOException e) {
+ // ignore
+ }
+ }
+
}
static class ModelFiles {
- String modelName;
+ Path modelPath;
- public ModelFiles(String modelName) {
- this.modelName = modelName;
+ public ModelFiles(Path modelPath) {
+ this.modelPath = modelPath;
}
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+
/** Files stored below this path will be replicated in zookeeper */
public Path storedModelReplicatedPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelName);
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath());
}
- /** Files stored below this path will not be replicated in zookeeper */
+ /** Files stored below this path will not be replicated */
public Path storedModelPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelName);
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath());
}
public Path expressionPath(String name) {
@@ -703,12 +681,7 @@ public class ConvertedModel {
}
/** Path to the large (ranking) constants directory */
- public Path largeConstantsContentPath() {
- return storedModelPath().append("constants");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsInfoPath() {
+ public Path largeConstantsPath() {
return storedModelReplicatedPath().append("constants");
}
@@ -722,19 +695,27 @@ public class ConvertedModel {
/** Encapsulates the arguments of a specific model output */
static class FeatureArguments {
+ private final String modelName;
+ private final Path modelPath;
+
/** Optional arguments */
private final Optional<String> signature, output;
public FeatureArguments(Arguments arguments) {
- this(optionalArgument(1, arguments),
+ this(Path.fromString(asString(arguments.expressions().get(0))),
+ optionalArgument(1, arguments),
optionalArgument(2, arguments));
}
- public FeatureArguments(Optional<String> signature, Optional<String> output) {
+ public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) {
+ this.modelPath = modelPath;
+ this.modelName = toModelName(modelPath);
this.signature = signature;
this.output = output;
}
+ public Path modelPath() { return modelPath; }
+
public Optional<String> signature() { return signature; }
public Optional<String> output() { return output; }