diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
commit | aad5c7184f37e1441c928efa77b434620742ff88 (patch) | |
tree | 34a92e7f954aa92e21d48816335771ff607fe404 /model-integration/src/test | |
parent | 6f5ca49e45cdc8262fcf360b1c731a393385ffa8 (diff) |
Update model-integration for supporting BERT-type models
Diffstat (limited to 'model-integration/src/test')
7 files changed, 160 insertions, 17 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 94c5577357b..d5dff7fb1b7 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -107,6 +107,18 @@ public class OnnxOperationsTestCase { assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y)); assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y)); assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y)); + + // broadcasting - opposite order + x = evaluate("random(d0[4]) + 1"); + y = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + assertEval("add", x, y, evaluate("rename(x, d0, d2) + y", x, y)); + assertEval("sub", x, y, evaluate("rename(x, d0, d2) - y", x, y)); + assertEval("mul", x, y, evaluate("rename(x, d0, d2) * y", x, y)); + assertEval("div", x, y, evaluate("rename(x, d0, d2) / y", x, y)); + assertEval("greater", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(pow(a,b)))", x, y)); } @Test @@ -185,9 +197,49 @@ public class OnnxOperationsTestCase { @Test public void testMatMul1() throws ParseException { - Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]"); - Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]"); - assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]")); + Tensor a = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + Tensor b = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("91")); + + a = evaluate("tensor(d0[3]):[1,2,3]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2]):[22, 28]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3]):[1,2,3]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2]):[14, 32]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[2],d1[3],d2[2]):[1,2,3,4,5,6,7,8,9,10,11,12]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,58,64,139,154]")); + + a = evaluate("tensor(d0[2],d1[2],d2[3]):[1,2,3,4,5,6,7,8,9,10,11,12]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,76,100,103,136]")); + + a = evaluate("tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]")); + + a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[2],d1[2],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]")); } @Test @@ -217,6 +269,10 @@ public class OnnxOperationsTestCase { y = evaluate("tensor(d0[4]):[3,2,-1,1]"); assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]")); + + x = evaluate("tensor(d0[1],d1[2],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12]"); + y = evaluate("tensor(d0[2]):[2,6]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[6]):[1,2,3,4,5,6,7,8,9,10,11,12]")); } @Test @@ -435,6 +491,48 @@ public class OnnxOperationsTestCase { } + @Test + public void testTranspose1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[3]):[[1,2,3],[4,5,6]]"); + assertEval("transpose", x, evaluate("tensor(d0[3],d1[2]):[[1,4],[2,5],[3,6]]")); + } + + @Test + public void testTile6() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + Tensor y = evaluate("tensor(d0[2]):[1,2]"); + assertEval("tile", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]")); + + x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + y = evaluate("tensor(d0[2]):[3,1]"); + assertEval("tile", x, y, evaluate("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]")); + + x = evaluate("tensor(d0[1],d1[1],d2[1]):[1]"); + y = evaluate("tensor(d0[3]):[1,6,1]"); + assertEval("tile", x, y, evaluate("tensor(d0[1],d1[6],d2[1]):[1,1,1,1,1,1]")); + } + + @Test + public void testSplit2() throws ParseException { + Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + assertEval("split", x, evaluate("tensor(d0[3]):[1,2,3]"), 0); + assertEval("split", x, evaluate("tensor(d0[3]):[4,5,6]"), 1); + assertEval("split", x, evaluate("tensor(d0[2]):[1,2]"), createAttribute("split", new int[] {2}), 0); + assertEval("split", x, evaluate("tensor(d0[4]):[3,4,5,6]"), createAttribute("split", new int[] {2}), 1); + assertEval("split", x, evaluate("tensor(d0[3]):[3,4,5]"), createAttribute("split", new int[] {2,3}), 1); + assertEval("split", x, evaluate("tensor(d0[1]):[6]"), createAttribute("split", new int[] {2,3}), 2); + + x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]")); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), 0); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), 1); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), createAttribute("split", new int[] {1}), 0); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), createAttribute("split", new int[] {1}), 1); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[1,4]"), createAttribute("axis", 1), 0); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[2,5]"), createAttribute("axis", 1), 1); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2); + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -461,41 +559,49 @@ public class OnnxOperationsTestCase { } private void assertEval(String opName, Tensor x, Tensor expected) { - assertEval(opName, x, null, null, null, null, expected, null); + assertEval(opName, x, null, null, null, null, expected, null, 0); + } + + private void assertEval(String opName, Tensor x, Tensor expected, int output) { + assertEval(opName, x, null, null, null, null, expected, null, output); } private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, null, null, null, null, expected, attr); + assertEval(opName, x, null, null, null, null, expected, attr, 0); + } + + private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr, int output) { + assertEval(opName, x, null, null, null, null, expected, attr, output); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, null, null, null, expected, attr); + assertEval(opName, x, y, null, null, null, expected, attr, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { - assertEval(opName, x, y, null, null, null, expected, null); + assertEval(opName, x, y, null, null, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { - assertEval(opName, x, y, z, null, null, expected, null); + assertEval(opName, x, y, z, null, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, z, null, null, expected, attr); + assertEval(opName, x, y, z, null, null, expected, attr, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) { - assertEval(opName, x, y, z, q, null, expected, null); + assertEval(opName, x, y, z, q, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) { - assertEval(opName, x, y, z, q, r, expected, null); + assertEval(opName, x, y, z, q, r, expected, null, 0); } - private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) { + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr, int output) { Context context = new MapContext(DoubleValue.NaN); List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r); - IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build(), output); optimizeAndRename(opName, op); Tensor result = evaluate(op); assertEquals(expected, result); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index 9631bddd93d..04db902073b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -11,7 +11,6 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -48,6 +47,19 @@ public class SimpleImportTestCase { assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]")); } + @Test + public void testConcat() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx"); + + MapContext context = new MapContext(); + context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]"))); + context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]"))); + context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]"))); + + Tensor result = model.expressions().get("y").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]")); + } + private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { RankingExpression e = RankingExpression.from(model.functions().get(functionName)); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index b9d767774be..25f8acf1f6d 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -34,7 +34,7 @@ public class DropoutImportTestCase { ImportedMlFunction function = signature.outputFunction("y", "y"); assertNotNull(function); - assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(reduce(constant(test_outputs_Const), sum, d1), imported_ml_function_test_outputs_BiasAdd, f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", function.expression()); model.assertEqualResult("X", "outputs/Maximum"); assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); diff --git a/model-integration/src/test/models/onnx/simple/concat.onnx b/model-integration/src/test/models/onnx/simple/concat.onnx Binary files differnew file mode 100644 index 00000000000..945bc3c9445 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/concat.onnx diff --git a/model-integration/src/test/models/onnx/simple/concat.py b/model-integration/src/test/models/onnx/simple/concat.py new file mode 100755 index 00000000000..186002c2abb --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/concat.py @@ -0,0 +1,25 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1]) +j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1]) +k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1]) + +output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + +node = onnx.helper.make_node( + 'Concat', + inputs=['i', 'j', 'k'], + outputs=['y'], + axis=0, +) +graph_def = onnx.helper.make_graph( + nodes = [node], + name = 'concat_test', + inputs = [i_type, j_type, k_type], + outputs = [output_type] +) +model_def = helper.make_model(graph_def, producer_name='concat.py') +onnx.save(model_def, 'concat.onnx') diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx Binary files differindex 62451ad953d..0647d86ed0f 100644 --- a/model-integration/src/test/models/onnx/simple/gather.onnx +++ b/model-integration/src/test/models/onnx/simple/gather.onnx diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx index 1c746c90efa..41b458451d0 100644 --- a/model-integration/src/test/models/onnx/simple/simple.onnx +++ b/model-integration/src/test/models/onnx/simple/simple.onnx @@ -1,4 +1,4 @@ - simple.py:ã + simple.py:ã 0 query_tensor attribute_tensormatmul"MatMul @@ -20,4 +20,4 @@ output -B +B
\ No newline at end of file |