summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-03 20:26:17 +0200
committerJon Bratseth <bratseth@oath.com>2018-09-03 20:26:17 +0200
commit489aae55545c19d409ab0f23ac17ec62287645d3 (patch)
tree63c9d1f54d5e257b8bd3428fec5dfa607b18848a /model-evaluation/src/test
parent03fb3fc851fe6f5bec4f4d86d7ff6ea5dcce5fd7 (diff)
Read and resolve constants in model evaluation
Diffstat (limited to 'model-evaluation/src/test')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java21
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java23
-rw-r--r--model-evaluation/src/test/resources/config/models/ranking-constants.cfg30
-rw-r--r--model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg0
4 files changed, 60 insertions, 14 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 23928c5b7e7..d94e5b2af1b 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -3,8 +3,10 @@ package ai.vespa.models.evaluation;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
+import com.yahoo.path.Path;
import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import org.junit.Test;
import java.io.File;
@@ -18,15 +20,9 @@ public class ModelsEvaluatorTest {
private static final double delta = 0.00000000001;
- private ModelsEvaluator createModels() {
- String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg";
- RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig("");
- return new ModelsEvaluator(config);
- }
-
@Test
public void testTensorEvaluation() {
- ModelsEvaluator models = createModels();
+ ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/");
FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum");
function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}"));
function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}"));
@@ -35,7 +31,7 @@ public class ModelsEvaluatorTest {
@Test
public void testEvaluationDependingOnMacroTakingArguments() {
- ModelsEvaluator models = createModels();
+ ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/");
FunctionEvaluator function = models.evaluatorOf("macros", "secondphase");
function.bind("match", 3);
function.bind("rankBoost", 5);
@@ -47,4 +43,13 @@ public class ModelsEvaluatorTest {
// TODO: Test that rebinding doesn't work
// TODO: Test with nested macros
+ private ModelsEvaluator createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ return new ModelsEvaluator(config, constantsConfig);
+ }
+
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
index e0f5674e016..84e01e58280 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
@@ -3,8 +3,10 @@ package ai.vespa.models.evaluation;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
+import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import org.junit.Test;
import java.io.File;
@@ -22,9 +24,8 @@ public class RankProfilesImporterTest {
@Test
public void testImportingModels() {
- String configPath = "src/test/resources/config/models/rank-profiles.cfg";
- RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig("");
- Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config);
+ Map<String, Model> models = createModels("src/test/resources/config/models/");
+
assertEquals(4, models.size());
Model xgboost = models.get("xgboost_2_2");
@@ -36,6 +37,8 @@ public class RankProfilesImporterTest {
assertFunction("default.add",
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
onnxMnistSoftmax);
+ assertEquals("tensor(d1[10],d2[784])",
+ onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
Model tfMnistSoftmax = models.get("mnist_softmax_saved");
assertFunction("serving_default.y",
@@ -50,9 +53,8 @@ public class RankProfilesImporterTest {
@Test
public void testImportingRankExpressions() {
- String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg";
- RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig("");
- Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config);
+ Map<String, Model> models = createModels("src/test/resources/config/rankexpression/");
+
assertEquals(18, models.size());
Model macros = models.get("macros");
@@ -85,4 +87,13 @@ public class RankProfilesImporterTest {
assertEquals(expression, function.getBody().getRoot().toString());
}
+ private Map<String, Model> createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ return new RankProfilesConfigImporter().importFrom(config, constantsConfig);
+ }
+
}
diff --git a/model-evaluation/src/test/resources/config/models/ranking-constants.cfg b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg
new file mode 100644
index 00000000000..2b7495ace5e
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg
@@ -0,0 +1,30 @@
+constant[0].name "mnist_saved_dnn_hidden1_weights_read"
+constant[0].fileref ""
+constant[0].type "tensor(d3[300],d4[784])"
+constant[1].name "mnist_saved_dnn_hidden2_weights_read"
+constant[1].fileref ""
+constant[1].type "tensor(d2[100],d3[300])"
+constant[2].name "mnist_softmax_saved_layer_Variable_1_read"
+constant[2].fileref ""
+constant[2].type "tensor(d1[10])"
+constant[3].name "mnist_saved_dnn_hidden1_bias_read"
+constant[3].fileref ""
+constant[3].type "tensor(d3[300])"
+constant[4].name "mnist_saved_dnn_hidden2_bias_read"
+constant[4].fileref ""
+constant[4].type "tensor(d2[100])"
+constant[5].name "mnist_softmax_Variable"
+constant[5].fileref ""
+constant[5].type "tensor(d1[10],d2[784])"
+constant[6].name "mnist_saved_dnn_outputs_weights_read"
+constant[6].fileref ""
+constant[6].type "tensor(d1[10],d2[100])"
+constant[7].name "mnist_softmax_saved_layer_Variable_read"
+constant[7].fileref ""
+constant[7].type "tensor(d1[10],d2[784])"
+constant[8].name "mnist_softmax_Variable_1"
+constant[8].fileref ""
+constant[8].type "tensor(d1[10])"
+constant[9].name "mnist_saved_dnn_outputs_bias_read"
+constant[9].fileref ""
+constant[9].type "tensor(d1[10])" \ No newline at end of file
diff --git a/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg