diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-07 13:54:51 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-07 13:54:51 +0200 |
commit | 55b5c8c7c4b14303dfffb9ad017fd6bcea40e9b9 (patch) | |
tree | 18a8b20fb58da199f7ba832b351a3c27a3f309ce /model-evaluation/src | |
parent | ad6540e0e71b8db2d236f029266cfacc0a0f11a8 (diff) |
Refactor test and prepare for injecting constants
Diffstat (limited to 'model-evaluation/src')
5 files changed, 197 insertions, 124 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 9af4022b170..cd21a0a6813 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -120,7 +120,7 @@ class RankProfilesConfigImporter { return constants; } - private Tensor readTensorFromFile(String name, TensorType type, String fileReference) { + Tensor readTensorFromFile(String name, TensorType type, String fileReference) { try { // TODO: Only allow these two fallbacks in testing mode if (fileReference.isEmpty()) { // this may be the case in unit tests @@ -135,6 +135,7 @@ class RankProfilesConfigImporter { } // TODO: Move these 2 lines to FileReference + dir = new File(Defaults.getDefaults().underVespaHome("var/db/vespa/filedistribution"), fileReference); File file = dir.listFiles()[0]; // directory contains one file having the original name diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java new file mode 100644 index 00000000000..a823f16d727 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -0,0 +1,82 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import org.junit.Test; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests instantiating models from rank-profiles configs. + * + * @author bratseth + */ +public class MlModelsImportingTest { + + @Test + public void testImportingModels() { + ModelTester tester = new ModelTester("src/test/resources/config/models/"); + + assertEquals(4, tester.models().size()); + + // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that + { + Model xgboost = tester.models().get("xgboost_2_2"); + tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + FunctionEvaluator evaluator = xgboost.evaluatorOf(); + assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + + Model onnxMnistSoftmax = tester.models().get("mnist_softmax"); + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnist = tester.models().get("mnist_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + // Macro: + tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument + assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java new file mode 100644 index 00000000000..63e17e37bde --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Logger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * Helper for testing model import and evaluation + * + * @author bratseth + */ +public class ModelTester { + + private final Map<String, Model> models; + + public ModelTester(String modelConfigDirectory) { + models = createModels(modelConfigDirectory); + } + + public Map<String, Model> models() { return models; } + + private static Map<String, Model> createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new RankProfilesConfigImporterWithMockedConstants().importFrom(config, constantsConfig); + } + + public void assertFunction(String name, String expression, Model model) { + assertNotNull("Model is present in config", model); + ExpressionFunction function = model.function(name); + assertNotNull("Function '" + name + "' is in " + model, function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + public void assertBoundFunction(String name, String expression, Model model) { + ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); + assertNotNull("Function '" + name + "' is present", function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + /** Allows us to provide canned tensor constants during import since file distribution does not work in tests */ + private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter { + + private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName()); + + Map<String, Tensor> constants = new HashMap<>(); + + @Override + Tensor readTensorFromFile(String name, TensorType type, String fileReference) { + if ( ! constants.containsKey(name)) { + log.warning("Missing a mocked tensor constant for '" + name + "': Returning an empty tensor"); + return Tensor.from(type, "{}"); + } + return constants.get(name); + } + + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java new file mode 100644 index 00000000000..210ffb823b2 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class RankProfileImportingTest { + + @Test + public void testImportingRankExpressions() { + ModelTester tester = new ModelTester("src/test/resources/config/rankexpression/"); + + assertEquals(18, tester.models().size()); + + Model macros = tester.models().get("macros"); + assertEquals("macros", macros.name()); + assertEquals(4, macros.functions().size()); + tester.assertFunction("fourtimessum", "4 * (var1 + var2)", macros); + tester.assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); + tester.assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); + tester.assertFunction("myfeature", + "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + + "30 * pow(0 - fieldMatch(description).earliness,2)", + macros); + assertEquals(4, macros.referencedFunctions().size()); + tester.assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", + "4 * (match + rankBoost)", macros); + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java deleted file mode 100644 index 2cb9602dfa7..00000000000 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.models.evaluation; - -import com.yahoo.config.subscription.ConfigGetter; -import com.yahoo.config.subscription.FileSource; -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; -import org.junit.Test; - -import java.io.File; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -/** - * Tests instantiating models from rank-profiles configs. - * - * @author bratseth - */ -public class RankProfilesImporterTest { - - @Test - public void testImportingModels() { - Map<String, Model> models = createModels("src/test/resources/config/models/"); - - assertEquals(4, models.size()); - - // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that - { - Model xgboost = models.get("xgboost_2_2"); - assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - FunctionEvaluator evaluator = xgboost.evaluatorOf(); - assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - - { - - Model onnxMnistSoftmax = models.get("mnist_softmax"); - assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - assertEquals("tensor(d1[10],d2[784])", - onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); - FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - - { - Model tfMnistSoftmax = models.get("mnist_softmax_saved"); - assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - - { - Model tfMnist = models.get("mnist_saved"); - assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - // Macro: - assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", - "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", - tfMnist); - FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument - assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - } - - @Test - public void testImportingRankExpressions() { - Map<String, Model> models = createModels("src/test/resources/config/rankexpression/"); - - assertEquals(18, models.size()); - - Model macros = models.get("macros"); - assertEquals("macros", macros.name()); - assertEquals(4, macros.functions().size()); - assertFunction("fourtimessum", "4 * (var1 + var2)", macros); - assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); - assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); - assertFunction("myfeature", - "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + - "30 * pow(0 - fieldMatch(description).earliness,2)", - macros); - assertEquals(4, macros.referencedFunctions().size()); - assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", - "4 * (match + rankBoost)", macros); - } - - private void assertFunction(String name, String expression, Model model) { - assertNotNull("Model is present in config", model); - ExpressionFunction function = model.function(name); - assertNotNull("Function '" + name + "' is in " + model, function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - - private void assertBoundFunction(String name, String expression, Model model) { - ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); - assertNotNull("Function '" + name + "' is present", function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - - private Map<String, Model> createModels(String path) { - Path configDir = Path.fromString(path); - RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), - RankProfilesConfig.class).getConfig(""); - RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), - RankingConstantsConfig.class).getConfig(""); - return new RankProfilesConfigImporter().importFrom(config, constantsConfig); - } - -} |