diff options
author | Lester Solbakken <lesters@oath.com> | 2022-02-23 14:45:31 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2022-03-10 14:41:08 +0100 |
commit | 3c87f16724e2d4e88809b8c56fd57397b410c8a9 (patch) | |
tree | a9ec8a3e03eabbc469e75b3e14e4bf05d85ea77a /config-model/src/test/java/com | |
parent | 80abc0659f07445536d92c59fae58dfb1f0ecae8 (diff) |
Add probing of ONNX models for type resolving in Java
Diffstat (limited to 'config-model/src/test/java/com')
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java new file mode 100644 index 00000000000..55e2f6f018c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/OnnxModelProbeTest.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.ml; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import java.io.IOException; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +public class OnnxModelProbeTest { + + @Test + public void testProbedOutputTypes() throws IOException { + + Path appDir = Path.fromString("src/test/cfg/application/onnx_probe"); + Path storedAppDir = appDir.append("copy"); + try { + FilesApplicationPackage app = FilesApplicationPackage.fromFile(appDir.toFile()); + Path modelPath = Path.fromString("files/dynamic_model.onnx"); + String output = "out"; + Map<String, TensorType> inputTypes = Map.of( + "input1", TensorType.fromSpec("tensor<float>(d0[1],d1[2])"), + "input2", TensorType.fromSpec("tensor<float>(d0[1],d1[2])")); + TensorType expected = TensorType.fromSpec("tensor<float>(d0[1],d1[2],d2[2])"); + + // Can't test model probing directly as 'vespa-analyze-onnx-model' is not available + + OnnxModelProbe.writeProbedOutputType(app, modelPath, output, inputTypes, expected); + + // Test loading from generated info + storedAppDir.toFile().mkdirs(); + IOUtils.copyDirectory(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + app = FilesApplicationPackage.fromFile(storedAppDir.toFile()); + TensorType outputType = OnnxModelProbe.probeModel(app, modelPath, output, inputTypes); + assertEquals(outputType, expected); + + } finally { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + +} |