diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-18 16:18:44 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-18 16:18:44 +0200 |
commit | 5688a50eb92fc4459e51dccca45858aecca8264a (patch) | |
tree | 5e374097e6697ef7c3652b4be7851e16e824d398 /model-integration | |
parent | c9c64237f0ee4c117ecafb9ef188ed853a7fa0c8 (diff) |
Support additional ONNX operators
Diffstat (limited to 'model-integration')
8 files changed, 459 insertions, 11 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index d14ad033a69..a6ce5e40ed3 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,11 +2,15 @@ package ai.vespa.rankingexpression.importer.onnx; +import ai.vespa.rankingexpression.importer.operations.ConstantOfShape; +import ai.vespa.rankingexpression.importer.operations.Expand; import ai.vespa.rankingexpression.importer.operations.Gather; +import ai.vespa.rankingexpression.importer.operations.OnnxConstant; import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; +import ai.vespa.rankingexpression.importer.operations.Range; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Slice; @@ -81,11 +85,15 @@ class GraphImporter { case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); + case "constant": return new OnnxConstant(modelName, nodeName, inputs, attributes); + case "constantofshape": return new ConstantOfShape(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); + case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); // approximation until we have erf in backend. case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); + case "expand": return new Expand(modelName, nodeName, inputs); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); case "gather": return new Gather(modelName, nodeName, inputs, attributes); case "gemm": return new Gemm(modelName, nodeName, inputs, attributes); @@ -100,6 +108,7 @@ class GraphImporter { case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); + case "range": return new Range(modelName, nodeName, inputs); case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null); case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java new file mode 100644 index 00000000000..887e350b430 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java @@ -0,0 +1,83 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class ConstantOfShape extends IntermediateOperation { + + private final AttributeMap attributeMap; + + private TensorType.Value valueTypeOfTensor = TensorType.Value.DOUBLE; + private double valueToFillWith = 0.0; + + + public ConstantOfShape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + + Optional<Value> value = attributeMap.get("value"); + if (value.isPresent()) { + Tensor t = value.get().asTensor(); + valueTypeOfTensor = t.type().valueType(); + valueToFillWith = t.valueIterator().next(); + } + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + IntermediateOperation input = inputs.get(0); + if (input.getConstantValue().isEmpty()) { + throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a constant."); + } + Tensor shape = input.getConstantValue().get().asTensor(); + if (shape.type().dimensions().size() > 1) { + throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a tensor with 0 or 1 dimensions."); + } + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueTypeOfTensor); + Iterator<Double> iter = shape.valueIterator(); + for (int i = 0; iter.hasNext(); i++) { + builder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().longValue())); + } + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputTypesPresent(1)) return null; + ExpressionNode valueExpr = new ConstantNode(new DoubleValue(valueToFillWith)); + TensorFunction function = Generate.bound(type.type(), wrapScalar(valueExpr)); + return function; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public ConstantOfShape withInputs(List<IntermediateOperation> inputs) { + return new ConstantOfShape(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "ConstantOfShape"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java new file mode 100644 index 00000000000..30a7bc3bbad --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java @@ -0,0 +1,122 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class Expand extends IntermediateOperation { + + public Expand(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) return null; + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + Optional<Value> shapeValue = inputs.get(1).getConstantValue(); + if (shapeValue.isEmpty()) + throw new IllegalArgumentException("Expand " + name + ": shape must be a constant."); + + Tensor shape = shapeValue.get().asTensor(); + if (shape.type().rank() != 1) + throw new IllegalArgumentException("Expand " + name + ": shape must be a 1-d tensor."); + + OrderedTensorType inputType = inputs.get(0).type().get(); + + int inputRank = inputType.rank(); + int shapeSize = shape.type().dimensions().get(0).size().get().intValue(); + int sizeDiff = shapeSize - inputRank; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(inputType.type().valueType()); + Iterator<Double> iter = shape.valueIterator(); + + // Add any extra dimensions + for (int i = 0; i < sizeDiff; ++i) { + typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().intValue())); + } + + // Dimensions are matched innermost + for (int i = sizeDiff; i < shapeSize; i++) { + int shapeDimSize = iter.next().intValue(); + int inputDimSize = inputType.dimensions().get(i - sizeDiff).size().get().intValue(); + if (shapeDimSize != inputDimSize && shapeDimSize != 1 && inputDimSize != 1) { + throw new IllegalArgumentException("Expand " + name + ": dimension sizes of input and shape " + + "are not compatible. Either they must be equal or one must be of size 1."); + } + int dimSize = Math.max(shapeDimSize, inputDimSize); + typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, dimSize)); + } + + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + OrderedTensorType type = type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + int sizeDiff = type().get().rank() - inputType.rank(); + for (int i = sizeDiff; i < type().get().rank(); ++i) { + String inputDimensionName = inputType.dimensions().get(i - sizeDiff).name(); + String typeDimensionName = type.dimensionNames().get(i); + long inputDimensionSize = inputType.dimensions().get(i - sizeDiff).size().get(); + + ExpressionNode index; + if (inputDimensionSize == 1) { + index = new ConstantNode(new DoubleValue(0.0)); + } else { + index = new EmbracedNode(new ReferenceNode(typeDimensionName)); + } + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(index))); + } + + TensorFunction<Reference> externalRef = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(externalRef, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + return Generate.bound(type.type(), wrapScalar(sliceExpression)); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public Expand withInputs(List<IntermediateOperation> inputs) { + return new Expand(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Expand"; } + +} + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index 2a34ae53d5e..91ff5d9cdd8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -105,6 +105,9 @@ public class Gather extends IntermediateOperation { private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) { TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName)); + if (dimensionValues.isEmpty()) { + return new TensorFunctionNode(inputIndices); + } Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues); return new TensorFunctionNode(sliceIndices); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java new file mode 100644 index 00000000000..3c5ddf48cfc --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java @@ -0,0 +1,91 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +public class OnnxConstant extends IntermediateOperation { + + private final AttributeMap attributeMap; + private final Value value; + + public OnnxConstant(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + this.value = value(); + setConstantValueFunction(type -> new TensorValue(this.value.asTensor())); + } + + @Override + protected OrderedTensorType lazyGetType() { + OrderedTensorType type; + if (value instanceof TensorValue) { + type = OrderedTensorType.fromSpec(value.type().toString()).rename(vespaName() + "_"); + } else { + type = OrderedTensorType.fromDimensionList(TensorType.Value.DOUBLE, Collections.emptyList()); + } + return type; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public Optional<Value> getConstantValue() { + return Optional.of(new TensorValue(value.asTensor().withType(type().get().type()))); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public boolean isConstant() { + return true; + } + + @Override + public OnnxConstant withInputs(List<IntermediateOperation> inputs) { + return new OnnxConstant(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Constant"; } + + @Override + public String toString() { + return "Constant(" + type + ")"; + } + + @Override + public String toFullString() { + return "\t" + type + ":\tConstant(" + type + ")"; + } + + private Value value() { + Optional<Value> value = attributeMap.get("value"); + if (value.isEmpty()) { + value = attributeMap.get("value_float"); + if (value.isEmpty()) { + value = attributeMap.get("value_int"); + } + } + if (value.isEmpty()) { + throw new IllegalArgumentException("Node '" + name + "' of type " + + "constant has missing or non-supported 'value' attribute"); + } + return value.get(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java new file mode 100644 index 00000000000..6df686cf910 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -0,0 +1,86 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class Range extends IntermediateOperation { + + private double start; + private double limit; + private double delta; + private long elements; + + public Range(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + private double getConstantInput(int index, String name) { + IntermediateOperation input = inputs.get(index); + if (input.getConstantValue().isEmpty()) { + throw new IllegalArgumentException("Range: " + name + " input must be a constant."); + } + Tensor value = input.getConstantValue().get().asTensor(); + if ( ! input.getConstantValue().get().hasDouble()) { + throw new IllegalArgumentException("Range: " + name + " input must be a scalar."); + } + return value.asDouble(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(3)) return null; + + start = getConstantInput(0, "start"); // must be constant because we need to know type + limit = getConstantInput(1, "limit"); + delta = getConstantInput(2, "delta"); + elements = (long) Math.ceil((limit - start) / delta); + + OrderedTensorType type = new OrderedTensorType.Builder() + .add(TensorType.Dimension.indexed(vespaName(), elements)) + .build(); + return type; + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputTypesPresent(3)) return null; + String dimensionName = type().get().dimensionNames().get(0); + ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); + ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta)); + ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); + ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr); + ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr); + TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr)); + return function; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public Range withInputs(List<IntermediateOperation> inputs) { + return new Range(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Range"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index 8696d0f1858..69283f10711 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -78,14 +78,12 @@ public class Select extends IntermediateOperation { List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions(); - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); - // These tensors should have the same dimension names - renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this); - renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this); + for (int i = 0; i < aDimensions.size(); ++i) { + String aDim = aDimensions.get(i).name(); + String bDim = bDimensions.get(i).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); + } } @Override diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index d5dff7fb1b7..20d1891adb8 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; @@ -405,9 +406,14 @@ public class OnnxOperationsTestCase { @Test public void testGather1() throws ParseException { - // 1 dim input, 1 dim indices + // 1 dim input, 0 dim indices Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); - Tensor y = evaluate("tensor(d0[3]):[0,2,4]"); + Tensor y = evaluate("tensor():[0]"); + assertEval("gather", x, y, evaluate("tensor():[1]")); + + // 1 dim input, 1 dim indices + x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + y = evaluate("tensor(d0[3]):[0,2,4]"); assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]")); // 2 dim input, 1 dim indices - axis 0 @@ -533,6 +539,43 @@ public class OnnxOperationsTestCase { assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2); } + @Test + public void testRange11() throws ParseException { + Tensor start = evaluate("tensor():[3]"); + Tensor limit = evaluate("tensor():[9]"); + Tensor delta = evaluate("tensor():[3]"); + assertEval("range", start, limit, delta, evaluate("tensor(d0[2]):[3,6]")); + + start = evaluate("tensor():[10]"); + limit = evaluate("tensor():[4]"); + delta = evaluate("tensor():[-2]"); + assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]")); + assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]")); + } + + @Test + public void testConstant12() throws ParseException { + assertEval("constant", evaluate("tensor(d0[3]):[1,2,3]"), createAttribute("value", evaluate("tensor(d0[3]):[1,2,3]"))); + assertEval("constant", evaluate("tensor<float>():[313.0]"), createAttribute("value_float", 313.0f)); + assertEval("constant", evaluate("tensor():[42]"), createAttribute("value_int", 42)); + } + + @Test + public void testConstantOfShape9() throws ParseException { + Tensor shape = evaluate("tensor(d0[3]):[1,2,3]"); + assertEval("constantofshape", shape, evaluate("tensor(d0[1],d1[2],d2[3]):[0,0,0,0,0,0]")); + assertEval("constantofshape", shape, evaluate("tensor<float>(d0[1],d1[2],d2[3]):[1,1,1,1,1,1]"), createAttribute("value", evaluate("tensor<float>(d0[1]):[1]"))); + } + + @Test + public void testExpand8() throws ParseException { + Tensor input = evaluate("tensor(d0[3],d1[1]):[1,2,3]"); + Tensor shape = evaluate("tensor(d0[2]):[3,4]"); + assertEval("expand", input, shape, evaluate("tensor(d0[3],d1[4]):[1,1,1,1,2,2,2,2,3,3,3,3]")); + shape = evaluate("tensor(d0[3]):[2,1,4]"); + assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]")); + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -558,6 +601,10 @@ public class OnnxOperationsTestCase { return renameToStandardType(op, tensor); } + private void assertEval(String opName, Tensor expected, AttributeConverter attr) { + assertEval(opName, null, null, null, null, null, expected, attr, 0); + } + private void assertEval(String opName, Tensor x, Tensor expected) { assertEval(opName, x, null, null, null, null, expected, null, 0); } @@ -667,6 +714,10 @@ public class OnnxOperationsTestCase { return new Attributes().attr(name, vals).build(); } + static AttributeConverter createAttribute(String name, Tensor val) { + return new Attributes().attr(name, val).build(); + } + static Attributes createAttributes() { return new Attributes(); } @@ -700,9 +751,14 @@ public class OnnxOperationsTestCase { Attributes attr(String name, Tensor tensor) { Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder(); - builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);; tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get())); - tensor.valueIterator().forEachRemaining(builder::addDoubleData); + if (tensor.type().valueType() == TensorType.Value.FLOAT) { + builder.setDataType(Onnx.TensorProto.DataType.FLOAT); + tensor.valueIterator().forEachRemaining(d -> builder.addFloatData(d.floatValue())); + } else { + builder.setDataType(Onnx.TensorProto.DataType.DOUBLE); + tensor.valueIterator().forEachRemaining(builder::addDoubleData); + } Onnx.TensorProto val = builder.build(); nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build()); return this; |