summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-03 09:51:40 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-03 09:51:40 +0200
commit2bcae10ce88746923d34e1db46e5e4b3ce835bd2 (patch)
treee75250b5620f7cf151eeb6ba3bcdd28e12a35bc1 /model-integration
parent716c293fe3e766c6af8e3e62d65ea40f0bf8369d (diff)
Test evaluation
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java26
-rw-r--r--model-integration/src/test/models/vespa/example.model4
2 files changed, 26 insertions, 4 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
index 4c8890f6476..767af147ad7 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
@@ -3,6 +3,12 @@ package ai.vespa.rankingexpression.importer.vespa;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
import org.junit.Test;
import java.util.List;
@@ -32,14 +38,17 @@ public class VespaImportTestCase {
assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.largeConstants().get("constant1asLarge"));
assertEquals(2, model.expressions().size());
- assertEquals("max(reduce(input1 * input2, sum, name) * constant1,x) * constant2",
+ assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2",
model.expressions().get("foo1").getRoot().toString());
- assertEquals("max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * constant2",
+ assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1asLarge, max, x) * constant2",
model.expressions().get("foo2").getRoot().toString());
List<ImportedMlFunction> functions = model.outputExpressions();
assertEquals(2, functions.size());
ImportedMlFunction foo1Function = functions.get(0);
+ assertEquals("foo1", foo1Function.name());
+ assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2", foo1Function.expression());
+ assertEquals("tensor():{202.5}", evaluate(foo1Function, "{{name:a, x:0}: 1, {name:a, x:1}: 2, {name:a, x:2}: 3}").toString());
assertEquals(2, foo1Function.arguments().size());
assertTrue(foo1Function.arguments().contains("input1"));
assertTrue(foo1Function.arguments().contains("input2"));
@@ -80,4 +89,17 @@ public class VespaImportTestCase {
return model;
}
+ private Tensor evaluate(ImportedMlFunction function, String input1Argument) {
+ try {
+ MapContext context = new MapContext();
+ context.put("input1", new TensorValue(Tensor.from(function.argumentTypes().get("input1"), input1Argument)));
+ context.put("input2", new TensorValue(Tensor.from(function.argumentTypes().get("input2"), "{{x:0}:3, {x:1}:6, {x:2}:9}")));
+ context.put("constant1", new TensorValue(Tensor.from("tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}")));
+ context.put("constant2", new TensorValue(Tensor.from("tensor():{{}:3}")));
+ return new RankingExpression(function.expression()).evaluate(context).asTensor();
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
}
diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model
index 9579be4e44c..66d21cfc53f 100644
--- a/model-integration/src/test/models/vespa/example.model
+++ b/model-integration/src/test/models/vespa/example.model
@@ -15,11 +15,11 @@ model example {
}
function foo1() {
- expression: max(sum(input1 * input2, name) * constant1, x) * constant2
+ expression: reduce(sum(input1 * input2, name) * constant1, max, x) * constant2
}
function foo2() {
- expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2
+ expression: reduce(sum(input1 * input2, name) * constant1asLarge, max, x) * constant2
}
} \ No newline at end of file