diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index d309f48d6df..96c043bdb34 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.VespaModel; @@ -158,6 +159,10 @@ public class RankProfile implements Cloneable { return search != null ? search.rankingConstants() : model.rankingConstants(); } + private Map<String, OnnxModel> onnxModels() { + return search != null ? search.onnxModels().asMap() : Collections.emptyMap(); + } + private Stream<ImmutableSDField> allFields() { if (search == null) return Stream.empty(); if (allFieldsList == null) { @@ -821,6 +826,20 @@ public class RankProfile implements Cloneable { } } + // Add output types for ONNX models + for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) { + String modelName = entry.getKey(); + OnnxModel model = entry.getValue(); + Arguments args = new Arguments(new ReferenceNode(modelName)); + + TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context); + context.setType(new Reference("onnxModel", args, null), defaultOutputType); + + for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { + TensorType type = model.getTensorType(mapping.getKey(), context); + context.setType(new Reference("onnxModel", args, mapping.getValue()), type); + } + } return context; } |