diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-20 10:48:58 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-20 12:22:52 +0000 |
commit | ad2e68e457aa87c508a7e057cc178f1fe6af35d1 (patch) | |
tree | 274e1c15d8cd047538f78005497a3c684c6dca51 | |
parent | 4a8b6fc334178defd281c307033ca2e40ba64051 (diff) |
deeper processing of TensorFunctionNode and ONNX model references
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java index 4e320594918..a9eea3d2ead 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/ConstantTensorTransformer.java @@ -2,14 +2,18 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import java.io.StringReader; import java.util.ArrayList; import java.util.List; @@ -22,6 +26,9 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof TensorFunctionNode tfn) { + node = tfn.withTransformedExpressions(expr -> transform(expr, context)); + } if (node instanceof ReferenceNode) { return transformFeature((ReferenceNode) node, context); } else if (node instanceof CompositeNode) { @@ -32,6 +39,32 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile } private ExpressionNode transformFeature(ReferenceNode node, RankProfileTransformContext context) { + Reference ref = node.reference(); + String name = ref.name(); + var args = ref.arguments(); + if (name.equals("onnx") && args.size() == 1) { + var arg = args.expressions().get(0); + var models = context.rankProfile().onnxModels(); + var model = models.get(arg.toString()); + if (model != null) { + for (var entry : model.getInputMap().entrySet()) { + String source = entry.getValue(); + var reader = new StringReader(source); + try { + var asExpression = new RankingExpression(reader); + String transformed = transform(asExpression.getRoot(), context).toString(); + if (! source.equals(transformed)) { + // not sure about this: + throw new IllegalStateException("unexpected rewrite: " + source + " => " + transformed + " for onnx input " + entry.getKey()); + // consider instead: model.addInputNameMapping(entry.getKey(), transformed, true); + } + } catch (ParseException e) { + throw new IllegalArgumentException("illegal onnx input '" + source + "': " + e.getMessage()); + } + } + return node; + } + } if ( ! node.getArguments().isEmpty() && ! FeatureNames.isSimpleFeature(node.reference())) { return transformArguments(node, context); } else { |