summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/test/java')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java46
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java14
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java38
3 files changed, 69 insertions, 29 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 3b16be311a0..00531e373ee 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -12,25 +12,53 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.Executors;
+import java.util.function.Predicate;
import static org.junit.Assert.assertEquals;
+import static com.yahoo.slime.SlimeUtils.jsonToSlime;
class HandlerTester {
private final ModelsEvaluationHandler handler;
+ private static Predicate<String> nop() {
+ return s -> true;
+ }
+ private static Predicate<String> matchString(String expected) {
+ return s -> expected.equals(s);
+ }
+ public static Predicate<String> matchJson(String... expectedJson) {
+ var jExp = String.join("\n", expectedJson).replaceAll("'", "\"");
+ var expected = jsonToSlime(jExp);
+ return s -> {
+ var got = jsonToSlime(s);
+ boolean result = got.equalTo(expected);
+ if (!result) {
+ System.err.println("got:");
+ System.err.println(got);
+ System.err.println("expected:");
+ System.err.println(expected);
+ }
+ return result;
+ };
+ }
+
HandlerTester(ModelsEvaluator models) {
this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor());
}
void assertResponse(String url, int expectedCode) {
- assertResponse(url, Map.of(), expectedCode, (String)null);
+ checkResponse(url, expectedCode, nop());
}
void assertResponse(String url, int expectedCode, String expectedResult) {
assertResponse(url, Map.of(), expectedCode, expectedResult);
}
+ void checkResponse(String url, int expectedCode, Predicate<String> check) {
+ checkResponse(url, Map.of(), expectedCode, check, Map.of());
+ }
+
void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) {
assertResponse(url, Map.of(), expectedCode, expectedResult, headers);
}
@@ -40,14 +68,18 @@ class HandlerTester {
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
+ checkResponse(url, properties, expectedCode, matchString(expectedResult), headers);
+ }
+
+ void checkResponse(String url, Map<String, String> properties, int expectedCode, Predicate<String> check, Map<String, String> headers) {
HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
if (headers.size() > 0) {
headers.forEach((k,v) -> getRequest.getJDiscRequest().headers().add(k, v));
headers.forEach((k,v) -> postRequest.getJDiscRequest().headers().add(k, v));
}
- assertResponse(getRequest, expectedCode, expectedResult);
- assertResponse(postRequest, expectedCode, expectedResult);
+ checkResponse(getRequest, expectedCode, check);
+ checkResponse(postRequest, expectedCode, check);
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, Tensor expectedResult) {
@@ -56,12 +88,14 @@ class HandlerTester {
}
void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
+ checkResponse(request, expectedCode, matchString(expectedResult));
+ }
+
+ void checkResponse(HttpRequest request, int expectedCode, Predicate<String> check) {
HttpResponse response = handler.handle(request);
assertEquals("application/json", response.getContentType());
assertEquals(expectedCode, response.getStatus());
- if (expectedResult != null) {
- assertEquals(expectedResult, getContents(response));
- }
+ assertEquals(true, check.test(getContents(response)));
}
void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index c52bf66626a..c0e5dd9ccda 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -14,7 +14,6 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import org.junit.BeforeClass;
-import org.junit.Ignore;
import org.junit.Test;
import java.util.HashMap;
@@ -262,7 +261,6 @@ public class ModelsEvaluationHandlerTest {
"tensor(a[2],b[2],c{},d[2]):{a:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], b:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]}");
}
- @Ignore
@Test
public void testMnistSavedEvaluateSpecificFunction() {
assumeTrue(OnnxEvaluator.isRuntimeAvailable());
@@ -270,7 +268,17 @@ public class ModelsEvaluationHandlerTest {
properties.put("input", inputTensor());
properties.put("format.tensors", "long");
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval";
- String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}";
+ Tensor expected = Tensor.from("tensor(d0[1],d1[10]):{"+
+ "{d0:0,d1:0}:-0.6319251673007533,"+
+ "{d0:0,d1:1}:-0.0007577770600619843,"+
+ "{d0:0,d1:2}:-0.010707969042025622,"+
+ "{d0:0,d1:3}:-0.6344759233540788,"+
+ "{d0:0,d1:4}:-0.17529455385847528,"+
+ "{d0:0,d1:5}:0.7490809723192187,"+
+ "{d0:0,d1:6}:-0.022790284182901716,"+
+ "{d0:0,d1:7}:0.26799240657608936,"+
+ "{d0:0,d1:8}:-0.3152438845465862,"+
+ "{d0:0,d1:9}:0.05949304847735276}");
handler.assertResponse(url, properties, 200, expected);
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
index cc3f2863068..29795fbcd95 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -12,7 +12,6 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import org.junit.BeforeClass;
-import org.junit.Ignore;
import org.junit.Test;
import java.io.File;
@@ -32,36 +31,35 @@ public class OnnxEvaluationHandlerTest {
handler = new HandlerTester(createModels());
}
- @Ignore
@Test
public void testListModels() {
String url = "http://localhost/model-evaluation/v1";
String expected = "{\"one_layer\":\"http://localhost/model-evaluation/v1/one_layer\"," +
"\"add_mul\":\"http://localhost/model-evaluation/v1/add_mul\"," +
"\"no_model\":\"http://localhost/model-evaluation/v1/no_model\"}";
- handler.assertResponse(url, 200, expected);
+ handler.checkResponse(url, 200, HandlerTester.matchJson(expected));
}
- @Ignore
@Test
public void testModelInfo() {
String url = "http://localhost/model-evaluation/v1/add_mul";
- String expected = "{\"model\":\"add_mul\",\"functions\":[" +
- "{\"function\":\"output1\"," +
- "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output1\"," +
- "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output1/eval\"," +
- "\"arguments\":[" +
- "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
- "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
- "]}," +
- "{\"function\":\"output2\"," +
- "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output2\"," +
- "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output2/eval\"," +
- "\"arguments\":[" +
- "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
- "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
- "]}]}";
- handler.assertResponse(url, 200, expected);
+ var check = HandlerTester.matchJson(
+ "{'model':'add_mul','functions':[",
+ " {'function':'output1',",
+ " 'info':'http://localhost/model-evaluation/v1/add_mul/output1',",
+ " 'eval':'http://localhost/model-evaluation/v1/add_mul/output1/eval',",
+ " 'arguments':[",
+ " {'name':'input1','type':'tensor<float>(d0[1])'},",
+ " {'name':'input2','type':'tensor<float>(d0[1])'}",
+ " ]},",
+ " {'function':'output2',",
+ " 'info':'http://localhost/model-evaluation/v1/add_mul/output2',",
+ " 'eval':'http://localhost/model-evaluation/v1/add_mul/output2/eval',",
+ " 'arguments':[",
+ " {'name':'input1','type':'tensor<float>(d0[1])'},",
+ " {'name':'input2','type':'tensor<float>(d0[1])'}",
+ " ]}]}");
+ handler.checkResponse(url, 200, check);
}
@Test