diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-10-01 12:52:30 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-10-01 12:52:30 +0200 |
commit | d410a107a048070d07afc199a12dff85bdea139e (patch) | |
tree | e26f33065ebfb612a8c83d840de5661523590d1d /model-evaluation/src/main | |
parent | 50bc3b3c198d29374448cc3eac73fbb26e42cab0 (diff) |
Validate all bindings
Diffstat (limited to 'model-evaluation/src/main')
3 files changed, 33 insertions, 4 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 1412936d4a0..8ce44ef5ed2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -3,10 +3,14 @@ package ai.vespa.models.evaluation; import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Map; +import java.util.stream.Collectors; + /** * An evaluator which can be used to evaluate a single function once. * @@ -35,6 +39,14 @@ public class FunctionEvaluator { public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) throw new IllegalStateException("You cannot bind a value in a used evaluator"); + TensorType requiredType = function.argumentTypes().get(name); + if (requiredType == null) + throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + + ". Expected arguments: " + function.argumentTypes().entrySet().stream() + .map(e -> e.getKey() + ": " + e.getValue()) + .collect(Collectors.joining(", "))); + if ( ! value.type().isAssignableTo(requiredType)) + throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); context.put(name, new TensorValue(value)); return this; } @@ -52,10 +64,19 @@ public class FunctionEvaluator { } public Tensor evaluate() { + for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { + if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0) + if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue) + throw new IllegalStateException("Missing argument '" + argument.getKey() + + "': Must be bound to a value of type " + argument.getValue()); + } evaluated = true; return function.getBody().evaluate(context).asTensor(); } + /** Returns the function evaluated by this */ + public ExpressionFunction function() { return function; } + public LazyArrayContext context() { return context; } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d144411127e..093d487c31f 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -32,6 +32,8 @@ import java.util.Set; */ public final class LazyArrayContext extends Context implements ContextIndex { + public final static Value defaultContextValue = DoubleValue.zero; + private final IndexedBindings indexedBindings; private LazyArrayContext(IndexedBindings indexedBindings) { @@ -167,7 +169,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; - Arrays.fill(values, DoubleValue.zero); + Arrays.fill(values, defaultContextValue); int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 683a1f345d8..6edcd84272e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -10,13 +10,16 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.JsonFormat; +import com.yahoo.yolean.Exceptions; import java.io.IOException; import java.io.OutputStream; import java.net.URI; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Map; import java.util.Optional; import java.util.concurrent.Executor; @@ -60,14 +63,17 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { return listModelInformation(request, model, function); } catch (IllegalArgumentException e) { - return new ErrorResponse(404, e.getMessage()); + return new ErrorResponse(404, Exceptions.toMessageString(e)); + } catch (IllegalStateException e) { // On missing bindings + return new ErrorResponse(400, Exceptions.toMessageString(e)); } } private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { FunctionEvaluator evaluator = model.evaluatorOf(function); - for (String bindingName : evaluator.context().names()) { - property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s))); + for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) { + property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(), + Tensor.from(argument.getValue(), value))); } Tensor result = evaluator.evaluate(); return new Response(200, JsonFormat.encode(result)); |