diff options
Diffstat (limited to 'application')
-rw-r--r-- | application/pom.xml | 6 | ||||
-rw-r--r-- | application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java | 10 |
2 files changed, 11 insertions, 5 deletions
diff --git a/application/pom.xml b/application/pom.xml index 2193f0fe2e3..236bcb6d81a 100644 --- a/application/pom.xml +++ b/application/pom.xml @@ -182,6 +182,12 @@ <artifactId>junit-jupiter-engine</artifactId> <scope>test</scope> </dependency> + <dependency> + <!-- Required for ContainerModelEvaluationTest --> + <groupId>com.microsoft.onnxruntime</groupId> + <artifactId>onnxruntime</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> diff --git a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java index f838d7a5481..cd5fd42a81a 100644 --- a/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java +++ b/application/src/test/java/com/yahoo/application/container/ContainerModelEvaluationTest.java @@ -1,7 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.application.container; -import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.application.Application; import com.yahoo.application.Networking; import com.yahoo.application.container.handler.Request; @@ -40,7 +40,7 @@ public class ContainerModelEvaluationTest { @Test void testCreateApplicationInstanceWithModelEvaluation() { - assumeTrue(OnnxEvaluator.isRuntimeAvailable()); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); try (Application application = Application.fromApplicationPackage(new File("src/test/app-packages/model-evaluation"), Networking.disable)) { @@ -54,17 +54,17 @@ public class ContainerModelEvaluationTest { } { - String expected = "{\"cells\":[{\"address\":{},\"value\":2.496898}]}"; + String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":2.496898}]}"; assertResponse("http://localhost/model-evaluation/v1/xgboost_xgboost_2_2/eval?format.tensors=long", expected, jdisc); } { - String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}"; + String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}"; assertResponse("http://localhost/model-evaluation/v1/lightgbm_regression/eval?format.tensors=long", expected, jdisc); } { - String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.3671652674674988}]}"; + String expected = "{\"type\":\"tensor<float>(d0[3])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":0.3006095290184021},{\"address\":{\"d0\":\"1\"},\"value\":0.33222490549087524},{\"address\":{\"d0\":\"2\"},\"value\":0.36716532707214355}]}"; assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc); assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default.output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc); assertResponse("http://localhost/model-evaluation/v1/onnx_softmax_func/default/output/eval?format.tensors=long&input=" + inputTensor(), expected, jdisc); |