summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-21 14:25:01 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-21 14:25:01 +0100
commited8ec5305f6838e31de94ef87ddd3a75390b59ed (patch)
tree6266387837bafdc29713b1a9605919b59fd86079 /searchlib
parentb56911f909e6ca68fa0a02cf5932d422a61a9f49 (diff)
- Tensor generate implementation
- Cross tensor implementation equals - Better iteration
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java80
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj43
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java12
4 files changed, 128 insertions, 9 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
new file mode 100644
index 00000000000..1e3b0c4362d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -0,0 +1,80 @@
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.*;
+
+/**
+ * A tensor generating function, whose arguments are determined by a tensor type
+ *
+ * @author bratseth
+ */
+public class GeneratorLambdaFunctionNode extends CompositeNode {
+
+ private final TensorType type;
+ private final ExpressionNode generator;
+
+ public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) {
+ if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent()))
+ throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " +
+ "dimensions, but tried to generate " + type);
+ // TODO: Verify that the function only accesses the given arguments
+ this.type = type;
+ this.generator = generator;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(generator);
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> children) {
+ if ( children.size() != 1)
+ throw new IllegalArgumentException("A lambda function must have a single child expression");
+ return new GeneratorLambdaFunctionNode(type, children.get(0));
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return (type + "(" + generator.toString(context, path, this)) + ")";
+ }
+
+ /** Evaluate this in a context which must have the arguments bound */
+ @Override
+ public Value evaluate(Context context) {
+ return generator.evaluate(context);
+ }
+
+ /**
+ * Returns this as an operator which converts a list of integers into a double
+ */
+ public IntegerListToDoubleLambda asIntegerListToDoubleOperator() {
+ return new IntegerListToDoubleLambda();
+ }
+
+ private class IntegerListToDoubleLambda implements java.util.function.Function<List<Integer>, Double> {
+
+ @Override
+ public Double apply(List<Integer> arguments) {
+ MapContext context = new MapContext();
+ for (int i = 0; i < type.dimensions().size(); i++)
+ context.put(type.dimensions().get(i).name(), arguments.get(i));
+ return evaluate(context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return GeneratorLambdaFunctionNode.this.toString();
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 7b48288598d..5a96cf4bbae 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -22,7 +22,7 @@ public class LambdaFunctionNode extends CompositeNode {
private final ExpressionNode functionExpression;
public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) {
- // TODO: Verify that the function only accesses the arguments in mapperVariables
+ // TODO: Verify that the function only accesses the given arguments
this.arguments = ImmutableList.copyOf(arguments);
this.functionExpression = functionExpression;
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 0fcfdb5d40c..564b2cd9801 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -236,7 +236,7 @@ IfNode ifExpression() :
}
{
( <IF> <LBRACE> ( condition = expression() )
- <COMMA> ifTrue = expression() <COMMA> ifFalse = expression() ( <COMMA> trueProbability = number() )? <RBRACE> )
+ <COMMA> ifTrue = expression() <COMMA> ifFalse = expression() ( <COMMA> trueProbability = doubleNumber() )? <RBRACE> )
{
return new IfNode(condition, ifTrue, ifFalse, trueProbability);
}
@@ -420,15 +420,37 @@ ExpressionNode tensorRename() :
{ return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
}
-// TODO: Notice that null is parsed below
ExpressionNode tensorGenerate() :
{
TensorType type;
- LambdaFunctionNode generator;
+ ExpressionNode generator;
}
{
- <TENSOR> <LBRACE> <RBRACE> <LBRACE>
- { return new TensorFunctionNode(new Generate(null, null)); }
+ type = tensorType() <LBRACE> generator = expression() <RBRACE>
+ { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asIntegerListToDoubleOperator())); }
+}
+
+TensorType tensorType() :
+{
+ TensorType.Builder builder = new TensorType.Builder();
+}
+{
+ <TENSOR> <LBRACE>
+ ( tensorTypeDimension(builder) ) ?
+ ( <COMMA> tensorTypeDimension(builder) ) *
+ <RBRACE>
+ { return builder.build(); }
+}
+
+// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need
+void tensorTypeDimension(TensorType.Builder builder) :
+{
+ String name;
+ int size;
+}
+{
+ name = identifier() <LSQUARE> size = integerNumber() <RSQUARE>
+ { builder.indexed(name, size); }
}
ExpressionNode tensorL1Normalize() :
@@ -574,7 +596,7 @@ List<ExpressionNode> expressionList() :
{ return list; }
}
-double number() :
+double doubleNumber() :
{
String sign = "";
}
@@ -583,6 +605,15 @@ double number() :
{ return Double.parseDouble(sign + token.image); }
}
+int integerNumber() :
+{
+ String sign = "";
+}
+{
+ ( <SUB> { sign = "-";} )? ( <INTEGER> )
+ { return Integer.parseInt(sign + token.image); }
+}
+
String identifier() :
{
String name;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 8fa9076993e..55638c3687b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -232,8 +232,10 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {newX:0,y:0}:3 }", "rename(tensor0, x, newX)", "{ {x:0,y:0}:3.0 }");
tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }", "rename(tensor0, (x, y), (y, x))", "{ {x:0,y:0}:3.0, {x:0,y:1}:5.0 }");
- // tensor generate - TODO
- // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)");
+ // tensor generate
+ tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:0, {x:0,y:1}:1, {x:1,y:1}:0, {x:0,y:2}:0, {x:1,y:2}:1 }", "tensor(x[2],y[3])(x+1==y)");
+ tester.assertEvaluates("{ {y:0,x:0}:0, {y:1,x:0}:0, {y:0,x:1}:1, {y:1,x:1}:0, {y:0,x:2}:0, {y:1,x:2}:1 }", "tensor(y[2],x[3])(y+1==x)");
+ // TODO
// range
// diag
// fill
@@ -263,6 +265,12 @@ public class EvaluationTestCase {
}
@Test
+ public void testItz() {
+ EvaluationTester tester = new EvaluationTester();
+ tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }");
+ }
+
+ @Test
public void testProgrammaticBuildingAndPrecedence() {
RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4))));
RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4)));