diff options
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java | 6 | ||||
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java | 13 |
2 files changed, 19 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 4ae96bfd62f..5c353fcdf35 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -26,6 +26,9 @@ import java.util.concurrent.Executor; public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { + /** A dash in this key ensures it does not collide with feature names */ + private static final String missingValueKey = "missing-value"; + public static final String API_ROOT = "model-evaluation"; public static final String VERSION_V1 = "v1"; public static final String EVALUATE = "eval"; @@ -70,6 +73,9 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { FunctionEvaluator evaluator = model.evaluatorOf(function); + + property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue))); + for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) { property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(), Tensor.from(argument.getValue(), value))); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 5d30fc93a4c..629c82f410a 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -81,6 +81,19 @@ public class ModelsEvaluationHandlerTest { } @Test + public void testXgBoostEvaluationWithMissingValue() { + Map<String, String> properties = new HashMap<>(); + properties.put("missing-value", "-1.0"); + properties.put("f56", "0.2"); + properties.put("f60", "0.3"); + properties.put("f109", "0.4"); + properties.put("non-existing-binding", "-1"); + String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; + String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test public void testMnistSoftmaxDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax"; String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; |