summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-07 13:54:51 +0200
committerJon Bratseth <bratseth@oath.com>2018-09-07 13:54:51 +0200
commit55b5c8c7c4b14303dfffb9ad017fd6bcea40e9b9 (patch)
tree18a8b20fb58da199f7ba832b351a3c27a3f309ce /model-evaluation
parentad6540e0e71b8db2d236f029266cfacc0a0f11a8 (diff)
Refactor test and prepare for injecting constants
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java3
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java82
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java77
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java36
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java123
5 files changed, 197 insertions, 124 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 9af4022b170..cd21a0a6813 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -120,7 +120,7 @@ class RankProfilesConfigImporter {
return constants;
}
- private Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
+ Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
try {
// TODO: Only allow these two fallbacks in testing mode
if (fileReference.isEmpty()) { // this may be the case in unit tests
@@ -135,6 +135,7 @@ class RankProfilesConfigImporter {
}
// TODO: Move these 2 lines to FileReference
+
dir = new File(Defaults.getDefaults().underVespaHome("var/db/vespa/filedistribution"), fileReference);
File file = dir.listFiles()[0]; // directory contains one file having the original name
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
new file mode 100644
index 00000000000..a823f16d727
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -0,0 +1,82 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests instantiating models from rank-profiles configs.
+ *
+ * @author bratseth
+ */
+public class MlModelsImportingTest {
+
+ @Test
+ public void testImportingModels() {
+ ModelTester tester = new ModelTester("src/test/resources/config/models/");
+
+ assertEquals(4, tester.models().size());
+
+ // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
+ {
+ Model xgboost = tester.models().get("xgboost_2_2");
+ tester.assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ FunctionEvaluator evaluator = xgboost.evaluatorOf();
+ assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+
+ Model onnxMnistSoftmax = tester.models().get("mnist_softmax");
+ tester.assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10],d2[784])",
+ onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
+ FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved");
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnist = tester.models().get("mnist_saved");
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ // Macro:
+ tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
+ assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
new file mode 100644
index 00000000000..63e17e37bde
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -0,0 +1,77 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.logging.Logger;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * Helper for testing model import and evaluation
+ *
+ * @author bratseth
+ */
+public class ModelTester {
+
+ private final Map<String, Model> models;
+
+ public ModelTester(String modelConfigDirectory) {
+ models = createModels(modelConfigDirectory);
+ }
+
+ public Map<String, Model> models() { return models; }
+
+ private static Map<String, Model> createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ return new RankProfilesConfigImporterWithMockedConstants().importFrom(config, constantsConfig);
+ }
+
+ public void assertFunction(String name, String expression, Model model) {
+ assertNotNull("Model is present in config", model);
+ ExpressionFunction function = model.function(name);
+ assertNotNull("Function '" + name + "' is in " + model, function);
+ assertEquals(name, function.getName());
+ assertEquals(expression, function.getBody().getRoot().toString());
+ }
+
+ public void assertBoundFunction(String name, String expression, Model model) {
+ ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get());
+ assertNotNull("Function '" + name + "' is present", function);
+ assertEquals(name, function.getName());
+ assertEquals(expression, function.getBody().getRoot().toString());
+ }
+
+ /** Allows us to provide canned tensor constants during import since file distribution does not work in tests */
+ private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter {
+
+ private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName());
+
+ Map<String, Tensor> constants = new HashMap<>();
+
+ @Override
+ Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
+ if ( ! constants.containsKey(name)) {
+ log.warning("Missing a mocked tensor constant for '" + name + "': Returning an empty tensor");
+ return Tensor.from(type, "{}");
+ }
+ return constants.get(name);
+ }
+
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
new file mode 100644
index 00000000000..210ffb823b2
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
@@ -0,0 +1,36 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import org.junit.Test;
+
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class RankProfileImportingTest {
+
+ @Test
+ public void testImportingRankExpressions() {
+ ModelTester tester = new ModelTester("src/test/resources/config/rankexpression/");
+
+ assertEquals(18, tester.models().size());
+
+ Model macros = tester.models().get("macros");
+ assertEquals("macros", macros.name());
+ assertEquals(4, macros.functions().size());
+ tester.assertFunction("fourtimessum", "4 * (var1 + var2)", macros);
+ tester.assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros);
+ tester.assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros);
+ tester.assertFunction("myfeature",
+ "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " +
+ "30 * pow(0 - fieldMatch(description).earliness,2)",
+ macros);
+ assertEquals(4, macros.referencedFunctions().size());
+ tester.assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)",
+ "4 * (match + rankBoost)", macros);
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
deleted file mode 100644
index 2cb9602dfa7..00000000000
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
+++ /dev/null
@@ -1,123 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.models.evaluation;
-
-import com.yahoo.config.subscription.ConfigGetter;
-import com.yahoo.config.subscription.FileSource;
-import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.vespa.config.search.RankProfilesConfig;
-import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
-import org.junit.Test;
-
-import java.io.File;
-import java.util.Map;
-import java.util.stream.Collectors;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-
-/**
- * Tests instantiating models from rank-profiles configs.
- *
- * @author bratseth
- */
-public class RankProfilesImporterTest {
-
- @Test
- public void testImportingModels() {
- Map<String, Model> models = createModels("src/test/resources/config/models/");
-
- assertEquals(4, models.size());
-
- // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
- {
- Model xgboost = models.get("xgboost_2_2");
- assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
- FunctionEvaluator evaluator = xgboost.evaluatorOf();
- assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
-
- {
-
- Model onnxMnistSoftmax = models.get("mnist_softmax");
- assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- assertEquals("tensor(d1[10],d2[784])",
- onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
- FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
- assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
-
- {
- Model tfMnistSoftmax = models.get("mnist_softmax_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
- FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
- assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
-
- {
- Model tfMnist = models.get("mnist_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
- // Macro:
- assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
- "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
- tfMnist);
- FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
- assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
- }
-
- @Test
- public void testImportingRankExpressions() {
- Map<String, Model> models = createModels("src/test/resources/config/rankexpression/");
-
- assertEquals(18, models.size());
-
- Model macros = models.get("macros");
- assertEquals("macros", macros.name());
- assertEquals(4, macros.functions().size());
- assertFunction("fourtimessum", "4 * (var1 + var2)", macros);
- assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros);
- assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros);
- assertFunction("myfeature",
- "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " +
- "30 * pow(0 - fieldMatch(description).earliness,2)",
- macros);
- assertEquals(4, macros.referencedFunctions().size());
- assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)",
- "4 * (match + rankBoost)", macros);
- }
-
- private void assertFunction(String name, String expression, Model model) {
- assertNotNull("Model is present in config", model);
- ExpressionFunction function = model.function(name);
- assertNotNull("Function '" + name + "' is in " + model, function);
- assertEquals(name, function.getName());
- assertEquals(expression, function.getBody().getRoot().toString());
- }
-
- private void assertBoundFunction(String name, String expression, Model model) {
- ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get());
- assertNotNull("Function '" + name + "' is present", function);
- assertEquals(name, function.getName());
- assertEquals(expression, function.getBody().getRoot().toString());
- }
-
- private Map<String, Model> createModels(String path) {
- Path configDir = Path.fromString(path);
- RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
- RankProfilesConfig.class).getConfig("");
- RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
- RankingConstantsConfig.class).getConfig("");
- return new RankProfilesConfigImporter().importFrom(config, constantsConfig);
- }
-
-}