diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-05 22:49:08 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-05 22:49:08 +0100 |
commit | ed8c274dc76794efa692efba6cf509b058b13648 (patch) | |
tree | c1dcb9fbc70b851be5cfdb8c335089283715f698 /searchlib/src | |
parent | 64c5daa351557869e64786188afa75ed3b59991b (diff) |
Literal tensors with value expressions
Diffstat (limited to 'searchlib/src')
5 files changed, 98 insertions, 51 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index c1732aabf0b..e6e49e07c34 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.TypeContext; @@ -14,9 +15,13 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -72,10 +77,44 @@ public class TensorFunctionNode extends CompositeNode { return new TensorValue(function.evaluate(context)); } - public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { + public static TensorFunctionExpressionNode wrap(ExpressionNode node) { return new TensorFunctionExpressionNode(node); } + public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>(); + for (var entry : nodes.entrySet()) + closures.put(entry.getKey(), new ExpressionClosure(entry.getValue())); + return closures; + } + + public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) { + List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>(); + for (var entry : nodes) + closures.add(new ExpressionClosure(entry)); + return closures; + } + + private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> { + + private final ExpressionNode expression; + + public ExpressionClosure(ExpressionNode expression) { + this.expression = expression; + } + + @Override + public Double apply(EvaluationContext<?> context) { + return expression.evaluate((Context)context).asDouble(); + } + + @Override + public String toString() { + return expression.toString(); + } + + } + /** * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 979c5b0f88c..6d687b015f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); - TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1); Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 8f411bf6593..47555d95e58 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +@SuppressWarnings({"rawtypes", "unchecked"}) public class RankingExpressionParser { } @@ -401,8 +402,8 @@ ExpressionNode tensorMap() : } { <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), - doubleMapper.asDoubleUnaryOperator())); } + { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor), + doubleMapper.asDoubleUnaryOperator())); } } ExpressionNode tensorReduce() : @@ -413,7 +414,7 @@ ExpressionNode tensorReduce() : } { <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } ExpressionNode tensorReduceComposites() : @@ -425,7 +426,7 @@ ExpressionNode tensorReduceComposites() : { aggregator = tensorReduceAggregator() <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } ExpressionNode tensorJoin() : @@ -435,9 +436,9 @@ ExpressionNode tensorJoin() : } { <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - doubleJoiner.asDoubleBinaryOperator())); } + { return new TensorFunctionNode(new Join(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + doubleJoiner.asDoubleBinaryOperator())); } } ExpressionNode tensorRename() : @@ -450,7 +451,7 @@ ExpressionNode tensorRename() : fromDimensions = bracedIdentifierList() <COMMA> toDimensions = bracedIdentifierList() <RBRACE> - { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } + { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); } } ExpressionNode tensorConcat() : @@ -460,8 +461,8 @@ ExpressionNode tensorConcat() : } { <CONCAT> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = tag() <RBRACE> - { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), + { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), dimension)); } } @@ -522,7 +523,7 @@ ExpressionNode tensorL1Normalize() : } { <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorL2Normalize() : @@ -532,7 +533,7 @@ ExpressionNode tensorL2Normalize() : } { <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorMatmul() : @@ -542,9 +543,9 @@ ExpressionNode tensorMatmul() : } { <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - dimension)); } + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + dimension)); } } ExpressionNode tensorSoftmax() : @@ -554,7 +555,7 @@ ExpressionNode tensorSoftmax() : } { <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorXwPlusB() : @@ -567,9 +568,9 @@ ExpressionNode tensorXwPlusB() : tensor2 = expression() <COMMA> tensor3 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - TensorFunctionNode.wrapArgument(tensor3), + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + TensorFunctionNode.wrap(tensor3), dimension)); } } @@ -580,7 +581,7 @@ ExpressionNode tensorArgmax() : } { <ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorArgmin() : @@ -590,7 +591,7 @@ ExpressionNode tensorArgmin() : } { <ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimension)); } } LambdaFunctionNode lambdaFunction() : @@ -823,63 +824,62 @@ Value primitiveValue() : { return Value.parse(sign + token.image); } } -ConstantNode tensorValueBody(TensorType type) : +ExpressionNode tensorValueBody(TensorType type) : { - Tensor.Builder builder = Tensor.Builder.of(type); + DynamicTensor dynamicTensor; } { <COLON> ( - mappedTensorValueBody(builder) | - indexedTensorValueBody(builder) + dynamicTensor = mappedTensorValueBody(type) | + dynamicTensor = indexedTensorValueBody(type) ) - { return new ConstantNode(new TensorValue(builder.build())); } + { return new TensorFunctionNode(dynamicTensor); } } -void mappedTensorValueBody(Tensor.Builder builder) : {} +DynamicTensor mappedTensorValueBody(TensorType type) : +{ + java.util.Map cells = new LinkedHashMap(); +} { <LCURLY> - ( tensorCell(builder.cell()))* - ( <COMMA> tensorCell(builder.cell()))* + ( tensorCell(type, cells))* + ( <COMMA> tensorCell(type, cells))* <RCURLY> + { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } } -void indexedTensorValueBody(Tensor.Builder builder) : +DynamicTensor indexedTensorValueBody(TensorType type) : { - IndexedTensor.BoundBuilder indexedBuilder; - long index = 0; - double value; + List cells = new ArrayList(); + ExpressionNode value; } { - { - if ( ! (builder instanceof IndexedTensor.BoundBuilder)) - throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " + - "bound tensors, not " + builder.type()); - indexedBuilder = (IndexedTensor.BoundBuilder)builder; - } <LSQUARE> - ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* - ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* + ( value = expression() { cells.add(value); } )* + ( <COMMA> value = expression() { cells.add(value); } )* <RSQUARE> + { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } } -void tensorCell(Tensor.Builder.CellBuilder cellBuilder) : +void tensorCell(TensorType type, java.util.Map cells) : { - double value; + ExpressionNode value; + TensorAddress.Builder addressBuilder = new TensorAddress.Builder(type); } { <LCURLY> - ( labelAndDimension(cellBuilder))* - ( <COMMA> labelAndDimension(cellBuilder))* + ( labelAndDimension(addressBuilder))* + ( <COMMA> labelAndDimension(addressBuilder))* <RCURLY> - <COLON> value = doubleNumber() { cellBuilder.value(value); } + <COLON> value = expression() { cells.put(addressBuilder.build(), value); } } -void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) : +void labelAndDimension(TensorAddress.Builder addressBuilder) : { String dimension, label; } { dimension = identifier() <COLON> label = tag() - { cellBuilder.label(dimension, label); } + { addressBuilder.add(dimension, label); } }
\ No newline at end of file diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 571e1f4d608..a41f24b3b8a 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -154,7 +154,12 @@ public class RankingExpressionTestCase { "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)"); assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))", "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)"); - + assertSerialization("tensor(x{}):{{x:a}:1 + 2 + 3,{x:b}:if (1 > 2, 3, 4),{x:c}:reduce(tensor0 * tensor1, sum)}", + "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }"); + assertSerialization("tensor(x[3]):[1.0,2.0,3]", + "tensor(x[3]):[1.0, 2.0, 3]"); + assertSerialization("tensor(x[3]):[1.0,reduce(tensor0 * tensor1, sum),3]", + "tensor(x[3]):[1.0, sum(tensor0*tensor1), 3]"); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 7aafb8efee7..e28daefdabf 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -368,6 +368,9 @@ public class EvaluationTestCase { "tensor(x{}):{}"); tester.assertEvaluates("tensor():{{}:1}", "tensor():{{}:1}"); + tester.assertEvaluates("tensor(x{}):{ {x:a}:6.0, {x:b}:4.0, {x:c}:14.0 }", + "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }", + "{ {x:0}:7 }", "tensor(x{}):{ {x:0}:2 }"); } @Test |