aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-02-23 14:45:31 +0100
committerLester Solbakken <lesters@oath.com>2022-02-23 14:45:31 +0100
commitf59b5aa627018e6574249488c802fbccd8360688 (patch)
treec027015867cc52dbd2d2dccffc0ea08f7a274e14 /config-model/src/test
parentd4d0db00c51be4b195721b331c3b958bf07b882d (diff)
Add probing of ONNX models for type resolving in Java
Diffstat (limited to 'config-model/src/test')
-rwxr-xr-xconfig-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py19
-rw-r--r--config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx21
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java50
3 files changed, 90 insertions, 0 deletions
diff --git a/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py b/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py
new file mode 100755
index 00000000000..b493e394ee4
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx_probe/files/create_dynamic_model.py
@@ -0,0 +1,19 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+import numpy as np
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, ["batch", 2])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, ["batch", 2])
+OUTPUT = helper.make_tensor_value_info('out', TensorProto.FLOAT, ["batch", "dim1", "dim2"])
+
+SHAPE = helper.make_tensor('shape', TensorProto.INT64, dims=[3], vals=np.array([1,2,2]).astype(np.int64))
+
+nodes = [
+ helper.make_node('Concat', ['input1', 'input2'], ['concat'], axis=1),
+ helper.make_node('Reshape', ['concat', 'shape'], ['out']),
+]
+graph_def = helper.make_graph(nodes, 'simple_scoring', [INPUT_1, INPUT_2], [OUTPUT], [SHAPE])
+model_def = helper.make_model(graph_def, producer_name='create_dynamic_model.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'dynamic_model_2.onnx')
diff --git a/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx b/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx
new file mode 100644
index 00000000000..28c600e2a09
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx_probe/files/dynamic_model.onnx
@@ -0,0 +1,21 @@
+create_dynamic_model_2.py:Ö
+-
+input1
+input2concat"Concat*
+axis 
+
+concat
+shapeout"Reshapesimple_scoring*:BshapeZ
+input1
+
+batch
+Z
+input2
+
+batch
+b&
+out
+
+batch
+dim1
+dim2B \ No newline at end of file
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java
new file mode 100644
index 00000000000..55e2f6f018c
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java
@@ -0,0 +1,50 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.ml;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+public class OnnxModelProbeTest {
+
+ @Test
+ public void testProbedOutputTypes() throws IOException {
+
+ Path appDir = Path.fromString("src/test/cfg/application/onnx_probe");
+ Path storedAppDir = appDir.append("copy");
+ try {
+ FilesApplicationPackage app = FilesApplicationPackage.fromFile(appDir.toFile());
+ Path modelPath = Path.fromString("files/dynamic_model.onnx");
+ String output = "out";
+ Map<String, TensorType> inputTypes = Map.of(
+ "input1", TensorType.fromSpec("tensor<float>(d0[1],d1[2])"),
+ "input2", TensorType.fromSpec("tensor<float>(d0[1],d1[2])"));
+ TensorType expected = TensorType.fromSpec("tensor<float>(d0[1],d1[2],d2[2])");
+
+ // Can't test model probing directly as 'vespa-analyze-onnx-model' is not available
+
+ OnnxModelProbe.writeProbedOutputType(app, modelPath, output, inputTypes, expected);
+
+ // Test loading from generated info
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ app = FilesApplicationPackage.fromFile(storedAppDir.toFile());
+ TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes);
+ assertEquals(outputType, expected);
+
+ } finally {
+ IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+}