diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java | 103 |
1 files changed, 76 insertions, 27 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c7accd00619..1b72565b423 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -22,51 +23,97 @@ import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; public class Reshape extends IntermediateOperation { - public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + private final AttributeMap attributeMap; + + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override protected OrderedTensorType lazyGetType() { - if ( ! allInputTypesPresent(2)) return null; + if (inputs.size() == 2) { + return typeWithShapeAsInput(); + } else if (inputs.size() == 1) { + return typeWithShapeAsAttribute(); + } + throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size()); + } + private OrderedTensorType typeWithShapeAsInput() { IntermediateOperation newShape = inputs.get(1); if (newShape.getConstantValue().isEmpty()) - throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); + throw new IllegalArgumentException("Reshape " + name + ": Shape input must be a constant."); + OrderedTensorType inputType = inputs.get(0).type().get(); Tensor shape = newShape.getConstantValue().get().asTensor(); + List<Integer> dimSizes = new ArrayList<>(shape.type().rank()); + shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue())); + + // first pass - set 0 values + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) == 0) { + if (i >= inputType.dimensions().size()) { + throw new IllegalArgumentException("Reshape " + name + ": 0 value for dimension not found in input"); + } + dimSizes.set(i, inputType.dimensions().get(i).size().get().intValue()); + } + } + + // second pass - set any -1 values + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) < 0) { + int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + dimSizes.set(i, -1 * tensorSize / (shapeSize == 0 ? -1 : shapeSize)); + } + } + + return buildOutputType(dimSizes); + } + + private OrderedTensorType typeWithShapeAsAttribute() { + if (attributeMap.getList("shape").isEmpty() || attributeMap.getList("shape").get().size() == 0) + throw new IllegalArgumentException("Reshape in " + name + ": Shape attribute is empty."); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); + List<Value> shape = attributeMap.getList("shape").get(); + List<Integer> dimSizes = new ArrayList<>(shape.size()); + + for (Value v : shape) { + int size = (int) v.asDouble(); if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - OrderedTensorType.tensorSize(inputType.type()).intValue(); + int shapeSize = (int) shape.stream().mapToDouble(Value::asDouble).reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + size = -1 * shapeSize / tensorSize; } - outputTypeBuilder.add(TensorType.Dimension.indexed( - String.format("%s_%d", vespaName(), dimensionIndex), size)); - dimensionIndex++; + dimSizes.add(size); + } + return buildOutputType(dimSizes); + } + + private OrderedTensorType buildOutputType(List<Integer> dimSizes) { + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dimSizes.size(); ++i) { + outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), i), dimSizes.get(i))); } return outputTypeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { - if ( ! allInputTypesPresent(2)) return null; - if ( ! allInputFunctionsPresent(2)) return null; + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null; + if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null; OrderedTensorType inputType = inputs.get(0).type().get(); TensorFunction inputFunction = inputs.get(0).function().get(); - return reshape(inputFunction, inputType.type(), type.type()); + return reshape(inputFunction, inputType, type); } @Override @@ -76,11 +123,11 @@ public class Reshape extends IntermediateOperation { @Override public Reshape withInputs(List<IntermediateOperation> inputs) { - return new Reshape(modelName(), name(), inputs); + return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) + public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, @@ -89,25 +136,27 @@ public class Reshape extends IntermediateOperation { // the new shape. We have to introduce temporary dimension names and rename back if dimension names // in the new and old tensor type overlap. + // Todo: change this to use tensor generate when available + List<String> from = new ArrayList<>(); List<String> to = new ArrayList<>(); boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType); if (dimensionNamesOverlap) { - TensorType.Builder builder = new TensorType.Builder(outputType.valueType()); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType()); for (int i = 0; i < outputType.rank(); ++i) { TensorType.Dimension dim = outputType.dimensions().get(i); from.add(dim.name()); to.add("temp_" + dim.name()); - builder.dimension(dim.withName("temp_" + dim.name())); + builder.add(dim.withName("temp_" + dim.name())); } outputType = builder.build(); } ExpressionNode unrollFrom = unrollTensorExpression(inputType); ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo)); + ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo)); - TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); + TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build(); Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); @@ -121,11 +170,11 @@ public class Reshape extends IntermediateOperation { return result; } - private static boolean dimensionNamesOverlap(TensorType a, TensorType b) { - return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent()); + private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { + return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); } - private static ExpressionNode unrollTensorExpression(TensorType type) { + private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { if (type.rank() == 0) return new ConstantNode(DoubleValue.zero); |