diff options
Diffstat (limited to 'model-integration')
14 files changed, 835 insertions, 23 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index d22a8067bd4..c7f320ed3b4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -224,6 +224,11 @@ public class DimensionRenamer { /** Returns whether this is an opposite of another constraint */ boolean isOpposite() { return opposite; } + public static Constraint equal() { return new EqualConstraint(false, false); } + public static Constraint notEqual() { return new NotEqualConstraint(false, false); } + public static Constraint lessThan() { return new LessThanConstraint(false, false); } + public static Constraint greaterThan() { return new GreaterThanConstraint(false, false); } + public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); } public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); } public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java index 8caa158e5be..b272d4c6750 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import onnx.Onnx; @@ -37,6 +38,7 @@ class AttributeConverter implements IntermediateOperation.AttributeMap { case INT: return Optional.of(DoubleValue.frozen(attr.getI())); case FLOAT: return Optional.of(DoubleValue.frozen(attr.getF())); case STRING: return Optional.of(StringValue.frozen(attr.getS().toString())); + case TENSOR: return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attr.getT(), TypeConverter.typeFrom(attr.getT())))); default: return Optional.empty(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index d42338deaf8..ffc64c38f16 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,13 +2,18 @@ package ai.vespa.rankingexpression.importer.onnx; +import ai.vespa.rankingexpression.importer.operations.ExpandDims; +import ai.vespa.rankingexpression.importer.operations.Gather; +import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; +import ai.vespa.rankingexpression.importer.operations.Slice; import ai.vespa.rankingexpression.importer.operations.Softmax; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Unsqueeze; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; @@ -67,6 +72,7 @@ class GraphImporter { case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); + case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); @@ -75,6 +81,7 @@ class GraphImporter { case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "gather": return new Gather(modelName, nodeName, inputs, attributes); case "gemm": return new Gemm(modelName, nodeName, inputs, attributes); case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); case "identity": return new Identity(modelName, nodeName, inputs); @@ -105,6 +112,7 @@ class GraphImporter { case "shape": return new Shape(modelName, nodeName, inputs); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); + case "slice": return new Slice(modelName, nodeName, inputs, attributes); case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); @@ -113,6 +121,7 @@ class GraphImporter { case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes); } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index 69d18d0ffcb..f8c7dc15857 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -37,12 +37,14 @@ class TensorConverter { case BOOL: return new RawBoolValues(tensorProto); case FLOAT: return new RawFloatValues(tensorProto); case DOUBLE: return new RawDoubleValues(tensorProto); + case INT32: return new RawIntValues(tensorProto); case INT64: return new RawLongValues(tensorProto); } } else { switch (tensorProto.getDataType()) { case FLOAT: return new FloatValues(tensorProto); case DOUBLE: return new DoubleValues(tensorProto); + case INT32: return new IntValues(tensorProto); case INT64: return new LongValues(tensorProto); } } @@ -96,6 +98,17 @@ class TensorConverter { @Override int size() { return size; } } + private static class RawIntValues extends RawValues { + private final IntBuffer values; + private final int size; + RawIntValues(Onnx.TensorProto tensorProto) { + values = bytes(tensorProto).asIntBuffer(); + size = values.remaining(); + } + @Override double get(int i) { return values.get(i); } + @Override int size() { return size; } + } + private static class RawLongValues extends RawValues { private final LongBuffer values; private final int size; @@ -125,6 +138,15 @@ class TensorConverter { @Override int size() { return tensorProto.getDoubleDataCount(); } } + private static class IntValues extends Values { + private final Onnx.TensorProto tensorProto; + IntValues(Onnx.TensorProto tensorProto) { + this.tensorProto = tensorProto; + } + @Override double get(int i) { return tensorProto.getInt32Data(i); } + @Override int size() { return tensorProto.getInt32DataCount(); } + } + private static class LongValues extends Values { private final Onnx.TensorProto tensorProto; LongValues(Onnx.TensorProto tensorProto) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 3487d889338..e02f29a63f9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -40,7 +40,7 @@ public class ExpandDims extends IntermediateOperation { OrderedTensorType inputType = inputs.get(0).type().get(); int dimensionToInsert = (int)axis.asDouble(); if (dimensionToInsert < 0) { - dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + dimensionToInsert = inputType.dimensions().size() + dimensionToInsert; } OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java new file mode 100644 index 00000000000..2a34ae53d5e --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -0,0 +1,170 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Slice; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/* + * Onnx gather is the same as Numpy take. + */ +public class Gather extends IntermediateOperation { + + private final AttributeMap attributeMap; + + private int axis; + + public Gather(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(2)) return null; + + OrderedTensorType dataType = inputs.get(0).type().get(); + OrderedTensorType indicesType = inputs.get(1).type().get(); + + axis = (int) attributeMap.get("axis").orElse(DoubleValue.zero).asDouble(); + if (axis < 0) + axis = dataType.rank() + axis; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < axis; ++i) { + addDimension(i, dataType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + for (int i = 0; i < indicesType.rank(); ++i) { + addDimension(i + axis, indicesType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + for (int i = axis + 1; i < dataType.rank(); ++i) { + addDimension(i + indicesType.rank(), dataType.dimensions().get(i).size().orElse(-1L), typeBuilder); + } + + inputs.get(0).exportAsRankingFunction = true; + inputs.get(1).exportAsRankingFunction = true; + + return typeBuilder.build(); + } + + private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + typeBuilder.add(TensorType.Dimension.indexed(name, size)); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(2)) return null; + + IntermediateOperation data = inputs.get(0); + IntermediateOperation indices = inputs.get(1); + OrderedTensorType dataType = data.type().get(); + OrderedTensorType indicesType = indices.type().get(); + + String dataFunctionName = data.rankingExpressionFunctionName(); + String indicesFunctionName = indices.rankingExpressionFunctionName(); + + List<Slice.DimensionValue<Reference>> dataSliceDimensions = new ArrayList<>(); + for (int i = 0; i < axis; ++i) { + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i); + } + + List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>(); + for (int i = 0; i < indicesType.rank(); ++i) { + addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i); + } + ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName); + ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression); + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); + + for (int i = axis + 1; i < dataType.rank(); ++i) { + addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i + indicesType.rank() - 1); + } + + sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName); + return Generate.bound(type.type(), wrapScalar(sliceExpression)); + } + + private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) { + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName)); + Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues); + return new TensorFunctionNode(sliceIndices); + } + + /** to support negative indexing */ + private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { + ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); + ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize)); + ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize); + return mod; + } + + private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, ExpressionNode expr) { + dimensionValues.add(new Slice.DimensionValue<>(Optional.of(dimensionName), wrapScalar(new EmbracedNode(expr)))); + } + + private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, int dimensionIndex) { + String outputDimensionName = type.dimensions().get(dimensionIndex).name(); + addSliceDimension(dimensionValues, dimensionName, new ReferenceNode(outputDimensionName)); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if ( ! allInputTypesPresent(2)) return; + + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.lessThan(), this); + } + } + + OrderedTensorType dataType = inputs.get(0).type().get(); + OrderedTensorType indicesType = inputs.get(1).type().get(); + + for (int i = 0; i < axis; ++i) { + renamer.addConstraint(type.dimensions().get(i).name(), + dataType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + for (int i = 0; i < indicesType.rank(); ++i) { + renamer.addConstraint(type.dimensions().get(i + axis).name(), + indicesType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + for (int i = axis + 1; i < dataType.rank(); ++i) { + renamer.addConstraint(type.dimensions().get(i + indicesType.rank() - 1).name(), + dataType.dimensions().get(i).name(), + DimensionRenamer.Constraint.equal(), this); + } + + } + + @Override + public Gather withInputs(List<IntermediateOperation> inputs) { + return new Gather(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Gather"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 724b5c6b3ac..2aa8b2a0d48 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -45,6 +45,7 @@ public abstract class IntermediateOperation { protected OrderedTensorType type; protected TensorFunction function; protected TensorFunction rankingExpressionFunction = null; + protected boolean exportAsRankingFunction = false; private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; @@ -78,7 +79,7 @@ public abstract class IntermediateOperation { if (isConstant()) { ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.ExpressionTensorFunction(constant); - } else if (outputs.size() > 1) { + } else if (outputs.size() > 1 || exportAsRankingFunction) { rankingExpressionFunction = lazyGetFunction(); function = new VariableTensor(rankingExpressionFunctionName(), type.type()); } else { @@ -137,7 +138,7 @@ public abstract class IntermediateOperation { return Optional.of(constantValue); } if (constantValueFunction != null) { - return Optional.of(constantValueFunction.apply(type)); + return Optional.of(constantValueFunction.apply(type().orElse(null))); } return Optional.empty(); } @@ -188,7 +189,7 @@ public abstract class IntermediateOperation { throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant."); } Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN)); - if ( ! val.asTensor().type().equals(type.type()) ) { + if (type != null && ! val.asTensor().type().equals(type.type()) ) { throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " + "Expected: " + type.type() + " Got: " + val.asTensor().type()); } @@ -211,6 +212,9 @@ public abstract class IntermediateOperation { result = new TensorValue(lazyGetFunction().evaluate(context)); } context.put(constantName, result); + if (outputs.size() > 1 || exportAsRankingFunction) { + context.put(rankingExpressionFunctionName(), result); + } } return result; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java new file mode 100644 index 00000000000..d15ac1b69f7 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java @@ -0,0 +1,82 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx.TensorProto.DataType; + +import java.util.List; +import java.util.function.DoubleUnaryOperator; + +public class OnnxCast extends IntermediateOperation { + + private final AttributeMap attributeMap; + private final DataType toType; + + public OnnxCast(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + if (attributeMap.get("to").isEmpty()) { + throw new IllegalArgumentException("OnnxCast in " + name + ": Required attribute 'to' is missing."); + } + toType = DataType.forNumber((int) attributeMap.get("to").get().asDouble()); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) + return null; + TensorFunction input = inputs.get(0).function().get(); + switch (toType) { + case BOOL: + return new com.yahoo.tensor.functions.Map(input, new AsBool()); + case INT8: + case INT16: + case INT32: + case INT64: + case UINT8: + case UINT16: + case UINT32: + case UINT64: + return new com.yahoo.tensor.functions.Map(input, new AsInt()); + case FLOAT: + case DOUBLE: + case FLOAT16: + return input; + case STRING: + throw new IllegalArgumentException("OnnxCast in " + name + ": Casting to string is not implemented."); + default: + throw new IllegalArgumentException("OnnxCast in " + name + ": Unknown or undefined cast: " + toType.name()); + } + } + + @Override + public OnnxCast withInputs(List<IntermediateOperation> inputs) { + return new OnnxCast(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Cast"; } + + private static class AsBool implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return operand != 0.0 ? 1 : 0; } + @Override + public String toString() { return "f(a)(a!=0)"; } + } + + private static class AsInt implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return operand < 0 ? Math.ceil(operand) : Math.floor(operand); } + @Override + public String toString() { return "f(a)(if (a < 0, ceil(a), floor(a)))"; } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java new file mode 100644 index 00000000000..b7c366d1034 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -0,0 +1,203 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/** + * Onnx slice operation. + * + * Opset 1 to 9 accepts starts, ends, and axes tensors as attributes + * + * Opset >= 10 accepts starts, ends, axes, and steps tensors as inputs. Here we assume these are + * constants, otherwise we can't import this model because that would mean we + * would not know the resulting tensor type until run-time, and that is currently + * not supported in Vespa. + */ +public class Slice extends IntermediateOperation { + + private final AttributeMap attributes; + + private int[] starts; + private int[] ends; + private int[] steps; + + public Slice(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (inputs.size() < 1 || inputs.get(0).type().isEmpty()) { + return null; + } + OrderedTensorType dataType = inputs.get(0).type().get(); + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + // Todo: only supports opsets 1-9, for >= get these from inputs + int[] startsInput = attributeListAsArray("starts", 0); + int[] endsInput = attributeListAsArray("ends", 0); + int[] stepsInput = new int[dataType.rank()]; Arrays.fill(stepsInput, 1); // Todo: get from input when opset >= 10 + + int[] axes; + if (attributes.getList("axes").isPresent()) { + axes = attributeListAsArray("axes", 0); + } else { + // infer axes: default is [0, 1, ..., len('starts')-1] + axes = new int[startsInput.length]; + for (int i = 0; i < startsInput.length; ++i) { + axes[i] = i; + } + } + + if (startsInput.length != endsInput.length) { + throw new IllegalArgumentException("Slice in " + name + ": 'starts' and 'ends' indexes are not of the same size."); + } + if (startsInput.length != axes.length) { + throw new IllegalArgumentException("Slice in " + name + ": 'axes' and 'starts' are not of same size."); + } + + int[] dimensionSizes = new int [dataType.rank()]; + for (int i = 0; i < dataType.rank(); ++i) { + dimensionSizes[i] = dataType.dimensions().get(i).size().get().intValue(); + } + + starts = new int[dataType.rank()]; Arrays.fill(starts, 0); + ends = new int[dataType.rank()]; + steps = new int[dataType.rank()]; Arrays.fill(steps, 1); + + for (int i = 0; i < axes.length; ++i) { + int axis = axes[i]; + int start = startsInput[i]; + int end = endsInput[i]; + int step = stepsInput[i]; + + axis = Math.min(axis, dataType.rank() - 1); + axis = axis < 0 ? axis + dataType.rank() : axis; + + start = Math.min(start, dimensionSizes[axis]); + start = start < 0 ? start + dimensionSizes[axis] : start; + + end = Math.min(end, dimensionSizes[axis]); + end = end < 0 ? end + dimensionSizes[axis] : end; + + // Todo: check negative values for step size + + starts[axis] = start; + steps[axis] = step; + + if (step == 0) { + throw new IllegalArgumentException("Slice in " + name + ": illegal step size of 0."); + } + if ((end - start) < 1) { + throw new IllegalArgumentException("Slice in " + name + ": illegal start (" + start + ") and end (" + end + ") index."); + } + dimensionSizes[axis] = (end - start) / step; + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dataType.rank(); ++i) { + addDimension(i, dimensionSizes[i], typeBuilder); + } + return typeBuilder.build(); + } + + private int[] attributeListAsArray(String name, int defaultValue) { + if (attributes.getList(name).isEmpty()) { + throw new IllegalArgumentException("Slice in " + name + ": Required attribute '" + name + "' is missing."); + } + List<Value> list = attributes.getList(name).get(); + int[] result = new int[list.size()]; Arrays.fill(result, defaultValue); + for (int i = 0; i < list.size(); ++i) { + result[i] = (int)list.get(i).asDouble(); + } + return result; + } + + private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + typeBuilder.add(TensorType.Dimension.indexed(name, size)); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (inputs.size() < 1 || inputs.get(0).function().isEmpty()) { + return null; + } + + IntermediateOperation data = inputs.get(0); + OrderedTensorType dataType = data.type().get(); + String dataFunctionName = data.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + for (int axis = 0; axis < dataType.rank(); ++axis) { + int start = starts[axis]; + int step = steps[axis]; + + String inputDimensionName = dataType.dimensions().get(axis).name(); + String outputDimensionName = type.dimensions().get(axis).name(); + + ExpressionNode stepSize = new ConstantNode(new DoubleValue(step)); + ExpressionNode startIndex = new ConstantNode(new DoubleValue(start)); + + // step * (d0 + start) + ExpressionNode reference = new ReferenceNode(outputDimensionName); + ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex)); + ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus); + + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul)))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(dataFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + // Todo: what to do? + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.lessThan(), this); + } + } + } + + @Override + public Slice withInputs(List<IntermediateOperation> inputs) { + return new Slice(modelName(), name(), inputs, attributes); + } + + @Override + public String operationName() { return "Slice"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java new file mode 100644 index 00000000000..0df09c21530 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java @@ -0,0 +1,109 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +public class Unsqueeze extends IntermediateOperation { + + private final AttributeMap attributeMap; + private List<String> expandDimensions; + + public Unsqueeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + if (attributeMap.getList("axes").isEmpty()) { + throw new IllegalArgumentException("Unsqueeze in " + name + ": Required attribute 'axes' is missing."); + } + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + Set<Integer> dimensionsToInsert = attributeMap.getList("axes").get().stream(). + map(d -> (int)d.asDouble()).collect(Collectors.toSet()); + + // handle negative dimension indexes + int rank = inputType.rank() + dimensionsToInsert.size(); + dimensionsToInsert = dimensionsToInsert.stream().map(d -> d < 0 ? rank + d : d).collect(Collectors.toSet()); + + expandDimensions = new ArrayList<>(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + int inputDimensionIndex = 0; + for (int expandedDimensionIndex = 0; expandedDimensionIndex < rank; ++expandedDimensionIndex) { + if (dimensionsToInsert.contains(expandedDimensionIndex)) { + addDimension(expandedDimensionIndex, typeBuilder); + } else { + typeBuilder.add(inputType.dimensions().get(inputDimensionIndex)); + inputDimensionIndex++; + } + } + return typeBuilder.build(); + } + + private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + + // multiply with a generated tensor created from the expanded dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); + for (String name : expandDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); + for (String name : expandDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (newName.isEmpty()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + expandDimensions = renamedDimensions; + } + + @Override + public Unsqueeze withInputs(List<IntermediateOperation> inputs) { + return new Unsqueeze(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Unsqueeze"; } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 6954abe5157..94c5577357b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -17,6 +17,7 @@ import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; import onnx.Onnx; +import org.junit.Ignore; import org.junit.Test; import java.util.ArrayList; @@ -26,7 +27,9 @@ import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*; import static onnx.Onnx.AttributeProto.AttributeType.FLOAT; import static onnx.Onnx.AttributeProto.AttributeType.INT; import static onnx.Onnx.AttributeProto.AttributeType.INTS; +import static onnx.Onnx.AttributeProto.AttributeType.TENSOR; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Unit tests for ONNX operators. The number on the test reflects the minimum @@ -294,6 +297,27 @@ public class OnnxOperationsTestCase { } @Test + public void testUnsqueeze1() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[1, 2]"); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {1})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {-1})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {-2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0,0})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {0,2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {2,0})); + + x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,1})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[3],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,3})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {1,2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[3],d2[1],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2,3})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2,4})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {4,2,0})); + } + + @Test public void testWhere9() throws ParseException { Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); @@ -308,6 +332,109 @@ public class OnnxOperationsTestCase { assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x); } + @Test + public void testCast1() throws ParseException { + Tensor x = evaluate("tensor(d0[4]):[-1.9, 0.0, 1.1, 2.9]"); + assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 9)); // boolean + assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 6)); // int32 + assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 12)); // uint32 + assertEval("cast", x, evaluate("tensor(d0[4]):[-1.9,0,1.1,2.9]"), createAttribute("to", 1)); // float + try { + assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 8)); // string + fail(); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "OnnxCast in cast: Casting to string is not implemented."); + } + } + + @Test + public void testGather1() throws ParseException { + // 1 dim input, 1 dim indices + Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + Tensor y = evaluate("tensor(d0[3]):[0,2,4]"); + assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]")); + + // 2 dim input, 1 dim indices - axis 0 + x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[4]):[2, 1, 0, 2]"); + assertEval("gather", x, y, evaluate("tensor(d0[4],d1[2]):[5, 6, 3, 4, 1, 2, 5, 6]")); + + // 1 dim input, 2 dim indices - axis 0 + x = evaluate("tensor(d0[6]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[2],d1[2]):[0, 1, 3, 5]"); + assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2]):[1, 2, 4, 6]")); + + // 2 dim input, 2 dim indices - axis 0 + x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[2],d1[2]):[0, 1, 1, 2]"); + assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"), createAttribute("axis", -2)); + + // 2 dim input, 1 dim indices - axis 1 + x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[4]):[0, 1, 0, 1]"); + assertEval("gather", x, y, evaluate("tensor(d0[3],d1[4]):[1,2,1,2,3,4,3,4,5,6,5,6]"), createAttribute("axis", 1)); + + // 2 dim input, 2 dim indices - axis 1 + x = evaluate("tensor(d0[3],d1[3]):[1, 2, 3, 4, 5, 6, 7, 8, 9]"); + y = evaluate("tensor(d0[1],d1[2]):[0, 2]"); + assertEval("gather", x, y, evaluate("tensor(d0[3],d1[1],d2[2]):[1,3,4,6,7,9]"), createAttribute("axis", 1)); + + // 1 dim input, 1 dim indices - negative indices + x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + y = evaluate("tensor(d0[3]):[0,-2,-4]"); + assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,5,3]")); + } + + @Test + public void testSlice1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]"); + AttributeConverter attributes = createAttributes(). + attr("starts", new int[] {1, 0}). + attr("ends", new int[] {2, 3}). + attr("axes", new int[] {0, 1}).build(); + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {0, 1}). + attr("ends", new int[] {-1, 1000}).build(); + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [2,3,4] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {0, 1}). + attr("ends", new int[] {3, 2}). + attr("axes", new int[] {1, 0}).build(); // axes are switched + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {1, 0}). + attr("ends", new int[] {2, 3}). + attr("axes", new int[] {0, -1}).build(); // negative axes + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {1}). + attr("ends", new int[] {2}). + attr("axes", new int[] {0}).build(); // axis 1 is not specified + assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [5,6,7,8] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {0}). + attr("ends", new int[] {1}).build(); + assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [1,2,3,4] ]"), attributes); + } + + @Ignore + @Test + public void testSlice10() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]"); + Tensor starts = evaluate("tensor(d0[2]):[1,0]"); + Tensor ends = evaluate("tensor(d0[2]):[2,3]"); + Tensor axes = evaluate("tensor(d0[2]):[0,1]"); + Tensor steps = evaluate("tensor(d0[2]):[1,2]"); + assertEval("slice", x, starts, ends, axes, steps, evaluate("tensor(d0[1],d1[2]):[ [5,7] ]")); + + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -334,28 +461,40 @@ public class OnnxOperationsTestCase { } private void assertEval(String opName, Tensor x, Tensor expected) { - assertEval(opName, x, null, null, expected, null); + assertEval(opName, x, null, null, null, null, expected, null); } private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, null, null, expected, attr); + assertEval(opName, x, null, null, null, null, expected, attr); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, null, expected, attr); + assertEval(opName, x, y, null, null, null, expected, attr); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { - assertEval(opName, x, y, null, expected, null); + assertEval(opName, x, y, null, null, null, expected, null); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { - assertEval(opName, x, y, z, expected, null); + assertEval(opName, x, y, z, null, null, expected, null); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, y, z, null, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) { + assertEval(opName, x, y, z, q, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) { + assertEval(opName, x, y, z, q, r, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) { Context context = new MapContext(DoubleValue.NaN); - List<IntermediateOperation> inputs = createInputs(context, x, y, z); + List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r); IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); optimizeAndRename(opName, op); Tensor result = evaluate(op); @@ -363,11 +502,13 @@ public class OnnxOperationsTestCase { assertEquals(expected.type(), result.type()); } - private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) { + private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r) { List<IntermediateOperation> inputs = new ArrayList<>(); addInput(inputs, context, x, "x"); addInput(inputs, context, y, "y"); addInput(inputs, context, z, "z"); + addInput(inputs, context, q, "q"); + addInput(inputs, context, r, "r"); return inputs; } @@ -451,6 +592,16 @@ public class OnnxOperationsTestCase { return this; } + Attributes attr(String name, Tensor tensor) { + Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder(); + builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);; + tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get())); + tensor.valueIterator().forEachRemaining(builder::addDoubleData); + Onnx.TensorProto val = builder.build(); + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build()); + return this; + } + AttributeConverter build() { return AttributeConverter.convert(nodeBuilder.build()); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index d1dea730da5..9631bddd93d 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -3,8 +3,13 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -21,21 +26,48 @@ public class SimpleImportTestCase { ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx"); MapContext context = new MapContext(); - context.put("query_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[4])")). - cell(0.1, 0, 0). - cell(0.2, 0, 1). - cell(0.3, 0, 2). - cell(0.4, 0, 3).build())); - context.put("attribute_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[4],d1[1])")). - cell(0.1, 0, 0). - cell(0.2, 1, 0). - cell(0.3, 2, 0). - cell(0.4, 3, 0).build())); - context.put("bias_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[1])")). - cell(1.0, 0, 0).build())); + context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]"))); + context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]"))); + context.put("bias_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[1]):[1.0]"))); Tensor result = model.expressions().get("output").evaluate(context).asTensor(); assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}")); } + @Test + public void testGather() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx"); + + MapContext context = new MapContext(); + context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"))); + context.put("indices", new TensorValue(Tensor.from("tensor(d0[2],d1[2]):[0, 1, 1, 2]"))); + + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); + + Tensor result = model.expressions().get("y").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]")); + } + + private void evaluateFunction(Context context, ImportedModel model, String functionName) { + if (!context.names().contains(functionName)) { + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); + evaluateFunctionDependencies(context, model, e.getRoot()); + context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); + } + } + + private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) { + if (node instanceof ReferenceNode) { + String name = node.toString(); + if (model.functions().containsKey(name)) { + evaluateFunction(context, model, name); + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) { + evaluateFunctionDependencies(context, model, child); + } + } + } + } diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx Binary files differnew file mode 100644 index 00000000000..62451ad953d --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/gather.onnx diff --git a/model-integration/src/test/models/onnx/simple/gather.py b/model-integration/src/test/models/onnx/simple/gather.py new file mode 100755 index 00000000000..63a2103fd86 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/gather.py @@ -0,0 +1,23 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +import numpy as np +from onnx import helper, TensorProto + +data_type = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3,2]) +indices_type = helper.make_tensor_value_info('indices', TensorProto.FLOAT, [2,2]) +output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2,2,2]) + +node = onnx.helper.make_node( + 'Gather', + inputs=['data', 'indices'], + outputs=['y'], + axis=0, +) +graph_def = onnx.helper.make_graph( + nodes = [node], + name = 'gather_test', + inputs = [data_type, indices_type], + outputs = [output_type] +) +model_def = helper.make_model(graph_def, producer_name='gather.py') +onnx.save(model_def, 'gather.onnx') |