aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-10-11 09:42:05 +0200
committerLester Solbakken <lesters@oath.com>2019-10-11 09:42:05 +0200
commit3acec4a95bc2f75f8384bde14d35f3a5c073460b (patch)
tree14d3874a55ebf1493842ba8a8a8f029ba6f1530b /model-evaluation
parentd78ad93a081552d5f671e266a15c0de770305c92 (diff)
Set default missing value to NaN for model evaluation
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/abi-spec.json4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java9
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java2
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java2
5 files changed, 12 insertions, 9 deletions
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 2657779c5cf..c79883450e6 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -37,9 +37,7 @@
"public java.util.Set arguments()",
"public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)"
],
- "fields": [
- "public static final com.yahoo.searchlib.rankingexpression.evaluation.Value defaultContextValue"
- ]
+ "fields": []
},
"ai.vespa.models.evaluation.Model": {
"superClass": "java.lang.Object",
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 8c728867f45..9db26d7ecd8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -65,8 +65,8 @@ public class FunctionEvaluator {
public Tensor evaluate() {
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0)
- if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue)
+ if (argument.getValue().rank() == 0) continue; // Scalar arguments can be skipped (defaults to 0)
+ if (context.isMissing(argument.getKey()))
throw new IllegalStateException("Missing argument '" + argument.getKey() +
"': Must be bound to a value of type " + argument.getValue());
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 55da2e78894..bc80989f030 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -6,7 +6,9 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -27,6 +29,9 @@ public class Model {
/** The prefix generated by mode-integration/../IntermediateOperation */
private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
+ /** Default value to return if value is not supplied */
+ private final static Value missingValue = DoubleValue.frozen(Double.NaN);
+
private final String name;
/** Free functions */
@@ -61,7 +66,7 @@ public class Model {
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this);
+ LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this, missingValue);
contextBuilder.put(function.getValue().getName(), context);
if ( ! function.getValue().returnType().isPresent()) {
functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
@@ -135,7 +140,7 @@ public class Model {
return context;
}
- /** Returns the function withe the given name, or null if none */ // TODO: Parameter overloading?
+ /** Returns the function with the given name, or null if none */ // TODO: Parameter overloading?
ExpressionFunction function(String name) {
for (ExpressionFunction function : functions)
if (function.getName().equals(name))
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index db892dce593..9320ac3fad8 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -43,7 +43,7 @@ public class MlModelsImportingTest {
// Evaluator
FunctionEvaluator evaluator = xgboost.evaluatorOf();
assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta);
+ assertEquals(-4.37659, evaluator.evaluate().sum().asDouble(), delta);
}
{
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 23f0fa7a571..95f9888024a 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -63,7 +63,7 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testXgBoostEvaluationWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function
- String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}";
assertResponse(url, 200, expected);
}