summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java58
-rw-r--r--config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json275
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java8
4 files changed, 330 insertions, 13 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java
index 1817f09ae46..e8a53b95566 100644
--- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java
+++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java
@@ -3,8 +3,6 @@ package com.yahoo.config.model.application.provider;
import com.yahoo.config.FileReference;
import com.yahoo.config.application.api.FileRegistry;
-import com.yahoo.net.HostName;
-import net.jpountz.xxhash.XXHashFactory;
import java.nio.ByteBuffer;
import java.util.ArrayList;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
index b98cabb6f33..6d723961acb 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java
@@ -9,7 +9,10 @@ import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.google.common.collect.ImmutableList;
+import com.yahoo.config.FileReference;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.application.api.FileRegistry;
+import com.yahoo.config.model.application.provider.MockFileRegistry;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -21,10 +24,13 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import com.yahoo.vespa.model.VespaModel;
+import net.jpountz.lz4.LZ4FrameOutputStream;
import org.xml.sax.SAXException;
import java.io.File;
+import java.io.FileOutputStream;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
@@ -66,13 +72,14 @@ public class ModelsEvaluatorTester {
File temporaryApplicationDir = null;
try {
temporaryApplicationDir = createTemporaryApplicationDir(modelsPath);
- RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir);
+ MockFileRegistry fileRegistry = new MockFileBlobRegistry(temporaryApplicationDir);
+ RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir, fileRegistry);
RankProfilesConfig rankProfilesConfig = getRankProfilesConfig(rankProfileList);
RankingConstantsConfig rankingConstantsConfig = getRankingConstantConfig(rankProfileList);
RankingExpressionsConfig rankingExpressionsConfig = getRankingExpressionsConfig(rankProfileList);
OnnxModelsConfig onnxModelsConfig = getOnnxModelsConfig(rankProfileList);
- FileAcquirer files = createFileAcquirer(rankingConstantsConfig, onnxModelsConfig, temporaryApplicationDir);
+ FileAcquirer files = createFileAcquirer(fileRegistry, temporaryApplicationDir);
return new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, rankingExpressionsConfig, onnxModelsConfig, files);
@@ -93,12 +100,16 @@ public class ModelsEvaluatorTester {
return temporaryApplicationDir;
}
- private static RankProfileList createRankProfileList(File appDir) throws IOException, SAXException {
+ private static RankProfileList createRankProfileList(File appDir, FileRegistry registry) throws IOException, SAXException {
ApplicationPackage app = new MockApplicationPackage.Builder()
.withEmptyHosts()
.withServices(modelEvaluationServices)
.withRoot(appDir).build();
- DeployState deployState = new DeployState.Builder().applicationPackage(app).modelImporters(importers).build();
+ DeployState deployState = new DeployState.Builder()
+ .applicationPackage(app)
+ .fileRegistry(registry)
+ .modelImporters(importers).build();
+
VespaModel vespaModel = new VespaModel(deployState);
return vespaModel.rankProfileList();
}
@@ -127,13 +138,10 @@ public class ModelsEvaluatorTester {
return builder.build();
}
- private static FileAcquirer createFileAcquirer(RankingConstantsConfig constantsConfig, OnnxModelsConfig onnxModelsConfig, File appDir) {
+ private static FileAcquirer createFileAcquirer(MockFileRegistry fileRegistry, File appDir) {
Map<String, File> fileMap = new HashMap<>();
- for (RankingConstantsConfig.Constant constant : constantsConfig.constant()) {
- fileMap.put(constant.fileref().value(), relativePath(appDir, constant.fileref().value()));
- }
- for (OnnxModelsConfig.Model model : onnxModelsConfig.model()) {
- fileMap.put(model.fileref().value(), relativePath(appDir, model.fileref().value()));
+ for (FileRegistry.Entry entry : fileRegistry.export()) {
+ fileMap.put(entry.reference.value(), relativePath(appDir, entry.reference.value()));
}
return MockFileAcquirer.returnFiles(fileMap);
}
@@ -142,4 +150,34 @@ public class ModelsEvaluatorTester {
return new File(root.getAbsolutePath() + File.separator + subpath);
}
+ private static class MockFileBlobRegistry extends MockFileRegistry {
+
+ private final File appDir;
+
+ MockFileBlobRegistry(File appdir) {
+ this.appDir = appdir;
+ }
+
+ @Override
+ public FileReference addBlob(String name, ByteBuffer blob) {
+ writeBlob(blob, name);
+ return addFile(name);
+ }
+
+ private void writeBlob(ByteBuffer blob, String relativePath) {
+ try (FileOutputStream fos = new FileOutputStream(new File(appDir, relativePath))) {
+ if (relativePath.endsWith(".lz4")) {
+ LZ4FrameOutputStream lz4 = new LZ4FrameOutputStream(fos);
+ lz4.write(blob.array(), blob.arrayOffset(), blob.remaining());
+ lz4.close();
+ } else {
+ fos.write(blob.array(), blob.arrayOffset(), blob.remaining());
+ }
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Failed writing temp file", e);
+ }
+ }
+
+ }
+
}
diff --git a/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json
new file mode 100644
index 00000000000..cf0488ecd8b
--- /dev/null
+++ b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json
@@ -0,0 +1,275 @@
+{
+ "name": "tree",
+ "version": "v3",
+ "num_class": 1,
+ "num_tree_per_iteration": 1,
+ "label_index": 0,
+ "max_feature_idx": 3,
+ "average_output": false,
+ "objective": "regression",
+ "feature_names": [
+ "numerical_1",
+ "numerical_2",
+ "categorical_1",
+ "categorical_2"
+ ],
+ "monotone_constraints": [],
+ "tree_info": [
+ {
+ "tree_index": 0,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 1,
+ "split_gain": 68.5353012084961,
+ "threshold": 0.46643291586559305,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 2.1594397038037663,
+ "leaf_weight": 469,
+ "leaf_count": 469
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 3,
+ "split_gain": 41.27640151977539,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0.246035,
+ "internal_weight": 531,
+ "internal_count": 531,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": 2.235297305276056,
+ "leaf_weight": 302,
+ "leaf_count": 302
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 2.1792953471546546,
+ "leaf_weight": 229,
+ "leaf_count": 229
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 1,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 2,
+ "split_gain": 64.22250366210938,
+ "threshold": "3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 0.03070842919354316,
+ "leaf_weight": 399,
+ "leaf_count": 399
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 0,
+ "split_gain": 36.74250030517578,
+ "threshold": 0.5102250691730842,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": -0.204906,
+ "internal_weight": 601,
+ "internal_count": 601,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.04439151147520909,
+ "leaf_weight": 315,
+ "leaf_count": 315
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.005117411709368601,
+ "leaf_weight": 286,
+ "leaf_count": 286
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 2,
+ "num_leaves": 3,
+ "num_cat": 0,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 1,
+ "split_gain": 57.1327018737793,
+ "threshold": 0.668665477622446,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 1,
+ "split_gain": 40.859100341796875,
+ "threshold": 0.008118820676863816,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": -0.162926,
+ "internal_weight": 681,
+ "internal_count": 681,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.15361238490967524,
+ "leaf_weight": 21,
+ "leaf_count": 21
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": -0.01192330846157292,
+ "leaf_weight": 660,
+ "leaf_count": 660
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": 0.03499044894987518,
+ "leaf_weight": 319,
+ "leaf_count": 319
+ }
+ }
+ },
+ {
+ "tree_index": 3,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 0,
+ "split_gain": 54.77090072631836,
+ "threshold": 0.5201391072644542,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.02141000620783247,
+ "leaf_weight": 543,
+ "leaf_count": 543
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 2,
+ "split_gain": 27.200700759887695,
+ "threshold": "0||1",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0.255704,
+ "internal_weight": 457,
+ "internal_count": 457,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.004121485787596721,
+ "leaf_weight": 191,
+ "leaf_count": 191
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.04534090904886873,
+ "leaf_weight": 266,
+ "leaf_count": 266
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 4,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 3,
+ "split_gain": 51.84349822998047,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 1,
+ "split_gain": 39.352699279785156,
+ "threshold": 0.27283279016959255,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0.188414,
+ "internal_weight": 593,
+ "internal_count": 593,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.01924803254356527,
+ "leaf_weight": 184,
+ "leaf_count": 184
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.03643772842347651,
+ "leaf_weight": 409,
+ "leaf_count": 409
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.02701711918923075,
+ "leaf_weight": 407,
+ "leaf_count": 407
+ }
+ }
+ }
+ ],
+ "pandas_categorical": [
+ [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e"
+ ],
+ [
+ "i",
+ "j",
+ "k",
+ "l",
+ "m"
+ ]
+ ]
+} \ No newline at end of file
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
index 771cba673bc..e6d3b5dc140 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
@@ -18,7 +18,7 @@ public class ModelsEvaluatorTest {
@Test
public void testModelsEvaluatorTester() {
ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval");
- assertEquals(2, modelsEvaluator.models().size());
+ assertEquals(3, modelsEvaluator.models().size());
// ONNX model evaluation
FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul");
@@ -27,6 +27,12 @@ public class ModelsEvaluatorTest {
Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);
+ // LightGBM model evaluation
+ FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression");
+ lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i");
+ output = lgbm.evaluate();
+ assertEquals(2.0547, output.sum().asDouble(), 1e-4);
+
// Vespa model evaluation
FunctionEvaluator foo1 = modelsEvaluator.evaluatorOf("example", "foo1");
input1 = Tensor.from("tensor(name{},x[3]):{{name:n,x:0}:1,{name:n,x:1}:2,{name:n,x:2}:3 }");