summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-04 22:08:22 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-04 22:08:22 +0100
commit41ea443855d08588395562f47c41b25c19afe76f (patch)
treee6cd9ca79ff819cf992116e65a4e6f86fba0a576 /searchlib
parent9ce04da332229898dce815cea1050cfb36e50d5e (diff)
Support literal tensors in expressions
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json8
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj104
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java15
3 files changed, 93 insertions, 34 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index d2170cf2d1d..5ef3cd61366 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -880,6 +880,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRename()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorConcat()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerate()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerateBody(com.yahoo.tensor.TensorType)",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRange()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorDiag()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRandom()",
@@ -892,7 +893,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmin()",
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
- "public final com.yahoo.tensor.TensorType tensorTypeArgument()",
+ "public final com.yahoo.tensor.TensorType tensorType()",
"public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()",
"public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)",
"public final java.lang.String tensorFunctionName()",
@@ -908,6 +909,11 @@
"public final java.util.List tagCommaLeadingList()",
"public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode tensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final void mappedTensorValueBody(com.yahoo.tensor.Tensor$Builder)",
+ "public final void indexedTensorValueBody(com.yahoo.tensor.Tensor$Builder)",
+ "public final void tensorCell(com.yahoo.tensor.Tensor$Builder$CellBuilder)",
+ "public final void labelAndDimension(com.yahoo.tensor.Tensor$Builder$CellBuilder)",
"public void <init>(java.io.InputStream)",
"public void <init>(java.io.InputStream, java.lang.String)",
"public void ReInit(java.io.InputStream)",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 954fe75577e..8f411bf6593 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -147,18 +147,6 @@ TOKEN :
<MAX: "max"> |
<MIN: "min"> |
- <TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? > |
- <TENSOR_VALUE_ML: "value" (<SEARCHLIB_SKIP>)? "{" (["\n"," "])* ("{"<BRACE_ML_LEVEL_1>) (["\n"," "])* "}" ("\n")? > |
- < #BRACE_SL_LEVEL_1: (("{"<BRACE_SL_LEVEL_2>)|<BRACE_SL_CONTENT>)* "}" > |
- < #BRACE_SL_LEVEL_2: (("{"<BRACE_SL_LEVEL_3>)|<BRACE_SL_CONTENT>)* "}" > |
- < #BRACE_SL_LEVEL_3: <BRACE_SL_CONTENT> "}" > |
- < #BRACE_SL_CONTENT: (~["{","}","\n"])* > |
- < #BRACE_ML_LEVEL_1: (("{"<BRACE_ML_LEVEL_2>)|<BRACE_ML_CONTENT>)* "}" > |
- < #BRACE_ML_LEVEL_2: (("{"<BRACE_ML_LEVEL_3>)|<BRACE_ML_CONTENT>)* "}" > |
- < #BRACE_ML_LEVEL_3: <BRACE_ML_CONTENT> "}" > |
- < #BRACE_ML_CONTENT: (~["{","}"])* > |
- < #SEARCHLIB_SKIP: ([" ","\f","\n","\r","\t"])+ > |
-
<IDENTIFIER: (["A"-"Z","a"-"z","0"-"9","_","@"](["A"-"Z","a"-"z","0"-"9","_","@","$"])*)>
}
@@ -255,8 +243,7 @@ ExpressionNode value() :
LOOKAHEAD(4) ret = function() |
ret = feature() |
ret = legacyQueryFeature() |
- ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) |
- ret = tensorValue()
+ ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) )
)
{
@@ -481,10 +468,23 @@ ExpressionNode tensorConcat() :
ExpressionNode tensorGenerate() :
{
TensorType type;
+ ExpressionNode expression;
+}
+{
+ <TENSOR> type = tensorType()
+ (
+ expression = tensorGenerateBody(type) |
+ expression = tensorValueBody(type)
+ )
+ { return expression; }
+}
+
+ExpressionNode tensorGenerateBody(TensorType type) :
+{
ExpressionNode generator;
}
{
- <TENSOR> type = tensorType() <LBRACE> generator = expression() <RBRACE>
+ <LBRACE> generator = expression() <RBRACE>
{ return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); }
}
@@ -635,15 +635,18 @@ TensorType.Value optionalTensorValueTypeParameter() :
{ return TensorType.Value.fromId(valueType); }
}
-// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need
void tensorTypeDimension(TensorType.Builder builder) :
{
String name;
int size;
}
{
- name = identifier() <LSQUARE> size = integerNumber() <RSQUARE>
- { builder.indexed(name, size); }
+ name = identifier()
+ (
+ ( <LCURLY> <RCURLY> { builder.mapped(name); } ) |
+ LOOKAHEAD(2) ( <LSQUARE> <RSQUARE> { builder.indexed(name); } ) |
+ ( <LSQUARE> size = integerNumber() <RSQUARE> { builder.indexed(name, size); } )
+ )
}
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
@@ -717,8 +720,8 @@ List<ExpressionNode> expressionList() :
}
{
<LSQUARE>
- expression=expression() { list.add(expression); }
- ( LOOKAHEAD(2) <COMMA> expression=expression() { list.add(expression); } ) *
+ expression = expression() { list.add(expression); }
+ ( LOOKAHEAD(2) <COMMA> expression = expression() { list.add(expression); } ) *
<RSQUARE>
{ return list; }
}
@@ -820,18 +823,63 @@ Value primitiveValue() :
{ return Value.parse(sign + token.image); }
}
-ConstantNode tensorValue() :
+ConstantNode tensorValueBody(TensorType type) :
{
- TensorType type;
- String value;
+ Tensor.Builder builder = Tensor.Builder.of(type);
}
{
- type = tensorType()
<COLON>
- ( <TENSOR_VALUE_SL> { value = token.image.substring(token.image.indexOf(":") + 1); } |
- <TENSOR_VALUE_ML> { value = token.image.substring(token.image.indexOf("{") + 1,
- token.image.lastIndexOf("}")); } )
+ (
+ mappedTensorValueBody(builder) |
+ indexedTensorValueBody(builder)
+ )
+ { return new ConstantNode(new TensorValue(builder.build())); }
+}
+
+void mappedTensorValueBody(Tensor.Builder builder) : {}
+{
+ <LCURLY>
+ ( tensorCell(builder.cell()))*
+ ( <COMMA> tensorCell(builder.cell()))*
+ <RCURLY>
+}
+
+void indexedTensorValueBody(Tensor.Builder builder) :
+{
+ IndexedTensor.BoundBuilder indexedBuilder;
+ long index = 0;
+ double value;
+}
+{
{
- return new ConstantNode(new TensorValue(Tensor.from(type, value)));
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " +
+ "bound tensors, not " + builder.type());
+ indexedBuilder = (IndexedTensor.BoundBuilder)builder;
}
+ <LSQUARE>
+ ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
+ ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
+ <RSQUARE>
}
+
+void tensorCell(Tensor.Builder.CellBuilder cellBuilder) :
+{
+ double value;
+}
+{
+ <LCURLY>
+ ( labelAndDimension(cellBuilder))*
+ ( <COMMA> labelAndDimension(cellBuilder))*
+ <RCURLY>
+ <COLON> value = doubleNumber() { cellBuilder.value(value); }
+}
+
+void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) :
+{
+ String dimension, label;
+}
+{
+ dimension = identifier() <COLON> label = tag()
+ { cellBuilder.label(dimension, label); }
+} \ No newline at end of file
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 39d8a043e18..7aafb8efee7 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -358,11 +358,16 @@ public class EvaluationTestCase {
@Test
public void testLiteralTensors() {
EvaluationTester tester = new EvaluationTester();
- //tester.assertEvaluates("tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }",
- // "tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }");
- //tester.assertEvaluates("tensor(x[3]):[1.0, 2.0, 3.0]",
- // "tensor(x[3]):[1.0, 2.0, 3.0]");
-
+ tester.assertEvaluates("tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }",
+ "tensor(x{}):{ {x:a}:1.0, {x:b}:2.0, {x:c}:3.0 }");
+ tester.assertEvaluates("tensor(x[3]):[1.0, 2, 3.0]",
+ "tensor(x[3]):[1.0, 2.0, 3]");
+ tester.assertEvaluates("tensor(x{},y{}):{ {x:a,y:0}:1.0, {x:b,y:0}:2.0, {x:c,y:0}:3.0 }",
+ "tensor(x{},y{}):{ {x:a,y:0}:1.0, {x:b,y:0}:2.0, {x:c,y:0}:3.0 }");
+ tester.assertEvaluates("tensor(x{}):{}",
+ "tensor(x{}):{}");
+ tester.assertEvaluates("tensor():{{}:1}",
+ "tensor():{{}:1}");
}
@Test