aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-13 09:25:55 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-13 09:25:55 -0800
commit4b3768bb618725d34d40dab49d7b375b1a19035e (patch)
treeeb13c325e79eb749e935691863de6fa0ece43907 /searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
parent58bbdee39ce7b38e9dad0956e7b0e57319e8b0b8 (diff)
Parse mixed tensors
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java28
1 files changed, 28 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 0a67ab5534e..11fc581640d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -91,6 +92,33 @@ public class TensorFunctionNode extends CompositeNode {
return functions;
}
+ public static void wrapScalarBlock(TensorType type,
+ String mappedDimensionLabel,
+ List<ExpressionNode> nodes,
+ Map<TensorAddress, ScalarFunction<Reference>> receivingMap) {
+ TensorType.Dimension sparseDimension = type.dimensions().stream().filter(d -> ! d.isIndexed()).findFirst().get();
+ TensorType denseSubtype = new TensorType(type.valueType(),
+ type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()));
+
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype);
+ for (ExpressionNode node : nodes) {
+ indexes.next();
+
+ // Insert the mapped dimension into the dense subspace address of indexes
+ String[] labels = new String[type.rank()];
+ int indexedDimensionsIndex = 0;
+ int allDimensionsIndex = 0;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.isIndexed())
+ labels[allDimensionsIndex++] = String.valueOf(indexes.indexesForReading()[indexedDimensionsIndex++]);
+ else
+ labels[allDimensionsIndex++] = mappedDimensionLabel;
+ }
+
+ receivingMap.put(TensorAddress.of(labels), wrapScalar(node));
+ }
+ }
+
public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
}