diff options
author | Lester Solbakken <lesters@oath.com> | 2020-02-10 10:33:53 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-02-10 10:33:53 +0100 |
commit | c6ea3aa88e8929c2cbfe90f9c9ffdde482b7adc5 (patch) | |
tree | 550c76a8310c4951a3c5ae4c6e53af889bb9b54c /model-integration/src/test | |
parent | 7b5b53d288ab8b3c9ec8e054d4d5ecf2f88f7ff0 (diff) |
Add gather,slice,cast,unsqueeze onnx operations
Diffstat (limited to 'model-integration/src/test')
4 files changed, 225 insertions, 19 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 6954abe5157..94c5577357b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -17,6 +17,7 @@ import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; import onnx.Onnx; +import org.junit.Ignore; import org.junit.Test; import java.util.ArrayList; @@ -26,7 +27,9 @@ import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*; import static onnx.Onnx.AttributeProto.AttributeType.FLOAT; import static onnx.Onnx.AttributeProto.AttributeType.INT; import static onnx.Onnx.AttributeProto.AttributeType.INTS; +import static onnx.Onnx.AttributeProto.AttributeType.TENSOR; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Unit tests for ONNX operators. The number on the test reflects the minimum @@ -294,6 +297,27 @@ public class OnnxOperationsTestCase { } @Test + public void testUnsqueeze1() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[1, 2]"); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {1})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {-1})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {-2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0,0})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {0,2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {2,0})); + + x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,1})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[3],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,3})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {1,2})); + assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[3],d2[1],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2,3})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2,4})); + assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {4,2,0})); + } + + @Test public void testWhere9() throws ParseException { Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); @@ -308,6 +332,109 @@ public class OnnxOperationsTestCase { assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x); } + @Test + public void testCast1() throws ParseException { + Tensor x = evaluate("tensor(d0[4]):[-1.9, 0.0, 1.1, 2.9]"); + assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 9)); // boolean + assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 6)); // int32 + assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 12)); // uint32 + assertEval("cast", x, evaluate("tensor(d0[4]):[-1.9,0,1.1,2.9]"), createAttribute("to", 1)); // float + try { + assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 8)); // string + fail(); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), "OnnxCast in cast: Casting to string is not implemented."); + } + } + + @Test + public void testGather1() throws ParseException { + // 1 dim input, 1 dim indices + Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + Tensor y = evaluate("tensor(d0[3]):[0,2,4]"); + assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]")); + + // 2 dim input, 1 dim indices - axis 0 + x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[4]):[2, 1, 0, 2]"); + assertEval("gather", x, y, evaluate("tensor(d0[4],d1[2]):[5, 6, 3, 4, 1, 2, 5, 6]")); + + // 1 dim input, 2 dim indices - axis 0 + x = evaluate("tensor(d0[6]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[2],d1[2]):[0, 1, 3, 5]"); + assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2]):[1, 2, 4, 6]")); + + // 2 dim input, 2 dim indices - axis 0 + x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[2],d1[2]):[0, 1, 1, 2]"); + assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"), createAttribute("axis", -2)); + + // 2 dim input, 1 dim indices - axis 1 + x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"); + y = evaluate("tensor(d0[4]):[0, 1, 0, 1]"); + assertEval("gather", x, y, evaluate("tensor(d0[3],d1[4]):[1,2,1,2,3,4,3,4,5,6,5,6]"), createAttribute("axis", 1)); + + // 2 dim input, 2 dim indices - axis 1 + x = evaluate("tensor(d0[3],d1[3]):[1, 2, 3, 4, 5, 6, 7, 8, 9]"); + y = evaluate("tensor(d0[1],d1[2]):[0, 2]"); + assertEval("gather", x, y, evaluate("tensor(d0[3],d1[1],d2[2]):[1,3,4,6,7,9]"), createAttribute("axis", 1)); + + // 1 dim input, 1 dim indices - negative indices + x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + y = evaluate("tensor(d0[3]):[0,-2,-4]"); + assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,5,3]")); + } + + @Test + public void testSlice1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]"); + AttributeConverter attributes = createAttributes(). + attr("starts", new int[] {1, 0}). + attr("ends", new int[] {2, 3}). + attr("axes", new int[] {0, 1}).build(); + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {0, 1}). + attr("ends", new int[] {-1, 1000}).build(); + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [2,3,4] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {0, 1}). + attr("ends", new int[] {3, 2}). + attr("axes", new int[] {1, 0}).build(); // axes are switched + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {1, 0}). + attr("ends", new int[] {2, 3}). + attr("axes", new int[] {0, -1}).build(); // negative axes + assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {1}). + attr("ends", new int[] {2}). + attr("axes", new int[] {0}).build(); // axis 1 is not specified + assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [5,6,7,8] ]"), attributes); + + attributes = createAttributes(). + attr("starts", new int[] {0}). + attr("ends", new int[] {1}).build(); + assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [1,2,3,4] ]"), attributes); + } + + @Ignore + @Test + public void testSlice10() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]"); + Tensor starts = evaluate("tensor(d0[2]):[1,0]"); + Tensor ends = evaluate("tensor(d0[2]):[2,3]"); + Tensor axes = evaluate("tensor(d0[2]):[0,1]"); + Tensor steps = evaluate("tensor(d0[2]):[1,2]"); + assertEval("slice", x, starts, ends, axes, steps, evaluate("tensor(d0[1],d1[2]):[ [5,7] ]")); + + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -334,28 +461,40 @@ public class OnnxOperationsTestCase { } private void assertEval(String opName, Tensor x, Tensor expected) { - assertEval(opName, x, null, null, expected, null); + assertEval(opName, x, null, null, null, null, expected, null); } private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, null, null, expected, attr); + assertEval(opName, x, null, null, null, null, expected, attr); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, null, expected, attr); + assertEval(opName, x, y, null, null, null, expected, attr); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { - assertEval(opName, x, y, null, expected, null); + assertEval(opName, x, y, null, null, null, expected, null); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { - assertEval(opName, x, y, z, expected, null); + assertEval(opName, x, y, z, null, null, expected, null); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, y, z, null, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) { + assertEval(opName, x, y, z, q, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) { + assertEval(opName, x, y, z, q, r, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) { Context context = new MapContext(DoubleValue.NaN); - List<IntermediateOperation> inputs = createInputs(context, x, y, z); + List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r); IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); optimizeAndRename(opName, op); Tensor result = evaluate(op); @@ -363,11 +502,13 @@ public class OnnxOperationsTestCase { assertEquals(expected.type(), result.type()); } - private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) { + private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r) { List<IntermediateOperation> inputs = new ArrayList<>(); addInput(inputs, context, x, "x"); addInput(inputs, context, y, "y"); addInput(inputs, context, z, "z"); + addInput(inputs, context, q, "q"); + addInput(inputs, context, r, "r"); return inputs; } @@ -451,6 +592,16 @@ public class OnnxOperationsTestCase { return this; } + Attributes attr(String name, Tensor tensor) { + Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder(); + builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);; + tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get())); + tensor.valueIterator().forEachRemaining(builder::addDoubleData); + Onnx.TensorProto val = builder.build(); + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build()); + return this; + } + AttributeConverter build() { return AttributeConverter.convert(nodeBuilder.build()); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index d1dea730da5..9631bddd93d 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -3,8 +3,13 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -21,21 +26,48 @@ public class SimpleImportTestCase { ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx"); MapContext context = new MapContext(); - context.put("query_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[4])")). - cell(0.1, 0, 0). - cell(0.2, 0, 1). - cell(0.3, 0, 2). - cell(0.4, 0, 3).build())); - context.put("attribute_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[4],d1[1])")). - cell(0.1, 0, 0). - cell(0.2, 1, 0). - cell(0.3, 2, 0). - cell(0.4, 3, 0).build())); - context.put("bias_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[1])")). - cell(1.0, 0, 0).build())); + context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]"))); + context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]"))); + context.put("bias_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[1]):[1.0]"))); Tensor result = model.expressions().get("output").evaluate(context).asTensor(); assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}")); } + @Test + public void testGather() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx"); + + MapContext context = new MapContext(); + context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]"))); + context.put("indices", new TensorValue(Tensor.from("tensor(d0[2],d1[2]):[0, 1, 1, 2]"))); + + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); + + Tensor result = model.expressions().get("y").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]")); + } + + private void evaluateFunction(Context context, ImportedModel model, String functionName) { + if (!context.names().contains(functionName)) { + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); + evaluateFunctionDependencies(context, model, e.getRoot()); + context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); + } + } + + private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) { + if (node instanceof ReferenceNode) { + String name = node.toString(); + if (model.functions().containsKey(name)) { + evaluateFunction(context, model, name); + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) { + evaluateFunctionDependencies(context, model, child); + } + } + } + } diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx Binary files differnew file mode 100644 index 00000000000..62451ad953d --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/gather.onnx diff --git a/model-integration/src/test/models/onnx/simple/gather.py b/model-integration/src/test/models/onnx/simple/gather.py new file mode 100755 index 00000000000..63a2103fd86 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/gather.py @@ -0,0 +1,23 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +import numpy as np +from onnx import helper, TensorProto + +data_type = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3,2]) +indices_type = helper.make_tensor_value_info('indices', TensorProto.FLOAT, [2,2]) +output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2,2,2]) + +node = onnx.helper.make_node( + 'Gather', + inputs=['data', 'indices'], + outputs=['y'], + axis=0, +) +graph_def = onnx.helper.make_graph( + nodes = [node], + name = 'gather_test', + inputs = [data_type, indices_type], + outputs = [output_type] +) +model_def = helper.make_model(graph_def, producer_name='gather.py') +onnx.save(model_def, 'gather.onnx') |