aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-21 11:42:49 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-23 06:54:56 +0000
commit2eb948d4085f099ed4420d4acb0339f907c03fa6 (patch)
tree93312cbb490f1b9bc7f0361f5c1f577d925e9ed1 /config-model/src/main/java/com/yahoo/vespa/model/ml
parent0db383e464bb24c525ffc4b3df51950a8f10444f (diff)
add input parameters to rank profile
* also: try to resolve type of output expressions
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java42
1 files changed, 33 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index b757259102b..01da57fc9bb 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -210,6 +210,9 @@ public class ConvertedModel {
Map<String, ExpressionFunction> expressions = new HashMap<>();
for (ImportedMlFunction outputFunction : model.outputExpressions()) {
ExpressionFunction expression = asExpressionFunction(outputFunction);
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) {
+ profile.addInputParameter(input.getKey(), input.getValue());
+ }
addExpression(expression, expression.getName(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
@@ -251,13 +254,20 @@ public class ConvertedModel {
QueryProfileRegistry queryProfiles,
Map<String, ExpressionFunction> expressions) {
expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions));
+ if (expression.returnType().isEmpty()) {
+ TensorType type = expression.getBody().type(profile.typeContext(queryProfiles));
+ if (type != null) {
+ expression = expression.withReturnType(type);
+ }
+ }
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
+ for (Pair<String, Tensor> constant : store.readSmallConstants()) {
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+ }
for (RankingConstant constant : store.readLargeConstants()) {
if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) {
@@ -269,7 +279,20 @@ public class ConvertedModel {
addGeneratedFunctionToProfile(profile, function.getFirst(), function.getSecond());
}
- return store.readExpressions();
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
+ for (Pair<String, ExpressionFunction> output : store.readExpressions()) {
+ String name = output.getFirst();
+ ExpressionFunction expression = output.getSecond();
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) {
+ profile.addInputParameter(input.getKey(), input.getValue());
+ }
+ TensorType type = expression.getBody().type(profile.typeContext());
+ if (type != null) {
+ expression = expression.withReturnType(type);
+ }
+ expressions.put(name, expression);
+ }
+ return expressions;
}
private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName,
@@ -321,8 +344,9 @@ public class ConvertedModel {
"\nwant to add " + expression + "\n");
return;
}
- var fun = new ExpressionFunction(functionName, expression);
- profile.addFunction(fun, false); // TODO: Inline if only used once
+ ExpressionFunction function = new ExpressionFunction(functionName, expression);
+ // XXX should we resolve type here?
+ profile.addFunction(function, false); // TODO: Inline if only used once
}
/**
@@ -465,14 +489,14 @@ public class ConvertedModel {
application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString()));
}
- Map<String, ExpressionFunction> readExpressions() {
- Map<String, ExpressionFunction> expressions = new HashMap<>();
+ List<Pair<String, ExpressionFunction>> readExpressions() {
+ List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
- if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
+ if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyList();
for (ApplicationFile expressionFile : expressionPath.listFiles()) {
- try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){
+ try (BufferedReader reader = new BufferedReader(expressionFile.createReader())) {
String name = expressionFile.getPath().getName();
- expressions.put(name, readExpression(name, reader));
+ expressions.add(new Pair<>(name, readExpression(name, reader)));
}
catch (IOException e) {
throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e);