diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-04 22:06:11 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-04 22:06:11 +0200 |
commit | f6076b70ef85bc6c0c10c373307acb1b1f456702 (patch) | |
tree | 2d3c7bad7f9157aba94c9119c6967ee34ecbec57 /model-evaluation/src/test/java/ai | |
parent | 793aac11a3cbefd24535595665fc3c3104c1e043 (diff) |
Revert "Bratseth/handle large constants take 2"
Diffstat (limited to 'model-evaluation/src/test/java/ai')
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java | 22 | ||||
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java | 51 |
2 files changed, 15 insertions, 58 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index d94e5b2af1b..60cf0d25ded 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -3,10 +3,8 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; -import com.yahoo.path.Path; import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -20,9 +18,15 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; + private ModelsEvaluator createModels() { + String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; + RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); + return new ModelsEvaluator(config); + } + @Test public void testTensorEvaluation() { - ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); + ModelsEvaluator models = createModels(); FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); @@ -31,7 +35,7 @@ public class ModelsEvaluatorTest { @Test public void testEvaluationDependingOnMacroTakingArguments() { - ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); + ModelsEvaluator models = createModels(); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); function.bind("rankBoost", 5); @@ -42,14 +46,6 @@ public class ModelsEvaluatorTest { // TODO: Test that binding nonexisting variable doesn't work // TODO: Test that rebinding doesn't work // TODO: Test with nested macros - - private ModelsEvaluator createModels(String path) { - Path configDir = Path.fromString(path); - RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), - RankProfilesConfig.class).getConfig(""); - RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), - RankingConstantsConfig.class).getConfig(""); - return new ModelsEvaluator(config, constantsConfig); - } + // TODO: Test TF/ONNX model } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index 84e01e58280..d45372fc7da 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -3,10 +3,8 @@ package ai.vespa.models.evaluation; import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; -import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.Test; import java.io.File; @@ -23,41 +21,14 @@ import static org.junit.Assert.assertNotNull; public class RankProfilesImporterTest { @Test - public void testImportingModels() { - Map<String, Model> models = createModels("src/test/resources/config/models/"); - - assertEquals(4, models.size()); - - Model xgboost = models.get("xgboost_2_2"); - assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - - Model onnxMnistSoftmax = models.get("mnist_softmax"); - assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - assertEquals("tensor(d1[10],d2[784])", - onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); - - Model tfMnistSoftmax = models.get("mnist_softmax_saved"); - assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - - Model tfMnist = models.get("mnist_saved"); - assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - } - - @Test - public void testImportingRankExpressions() { - Map<String, Model> models = createModels("src/test/resources/config/rankexpression/"); - + public void testImporting() { + String configPath = "src/test/resources/config/rankexpression/rank-profiles.cfg"; + RankProfilesConfig config = new ConfigGetter<>(new FileSource(new File(configPath)), RankProfilesConfig.class).getConfig(""); + Map<String, Model> models = new RankProfilesConfigImporter().importFrom(config); assertEquals(18, models.size()); Model macros = models.get("macros"); + assertNotNull(macros); assertEquals("macros", macros.name()); assertEquals(4, macros.functions().size()); assertFunction("fourtimessum", "4 * (var1 + var2)", macros); @@ -73,9 +44,8 @@ public class RankProfilesImporterTest { } private void assertFunction(String name, String expression, Model model) { - assertNotNull("Model is present in config", model); ExpressionFunction function = model.function(name); - assertNotNull("Function '" + name + "' is in " + model, function); + assertNotNull(function); assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); } @@ -87,13 +57,4 @@ public class RankProfilesImporterTest { assertEquals(expression, function.getBody().getRoot().toString()); } - private Map<String, Model> createModels(String path) { - Path configDir = Path.fromString(path); - RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), - RankProfilesConfig.class).getConfig(""); - RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), - RankingConstantsConfig.class).getConfig(""); - return new RankProfilesConfigImporter().importFrom(config, constantsConfig); - } - } |