summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-04-21 15:26:58 +0200
committerLester Solbakken <lesters@oath.com>2020-04-21 15:26:58 +0200
commitaad5c7184f37e1441c928efa77b434620742ff88 (patch)
tree34a92e7f954aa92e21d48816335771ff607fe404 /model-integration/src/test
parent6f5ca49e45cdc8262fcf360b1c731a393385ffa8 (diff)
Update model-integration for supporting BERT-type models
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java132
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java14
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java2
-rw-r--r--model-integration/src/test/models/onnx/simple/concat.onnxbin0 -> 135 bytes
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/concat.py25
-rw-r--r--model-integration/src/test/models/onnx/simple/gather.onnxbin150 -> 150 bytes
-rw-r--r--model-integration/src/test/models/onnx/simple/simple.onnx4
7 files changed, 160 insertions, 17 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 94c5577357b..d5dff7fb1b7 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -107,6 +107,18 @@ public class OnnxOperationsTestCase {
assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y));
assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y));
assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y));
+
+ // broadcasting - opposite order
+ x = evaluate("random(d0[4]) + 1");
+ y = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ assertEval("add", x, y, evaluate("rename(x, d0, d2) + y", x, y));
+ assertEval("sub", x, y, evaluate("rename(x, d0, d2) - y", x, y));
+ assertEval("mul", x, y, evaluate("rename(x, d0, d2) * y", x, y));
+ assertEval("div", x, y, evaluate("rename(x, d0, d2) / y", x, y));
+ assertEval("greater", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a > b))", x, y));
+ assertEval("less", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a < b))", x, y));
+ assertEval("equal", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a == b))", x, y));
+ assertEval("pow", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(pow(a,b)))", x, y));
}
@Test
@@ -185,9 +197,49 @@ public class OnnxOperationsTestCase {
@Test
public void testMatMul1() throws ParseException {
- Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]");
- Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]");
- assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]"));
+ Tensor a = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ Tensor b = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("91"));
+
+ a = evaluate("tensor(d0[3]):[1,2,3]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2]):[22, 28]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3]):[1,2,3]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2]):[14, 32]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[2],d1[3],d2[2]):[1,2,3,4,5,6,7,8,9,10,11,12]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,58,64,139,154]"));
+
+ a = evaluate("tensor(d0[2],d1[2],d2[3]):[1,2,3,4,5,6,7,8,9,10,11,12]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,76,100,103,136]"));
+
+ a = evaluate("tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]"));
+
+ a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[2],d1[2],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]"));
}
@Test
@@ -217,6 +269,10 @@ public class OnnxOperationsTestCase {
y = evaluate("tensor(d0[4]):[3,2,-1,1]");
assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]"));
+
+ x = evaluate("tensor(d0[1],d1[2],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12]");
+ y = evaluate("tensor(d0[2]):[2,6]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[6]):[1,2,3,4,5,6,7,8,9,10,11,12]"));
}
@Test
@@ -435,6 +491,48 @@ public class OnnxOperationsTestCase {
}
+ @Test
+ public void testTranspose1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[3]):[[1,2,3],[4,5,6]]");
+ assertEval("transpose", x, evaluate("tensor(d0[3],d1[2]):[[1,4],[2,5],[3,6]]"));
+ }
+
+ @Test
+ public void testTile6() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ Tensor y = evaluate("tensor(d0[2]):[1,2]");
+ assertEval("tile", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]"));
+
+ x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ y = evaluate("tensor(d0[2]):[3,1]");
+ assertEval("tile", x, y, evaluate("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]"));
+
+ x = evaluate("tensor(d0[1],d1[1],d2[1]):[1]");
+ y = evaluate("tensor(d0[3]):[1,6,1]");
+ assertEval("tile", x, y, evaluate("tensor(d0[1],d1[6],d2[1]):[1,1,1,1,1,1]"));
+ }
+
+ @Test
+ public void testSplit2() throws ParseException {
+ Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ assertEval("split", x, evaluate("tensor(d0[3]):[1,2,3]"), 0);
+ assertEval("split", x, evaluate("tensor(d0[3]):[4,5,6]"), 1);
+ assertEval("split", x, evaluate("tensor(d0[2]):[1,2]"), createAttribute("split", new int[] {2}), 0);
+ assertEval("split", x, evaluate("tensor(d0[4]):[3,4,5,6]"), createAttribute("split", new int[] {2}), 1);
+ assertEval("split", x, evaluate("tensor(d0[3]):[3,4,5]"), createAttribute("split", new int[] {2,3}), 1);
+ assertEval("split", x, evaluate("tensor(d0[1]):[6]"), createAttribute("split", new int[] {2,3}), 2);
+
+ x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"));
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), 0);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), 1);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), createAttribute("split", new int[] {1}), 0);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), createAttribute("split", new int[] {1}), 1);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[1,4]"), createAttribute("axis", 1), 0);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[2,5]"), createAttribute("axis", 1), 1);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2);
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -461,41 +559,49 @@ public class OnnxOperationsTestCase {
}
private void assertEval(String opName, Tensor x, Tensor expected) {
- assertEval(opName, x, null, null, null, null, expected, null);
+ assertEval(opName, x, null, null, null, null, expected, null, 0);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, int output) {
+ assertEval(opName, x, null, null, null, null, expected, null, output);
}
private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, null, null, null, null, expected, attr);
+ assertEval(opName, x, null, null, null, null, expected, attr, 0);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr, int output) {
+ assertEval(opName, x, null, null, null, null, expected, attr, output);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, null, null, null, expected, attr);
+ assertEval(opName, x, y, null, null, null, expected, attr, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) {
- assertEval(opName, x, y, null, null, null, expected, null);
+ assertEval(opName, x, y, null, null, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) {
- assertEval(opName, x, y, z, null, null, expected, null);
+ assertEval(opName, x, y, z, null, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, z, null, null, expected, attr);
+ assertEval(opName, x, y, z, null, null, expected, attr, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) {
- assertEval(opName, x, y, z, q, null, expected, null);
+ assertEval(opName, x, y, z, q, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) {
- assertEval(opName, x, y, z, q, r, expected, null);
+ assertEval(opName, x, y, z, q, r, expected, null, 0);
}
- private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) {
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr, int output) {
Context context = new MapContext(DoubleValue.NaN);
List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r);
- IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build());
+ IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build(), output);
optimizeAndRename(opName, op);
Tensor result = evaluate(op);
assertEquals(expected, result);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index 9631bddd93d..04db902073b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -11,7 +11,6 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -48,6 +47,19 @@ public class SimpleImportTestCase {
assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"));
}
+ @Test
+ public void testConcat() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx");
+
+ MapContext context = new MapContext();
+ context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]")));
+ context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]")));
+ context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]")));
+
+ Tensor result = model.expressions().get("y").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]"));
+ }
+
private void evaluateFunction(Context context, ImportedModel model, String functionName) {
if (!context.names().contains(functionName)) {
RankingExpression e = RankingExpression.from(model.functions().get(functionName));
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index b9d767774be..25f8acf1f6d 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -34,7 +34,7 @@ public class DropoutImportTestCase {
ImportedMlFunction function = signature.outputFunction("y", "y");
assertNotNull(function);
- assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(reduce(constant(test_outputs_Const), sum, d1), imported_ml_function_test_outputs_BiasAdd, f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
function.expression());
model.assertEqualResult("X", "outputs/Maximum");
assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
diff --git a/model-integration/src/test/models/onnx/simple/concat.onnx b/model-integration/src/test/models/onnx/simple/concat.onnx
new file mode 100644
index 00000000000..945bc3c9445
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/concat.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/concat.py b/model-integration/src/test/models/onnx/simple/concat.py
new file mode 100755
index 00000000000..186002c2abb
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/concat.py
@@ -0,0 +1,25 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1])
+j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1])
+k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1])
+
+output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3])
+
+node = onnx.helper.make_node(
+ 'Concat',
+ inputs=['i', 'j', 'k'],
+ outputs=['y'],
+ axis=0,
+)
+graph_def = onnx.helper.make_graph(
+ nodes = [node],
+ name = 'concat_test',
+ inputs = [i_type, j_type, k_type],
+ outputs = [output_type]
+)
+model_def = helper.make_model(graph_def, producer_name='concat.py')
+onnx.save(model_def, 'concat.onnx')
diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx
index 62451ad953d..0647d86ed0f 100644
--- a/model-integration/src/test/models/onnx/simple/gather.onnx
+++ b/model-integration/src/test/models/onnx/simple/gather.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx
index 1c746c90efa..41b458451d0 100644
--- a/model-integration/src/test/models/onnx/simple/simple.onnx
+++ b/model-integration/src/test/models/onnx/simple/simple.onnx
@@ -1,4 +1,4 @@
- simple.py:ã
+ simple.py:ã
0
query_tensor
attribute_tensormatmul"MatMul
@@ -20,4 +20,4 @@
output


-B
+B \ No newline at end of file