diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-13 09:25:55 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-13 09:25:55 -0800 |
commit | 4b3768bb618725d34d40dab49d7b375b1a19035e (patch) | |
tree | eb13c325e79eb749e935691863de6fa0ece43907 /searchlib/src | |
parent | 58bbdee39ce7b38e9dad0956e7b0e57319e8b0b8 (diff) |
Parse mixed tensors
Diffstat (limited to 'searchlib/src')
3 files changed, 53 insertions, 12 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 0a67ab5534e..11fc581640d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -91,6 +92,33 @@ public class TensorFunctionNode extends CompositeNode { return functions; } + public static void wrapScalarBlock(TensorType type, + String mappedDimensionLabel, + List<ExpressionNode> nodes, + Map<TensorAddress, ScalarFunction<Reference>> receivingMap) { + TensorType.Dimension sparseDimension = type.dimensions().stream().filter(d -> ! d.isIndexed()).findFirst().get(); + TensorType denseSubtype = new TensorType(type.valueType(), + type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList())); + + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype); + for (ExpressionNode node : nodes) { + indexes.next(); + + // Insert the mapped dimension into the dense subspace address of indexes + String[] labels = new String[type.rank()]; + int indexedDimensionsIndex = 0; + int allDimensionsIndex = 0; + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.isIndexed()) + labels[allDimensionsIndex++] = String.valueOf(indexes.indexesForReading()[indexedDimensionsIndex++]); + else + labels[allDimensionsIndex++] = mappedDimensionLabel; + } + + receivingMap.put(TensorAddress.of(labels), wrapScalar(node)); + } + } + public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) { return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 6abd9396ecf..92b465c1303 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -839,9 +839,9 @@ TensorFunctionNode tensorValueBody(TensorType type) : { <COLON> ( + LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type) | dynamicTensor = mappedTensorValueBody(type) | - dynamicTensor = indexedTensorValueBody(type) | - dynamicTensor = mixedTensorValueBody(type) + dynamicTensor = indexedTensorValueBody(type) ) { return new TensorFunctionNode(dynamicTensor); } } @@ -852,7 +852,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) : } { <LCURLY> - ( tensorCell(type, cells))* + [ tensorCell(type, cells)] ( <COMMA> tensorCell(type, cells))* <RCURLY> { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } @@ -864,7 +864,7 @@ DynamicTensor mixedTensorValueBody(TensorType type) : } { <LCURLY> - mixedBlock(type, cells) + mixedBlock(type, cells) ( <COMMA> mixedBlock(type, cells))* <RCURLY> { return DynamicTensor.from(type, cells); } @@ -879,29 +879,42 @@ DynamicTensor indexedTensorValueBody(TensorType type) : { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } -void mixedBlock(TensorType type, Map cellMap) : +void mixedBlock(TensorType type, java.util.Map cellMap) : { String label; List cells; } { label = tag() <COLON> cells = indexedTensorCells() - { TensorFunctionNode.wrapScalarBlock(label, cells, cellMap); } + { TensorFunctionNode.wrapScalarBlock(type, label, cells, cellMap); } } List indexedTensorCells() : { List cells = new ArrayList(); - ExpressionNode value; } { - <LSQUARE> // TODO: Parse inner square brackets properly - ( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* - ( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* -// <RSQUARE> + <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE> { return cells; } } +void indexedTensorCellSubspaceList(List cells) : +{ +} +{ + indexedTensorCellSubspace(cells) ( LOOKAHEAD(2) <COMMA> indexedTensorCellSubspace(cells) )* +} + +void indexedTensorCellSubspace(List cells) : +{ + ExpressionNode value; +} +{ + ( <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE> ) + | + ( value = expression() { cells.add(value); } ) +} + void tensorCell(TensorType type, java.util.Map cells) : { ExpressionNode value; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index e3d3ac7b2e1..b750a7607cc 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -191,7 +191,7 @@ public class RankingExpressionTestCase { // Accessing a function in a dynamic tensor, short form assertSerialization(List.of("tensor(x[2]):{{x:0}:rankingExpression(scalarFunction),{x:1}:rankingExpression(scalarFunction)}"), - "tensor(x[2]):[scalarFunction(), scalarFunction()]]", + "tensor(x[2]):[scalarFunction(), scalarFunction()]", functions, false); // Accessing a function in a dynamic tensor, long form |