1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.handler;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.JsonFormat;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.Executors;
import static org.junit.Assert.assertEquals;
class HandlerTester {
private final ModelsEvaluationHandler handler;
HandlerTester(ModelsEvaluator models) {
this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor());
}
void assertResponse(String url, int expectedCode) {
assertResponse(url, Collections.emptyMap(), expectedCode, (String)null);
}
void assertResponse(String url, int expectedCode, String expectedResult) {
assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
}
void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) {
assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult, headers);
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
assertResponse(url, properties, expectedCode, expectedResult, Collections.emptyMap());
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
if (headers.size() > 0) {
headers.forEach((k,v) -> getRequest.getJDiscRequest().headers().add(k, v));
headers.forEach((k,v) -> postRequest.getJDiscRequest().headers().add(k, v));
}
assertResponse(getRequest, expectedCode, expectedResult);
assertResponse(postRequest, expectedCode, expectedResult);
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, Tensor expectedResult) {
HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
assertResponse(getRequest, expectedCode, expectedResult);
}
void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
HttpResponse response = handler.handle(request);
assertEquals("application/json", response.getContentType());
assertEquals(expectedCode, response.getStatus());
if (expectedResult != null) {
assertEquals(expectedResult, getContents(response));
}
}
void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) {
HttpResponse response = handler.handle(request);
assertEquals("application/json", response.getContentType());
assertEquals(expectedCode, response.getStatus());
if (expectedResult != null) {
String contents = getContents(response);
Tensor result = JsonFormat.decode(expectedResult.type(), contents.getBytes(StandardCharsets.UTF_8));
assertEquals(expectedResult, result);
}
}
private String getContents(HttpResponse response) {
try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) {
response.render(stream);
return stream.toString();
} catch (IOException e) {
throw new Error(e);
}
}
}
|