summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-03 13:16:06 +0200
committerJon Bratseth <bratseth@oath.com>2018-09-03 13:16:06 +0200
commit03fb3fc851fe6f5bec4f4d86d7ff6ea5dcce5fd7 (patch)
tree094dd8d8c7c324ff6208f35d1bad201ba58d2782 /model-evaluation
parent5baad482446f664754aaa4ad422fa00a055470e6 (diff)
Test importing of ml models
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java1
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java34
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg14
3 files changed, 45 insertions, 4 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 60cf0d25ded..23928c5b7e7 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -46,6 +46,5 @@ public class ModelsEvaluatorTest {
// TODO: Test that binding nonexisting variable doesn't work
// TODO: Test that rebinding doesn't work
// TODO: Test with nested macros
- // TODO: Test TF/ONNX model
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
index d45372fc7da..e0f5674e016 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
@@ -21,14 +21,41 @@ import static org.junit.Assert.assertNotNull;
public class RankProfilesImporterTest {
@Test
- public void testImporting() {
+ public void testImportingModels() {
+ String configPath = "src/test/resources/config/models/rank-profiles.cfg";
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig("");
+ Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config);
+ assertEquals(4, models.size());
+
+ Model xgboost = models.get("xgboost_2_2");
+ assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+
+ Model onnxMnistSoftmax = models.get("mnist_softmax");
+ assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+
+ Model tfMnistSoftmax = models.get("mnist_softmax_saved");
+ assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+
+ Model tfMnist = models.get("mnist_saved");
+ assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ }
+
+ @Test
+ public void testImportingRankExpressions() {
String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg";
RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig("");
Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config);
assertEquals(18, models.size());
Model macros = models.get("macros");
- assertNotNull(macros);
assertEquals("macros", macros.name());
assertEquals(4, macros.functions().size());
assertFunction("fourtimessum", "4 * (var1 + var2)", macros);
@@ -44,8 +71,9 @@ public class RankProfilesImporterTest {
}
private void assertFunction(String name, String expression, Model model) {
+ assertNotNull("Model is present in config", model);
ExpressionFunction function = model.function(name);
- assertNotNull(function);
+ assertNotNull("Function '" + name + "' is in " + model, function);
assertEquals(name, function.getName());
assertEquals(expression, function.getBody().getRoot().toString());
}
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
new file mode 100644
index 00000000000..1cc36f75158
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -0,0 +1,14 @@
+rankprofile[0].name "mnist_saved"
+rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript"
+rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"
+rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript"
+rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))"
+rankprofile[1].name "xgboost_2_2"
+rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript"
+rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)"
+rankprofile[2].name "mnist_softmax_saved"
+rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript"
+rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"
+rankprofile[3].name "mnist_softmax"
+rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript"
+rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))"