diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-13 09:25:55 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-13 09:25:55 -0800 |
commit | 4b3768bb618725d34d40dab49d7b375b1a19035e (patch) | |
tree | eb13c325e79eb749e935691863de6fa0ece43907 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | |
parent | 58bbdee39ce7b38e9dad0956e7b0e57319e8b0b8 (diff) |
Parse mixed tensors
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 0a67ab5534e..11fc581640d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -91,6 +92,33 @@ public class TensorFunctionNode extends CompositeNode { return functions; } + public static void wrapScalarBlock(TensorType type, + String mappedDimensionLabel, + List<ExpressionNode> nodes, + Map<TensorAddress, ScalarFunction<Reference>> receivingMap) { + TensorType.Dimension sparseDimension = type.dimensions().stream().filter(d -> ! d.isIndexed()).findFirst().get(); + TensorType denseSubtype = new TensorType(type.valueType(), + type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList())); + + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype); + for (ExpressionNode node : nodes) { + indexes.next(); + + // Insert the mapped dimension into the dense subspace address of indexes + String[] labels = new String[type.rank()]; + int indexedDimensionsIndex = 0; + int allDimensionsIndex = 0; + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.isIndexed()) + labels[allDimensionsIndex++] = String.valueOf(indexes.indexesForReading()[indexedDimensionsIndex++]); + else + labels[allDimensionsIndex++] = mappedDimensionLabel; + } + + receivingMap.put(TensorAddress.of(labels), wrapScalar(node)); + } + } + public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) { return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); } |