summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-19 11:35:40 +0200
committerLester Solbakken <lesters@oath.com>2021-05-19 11:35:40 +0200
commita186020aa62214a714f24091b7928a159a55b166 (patch)
tree418641c48b1fde584c19b8914608fee00bd37628 /model-integration
parent00a724c605b3d1332a119454f1382830df2226d2 (diff)
Add ONNX-RT evaluator to model-integration module
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml4
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java79
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java181
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java5
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java93
-rw-r--r--model-integration/src/test/models/onnx/add_double.onnx16
-rwxr-xr-xmodel-integration/src/test/models/onnx/add_double.py27
-rw-r--r--model-integration/src/test/models/onnx/add_float.onnx16
-rwxr-xr-xmodel-integration/src/test/models/onnx/add_float.py27
-rw-r--r--model-integration/src/test/models/onnx/add_int64.onnx16
-rwxr-xr-xmodel-integration/src/test/models/onnx/add_int64.py27
-rw-r--r--model-integration/src/test/models/onnx/cast_bfloat16_float.onnx12
-rwxr-xr-xmodel-integration/src/test/models/onnx/cast_bfloat16_float.py24
-rw-r--r--model-integration/src/test/models/onnx/cast_float_int8.onnx12
-rwxr-xr-xmodel-integration/src/test/models/onnx/cast_float_int8.py24
-rw-r--r--model-integration/src/test/models/onnx/cast_int8_float.onnx12
-rwxr-xr-xmodel-integration/src/test/models/onnx/cast_int8_float.py24
-rw-r--r--model-integration/src/test/models/onnx/pytorch/one_layer.onnxbin0 -> 299 bytes
-rwxr-xr-xmodel-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py38
-rw-r--r--model-integration/src/test/models/onnx/simple/matmul.onnx16
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/matmul.py27
21 files changed, 680 insertions, 0 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 536d3578f8c..dc3154c5c41 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -53,6 +53,10 @@
</dependency>
<dependency>
+ <groupId>com.microsoft.onnxruntime</groupId>
+ <artifactId>onnxruntime</artifactId>
+ </dependency>
+ <dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</dependency>
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
new file mode 100644
index 00000000000..59ad20b7714
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -0,0 +1,79 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OnnxValue;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+
+/**
+ * Evaluates an ONNX Model by deferring to ONNX Runtime.
+ *
+ * @author lesters
+ */
+public class OnnxEvaluator {
+
+ private final OrtEnvironment environment;
+ private final OrtSession session;
+
+ public OnnxEvaluator(String modelPath) {
+ try {
+ environment = OrtEnvironment.getEnvironment();
+ session = environment.createSession(modelPath, new OrtSession.SessionOptions());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ public Tensor evaluate(Map<String, Tensor> inputs, String output) {
+ try {
+ Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
+ try (OrtSession.Result result = session.run(onnxInputs, Collections.singleton(output))) {
+ return TensorConverter.toVespaTensor(result.get(0));
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
+ try {
+ Map<String, OnnxTensor> onnxInputs = TensorConverter.toOnnxTensors(inputs, environment, session);
+ Map<String, Tensor> outputs = new HashMap<>();
+ try (OrtSession.Result result = session.run(onnxInputs)) {
+ for (Map.Entry<String, OnnxValue> output : result) {
+ outputs.put(output.getKey(), TensorConverter.toVespaTensor(output.getValue()));
+ }
+ return outputs;
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ public Map<String, TensorType> getInputInfo() {
+ try {
+ return TensorConverter.toVespaTypes(session.getInputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+ public Map<String, TensorType> getOutputInfo() {
+ try {
+ return TensorConverter.toVespaTypes(session.getOutputInfo());
+ } catch (OrtException e) {
+ throw new RuntimeException("ONNX Runtime exception", e);
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
new file mode 100644
index 00000000000..c1f973300d6
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -0,0 +1,181 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import ai.onnxruntime.NodeInfo;
+import ai.onnxruntime.OnnxJavaType;
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OnnxValue;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
+import ai.onnxruntime.OrtSession;
+import ai.onnxruntime.TensorInfo;
+import ai.onnxruntime.ValueInfo;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import java.nio.ShortBuffer;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+
+/**
+ * @author lesters
+ */
+class TensorConverter {
+
+ static Map<String, OnnxTensor> toOnnxTensors(Map<String, Tensor> tensorMap, OrtEnvironment env, OrtSession session)
+ throws OrtException
+ {
+ Map<String, OnnxTensor> result = new HashMap<>();
+ for (String name : tensorMap.keySet()) {
+ Tensor vespaTensor = tensorMap.get(name);
+ TensorInfo onnxTensorInfo = toTensorInfo(session.getInputInfo().get(name).getInfo());
+ OnnxTensor onnxTensor = toOnnxTensor(vespaTensor, onnxTensorInfo, env);
+ result.put(name, onnxTensor);
+ }
+ return result;
+ }
+
+ static OnnxTensor toOnnxTensor(Tensor vespaTensor, TensorInfo onnxTensorInfo, OrtEnvironment environment)
+ throws OrtException
+ {
+ if ( ! (vespaTensor instanceof IndexedTensor)) {
+ throw new IllegalArgumentException("OnnxEvaluator currently only supports tensors with indexed dimensions");
+ }
+ IndexedTensor tensor = (IndexedTensor) vespaTensor;
+
+ ByteBuffer buffer = ByteBuffer.allocateDirect((int)tensor.size() * onnxTensorInfo.type.size).order(ByteOrder.nativeOrder());
+ if (onnxTensorInfo.type == OnnxJavaType.FLOAT) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putFloat(tensor.getFloat(i));
+ return OnnxTensor.createTensor(environment, buffer.rewind().asFloatBuffer(), tensor.shape());
+ }
+ if (onnxTensorInfo.type == OnnxJavaType.DOUBLE) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putDouble(tensor.get(i));
+ return OnnxTensor.createTensor(environment, buffer.rewind().asDoubleBuffer(), tensor.shape());
+ }
+ if (onnxTensorInfo.type == OnnxJavaType.INT8) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.put((byte) tensor.get(i));
+ return OnnxTensor.createTensor(environment, buffer.rewind(), tensor.shape());
+ }
+ if (onnxTensorInfo.type == OnnxJavaType.INT16) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putShort((short) tensor.get(i));
+ return OnnxTensor.createTensor(environment, buffer.rewind().asShortBuffer(), tensor.shape());
+ }
+ if (onnxTensorInfo.type == OnnxJavaType.INT32) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putInt((int) tensor.get(i));
+ return OnnxTensor.createTensor(environment, buffer.rewind().asIntBuffer(), tensor.shape());
+ }
+ if (onnxTensorInfo.type == OnnxJavaType.INT64) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putLong((long) tensor.get(i));
+ return OnnxTensor.createTensor(environment, buffer.rewind().asLongBuffer(), tensor.shape());
+ }
+ throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type);
+ }
+
+ static Tensor toVespaTensor(OnnxValue onnxValue) {
+ if ( ! (onnxValue instanceof OnnxTensor)) {
+ throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
+ }
+ OnnxTensor onnxTensor = (OnnxTensor) onnxValue;
+ TensorInfo tensorInfo = onnxTensor.getInfo();
+
+ TensorType type = toVespaType(onnxTensor.getInfo());
+ DimensionSizes sizes = sizesFromType(type);
+
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes);
+ if (tensorInfo.type == OnnxJavaType.FLOAT) {
+ FloatBuffer buffer = onnxTensor.getFloatBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.get());
+ }
+ else if (tensorInfo.type == OnnxJavaType.DOUBLE) {
+ DoubleBuffer buffer = onnxTensor.getDoubleBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.get());
+ }
+ else if (tensorInfo.type == OnnxJavaType.INT8) {
+ ByteBuffer buffer = onnxTensor.getByteBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.get());
+ }
+ else if (tensorInfo.type == OnnxJavaType.INT16) {
+ ShortBuffer buffer = onnxTensor.getShortBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.get());
+ }
+ else if (tensorInfo.type == OnnxJavaType.INT32) {
+ IntBuffer buffer = onnxTensor.getIntBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.get());
+ }
+ else if (tensorInfo.type == OnnxJavaType.INT64) {
+ LongBuffer buffer = onnxTensor.getLongBuffer();
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.get());
+ }
+ else {
+ throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
+ }
+ return builder.build();
+ }
+
+ static private DimensionSizes sizesFromType(TensorType type) {
+ DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
+ for (int i = 0; i < type.dimensions().size(); i++)
+ builder.set(i, type.dimensions().get(i).size().get());
+ return builder.build();
+ }
+
+ static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) {
+ return infoMap.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> toVespaType(e.getValue().getInfo())));
+ }
+
+ static TensorType toVespaType(ValueInfo valueInfo) {
+ TensorInfo tensorInfo = toTensorInfo(valueInfo);
+ TensorType.Builder builder = new TensorType.Builder(toVespaValueType(tensorInfo.onnxType));
+ long[] shape = tensorInfo.getShape();
+ for (int i = 0; i < shape.length; ++i) {
+ long dimSize = shape[i];
+ String dimName = "d" + i; // standard naming convention
+ if (dimSize > 0)
+ builder.indexed(dimName, dimSize);
+ else
+ builder.indexed(dimName); // unbound dimension for dim size -1
+ }
+ return builder.build();
+ }
+
+ static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
+ switch (onnxType) {
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
+ }
+ return TensorType.Value.DOUBLE;
+ }
+
+ static private TensorInfo toTensorInfo(ValueInfo valueInfo) {
+ if ( ! (valueInfo instanceof TensorInfo)) {
+ throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
+ }
+ return (TensorInfo) valueInfo;
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java
new file mode 100644
index 00000000000..e44ea96c534
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/package-info.java
@@ -0,0 +1,5 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+package ai.vespa.modelintegration.evaluator;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
new file mode 100644
index 00000000000..4b42e18d75e
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -0,0 +1,93 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.modelintegration.evaluator;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class OnnxEvaluatorTest {
+
+ @Test
+ public void testSimpleMoodel() {
+ OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/simple/simple.onnx");
+
+ // Input types
+ Map<String, TensorType> inputTypes = evaluator.getInputInfo();
+ assertEquals(inputTypes.get("query_tensor"), TensorType.fromSpec("tensor<float>(d0[1],d1[4])"));
+ assertEquals(inputTypes.get("attribute_tensor"), TensorType.fromSpec("tensor<float>(d0[4],d1[1])"));
+ assertEquals(inputTypes.get("bias_tensor"), TensorType.fromSpec("tensor<float>(d0[1],d1[1])"));
+
+ // Output types
+ Map<String, TensorType> outputTypes = evaluator.getOutputInfo();
+ assertEquals(outputTypes.get("output"), TensorType.fromSpec("tensor<float>(d0[1],d1[1])"));
+
+ // Evaluation
+ Map<String, Tensor> inputs = new HashMap<>();
+ inputs.put("query_tensor", Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]"));
+ inputs.put("attribute_tensor", Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]"));
+ inputs.put("bias_tensor", Tensor.from("tensor(d0[1],d1[1]):[1.0]"));
+
+ assertEquals(evaluator.evaluate(inputs).get("output"), Tensor.from("tensor(d0[1],d1[1]):[1.3]"));
+ assertEquals(evaluator.evaluate(inputs, "output"), Tensor.from("tensor(d0[1],d1[1]):[1.3]"));
+ }
+
+ @Test
+ public void testBatchDimension() {
+ OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/pytorch/one_layer.onnx");
+
+ // Input types
+ Map<String, TensorType> inputTypes = evaluator.getInputInfo();
+ assertEquals(inputTypes.get("input"), TensorType.fromSpec("tensor<float>(d0[],d1[3])"));
+
+ // Output types
+ Map<String, TensorType> outputTypes = evaluator.getOutputInfo();
+ assertEquals(outputTypes.get("output"), TensorType.fromSpec("tensor<float>(d0[],d1[1])"));
+
+ // Evaluation
+ Map<String, Tensor> inputs = new HashMap<>();
+ inputs.put("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]"));
+ assertEquals(evaluator.evaluate(inputs, "output"), Tensor.from("tensor<float>(d0[2],d1[1]):[0.6393113,0.67574286]"));
+ }
+
+ @Test
+ public void testMatMul() {
+ String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
+ String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
+ String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
+ assertEvaluate("simple/matmul.onnx", expected, input1, input2);
+ }
+
+ @Test
+ public void testTypes() {
+ assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
+ assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
+ assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
+ assertEvaluate("cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
+ assertEvaluate("cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
+
+ // ONNX Runtime 1.7.0 does not support much of bfloat16 yet
+ // assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
+ }
+
+ private void assertEvaluate(String model, String output, String... input) {
+ OnnxEvaluator evaluator = new OnnxEvaluator("src/test/models/onnx/" + model);
+ Map<String, Tensor> inputs = new HashMap<>();
+ for (int i = 0; i < input.length; ++i) {
+ inputs.put("input" + (i+1), Tensor.from(input[i]));
+ }
+ Tensor expected = Tensor.from(output);
+ Tensor result = evaluator.evaluate(inputs, "output");
+ assertEquals(expected, result);
+ assertEquals(expected.type().valueType(), result.type().valueType());
+ }
+
+}
diff --git a/model-integration/src/test/models/onnx/add_double.onnx b/model-integration/src/test/models/onnx/add_double.onnx
new file mode 100644
index 00000000000..9264d1eb9f9
--- /dev/null
+++ b/model-integration/src/test/models/onnx/add_double.onnx
@@ -0,0 +1,16 @@
+ add_double.py:f
+
+input1
+input2output"AddaddZ
+input1
+
+ 
+Z
+input2
+
+ 
+b
+output
+
+ 
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/add_double.py b/model-integration/src/test/models/onnx/add_double.py
new file mode 100755
index 00000000000..fa9aa48f4b2
--- /dev/null
+++ b/model-integration/src/test/models/onnx/add_double.py
@@ -0,0 +1,27 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.DOUBLE, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.DOUBLE, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.DOUBLE, [1])
+
+nodes = [
+ helper.make_node(
+ 'Add',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'add',
+ [
+ INPUT_1,
+ INPUT_2
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='add_double.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'add_double.onnx')
diff --git a/model-integration/src/test/models/onnx/add_float.onnx b/model-integration/src/test/models/onnx/add_float.onnx
new file mode 100644
index 00000000000..0e3ad8f900c
--- /dev/null
+++ b/model-integration/src/test/models/onnx/add_float.onnx
@@ -0,0 +1,16 @@
+ add_float.py:f
+
+input1
+input2output"AddaddZ
+input1
+
+
+Z
+input2
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/add_float.py b/model-integration/src/test/models/onnx/add_float.py
new file mode 100755
index 00000000000..e18b2c46d9d
--- /dev/null
+++ b/model-integration/src/test/models/onnx/add_float.py
@@ -0,0 +1,27 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Add',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'add',
+ [
+ INPUT_1,
+ INPUT_2
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='add_float.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'add_float.onnx')
diff --git a/model-integration/src/test/models/onnx/add_int64.onnx b/model-integration/src/test/models/onnx/add_int64.onnx
new file mode 100644
index 00000000000..7b3a9ec6b95
--- /dev/null
+++ b/model-integration/src/test/models/onnx/add_int64.onnx
@@ -0,0 +1,16 @@
+ add_int64.py:f
+
+input1
+input2output"AddaddZ
+input1
+
+
+Z
+input2
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/add_int64.py b/model-integration/src/test/models/onnx/add_int64.py
new file mode 100755
index 00000000000..87908e292a2
--- /dev/null
+++ b/model-integration/src/test/models/onnx/add_int64.py
@@ -0,0 +1,27 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.INT64, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.INT64, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.INT64, [1])
+
+nodes = [
+ helper.make_node(
+ 'Add',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'add',
+ [
+ INPUT_1,
+ INPUT_2
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='add_int64.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'add_int64.onnx')
diff --git a/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx b/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx
new file mode 100644
index 00000000000..cb19592abf4
--- /dev/null
+++ b/model-integration/src/test/models/onnx/cast_bfloat16_float.onnx
@@ -0,0 +1,12 @@
+cast_bfloat16_float.py:U
+!
+input1output"Cast*
+to castZ
+input1
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/cast_bfloat16_float.py b/model-integration/src/test/models/onnx/cast_bfloat16_float.py
new file mode 100755
index 00000000000..14b05347262
--- /dev/null
+++ b/model-integration/src/test/models/onnx/cast_bfloat16_float.py
@@ -0,0 +1,24 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.BFLOAT16, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Cast',
+ ['input1'],
+ ['output'],
+ to=TensorProto.FLOAT
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'cast',
+ [INPUT_1],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='cast_bfloat16_float.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'cast_bfloat16_float.onnx')
diff --git a/model-integration/src/test/models/onnx/cast_float_int8.onnx b/model-integration/src/test/models/onnx/cast_float_int8.onnx
new file mode 100644
index 00000000000..c30b023dd68
--- /dev/null
+++ b/model-integration/src/test/models/onnx/cast_float_int8.onnx
@@ -0,0 +1,12 @@
+cast_float_int8.py:U
+!
+input1output"Cast*
+to castZ
+input1
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/cast_float_int8.py b/model-integration/src/test/models/onnx/cast_float_int8.py
new file mode 100755
index 00000000000..bdc0850d033
--- /dev/null
+++ b/model-integration/src/test/models/onnx/cast_float_int8.py
@@ -0,0 +1,24 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.INT8, [1])
+
+nodes = [
+ helper.make_node(
+ 'Cast',
+ ['input1'],
+ ['output'],
+ to=TensorProto.INT8
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'cast',
+ [INPUT_1],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='cast_float_int8.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'cast_float_int8.onnx')
diff --git a/model-integration/src/test/models/onnx/cast_int8_float.onnx b/model-integration/src/test/models/onnx/cast_int8_float.onnx
new file mode 100644
index 00000000000..65aea4a36ae
--- /dev/null
+++ b/model-integration/src/test/models/onnx/cast_int8_float.onnx
@@ -0,0 +1,12 @@
+cast_int8_float.py:U
+!
+input1output"Cast*
+to castZ
+input1
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/cast_int8_float.py b/model-integration/src/test/models/onnx/cast_int8_float.py
new file mode 100755
index 00000000000..70bf2cf70ca
--- /dev/null
+++ b/model-integration/src/test/models/onnx/cast_int8_float.py
@@ -0,0 +1,24 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.INT8, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Cast',
+ ['input1'],
+ ['output'],
+ to=TensorProto.FLOAT
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'cast',
+ [INPUT_1],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='cast_int8_float.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'cast_int8_float.onnx')
diff --git a/model-integration/src/test/models/onnx/pytorch/one_layer.onnx b/model-integration/src/test/models/onnx/pytorch/one_layer.onnx
new file mode 100644
index 00000000000..dc9f664b943
--- /dev/null
+++ b/model-integration/src/test/models/onnx/pytorch/one_layer.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py b/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py
new file mode 100755
index 00000000000..1296d84e180
--- /dev/null
+++ b/model-integration/src/test/models/onnx/pytorch/pytorch_one_layer.py
@@ -0,0 +1,38 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import torch
+import torch.onnx
+
+
+class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.linear = torch.nn.Linear(in_features=3, out_features=1)
+ self.logistic = torch.nn.Sigmoid()
+
+ def forward(self, vec):
+ return self.logistic(self.linear(vec))
+
+
+def main():
+ model = MyModel()
+
+ # Omit training - just export randomly initialized network
+
+ data = torch.FloatTensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]])
+ torch.onnx.export(model,
+ data,
+ "one_layer.onnx",
+ input_names = ["input"],
+ output_names = ["output"],
+ dynamic_axes = {
+ "input": {0: "batch"},
+ "output": {0: "batch"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/model-integration/src/test/models/onnx/simple/matmul.onnx b/model-integration/src/test/models/onnx/simple/matmul.onnx
new file mode 100644
index 00000000000..9bb88406116
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/matmul.onnx
@@ -0,0 +1,16 @@
+ matmul.py:x
+
+input1
+input2output"MatMulmatmulZ
+input1
+ 
+
+Z
+input2
+ 
+
+b
+output
+ 
+
+B \ No newline at end of file
diff --git a/model-integration/src/test/models/onnx/simple/matmul.py b/model-integration/src/test/models/onnx/simple/matmul.py
new file mode 100755
index 00000000000..beec55e9f5a
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/matmul.py
@@ -0,0 +1,27 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [2, 3])
+INPUT2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [3, 4])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [2, 4])
+
+nodes = [
+ helper.make_node(
+ 'MatMul',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'matmul',
+ [
+ INPUT1,
+ INPUT2,
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='matmul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'matmul.onnx')