aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-18 16:18:44 +0200
committerLester Solbakken <lesters@oath.com>2020-06-18 16:18:44 +0200
commit5688a50eb92fc4459e51dccca45858aecca8264a (patch)
tree5e374097e6697ef7c3652b4be7851e16e824d398 /model-integration
parentc9c64237f0ee4c117ecafb9ef188ed853a7fa0c8 (diff)
Support additional ONNX operators
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java83
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java122
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java91
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java86
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java12
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java64
8 files changed, 459 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index d14ad033a69..a6ce5e40ed3 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,11 +2,15 @@
package ai.vespa.rankingexpression.importer.onnx;
+import ai.vespa.rankingexpression.importer.operations.ConstantOfShape;
+import ai.vespa.rankingexpression.importer.operations.Expand;
import ai.vespa.rankingexpression.importer.operations.Gather;
+import ai.vespa.rankingexpression.importer.operations.OnnxConstant;
import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
import ai.vespa.rankingexpression.importer.operations.ConcatReduce;
import ai.vespa.rankingexpression.importer.operations.OnnxConcat;
+import ai.vespa.rankingexpression.importer.operations.Range;
import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Slice;
@@ -81,11 +85,15 @@ class GraphImporter {
case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes);
case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes);
+ case "constant": return new OnnxConstant(modelName, nodeName, inputs, attributes);
+ case "constantofshape": return new ConstantOfShape(modelName, nodeName, inputs, attributes);
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble()));
+ case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); // approximation until we have erf in backend.
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
+ case "expand": return new Expand(modelName, nodeName, inputs);
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
case "gather": return new Gather(modelName, nodeName, inputs, attributes);
case "gemm": return new Gemm(modelName, nodeName, inputs, attributes);
@@ -100,6 +108,7 @@ class GraphImporter {
case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
+ case "range": return new Range(modelName, nodeName, inputs);
case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null);
case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java
new file mode 100644
index 00000000000..887e350b430
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java
@@ -0,0 +1,83 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class ConstantOfShape extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+
+ private TensorType.Value valueTypeOfTensor = TensorType.Value.DOUBLE;
+ private double valueToFillWith = 0.0;
+
+
+ public ConstantOfShape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+
+ Optional<Value> value = attributeMap.get("value");
+ if (value.isPresent()) {
+ Tensor t = value.get().asTensor();
+ valueTypeOfTensor = t.type().valueType();
+ valueToFillWith = t.valueIterator().next();
+ }
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ if (input.getConstantValue().isEmpty()) {
+ throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a constant.");
+ }
+ Tensor shape = input.getConstantValue().get().asTensor();
+ if (shape.type().dimensions().size() > 1) {
+ throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a tensor with 0 or 1 dimensions.");
+ }
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueTypeOfTensor);
+ Iterator<Double> iter = shape.valueIterator();
+ for (int i = 0; iter.hasNext(); i++) {
+ builder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().longValue()));
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputTypesPresent(1)) return null;
+ ExpressionNode valueExpr = new ConstantNode(new DoubleValue(valueToFillWith));
+ TensorFunction function = Generate.bound(type.type(), wrapScalar(valueExpr));
+ return function;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public ConstantOfShape withInputs(List<IntermediateOperation> inputs) {
+ return new ConstantOfShape(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "ConstantOfShape"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java
new file mode 100644
index 00000000000..30a7bc3bbad
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java
@@ -0,0 +1,122 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class Expand extends IntermediateOperation {
+
+ public Expand(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) return null;
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ Optional<Value> shapeValue = inputs.get(1).getConstantValue();
+ if (shapeValue.isEmpty())
+ throw new IllegalArgumentException("Expand " + name + ": shape must be a constant.");
+
+ Tensor shape = shapeValue.get().asTensor();
+ if (shape.type().rank() != 1)
+ throw new IllegalArgumentException("Expand " + name + ": shape must be a 1-d tensor.");
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ int inputRank = inputType.rank();
+ int shapeSize = shape.type().dimensions().get(0).size().get().intValue();
+ int sizeDiff = shapeSize - inputRank;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(inputType.type().valueType());
+ Iterator<Double> iter = shape.valueIterator();
+
+ // Add any extra dimensions
+ for (int i = 0; i < sizeDiff; ++i) {
+ typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().intValue()));
+ }
+
+ // Dimensions are matched innermost
+ for (int i = sizeDiff; i < shapeSize; i++) {
+ int shapeDimSize = iter.next().intValue();
+ int inputDimSize = inputType.dimensions().get(i - sizeDiff).size().get().intValue();
+ if (shapeDimSize != inputDimSize && shapeDimSize != 1 && inputDimSize != 1) {
+ throw new IllegalArgumentException("Expand " + name + ": dimension sizes of input and shape " +
+ "are not compatible. Either they must be equal or one must be of size 1.");
+ }
+ int dimSize = Math.max(shapeDimSize, inputDimSize);
+ typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, dimSize));
+ }
+
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ OrderedTensorType type = type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ int sizeDiff = type().get().rank() - inputType.rank();
+ for (int i = sizeDiff; i < type().get().rank(); ++i) {
+ String inputDimensionName = inputType.dimensions().get(i - sizeDiff).name();
+ String typeDimensionName = type.dimensionNames().get(i);
+ long inputDimensionSize = inputType.dimensions().get(i - sizeDiff).size().get();
+
+ ExpressionNode index;
+ if (inputDimensionSize == 1) {
+ index = new ConstantNode(new DoubleValue(0.0));
+ } else {
+ index = new EmbracedNode(new ReferenceNode(typeDimensionName));
+ }
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(index)));
+ }
+
+ TensorFunction<Reference> externalRef = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(externalRef, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+ return Generate.bound(type.type(), wrapScalar(sliceExpression));
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public Expand withInputs(List<IntermediateOperation> inputs) {
+ return new Expand(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Expand"; }
+
+}
+
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index 2a34ae53d5e..91ff5d9cdd8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -105,6 +105,9 @@ public class Gather extends IntermediateOperation {
private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) {
TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName));
+ if (dimensionValues.isEmpty()) {
+ return new TensorFunctionNode(inputIndices);
+ }
Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues);
return new TensorFunctionNode(sliceIndices);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java
new file mode 100644
index 00000000000..3c5ddf48cfc
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java
@@ -0,0 +1,91 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+public class OnnxConstant extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private final Value value;
+
+ public OnnxConstant(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.value = value();
+ setConstantValueFunction(type -> new TensorValue(this.value.asTensor()));
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ OrderedTensorType type;
+ if (value instanceof TensorValue) {
+ type = OrderedTensorType.fromSpec(value.type().toString()).rename(vespaName() + "_");
+ } else {
+ type = OrderedTensorType.fromDimensionList(TensorType.Value.DOUBLE, Collections.emptyList());
+ }
+ return type;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public Optional<Value> getConstantValue() {
+ return Optional.of(new TensorValue(value.asTensor().withType(type().get().type())));
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ @Override
+ public OnnxConstant withInputs(List<IntermediateOperation> inputs) {
+ return new OnnxConstant(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Constant"; }
+
+ @Override
+ public String toString() {
+ return "Constant(" + type + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "\t" + type + ":\tConstant(" + type + ")";
+ }
+
+ private Value value() {
+ Optional<Value> value = attributeMap.get("value");
+ if (value.isEmpty()) {
+ value = attributeMap.get("value_float");
+ if (value.isEmpty()) {
+ value = attributeMap.get("value_int");
+ }
+ }
+ if (value.isEmpty()) {
+ throw new IllegalArgumentException("Node '" + name + "' of type " +
+ "constant has missing or non-supported 'value' attribute");
+ }
+ return value.get();
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
new file mode 100644
index 00000000000..6df686cf910
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
@@ -0,0 +1,86 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class Range extends IntermediateOperation {
+
+ private double start;
+ private double limit;
+ private double delta;
+ private long elements;
+
+ public Range(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ private double getConstantInput(int index, String name) {
+ IntermediateOperation input = inputs.get(index);
+ if (input.getConstantValue().isEmpty()) {
+ throw new IllegalArgumentException("Range: " + name + " input must be a constant.");
+ }
+ Tensor value = input.getConstantValue().get().asTensor();
+ if ( ! input.getConstantValue().get().hasDouble()) {
+ throw new IllegalArgumentException("Range: " + name + " input must be a scalar.");
+ }
+ return value.asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(3)) return null;
+
+ start = getConstantInput(0, "start"); // must be constant because we need to know type
+ limit = getConstantInput(1, "limit");
+ delta = getConstantInput(2, "delta");
+ elements = (long) Math.ceil((limit - start) / delta);
+
+ OrderedTensorType type = new OrderedTensorType.Builder()
+ .add(TensorType.Dimension.indexed(vespaName(), elements))
+ .build();
+ return type;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputTypesPresent(3)) return null;
+ String dimensionName = type().get().dimensionNames().get(0);
+ ExpressionNode startExpr = new ConstantNode(new DoubleValue(start));
+ ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta));
+ ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName));
+ ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr);
+ ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr);
+ TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr));
+ return function;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public Range withInputs(List<IntermediateOperation> inputs) {
+ return new Range(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Range"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index 8696d0f1858..69283f10711 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -78,14 +78,12 @@ public class Select extends IntermediateOperation {
List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions();
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
-
// These tensors should have the same dimension names
- renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this);
- renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this);
+ for (int i = 0; i < aDimensions.size(); ++i) {
+ String aDim = aDimensions.get(i).name();
+ String bDim = bDimensions.get(i).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
+ }
}
@Override
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index d5dff7fb1b7..20d1891adb8 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
@@ -405,9 +406,14 @@ public class OnnxOperationsTestCase {
@Test
public void testGather1() throws ParseException {
- // 1 dim input, 1 dim indices
+ // 1 dim input, 0 dim indices
Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
- Tensor y = evaluate("tensor(d0[3]):[0,2,4]");
+ Tensor y = evaluate("tensor():[0]");
+ assertEval("gather", x, y, evaluate("tensor():[1]"));
+
+ // 1 dim input, 1 dim indices
+ x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ y = evaluate("tensor(d0[3]):[0,2,4]");
assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]"));
// 2 dim input, 1 dim indices - axis 0
@@ -533,6 +539,43 @@ public class OnnxOperationsTestCase {
assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2);
}
+ @Test
+ public void testRange11() throws ParseException {
+ Tensor start = evaluate("tensor():[3]");
+ Tensor limit = evaluate("tensor():[9]");
+ Tensor delta = evaluate("tensor():[3]");
+ assertEval("range", start, limit, delta, evaluate("tensor(d0[2]):[3,6]"));
+
+ start = evaluate("tensor():[10]");
+ limit = evaluate("tensor():[4]");
+ delta = evaluate("tensor():[-2]");
+ assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]"));
+ assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]"));
+ }
+
+ @Test
+ public void testConstant12() throws ParseException {
+ assertEval("constant", evaluate("tensor(d0[3]):[1,2,3]"), createAttribute("value", evaluate("tensor(d0[3]):[1,2,3]")));
+ assertEval("constant", evaluate("tensor<float>():[313.0]"), createAttribute("value_float", 313.0f));
+ assertEval("constant", evaluate("tensor():[42]"), createAttribute("value_int", 42));
+ }
+
+ @Test
+ public void testConstantOfShape9() throws ParseException {
+ Tensor shape = evaluate("tensor(d0[3]):[1,2,3]");
+ assertEval("constantofshape", shape, evaluate("tensor(d0[1],d1[2],d2[3]):[0,0,0,0,0,0]"));
+ assertEval("constantofshape", shape, evaluate("tensor<float>(d0[1],d1[2],d2[3]):[1,1,1,1,1,1]"), createAttribute("value", evaluate("tensor<float>(d0[1]):[1]")));
+ }
+
+ @Test
+ public void testExpand8() throws ParseException {
+ Tensor input = evaluate("tensor(d0[3],d1[1]):[1,2,3]");
+ Tensor shape = evaluate("tensor(d0[2]):[3,4]");
+ assertEval("expand", input, shape, evaluate("tensor(d0[3],d1[4]):[1,1,1,1,2,2,2,2,3,3,3,3]"));
+ shape = evaluate("tensor(d0[3]):[2,1,4]");
+ assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]"));
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -558,6 +601,10 @@ public class OnnxOperationsTestCase {
return renameToStandardType(op, tensor);
}
+ private void assertEval(String opName, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, null, null, null, null, null, expected, attr, 0);
+ }
+
private void assertEval(String opName, Tensor x, Tensor expected) {
assertEval(opName, x, null, null, null, null, expected, null, 0);
}
@@ -667,6 +714,10 @@ public class OnnxOperationsTestCase {
return new Attributes().attr(name, vals).build();
}
+ static AttributeConverter createAttribute(String name, Tensor val) {
+ return new Attributes().attr(name, val).build();
+ }
+
static Attributes createAttributes() {
return new Attributes();
}
@@ -700,9 +751,14 @@ public class OnnxOperationsTestCase {
Attributes attr(String name, Tensor tensor) {
Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder();
- builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);;
tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get()));
- tensor.valueIterator().forEachRemaining(builder::addDoubleData);
+ if (tensor.type().valueType() == TensorType.Value.FLOAT) {
+ builder.setDataType(Onnx.TensorProto.DataType.FLOAT);
+ tensor.valueIterator().forEachRemaining(d -> builder.addFloatData(d.floatValue()));
+ } else {
+ builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);
+ tensor.valueIterator().forEachRemaining(builder::addDoubleData);
+ }
Onnx.TensorProto val = builder.build();
nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build());
return this;