diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-21 14:25:01 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-21 14:25:01 +0100 |
commit | ed8ec5305f6838e31de94ef87ddd3a75390b59ed (patch) | |
tree | 6266387837bafdc29713b1a9605919b59fd86079 /searchlib | |
parent | b56911f909e6ca68fa0a02cf5932d422a61a9f49 (diff) |
- Tensor generate implementation
- Cross tensor implementation equals
- Better iteration
Diffstat (limited to 'searchlib')
4 files changed, 128 insertions, 9 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java new file mode 100644 index 00000000000..1e3b0c4362d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -0,0 +1,80 @@ +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.function.*; + +/** + * A tensor generating function, whose arguments are determined by a tensor type + * + * @author bratseth + */ +public class GeneratorLambdaFunctionNode extends CompositeNode { + + private final TensorType type; + private final ExpressionNode generator; + + public GeneratorLambdaFunctionNode(TensorType type, ExpressionNode generator) { + if ( ! type.dimensions().stream().allMatch(d -> d.size().isPresent())) + throw new IllegalArgumentException("A tensor generator function can only generate tensors with bound " + + "dimensions, but tried to generate " + type); + // TODO: Verify that the function only accesses the given arguments + this.type = type; + this.generator = generator; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(generator); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> children) { + if ( children.size() != 1) + throw new IllegalArgumentException("A lambda function must have a single child expression"); + return new GeneratorLambdaFunctionNode(type, children.get(0)); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return (type + "(" + generator.toString(context, path, this)) + ")"; + } + + /** Evaluate this in a context which must have the arguments bound */ + @Override + public Value evaluate(Context context) { + return generator.evaluate(context); + } + + /** + * Returns this as an operator which converts a list of integers into a double + */ + public IntegerListToDoubleLambda asIntegerListToDoubleOperator() { + return new IntegerListToDoubleLambda(); + } + + private class IntegerListToDoubleLambda implements java.util.function.Function<List<Integer>, Double> { + + @Override + public Double apply(List<Integer> arguments) { + MapContext context = new MapContext(); + for (int i = 0; i < type.dimensions().size(); i++) + context.put(type.dimensions().get(i).name(), arguments.get(i)); + return evaluate(context).asDouble(); + } + + @Override + public String toString() { + return GeneratorLambdaFunctionNode.this.toString(); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 7b48288598d..5a96cf4bbae 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -22,7 +22,7 @@ public class LambdaFunctionNode extends CompositeNode { private final ExpressionNode functionExpression; public LambdaFunctionNode(List<String> arguments, ExpressionNode functionExpression) { - // TODO: Verify that the function only accesses the arguments in mapperVariables + // TODO: Verify that the function only accesses the given arguments this.arguments = ImmutableList.copyOf(arguments); this.functionExpression = functionExpression; } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 0fcfdb5d40c..564b2cd9801 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -236,7 +236,7 @@ IfNode ifExpression() : } { ( <IF> <LBRACE> ( condition = expression() ) - <COMMA> ifTrue = expression() <COMMA> ifFalse = expression() ( <COMMA> trueProbability = number() )? <RBRACE> ) + <COMMA> ifTrue = expression() <COMMA> ifFalse = expression() ( <COMMA> trueProbability = doubleNumber() )? <RBRACE> ) { return new IfNode(condition, ifTrue, ifFalse, trueProbability); } @@ -420,15 +420,37 @@ ExpressionNode tensorRename() : { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } } -// TODO: Notice that null is parsed below ExpressionNode tensorGenerate() : { TensorType type; - LambdaFunctionNode generator; + ExpressionNode generator; } { - <TENSOR> <LBRACE> <RBRACE> <LBRACE> - { return new TensorFunctionNode(new Generate(null, null)); } + type = tensorType() <LBRACE> generator = expression() <RBRACE> + { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asIntegerListToDoubleOperator())); } +} + +TensorType tensorType() : +{ + TensorType.Builder builder = new TensorType.Builder(); +} +{ + <TENSOR> <LBRACE> + ( tensorTypeDimension(builder) ) ? + ( <COMMA> tensorTypeDimension(builder) ) * + <RBRACE> + { return builder.build(); } +} + +// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need +void tensorTypeDimension(TensorType.Builder builder) : +{ + String name; + int size; +} +{ + name = identifier() <LSQUARE> size = integerNumber() <RSQUARE> + { builder.indexed(name, size); } } ExpressionNode tensorL1Normalize() : @@ -574,7 +596,7 @@ List<ExpressionNode> expressionList() : { return list; } } -double number() : +double doubleNumber() : { String sign = ""; } @@ -583,6 +605,15 @@ double number() : { return Double.parseDouble(sign + token.image); } } +int integerNumber() : +{ + String sign = ""; +} +{ + ( <SUB> { sign = "-";} )? ( <INTEGER> ) + { return Integer.parseInt(sign + token.image); } +} + String identifier() : { String name; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 8fa9076993e..55638c3687b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -232,8 +232,10 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {newX:0,y:0}:3 }", "rename(tensor0, x, newX)", "{ {x:0,y:0}:3.0 }"); tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }", "rename(tensor0, (x, y), (y, x))", "{ {x:0,y:0}:3.0, {x:0,y:1}:5.0 }"); - // tensor generate - TODO - // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)"); + // tensor generate + tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:0, {x:0,y:1}:1, {x:1,y:1}:0, {x:0,y:2}:0, {x:1,y:2}:1 }", "tensor(x[2],y[3])(x+1==y)"); + tester.assertEvaluates("{ {y:0,x:0}:0, {y:1,x:0}:0, {y:0,x:1}:1, {y:1,x:1}:0, {y:0,x:2}:0, {y:1,x:2}:1 }", "tensor(y[2],x[3])(y+1==x)"); + // TODO // range // diag // fill @@ -263,6 +265,12 @@ public class EvaluationTestCase { } @Test + public void testItz() { + EvaluationTester tester = new EvaluationTester(); + tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }"); + } + + @Test public void testProgrammaticBuildingAndPrecedence() { RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4)))); RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4))); |