diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-04 22:08:22 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-04 22:08:22 +0100 |
commit | 41ea443855d08588395562f47c41b25c19afe76f (patch) | |
tree | e6cd9ca79ff819cf992116e65a4e6f86fba0a576 /searchlib | |
parent | 9ce04da332229898dce815cea1050cfb36e50d5e (diff) |
Support literal tensors in expressions
Diffstat (limited to 'searchlib')
3 files changed, 93 insertions, 34 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index d2170cf2d1d..5ef3cd61366 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -880,6 +880,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRename()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorConcat()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerate()", + "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerateBody(com.yahoo.tensor.TensorType)", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRange()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorDiag()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRandom()", @@ -892,7 +893,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", - "public final com.yahoo.tensor.TensorType tensorTypeArgument()", + "public final com.yahoo.tensor.TensorType tensorType()", "public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()", "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)", "public final java.lang.String tensorFunctionName()", @@ -908,6 +909,11 @@ "public final java.util.List tagCommaLeadingList()", "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()", + "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode tensorValueBody(com.yahoo.tensor.TensorType)", + "public final void mappedTensorValueBody(com.yahoo.tensor.Tensor$Builder)", + "public final void indexedTensorValueBody(com.yahoo.tensor.Tensor$Builder)", + "public final void tensorCell(com.yahoo.tensor.Tensor$Builder$CellBuilder)", + "public final void labelAndDimension(com.yahoo.tensor.Tensor$Builder$CellBuilder)", "public void <init>(java.io.InputStream)", "public void <init>(java.io.InputStream, java.lang.String)", "public void ReInit(java.io.InputStream)", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 954fe75577e..8f411bf6593 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -147,18 +147,6 @@ TOKEN : <MAX: "max"> | <MIN: "min"> | - <TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? > | - <TENSOR_VALUE_ML: "value" (<SEARCHLIB_SKIP>)? "{" (["\n"," "])* ("{"<BRACE_ML_LEVEL_1>) (["\n"," "])* "}" ("\n")? > | - < #BRACE_SL_LEVEL_1: (("{"<BRACE_SL_LEVEL_2>)|<BRACE_SL_CONTENT>)* "}" > | - < #BRACE_SL_LEVEL_2: (("{"<BRACE_SL_LEVEL_3>)|<BRACE_SL_CONTENT>)* "}" > | - < #BRACE_SL_LEVEL_3: <BRACE_SL_CONTENT> "}" > | - < #BRACE_SL_CONTENT: (~["{","}","\n"])* > | - < #BRACE_ML_LEVEL_1: (("{"<BRACE_ML_LEVEL_2>)|<BRACE_ML_CONTENT>)* "}" > | - < #BRACE_ML_LEVEL_2: (("{"<BRACE_ML_LEVEL_3>)|<BRACE_ML_CONTENT>)* "}" > | - < #BRACE_ML_LEVEL_3: <BRACE_ML_CONTENT> "}" > | - < #BRACE_ML_CONTENT: (~["{","}"])* > | - < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ > | - <IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)> } @@ -255,8 +243,7 @@ ExpressionNode value() : LOOKAHEAD(4) ret = function() | ret = feature() | ret = legacyQueryFeature() | - ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) | - ret = tensorValue() + ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) ) { @@ -481,10 +468,23 @@ ExpressionNode tensorConcat() : ExpressionNode tensorGenerate() : { TensorType type; + ExpressionNode expression; +} +{ + <TENSOR> type = tensorType() + ( + expression = tensorGenerateBody(type) | + expression = tensorValueBody(type) + ) + { return expression; } +} + +ExpressionNode tensorGenerateBody(TensorType type) : +{ ExpressionNode generator; } { - <TENSOR> type = tensorType() <LBRACE> generator = expression() <RBRACE> + <LBRACE> generator = expression() <RBRACE> { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); } } @@ -635,15 +635,18 @@ TensorType.Value optionalTensorValueTypeParameter() : { return TensorType.Value.fromId(valueType); } } -// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need void tensorTypeDimension(TensorType.Builder builder) : { String name; int size; } { - name = identifier() <LSQUARE> size = integerNumber() <RSQUARE> - { builder.indexed(name, size); } + name = identifier() + ( + ( <LCURLY> <RCURLY> { builder.mapped(name); } ) | + LOOKAHEAD(2) ( <LSQUARE> <RSQUARE> { builder.indexed(name); } ) | + ( <LSQUARE> size = integerNumber() <RSQUARE> { builder.indexed(name, size); } ) + ) } // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge @@ -717,8 +720,8 @@ List<ExpressionNode> expressionList() : } { <LSQUARE> - expression=expression() { list.add(expression); } - ( LOOKAHEAD(2) <COMMA> expression=expression() { list.add(expression); } ) * + expression = expression() { list.add(expression); } + ( LOOKAHEAD(2) <COMMA> expression = expression() { list.add(expression); } ) * <RSQUARE> { return list; } } @@ -820,18 +823,63 @@ Value primitiveValue() : { return Value.parse(sign + token.image); } } -ConstantNode tensorValue() : +ConstantNode tensorValueBody(TensorType type) : { - TensorType type; - String value; + Tensor.Builder builder = Tensor.Builder.of(type); } { - type = tensorType() <COLON> - ( <TENSOR_VALUE_SL> { value = token.image.substring(token.image.indexOf(":") + 1); } | - <TENSOR_VALUE_ML> { value = token.image.substring(token.image.indexOf("{") + 1, - token.image.lastIndexOf("}")); } ) + ( + mappedTensorValueBody(builder) | + indexedTensorValueBody(builder) + ) + { return new ConstantNode(new TensorValue(builder.build())); } +} + +void mappedTensorValueBody(Tensor.Builder builder) : {} +{ + <LCURLY> + ( tensorCell(builder.cell()))* + ( <COMMA> tensorCell(builder.cell()))* + <RCURLY> +} + +void indexedTensorValueBody(Tensor.Builder builder) : +{ + IndexedTensor.BoundBuilder indexedBuilder; + long index = 0; + double value; +} +{ { - return new ConstantNode(new TensorValue(Tensor.from(type, value))); + if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " + + "bound tensors, not " + builder.type()); + indexedBuilder = (IndexedTensor.BoundBuilder)builder; } + <LSQUARE> + ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* + ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* + <RSQUARE> } + +void tensorCell(Tensor.Builder.CellBuilder cellBuilder) : +{ + double value; +} +{ + <LCURLY> + ( labelAndDimension(cellBuilder))* + ( <COMMA> labelAndDimension(cellBuilder))* + <RCURLY> + <COLON> value = doubleNumber() { cellBuilder.value(value); } +} + +void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) : +{ + String dimension, label; +} +{ + dimension = identifier() <COLON> label = tag() + { cellBuilder.label(dimension, label); } +}
\ No newline at end of file diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 39d8a043e18..7aafb8efee7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -358,11 +358,16 @@ public class EvaluationTestCase { @Test public void testLiteralTensors() { EvaluationTester tester = new EvaluationTester(); - //tester.assertEvaluates("tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }", - // "tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }"); - //tester.assertEvaluates("tensor(x[3]):[1.0, 2.0, 3.0]", - // "tensor(x[3]):[1.0, 2.0, 3.0]"); - + tester.assertEvaluates("tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }", + "tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }"); + tester.assertEvaluates("tensor(x[3]):[1.0, 2, 3.0]", + "tensor(x[3]):[1.0, 2.0, 3]"); + tester.assertEvaluates("tensor(x{},y{}):{ {x:a,y:0}:1.0, {x:b,y:0}:2.0, {x:c,y:0}:3.0 }", + "tensor(x{},y{}):{ {x:a,y:0}:1.0, {x:b,y:0}:2.0, {x:c,y:0}:3.0 }"); + tester.assertEvaluates("tensor(x{}):{}", + "tensor(x{}):{}"); + tester.assertEvaluates("tensor():{{}:1}", + "tensor():{{}:1}"); } @Test |