diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-11-05 13:24:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-05 13:24:30 +0100 |
commit | b123fea3dd152e5e81bba6235e10f4643e7924b7 (patch) | |
tree | e1dc6277ca5622f2389ef3adc4a3da84281f0853 /searchlib | |
parent | 1c4a7f037627c082f2660ffbc88daaeeec5c0b1a (diff) | |
parent | 7ca0c9dcb417f1502666efc5ce47f44a41fc264a (diff) |
Merge pull request #29210 from vespa-engine/arnej/add-map-subspaces-for-java
Arnej/add map subspaces for java
Diffstat (limited to 'searchlib')
6 files changed, 81 insertions, 11 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index a8e028ff6ad..ca475a77b6c 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -930,6 +930,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.FunctionNode scalarOrTensorFunction()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorFunction()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMap()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMapSubspaces()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduce()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduceComposites()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorJoin()", @@ -1084,6 +1085,7 @@ "public static final int BIT", "public static final int HAMMING", "public static final int MAP", + "public static final int MAP_SUBSPACES", "public static final int REDUCE", "public static final int JOIN", "public static final int MERGE", @@ -1469,6 +1471,7 @@ ], "methods" : [ "public void <init>(java.util.List, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", + "public java.lang.String singleArgumentName()", "public java.util.List children()", "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)", "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 1f99ea64a88..b2641fdf229 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -40,6 +40,13 @@ public class LambdaFunctionNode extends CompositeNode { this.functionExpression = functionExpression; } + public String singleArgumentName() { + if (arguments.size() != 1) { + throw new IllegalArgumentException("Cannot apply " + this + " in map, must have a single argument"); + } + return arguments.get(0); + } + @Override public List<ExpressionNode> children() { return Collections.singletonList(functionExpression); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 47f0fb29799..b3f2f265900 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -53,8 +53,8 @@ public class TensorFunctionNode extends CompositeNode { } private ExpressionNode toExpressionNode(TensorFunction<Reference> f) { - if (f instanceof ExpressionTensorFunction) - return ((ExpressionTensorFunction)f).expression; + if (f instanceof ExpressionTensorFunction etf) + return etf.expression; else return new TensorFunctionNode(f); } @@ -176,7 +176,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public Double apply(EvaluationContext<Reference> context) { - return expression.evaluate(new ContextWrapper(context)).asDouble(); + return expression.evaluate(asContext(context)).asDouble(); } @Override @@ -233,10 +233,10 @@ public class TensorFunctionNode extends CompositeNode { @Override public List<TensorFunction<Reference>> arguments() { - if (expression instanceof CompositeNode) - return ((CompositeNode)expression).children().stream() - .map(ExpressionTensorFunction::new) - .collect(Collectors.toList()); + if (expression instanceof CompositeNode cNode) + return cNode.children().stream() + .map(ExpressionTensorFunction::new) + .collect(Collectors.toList()); else return Collections.emptyList(); } @@ -265,7 +265,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public Tensor evaluate(EvaluationContext<Reference> context) { - return expression.evaluate((Context)context).asTensor(); + return expression.evaluate(asContext(context)).asTensor(); } @Override @@ -412,7 +412,12 @@ public class TensorFunctionNode extends CompositeNode { public TensorType getType(Reference name) { return delegate.getType(name); } - } + private static Context asContext(EvaluationContext<Reference> generic) { + if (generic instanceof Context context) { + return context; + } + return new ContextWrapper(generic); + } } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 02a49cea8bc..9b088825201 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -127,6 +127,7 @@ TOKEN : <HAMMING: "hamming"> | <MAP: "map"> | + <MAP_SUBSPACES: "map_subspaces"> | <REDUCE: "reduce"> | <JOIN: "join"> | <MERGE: "merge"> | @@ -369,6 +370,7 @@ TensorFunctionNode tensorFunction() : { ( tensorExpression = tensorMap() | + tensorExpression = tensorMapSubspaces() | tensorExpression = tensorReduce() | tensorExpression = tensorReduceComposites() | tensorExpression = tensorJoin() | @@ -405,6 +407,22 @@ TensorFunctionNode tensorMap() : doubleMapper.asDoubleUnaryOperator())); } } +TensorFunctionNode tensorMapSubspaces() : +{ + ExpressionNode tensor; + LambdaFunctionNode denseMapper; +} +{ + <MAP_SUBSPACES> <LBRACE> tensor = expression() <COMMA> denseMapper = lambdaFunction() <RBRACE> + { + return new TensorFunctionNode( + new MapSubspaces( + TensorFunctionNode.wrap(tensor), + denseMapper.singleArgumentName(), + TensorFunctionNode.wrap(denseMapper.children().get(0)))); + } +} + TensorFunctionNode tensorReduce() : { ExpressionNode tensor; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 6aac6ce0983..f9ba7552560 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -177,6 +177,25 @@ public class EvaluationTestCase { } @Test + public void testMapSubspaces() { + EvaluationTester tester = new EvaluationTester(); + tester.assertEvaluates("tensor<float>(a{},x[2]):{foo:[2,3],bar:[7,10]}", + "map_subspaces(tensor0, f(t)(t))", + "tensor<float>(a{},x[2]):{foo:[2,3],bar:[7,10]}"); + tester.assertEvaluates("tensor<float>(a{},x[2]):{foo:[2,3],bar:[7,10]}", + "map_subspaces(tensor0, f(t)(t+2))", + "tensor<float>(a{},x[2]):{foo:[0,1],bar:[5,8]}"); + + tester.assertEvaluates("tensor<float>(a{},y[2]):{foo:[3,5],bar:[9,11]}", + "map_subspaces(tensor0, f(t)(tensor<float>(y[2])(t{x:(y)}+t{x:(y+1)})))", + "tensor(a{},x[3]):{foo:[1,2,3],bar:[4,5,6]}"); + + tester.assertEvaluates("tensor<double>(a{},x[2]):{foo:[3,5],bar:[9,11]}", + "map_subspaces(tensor0, f(t)(tensor(x[2])(t{x:(x)}+t{x:(x+1)})))", + "tensor<float>(a{},x[3]):{foo:[1,2,3],bar:[4,5,6]}"); + } + + @Test public void testTensorEvaluation() { EvaluationTester tester = new EvaluationTester(); tester.assertEvaluates("{}", "tensor0", "{}"); @@ -296,7 +315,7 @@ public class EvaluationTestCase { "{{x:0}:1}", "{}", "{{y:0,z:0}:1}"); tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:3 }", "tensor(x{}):{ {x:1}:5 }"); - tester.assertEvaluates("tensor<float>(x{}):{}", + tester.assertEvaluates("tensor<double>(x{}):{}", "tensor0 * tensor1", "{ {x:0}:3 }", "tensor<float>(x{}):{ {x:1}:5 }"); tester.assertEvaluates("{ {x:0}:15 }", "tensor0 * tensor1", "{ {x:0}:3 }", "{ {x:0}:5 }"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index cc278b3d73b..e1018e3b2a5 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -75,7 +76,24 @@ public class EvaluationTester { RankingExpression expression = new RankingExpression(expressionString); if ( ! explanation.isEmpty()) explanation = explanation + ": "; - assertEquals(explanation + expression, value, expression.evaluate(context)); + var result = expression.evaluate(context); + assertEquals(explanation + expression, value, result); + assertEquals(value.type().valueType(), result.type().valueType()); + var root = expression.getRoot(); + String asString = root.toString(); + try { + expression = new RankingExpression(asString); + result = expression.evaluate(context); + assertEquals(explanation + expressionString + " -> " + asString, value, result); + assertEquals(value.type().valueType(), result.type().valueType()); + } catch (Exception e) { + System.err.println("toString() failure, " + expressionString + " -> " + asString); + System.err.println("root: " + root.getClass()); + if (root instanceof TensorFunctionNode tfn) { + System.err.println("root func: " + tfn.function().getClass()); + } + throw new RuntimeException(e); + } return expression; } catch (ParseException e) { |