diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java | 60 |
1 files changed, 59 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c88fc18e6c6..f96dd420d30 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -2,8 +2,10 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -11,8 +13,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.Function; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -27,6 +32,8 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + public class Reshape extends IntermediateOperation { private final AttributeMap attributeMap; @@ -38,6 +45,10 @@ public class Reshape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + if (inputs.size() == 2) { return typeWithShapeAsInput(); } else if (inputs.size() == 1) { @@ -126,10 +137,54 @@ public class Reshape extends IntermediateOperation { return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); + IntermediateOperation input = inputs.get(0); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + // ala (d0 * 2 + d1) + ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType)); + + long innerSize = 1; + for (int dim = 0; dim < inputType.rank(); ++dim) { + innerSize *= inputType.dimensions().get(dim).size().get(); + } + + for (int dim = 0; dim < inputType.rank(); ++dim) { + String inputDimensionName = inputType.dimensions().get(dim).name(); + long inputDimensionSize = inputType.dimensions().get(dim).size().get(); + long previousInnerSize = innerSize; + innerSize /= inputDimensionSize; + + ExpressionNode inputDimensionExpression; + if (inputDimensionSize == 1) { + inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); + } else if (dim == (inputType.rank() - 1)) { + ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); + ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + inputDimensionExpression = new EmbracedNode(div); + } else { + ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); + ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); + ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); + ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div)); + } + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(outputType.type(), wrapScalar(sliceExpression)); + return generate; + + /* // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, // then use the dimension order of the new shape to roll back into a tensor. // Here we create a transformation tensor that is multiplied with the from tensor to map into @@ -168,11 +223,14 @@ public class Reshape extends IntermediateOperation { result = new Rename(result, to, from); } return result; + */ } + /* private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); } + */ private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { if (type.rank() == 0) |