summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-23 16:20:32 -0700
committerJon Bratseth <bratseth@oath.com>2018-09-23 16:20:32 -0700
commit4e44e5472829c033c3d995c618f2febcc4463eb7 (patch)
tree402dc48f0fce44759ce7bca8068c6b98097dd031 /config-model
parent2ee637ff5ef12924e77d5fbf087fb9fb803f0143 (diff)
Use ExpressionFunction
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java8
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java57
2 files changed, 33 insertions, 32 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 04481a3bc8d..2c6a7941772 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -237,8 +237,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
model.name(), profile, queryProfiles.getRegistry(), model);
- for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) {
- profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs
+ for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) {
+ profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use arguments
}
}
}
@@ -249,8 +249,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile);
- for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) {
- profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs
+ for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) {
+ profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use inputs
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index d72a22f7c5e..fb0109ed32e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -48,6 +48,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -67,14 +68,14 @@ public class ConvertedModel {
private final ModelName modelName;
private final String modelDescription;
- private final ImmutableMap<String, ImportedModel.ExpressionWithInputs> expressions;
+ private final ImmutableMap<String, ExpressionFunction> expressions;
/** The source importedModel, or empty if this was created from a stored converted model */
private final Optional<ImportedModel> sourceModel;
private ConvertedModel(ModelName modelName,
String modelDescription,
- Map<String, ImportedModel.ExpressionWithInputs> expressions,
+ Map<String, ExpressionFunction> expressions,
Optional<ImportedModel> sourceModel) {
this.modelName = modelName;
this.modelDescription = modelDescription;
@@ -132,23 +133,23 @@ public class ConvertedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public Map<String, ImportedModel.ExpressionWithInputs> expressions() { return expressions; }
+ public Map<String, ExpressionFunction> expressions() { return expressions; }
/**
* Returns the expression matching the given arguments.
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
- ImportedModel.ExpressionWithInputs expression = selectExpression(arguments);
+ ExpressionFunction expression = selectExpression(arguments);
if (sourceModel.isPresent()) // we should verify
- verifyInputs(expression.expression(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
- return expression.expression().getRoot();
+ verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
+ return expression.getBody().getRoot();
}
- private ImportedModel.ExpressionWithInputs selectExpression(FeatureArguments arguments) {
+ private ExpressionFunction selectExpression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
- ImportedModel.ExpressionWithInputs expression = expressions.get(arguments.toName());
+ ExpressionFunction expression = expressions.get(arguments.toName());
if (expression != null) return expression;
if ( ! arguments.signature().isPresent()) {
@@ -158,7 +159,7 @@ public class ConvertedModel {
}
if ( ! arguments.output().isPresent()) {
- List<Map.Entry<String, ImportedModel.ExpressionWithInputs>> entriesWithTheRightPrefix =
+ List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix =
expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList());
if (entriesWithTheRightPrefix.size() < 1)
throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() +
@@ -179,10 +180,10 @@ public class ConvertedModel {
// ----------------------- Static model conversion/storage below here
- private static Map<String, ImportedModel.ExpressionWithInputs> convertAndStore(ImportedModel model,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelStore store) {
+ private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelStore store) {
// Add constants
Set<String> constantsReplacedByFunctions = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -193,8 +194,8 @@ public class ConvertedModel {
addGeneratedFunctions(model, profile);
// Add expressions
- Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>();
- for (Pair<String, ImportedModel.ExpressionWithInputs> output : model.outputExpressions()) {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
+ for (Pair<String, ExpressionFunction> output : model.outputExpressions()) {
addExpression(output.getSecond(), output.getFirst(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
@@ -210,21 +211,21 @@ public class ConvertedModel {
return expressions;
}
- private static void addExpression(ImportedModel.ExpressionWithInputs expression,
+ private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
ImportedModel model,
ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
- Map<String, ImportedModel.ExpressionWithInputs> expressions) {
- expression = expression.with(replaceConstantsByFunctions(expression.expression(), constantsReplacedByFunctions));
- reduceBatchDimensions(expression.expression(), model, profile, queryProfiles);
+ Map<String, ExpressionFunction> expressions) {
+ expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions));
+ reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private static Map<String, ImportedModel.ExpressionWithInputs> convertStored(ModelStore store, RankProfile profile) {
+ private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -525,15 +526,15 @@ public class ConvertedModel {
* @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part
* is always the model name
*/
- void writeExpression(String name, ImportedModel.ExpressionWithInputs expression) {
- StringBuilder b = new StringBuilder(expression.expression().getRoot().toString());
- for (Map.Entry<String, TensorType> input : expression.inputs().entrySet())
+ void writeExpression(String name, ExpressionFunction expression) {
+ StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString());
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet())
b.append('\n').append(input.getKey()).append('\t').append(input.getValue());
application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString()));
}
- Map<String, ImportedModel.ExpressionWithInputs> readExpressions() {
- Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>();
+ Map<String, ExpressionFunction> readExpressions() {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
for (ApplicationFile expressionFile : expressionPath.listFiles()) {
@@ -551,18 +552,18 @@ public class ConvertedModel {
return expressions;
}
- private ImportedModel.ExpressionWithInputs readExpression(String name, BufferedReader reader)
+ private ExpressionFunction readExpression(String name, BufferedReader reader)
throws IOException, ParseException {
// First line is expression
RankingExpression expression = new RankingExpression(name, reader.readLine());
// Next lines are inputs on the format name\ttensorTypeSpec
- Map<String, TensorType> inputs = new HashMap<>();
+ Map<String, TensorType> inputs = new LinkedHashMap<>();
String line;
while (null != (line = reader.readLine())) {
String[] parts = line.split("\t");
inputs.put(parts[0], TensorType.fromSpec(parts[1]));
}
- return new ImportedModel.ExpressionWithInputs(expression, inputs);
+ return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty());
}
/** Adds this function expression to the application package so it can be read later. */