summaryrefslogtreecommitdiffstats
path: root/application
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 17:02:23 +0100
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-27 18:13:08 +0100
commit5271b5d7241aa2aa2538b2072b8cae9b8f3d689a (patch)
tree12f025b12e86e5f9490b74dd2cae68283f779e67 /application
parent6b40c6053b8542ae20a5bbe669f84f2d478fd697 (diff)
Replace `OnnxEvaluatorCache` with OnnxRuntime
Require an `OnnxRuntime` instance to create `OnnxEvaluator` instances. Cache underlying `OrtSession` instead of `OnnxEvaluator`. Move static helpers for checking Onnx runtime availability from `OnnxEvaluator` to `OnnxRuntime`.
Diffstat (limited to 'application')
-rw-r--r--application/pom.xml6
-rw-r--r--application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java10
2 files changed, 11 insertions, 5 deletions
diff --git a/application/pom.xml b/application/pom.xml
index 2193f0fe2e3..236bcb6d81a 100644
--- a/application/pom.xml
+++ b/application/pom.xml
@@ -182,6 +182,12 @@
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <!-- Required for ContainerModelEvaluationTest -->
+ <groupId>com.microsoft.onnxruntime</groupId>
+ <artifactId>onnxruntime</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
index f838d7a5481..cd5fd42a81a 100644
--- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
+++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.application.container;
-import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.application.Application;
import com.yahoo.application.Networking;
import com.yahoo.application.container.handler.Request;
@@ -40,7 +40,7 @@ public class ContainerModelEvaluationTest {
@Test
void testCreateApplicationInstanceWithModelEvaluation() {
- assumeTrue(OnnxEvaluator.isRuntimeAvailable());
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
try (Application application =
Application.fromApplicationPackage(new File("src/test/app-packages/model-evaluation"),
Networking.disable)) {
@@ -54,17 +54,17 @@ public class ContainerModelEvaluationTest {
}
{
- String expected = "{\"cells\":[{\"address\":{},\"value\":2.496898}]}";
+ String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":2.496898}]}";
assertResponse("http://localhost/model-evaluation/v1/xgboost_xgboost_2_2/eval?format.tensors=long", expected, jdisc);
}
{
- String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
+ String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
assertResponse("http://localhost/model-evaluation/v1/lightgbm_regression/eval?format.tensors=long", expected, jdisc);
}
{
- String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.3671652674674988}]}";
+ String expected = "{\"type\":\"tensor<float>(d0[3])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.36716532707214355}]}";
assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc);
assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default.output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc);
assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc);