aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java60
1 files changed, 59 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index c88fc18e6c6..f96dd420d30 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -2,8 +2,10 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
@@ -11,8 +13,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.Function;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -27,6 +32,8 @@ import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
public class Reshape extends IntermediateOperation {
private final AttributeMap attributeMap;
@@ -38,6 +45,10 @@ public class Reshape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
if (inputs.size() == 2) {
return typeWithShapeAsInput();
} else if (inputs.size() == 1) {
@@ -126,10 +137,54 @@ public class Reshape extends IntermediateOperation {
return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
+ IntermediateOperation input = inputs.get(0);
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ // ala (d0 * 2 + d1)
+ ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType));
+
+ long innerSize = 1;
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ innerSize *= inputType.dimensions().get(dim).size().get();
+ }
+
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ String inputDimensionName = inputType.dimensions().get(dim).name();
+ long inputDimensionSize = inputType.dimensions().get(dim).size().get();
+ long previousInnerSize = innerSize;
+ innerSize /= inputDimensionSize;
+
+ ExpressionNode inputDimensionExpression;
+ if (inputDimensionSize == 1) {
+ inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
+ } else if (dim == (inputType.rank() - 1)) {
+ ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
+ ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ inputDimensionExpression = new EmbracedNode(div);
+ } else {
+ ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
+ ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
+ ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
+ ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div));
+ }
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(outputType.type(), wrapScalar(sliceExpression));
+ return generate;
+
+ /*
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
// Here we create a transformation tensor that is multiplied with the from tensor to map into
@@ -168,11 +223,14 @@ public class Reshape extends IntermediateOperation {
result = new Rename(result, to, from);
}
return result;
+ */
}
+ /*
private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) {
return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent());
}
+ */
private static ExpressionNode unrollTensorExpression(OrderedTensorType type) {
if (type.rank() == 0)