aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-02-02 17:39:44 +0100
committerLester Solbakken <lesters@oath.com>2020-02-02 17:39:44 +0100
commitf656ff5c15d95905f48d5829278ec241f1941577 (patch)
tree41d1fd4f8bc22df172acac42bfc39abd136036c0 /model-evaluation
parent99f3a7193090cfcd6b5fdbbe612f53d892f9d86b (diff)
Add support for importing LightGBM models
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/abi-spec.json1
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java16
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java10
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java19
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java35
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg3
6 files changed, 80 insertions, 4 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 5a75c8b31ea..d465464de7f 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -8,6 +8,7 @@
"methods": [
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)",
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, double)",
+ "public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, java.lang.String)",
"public ai.vespa.models.evaluation.FunctionEvaluator setMissingValue(com.yahoo.tensor.Tensor)",
"public ai.vespa.models.evaluation.FunctionEvaluator setMissingValue(double)",
"public com.yahoo.tensor.Tensor evaluate()",
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 994f6dd9b64..e373a54bcd1 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -2,6 +2,7 @@
package ai.vespa.models.evaluation;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -61,6 +62,21 @@ public class FunctionEvaluator {
}
/**
+ * Binds the given variable referred in this expression to the given value.
+ * String values are not yet supported in tensors.
+ *
+ * @param name the variable to bind
+ * @param value the value this becomes bound to
+ * @return this for chaining
+ */
+ public FunctionEvaluator bind(String name, String value) {
+ if (evaluated)
+ throw new IllegalStateException("Cannot bind a new value in a used evaluator");
+ context.put(name, new StringValue(value));
+ return this;
+ }
+
+ /**
* Sets the default value to use for variables which are not bound
*
* @param value the default value
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index 5c353fcdf35..de23a8c6526 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -77,8 +77,14 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue)));
for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) {
- property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(),
- Tensor.from(argument.getValue(), value)));
+ Optional<String> value = property(request, argument.getKey());
+ if (value.isPresent()) {
+ try {
+ evaluator.bind(argument.getKey(), Tensor.from(argument.getValue(), value.get()));
+ } catch (IllegalArgumentException e) {
+ evaluator.bind(argument.getKey(), value.get()); // since we don't yet support tensors with string values
+ }
+ }
}
Tensor result = evaluator.evaluate();
return new Response(200, JsonFormat.encode(result));
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 9320ac3fad8..0d13b7d4660 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -25,7 +25,7 @@ public class MlModelsImportingTest {
public void testImportingModels() {
ModelTester tester = new ModelTester("src/test/resources/config/models/");
- assertEquals(4, tester.models().size());
+ assertEquals(5, tester.models().size());
// TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
{
@@ -47,7 +47,24 @@ public class MlModelsImportingTest {
}
{
+ Model lightgbm = tester.models().get("lightgbm_regression");
+ // Function
+ assertEquals(1, lightgbm.functions().size());
+ ExpressionFunction function = tester.assertFunction("lightgbm_regression",
+ "(optimized sum of condition trees of size 480 bytes)",
+ lightgbm);
+ assertEquals("tensor()", function.returnType().get().toString());
+ assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", commaSeparated(function.arguments()));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
+
+ // Evaluator
+ FunctionEvaluator evaluator = lightgbm.evaluatorOf();
+ assertEquals("categorical_1, categorical_2, numerical_1, numerical_2", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ assertEquals(1.91300868202, evaluator.evaluate().sum().asDouble(), delta);
+ }
+
+ {
Model onnxMnistSoftmax = tester.models().get("mnist_softmax");
// Function
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 629c82f410a..c9e49d3be02 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -56,7 +56,7 @@ public class ModelsEvaluationHandlerTest {
public void testListModels() {
String url = "http://localhost/model-evaluation/v1";
String expected =
- "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}";
+ "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
assertResponse(url, 200, expected);
}
@@ -94,6 +94,39 @@ public class ModelsEvaluationHandlerTest {
}
@Test
+ public void testLightGBMEvaluationWithoutBindings() {
+ String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testLightGBMEvaluationWithBindings() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("numerical_1", "0.1");
+ properties.put("numerical_2", "0.2");
+ properties.put("categorical_1", "a");
+ properties.put("categorical_2", "i");
+ properties.put("non-existing-binding", "-1");
+ String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":2.054697758469921}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testLightGBMEvaluationWithMissingValue() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("missing-value", "-1.0");
+ properties.put("numerical_2", "0.5");
+ properties.put("categorical_1", "b");
+ properties.put("categorical_2", "j");
+ properties.put("non-existing-binding", "-1");
+ String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":2.0745534018208094}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
public void testMnistSoftmaxDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax";
String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
index c25c5ba555b..385115b7cd4 100644
--- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -26,3 +26,6 @@ rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).input.
rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])"
rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type"
rankprofile[3].fef.property[4].value "tensor(d1[10])"
+rankprofile[4].name "lightgbm_regression"
+rankprofile[4].fef.property[0].name "rankingExpression(lightgbm_regression).rankingScript"
+rankprofile[4].fef.property[0].value "if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in ["k", "l", "m"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in ["d", "e"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in ["a", "b"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in ["k", "l", "m"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)"