diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
commit | f4203c3cc571722f08ee65047437c1290ed63f69 (patch) | |
tree | 7d06d17091a2e388e6771187a11cf4f4023a0c1e /searchlib | |
parent | 316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff) |
Allow bound functions in tensor generate
Diffstat (limited to 'searchlib')
5 files changed, 45 insertions, 17 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index d5970a4b69e..dcf42069373 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1615,8 +1615,9 @@ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", - "public static java.util.Map wrap(java.util.Map)", - "public static java.util.List wrap(java.util.List)" + "public static java.util.Map wrapScalars(java.util.Map)", + "public static java.util.List wrapScalars(java.util.List)", + "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" ], "fields": [] }, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index d68f8c85ad1..cf17c6465f3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -34,7 +34,7 @@ public abstract class Context implements EvaluationContext<Reference> { @Override public TensorType getType(String reference) { - throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + throw new UnsupportedOperationException("Not able to parse general references from string form"); } /** Returns a variable as a tensor */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 4ffd40f00f7..18f1fa8a78f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.sql.Ref; import java.util.ArrayList; import java.util.Collections; import java.util.Deque; @@ -81,21 +82,22 @@ public class TensorFunctionNode extends CompositeNode { return new ExpressionTensorFunction(node); } - public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) { - Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>(); + public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<>(); for (var entry : nodes.entrySet()) - functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue())); + functions.put(entry.getKey(), wrapScalar(entry.getValue())); return functions; } - public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) { - List<ScalarFunction> functions = new ArrayList<>(); - for (var entry : nodes) - functions.add(new ExpressionScalarFunction(entry)); - return functions; + public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) { + return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); + } + + public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) { + return new ExpressionScalarFunction(node); } - private static class ExpressionScalarFunction implements ScalarFunction { + private static class ExpressionScalarFunction implements ScalarFunction<Reference> { private final ExpressionNode expression; @@ -104,8 +106,8 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public Double apply(EvaluationContext<?> context) { - return expression.evaluate((Context)context).asDouble(); + public Double apply(EvaluationContext<Reference> context) { + return expression.evaluate(new ContextWrapper(context)).asDouble(); } @Override @@ -209,4 +211,26 @@ public class TensorFunctionNode extends CompositeNode { } + /** Turns an EvaluationContext into a Context */ + // TODO: We should be able to change RankingExpression.evaluate to take an EvaluationContext and then get rid of this + private static class ContextWrapper extends Context { + + private final EvaluationContext<Reference> delegate; + + public ContextWrapper(EvaluationContext<Reference> delegate) { + this.delegate = delegate; + } + + @Override + public Value get(String name) { + return new TensorValue(delegate.getTensor(name)); + } + + @Override + public TensorType getType(Reference name) { + return delegate.getType(name); + } + + } + } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 3e9649cd9c6..beab722a1eb 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -487,7 +487,7 @@ TensorFunctionNode tensorGenerateBody(TensorType type) : } { <LBRACE> generator = expression() <RBRACE> - { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); } + { return new TensorFunctionNode(Generate.bound(type, TensorFunctionNode.wrapScalar(generator))); } } TensorFunctionNode tensorRange() : @@ -847,7 +847,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) : ( tensorCell(type, cells))* ( <COMMA> tensorCell(type, cells))* <RCURLY> - { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } DynamicTensor indexedTensorValueBody(TensorType type) : @@ -860,7 +860,7 @@ DynamicTensor indexedTensorValueBody(TensorType type) : ( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* ( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* // <RSQUARE> - { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } void tensorCell(TensorType type, java.util.Map cells) : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index a8afc230bde..05ad8c97c7f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -358,6 +358,9 @@ public class EvaluationTestCase { tester.assertEvaluates("500", "join(tensor0, tensor1, f(x,y) (x*y)){tag2}", "tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}"); + tester.assertEvaluates("tensor(j[3]):[3, 3, 3]", + "tensor(j[3])(tensor0[2])", + "tensor(values[5]):[1, 2, 3, 4, 5]"); // tensor result dimensions are given from argument dimensions, not the resulting values tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }"); |