diff options
4 files changed, 330 insertions, 13 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java index 1817f09ae46..e8a53b95566 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/MockFileRegistry.java @@ -3,8 +3,6 @@ package com.yahoo.config.model.application.provider; import com.yahoo.config.FileReference; import com.yahoo.config.application.api.FileRegistry; -import com.yahoo.net.HostName; -import net.jpountz.xxhash.XXHashFactory; import java.nio.ByteBuffer; import java.util.ArrayList; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java index b98cabb6f33..6d723961acb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTester.java @@ -9,7 +9,10 @@ import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.google.common.collect.ImmutableList; +import com.yahoo.config.FileReference; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.config.model.application.provider.MockFileRegistry; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -21,10 +24,13 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; import com.yahoo.vespa.model.VespaModel; +import net.jpountz.lz4.LZ4FrameOutputStream; import org.xml.sax.SAXException; import java.io.File; +import java.io.FileOutputStream; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; @@ -66,13 +72,14 @@ public class ModelsEvaluatorTester { File temporaryApplicationDir = null; try { temporaryApplicationDir = createTemporaryApplicationDir(modelsPath); - RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir); + MockFileRegistry fileRegistry = new MockFileBlobRegistry(temporaryApplicationDir); + RankProfileList rankProfileList = createRankProfileList(temporaryApplicationDir, fileRegistry); RankProfilesConfig rankProfilesConfig = getRankProfilesConfig(rankProfileList); RankingConstantsConfig rankingConstantsConfig = getRankingConstantConfig(rankProfileList); RankingExpressionsConfig rankingExpressionsConfig = getRankingExpressionsConfig(rankProfileList); OnnxModelsConfig onnxModelsConfig = getOnnxModelsConfig(rankProfileList); - FileAcquirer files = createFileAcquirer(rankingConstantsConfig, onnxModelsConfig, temporaryApplicationDir); + FileAcquirer files = createFileAcquirer(fileRegistry, temporaryApplicationDir); return new ModelsEvaluator(rankProfilesConfig, rankingConstantsConfig, rankingExpressionsConfig, onnxModelsConfig, files); @@ -93,12 +100,16 @@ public class ModelsEvaluatorTester { return temporaryApplicationDir; } - private static RankProfileList createRankProfileList(File appDir) throws IOException, SAXException { + private static RankProfileList createRankProfileList(File appDir, FileRegistry registry) throws IOException, SAXException { ApplicationPackage app = new MockApplicationPackage.Builder() .withEmptyHosts() .withServices(modelEvaluationServices) .withRoot(appDir).build(); - DeployState deployState = new DeployState.Builder().applicationPackage(app).modelImporters(importers).build(); + DeployState deployState = new DeployState.Builder() + .applicationPackage(app) + .fileRegistry(registry) + .modelImporters(importers).build(); + VespaModel vespaModel = new VespaModel(deployState); return vespaModel.rankProfileList(); } @@ -127,13 +138,10 @@ public class ModelsEvaluatorTester { return builder.build(); } - private static FileAcquirer createFileAcquirer(RankingConstantsConfig constantsConfig, OnnxModelsConfig onnxModelsConfig, File appDir) { + private static FileAcquirer createFileAcquirer(MockFileRegistry fileRegistry, File appDir) { Map<String, File> fileMap = new HashMap<>(); - for (RankingConstantsConfig.Constant constant : constantsConfig.constant()) { - fileMap.put(constant.fileref().value(), relativePath(appDir, constant.fileref().value())); - } - for (OnnxModelsConfig.Model model : onnxModelsConfig.model()) { - fileMap.put(model.fileref().value(), relativePath(appDir, model.fileref().value())); + for (FileRegistry.Entry entry : fileRegistry.export()) { + fileMap.put(entry.reference.value(), relativePath(appDir, entry.reference.value())); } return MockFileAcquirer.returnFiles(fileMap); } @@ -142,4 +150,34 @@ public class ModelsEvaluatorTester { return new File(root.getAbsolutePath() + File.separator + subpath); } + private static class MockFileBlobRegistry extends MockFileRegistry { + + private final File appDir; + + MockFileBlobRegistry(File appdir) { + this.appDir = appdir; + } + + @Override + public FileReference addBlob(String name, ByteBuffer blob) { + writeBlob(blob, name); + return addFile(name); + } + + private void writeBlob(ByteBuffer blob, String relativePath) { + try (FileOutputStream fos = new FileOutputStream(new File(appDir, relativePath))) { + if (relativePath.endsWith(".lz4")) { + LZ4FrameOutputStream lz4 = new LZ4FrameOutputStream(fos); + lz4.write(blob.array(), blob.arrayOffset(), blob.remaining()); + lz4.close(); + } else { + fos.write(blob.array(), blob.arrayOffset(), blob.remaining()); + } + } catch (IOException e) { + throw new IllegalArgumentException("Failed writing temp file", e); + } + } + + } + } diff --git a/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json new file mode 100644 index 00000000000..cf0488ecd8b --- /dev/null +++ b/config-model/src/test/cfg/application/stateless_eval/lightgbm_regression.json @@ -0,0 +1,275 @@ +{ + "name": "tree", + "version": "v3", + "num_class": 1, + "num_tree_per_iteration": 1, + "label_index": 0, + "max_feature_idx": 3, + "average_output": false, + "objective": "regression", + "feature_names": [ + "numerical_1", + "numerical_2", + "categorical_1", + "categorical_2" + ], + "monotone_constraints": [], + "tree_info": [ + { + "tree_index": 0, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 68.5353012084961, + "threshold": 0.46643291586559305, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 2.1594397038037663, + "leaf_weight": 469, + "leaf_count": 469 + }, + "right_child": { + "split_index": 1, + "split_feature": 3, + "split_gain": 41.27640151977539, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.246035, + "internal_weight": 531, + "internal_count": 531, + "left_child": { + "leaf_index": 1, + "leaf_value": 2.235297305276056, + "leaf_weight": 302, + "leaf_count": 302 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 2.1792953471546546, + "leaf_weight": 229, + "leaf_count": 229 + } + } + } + }, + { + "tree_index": 1, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 2, + "split_gain": 64.22250366210938, + "threshold": "3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 0.03070842919354316, + "leaf_weight": 399, + "leaf_count": 399 + }, + "right_child": { + "split_index": 1, + "split_feature": 0, + "split_gain": 36.74250030517578, + "threshold": 0.5102250691730842, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.204906, + "internal_weight": 601, + "internal_count": 601, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.04439151147520909, + "leaf_weight": 315, + "leaf_count": 315 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.005117411709368601, + "leaf_weight": 286, + "leaf_count": 286 + } + } + } + }, + { + "tree_index": 2, + "num_leaves": 3, + "num_cat": 0, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 57.1327018737793, + "threshold": 0.668665477622446, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 40.859100341796875, + "threshold": 0.008118820676863816, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.162926, + "internal_weight": 681, + "internal_count": 681, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.15361238490967524, + "leaf_weight": 21, + "leaf_count": 21 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": -0.01192330846157292, + "leaf_weight": 660, + "leaf_count": 660 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": 0.03499044894987518, + "leaf_weight": 319, + "leaf_count": 319 + } + } + }, + { + "tree_index": 3, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 0, + "split_gain": 54.77090072631836, + "threshold": 0.5201391072644542, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.02141000620783247, + "leaf_weight": 543, + "leaf_count": 543 + }, + "right_child": { + "split_index": 1, + "split_feature": 2, + "split_gain": 27.200700759887695, + "threshold": "0||1", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.255704, + "internal_weight": 457, + "internal_count": 457, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.004121485787596721, + "leaf_weight": 191, + "leaf_count": 191 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.04534090904886873, + "leaf_weight": 266, + "leaf_count": 266 + } + } + } + }, + { + "tree_index": 4, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 3, + "split_gain": 51.84349822998047, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 39.352699279785156, + "threshold": 0.27283279016959255, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0.188414, + "internal_weight": 593, + "internal_count": 593, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.01924803254356527, + "leaf_weight": 184, + "leaf_count": 184 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.03643772842347651, + "leaf_weight": 409, + "leaf_count": 409 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": -0.02701711918923075, + "leaf_weight": 407, + "leaf_count": 407 + } + } + } + ], + "pandas_categorical": [ + [ + "a", + "b", + "c", + "d", + "e" + ], + [ + "i", + "j", + "k", + "l", + "m" + ] + ] +}
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java index 771cba673bc..e6d3b5dc140 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java @@ -18,7 +18,7 @@ public class ModelsEvaluatorTest { @Test public void testModelsEvaluatorTester() { ModelsEvaluator modelsEvaluator = ModelsEvaluatorTester.create("src/test/cfg/application/stateless_eval"); - assertEquals(2, modelsEvaluator.models().size()); + assertEquals(3, modelsEvaluator.models().size()); // ONNX model evaluation FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul"); @@ -27,6 +27,12 @@ public class ModelsEvaluatorTest { Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate(); assertEquals(6.0, output.sum().asDouble(), 1e-9); + // LightGBM model evaluation + FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression"); + lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i"); + output = lgbm.evaluate(); + assertEquals(2.0547, output.sum().asDouble(), 1e-4); + // Vespa model evaluation FunctionEvaluator foo1 = modelsEvaluator.evaluatorOf("example", "foo1"); input1 = Tensor.from("tensor(name{},x[3]):{{name:n,x:0}:1,{name:n,x:1}:2,{name:n,x:2}:3 }"); |