summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java27
1 files changed, 27 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 4011ce43841..b153ff62e7d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition;
+import com.yahoo.searchdefinition.expressiontransforms.OnnxModelTransformer;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -158,6 +159,12 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments())));
}
+ // A reference to an ONNX model?
+ Optional<TensorType> onnxFeatureType = onnxFeatureType(reference);
+ if (onnxFeatureType.isPresent()) {
+ return onnxFeatureType.get();
+ }
+
// A reference to a feature which returns a tensor?
Optional<TensorType> featureTensorType = tensorFeatureType(reference);
if (featureTensorType.isPresent()) {
@@ -210,6 +217,26 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
return Optional.of(function);
}
+ private Optional<TensorType> onnxFeatureType(Reference reference) {
+ if ( ! reference.name().equals("onnxModel"))
+ return Optional.empty();
+
+ if ( ! featureTypes.containsKey(reference)) {
+ String configOrFileName = reference.arguments().expressions().get(0).toString();
+
+ // Look up standardized format as added in RankProfile
+ String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
+ String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
+
+ reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
+ if ( ! featureTypes.containsKey(reference)) {
+ throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
+ }
+ }
+
+ return Optional.of(featureTypes.get(reference));
+ }
+
/**
* There are two features which returns the (non-empty) tensor type: tensorFromLabels and tensorFromWeightedSet.
* This returns the type of those features if this is a reference to either of them, or empty otherwise.