summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 12:52:30 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 12:52:30 +0200
commitd410a107a048070d07afc199a12dff85bdea139e (patch)
treee26f33065ebfb612a8c83d840de5661523590d1d /model-evaluation/src/main
parent50bc3b3c198d29374448cc3eac73fbb26e42cab0 (diff)
Validate all bindings
Diffstat (limited to 'model-evaluation/src/main')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java21
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java12
3 files changed, 33 insertions, 4 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 1412936d4a0..8ce44ef5ed2 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -3,10 +3,14 @@ package ai.vespa.models.evaluation;
import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.Map;
+import java.util.stream.Collectors;
+
/**
* An evaluator which can be used to evaluate a single function once.
*
@@ -35,6 +39,14 @@ public class FunctionEvaluator {
public FunctionEvaluator bind(String name, Tensor value) {
if (evaluated)
throw new IllegalStateException("You cannot bind a value in a used evaluator");
+ TensorType requiredType = function.argumentTypes().get(name);
+ if (requiredType == null)
+ throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
+ ". Expected arguments: " + function.argumentTypes().entrySet().stream()
+ .map(e -> e.getKey() + ": " + e.getValue())
+ .collect(Collectors.joining(", ")));
+ if ( ! value.type().isAssignableTo(requiredType))
+ throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
context.put(name, new TensorValue(value));
return this;
}
@@ -52,10 +64,19 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
+ for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
+ if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0)
+ if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue)
+ throw new IllegalStateException("Missing argument '" + argument.getKey() +
+ "': Must be bound to a value of type " + argument.getValue());
+ }
evaluated = true;
return function.getBody().evaluate(context).asTensor();
}
+ /** Returns the function evaluated by this */
+ public ExpressionFunction function() { return function; }
+
public LazyArrayContext context() { return context; }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index d144411127e..093d487c31f 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -32,6 +32,8 @@ import java.util.Set;
*/
public final class LazyArrayContext extends Context implements ContextIndex {
+ public final static Value defaultContextValue = DoubleValue.zero;
+
private final IndexedBindings indexedBindings;
private LazyArrayContext(IndexedBindings indexedBindings) {
@@ -167,7 +169,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
- Arrays.fill(values, DoubleValue.zero);
+ Arrays.fill(values, defaultContextValue);
int i = 0;
ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index 683a1f345d8..6edcd84272e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -10,13 +10,16 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
+import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Arrays;
+import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;
@@ -60,14 +63,17 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
return listModelInformation(request, model, function);
} catch (IllegalArgumentException e) {
- return new ErrorResponse(404, e.getMessage());
+ return new ErrorResponse(404, Exceptions.toMessageString(e));
+ } catch (IllegalStateException e) { // On missing bindings
+ return new ErrorResponse(400, Exceptions.toMessageString(e));
}
}
private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) {
FunctionEvaluator evaluator = model.evaluatorOf(function);
- for (String bindingName : evaluator.context().names()) {
- property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s)));
+ for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) {
+ property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(),
+ Tensor.from(argument.getValue(), value)));
}
Tensor result = evaluator.evaluate();
return new Response(200, JsonFormat.encode(result));