summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java19
1 files changed, 19 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index d309f48d6df..96c043bdb34 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -18,6 +18,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.VespaModel;
@@ -158,6 +159,10 @@ public class RankProfile implements Cloneable {
return search != null ? search.rankingConstants() : model.rankingConstants();
}
+ private Map<String, OnnxModel> onnxModels() {
+ return search != null ? search.onnxModels().asMap() : Collections.emptyMap();
+ }
+
private Stream<ImmutableSDField> allFields() {
if (search == null) return Stream.empty();
if (allFieldsList == null) {
@@ -821,6 +826,20 @@ public class RankProfile implements Cloneable {
}
}
+ // Add output types for ONNX models
+ for (Map.Entry<String, OnnxModel> entry : onnxModels().entrySet()) {
+ String modelName = entry.getKey();
+ OnnxModel model = entry.getValue();
+ Arguments args = new Arguments(new ReferenceNode(modelName));
+
+ TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), context);
+ context.setType(new Reference("onnxModel", args, null), defaultOutputType);
+
+ for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
+ TensorType type = model.getTensorType(mapping.getKey(), context);
+ context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
+ }
+ }
return context;
}