summaryrefslogtreecommitdiffstats
path: root/config-model/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test')
-rw-r--r--config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json275
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/MockModelContext.java6
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java15
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java52
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java5
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java8
8 files changed, 360 insertions, 8 deletions
diff --git a/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json
new file mode 100644
index 00000000000..cf0488ecd8b
--- /dev/null
+++ b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json
@@ -0,0 +1,275 @@
+{
+ "name": "tree",
+ "version": "v3",
+ "num_class": 1,
+ "num_tree_per_iteration": 1,
+ "label_index": 0,
+ "max_feature_idx": 3,
+ "average_output": false,
+ "objective": "regression",
+ "feature_names": [
+ "numerical_1",
+ "numerical_2",
+ "categorical_1",
+ "categorical_2"
+ ],
+ "monotone_constraints": [],
+ "tree_info": [
+ {
+ "tree_index": 0,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 1,
+ "split_gain": 68.5353012084961,
+ "threshold": 0.46643291586559305,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 2.1594397038037663,
+ "leaf_weight": 469,
+ "leaf_count": 469
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 3,
+ "split_gain": 41.27640151977539,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0.246035,
+ "internal_weight": 531,
+ "internal_count": 531,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": 2.235297305276056,
+ "leaf_weight": 302,
+ "leaf_count": 302
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 2.1792953471546546,
+ "leaf_weight": 229,
+ "leaf_count": 229
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 1,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 2,
+ "split_gain": 64.22250366210938,
+ "threshold": "3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 0.03070842919354316,
+ "leaf_weight": 399,
+ "leaf_count": 399
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 0,
+ "split_gain": 36.74250030517578,
+ "threshold": 0.5102250691730842,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": -0.204906,
+ "internal_weight": 601,
+ "internal_count": 601,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.04439151147520909,
+ "leaf_weight": 315,
+ "leaf_count": 315
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.005117411709368601,
+ "leaf_weight": 286,
+ "leaf_count": 286
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 2,
+ "num_leaves": 3,
+ "num_cat": 0,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 1,
+ "split_gain": 57.1327018737793,
+ "threshold": 0.668665477622446,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 1,
+ "split_gain": 40.859100341796875,
+ "threshold": 0.008118820676863816,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": -0.162926,
+ "internal_weight": 681,
+ "internal_count": 681,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.15361238490967524,
+ "leaf_weight": 21,
+ "leaf_count": 21
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": -0.01192330846157292,
+ "leaf_weight": 660,
+ "leaf_count": 660
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": 0.03499044894987518,
+ "leaf_weight": 319,
+ "leaf_count": 319
+ }
+ }
+ },
+ {
+ "tree_index": 3,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 0,
+ "split_gain": 54.77090072631836,
+ "threshold": 0.5201391072644542,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.02141000620783247,
+ "leaf_weight": 543,
+ "leaf_count": 543
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 2,
+ "split_gain": 27.200700759887695,
+ "threshold": "0||1",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0.255704,
+ "internal_weight": 457,
+ "internal_count": 457,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.004121485787596721,
+ "leaf_weight": 191,
+ "leaf_count": 191
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.04534090904886873,
+ "leaf_weight": 266,
+ "leaf_count": 266
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 4,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 3,
+ "split_gain": 51.84349822998047,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 1,
+ "split_gain": 39.352699279785156,
+ "threshold": 0.27283279016959255,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0.188414,
+ "internal_weight": 593,
+ "internal_count": 593,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.01924803254356527,
+ "leaf_weight": 184,
+ "leaf_count": 184
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.03643772842347651,
+ "leaf_weight": 409,
+ "leaf_count": 409
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.02701711918923075,
+ "leaf_weight": 407,
+ "leaf_count": 407
+ }
+ }
+ }
+ ],
+ "pandas_categorical": [
+ [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e"
+ ],
+ [
+ "i",
+ "j",
+ "k",
+ "l",
+ "m"
+ ]
+ ]
+} \ No newline at end of file
diff --git a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java
index 59af3193b79..8c4c6aa7fc0 100644
--- a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java
+++ b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java
@@ -2,6 +2,7 @@
package com.yahoo.config.model;
import com.yahoo.component.Version;
+import com.yahoo.concurrent.InThreadExecutorService;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.application.api.FileRegistry;
@@ -18,6 +19,7 @@ import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.test.MockApplicationPackage;
import java.util.Optional;
+import java.util.concurrent.ExecutorService;
/**
* @author hmusum
@@ -83,4 +85,8 @@ public class MockModelContext implements ModelContext {
return new TestProperties();
}
+ @Override
+ public ExecutorService getExecutor() {
+ return new InThreadExecutorService();
+ }
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 8f3fbfc9de9..69789d09dc2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -118,8 +118,13 @@ public class RankProfileTestCase extends SchemaTestCase {
@Test
public void requireThatSidewaysInheritanceIsImpossible() throws ParseException {
+ verifySidewaysInheritance(false);
+ verifySidewaysInheritance(true);
+ }
+ private void verifySidewaysInheritance(boolean enforce) throws ParseException {
RankProfileRegistry registry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes());
+ SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes(),
+ new TestProperties().enforceRankProfileInheritance(enforce));
builder.importString(joinLines(
"schema child1 {",
" document child1 {",
@@ -163,7 +168,15 @@ public class RankProfileTestCase extends SchemaTestCase {
"}"));
try {
builder.build(true);
+ if (enforce) {
+ fail("Sideways inheritance should have been enforced");
+ } else {
+ assertNotNull(builder.getSearch("child2"));
+ assertNotNull(builder.getSearch("child1"));
+ assertTrue(registry.get("child1", "child").inherits("parent"));
+ }
} catch (IllegalArgumentException e) {
+ if (!enforce) fail("Sideways inheritance should have been allowed");
assertEquals("rank-profile 'child' inherits 'parent', but it does not exist anywhere in the inheritance of search 'child1'.", e.getMessage());
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
index d7143281977..d87278a9ca1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java
@@ -2,8 +2,10 @@
package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.application.provider.MockFileRegistry;
import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
@@ -11,7 +13,9 @@ import com.yahoo.searchdefinition.parser.ParseException;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import org.junit.Test;
+import java.util.ArrayList;
import java.util.Optional;
+import java.util.logging.Level;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -26,7 +30,7 @@ public class RankingExpressionInliningTestCase extends SchemaTestCase {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
builder.importString(
- "search test {\n" +
+ "search test {\n" +
" document test { \n" +
" field a type double { \n" +
" indexing: attribute \n" +
@@ -186,6 +190,39 @@ public class RankingExpressionInliningTestCase extends SchemaTestCase {
assertEquals("attribute(b) + 1", getRankingExpression("D", test, s));
}
+ @Test
+ public void testFunctionInliningWithReplacement() throws ParseException {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ MockDeployLogger deployLogger = new MockDeployLogger();
+ SearchBuilder builder = new SearchBuilder(MockApplicationPackage.createEmpty(),
+ new MockFileRegistry(),
+ deployLogger,
+ new TestProperties(),
+ rankProfileRegistry,
+ new QueryProfileRegistry());
+ builder.importString(
+ "search test {\n" +
+ " document test { }\n" +
+ " rank-profile test {\n" +
+ " first-phase {\n" +
+ " expression: foo\n" +
+ " }\n" +
+ " function foo(x) {\n" +
+ " expression: x + x\n" +
+ " }\n" +
+ " function inline foo() {\n" + // replaces previous "foo" during parsing
+ " expression: foo(2)\n" +
+ " }\n" +
+ " }\n" +
+ "}\n");
+ builder.build();
+ Search s = builder.getSearch();
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
+ assertEquals("foo(2)", test.getFirstPhaseRanking().getRoot().toString());
+ assertTrue("Does not contain expected warning", deployLogger.contains("Function 'foo' replaces " +
+ "a previous function with the same name in rank profile 'test'"));
+ }
+
/**
* Expression evaluation has no stack so function arguments are bound at config time creating a separate version of
* each function for each binding, using hashes to name the bound variants of the function.
@@ -221,4 +258,17 @@ public class RankingExpressionInliningTestCase extends SchemaTestCase {
return censorBindingHash(rankExpression.get());
}
+ private static class MockDeployLogger implements DeployLogger {
+ private final ArrayList<String> msgs = new ArrayList<>();
+
+ @Override
+ public void log(Level level, String message) {
+ msgs.add(message);
+ }
+
+ public boolean contains(String expected) {
+ return msgs.stream().anyMatch(msg -> msg.equals(expected));
+ }
+ }
+
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
index 8ef04752800..12263521dcb 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/ExportingTestCase.java
@@ -9,7 +9,6 @@ import org.junit.Test;
import java.io.IOException;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
/**
* Tests exporting
@@ -106,7 +105,7 @@ public class ExportingTestCase extends AbstractExportingTestCase {
@Test
public void testRankExpression() throws IOException, ParseException {
assertCorrectDeriving("rankexpression", null,
- new TestProperties().useExternalRankExpression(true).largeRankExpressionLimit(1024), new TestableDeployLogger());
+ new TestProperties().largeRankExpressionLimit(1024), new TestableDeployLogger());
}
@Test
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 010b33597f3..9c363ea0628 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -24,6 +24,8 @@ import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import static org.junit.Assert.assertEquals;
@@ -43,6 +45,7 @@ class RankProfileSearchFixture {
private final QueryProfileRegistry queryProfileRegistry;
private final Search search;
private final Map<String, RankProfile> compiledRankProfiles = new HashMap<>();
+ private final ExecutorService executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
public RankProfileRegistry getRankProfileRegistry() {
return rankProfileRegistry;
@@ -105,7 +108,7 @@ class RankProfileSearchFixture {
public RankProfile compileRankProfile(String rankProfile, Path applicationDir) {
RankProfile compiled = rankProfileRegistry.get(search, rankProfile)
.compile(queryProfileRegistry,
- new ImportedMlModels(applicationDir.toFile(), importers));
+ new ImportedMlModels(applicationDir.toFile(), executor, importers));
compiledRankProfiles.put(rankProfile, compiled);
return compiled;
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index 00ac5ac5405..b81fe7a02cc 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -36,7 +36,7 @@ public class RankingExpressionsTestCase extends SchemaTestCase {
@Test
public void testFunctions() throws IOException, ParseException {
- ModelContext.Properties deployProperties = new TestProperties().useExternalRankExpression(true);
+ ModelContext.Properties deployProperties = new TestProperties();
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
Search search = createSearch("src/test/examples/rankingexpressionfunction", deployProperties, rankProfileRegistry);
RankProfile functionsRankProfile = rankProfileRegistry.get(search, "macros");
@@ -115,7 +115,7 @@ public class RankingExpressionsTestCase extends SchemaTestCase {
@Test
public void testLargeInheritedFunctions() throws IOException, ParseException {
- ModelContext.Properties properties = new TestProperties().useExternalRankExpression(true).largeRankExpressionLimit(50);
+ ModelContext.Properties properties = new TestProperties().largeRankExpressionLimit(50);
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
LargeRankExpressions largeExpressions = new LargeRankExpressions(new MockFileRegistry());
QueryProfileRegistry queryProfiles = new QueryProfileRegistry();
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
index 771cba673bc..e6d3b5dc140 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
@@ -18,7 +18,7 @@ public class ModelsEvaluatorTest {
@Test
public void testModelsEvaluatorTester() {
ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval");
- assertEquals(2, modelsEvaluator.models().size());
+ assertEquals(3, modelsEvaluator.models().size());
// ONNX model evaluation
FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul");
@@ -27,6 +27,12 @@ public class ModelsEvaluatorTest {
Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);
+ // LightGBM model evaluation
+ FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression");
+ lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i");
+ output = lgbm.evaluate();
+ assertEquals(2.0547, output.sum().asDouble(), 1e-4);
+
// Vespa model evaluation
FunctionEvaluator foo1 = modelsEvaluator.evaluatorOf("example", "foo1");
input1 = Tensor.from("tensor(name{},x[3]):{{name:n,x:0}:1,{name:n,x:1}:2,{name:n,x:2}:3 }");