summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java103
1 files changed, 76 insertions, 27 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index c7accd00619..1b72565b423 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -22,51 +23,97 @@ import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
-import java.util.Iterator;
import java.util.List;
+import java.util.Optional;
import java.util.stream.Collectors;
public class Reshape extends IntermediateOperation {
- public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ private final AttributeMap attributeMap;
+
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
protected OrderedTensorType lazyGetType() {
- if ( ! allInputTypesPresent(2)) return null;
+ if (inputs.size() == 2) {
+ return typeWithShapeAsInput();
+ } else if (inputs.size() == 1) {
+ return typeWithShapeAsAttribute();
+ }
+ throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size());
+ }
+ private OrderedTensorType typeWithShapeAsInput() {
IntermediateOperation newShape = inputs.get(1);
if (newShape.getConstantValue().isEmpty())
- throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
+ throw new IllegalArgumentException("Reshape " + name + ": Shape input must be a constant.");
+ OrderedTensorType inputType = inputs.get(0).type().get();
Tensor shape = newShape.getConstantValue().get().asTensor();
+ List<Integer> dimSizes = new ArrayList<>(shape.type().rank());
+ shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue()));
+
+ // first pass - set 0 values
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ if (dimSizes.get(i) == 0) {
+ if (i >= inputType.dimensions().size()) {
+ throw new IllegalArgumentException("Reshape " + name + ": 0 value for dimension not found in input");
+ }
+ dimSizes.set(i, inputType.dimensions().get(i).size().get().intValue());
+ }
+ }
+
+ // second pass - set any -1 values
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ if (dimSizes.get(i) < 0) {
+ int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b);
+ int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue();
+ dimSizes.set(i, -1 * tensorSize / (shapeSize == 0 ? -1 : shapeSize));
+ }
+ }
+
+ return buildOutputType(dimSizes);
+ }
+
+ private OrderedTensorType typeWithShapeAsAttribute() {
+ if (attributeMap.getList("shape").isEmpty() || attributeMap.getList("shape").get().size() == 0)
+ throw new IllegalArgumentException("Reshape in " + name + ": Shape attribute is empty.");
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
- int dimensionIndex = 0;
- for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int size = cell.getValue().intValue();
+ List<Value> shape = attributeMap.getList("shape").get();
+ List<Integer> dimSizes = new ArrayList<>(shape.size());
+
+ for (Value v : shape) {
+ int size = (int) v.asDouble();
if (size < 0) {
- size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- OrderedTensorType.tensorSize(inputType.type()).intValue();
+ int shapeSize = (int) shape.stream().mapToDouble(Value::asDouble).reduce(1, (a, b) -> a * b);
+ int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue();
+ size = -1 * shapeSize / tensorSize;
}
- outputTypeBuilder.add(TensorType.Dimension.indexed(
- String.format("%s_%d", vespaName(), dimensionIndex), size));
- dimensionIndex++;
+ dimSizes.add(size);
+ }
+ return buildOutputType(dimSizes);
+ }
+
+ private OrderedTensorType buildOutputType(List<Integer> dimSizes) {
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), i), dimSizes.get(i)));
}
return outputTypeBuilder.build();
}
@Override
protected TensorFunction lazyGetFunction() {
- if ( ! allInputTypesPresent(2)) return null;
- if ( ! allInputFunctionsPresent(2)) return null;
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null;
+ if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null;
OrderedTensorType inputType = inputs.get(0).type().get();
TensorFunction inputFunction = inputs.get(0).function().get();
- return reshape(inputFunction, inputType.type(), type.type());
+ return reshape(inputFunction, inputType, type);
}
@Override
@@ -76,11 +123,11 @@ public class Reshape extends IntermediateOperation {
@Override
public Reshape withInputs(List<IntermediateOperation> inputs) {
- return new Reshape(modelName(), name(), inputs);
+ return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
+ public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
@@ -89,25 +136,27 @@ public class Reshape extends IntermediateOperation {
// the new shape. We have to introduce temporary dimension names and rename back if dimension names
// in the new and old tensor type overlap.
+ // Todo: change this to use tensor generate when available
+
List<String> from = new ArrayList<>();
List<String> to = new ArrayList<>();
boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType);
if (dimensionNamesOverlap) {
- TensorType.Builder builder = new TensorType.Builder(outputType.valueType());
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType());
for (int i = 0; i < outputType.rank(); ++i) {
TensorType.Dimension dim = outputType.dimensions().get(i);
from.add(dim.name());
to.add("temp_" + dim.name());
- builder.dimension(dim.withName("temp_" + dim.name()));
+ builder.add(dim.withName("temp_" + dim.name()));
}
outputType = builder.build();
}
ExpressionNode unrollFrom = unrollTensorExpression(inputType);
ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo));
+ ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo));
- TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
+ TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build();
Generate transformTensor = new Generate(transformationType,
new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
@@ -121,11 +170,11 @@ public class Reshape extends IntermediateOperation {
return result;
}
- private static boolean dimensionNamesOverlap(TensorType a, TensorType b) {
- return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent());
+ private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) {
+ return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent());
}
- private static ExpressionNode unrollTensorExpression(TensorType type) {
+ private static ExpressionNode unrollTensorExpression(OrderedTensorType type) {
if (type.rank() == 0)
return new ConstantNode(DoubleValue.zero);