aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-11-05 13:24:30 +0100
committerGitHub <noreply@github.com>2023-11-05 13:24:30 +0100
commitb123fea3dd152e5e81bba6235e10f4643e7924b7 (patch)
treee1dc6277ca5622f2389ef3adc4a3da84281f0853 /searchlib
parent1c4a7f037627c082f2660ffbc88daaeeec5c0b1a (diff)
parent7ca0c9dcb417f1502666efc5ce47f44a41fc264a (diff)
Merge pull request #29210 from vespa-engine/arnej/add-map-subspaces-for-java
Arnej/add map subspaces for java
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java23
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj18
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java21
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java20
6 files changed, 81 insertions, 11 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index a8e028ff6ad..ca475a77b6c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -930,6 +930,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.FunctionNode scalarOrTensorFunction()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorFunction()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMap()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMapSubspaces()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduce()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduceComposites()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorJoin()",
@@ -1084,6 +1085,7 @@
"public static final int BIT",
"public static final int HAMMING",
"public static final int MAP",
+ "public static final int MAP_SUBSPACES",
"public static final int REDUCE",
"public static final int JOIN",
"public static final int MERGE",
@@ -1469,6 +1471,7 @@
],
"methods" : [
"public void <init>(java.util.List, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public java.lang.String singleArgumentName()",
"public java.util.List children()",
"public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 1f99ea64a88..b2641fdf229 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -40,6 +40,13 @@ public class LambdaFunctionNode extends CompositeNode {
this.functionExpression = functionExpression;
}
+ public String singleArgumentName() {
+ if (arguments.size() != 1) {
+ throw new IllegalArgumentException("Cannot apply " + this + " in map, must have a single argument");
+ }
+ return arguments.get(0);
+ }
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(functionExpression);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 47f0fb29799..b3f2f265900 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -53,8 +53,8 @@ public class TensorFunctionNode extends CompositeNode {
}
private ExpressionNode toExpressionNode(TensorFunction<Reference> f) {
- if (f instanceof ExpressionTensorFunction)
- return ((ExpressionTensorFunction)f).expression;
+ if (f instanceof ExpressionTensorFunction etf)
+ return etf.expression;
else
return new TensorFunctionNode(f);
}
@@ -176,7 +176,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public Double apply(EvaluationContext<Reference> context) {
- return expression.evaluate(new ContextWrapper(context)).asDouble();
+ return expression.evaluate(asContext(context)).asDouble();
}
@Override
@@ -233,10 +233,10 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public List<TensorFunction<Reference>> arguments() {
- if (expression instanceof CompositeNode)
- return ((CompositeNode)expression).children().stream()
- .map(ExpressionTensorFunction::new)
- .collect(Collectors.toList());
+ if (expression instanceof CompositeNode cNode)
+ return cNode.children().stream()
+ .map(ExpressionTensorFunction::new)
+ .collect(Collectors.toList());
else
return Collections.emptyList();
}
@@ -265,7 +265,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public Tensor evaluate(EvaluationContext<Reference> context) {
- return expression.evaluate((Context)context).asTensor();
+ return expression.evaluate(asContext(context)).asTensor();
}
@Override
@@ -412,7 +412,12 @@ public class TensorFunctionNode extends CompositeNode {
public TensorType getType(Reference name) {
return delegate.getType(name);
}
-
}
+ private static Context asContext(EvaluationContext<Reference> generic) {
+ if (generic instanceof Context context) {
+ return context;
+ }
+ return new ContextWrapper(generic);
+ }
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 02a49cea8bc..9b088825201 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -127,6 +127,7 @@ TOKEN :
<HAMMING: "hamming"> |
<MAP: "map"> |
+ <MAP_SUBSPACES: "map_subspaces"> |
<REDUCE: "reduce"> |
<JOIN: "join"> |
<MERGE: "merge"> |
@@ -369,6 +370,7 @@ TensorFunctionNode tensorFunction() :
{
(
tensorExpression = tensorMap() |
+ tensorExpression = tensorMapSubspaces() |
tensorExpression = tensorReduce() |
tensorExpression = tensorReduceComposites() |
tensorExpression = tensorJoin() |
@@ -405,6 +407,22 @@ TensorFunctionNode tensorMap() :
doubleMapper.asDoubleUnaryOperator())); }
}
+TensorFunctionNode tensorMapSubspaces() :
+{
+ ExpressionNode tensor;
+ LambdaFunctionNode denseMapper;
+}
+{
+ <MAP_SUBSPACES> <LBRACE> tensor = expression() <COMMA> denseMapper = lambdaFunction() <RBRACE>
+ {
+ return new TensorFunctionNode(
+ new MapSubspaces(
+ TensorFunctionNode.wrap(tensor),
+ denseMapper.singleArgumentName(),
+ TensorFunctionNode.wrap(denseMapper.children().get(0))));
+ }
+}
+
TensorFunctionNode tensorReduce() :
{
ExpressionNode tensor;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 6aac6ce0983..f9ba7552560 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -177,6 +177,25 @@ public class EvaluationTestCase {
}
@Test
+ public void testMapSubspaces() {
+ EvaluationTester tester = new EvaluationTester();
+ tester.assertEvaluates("tensor<float>(a{},x[2]):{foo:[2,3],bar:[7,10]}",
+ "map_subspaces(tensor0, f(t)(t))",
+ "tensor<float>(a{},x[2]):{foo:[2,3],bar:[7,10]}");
+ tester.assertEvaluates("tensor<float>(a{},x[2]):{foo:[2,3],bar:[7,10]}",
+ "map_subspaces(tensor0, f(t)(t+2))",
+ "tensor<float>(a{},x[2]):{foo:[0,1],bar:[5,8]}");
+
+ tester.assertEvaluates("tensor<float>(a{},y[2]):{foo:[3,5],bar:[9,11]}",
+ "map_subspaces(tensor0, f(t)(tensor<float>(y[2])(t{x:(y)}+t{x:(y+1)})))",
+ "tensor(a{},x[3]):{foo:[1,2,3],bar:[4,5,6]}");
+
+ tester.assertEvaluates("tensor<double>(a{},x[2]):{foo:[3,5],bar:[9,11]}",
+ "map_subspaces(tensor0, f(t)(tensor(x[2])(t{x:(x)}+t{x:(x+1)})))",
+ "tensor<float>(a{},x[3]):{foo:[1,2,3],bar:[4,5,6]}");
+ }
+
+ @Test
public void testTensorEvaluation() {
EvaluationTester tester = new EvaluationTester();
tester.assertEvaluates("{}", "tensor0", "{}");
@@ -296,7 +315,7 @@ public class EvaluationTestCase {
"{{x:0}:1}", "{}", "{{y:0,z:0}:1}");
tester.assertEvaluates("tensor(x{}):{}",
"tensor0 * tensor1", "{ {x:0}:3 }", "tensor(x{}):{ {x:1}:5 }");
- tester.assertEvaluates("tensor<float>(x{}):{}",
+ tester.assertEvaluates("tensor<double>(x{}):{}",
"tensor0 * tensor1", "{ {x:0}:3 }", "tensor<float>(x{}):{ {x:1}:5 }");
tester.assertEvaluates("{ {x:0}:15 }",
"tensor0 * tensor1", "{ {x:0}:3 }", "{ {x:0}:5 }");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index cc278b3d73b..e1018e3b2a5 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -75,7 +76,24 @@ public class EvaluationTester {
RankingExpression expression = new RankingExpression(expressionString);
if ( ! explanation.isEmpty())
explanation = explanation + ": ";
- assertEquals(explanation + expression, value, expression.evaluate(context));
+ var result = expression.evaluate(context);
+ assertEquals(explanation + expression, value, result);
+ assertEquals(value.type().valueType(), result.type().valueType());
+ var root = expression.getRoot();
+ String asString = root.toString();
+ try {
+ expression = new RankingExpression(asString);
+ result = expression.evaluate(context);
+ assertEquals(explanation + expressionString + " -> " + asString, value, result);
+ assertEquals(value.type().valueType(), result.type().valueType());
+ } catch (Exception e) {
+ System.err.println("toString() failure, " + expressionString + " -> " + asString);
+ System.err.println("root: " + root.getClass());
+ if (root instanceof TensorFunctionNode tfn) {
+ System.err.println("root func: " + tfn.function().getClass());
+ }
+ throw new RuntimeException(e);
+ }
return expression;
}
catch (ParseException e) {