diff options
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java')
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java | 51 |
1 files changed, 45 insertions, 6 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index d45372fc7da..84e01e58280 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -3,8 +3,10 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -21,14 +23,41 @@ import static org.junit.Assert.assertNotNull; public class RankProfilesImporterTest { @Test - public void testImporting() { - String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; - RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); - Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); + public void testImportingModels() { + Map<String, Model> models = createModels("src/test/resources/config/models/"); + + assertEquals(4, models.size()); + + Model xgboost = models.get("xgboost_2_2"); + assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + + Model onnxMnistSoftmax = models.get("mnist_softmax"); + assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + + Model tfMnistSoftmax = models.get("mnist_softmax_saved"); + assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + + Model tfMnist = models.get("mnist_saved"); + assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + } + + @Test + public void testImportingRankExpressions() { + Map<String, Model> models = createModels("src/test/resources/config/rankexpression/"); + assertEquals(18, models.size()); Model macros = models.get("macros"); - assertNotNull(macros); assertEquals("macros", macros.name()); assertEquals(4, macros.functions().size()); assertFunction("fourtimessum", "4 * (var1 + var2)", macros); @@ -44,8 +73,9 @@ public class RankProfilesImporterTest { } private void assertFunction(String name, String expression, Model model) { + assertNotNull("Model is present in config", model); ExpressionFunction function = model.function(name); - assertNotNull(function); + assertNotNull("Function '" + name + "' is in " + model, function); assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); } @@ -57,4 +87,13 @@ public class RankProfilesImporterTest { assertEquals(expression, function.getBody().getRoot().toString()); } + private Map<String, Model> createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new RankProfilesConfigImporter().importFrom(config, constantsConfig); + } + } |