summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java116
1 files changed, 67 insertions, 49 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index adf5c81283e..fb0109ed32e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -48,6 +48,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -67,14 +68,14 @@ public class ConvertedModel {
private final ModelName modelName;
private final String modelDescription;
- private final ImmutableMap<String, RankingExpression> expressions;
+ private final ImmutableMap<String, ExpressionFunction> expressions;
/** The source importedModel, or empty if this was created from a stored converted model */
private final Optional<ImportedModel> sourceModel;
private ConvertedModel(ModelName modelName,
String modelDescription,
- Map<String, RankingExpression> expressions,
+ Map<String, ExpressionFunction> expressions,
Optional<ImportedModel> sourceModel) {
this.modelName = modelName;
this.modelDescription = modelDescription;
@@ -132,23 +133,23 @@ public class ConvertedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public Map<String, RankingExpression> expressions() { return expressions; }
+ public Map<String, ExpressionFunction> expressions() { return expressions; }
/**
* Returns the expression matching the given arguments.
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
- RankingExpression expression = selectExpression(arguments);
- if (sourceModel.isPresent()) // we can verify
- verifyRequiredFunctions(expression, sourceModel.get(), context.rankProfile(), context.queryProfiles());
- return expression.getRoot();
+ ExpressionFunction expression = selectExpression(arguments);
+ if (sourceModel.isPresent()) // we should verify
+ verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
+ return expression.getBody().getRoot();
}
- private RankingExpression selectExpression(FeatureArguments arguments) {
+ private ExpressionFunction selectExpression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
- RankingExpression expression = expressions.get(arguments.toName());
+ ExpressionFunction expression = expressions.get(arguments.toName());
if (expression != null) return expression;
if ( ! arguments.signature().isPresent()) {
@@ -158,7 +159,7 @@ public class ConvertedModel {
}
if ( ! arguments.output().isPresent()) {
- List<Map.Entry<String, RankingExpression>> entriesWithTheRightPrefix =
+ List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix =
expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList());
if (entriesWithTheRightPrefix.size() < 1)
throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() +
@@ -179,10 +180,10 @@ public class ConvertedModel {
// ----------------------- Static model conversion/storage below here
- private static Map<String, RankingExpression> convertAndStore(ImportedModel model,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelStore store) {
+ private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelStore store) {
// Add constants
Set<String> constantsReplacedByFunctions = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -193,8 +194,8 @@ public class ConvertedModel {
addGeneratedFunctions(model, profile);
// Add expressions
- Map<String, RankingExpression> expressions = new HashMap<>();
- for (Pair<String, RankingExpression> output : model.outputExpressions()) {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
+ for (Pair<String, ExpressionFunction> output : model.outputExpressions()) {
addExpression(output.getSecond(), output.getFirst(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
@@ -210,21 +211,21 @@ public class ConvertedModel {
return expressions;
}
- private static void addExpression(RankingExpression expression,
+ private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
ImportedModel model,
ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
- Map<String, RankingExpression> expressions) {
- expression = replaceConstantsByFunctions(expression, constantsReplacedByFunctions);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
+ Map<String, ExpressionFunction> expressions) {
+ expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions));
+ reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private static Map<String, RankingExpression> convertStored(ModelStore store, RankProfile profile) {
+ private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -290,15 +291,15 @@ public class ConvertedModel {
}
/**
- * Verify that the functions referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredFunctions.
+ * Verify that the inputs declared in the given expression exists in the given rank profile as functions,
+ * and return tensors of the correct types.
*/
- private static void verifyRequiredFunctions(RankingExpression expression, ImportedModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
+ private static void verifyInputs(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> functionNames = new HashSet<>();
addFunctionNamesIn(expression.getRoot(), functionNames, model);
for (String functionName : functionNames) {
- TensorType requiredType = model.requiredFunctions().get(functionName);
+ TensorType requiredType = model.inputs().get(functionName);
if (requiredType == null) continue; // Not a required function
RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
@@ -375,7 +376,7 @@ public class ConvertedModel {
List<ExpressionNode> children = ((TensorFunctionNode)node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredFunctions().containsKey(referenceNode.getName())) {
+ if (model.inputs().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
@@ -383,7 +384,7 @@ public class ConvertedModel {
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredFunctions().containsKey(referenceNode.getName())) {
+ if (model.inputs().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
@@ -451,7 +452,8 @@ public class ConvertedModel {
Set<String> constantsReplacedByFunctions) {
if (constantsReplacedByFunctions.isEmpty()) return expression;
return new RankingExpression(expression.getName(),
- replaceConstantsByFunctions(expression.getRoot(), constantsReplacedByFunctions));
+ replaceConstantsByFunctions(expression.getRoot(),
+ constantsReplacedByFunctions));
}
private static ExpressionNode replaceConstantsByFunctions(ExpressionNode node, Set<String> constantsReplacedByFunctions) {
@@ -524,19 +526,21 @@ public class ConvertedModel {
* @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part
* is always the model name
*/
- void writeExpression(String name, RankingExpression expression) {
- application.getFile(modelFiles.expressionPath(name))
- .writeFile(new StringReader(expression.getRoot().toString()));
+ void writeExpression(String name, ExpressionFunction expression) {
+ StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString());
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet())
+ b.append('\n').append(input.getKey()).append('\t').append(input.getValue());
+ application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString()));
}
- Map<String, RankingExpression> readExpressions() {
- Map<String, RankingExpression> expressions = new HashMap<>();
+ Map<String, ExpressionFunction> readExpressions() {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
for (ApplicationFile expressionFile : expressionPath.listFiles()) {
- try (Reader reader = new BufferedReader(expressionFile.createReader())){
+ try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){
String name = expressionFile.getPath().getName();
- expressions.put(name, new RankingExpression(name, reader));
+ expressions.put(name, readExpression(name, reader));
}
catch (IOException e) {
throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e);
@@ -548,8 +552,22 @@ public class ConvertedModel {
return expressions;
}
+ private ExpressionFunction readExpression(String name, BufferedReader reader)
+ throws IOException, ParseException {
+ // First line is expression
+ RankingExpression expression = new RankingExpression(name, reader.readLine());
+ // Next lines are inputs on the format name\ttensorTypeSpec
+ Map<String, TensorType> inputs = new LinkedHashMap<>();
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ inputs.put(parts[0], TensorType.fromSpec(parts[1]));
+ }
+ return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty());
+ }
+
/** Adds this function expression to the application package so it can be read later. */
- void writeFunction(String name, RankingExpression expression) {
+ public void writeFunction(String name, RankingExpression expression) {
application.getFile(modelFiles.functionsPath()).appendFile(name + "\t" +
expression.getRoot().toString() + "\n");
}
@@ -561,20 +579,20 @@ public class ConvertedModel {
if ( ! file.exists()) return Collections.emptyList();
List<Pair<String, RankingExpression>> functions = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[0], parts[1]);
- functions.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + name, e);
+ try (BufferedReader reader = new BufferedReader(file.createReader())) {
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[0], parts[1]);
+ functions.add(new Pair<>(name, expression));
+ } catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + name, e);
+ }
}
+ return functions;
}
- return functions;
}
catch (IOException e) {
throw new UncheckedIOException(e);