diff options
author | Lester Solbakken <lesters@oath.com> | 2019-10-11 09:42:05 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-10-11 09:42:05 +0200 |
commit | 3acec4a95bc2f75f8384bde14d35f3a5c073460b (patch) | |
tree | 14d3874a55ebf1493842ba8a8a8f029ba6f1530b /model-evaluation/src/main | |
parent | d78ad93a081552d5f671e266a15c0de770305c92 (diff) |
Set default missing value to NaN for model evaluation
Diffstat (limited to 'model-evaluation/src/main')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java | 4 | ||||
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | 9 |
2 files changed, 9 insertions, 4 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 8c728867f45..9db26d7ecd8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -65,8 +65,8 @@ public class FunctionEvaluator { public Tensor evaluate() { for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0) - if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue) + if (argument.getValue().rank() == 0) continue; // Scalar arguments can be skipped (defaults to 0) + if (context.isMissing(argument.getKey())) throw new IllegalStateException("Missing argument '" + argument.getKey() + "': Must be bound to a value of type " + argument.getValue()); } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 55da2e78894..bc80989f030 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -6,7 +6,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -27,6 +29,9 @@ public class Model { /** The prefix generated by mode-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; + /** Default value to return if value is not supplied */ + private final static Value missingValue = DoubleValue.frozen(Double.NaN); + private final String name; /** Free functions */ @@ -61,7 +66,7 @@ public class Model { ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this); + LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this, missingValue); contextBuilder.put(function.getValue().getName(), context); if ( ! function.getValue().returnType().isPresent()) { functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); @@ -135,7 +140,7 @@ public class Model { return context; } - /** Returns the function withe the given name, or null if none */ // TODO: Parameter overloading? + /** Returns the function with the given name, or null if none */ // TODO: Parameter overloading? ExpressionFunction function(String name) { for (ExpressionFunction function : functions) if (function.getName().equals(name)) |