summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 13:19:39 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 13:19:39 +0200
commitde09ec45a3e4022ea36a142ce80dc880a975c177 (patch)
tree7def63569c614cf21a5d935c09799fe61e67591d
parentd410a107a048070d07afc199a12dff85bdea139e (diff)
Test validation explicitly
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java2
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java71
2 files changed, 71 insertions, 2 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 8ce44ef5ed2..8c728867f45 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -38,7 +38,7 @@ public class FunctionEvaluator {
*/
public FunctionEvaluator bind(String name, Tensor value) {
if (evaluated)
- throw new IllegalStateException("You cannot bind a value in a used evaluator");
+ throw new IllegalStateException("Cannot bind a new value in a used evaluator");
TensorType requiredType = function.argumentTypes().get(name);
if (requiredType == null)
throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 57a415e8894..500b942016c 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -5,11 +5,19 @@ import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.yolean.Exceptions;
import org.junit.Test;
+import java.util.ArrayList;
+import java.util.List;
+
import static org.junit.Assert.assertEquals;
/**
@@ -28,8 +36,69 @@ public class ModelsEvaluatorTest {
assertEquals(32.0, function.evaluate().asDouble(), delta);
}
+ @Test
+ public void testBindingValidation() {
+ List<ExpressionFunction> functions = new ArrayList<>();
+ ExpressionFunction function = new ExpressionFunction("test", RankingExpression.from("sum(arg1 * arg2)"));
+ function = function.withArgument("arg1", TensorType.fromSpec("tensor(d0[1])"));
+ function = function.withArgument("arg2", TensorType.fromSpec("tensor(d1{})"));
+ functions.add(function);
+ Model model = new Model("test-model", functions);
+
+ try { // No bindings
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.evaluate();
+ }
+ catch (IllegalStateException e) {
+ assertEquals("Missing argument 'arg2': Must be bound to a value of type tensor(d1{})",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Just one binding
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("arg2", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}"));
+ evaluator.evaluate();
+ }
+ catch (IllegalStateException e) {
+ assertEquals("Missing argument 'arg1': Must be bound to a value of type tensor(d0[1])",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Wrong binding argument
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("argNone", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}"));
+ evaluator.evaluate();
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Wrong binding type
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d3{})"), "{{d3:foo}:0.1}"));
+ evaluator.evaluate();
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("'arg1' must be of type tensor(d0[1]), not tensor(d3{})",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Attempt to reuse evaluator
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}"));
+ evaluator.bind("arg2", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}"));
+ evaluator.evaluate();
+ evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}"));
+ }
+ catch (IllegalStateException e) {
+ assertEquals("Cannot bind a new value in a used evaluator",
+ Exceptions.toMessageString(e));
+ }
+
+ }
+
// TODO: Test argument-less function
- // TODO: Test that binding nonexisting variable doesn't work
// TODO: Test that rebinding doesn't work
// TODO: Test with nested functions