diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java new file mode 100644 index 00000000000..2a34ae53d5e --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -0,0 +1,170 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/* + * Onnx gather is the same as Numpy take. + */ +public class Gather extends IntermediateOperation { + + private final AttributeMap attributeMap; + + private int axis; + + public Gather(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(2)) return null; + + OrderedTensorType dataType = inputs.get(0).type().get(); + OrderedTensorType indicesType = inputs.get(1).type().get(); + + axis = (int) attributeMap.get("axis").orElse(DoubleValue.zero).asDouble(); + if (axis < 0) + axis = dataType.rank() + axis; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < axis; ++i) { + addDimension(i, dataType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + for (int i = 0; i < indicesType.rank(); ++i) { + addDimension(i + axis, indicesType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + for (int i = axis + 1; i < dataType.rank(); ++i) { + addDimension(i + indicesType.rank(), dataType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + + inputs.get(0).exportAsRankingFunction = true; + inputs.get(1).exportAsRankingFunction = true; + + return typeBuilder.build(); + } + + private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + typeBuilder.add(TensorType.Dimension.indexed(name, size)); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(2)) return null; + + IntermediateOperation data = inputs.get(0); + IntermediateOperation indices = inputs.get(1); + OrderedTensorType dataType = data.type().get(); + OrderedTensorType indicesType = indices.type().get(); + + String dataFunctionName = data.rankingExpressionFunctionName(); + String indicesFunctionName = indices.rankingExpressionFunctionName(); + + List<Slice.DimensionValue<Reference>> dataSliceDimensions = new ArrayList<>(); + for (int i = 0; i < axis; ++i) { + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i); + } + + List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>(); + for (int i = 0; i < indicesType.rank(); ++i) { + addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i); + } + ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName); + ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression); + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); + + for (int i = axis + 1; i < dataType.rank(); ++i) { + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i + indicesType.rank() - 1); + } + + sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName); + return Generate.bound(type.type(), wrapScalar(sliceExpression)); + } + + private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) { + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName)); + Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues); + return new TensorFunctionNode(sliceIndices); + } + + /** to support negative indexing */ + private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { + ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); + ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize)); + ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize); + return mod; + } + + private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, ExpressionNode expr) { + dimensionValues.add(new Slice.DimensionValue<>(Optional.of(dimensionName), wrapScalar(new EmbracedNode(expr)))); + } + + private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, int dimensionIndex) { + String outputDimensionName = type.dimensions().get(dimensionIndex).name(); + addSliceDimension(dimensionValues, dimensionName, new ReferenceNode(outputDimensionName)); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if ( ! allInputTypesPresent(2)) return; + + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.lessThan(), this); + } + } + + OrderedTensorType dataType = inputs.get(0).type().get(); + OrderedTensorType indicesType = inputs.get(1).type().get(); + + for (int i = 0; i < axis; ++i) { + renamer.addConstraint(type.dimensions().get(i).name(), + dataType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + for (int i = 0; i < indicesType.rank(); ++i) { + renamer.addConstraint(type.dimensions().get(i + axis).name(), + indicesType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + for (int i = axis + 1; i < dataType.rank(); ++i) { + renamer.addConstraint(type.dimensions().get(i + indicesType.rank() - 1).name(), + dataType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + + } + + @Override + public Gather withInputs(List<IntermediateOperation> inputs) { + return new Gather(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Gather"; } + +} |