diff options
author | Lester Solbakken <lesters@oath.com> | 2020-02-02 17:39:44 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-02-02 17:39:44 +0100 |
commit | f656ff5c15d95905f48d5829278ec241f1941577 (patch) | |
tree | 41d1fd4f8bc22df172acac42bfc39abd136036c0 /model-evaluation/src | |
parent | 99f3a7193090cfcd6b5fdbbe612f53d892f9d86b (diff) |
Add support for importing LightGBM models
Diffstat (limited to 'model-evaluation/src')
5 files changed, 79 insertions, 4 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 994f6dd9b64..e373a54bcd1 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -61,6 +62,21 @@ public class FunctionEvaluator { } /** + * Binds the given variable referred in this expression to the given value. + * String values are not yet supported in tensors. + * + * @param name the variable to bind + * @param value the value this becomes bound to + * @return this for chaining + */ + public FunctionEvaluator bind(String name, String value) { + if (evaluated) + throw new IllegalStateException("Cannot bind a new value in a used evaluator"); + context.put(name, new StringValue(value)); + return this; + } + + /** * Sets the default value to use for variables which are not bound * * @param value the default value diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 5c353fcdf35..de23a8c6526 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -77,8 +77,14 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue))); for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) { - property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(), - Tensor.from(argument.getValue(), value))); + Optional<String> value = property(request, argument.getKey()); + if (value.isPresent()) { + try { + evaluator.bind(argument.getKey(), Tensor.from(argument.getValue(), value.get())); + } catch (IllegalArgumentException e) { + evaluator.bind(argument.getKey(), value.get()); // since we don't yet support tensors with string values + } + } } Tensor result = evaluator.evaluate(); return new Response(200, JsonFormat.encode(result)); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 9320ac3fad8..0d13b7d4660 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -25,7 +25,7 @@ public class MlModelsImportingTest { public void testImportingModels() { ModelTester tester = new ModelTester("src/test/resources/config/models/"); - assertEquals(4, tester.models().size()); + assertEquals(5, tester.models().size()); // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that { @@ -47,7 +47,24 @@ public class MlModelsImportingTest { } { + Model lightgbm = tester.models().get("lightgbm_regression"); + // Function + assertEquals(1, lightgbm.functions().size()); + ExpressionFunction function = tester.assertFunction("lightgbm_regression", + "(optimized sum of condition trees of size 480 bytes)", + lightgbm); + assertEquals("tensor()", function.returnType().get().toString()); + assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", commaSeparated(function.arguments())); + function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); + + // Evaluator + FunctionEvaluator evaluator = lightgbm.evaluatorOf(); + assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + assertEquals(1.91300868202, evaluator.evaluate().sum().asDouble(), delta); + } + + { Model onnxMnistSoftmax = tester.models().get("mnist_softmax"); // Function diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 629c82f410a..c9e49d3be02 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -56,7 +56,7 @@ public class ModelsEvaluationHandlerTest { public void testListModels() { String url = "http://localhost/model-evaluation/v1"; String expected = - "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}"; + "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}"; assertResponse(url, 200, expected); } @@ -94,6 +94,39 @@ public class ModelsEvaluationHandlerTest { } @Test + public void testLightGBMEvaluationWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval"; + String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testLightGBMEvaluationWithBindings() { + Map<String, String> properties = new HashMap<>(); + properties.put("numerical_1", "0.1"); + properties.put("numerical_2", "0.2"); + properties.put("categorical_1", "a"); + properties.put("categorical_2", "i"); + properties.put("non-existing-binding", "-1"); + String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval"; + String expected = "{\"cells\":[{\"address\":{},\"value\":2.054697758469921}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test + public void testLightGBMEvaluationWithMissingValue() { + Map<String, String> properties = new HashMap<>(); + properties.put("missing-value", "-1.0"); + properties.put("numerical_2", "0.5"); + properties.put("categorical_1", "b"); + properties.put("categorical_2", "j"); + properties.put("non-existing-binding", "-1"); + String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval"; + String expected = "{\"cells\":[{\"address\":{},\"value\":2.0745534018208094}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test public void testMnistSoftmaxDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax"; String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg index c25c5ba555b..385115b7cd4 100644 --- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -26,3 +26,6 @@ rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).input. rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])" rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type" rankprofile[3].fef.property[4].value "tensor(d1[10])" +rankprofile[4].name "lightgbm_regression" +rankprofile[4].fef.property[0].name "rankingExpression(lightgbm_regression).rankingScript" +rankprofile[4].fef.property[0].value "if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in ["k", "l", "m"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in ["d", "e"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in ["a", "b"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in ["k", "l", "m"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)" |