diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-03 13:16:06 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-03 13:16:06 +0200 |
commit | 03fb3fc851fe6f5bec4f4d86d7ff6ea5dcce5fd7 (patch) | |
tree | 094dd8d8c7c324ff6208f35d1bad201ba58d2782 /model-evaluation | |
parent | 5baad482446f664754aaa4ad422fa00a055470e6 (diff) |
Test importing of ml models
Diffstat (limited to 'model-evaluation')
3 files changed, 45 insertions, 4 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 60cf0d25ded..23928c5b7e7 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -46,6 +46,5 @@ public class ModelsEvaluatorTest { // TODO: Test that binding nonexisting variable doesn't work // TODO: Test that rebinding doesn't work // TODO: Test with nested macros - // TODO: Test TF/ONNX model } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index d45372fc7da..e0f5674e016 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -21,14 +21,41 @@ import static org.junit.Assert.assertNotNull; public class RankProfilesImporterTest { @Test - public void testImporting() { + public void testImportingModels() { + String configPath = "src/test/resources/config/models/rank-profiles.cfg"; + RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); + Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); + assertEquals(4, models.size()); + + Model xgboost = models.get("xgboost_2_2"); + assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + + Model onnxMnistSoftmax = models.get("mnist_softmax"); + assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + + Model tfMnistSoftmax = models.get("mnist_softmax_saved"); + assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + + Model tfMnist = models.get("mnist_saved"); + assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + } + + @Test + public void testImportingRankExpressions() { String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); assertEquals(18, models.size()); Model macros = models.get("macros"); - assertNotNull(macros); assertEquals("macros", macros.name()); assertEquals(4, macros.functions().size()); assertFunction("fourtimessum", "4 * (var1 + var2)", macros); @@ -44,8 +71,9 @@ public class RankProfilesImporterTest { } private void assertFunction(String name, String expression, Model model) { + assertNotNull("Model is present in config", model); ExpressionFunction function = model.function(name); - assertNotNull(function); + assertNotNull("Function '" + name + "' is in " + model, function); assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); } diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg new file mode 100644 index 00000000000..1cc36f75158 --- /dev/null +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -0,0 +1,14 @@ +rankprofile[0].name "mnist_saved" +rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript" +rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" +rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" +rankprofile[1].name "xgboost_2_2" +rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript" +rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)" +rankprofile[2].name "mnist_softmax_saved" +rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" +rankprofile[3].name "mnist_softmax" +rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript" +rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" |