summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
commitf4203c3cc571722f08ee65047437c1290ed63f69 (patch)
tree7d06d17091a2e388e6771187a11cf4f4023a0c1e /searchlib
parent316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff)
Allow bound functions in tensor generate
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java46
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java3
5 files changed, 45 insertions, 17 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index d5970a4b69e..dcf42069373 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1615,8 +1615,9 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
- "public static java.util.Map wrap(java.util.Map)",
- "public static java.util.List wrap(java.util.List)"
+ "public static java.util.Map wrapScalars(java.util.Map)",
+ "public static java.util.List wrapScalars(java.util.List)",
+ "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
],
"fields": []
},
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index d68f8c85ad1..cf17c6465f3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -34,7 +34,7 @@ public abstract class Context implements EvaluationContext<Reference> {
@Override
public TensorType getType(String reference) {
- throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ throw new UnsupportedOperationException("Not able to parse general references from string form");
}
/** Returns a variable as a tensor */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 4ffd40f00f7..18f1fa8a78f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.sql.Ref;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
@@ -81,21 +82,22 @@ public class TensorFunctionNode extends CompositeNode {
return new ExpressionTensorFunction(node);
}
- public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) {
- Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>();
+ public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<>();
for (var entry : nodes.entrySet())
- functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue()));
+ functions.put(entry.getKey(), wrapScalar(entry.getValue()));
return functions;
}
- public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) {
- List<ScalarFunction> functions = new ArrayList<>();
- for (var entry : nodes)
- functions.add(new ExpressionScalarFunction(entry));
- return functions;
+ public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
+ return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
+ }
+
+ public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) {
+ return new ExpressionScalarFunction(node);
}
- private static class ExpressionScalarFunction implements ScalarFunction {
+ private static class ExpressionScalarFunction implements ScalarFunction<Reference> {
private final ExpressionNode expression;
@@ -104,8 +106,8 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public Double apply(EvaluationContext<?> context) {
- return expression.evaluate((Context)context).asDouble();
+ public Double apply(EvaluationContext<Reference> context) {
+ return expression.evaluate(new ContextWrapper(context)).asDouble();
}
@Override
@@ -209,4 +211,26 @@ public class TensorFunctionNode extends CompositeNode {
}
+ /** Turns an EvaluationContext into a Context */
+ // TODO: We should be able to change RankingExpression.evaluate to take an EvaluationContext and then get rid of this
+ private static class ContextWrapper extends Context {
+
+ private final EvaluationContext<Reference> delegate;
+
+ public ContextWrapper(EvaluationContext<Reference> delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public Value get(String name) {
+ return new TensorValue(delegate.getTensor(name));
+ }
+
+ @Override
+ public TensorType getType(Reference name) {
+ return delegate.getType(name);
+ }
+
+ }
+
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 3e9649cd9c6..beab722a1eb 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -487,7 +487,7 @@ TensorFunctionNode tensorGenerateBody(TensorType type) :
}
{
<LBRACE> generator = expression() <RBRACE>
- { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); }
+ { return new TensorFunctionNode(Generate.bound(type, TensorFunctionNode.wrapScalar(generator))); }
}
TensorFunctionNode tensorRange() :
@@ -847,7 +847,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) :
( tensorCell(type, cells))*
( <COMMA> tensorCell(type, cells))*
<RCURLY>
- { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
DynamicTensor indexedTensorValueBody(TensorType type) :
@@ -860,7 +860,7 @@ DynamicTensor indexedTensorValueBody(TensorType type) :
( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
// <RSQUARE>
- { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
void tensorCell(TensorType type, java.util.Map cells) :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index a8afc230bde..05ad8c97c7f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -358,6 +358,9 @@ public class EvaluationTestCase {
tester.assertEvaluates("500",
"join(tensor0, tensor1, f(x,y) (x*y)){tag2}",
"tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}");
+ tester.assertEvaluates("tensor(j[3]):[3, 3, 3]",
+ "tensor(j[3])(tensor0[2])",
+ "tensor(values[5]):[1, 2, 3, 4, 5]");
// tensor result dimensions are given from argument dimensions, not the resulting values
tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }");