summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-02-10 10:33:53 +0100
committerLester Solbakken <lesters@oath.com>2020-02-10 10:33:53 +0100
commitc6ea3aa88e8929c2cbfe90f9c9ffdde482b7adc5 (patch)
tree550c76a8310c4951a3c5ae4c6e53af889bb9b54c /model-integration/src/test/java
parent7b5b53d288ab8b3c9ec8e054d4d5ecf2f88f7ff0 (diff)
Add gather,slice,cast,unsqueeze onnx operations
Diffstat (limited to 'model-integration/src/test/java')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java165
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java56
2 files changed, 202 insertions, 19 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 6954abe5157..94c5577357b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -17,6 +17,7 @@ import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import onnx.Onnx;
+import org.junit.Ignore;
import org.junit.Test;
import java.util.ArrayList;
@@ -26,7 +27,9 @@ import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*;
import static onnx.Onnx.AttributeProto.AttributeType.FLOAT;
import static onnx.Onnx.AttributeProto.AttributeType.INT;
import static onnx.Onnx.AttributeProto.AttributeType.INTS;
+import static onnx.Onnx.AttributeProto.AttributeType.TENSOR;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
/**
* Unit tests for ONNX operators. The number on the test reflects the minimum
@@ -294,6 +297,27 @@ public class OnnxOperationsTestCase {
}
@Test
+ public void testUnsqueeze1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[1, 2]");
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {1}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {-1}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {-2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0,0}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {0,2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {2,0}));
+
+ x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,1}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[3],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,3}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {1,2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[3],d2[1],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2,3}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2,4}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {4,2,0}));
+ }
+
+ @Test
public void testWhere9() throws ParseException {
Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
@@ -308,6 +332,109 @@ public class OnnxOperationsTestCase {
assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x);
}
+ @Test
+ public void testCast1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[4]):[-1.9, 0.0, 1.1, 2.9]");
+ assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 9)); // boolean
+ assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 6)); // int32
+ assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 12)); // uint32
+ assertEval("cast", x, evaluate("tensor(d0[4]):[-1.9,0,1.1,2.9]"), createAttribute("to", 1)); // float
+ try {
+ assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 8)); // string
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertEquals(e.getMessage(), "OnnxCast in cast: Casting to string is not implemented.");
+ }
+ }
+
+ @Test
+ public void testGather1() throws ParseException {
+ // 1 dim input, 1 dim indices
+ Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ Tensor y = evaluate("tensor(d0[3]):[0,2,4]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]"));
+
+ // 2 dim input, 1 dim indices - axis 0
+ x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[4]):[2, 1, 0, 2]");
+ assertEval("gather", x, y, evaluate("tensor(d0[4],d1[2]):[5, 6, 3, 4, 1, 2, 5, 6]"));
+
+ // 1 dim input, 2 dim indices - axis 0
+ x = evaluate("tensor(d0[6]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[2],d1[2]):[0, 1, 3, 5]");
+ assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2]):[1, 2, 4, 6]"));
+
+ // 2 dim input, 2 dim indices - axis 0
+ x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[2],d1[2]):[0, 1, 1, 2]");
+ assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"), createAttribute("axis", -2));
+
+ // 2 dim input, 1 dim indices - axis 1
+ x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[4]):[0, 1, 0, 1]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3],d1[4]):[1,2,1,2,3,4,3,4,5,6,5,6]"), createAttribute("axis", 1));
+
+ // 2 dim input, 2 dim indices - axis 1
+ x = evaluate("tensor(d0[3],d1[3]):[1, 2, 3, 4, 5, 6, 7, 8, 9]");
+ y = evaluate("tensor(d0[1],d1[2]):[0, 2]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3],d1[1],d2[2]):[1,3,4,6,7,9]"), createAttribute("axis", 1));
+
+ // 1 dim input, 1 dim indices - negative indices
+ x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ y = evaluate("tensor(d0[3]):[0,-2,-4]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,5,3]"));
+ }
+
+ @Test
+ public void testSlice1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]");
+ AttributeConverter attributes = createAttributes().
+ attr("starts", new int[] {1, 0}).
+ attr("ends", new int[] {2, 3}).
+ attr("axes", new int[] {0, 1}).build();
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {0, 1}).
+ attr("ends", new int[] {-1, 1000}).build();
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [2,3,4] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {0, 1}).
+ attr("ends", new int[] {3, 2}).
+ attr("axes", new int[] {1, 0}).build(); // axes are switched
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {1, 0}).
+ attr("ends", new int[] {2, 3}).
+ attr("axes", new int[] {0, -1}).build(); // negative axes
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {1}).
+ attr("ends", new int[] {2}).
+ attr("axes", new int[] {0}).build(); // axis 1 is not specified
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [5,6,7,8] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {0}).
+ attr("ends", new int[] {1}).build();
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [1,2,3,4] ]"), attributes);
+ }
+
+ @Ignore
+ @Test
+ public void testSlice10() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]");
+ Tensor starts = evaluate("tensor(d0[2]):[1,0]");
+ Tensor ends = evaluate("tensor(d0[2]):[2,3]");
+ Tensor axes = evaluate("tensor(d0[2]):[0,1]");
+ Tensor steps = evaluate("tensor(d0[2]):[1,2]");
+ assertEval("slice", x, starts, ends, axes, steps, evaluate("tensor(d0[1],d1[2]):[ [5,7] ]"));
+
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -334,28 +461,40 @@ public class OnnxOperationsTestCase {
}
private void assertEval(String opName, Tensor x, Tensor expected) {
- assertEval(opName, x, null, null, expected, null);
+ assertEval(opName, x, null, null, null, null, expected, null);
}
private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, null, null, expected, attr);
+ assertEval(opName, x, null, null, null, null, expected, attr);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, null, expected, attr);
+ assertEval(opName, x, y, null, null, null, expected, attr);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) {
- assertEval(opName, x, y, null, expected, null);
+ assertEval(opName, x, y, null, null, null, expected, null);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) {
- assertEval(opName, x, y, z, expected, null);
+ assertEval(opName, x, y, z, null, null, expected, null);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, x, y, z, null, null, expected, attr);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) {
+ assertEval(opName, x, y, z, q, null, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) {
+ assertEval(opName, x, y, z, q, r, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) {
Context context = new MapContext(DoubleValue.NaN);
- List<IntermediateOperation> inputs = createInputs(context, x, y, z);
+ List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r);
IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build());
optimizeAndRename(opName, op);
Tensor result = evaluate(op);
@@ -363,11 +502,13 @@ public class OnnxOperationsTestCase {
assertEquals(expected.type(), result.type());
}
- private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) {
+ private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r) {
List<IntermediateOperation> inputs = new ArrayList<>();
addInput(inputs, context, x, "x");
addInput(inputs, context, y, "y");
addInput(inputs, context, z, "z");
+ addInput(inputs, context, q, "q");
+ addInput(inputs, context, r, "r");
return inputs;
}
@@ -451,6 +592,16 @@ public class OnnxOperationsTestCase {
return this;
}
+ Attributes attr(String name, Tensor tensor) {
+ Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder();
+ builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);;
+ tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get()));
+ tensor.valueIterator().forEachRemaining(builder::addDoubleData);
+ Onnx.TensorProto val = builder.build();
+ nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build());
+ return this;
+ }
+
AttributeConverter build() {
return AttributeConverter.convert(nodeBuilder.build());
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index d1dea730da5..9631bddd93d 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -3,8 +3,13 @@
package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -21,21 +26,48 @@ public class SimpleImportTestCase {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx");
MapContext context = new MapContext();
- context.put("query_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[4])")).
- cell(0.1, 0, 0).
- cell(0.2, 0, 1).
- cell(0.3, 0, 2).
- cell(0.4, 0, 3).build()));
- context.put("attribute_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[4],d1[1])")).
- cell(0.1, 0, 0).
- cell(0.2, 1, 0).
- cell(0.3, 2, 0).
- cell(0.4, 3, 0).build()));
- context.put("bias_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[1])")).
- cell(1.0, 0, 0).build()));
+ context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]")));
+ context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]")));
+ context.put("bias_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[1]):[1.0]")));
Tensor result = model.expressions().get("output").evaluate(context).asTensor();
assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}"));
}
+ @Test
+ public void testGather() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx");
+
+ MapContext context = new MapContext();
+ context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]")));
+ context.put("indices", new TensorValue(Tensor.from("tensor(d0[2],d1[2]):[0, 1, 1, 2]")));
+
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
+
+ Tensor result = model.expressions().get("y").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"));
+ }
+
+ private void evaluateFunction(Context context, ImportedModel model, String functionName) {
+ if (!context.names().contains(functionName)) {
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
+ evaluateFunctionDependencies(context, model, e.getRoot());
+ context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
+ }
+ }
+
+ private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ if (model.functions().containsKey(name)) {
+ evaluateFunction(context, model, name);
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateFunctionDependencies(context, model, child);
+ }
+ }
+ }
+
}