diff options
author | Lester Solbakken <lesters@oath.com> | 2020-02-10 10:33:53 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-02-10 10:33:53 +0100 |
commit | c6ea3aa88e8929c2cbfe90f9c9ffdde482b7adc5 (patch) | |
tree | 550c76a8310c4951a3c5ae4c6e53af889bb9b54c /model-integration/src/main | |
parent | 7b5b53d288ab8b3c9ec8e054d4d5ecf2f88f7ff0 (diff) |
Add gather,slice,cast,unsqueeze onnx operations
Diffstat (limited to 'model-integration/src/main')
10 files changed, 610 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index d22a8067bd4..c7f320ed3b4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -224,6 +224,11 @@ public class DimensionRenamer { /** Returns whether this is an opposite of another constraint */ boolean isOpposite() { return opposite; } + public static Constraint equal() { return new EqualConstraint(false, false); } + public static Constraint notEqual() { return new NotEqualConstraint(false, false); } + public static Constraint lessThan() { return new LessThanConstraint(false, false); } + public static Constraint greaterThan() { return new GreaterThanConstraint(false, false); } + public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); } public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); } public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java index 8caa158e5be..b272d4c6750 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import onnx.Onnx; @@ -37,6 +38,7 @@ class AttributeConverter implements IntermediateOperation.AttributeMap { case INT: return Optional.of(DoubleValue.frozen(attr.getI())); case FLOAT: return Optional.of(DoubleValue.frozen(attr.getF())); case STRING: return Optional.of(StringValue.frozen(attr.getS().toString())); + case TENSOR: return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attr.getT(), TypeConverter.typeFrom(attr.getT())))); default: return Optional.empty(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index d42338deaf8..ffc64c38f16 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,13 +2,18 @@ package ai.vespa.rankingexpression.importer.onnx; +import ai.vespa.rankingexpression.importer.operations.ExpandDims; +import ai.vespa.rankingexpression.importer.operations.Gather; +import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; +import ai.vespa.rankingexpression.importer.operations.Slice; import ai.vespa.rankingexpression.importer.operations.Softmax; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Unsqueeze; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; @@ -67,6 +72,7 @@ class GraphImporter { case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); + case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); @@ -75,6 +81,7 @@ class GraphImporter { case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "gather": return new Gather(modelName, nodeName, inputs, attributes); case "gemm": return new Gemm(modelName, nodeName, inputs, attributes); case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); case "identity": return new Identity(modelName, nodeName, inputs); @@ -105,6 +112,7 @@ class GraphImporter { case "shape": return new Shape(modelName, nodeName, inputs); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); + case "slice": return new Slice(modelName, nodeName, inputs, attributes); case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); @@ -113,6 +121,7 @@ class GraphImporter { case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes); } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index 69d18d0ffcb..f8c7dc15857 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -37,12 +37,14 @@ class TensorConverter { case BOOL: return new RawBoolValues(tensorProto); case FLOAT: return new RawFloatValues(tensorProto); case DOUBLE: return new RawDoubleValues(tensorProto); + case INT32: return new RawIntValues(tensorProto); case INT64: return new RawLongValues(tensorProto); } } else { switch (tensorProto.getDataType()) { case FLOAT: return new FloatValues(tensorProto); case DOUBLE: return new DoubleValues(tensorProto); + case INT32: return new IntValues(tensorProto); case INT64: return new LongValues(tensorProto); } } @@ -96,6 +98,17 @@ class TensorConverter { @Override int size() { return size; } } + private static class RawIntValues extends RawValues { + private final IntBuffer values; + private final int size; + RawIntValues(Onnx.TensorProto tensorProto) { + values = bytes(tensorProto).asIntBuffer(); + size = values.remaining(); + } + @Override double get(int i) { return values.get(i); } + @Override int size() { return size; } + } + private static class RawLongValues extends RawValues { private final LongBuffer values; private final int size; @@ -125,6 +138,15 @@ class TensorConverter { @Override int size() { return tensorProto.getDoubleDataCount(); } } + private static class IntValues extends Values { + private final Onnx.TensorProto tensorProto; + IntValues(Onnx.TensorProto tensorProto) { + this.tensorProto = tensorProto; + } + @Override double get(int i) { return tensorProto.getInt32Data(i); } + @Override int size() { return tensorProto.getInt32DataCount(); } + } + private static class LongValues extends Values { private final Onnx.TensorProto tensorProto; LongValues(Onnx.TensorProto tensorProto) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 3487d889338..e02f29a63f9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -40,7 +40,7 @@ public class ExpandDims extends IntermediateOperation { OrderedTensorType inputType = inputs.get(0).type().get(); int dimensionToInsert = (int)axis.asDouble(); if (dimensionToInsert < 0) { - dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + dimensionToInsert = inputType.dimensions().size() + dimensionToInsert; } OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java new file mode 100644 index 00000000000..2a34ae53d5e --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -0,0 +1,170 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/* + * Onnx gather is the same as Numpy take. + */ +public class Gather extends IntermediateOperation { + + private final AttributeMap attributeMap; + + private int axis; + + public Gather(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(2)) return null; + + OrderedTensorType dataType = inputs.get(0).type().get(); + OrderedTensorType indicesType = inputs.get(1).type().get(); + + axis = (int) attributeMap.get("axis").orElse(DoubleValue.zero).asDouble(); + if (axis < 0) + axis = dataType.rank() + axis; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < axis; ++i) { + addDimension(i, dataType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + for (int i = 0; i < indicesType.rank(); ++i) { + addDimension(i + axis, indicesType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + for (int i = axis + 1; i < dataType.rank(); ++i) { + addDimension(i + indicesType.rank(), dataType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + + inputs.get(0).exportAsRankingFunction = true; + inputs.get(1).exportAsRankingFunction = true; + + return typeBuilder.build(); + } + + private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + typeBuilder.add(TensorType.Dimension.indexed(name, size)); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(2)) return null; + + IntermediateOperation data = inputs.get(0); + IntermediateOperation indices = inputs.get(1); + OrderedTensorType dataType = data.type().get(); + OrderedTensorType indicesType = indices.type().get(); + + String dataFunctionName = data.rankingExpressionFunctionName(); + String indicesFunctionName = indices.rankingExpressionFunctionName(); + + List<Slice.DimensionValue<Reference>> dataSliceDimensions = new ArrayList<>(); + for (int i = 0; i < axis; ++i) { + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i); + } + + List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>(); + for (int i = 0; i < indicesType.rank(); ++i) { + addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i); + } + ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName); + ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression); + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); + + for (int i = axis + 1; i < dataType.rank(); ++i) { + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i + indicesType.rank() - 1); + } + + sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName); + return Generate.bound(type.type(), wrapScalar(sliceExpression)); + } + + private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) { + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName)); + Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues); + return new TensorFunctionNode(sliceIndices); + } + + /** to support negative indexing */ + private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { + ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); + ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize)); + ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize); + return mod; + } + + private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, ExpressionNode expr) { + dimensionValues.add(new Slice.DimensionValue<>(Optional.of(dimensionName), wrapScalar(new EmbracedNode(expr)))); + } + + private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, int dimensionIndex) { + String outputDimensionName = type.dimensions().get(dimensionIndex).name(); + addSliceDimension(dimensionValues, dimensionName, new ReferenceNode(outputDimensionName)); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if ( ! allInputTypesPresent(2)) return; + + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.lessThan(), this); + } + } + + OrderedTensorType dataType = inputs.get(0).type().get(); + OrderedTensorType indicesType = inputs.get(1).type().get(); + + for (int i = 0; i < axis; ++i) { + renamer.addConstraint(type.dimensions().get(i).name(), + dataType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + for (int i = 0; i < indicesType.rank(); ++i) { + renamer.addConstraint(type.dimensions().get(i + axis).name(), + indicesType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + for (int i = axis + 1; i < dataType.rank(); ++i) { + renamer.addConstraint(type.dimensions().get(i + indicesType.rank() - 1).name(), + dataType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + + } + + @Override + public Gather withInputs(List<IntermediateOperation> inputs) { + return new Gather(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Gather"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 724b5c6b3ac..2aa8b2a0d48 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -45,6 +45,7 @@ public abstract class IntermediateOperation { protected OrderedTensorType type; protected TensorFunction function; protected TensorFunction rankingExpressionFunction = null; + protected boolean exportAsRankingFunction = false; private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; @@ -78,7 +79,7 @@ public abstract class IntermediateOperation { if (isConstant()) { ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.ExpressionTensorFunction(constant); - } else if (outputs.size() > 1) { + } else if (outputs.size() > 1 || exportAsRankingFunction) { rankingExpressionFunction = lazyGetFunction(); function = new VariableTensor(rankingExpressionFunctionName(), type.type()); } else { @@ -137,7 +138,7 @@ public abstract class IntermediateOperation { return Optional.of(constantValue); } if (constantValueFunction != null) { - return Optional.of(constantValueFunction.apply(type)); + return Optional.of(constantValueFunction.apply(type().orElse(null))); } return Optional.empty(); } @@ -188,7 +189,7 @@ public abstract class IntermediateOperation { throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant."); } Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN)); - if ( ! val.asTensor().type().equals(type.type()) ) { + if (type != null && ! val.asTensor().type().equals(type.type()) ) { throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " + "Expected: " + type.type() + " Got: " + val.asTensor().type()); } @@ -211,6 +212,9 @@ public abstract class IntermediateOperation { result = new TensorValue(lazyGetFunction().evaluate(context)); } context.put(constantName, result); + if (outputs.size() > 1 || exportAsRankingFunction) { + context.put(rankingExpressionFunctionName(), result); + } } return result; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java new file mode 100644 index 00000000000..d15ac1b69f7 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java @@ -0,0 +1,82 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx.TensorProto.DataType; + +import java.util.List; +import java.util.function.DoubleUnaryOperator; + +public class OnnxCast extends IntermediateOperation { + + private final AttributeMap attributeMap; + private final DataType toType; + + public OnnxCast(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + if (attributeMap.get("to").isEmpty()) { + throw new IllegalArgumentException("OnnxCast in " + name + ": Required attribute 'to' is missing."); + } + toType = DataType.forNumber((int) attributeMap.get("to").get().asDouble()); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) + return null; + TensorFunction input = inputs.get(0).function().get(); + switch (toType) { + case BOOL: + return new com.yahoo.tensor.functions.Map(input, new AsBool()); + case INT8: + case INT16: + case INT32: + case INT64: + case UINT8: + case UINT16: + case UINT32: + case UINT64: + return new com.yahoo.tensor.functions.Map(input, new AsInt()); + case FLOAT: + case DOUBLE: + case FLOAT16: + return input; + case STRING: + throw new IllegalArgumentException("OnnxCast in " + name + ": Casting to string is not implemented."); + default: + throw new IllegalArgumentException("OnnxCast in " + name + ": Unknown or undefined cast: " + toType.name()); + } + } + + @Override + public OnnxCast withInputs(List<IntermediateOperation> inputs) { + return new OnnxCast(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Cast"; } + + private static class AsBool implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return operand != 0.0 ? 1 : 0; } + @Override + public String toString() { return "f(a)(a!=0)"; } + } + + private static class AsInt implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return operand < 0 ? Math.ceil(operand) : Math.floor(operand); } + @Override + public String toString() { return "f(a)(if (a < 0, ceil(a), floor(a)))"; } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java new file mode 100644 index 00000000000..b7c366d1034 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -0,0 +1,203 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/** + * Onnx slice operation. + * + * Opset 1 to 9 accepts starts, ends, and axes tensors as attributes + * + * Opset >= 10 accepts starts, ends, axes, and steps tensors as inputs. Here we assume these are + * constants, otherwise we can't import this model because that would mean we + * would not know the resulting tensor type until run-time, and that is currently + * not supported in Vespa. + */ +public class Slice extends IntermediateOperation { + + private final AttributeMap attributes; + + private int[] starts; + private int[] ends; + private int[] steps; + + public Slice(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (inputs.size() < 1 || inputs.get(0).type().isEmpty()) { + return null; + } + OrderedTensorType dataType = inputs.get(0).type().get(); + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + // Todo: only supports opsets 1-9, for >= get these from inputs + int[] startsInput = attributeListAsArray("starts", 0); + int[] endsInput = attributeListAsArray("ends", 0); + int[] stepsInput = new int[dataType.rank()]; Arrays.fill(stepsInput, 1); // Todo: get from input when opset >= 10 + + int[] axes; + if (attributes.getList("axes").isPresent()) { + axes = attributeListAsArray("axes", 0); + } else { + // infer axes: default is [0, 1, ..., len('starts')-1] + axes = new int[startsInput.length]; + for (int i = 0; i < startsInput.length; ++i) { + axes[i] = i; + } + } + + if (startsInput.length != endsInput.length) { + throw new IllegalArgumentException("Slice in " + name + ": 'starts' and 'ends' indexes are not of the same size."); + } + if (startsInput.length != axes.length) { + throw new IllegalArgumentException("Slice in " + name + ": 'axes' and 'starts' are not of same size."); + } + + int[] dimensionSizes = new int [dataType.rank()]; + for (int i = 0; i < dataType.rank(); ++i) { + dimensionSizes[i] = dataType.dimensions().get(i).size().get().intValue(); + } + + starts = new int[dataType.rank()]; Arrays.fill(starts, 0); + ends = new int[dataType.rank()]; + steps = new int[dataType.rank()]; Arrays.fill(steps, 1); + + for (int i = 0; i < axes.length; ++i) { + int axis = axes[i]; + int start = startsInput[i]; + int end = endsInput[i]; + int step = stepsInput[i]; + + axis = Math.min(axis, dataType.rank() - 1); + axis = axis < 0 ? axis + dataType.rank() : axis; + + start = Math.min(start, dimensionSizes[axis]); + start = start < 0 ? start + dimensionSizes[axis] : start; + + end = Math.min(end, dimensionSizes[axis]); + end = end < 0 ? end + dimensionSizes[axis] : end; + + // Todo: check negative values for step size + + starts[axis] = start; + steps[axis] = step; + + if (step == 0) { + throw new IllegalArgumentException("Slice in " + name + ": illegal step size of 0."); + } + if ((end - start) < 1) { + throw new IllegalArgumentException("Slice in " + name + ": illegal start (" + start + ") and end (" + end + ") index."); + } + dimensionSizes[axis] = (end - start) / step; + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dataType.rank(); ++i) { + addDimension(i, dimensionSizes[i], typeBuilder); + } + return typeBuilder.build(); + } + + private int[] attributeListAsArray(String name, int defaultValue) { + if (attributes.getList(name).isEmpty()) { + throw new IllegalArgumentException("Slice in " + name + ": Required attribute '" + name + "' is missing."); + } + List<Value> list = attributes.getList(name).get(); + int[] result = new int[list.size()]; Arrays.fill(result, defaultValue); + for (int i = 0; i < list.size(); ++i) { + result[i] = (int)list.get(i).asDouble(); + } + return result; + } + + private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + typeBuilder.add(TensorType.Dimension.indexed(name, size)); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (inputs.size() < 1 || inputs.get(0).function().isEmpty()) { + return null; + } + + IntermediateOperation data = inputs.get(0); + OrderedTensorType dataType = data.type().get(); + String dataFunctionName = data.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + for (int axis = 0; axis < dataType.rank(); ++axis) { + int start = starts[axis]; + int step = steps[axis]; + + String inputDimensionName = dataType.dimensions().get(axis).name(); + String outputDimensionName = type.dimensions().get(axis).name(); + + ExpressionNode stepSize = new ConstantNode(new DoubleValue(step)); + ExpressionNode startIndex = new ConstantNode(new DoubleValue(start)); + + // step * (d0 + start) + ExpressionNode reference = new ReferenceNode(outputDimensionName); + ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex)); + ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus); + + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul)))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(dataFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + // Todo: what to do? + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.lessThan(), this); + } + } + } + + @Override + public Slice withInputs(List<IntermediateOperation> inputs) { + return new Slice(modelName(), name(), inputs, attributes); + } + + @Override + public String operationName() { return "Slice"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java new file mode 100644 index 00000000000..0df09c21530 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java @@ -0,0 +1,109 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +public class Unsqueeze extends IntermediateOperation { + + private final AttributeMap attributeMap; + private List<String> expandDimensions; + + public Unsqueeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + if (attributeMap.getList("axes").isEmpty()) { + throw new IllegalArgumentException("Unsqueeze in " + name + ": Required attribute 'axes' is missing."); + } + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + Set<Integer> dimensionsToInsert = attributeMap.getList("axes").get().stream(). + map(d -> (int)d.asDouble()).collect(Collectors.toSet()); + + // handle negative dimension indexes + int rank = inputType.rank() + dimensionsToInsert.size(); + dimensionsToInsert = dimensionsToInsert.stream().map(d -> d < 0 ? rank + d : d).collect(Collectors.toSet()); + + expandDimensions = new ArrayList<>(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + int inputDimensionIndex = 0; + for (int expandedDimensionIndex = 0; expandedDimensionIndex < rank; ++expandedDimensionIndex) { + if (dimensionsToInsert.contains(expandedDimensionIndex)) { + addDimension(expandedDimensionIndex, typeBuilder); + } else { + typeBuilder.add(inputType.dimensions().get(inputDimensionIndex)); + inputDimensionIndex++; + } + } + return typeBuilder.build(); + } + + private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + + // multiply with a generated tensor created from the expanded dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); + for (String name : expandDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); + for (String name : expandDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (newName.isEmpty()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + expandDimensions = renamedDimensions; + } + + @Override + public Unsqueeze withInputs(List<IntermediateOperation> inputs) { + return new Unsqueeze(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Unsqueeze"; } + +} |