summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-20 10:45:28 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-20 12:22:52 +0000
commit4a8b6fc334178defd281c307033ca2e40ba64051 (patch)
treeb72a625c63a22885c4af1c7b6d2a88ac7c8b526f
parent494f5d0e133417880881b05c4fe4f08a265a7510 (diff)
add withTransformedExpressions() to TensorFunctionNode API
-rw-r--r--searchlib/abi-spec.json1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java4
3 files changed, 25 insertions, 0 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index de3ec9648d1..f3fe86e261f 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1696,6 +1696,7 @@
"public void <init>(com.yahoo.tensor.functions.TensorFunction)",
"public com.yahoo.tensor.functions.TensorFunction function()",
"public java.util.List children()",
+ "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode withTransformedExpressions(java.util.function.Function)",
"public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 7577c65527b..41ece967491 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.function.Function;
import java.util.stream.Collectors;
/**
@@ -58,6 +59,25 @@ public class TensorFunctionNode extends CompositeNode {
return new TensorFunctionNode(f);
}
+ private static ScalarFunction<Reference> transform(ScalarFunction<Reference> input,
+ Function<ExpressionNode, ExpressionNode> transformer)
+ {
+ if (input instanceof ExpressionScalarFunction wrapper) {
+ ExpressionNode transformed = transformer.apply(wrapper.expression);
+ return new ExpressionScalarFunction(transformed);
+ }
+ return input;
+ }
+
+ public ExpressionNode withTransformedExpressions(Function<ExpressionNode, ExpressionNode> transformer) {
+ if (function instanceof ExpressionTensorFunction etf) {
+ ExpressionNode orig = etf.expression;
+ return transformer.apply(orig);
+ }
+ TensorFunction<Reference> transformed = function.withTransformedFunctions(fun -> transform(fun, transformer));
+ return new TensorFunctionNode(transformed);
+ }
+
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction<Reference>> wrappedChildren = children.stream()
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
index 39afcfff541..225b260f403 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import java.util.ArrayList;
import java.util.List;
@@ -20,6 +21,9 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext
@Override
public ExpressionNode transform(ExpressionNode node, TransformContext context) {
+ if (node instanceof TensorFunctionNode tfn) {
+ node = tfn.withTransformedExpressions(expr -> transform(expr, context));
+ }
if (node instanceof ReferenceNode)
return transformFeature((ReferenceNode) node, context);
else if (node instanceof CompositeNode)