summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java201
1 files changed, 99 insertions, 102 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index e2236feb336..adf5c81283e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -13,6 +13,7 @@ import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
@@ -139,7 +140,7 @@ public class ConvertedModel {
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
RankingExpression expression = selectExpression(arguments);
if (sourceModel.isPresent()) // we can verify
- verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
+ verifyRequiredFunctions(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
return expression.getRoot();
}
@@ -183,41 +184,41 @@ public class ConvertedModel {
QueryProfileRegistry queryProfiles,
ModelStore store) {
// Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
+ Set<String> constantsReplacedByFunctions = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
+ constantsReplacedByFunctions, k, v));
- // Add macros
- addGeneratedMacros(model, profile);
+ // Add functions
+ addGeneratedFunctions(model, profile);
// Add expressions
Map<String, RankingExpression> expressions = new HashMap<>();
for (Pair<String, RankingExpression> output : model.outputExpressions()) {
addExpression(output.getSecond(), output.getFirst(),
- constantsReplacedByMacros,
+ constantsReplacedByFunctions,
model, store, profile, queryProfiles,
expressions);
}
- // Transform and save macro - must come after reading expressions due to optimization transforms
- // and must use the macro expression added to the profile, which may differ from the one saved in the model,
+ // Transform and save function - must come after reading expressions due to optimization transforms
+ // and must use the function expression added to the profile, which may differ from the one saved in the model,
// after rewrite
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k,
- profile.getMacros().get(k).getRankingExpression()));
+ model.functions().forEach((k, v) -> transformGeneratedFunction(store, constantsReplacedByFunctions, k,
+ profile.getFunctions().get(k).function().getBody()));
return expressions;
}
private static void addExpression(RankingExpression expression,
String expressionName,
- Set<String> constantsReplacedByMacros,
+ Set<String> constantsReplacedByFunctions,
ImportedModel model,
ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
Map<String, RankingExpression> expressions) {
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions);
reduceBatchDimensions(expression, model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
@@ -232,8 +233,8 @@ public class ConvertedModel {
profile.rankingConstants().add(constant);
}
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ for (Pair<String, RankingExpression> function : store.readFunctions()) {
+ addGeneratedFunctionToProfile(profile, function.getFirst(), function.getSecond());
}
return store.readExpressions();
@@ -247,16 +248,16 @@ public class ConvertedModel {
private static void transformLargeConstant(ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
+ Set<String> constantsReplacedByFunctions,
String constantName,
Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- typeMismatchExplanation(constantValue.type(), macroType));
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName);
+ if (rankingExpressionFunctionOverridingConstant != null) {
+ TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles));
+ if ( ! functionType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Function '" + constantName + "' replaces the constant with this name. " +
+ typeMismatchExplanation(constantValue.type(), functionType));
+ constantsReplacedByFunctions.add(constantName); // will replace constant(constantName) by constantName later
}
else {
Path constantPath = store.writeLargeConstant(constantName, constantValue);
@@ -267,79 +268,75 @@ public class ConvertedModel {
}
}
- private static void transformGeneratedMacro(ModelStore store,
- Set<String> constantsReplacedByMacros,
- String macroName,
- RankingExpression expression) {
+ private static void transformGeneratedFunction(ModelStore store,
+ Set<String> constantsReplacedByFunctions,
+ String functionName,
+ RankingExpression expression) {
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
+ expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions);
+ store.writeFunction(functionName, expression);
}
- private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression))
- throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile +
+ private static void addGeneratedFunctionToProfile(RankProfile profile, String functionName, RankingExpression expression) {
+ if (profile.getFunctions().containsKey(functionName)) {
+ if ( ! profile.getFunctions().get(functionName).function().getBody().equals(expression))
+ throw new IllegalArgumentException("Generated function '" + functionName + "' already exists in " + profile +
" - with a different definition" +
- ": Has\n" + profile.getMacros().get(macroName).getRankingExpression() +
+ ": Has\n" + profile.getFunctions().get(functionName).function().getBody() +
"\nwant to add " + expression + "\n");
return;
}
- RankProfile.Macro macro = profile.addMacro(macroName, false); // TODO: Inline if only used once
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
+ profile.addFunction(new ExpressionFunction(functionName, expression), false); // TODO: Inline if only used once
}
/**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
+ * Verify that the functions referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredFunctions.
*/
- private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
+ private static void verifyRequiredFunctions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> functionNames = new HashSet<>();
+ addFunctionNamesIn(expression.getRoot(), functionNames, model);
+ for (String functionName : functionNames) {
+ TensorType requiredType = model.requiredFunctions().get(functionName);
+ if (requiredType == null) continue; // Not a required function
+
+ RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
+ if (rankingExpressionFunction == null)
+ throw new IllegalArgumentException("Model refers input '" + functionName +
+ "' of type " + requiredType + " but this function is not present in " +
profile);
// TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
// phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
+ // Or, probably better, annotate the functions with type constraints here and verify during general
// type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles));
if ( actualType == null)
- throw new IllegalArgumentException("Model refers input '" + macroName +
+ throw new IllegalArgumentException("Model refers input '" + functionName +
"' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
+ " which must be produced by a function in the rank profile, but " +
+ "this function references a feature which is not declared");
if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers input '" + macroName + "'. " +
+ throw new IllegalArgumentException("Model refers input '" + functionName + "'. " +
typeMismatchExplanation(requiredType, actualType));
}
}
private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
- return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
+ return "The required type of this is " + requiredType + ", but this function returns " + actualType +
(actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
"in query profile types - see the documentation."
: "");
}
- /**
- * Add the generated macros to the rank profile
- */
- private static void addGeneratedMacros(ImportedModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy()));
+ /** Add the generated functions to the rank profile */
+ private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) {
+ model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy()));
}
/**
* Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
+ * function specifies that a single exemplar should be evaluated, we can
* reduce the batch dimension out.
*/
private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
@@ -347,19 +344,19 @@ public class ConvertedModel {
TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) continue;
+ // Check generated functions for inputs to reduce
+ Set<String> functionNames = new HashSet<>();
+ addFunctionNamesIn(expression.getRoot(), functionNames, model);
+ for (String functionName : functionNames) {
+ if ( ! model.functions().containsKey(functionName)) continue;
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
+ RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
+ if (rankingExpressionFunction == null) {
+ throw new IllegalArgumentException("Model refers to generated function '" + functionName +
+ "but this function is not present in " + profile);
}
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
+ functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
}
// Check expression for inputs to reduce
@@ -378,7 +375,7 @@ public class ConvertedModel {
List<ExpressionNode> children = ((TensorFunctionNode)node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ if (model.requiredFunctions().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
@@ -386,7 +383,7 @@ public class ConvertedModel {
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ if (model.requiredFunctions().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
@@ -447,47 +444,47 @@ public class ConvertedModel {
}
/**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions.
* This method does that for the given expression and returns the result.
*/
- private static RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
+ private static RankingExpression replaceConstantsByFunctions(RankingExpression expression,
+ Set<String> constantsReplacedByFunctions) {
+ if (constantsReplacedByFunctions.isEmpty()) return expression;
return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ replaceConstantsByFunctions(expression.getRoot(), constantsReplacedByFunctions));
}
- private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ private static ExpressionNode replaceConstantsByFunctions(ExpressionNode node, Set<String> constantsReplacedByFunctions) {
if (node instanceof ReferenceNode) {
Reference reference = ((ReferenceNode)node).reference();
if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
+ if (constantsReplacedByFunctions.contains(argument))
return new ReferenceNode(argument);
}
}
if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
CompositeNode composite = (CompositeNode)node;
return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .map(child -> replaceConstantsByFunctions(child, constantsReplacedByFunctions))
.collect(Collectors.toList()));
}
return node;
}
- private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ if (referenceNode.getOutput() == null) { // function references cannot specify outputs
names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ if (model.functions().containsKey(referenceNode.getName())) {
+ addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model);
}
}
}
else if (node instanceof CompositeNode) {
for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
+ addFunctionNamesIn(child, names, model);
}
}
@@ -551,19 +548,19 @@ public class ConvertedModel {
return expressions;
}
- /** Adds this macro expression to the application package so it can be read later. */
- void writeMacro(String name, RankingExpression expression) {
- application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
+ /** Adds this function expression to the application package so it can be read later. */
+ void writeFunction(String name, RankingExpression expression) {
+ application.getFile(modelFiles.functionsPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
}
- /** Reads the previously stored macro expressions for these arguments */
- List<Pair<String, RankingExpression>> readMacros() {
+ /** Reads the previously stored function expressions for these arguments */
+ List<Pair<String, RankingExpression>> readFunctions() {
try {
- ApplicationFile file = application.getFile(modelFiles.macrosPath());
+ ApplicationFile file = application.getFile(modelFiles.functionsPath());
if ( ! file.exists()) return Collections.emptyList();
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ List<Pair<String, RankingExpression>> functions = new ArrayList<>();
BufferedReader reader = new BufferedReader(file.createReader());
String line;
while (null != (line = reader.readLine())) {
@@ -571,13 +568,13 @@ public class ConvertedModel {
String name = parts[0];
try {
RankingExpression expression = new RankingExpression(parts[0], parts[1]);
- macros.add(new Pair<>(name, expression));
+ functions.add(new Pair<>(name, expression));
}
catch (ParseException e) {
throw new IllegalStateException("Could not parse " + name, e);
}
}
- return macros;
+ return functions;
}
catch (IOException e) {
throw new UncheckedIOException(e);
@@ -725,9 +722,9 @@ public class ConvertedModel {
return storedModelReplicatedPath().append("constants");
}
- /** Path to the macros file */
- public Path macrosPath() {
- return storedModelReplicatedPath().append("macros.txt");
+ /** Path to the functions file */
+ public Path functionsPath() {
+ return storedModelReplicatedPath().append("functions.txt");
}
}