diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 201 |
1 files changed, 99 insertions, 102 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index e2236feb336..adf5c81283e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -13,6 +13,7 @@ import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; @@ -139,7 +140,7 @@ public class ConvertedModel { public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { RankingExpression expression = selectExpression(arguments); if (sourceModel.isPresent()) // we can verify - verifyRequiredMacros(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles()); + verifyRequiredFunctions(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles()); return expression.getRoot(); } @@ -183,41 +184,41 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, ModelStore store) { // Add constants - Set<String> constantsReplacedByMacros = new HashSet<>(); + Set<String> constantsReplacedByFunctions = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, - constantsReplacedByMacros, k, v)); + constantsReplacedByFunctions, k, v)); - // Add macros - addGeneratedMacros(model, profile); + // Add functions + addGeneratedFunctions(model, profile); // Add expressions Map<String, RankingExpression> expressions = new HashMap<>(); for (Pair<String, RankingExpression> output : model.outputExpressions()) { addExpression(output.getSecond(), output.getFirst(), - constantsReplacedByMacros, + constantsReplacedByFunctions, model, store, profile, queryProfiles, expressions); } - // Transform and save macro - must come after reading expressions due to optimization transforms - // and must use the macro expression added to the profile, which may differ from the one saved in the model, + // Transform and save function - must come after reading expressions due to optimization transforms + // and must use the function expression added to the profile, which may differ from the one saved in the model, // after rewrite - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, - profile.getMacros().get(k).getRankingExpression())); + model.functions().forEach((k, v) -> transformGeneratedFunction(store, constantsReplacedByFunctions, k, + profile.getFunctions().get(k).function().getBody())); return expressions; } private static void addExpression(RankingExpression expression, String expressionName, - Set<String> constantsReplacedByMacros, + Set<String> constantsReplacedByFunctions, ImportedModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, Map<String, RankingExpression> expressions) { - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); + expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions); reduceBatchDimensions(expression, model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); @@ -232,8 +233,8 @@ public class ConvertedModel { profile.rankingConstants().add(constant); } - for (Pair<String, RankingExpression> macro : store.readMacros()) { - addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + for (Pair<String, RankingExpression> function : store.readFunctions()) { + addGeneratedFunctionToProfile(profile, function.getFirst(), function.getSecond()); } return store.readExpressions(); @@ -247,16 +248,16 @@ public class ConvertedModel { private static void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, + Set<String> constantsReplacedByFunctions, String constantName, Tensor constantValue) { - RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); - if (macroOverridingConstant != null) { - TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); - if ( ! macroType.equals(constantValue.type())) - throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - typeMismatchExplanation(constantValue.type(), macroType)); - constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName); + if (rankingExpressionFunctionOverridingConstant != null) { + TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles)); + if ( ! functionType.equals(constantValue.type())) + throw new IllegalArgumentException("Function '" + constantName + "' replaces the constant with this name. " + + typeMismatchExplanation(constantValue.type(), functionType)); + constantsReplacedByFunctions.add(constantName); // will replace constant(constantName) by constantName later } else { Path constantPath = store.writeLargeConstant(constantName, constantValue); @@ -267,79 +268,75 @@ public class ConvertedModel { } } - private static void transformGeneratedMacro(ModelStore store, - Set<String> constantsReplacedByMacros, - String macroName, - RankingExpression expression) { + private static void transformGeneratedFunction(ModelStore store, + Set<String> constantsReplacedByFunctions, + String functionName, + RankingExpression expression) { - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - store.writeMacro(macroName, expression); + expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions); + store.writeFunction(functionName, expression); } - private static void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { - if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression)) - throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile + + private static void addGeneratedFunctionToProfile(RankProfile profile, String functionName, RankingExpression expression) { + if (profile.getFunctions().containsKey(functionName)) { + if ( ! profile.getFunctions().get(functionName).function().getBody().equals(expression)) + throw new IllegalArgumentException("Generated function '" + functionName + "' already exists in " + profile + " - with a different definition" + - ": Has\n" + profile.getMacros().get(macroName).getRankingExpression() + + ": Has\n" + profile.getFunctions().get(functionName).function().getBody() + "\nwant to add " + expression + "\n"); return; } - RankProfile.Macro macro = profile.addMacro(macroName, false); // TODO: Inline if only used once - macro.setRankingExpression(expression); - macro.setTextualExpression(expression.getRoot().toString()); + profile.addFunction(new ExpressionFunction(functionName, expression), false); // TODO: Inline if only used once } /** - * Verify that the macros referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredMacros. + * Verify that the functions referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredFunctions. */ - private static void verifyRequiredMacros(RankingExpression expression, ImportedModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - TensorType requiredType = model.requiredMacros().get(macroName); - if (requiredType == null) continue; // Not a required macro - - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) - throw new IllegalArgumentException("Model refers input '" + macroName + - "' of type " + requiredType + " but this macro is not present in " + + private static void verifyRequiredFunctions(RankingExpression expression, ImportedModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + Set<String> functionNames = new HashSet<>(); + addFunctionNamesIn(expression.getRoot(), functionNames, model); + for (String functionName : functionNames) { + TensorType requiredType = model.requiredFunctions().get(functionName); + if (requiredType == null) continue; // Not a required function + + RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); + if (rankingExpressionFunction == null) + throw new IllegalArgumentException("Model refers input '" + functionName + + "' of type " + requiredType + " but this function is not present in " + profile); // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second // phase and summary features), as it may only resolve correctly given those bindings - // Or, probably better, annotate the macros with type constraints here and verify during general + // Or, probably better, annotate the functions with type constraints here and verify during general // type verification - TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); + TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) - throw new IllegalArgumentException("Model refers input '" + macroName + + throw new IllegalArgumentException("Model refers input '" + functionName + "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro references a feature which is not declared"); + " which must be produced by a function in the rank profile, but " + + "this function references a feature which is not declared"); if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers input '" + macroName + "'. " + + throw new IllegalArgumentException("Model refers input '" + functionName + "'. " + typeMismatchExplanation(requiredType, actualType)); } } private static String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { - return "The required type of this is " + requiredType + ", but this macro returns " + actualType + + return "The required type of this is " + requiredType + ", but this function returns " + actualType + (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + "in query profile types - see the documentation." : ""); } - /** - * Add the generated macros to the rank profile - */ - private static void addGeneratedMacros(ImportedModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy())); + /** Add the generated functions to the rank profile */ + private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) { + model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy())); } /** * Check if batch dimensions of inputs can be reduced out. If the input - * macro specifies that a single exemplar should be evaluated, we can + * function specifies that a single exemplar should be evaluated, we can * reduce the batch dimension out. */ private static void reduceBatchDimensions(RankingExpression expression, ImportedModel model, @@ -347,19 +344,19 @@ public class ConvertedModel { TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - // Check generated macros for inputs to reduce - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) continue; + // Check generated functions for inputs to reduce + Set<String> functionNames = new HashSet<>(); + addFunctionNamesIn(expression.getRoot(), functionNames, model); + for (String functionName : functionNames) { + if ( ! model.functions().containsKey(functionName)) continue; - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) { - throw new IllegalArgumentException("Model refers to generated macro '" + macroName + - "but this macro is not present in " + profile); + RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); + if (rankingExpressionFunction == null) { + throw new IllegalArgumentException("Model refers to generated function '" + functionName + + "but this function is not present in " + profile); } - RankingExpression macroExpression = macro.getRankingExpression(); - macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); + RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); + functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); } // Check expression for inputs to reduce @@ -378,7 +375,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredMacros().containsKey(referenceNode.getName())) { + if (model.requiredFunctions().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -386,7 +383,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredMacros().containsKey(referenceNode.getName())) { + if (model.requiredFunctions().containsKey(referenceNode.getName())) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -447,47 +444,47 @@ public class ConvertedModel { } /** - * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions. * This method does that for the given expression and returns the result. */ - private static RankingExpression replaceConstantsByMacros(RankingExpression expression, - Set<String> constantsReplacedByMacros) { - if (constantsReplacedByMacros.isEmpty()) return expression; + private static RankingExpression replaceConstantsByFunctions(RankingExpression expression, + Set<String> constantsReplacedByFunctions) { + if (constantsReplacedByFunctions.isEmpty()) return expression; return new RankingExpression(expression.getName(), - replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + replaceConstantsByFunctions(expression.getRoot(), constantsReplacedByFunctions)); } - private static ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + private static ExpressionNode replaceConstantsByFunctions(ExpressionNode node, Set<String> constantsReplacedByFunctions) { if (node instanceof ReferenceNode) { Reference reference = ((ReferenceNode)node).reference(); if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { String argument = reference.simpleArgument().get(); - if (constantsReplacedByMacros.contains(argument)) + if (constantsReplacedByFunctions.contains(argument)) return new ReferenceNode(argument); } } if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above CompositeNode composite = (CompositeNode)node; return composite.setChildren(composite.children().stream() - .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .map(child -> replaceConstantsByFunctions(child, constantsReplacedByFunctions)) .collect(Collectors.toList())); } return node; } - private static void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { + private static void addFunctionNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) { // macro references cannot specify outputs + if (referenceNode.getOutput() == null) { // function references cannot specify outputs names.add(referenceNode.getName()); - if (model.macros().containsKey(referenceNode.getName())) { - addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + if (model.functions().containsKey(referenceNode.getName())) { + addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model); } } } else if (node instanceof CompositeNode) { for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names, model); + addFunctionNamesIn(child, names, model); } } @@ -551,19 +548,19 @@ public class ConvertedModel { return expressions; } - /** Adds this macro expression to the application package so it can be read later. */ - void writeMacro(String name, RankingExpression expression) { - application.getFile(modelFiles.macrosPath()).appendFile(name + "\t" + - expression.getRoot().toString() + "\n"); + /** Adds this function expression to the application package so it can be read later. */ + void writeFunction(String name, RankingExpression expression) { + application.getFile(modelFiles.functionsPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); } - /** Reads the previously stored macro expressions for these arguments */ - List<Pair<String, RankingExpression>> readMacros() { + /** Reads the previously stored function expressions for these arguments */ + List<Pair<String, RankingExpression>> readFunctions() { try { - ApplicationFile file = application.getFile(modelFiles.macrosPath()); + ApplicationFile file = application.getFile(modelFiles.functionsPath()); if ( ! file.exists()) return Collections.emptyList(); - List<Pair<String, RankingExpression>> macros = new ArrayList<>(); + List<Pair<String, RankingExpression>> functions = new ArrayList<>(); BufferedReader reader = new BufferedReader(file.createReader()); String line; while (null != (line = reader.readLine())) { @@ -571,13 +568,13 @@ public class ConvertedModel { String name = parts[0]; try { RankingExpression expression = new RankingExpression(parts[0], parts[1]); - macros.add(new Pair<>(name, expression)); + functions.add(new Pair<>(name, expression)); } catch (ParseException e) { throw new IllegalStateException("Could not parse " + name, e); } } - return macros; + return functions; } catch (IOException e) { throw new UncheckedIOException(e); @@ -725,9 +722,9 @@ public class ConvertedModel { return storedModelReplicatedPath().append("constants"); } - /** Path to the macros file */ - public Path macrosPath() { - return storedModelReplicatedPath().append("macros.txt"); + /** Path to the functions file */ + public Path functionsPath() { + return storedModelReplicatedPath().append("functions.txt"); } } |