summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-02-02 17:39:44 +0100
committerLester Solbakken <lesters@oath.com>2020-02-02 17:39:44 +0100
commitf656ff5c15d95905f48d5829278ec241f1941577 (patch)
tree41d1fd4f8bc22df172acac42bfc39abd136036c0 /model-integration
parent99f3a7193090cfcd6b5fdbbe612f53d892f9d86b (diff)
Add support for importing LightGBM models
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/config/model-integration.xml3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java54
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMNode.java67
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMParser.java146
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/package-info.java5
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImportEvaluationTestCase.java49
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMTestBase.java42
-rw-r--r--model-integration/src/test/models/lightgbm/classification.json275
-rw-r--r--model-integration/src/test/models/lightgbm/regression.json275
-rwxr-xr-xmodel-integration/src/test/models/lightgbm/train_lightgbm_classification.py54
-rwxr-xr-xmodel-integration/src/test/models/lightgbm/train_lightgbm_regression.py53
11 files changed, 1022 insertions, 1 deletions
diff --git a/model-integration/src/main/config/model-integration.xml b/model-integration/src/main/config/model-integration.xml
index 90ec7d7275e..34f5f0ce31a 100644
--- a/model-integration/src/main/config/model-integration.xml
+++ b/model-integration/src/main/config/model-integration.xml
@@ -1,11 +1,12 @@
<!-- Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<!-- Component which can import some ml model.
This is included into the config server services.xml to enable it to translate
- model pseudofeatures in ranking expressions during config mddel building.
+ model pseudo features in ranking expressions during config model building.
It is provided as separate bundles instead of being included in the config model
because some of these (TensorFlow) includes
JNI code, and so can only exist in one instance in the server. -->
<component id="ai.vespa.rankingexpression.importer.onnx.OnnxImporter" bundle="model-integration" />
<component id="ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter" bundle="model-integration" />
<component id="ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter" bundle="model-integration" />
+<component id="ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter" bundle="model-integration" />
<component id="ai.vespa.rankingexpression.importer.vespa.VespaImporter" bundle="model-integration" />
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
new file mode 100644
index 00000000000..76caa652ad2
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java
@@ -0,0 +1,54 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.lightgbm;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.ModelImporter;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+
+import java.io.File;
+import java.io.IOException;
+
+/**
+ * Converts a LightGBM model into a ranking expression.
+ *
+ * @author lesters
+ */
+public class LightGBMImporter extends ModelImporter {
+
+ @Override
+ public boolean canImport(String modelPath) {
+ File modelFile = new File(modelPath);
+ if ( ! modelFile.isFile()) return false;
+ return modelFile.toString().endsWith(".json") && probe(modelFile);
+ }
+
+ /**
+ * Returns true if the give file looks like a LightGBM json file.
+ * Currently, we just check if the json has an element called "tree_info"
+ */
+ private boolean probe(File modelFile) {
+ try {
+ return new ObjectMapper().readTree(modelFile).has("tree_info");
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not read '" + modelFile + "'", e);
+ }
+ }
+
+ @Override
+ public ImportedModel importModel(String modelName, String modelPath) {
+ try {
+ ImportedModel model = new ImportedModel(modelName, modelPath);
+ LightGBMParser parser = new LightGBMParser(modelPath);
+ RankingExpression expression = new RankingExpression(parser.toRankingExpression());
+ model.expression(modelName, expression);
+ return model;
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import LightGBM model from '" + modelPath + "'", e);
+ } catch (ParseException e) {
+ throw new IllegalArgumentException("Could not parse ranking expression resulting from '" + modelPath + "'", e);
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMNode.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMNode.java
new file mode 100644
index 00000000000..dc76ed8cb6f
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMNode.java
@@ -0,0 +1,67 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.lightgbm;
+
+/**
+ * @author lesters
+ */
+public class LightGBMNode {
+
+ // split nodes
+ private int split_feature;
+ private String threshold; // double for numerical, string for categorical
+ private String decision_type;
+ private boolean default_left;
+ private String missing_type;
+ private int internal_count;
+ private LightGBMNode left_child;
+ private LightGBMNode right_child;
+
+ // leaf nodes
+ private double leaf_value;
+ private int leaf_count;
+
+ public int getSplit_feature() {
+ return split_feature;
+ }
+
+ public String getThreshold() {
+ return threshold;
+ }
+
+ public String getDecision_type() {
+ return decision_type;
+ }
+
+ public boolean isDefault_left() {
+ return default_left;
+ }
+
+ public String getMissing_type() {
+ return missing_type;
+ }
+
+ public int getInternal_count() {
+ return internal_count;
+ }
+
+ public LightGBMNode getLeft_child() {
+ return left_child;
+ }
+
+ public LightGBMNode getRight_child() {
+ return right_child;
+ }
+
+ public double getLeaf_value() {
+ return leaf_value;
+ }
+
+ public int getLeaf_count() {
+ return leaf_count;
+ }
+
+ public boolean isLeaf() {
+ return left_child == null && right_child == null;
+ }
+
+} \ No newline at end of file
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMParser.java
new file mode 100644
index 00000000000..996343674ce
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMParser.java
@@ -0,0 +1,146 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.lightgbm;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+/**
+ * @author lesters
+ */
+class LightGBMParser {
+
+ private final String objective;
+ private final List<LightGBMNode> nodes;
+ private final List<String> featureNames;
+ private final Map<Integer, List<String>> categoryValues; // pr feature index
+
+ LightGBMParser(String filePath) throws JsonProcessingException, IOException {
+ ObjectMapper mapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+ JsonNode root = mapper.readTree(new File(filePath));
+
+ objective = root.get("objective").asText("regression");
+ featureNames = parseFeatureNames(root);
+ nodes = parseTrees(mapper, root);
+ categoryValues = parseCategoryValues(root);
+ }
+
+ private List<String> parseFeatureNames(JsonNode root) {
+ List<String> features = new ArrayList<>();
+ for (JsonNode name : root.get("feature_names")) {
+ features.add(name.textValue());
+ }
+ return features;
+ }
+
+ private List<LightGBMNode> parseTrees(ObjectMapper mapper, JsonNode root) throws JsonProcessingException {
+ List<LightGBMNode> nodes = new ArrayList<>();
+ for (JsonNode treeNode : root.get("tree_info")) {
+ nodes.add(mapper.treeToValue(treeNode.get("tree_structure"), LightGBMNode.class));
+ }
+ return nodes;
+ }
+
+ private Map<Integer, List<String>> parseCategoryValues(JsonNode root) {
+ Map<Integer, List<String>> categoryValues = new HashMap<>();
+
+ // Since the JSON format does not explicitly tell which features are
+ // categorical, we traverse the decision tree looking for categorical
+ // decisions and use that to determine which categorical features.
+ Set<Integer> categoricalFeatures = new TreeSet<>();
+ nodes.forEach(node -> findCategoricalFeatures(node, categoricalFeatures));
+
+ // Again, the LightGBM JSON format does not explicitly tell which
+ // categorical values map to each categorical feature. The assumption
+ // here is that the order they appear in the "pandas_categorical"
+ // structure is the same order as the "feature_names".
+ var pandasFeatureIterator = root.get("pandas_categorical").iterator();
+ var categoricalFeatureIterator = categoricalFeatures.iterator();
+ while (pandasFeatureIterator.hasNext() && categoricalFeatureIterator.hasNext()) {
+ List<String> values = new ArrayList<>();
+ pandasFeatureIterator.next().forEach(value -> values.add(value.textValue()));
+ categoryValues.put(categoricalFeatureIterator.next(), values);
+ }
+
+ return categoryValues;
+ }
+
+ private void findCategoricalFeatures(LightGBMNode node, Set<Integer> categoricalFeatures) {
+ if (node == null || node.isLeaf()) {
+ return;
+ }
+ if (node.getDecision_type().equals("==")) {
+ categoricalFeatures.add(node.getSplit_feature());
+ }
+ findCategoricalFeatures(node.getLeft_child(), categoricalFeatures);
+ findCategoricalFeatures(node.getRight_child(), categoricalFeatures);
+ }
+
+ String toRankingExpression() {
+ return applyObjective(nodes.stream().map(this::nodeToRankingExpression).collect(Collectors.joining(" + \n")));
+ }
+
+ // See https://lightgbm.readthedocs.io/en/latest/Parameters.html#objective
+ private String applyObjective(String expression) {
+ if (objective.startsWith("binary") || objective.equals("cross_entropy")) {
+ return "sigmoid(" + expression + ")";
+ }
+ if (objective.equals("poisson") || objective.equals("gamma") || objective.equals("tweedie")) {
+ return "exp(" + expression + ")";
+ }
+ return expression; // else: use expression directly
+ }
+
+ private String nodeToRankingExpression(LightGBMNode node) {
+ if (node.isLeaf()) {
+ return Double.toString(node.getLeaf_value());
+ } else {
+ String condition;
+ String feature = featureNames.get(node.getSplit_feature());
+ if (node.getDecision_type().equals("==")) {
+ String values = transformCategoryIndexesToValues(node);
+ if (node.isDefault_left()) { // means go left (true) when isNan
+ condition = "isNan(" + feature + ") || (" + feature + " in [ " + values + "])";
+ } else {
+ condition = feature + " in [" + values + "]";
+ }
+ } else { // assumption: all other decision types are <=
+ double value = Double.parseDouble(node.getThreshold());
+ if (node.isDefault_left()) {
+ condition = "!(" + feature + " >= " + value + ")";
+ } else {
+ condition = feature + " < " + value;
+ }
+ }
+ String left = nodeToRankingExpression(node.getLeft_child());
+ String right = nodeToRankingExpression(node.getRight_child());
+ return "if (" + condition + ", " + left + ", " + right + ")";
+ }
+ }
+
+ private String transformCategoryIndexesToValues(LightGBMNode node) {
+ return Arrays.stream(node.getThreshold().split("\\|\\|"))
+ .map(index -> "\"" + transformCategoryIndexToValue(node.getSplit_feature(), index) + "\"")
+ .collect(Collectors.joining(","));
+ }
+
+ private String transformCategoryIndexToValue(int featureIndex, String valueIndex) {
+ if ( ! categoryValues.containsKey(featureIndex) ) {
+ return valueIndex; // We don't have a pandas categorical lookup table
+ }
+ return categoryValues.get(featureIndex).get(Integer.parseInt(valueIndex));
+ }
+
+} \ No newline at end of file
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/package-info.java
new file mode 100644
index 00000000000..b29145ee21b
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/package-info.java
@@ -0,0 +1,5 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+package ai.vespa.rankingexpression.importer.lightgbm;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImportEvaluationTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImportEvaluationTestCase.java
new file mode 100644
index 00000000000..d2ef4b984ff
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImportEvaluationTestCase.java
@@ -0,0 +1,49 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.lightgbm;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestNode;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class LightGBMImportEvaluationTestCase extends LightGBMTestBase {
+
+ @Test
+ public void testRegression() {
+ RankingExpression expression = importModel("src/test/models/lightgbm/regression.json");
+ ArrayContext context = new ArrayContext(expression, true, DoubleValue.NaN);
+
+ assertEvaluation(1.91300868, expression, features(context));
+ assertEvaluation(2.05469776, expression, features(context).add("numerical_1", 0.1).add("numerical_2", 0.2).add("categorical_1", "a").add("categorical_2", "i"));
+ assertEvaluation(2.0745534, expression, features(context).add("numerical_2", 0.5).add("categorical_1", "b").add("categorical_2", "j"));
+ assertEvaluation(2.3571838, expression, features(context).add("numerical_1", 0.7).add("numerical_2", 0.8).add("categorical_2", "m"));
+
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ assertTrue(expression.getRoot() instanceof GBDTForestNode);
+
+ assertEvaluation(1.91300868, expression, features(context));
+ assertEvaluation(2.05469776, expression, features(context).add("numerical_1", 0.1).add("numerical_2", 0.2).add("categorical_1", "a").add("categorical_2", "i"));
+ assertEvaluation(2.0745534, expression, features(context).add("numerical_2", 0.5).add("categorical_1", "b").add("categorical_2", "j"));
+ assertEvaluation(2.3571838, expression, features(context).add("numerical_1", 0.7).add("numerical_2", 0.8).add("categorical_2", "m"));
+ }
+
+ @Test
+ public void testClassification() {
+ RankingExpression expression = importModel("src/test/models/lightgbm/classification.json");
+ ArrayContext context = new ArrayContext(expression, DoubleValue.NaN);
+ assertEvaluation(0.37464997, expression, features(context));
+ assertEvaluation(0.37464997, expression, features(context).add("numerical_1", 0.1).add("numerical_2", 0.2).add("categorical_1", "a").add("categorical_2", "i"));
+ assertEvaluation(0.38730827, expression, features(context).add("numerical_2", 0.5).add("categorical_1", "b").add("categorical_2", "j"));
+ assertEvaluation(0.5647872, expression, features(context).add("numerical_1", 0.7).add("numerical_2", 0.8).add("categorical_2", "m"));
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMTestBase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMTestBase.java
new file mode 100644
index 00000000000..80c2ce68394
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMTestBase.java
@@ -0,0 +1,42 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.lightgbm;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+class LightGBMTestBase {
+
+ RankingExpression importModel(String path) {
+ return new LightGBMImporter().importModel("lightgbm", path).expressions().get("lightgbm");
+ }
+
+ void assertEvaluation(double expected, RankingExpression expr, TestFeatures features) {
+ assertEquals(expected, expr.evaluate(features.context).asDouble(), 1e-6);
+ }
+
+ TestFeatures features(ArrayContext context) {
+ return new TestFeatures(context.clone());
+ }
+
+ static class TestFeatures {
+ private final ArrayContext context;
+ TestFeatures(ArrayContext context) {
+ this.context = context;
+ }
+ TestFeatures add(String name, double value) {
+ context.put(name, value);
+ return this;
+ }
+ TestFeatures add(String name, String value) {
+ context.put(name, new StringValue(value));
+ return this;
+ }
+ }
+
+}
diff --git a/model-integration/src/test/models/lightgbm/classification.json b/model-integration/src/test/models/lightgbm/classification.json
new file mode 100644
index 00000000000..1087446519d
--- /dev/null
+++ b/model-integration/src/test/models/lightgbm/classification.json
@@ -0,0 +1,275 @@
+{
+ "name": "tree",
+ "version": "v3",
+ "num_class": 1,
+ "num_tree_per_iteration": 1,
+ "label_index": 0,
+ "max_feature_idx": 3,
+ "average_output": false,
+ "objective": "binary sigmoid:1",
+ "feature_names": [
+ "numerical_1",
+ "numerical_2",
+ "categorical_1",
+ "categorical_2"
+ ],
+ "monotone_constraints": [],
+ "tree_info": [
+ {
+ "tree_index": 0,
+ "num_leaves": 3,
+ "num_cat": 2,
+ "shrinkage": 1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 3,
+ "split_gain": 13080.099609375,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 100000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 2,
+ "split_gain": 8303.599609375,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0.598248,
+ "internal_weight": 14841.2,
+ "internal_count": 59371,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 0.10149588882231209,
+ "leaf_weight": 8812.104370772839,
+ "leaf_count": 35252
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": -0.05076009488472203,
+ "leaf_weight": 6029.137221112847,
+ "leaf_count": 24119
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.1075553310531564,
+ "leaf_weight": 10156.217760130763,
+ "leaf_count": 40629
+ }
+ }
+ },
+ {
+ "tree_index": 1,
+ "num_leaves": 3,
+ "num_cat": 0,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 1,
+ "split_gain": 12144.5,
+ "threshold": 0.4932456977560694,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 100000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.07039230856418545,
+ "leaf_weight": 12362.572675153613,
+ "leaf_count": 49561
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 0,
+ "split_gain": 6445.509765625,
+ "threshold": 0.4026061210695467,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0.691647,
+ "internal_weight": 12581.6,
+ "internal_count": 50439,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.016713933964828474,
+ "leaf_weight": 5157.183633238077,
+ "leaf_count": 20675
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.12881836794307533,
+ "leaf_weight": 7424.385220557451,
+ "leaf_count": 29764
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 2,
+ "num_leaves": 3,
+ "num_cat": 2,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 2,
+ "split_gain": 11470.099609375,
+ "threshold": "3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 100000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 0.0837843210726433,
+ "leaf_weight": 9858.360527098179,
+ "leaf_count": 39612
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 3,
+ "split_gain": 8077.8701171875,
+ "threshold": "3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": -0.549408,
+ "internal_weight": 15039.7,
+ "internal_count": 60388,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": 0.035561394754096094,
+ "leaf_weight": 5955.117423638701,
+ "leaf_count": 23896
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": -0.11424082565448186,
+ "leaf_weight": 9084.538012728095,
+ "leaf_count": 36492
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 3,
+ "num_leaves": 3,
+ "num_cat": 0,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 0,
+ "split_gain": 11022.599609375,
+ "threshold": 0.5135386524711826,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 100000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 1,
+ "split_gain": 5789.919921875,
+ "threshold": 0.6237474076885036,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": -0.641438,
+ "internal_weight": 12881.9,
+ "internal_count": 51907,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.11613056205533928,
+ "leaf_weight": 8044.6355674266815,
+ "leaf_count": 32426
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.022313103333779363,
+ "leaf_weight": 4837.266924858093,
+ "leaf_count": 19481
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": 0.06927713686880098,
+ "leaf_weight": 11923.512641906738,
+ "leaf_count": 48093
+ }
+ }
+ },
+ {
+ "tree_index": 4,
+ "num_leaves": 3,
+ "num_cat": 2,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 3,
+ "split_gain": 9828.9501953125,
+ "threshold": "3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 100000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 0.07771712562582928,
+ "leaf_weight": 9804.427681803703,
+ "leaf_count": 39586
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 2,
+ "split_gain": 6332.2900390625,
+ "threshold": "3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": -0.51112,
+ "internal_weight": 14922.7,
+ "internal_count": 60414,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": 0.029062142260340918,
+ "leaf_weight": 5933.120021238923,
+ "leaf_count": 23922
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": -0.10400033924773491,
+ "leaf_weight": 8989.602796778083,
+ "leaf_count": 36492
+ }
+ }
+ }
+ }
+ ],
+ "pandas_categorical": [
+ [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e"
+ ],
+ [
+ "i",
+ "j",
+ "k",
+ "l",
+ "m"
+ ]
+ ]
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/lightgbm/regression.json b/model-integration/src/test/models/lightgbm/regression.json
new file mode 100644
index 00000000000..cf0488ecd8b
--- /dev/null
+++ b/model-integration/src/test/models/lightgbm/regression.json
@@ -0,0 +1,275 @@
+{
+ "name": "tree",
+ "version": "v3",
+ "num_class": 1,
+ "num_tree_per_iteration": 1,
+ "label_index": 0,
+ "max_feature_idx": 3,
+ "average_output": false,
+ "objective": "regression",
+ "feature_names": [
+ "numerical_1",
+ "numerical_2",
+ "categorical_1",
+ "categorical_2"
+ ],
+ "monotone_constraints": [],
+ "tree_info": [
+ {
+ "tree_index": 0,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 1,
+ "split_gain": 68.5353012084961,
+ "threshold": 0.46643291586559305,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 2.1594397038037663,
+ "leaf_weight": 469,
+ "leaf_count": 469
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 3,
+ "split_gain": 41.27640151977539,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0.246035,
+ "internal_weight": 531,
+ "internal_count": 531,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": 2.235297305276056,
+ "leaf_weight": 302,
+ "leaf_count": 302
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 2.1792953471546546,
+ "leaf_weight": 229,
+ "leaf_count": 229
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 1,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 2,
+ "split_gain": 64.22250366210938,
+ "threshold": "3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": 0.03070842919354316,
+ "leaf_weight": 399,
+ "leaf_count": 399
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 0,
+ "split_gain": 36.74250030517578,
+ "threshold": 0.5102250691730842,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": -0.204906,
+ "internal_weight": 601,
+ "internal_count": 601,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.04439151147520909,
+ "leaf_weight": 315,
+ "leaf_count": 315
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.005117411709368601,
+ "leaf_weight": 286,
+ "leaf_count": 286
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 2,
+ "num_leaves": 3,
+ "num_cat": 0,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 1,
+ "split_gain": 57.1327018737793,
+ "threshold": 0.668665477622446,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 1,
+ "split_gain": 40.859100341796875,
+ "threshold": 0.008118820676863816,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": -0.162926,
+ "internal_weight": 681,
+ "internal_count": 681,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.15361238490967524,
+ "leaf_weight": 21,
+ "leaf_count": 21
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": -0.01192330846157292,
+ "leaf_weight": 660,
+ "leaf_count": 660
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": 0.03499044894987518,
+ "leaf_weight": 319,
+ "leaf_count": 319
+ }
+ }
+ },
+ {
+ "tree_index": 3,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 0,
+ "split_gain": 54.77090072631836,
+ "threshold": 0.5201391072644542,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.02141000620783247,
+ "leaf_weight": 543,
+ "leaf_count": 543
+ },
+ "right_child": {
+ "split_index": 1,
+ "split_feature": 2,
+ "split_gain": 27.200700759887695,
+ "threshold": "0||1",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0.255704,
+ "internal_weight": 457,
+ "internal_count": 457,
+ "left_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.004121485787596721,
+ "leaf_weight": 191,
+ "leaf_count": 191
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.04534090904886873,
+ "leaf_weight": 266,
+ "leaf_count": 266
+ }
+ }
+ }
+ },
+ {
+ "tree_index": 4,
+ "num_leaves": 3,
+ "num_cat": 1,
+ "shrinkage": 0.1,
+ "tree_structure": {
+ "split_index": 0,
+ "split_feature": 3,
+ "split_gain": 51.84349822998047,
+ "threshold": "2||3||4",
+ "decision_type": "==",
+ "default_left": false,
+ "missing_type": "NaN",
+ "internal_value": 0,
+ "internal_weight": 0,
+ "internal_count": 1000,
+ "left_child": {
+ "split_index": 1,
+ "split_feature": 1,
+ "split_gain": 39.352699279785156,
+ "threshold": 0.27283279016959255,
+ "decision_type": "<=",
+ "default_left": true,
+ "missing_type": "NaN",
+ "internal_value": 0.188414,
+ "internal_weight": 593,
+ "internal_count": 593,
+ "left_child": {
+ "leaf_index": 0,
+ "leaf_value": -0.01924803254356527,
+ "leaf_weight": 184,
+ "leaf_count": 184
+ },
+ "right_child": {
+ "leaf_index": 2,
+ "leaf_value": 0.03643772842347651,
+ "leaf_weight": 409,
+ "leaf_count": 409
+ }
+ },
+ "right_child": {
+ "leaf_index": 1,
+ "leaf_value": -0.02701711918923075,
+ "leaf_weight": 407,
+ "leaf_count": 407
+ }
+ }
+ }
+ ],
+ "pandas_categorical": [
+ [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e"
+ ],
+ [
+ "i",
+ "j",
+ "k",
+ "l",
+ "m"
+ ]
+ ]
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/lightgbm/train_lightgbm_classification.py b/model-integration/src/test/models/lightgbm/train_lightgbm_classification.py
new file mode 100755
index 00000000000..ac00437d192
--- /dev/null
+++ b/model-integration/src/test/models/lightgbm/train_lightgbm_classification.py
@@ -0,0 +1,54 @@
+#! /usr/bin/env python3
+# coding: utf-8
+
+import json
+import random
+
+import lightgbm as lgb
+import numpy as np
+import pandas as pd
+
+
+def category_value(arr):
+ values = { np.NaN: 0, "a":1, "b":2, "c":3, "d":4, "e":5, "i":1, "j":2, "k":3, "l":4, "m":5 }
+ return np.array([ 0.21 * values[i] for i in arr ])
+
+# Create training set
+num_examples = 100000
+missing_prob = 0.01
+features = pd.DataFrame({
+ "numerical_1": np.random.random(num_examples),
+ "numerical_2": np.random.random(num_examples),
+ "categorical_1": pd.Series(np.random.permutation(["a", "b", "c", "d", "e"] * int(num_examples/5)), dtype="category"),
+ "categorical_2": pd.Series(np.random.permutation(["i", "j", "k", "l", "m"] * int(num_examples/5)), dtype="category"),
+ })
+
+# randomly insert missing values
+for i in range(int(num_examples * len(features.columns) * missing_prob)):
+ features.loc[random.randint(0, num_examples-1), features.columns[random.randint(0, len(features.columns)-1)]] = None
+
+# create targets (with 0.0 as default for missing values)
+target = features["numerical_1"] + features["numerical_2"] + category_value(features["categorical_1"]) + category_value(features["categorical_2"])
+target = (target > 2.24) * 1.0
+lgb_train = lgb.Dataset(features, target)
+
+# Train model
+params = {
+ 'objective': 'binary',
+ 'metric': 'binary_logloss',
+ 'num_leaves': 3,
+}
+model = lgb.train(params, lgb_train, num_boost_round=5)
+
+# Save model
+with open("classification.json", "w") as f:
+ json.dump(model.dump_model(), f, indent=2)
+
+# Predict (for comparison with Vespa evaluation)
+predict_features = pd.DataFrame({
+ "numerical_1": pd.Series([ None, 0.1, None, 0.7]),
+ "numerical_2": pd.Series([np.NaN, 0.2, 0.5, 0.8]),
+ "categorical_1": pd.Series([ None, "a", "b", None], dtype="category"),
+ "categorical_2": pd.Series([ None, "i", "j", "m"], dtype="category"),
+ })
+print(model.predict(predict_features))
diff --git a/model-integration/src/test/models/lightgbm/train_lightgbm_regression.py b/model-integration/src/test/models/lightgbm/train_lightgbm_regression.py
new file mode 100755
index 00000000000..3e74e38da35
--- /dev/null
+++ b/model-integration/src/test/models/lightgbm/train_lightgbm_regression.py
@@ -0,0 +1,53 @@
+#! /usr/bin/env python3
+# coding: utf-8
+
+import json
+import random
+
+import lightgbm as lgb
+import numpy as np
+import pandas as pd
+
+
+def category_value(arr):
+ values = { np.NaN: 0, "a":1, "b":2, "c":3, "d":4, "e":5, "i":1, "j":2, "k":3, "l":4, "m":5 }
+ return np.array([ 0.21 * values[i] for i in arr ])
+
+# Create training set
+num_examples = 100000
+missing_prob = 0.01
+features = pd.DataFrame({
+ "numerical_1": np.random.random(num_examples),
+ "numerical_2": np.random.random(num_examples),
+ "categorical_1": pd.Series(np.random.permutation(["a", "b", "c", "d", "e"] * int(num_examples/5)), dtype="category"),
+ "categorical_2": pd.Series(np.random.permutation(["i", "j", "k", "l", "m"] * int(num_examples/5)), dtype="category"),
+ })
+
+# randomly insert missing values
+for i in range(int(num_examples * len(features.columns) * missing_prob)):
+ features.loc[random.randint(0, num_examples-1), features.columns[random.randint(0, len(features.columns)-1)]] = None
+
+# create targets (with 0.0 as default for missing values)
+target = features["numerical_1"] + features["numerical_2"] + category_value(features["categorical_1"]) + category_value(features["categorical_2"])
+lgb_train = lgb.Dataset(features, target)
+
+# Train model
+params = {
+ 'objective': 'mse',
+ 'metric': {'l2', 'l1'},
+ 'num_leaves': 3,
+}
+model = lgb.train(params, lgb_train, num_boost_round=2)
+
+# Save model
+with open("regression.json", "w") as f:
+ json.dump(model.dump_model(), f, indent=2)
+
+# Predict (for comparison with Vespa evaluation)
+predict_features = pd.DataFrame({
+ "numerical_1": pd.Series([ None, 0.1, None, 0.7]),
+ "numerical_2": pd.Series([np.NaN, 0.2, 0.5, 0.8]),
+ "categorical_1": pd.Series([ None, "a", "b", None], dtype="category"),
+ "categorical_2": pd.Series([ None, "i", "j", "m"], dtype="category"),
+ })
+print(model.predict(predict_features))