diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-04-21 11:42:49 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-04-23 06:54:56 +0000 |
commit | 2eb948d4085f099ed4420d4acb0339f907c03fa6 (patch) | |
tree | 93312cbb490f1b9bc7f0361f5c1f577d925e9ed1 /config-model/src/main/java/com/yahoo/vespa/model/ml | |
parent | 0db383e464bb24c525ffc4b3df51950a8f10444f (diff) |
add input parameters to rank profile
* also: try to resolve type of output expressions
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 42 |
1 files changed, 33 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index b757259102b..01da57fc9bb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -210,6 +210,9 @@ public class ConvertedModel { Map<String, ExpressionFunction> expressions = new HashMap<>(); for (ImportedMlFunction outputFunction : model.outputExpressions()) { ExpressionFunction expression = asExpressionFunction(outputFunction); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) { + profile.addInputParameter(input.getKey(), input.getValue()); + } addExpression(expression, expression.getName(), constantsReplacedByFunctions, model, store, profile, queryProfiles, @@ -251,13 +254,20 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, Map<String, ExpressionFunction> expressions) { expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); + if (expression.returnType().isEmpty()) { + TensorType type = expression.getBody().type(profile.typeContext(queryProfiles)); + if (type != null) { + expression = expression.withReturnType(type); + } + } store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) + for (Pair<String, Tensor> constant : store.readSmallConstants()) { profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + } for (RankingConstant constant : store.readLargeConstants()) { if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) { @@ -269,7 +279,20 @@ public class ConvertedModel { addGeneratedFunctionToProfile(profile, function.getFirst(), function.getSecond()); } - return store.readExpressions(); + Map<String, ExpressionFunction> expressions = new HashMap<>(); + for (Pair<String, ExpressionFunction> output : store.readExpressions()) { + String name = output.getFirst(); + ExpressionFunction expression = output.getSecond(); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) { + profile.addInputParameter(input.getKey(), input.getValue()); + } + TensorType type = expression.getBody().type(profile.typeContext()); + if (type != null) { + expression = expression.withReturnType(type); + } + expressions.put(name, expression); + } + return expressions; } private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, @@ -321,8 +344,9 @@ public class ConvertedModel { "\nwant to add " + expression + "\n"); return; } - var fun = new ExpressionFunction(functionName, expression); - profile.addFunction(fun, false); // TODO: Inline if only used once + ExpressionFunction function = new ExpressionFunction(functionName, expression); + // XXX should we resolve type here? + profile.addFunction(function, false); // TODO: Inline if only used once } /** @@ -465,14 +489,14 @@ public class ConvertedModel { application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString())); } - Map<String, ExpressionFunction> readExpressions() { - Map<String, ExpressionFunction> expressions = new HashMap<>(); + List<Pair<String, ExpressionFunction>> readExpressions() { + List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); - if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); + if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyList(); for (ApplicationFile expressionFile : expressionPath.listFiles()) { - try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){ + try (BufferedReader reader = new BufferedReader(expressionFile.createReader())) { String name = expressionFile.getPath().getName(); - expressions.put(name, readExpression(name, reader)); + expressions.add(new Pair<>(name, readExpression(name, reader))); } catch (IOException e) { throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e); |