diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-03 09:51:40 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-03 09:51:40 +0200 |
commit | 2bcae10ce88746923d34e1db46e5e4b3ce835bd2 (patch) | |
tree | e75250b5620f7cf151eeb6ba3bcdd28e12a35bc1 /model-integration | |
parent | 716c293fe3e766c6af8e3e62d65ea40f0bf8369d (diff) |
Test evaluation
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java | 26 | ||||
-rw-r--r-- | model-integration/src/test/models/vespa/example.model | 4 |
2 files changed, 26 insertions, 4 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java index 4c8890f6476..767af147ad7 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java @@ -3,6 +3,12 @@ package ai.vespa.rankingexpression.importer.vespa; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; import org.junit.Test; import java.util.List; @@ -32,14 +38,17 @@ public class VespaImportTestCase { assertEquals("tensor(x[3]):{{x:0}:0.5,{x:1}:1.5,{x:2}:2.5}", model.largeConstants().get("constant1asLarge")); assertEquals(2, model.expressions().size()); - assertEquals("max(reduce(input1 * input2, sum, name) * constant1,x) * constant2", + assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2", model.expressions().get("foo1").getRoot().toString()); - assertEquals("max(reduce(input1 * input2, sum, name) * constant1asLarge,x) * constant2", + assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1asLarge, max, x) * constant2", model.expressions().get("foo2").getRoot().toString()); List<ImportedMlFunction> functions = model.outputExpressions(); assertEquals(2, functions.size()); ImportedMlFunction foo1Function = functions.get(0); + assertEquals("foo1", foo1Function.name()); + assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2", foo1Function.expression()); + assertEquals("tensor():{202.5}", evaluate(foo1Function, "{{name:a, x:0}: 1, {name:a, x:1}: 2, {name:a, x:2}: 3}").toString()); assertEquals(2, foo1Function.arguments().size()); assertTrue(foo1Function.arguments().contains("input1")); assertTrue(foo1Function.arguments().contains("input2")); @@ -80,4 +89,17 @@ public class VespaImportTestCase { return model; } + private Tensor evaluate(ImportedMlFunction function, String input1Argument) { + try { + MapContext context = new MapContext(); + context.put("input1", new TensorValue(Tensor.from(function.argumentTypes().get("input1"), input1Argument))); + context.put("input2", new TensorValue(Tensor.from(function.argumentTypes().get("input2"), "{{x:0}:3, {x:1}:6, {x:2}:9}"))); + context.put("constant1", new TensorValue(Tensor.from("tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}"))); + context.put("constant2", new TensorValue(Tensor.from("tensor():{{}:3}"))); + return new RankingExpression(function.expression()).evaluate(context).asTensor(); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } } diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model index 9579be4e44c..66d21cfc53f 100644 --- a/model-integration/src/test/models/vespa/example.model +++ b/model-integration/src/test/models/vespa/example.model @@ -15,11 +15,11 @@ model example { } function foo1() { - expression: max(sum(input1 * input2, name) * constant1, x) * constant2 + expression: reduce(sum(input1 * input2, name) * constant1, max, x) * constant2 } function foo2() { - expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2 + expression: reduce(sum(input1 * input2, name) * constant1asLarge, max, x) * constant2 } }
\ No newline at end of file |