diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-03 20:26:17 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-03 20:26:17 +0200 |
commit | 489aae55545c19d409ab0f23ac17ec62287645d3 (patch) | |
tree | 63c9d1f54d5e257b8bd3428fec5dfa607b18848a /model-evaluation/src/test | |
parent | 03fb3fc851fe6f5bec4f4d86d7ff6ea5dcce5fd7 (diff) |
Read and resolve constants in model evaluation
Diffstat (limited to 'model-evaluation/src/test')
4 files changed, 60 insertions, 14 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 23928c5b7e7..d94e5b2af1b 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -3,8 +3,10 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -18,15 +20,9 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; - private ModelsEvaluator createModels() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - return new ModelsEvaluator(config); - } - @Test public void testTensorEvaluation() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); @@ -35,7 +31,7 @@ public class ModelsEvaluatorTest { @Test public void testEvaluationDependingOnMacroTakingArguments() { - ModelsEvaluator models = createModels(); + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); function.bind("rankBoost", 5); @@ -47,4 +43,13 @@ public class ModelsEvaluatorTest { // TODO: Test that rebinding doesn't work // TODO: Test with nested macros + private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new ModelsEvaluator(config, constantsConfig); + } + } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index e0f5674e016..84e01e58280 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -3,8 +3,10 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -22,9 +24,8 @@ public class RankProfilesImporterTest { @Test public void testImportingModels() { - String configPath = "src/test/resources/config/models/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); + Map<String, Model> models = createModels("src/test/resources/config/models/"); + assertEquals(4, models.size()); Model xgboost = models.get("xgboost_2_2"); @@ -36,6 +37,8 @@ public class RankProfilesImporterTest { assertFunction("default.add", "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); Model tfMnistSoftmax = models.get("mnist_softmax_saved"); assertFunction("serving_default.y", @@ -50,9 +53,8 @@ public class RankProfilesImporterTest { @Test public void testImportingRankExpressions() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); + Map<String, Model> models = createModels("src/test/resources/config/rankexpression/"); + assertEquals(18, models.size()); Model macros = models.get("macros"); @@ -85,4 +87,13 @@ public class RankProfilesImporterTest { assertEquals(expression, function.getBody().getRoot().toString()); } + private Map<String, Model> createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new RankProfilesConfigImporter().importFrom(config, constantsConfig); + } + } diff --git a/model-evaluation/src/test/resources/config/models/ranking-constants.cfg b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg new file mode 100644 index 00000000000..2b7495ace5e --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/ranking-constants.cfg @@ -0,0 +1,30 @@ +constant[0].name "mnist_saved_dnn_hidden1_weights_read" +constant[0].fileref "" +constant[0].type "tensor(d3[300],d4[784])" +constant[1].name "mnist_saved_dnn_hidden2_weights_read" +constant[1].fileref "" +constant[1].type "tensor(d2[100],d3[300])" +constant[2].name "mnist_softmax_saved_layer_Variable_1_read" +constant[2].fileref "" +constant[2].type "tensor(d1[10])" +constant[3].name "mnist_saved_dnn_hidden1_bias_read" +constant[3].fileref "" +constant[3].type "tensor(d3[300])" +constant[4].name "mnist_saved_dnn_hidden2_bias_read" +constant[4].fileref "" +constant[4].type "tensor(d2[100])" +constant[5].name "mnist_softmax_Variable" +constant[5].fileref "" +constant[5].type "tensor(d1[10],d2[784])" +constant[6].name "mnist_saved_dnn_outputs_weights_read" +constant[6].fileref "" +constant[6].type "tensor(d1[10],d2[100])" +constant[7].name "mnist_softmax_saved_layer_Variable_read" +constant[7].fileref "" +constant[7].type "tensor(d1[10],d2[784])" +constant[8].name "mnist_softmax_Variable_1" +constant[8].fileref "" +constant[8].type "tensor(d1[10])" +constant[9].name "mnist_saved_dnn_outputs_bias_read" +constant[9].fileref "" +constant[9].type "tensor(d1[10])"
\ No newline at end of file diff --git a/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/rankexpression/ranking-constants.cfg |