diff options
Diffstat (limited to 'config-model/src/test')
8 files changed, 360 insertions, 8 deletions
diff --git a/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json new file mode 100644 index 00000000000..cf0488ecd8b --- /dev/null +++ b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json @@ -0,0 +1,275 @@ +{ + "name": "tree", + "version": "v3", + "num_class": 1, + "num_tree_per_iteration": 1, + "label_index": 0, + "max_feature_idx": 3, + "average_output": false, + "objective": "regression", + "feature_names": [ + "numerical_1", + "numerical_2", + "categorical_1", + "categorical_2" + ], + "monotone_constraints": [], + "tree_info": [ + { + "tree_index": 0, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 68.5353012084961, + "threshold": 0.46643291586559305, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 2.1594397038037663, + "leaf_weight": 469, + "leaf_count": 469 + }, + "right_child": { + "split_index": 1, + "split_feature": 3, + "split_gain": 41.27640151977539, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.246035, + "internal_weight": 531, + "internal_count": 531, + "left_child": { + "leaf_index": 1, + "leaf_value": 2.235297305276056, + "leaf_weight": 302, + "leaf_count": 302 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 2.1792953471546546, + "leaf_weight": 229, + "leaf_count": 229 + } + } + } + }, + { + "tree_index": 1, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 2, + "split_gain": 64.22250366210938, + "threshold": "3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 0.03070842919354316, + "leaf_weight": 399, + "leaf_count": 399 + }, + "right_child": { + "split_index": 1, + "split_feature": 0, + "split_gain": 36.74250030517578, + "threshold": 0.5102250691730842, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.204906, + "internal_weight": 601, + "internal_count": 601, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.04439151147520909, + "leaf_weight": 315, + "leaf_count": 315 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.005117411709368601, + "leaf_weight": 286, + "leaf_count": 286 + } + } + } + }, + { + "tree_index": 2, + "num_leaves": 3, + "num_cat": 0, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 57.1327018737793, + "threshold": 0.668665477622446, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 40.859100341796875, + "threshold": 0.008118820676863816, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.162926, + "internal_weight": 681, + "internal_count": 681, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.15361238490967524, + "leaf_weight": 21, + "leaf_count": 21 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": -0.01192330846157292, + "leaf_weight": 660, + "leaf_count": 660 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": 0.03499044894987518, + "leaf_weight": 319, + "leaf_count": 319 + } + } + }, + { + "tree_index": 3, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 0, + "split_gain": 54.77090072631836, + "threshold": 0.5201391072644542, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.02141000620783247, + "leaf_weight": 543, + "leaf_count": 543 + }, + "right_child": { + "split_index": 1, + "split_feature": 2, + "split_gain": 27.200700759887695, + "threshold": "0||1", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.255704, + "internal_weight": 457, + "internal_count": 457, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.004121485787596721, + "leaf_weight": 191, + "leaf_count": 191 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.04534090904886873, + "leaf_weight": 266, + "leaf_count": 266 + } + } + } + }, + { + "tree_index": 4, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 3, + "split_gain": 51.84349822998047, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 39.352699279785156, + "threshold": 0.27283279016959255, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0.188414, + "internal_weight": 593, + "internal_count": 593, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.01924803254356527, + "leaf_weight": 184, + "leaf_count": 184 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.03643772842347651, + "leaf_weight": 409, + "leaf_count": 409 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": -0.02701711918923075, + "leaf_weight": 407, + "leaf_count": 407 + } + } + } + ], + "pandas_categorical": [ + [ + "a", + "b", + "c", + "d", + "e" + ], + [ + "i", + "j", + "k", + "l", + "m" + ] + ] +}
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java index 59af3193b79..8c4c6aa7fc0 100644 --- a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java +++ b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java @@ -2,6 +2,7 @@ package com.yahoo.config.model; import com.yahoo.component.Version; +import com.yahoo.concurrent.InThreadExecutorService; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.application.api.FileRegistry; @@ -18,6 +19,7 @@ import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; import java.util.Optional; +import java.util.concurrent.ExecutorService; /** * @author hmusum @@ -83,4 +85,8 @@ public class MockModelContext implements ModelContext { return new TestProperties(); } + @Override + public ExecutorService getExecutor() { + return new InThreadExecutorService(); + } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 8f3fbfc9de9..69789d09dc2 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -118,8 +118,13 @@ public class RankProfileTestCase extends SchemaTestCase { @Test public void requireThatSidewaysInheritanceIsImpossible() throws ParseException { + verifySidewaysInheritance(false); + verifySidewaysInheritance(true); + } + private void verifySidewaysInheritance(boolean enforce) throws ParseException { RankProfileRegistry registry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes()); + SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes(), + new TestProperties().enforceRankProfileInheritance(enforce)); builder.importString(joinLines( "schema child1 {", " document child1 {", @@ -163,7 +168,15 @@ public class RankProfileTestCase extends SchemaTestCase { "}")); try { builder.build(true); + if (enforce) { + fail("Sideways inheritance should have been enforced"); + } else { + assertNotNull(builder.getSearch("child2")); + assertNotNull(builder.getSearch("child1")); + assertTrue(registry.get("child1", "child").inherits("parent")); + } } catch (IllegalArgumentException e) { + if (!enforce) fail("Sideways inheritance should have been allowed"); assertEquals("rank-profile 'child' inherits 'parent', but it does not exist anywhere in the inheritance of search 'child1'.", e.getMessage()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index d7143281977..d87278a9ca1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -2,8 +2,10 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; +import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.application.provider.MockFileRegistry; import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; @@ -11,7 +13,9 @@ import com.yahoo.searchdefinition.parser.ParseException; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import org.junit.Test; +import java.util.ArrayList; import java.util.Optional; +import java.util.logging.Level; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -26,7 +30,7 @@ public class RankingExpressionInliningTestCase extends SchemaTestCase { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); SearchBuilder builder = new SearchBuilder(rankProfileRegistry); builder.importString( - "search test {\n" + + "search test {\n" + " document test { \n" + " field a type double { \n" + " indexing: attribute \n" + @@ -186,6 +190,39 @@ public class RankingExpressionInliningTestCase extends SchemaTestCase { assertEquals("attribute(b) + 1", getRankingExpression("D", test, s)); } + @Test + public void testFunctionInliningWithReplacement() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + MockDeployLogger deployLogger = new MockDeployLogger(); + SearchBuilder builder = new SearchBuilder(MockApplicationPackage.createEmpty(), + new MockFileRegistry(), + deployLogger, + new TestProperties(), + rankProfileRegistry, + new QueryProfileRegistry()); + builder.importString( + "search test {\n" + + " document test { }\n" + + " rank-profile test {\n" + + " first-phase {\n" + + " expression: foo\n" + + " }\n" + + " function foo(x) {\n" + + " expression: x + x\n" + + " }\n" + + " function inline foo() {\n" + // replaces previous "foo" during parsing + " expression: foo(2)\n" + + " }\n" + + " }\n" + + "}\n"); + builder.build(); + Search s = builder.getSearch(); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); + assertEquals("foo(2)", test.getFirstPhaseRanking().getRoot().toString()); + assertTrue("Does not contain expected warning", deployLogger.contains("Function 'foo' replaces " + + "a previous function with the same name in rank profile 'test'")); + } + /** * Expression evaluation has no stack so function arguments are bound at config time creating a separate version of * each function for each binding, using hashes to name the bound variants of the function. @@ -221,4 +258,17 @@ public class RankingExpressionInliningTestCase extends SchemaTestCase { return censorBindingHash(rankExpression.get()); } + private static class MockDeployLogger implements DeployLogger { + private final ArrayList<String> msgs = new ArrayList<>(); + + @Override + public void log(Level level, String message) { + msgs.add(message); + } + + public boolean contains(String expected) { + return msgs.stream().anyMatch(msg -> msg.equals(expected)); + } + } + } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java index 8ef04752800..12263521dcb 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java @@ -9,7 +9,6 @@ import org.junit.Test; import java.io.IOException; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; /** * Tests exporting @@ -106,7 +105,7 @@ public class ExportingTestCase extends AbstractExportingTestCase { @Test public void testRankExpression() throws IOException, ParseException { assertCorrectDeriving("rankexpression", null, - new TestProperties().useExternalRankExpression(true).largeRankExpressionLimit(1024), new TestableDeployLogger()); + new TestProperties().largeRankExpressionLimit(1024), new TestableDeployLogger()); } @Test diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 010b33597f3..9c363ea0628 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -24,6 +24,8 @@ import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import static org.junit.Assert.assertEquals; @@ -43,6 +45,7 @@ class RankProfileSearchFixture { private final QueryProfileRegistry queryProfileRegistry; private final Search search; private final Map<String, RankProfile> compiledRankProfiles = new HashMap<>(); + private final ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); public RankProfileRegistry getRankProfileRegistry() { return rankProfileRegistry; @@ -105,7 +108,7 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) .compile(queryProfileRegistry, - new ImportedMlModels(applicationDir.toFile(), importers)); + new ImportedMlModels(applicationDir.toFile(), executor, importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index 00ac5ac5405..b81fe7a02cc 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -36,7 +36,7 @@ public class RankingExpressionsTestCase extends SchemaTestCase { @Test public void testFunctions() throws IOException, ParseException { - ModelContext.Properties deployProperties = new TestProperties().useExternalRankExpression(true); + ModelContext.Properties deployProperties = new TestProperties(); RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); Search search = createSearch("src/test/examples/rankingexpressionfunction", deployProperties, rankProfileRegistry); RankProfile functionsRankProfile = rankProfileRegistry.get(search, "macros"); @@ -115,7 +115,7 @@ public class RankingExpressionsTestCase extends SchemaTestCase { @Test public void testLargeInheritedFunctions() throws IOException, ParseException { - ModelContext.Properties properties = new TestProperties().useExternalRankExpression(true).largeRankExpressionLimit(50); + ModelContext.Properties properties = new TestProperties().largeRankExpressionLimit(50); RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); LargeRankExpressions largeExpressions = new LargeRankExpressions(new MockFileRegistry()); QueryProfileRegistry queryProfiles = new QueryProfileRegistry(); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java index 771cba673bc..e6d3b5dc140 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java @@ -18,7 +18,7 @@ public class ModelsEvaluatorTest { @Test public void testModelsEvaluatorTester() { ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval"); - assertEquals(2, modelsEvaluator.models().size()); + assertEquals(3, modelsEvaluator.models().size()); // ONNX model evaluation FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul"); @@ -27,6 +27,12 @@ public class ModelsEvaluatorTest { Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate(); assertEquals(6.0, output.sum().asDouble(), 1e-9); + // LightGBM model evaluation + FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression"); + lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i"); + output = lgbm.evaluate(); + assertEquals(2.0547, output.sum().asDouble(), 1e-4); + // Vespa model evaluation FunctionEvaluator foo1 = modelsEvaluator.evaluatorOf("example", "foo1"); input1 = Tensor.from("tensor(name{},x[3]):{{name:n,x:0}:1,{name:n,x:1}:2,{name:n,x:2}:3 }"); |