summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-08 15:23:49 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-08 15:23:49 +0100
commitcdcafc6fc8b4417abab8c72bbce5c503533558ea (patch)
treeea4b1c1a64da79bf42270d1b59f10ae32013b4d8 /searchlib
parentdf287b9364b8088192146df70f5f4814ff6c94c1 (diff)
Serialize scalar functions with context
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java73
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java2
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java6
4 files changed, 46 insertions, 39 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 1258601a2d1..8d7bf4f9f14 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1580,7 +1580,7 @@
],
"fields": []
},
- "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode": {
+ "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction": {
"superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
"interfaces": [],
"attributes": [
@@ -1612,7 +1612,7 @@
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
"public static java.util.Map wrap(java.util.Map)",
"public static java.util.List wrap(java.util.List)"
],
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index e6e49e07c34..4ffd40f00f7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
@@ -21,7 +22,6 @@ import java.util.Deque;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.function.Function;
import java.util.stream.Collectors;
/**
@@ -49,8 +49,8 @@ public class TensorFunctionNode extends CompositeNode {
}
private ExpressionNode toExpressionNode(TensorFunction f) {
- if (f instanceof TensorFunctionExpressionNode)
- return ((TensorFunctionExpressionNode)f).expression;
+ if (f instanceof ExpressionTensorFunction)
+ return ((ExpressionTensorFunction)f).expression;
else
return new TensorFunctionNode(f);
}
@@ -58,7 +58,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction> wrappedChildren = children.stream()
- .map(TensorFunctionExpressionNode::new)
+ .map(ExpressionTensorFunction::new)
.collect(Collectors.toList());
return new TensorFunctionNode(function.withArguments(wrappedChildren));
}
@@ -66,7 +66,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
// Serialize as primitive
- return string.append(function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)));
+ return string.append(function.toPrimitive().toString(new ExpressionToStringContext(context, path, this)));
}
@Override
@@ -77,29 +77,29 @@ public class TensorFunctionNode extends CompositeNode {
return new TensorValue(function.evaluate(context));
}
- public static TensorFunctionExpressionNode wrap(ExpressionNode node) {
- return new TensorFunctionExpressionNode(node);
+ public static ExpressionTensorFunction wrap(ExpressionNode node) {
+ return new ExpressionTensorFunction(node);
}
- public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) {
- Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>();
+ public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>();
for (var entry : nodes.entrySet())
- closures.put(entry.getKey(), new ExpressionClosure(entry.getValue()));
- return closures;
+ functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue()));
+ return functions;
}
- public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) {
- List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>();
+ public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) {
+ List<ScalarFunction> functions = new ArrayList<>();
for (var entry : nodes)
- closures.add(new ExpressionClosure(entry));
- return closures;
+ functions.add(new ExpressionScalarFunction(entry));
+ return functions;
}
- private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> {
+ private static class ExpressionScalarFunction implements ScalarFunction {
private final ExpressionNode expression;
- public ExpressionClosure(ExpressionNode expression) {
+ public ExpressionScalarFunction(ExpressionNode expression) {
this.expression = expression;
}
@@ -110,7 +110,18 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public String toString() {
- return expression.toString();
+ return toString(ExpressionToStringContext.empty);
+ }
+
+ @Override
+ public String toString(ToStringContext c) {
+ if (c instanceof ExpressionToStringContext) {
+ ExpressionToStringContext context = (ExpressionToStringContext) c;
+ return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString();
+ }
+ else {
+ return expression.toString();
+ }
}
}
@@ -119,12 +130,12 @@ public class TensorFunctionNode extends CompositeNode {
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
- public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {
+ public static class ExpressionTensorFunction extends PrimitiveTensorFunction {
/** An expression which produces a tensor */
private final ExpressionNode expression;
- public TensorFunctionExpressionNode(ExpressionNode expression) {
+ public ExpressionTensorFunction(ExpressionNode expression) {
this.expression = expression;
}
@@ -132,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode {
public List<TensorFunction> arguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
- .map(TensorFunctionExpressionNode::new)
+ .map(ExpressionTensorFunction::new)
.collect(Collectors.toList());
else
return Collections.emptyList();
@@ -142,9 +153,9 @@ public class TensorFunctionNode extends CompositeNode {
public TensorFunction withArguments(List<TensorFunction> arguments) {
if (arguments.size() == 0) return this;
List<ExpressionNode> unwrappedChildren = arguments.stream()
- .map(arg -> ((TensorFunctionExpressionNode)arg).expression)
+ .map(arg -> ((ExpressionTensorFunction)arg).expression)
.collect(Collectors.toList());
- return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren));
+ return new ExpressionTensorFunction(((CompositeNode)expression).setChildren(unwrappedChildren));
}
@Override
@@ -163,13 +174,13 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public String toString() {
- return toString(ExpressionNodeToStringContext.empty);
+ return toString(ExpressionToStringContext.empty);
}
@Override
public String toString(ToStringContext c) {
- if (c instanceof ExpressionNodeToStringContext) {
- ExpressionNodeToStringContext context = (ExpressionNodeToStringContext) c;
+ if (c instanceof ExpressionToStringContext) {
+ ExpressionToStringContext context = (ExpressionToStringContext) c;
return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString();
}
else {
@@ -180,17 +191,17 @@ public class TensorFunctionNode extends CompositeNode {
}
/** Allows passing serialization context arguments through TensorFunctions */
- private static class ExpressionNodeToStringContext implements ToStringContext {
+ private static class ExpressionToStringContext implements ToStringContext {
final SerializationContext context;
final Deque<String> path;
final CompositeNode parent;
- public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(new SerializationContext(),
- null,
- null);
+ public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(),
+ null,
+ null);
- public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
this.context = context;
this.path = path;
this.parent = parent;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
index 6d687b015f1..9a38b5efc1f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
@@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E
ExpressionNode arg1 = node.children().get(0);
ExpressionNode arg2 = node.children().get(1);
- TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1);
+ TensorFunctionNode.ExpressionTensorFunction expression = TensorFunctionNode.wrap(arg1);
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index efebdb310f7..e7024b87452 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -1,20 +1,16 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.TensorFunction;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -65,7 +61,7 @@ public class RankingExpressionTestCase {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant);
- Reduce sum = new Reduce(new TensorFunctionNode.TensorFunctionExpressionNode(product), Reduce.Aggregator.sum);
+ Reduce sum = new Reduce(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
RankingExpression expected = new RankingExpression("sum(input * constant)");