aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-02-02 17:39:44 +0100
committerLester Solbakken <lesters@oath.com>2020-02-02 17:39:44 +0100
commitf656ff5c15d95905f48d5829278ec241f1941577 (patch)
tree41d1fd4f8bc22df172acac42bfc39abd136036c0 /model-evaluation/src/test
parent99f3a7193090cfcd6b5fdbbe612f53d892f9d86b (diff)
Add support for importing LightGBM models
Diffstat (limited to 'model-evaluation/src/test')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java19
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java35
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg3
3 files changed, 55 insertions, 2 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 9320ac3fad8..0d13b7d4660 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -25,7 +25,7 @@ public class MlModelsImportingTest {
public void testImportingModels() {
ModelTester tester = new ModelTester("src/test/resources/config/models/");
- assertEquals(4, tester.models().size());
+ assertEquals(5, tester.models().size());
// TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
{
@@ -47,7 +47,24 @@ public class MlModelsImportingTest {
}
{
+ Model lightgbm = tester.models().get("lightgbm_regression");
+ // Function
+ assertEquals(1, lightgbm.functions().size());
+ ExpressionFunction function = tester.assertFunction("lightgbm_regression",
+ "(optimized sum of condition trees of size 480 bytes)",
+ lightgbm);
+ assertEquals("tensor()", function.returnType().get().toString());
+ assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", commaSeparated(function.arguments()));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
+
+ // Evaluator
+ FunctionEvaluator evaluator = lightgbm.evaluatorOf();
+ assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ assertEquals(1.91300868202, evaluator.evaluate().sum().asDouble(), delta);
+ }
+
+ {
Model onnxMnistSoftmax = tester.models().get("mnist_softmax");
// Function
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 629c82f410a..c9e49d3be02 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -56,7 +56,7 @@ public class ModelsEvaluationHandlerTest {
public void testListModels() {
String url = "http://localhost/model-evaluation/v1";
String expected =
- "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}";
+ "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
assertResponse(url, 200, expected);
}
@@ -94,6 +94,39 @@ public class ModelsEvaluationHandlerTest {
}
@Test
+ public void testLightGBMEvaluationWithoutBindings() {
+ String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testLightGBMEvaluationWithBindings() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("numerical_1", "0.1");
+ properties.put("numerical_2", "0.2");
+ properties.put("categorical_1", "a");
+ properties.put("categorical_2", "i");
+ properties.put("non-existing-binding", "-1");
+ String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":2.054697758469921}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testLightGBMEvaluationWithMissingValue() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("missing-value", "-1.0");
+ properties.put("numerical_2", "0.5");
+ properties.put("categorical_1", "b");
+ properties.put("categorical_2", "j");
+ properties.put("non-existing-binding", "-1");
+ String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":2.0745534018208094}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
public void testMnistSoftmaxDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax";
String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
index c25c5ba555b..385115b7cd4 100644
--- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -26,3 +26,6 @@ rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).input.
rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])"
rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type"
rankprofile[3].fef.property[4].value "tensor(d1[10])"
+rankprofile[4].name "lightgbm_regression"
+rankprofile[4].fef.property[0].name "rankingExpression(lightgbm_regression).rankingScript"
+rankprofile[4].fef.property[0].value "if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in ["k", "l", "m"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in ["d", "e"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in ["a", "b"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in ["k", "l", "m"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)"