summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-20 10:48:58 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-20 12:22:52 +0000
commitad2e68e457aa87c508a7e057cc178f1fe6af35d1 (patch)
tree274e1c15d8cd047538f78005497a3c684c6dca51
parent4a8b6fc334178defd281c307033ca2e40ba64051 (diff)
deeper processing of TensorFunctionNode and ONNX model references
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java33
1 files changed, 33 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java
index 4e320594918..a9eea3d2ead 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java
@@ -2,14 +2,18 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
@@ -22,6 +26,9 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof TensorFunctionNode tfn) {
+ node = tfn.withTransformedExpressions(expr -> transform(expr, context));
+ }
if (node instanceof ReferenceNode) {
return transformFeature((ReferenceNode) node, context);
} else if (node instanceof CompositeNode) {
@@ -32,6 +39,32 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile
}
private ExpressionNode transformFeature(ReferenceNode node, RankProfileTransformContext context) {
+ Reference ref = node.reference();
+ String name = ref.name();
+ var args = ref.arguments();
+ if (name.equals("onnx") && args.size() == 1) {
+ var arg = args.expressions().get(0);
+ var models = context.rankProfile().onnxModels();
+ var model = models.get(arg.toString());
+ if (model != null) {
+ for (var entry : model.getInputMap().entrySet()) {
+ String source = entry.getValue();
+ var reader = new StringReader(source);
+ try {
+ var asExpression = new RankingExpression(reader);
+ String transformed = transform(asExpression.getRoot(), context).toString();
+ if (! source.equals(transformed)) {
+ // not sure about this:
+ throw new IllegalStateException("unexpected rewrite: " + source + " => " + transformed + " for onnx input " + entry.getKey());
+ // consider instead: model.addInputNameMapping(entry.getKey(), transformed, true);
+ }
+ } catch (ParseException e) {
+ throw new IllegalArgumentException("illegal onnx input '" + source + "': " + e.getMessage());
+ }
+ }
+ return node;
+ }
+ }
if ( ! node.getArguments().isEmpty() && ! FeatureNames.isSimpleFeature(node.reference())) {
return transformArguments(node, context);
} else {