diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-20 10:45:28 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-20 12:22:52 +0000 |
commit | 4a8b6fc334178defd281c307033ca2e40ba64051 (patch) | |
tree | b72a625c63a22885c4af1c7b6d2a88ac7c8b526f | |
parent | 494f5d0e133417880881b05c4fe4f08a265a7510 (diff) |
add withTransformedExpressions() to TensorFunctionNode API
3 files changed, 25 insertions, 0 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index de3ec9648d1..f3fe86e261f 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1696,6 +1696,7 @@ "public void <init>(com.yahoo.tensor.functions.TensorFunction)", "public com.yahoo.tensor.functions.TensorFunction function()", "public java.util.List children()", + "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode withTransformedExpressions(java.util.function.Function)", "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)", "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 7577c65527b..41ece967491 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -25,6 +25,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -58,6 +59,25 @@ public class TensorFunctionNode extends CompositeNode { return new TensorFunctionNode(f); } + private static ScalarFunction<Reference> transform(ScalarFunction<Reference> input, + Function<ExpressionNode, ExpressionNode> transformer) + { + if (input instanceof ExpressionScalarFunction wrapper) { + ExpressionNode transformed = transformer.apply(wrapper.expression); + return new ExpressionScalarFunction(transformed); + } + return input; + } + + public ExpressionNode withTransformedExpressions(Function<ExpressionNode, ExpressionNode> transformer) { + if (function instanceof ExpressionTensorFunction etf) { + ExpressionNode orig = etf.expression; + return transformer.apply(orig); + } + TensorFunction<Reference> transformed = function.withTransformedFunctions(fun -> transform(fun, transformer)); + return new TensorFunctionNode(transformed); + } + @Override public CompositeNode setChildren(List<ExpressionNode> children) { List<TensorFunction<Reference>> wrappedChildren = children.stream() diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index 39afcfff541..225b260f403 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import java.util.ArrayList; import java.util.List; @@ -20,6 +21,9 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext @Override public ExpressionNode transform(ExpressionNode node, TransformContext context) { + if (node instanceof TensorFunctionNode tfn) { + node = tfn.withTransformedExpressions(expr -> transform(expr, context)); + } if (node instanceof ReferenceNode) return transformFeature((ReferenceNode) node, context); else if (node instanceof CompositeNode) |