aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-13 09:25:55 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-13 09:25:55 -0800
commit4b3768bb618725d34d40dab49d7b375b1a19035e (patch)
treeeb13c325e79eb749e935691863de6fa0ece43907 /searchlib/src
parent58bbdee39ce7b38e9dad0956e7b0e57319e8b0b8 (diff)
Parse mixed tensors
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java28
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj35
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java2
3 files changed, 53 insertions, 12 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 0a67ab5534e..11fc581640d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -91,6 +92,33 @@ public class TensorFunctionNode extends CompositeNode {
return functions;
}
+ public static void wrapScalarBlock(TensorType type,
+ String mappedDimensionLabel,
+ List<ExpressionNode> nodes,
+ Map<TensorAddress, ScalarFunction<Reference>> receivingMap) {
+ TensorType.Dimension sparseDimension = type.dimensions().stream().filter(d -> ! d.isIndexed()).findFirst().get();
+ TensorType denseSubtype = new TensorType(type.valueType(),
+ type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()));
+
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype);
+ for (ExpressionNode node : nodes) {
+ indexes.next();
+
+ // Insert the mapped dimension into the dense subspace address of indexes
+ String[] labels = new String[type.rank()];
+ int indexedDimensionsIndex = 0;
+ int allDimensionsIndex = 0;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.isIndexed())
+ labels[allDimensionsIndex++] = String.valueOf(indexes.indexesForReading()[indexedDimensionsIndex++]);
+ else
+ labels[allDimensionsIndex++] = mappedDimensionLabel;
+ }
+
+ receivingMap.put(TensorAddress.of(labels), wrapScalar(node));
+ }
+ }
+
public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 6abd9396ecf..92b465c1303 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -839,9 +839,9 @@ TensorFunctionNode tensorValueBody(TensorType type) :
{
<COLON>
(
+ LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type) |
dynamicTensor = mappedTensorValueBody(type) |
- dynamicTensor = indexedTensorValueBody(type) |
- dynamicTensor = mixedTensorValueBody(type)
+ dynamicTensor = indexedTensorValueBody(type)
)
{ return new TensorFunctionNode(dynamicTensor); }
}
@@ -852,7 +852,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) :
}
{
<LCURLY>
- ( tensorCell(type, cells))*
+ [ tensorCell(type, cells)]
( <COMMA> tensorCell(type, cells))*
<RCURLY>
{ return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
@@ -864,7 +864,7 @@ DynamicTensor mixedTensorValueBody(TensorType type) :
}
{
<LCURLY>
- mixedBlock(type, cells)
+ mixedBlock(type, cells)
( <COMMA> mixedBlock(type, cells))*
<RCURLY>
{ return DynamicTensor.from(type, cells); }
@@ -879,29 +879,42 @@ DynamicTensor indexedTensorValueBody(TensorType type) :
{ return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
-void mixedBlock(TensorType type, Map cellMap) :
+void mixedBlock(TensorType type, java.util.Map cellMap) :
{
String label;
List cells;
}
{
label = tag() <COLON> cells = indexedTensorCells()
- { TensorFunctionNode.wrapScalarBlock(label, cells, cellMap); }
+ { TensorFunctionNode.wrapScalarBlock(type, label, cells, cellMap); }
}
List indexedTensorCells() :
{
List cells = new ArrayList();
- ExpressionNode value;
}
{
- <LSQUARE> // TODO: Parse inner square brackets properly
- ( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
- ( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
-// <RSQUARE>
+ <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE>
{ return cells; }
}
+void indexedTensorCellSubspaceList(List cells) :
+{
+}
+{
+ indexedTensorCellSubspace(cells) ( LOOKAHEAD(2) <COMMA> indexedTensorCellSubspace(cells) )*
+}
+
+void indexedTensorCellSubspace(List cells) :
+{
+ ExpressionNode value;
+}
+{
+ ( <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE> )
+ |
+ ( value = expression() { cells.add(value); } )
+}
+
void tensorCell(TensorType type, java.util.Map cells) :
{
ExpressionNode value;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index e3d3ac7b2e1..b750a7607cc 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -191,7 +191,7 @@ public class RankingExpressionTestCase {
// Accessing a function in a dynamic tensor, short form
assertSerialization(List.of("tensor(x[2]):{{x:0}:rankingExpression(scalarFunction),{x:1}:rankingExpression(scalarFunction)}"),
- "tensor(x[2]):[scalarFunction(), scalarFunction()]]",
+ "tensor(x[2]):[scalarFunction(), scalarFunction()]",
functions, false);
// Accessing a function in a dynamic tensor, long form