diff options
author | Lester Solbakken <lesters@oath.com> | 2020-02-02 17:39:44 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-02-02 17:39:44 +0100 |
commit | f656ff5c15d95905f48d5829278ec241f1941577 (patch) | |
tree | 41d1fd4f8bc22df172acac42bfc39abd136036c0 /config-model | |
parent | 99f3a7193090cfcd6b5fdbbe612f53d892f9d86b (diff) |
Add support for importing LightGBM models
Diffstat (limited to 'config-model')
11 files changed, 994 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index 6fdf448a39b..a6707ec7ac0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -27,6 +27,7 @@ public class ExpressionTransforms { ImmutableList.of(new TensorFlowFeatureConverter(), new OnnxFeatureConverter(), new XgboostFeatureConverter(), + new LightGBMFeatureConverter(), new ConstantDereferencer(), new ConstantTensorTransformer(), new FunctionInliner(), diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/LightGBMFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/LightGBMFeatureConverter.java new file mode 100644 index 00000000000..5bde627dc0a --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/LightGBMFeatureConverter.java @@ -0,0 +1,59 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.vespa.model.ml.ConvertedModel; +import com.yahoo.vespa.model.ml.FeatureArguments; + +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; + +/** + * Replaces instances of the lightgbm(model-path) pseudofeature with the + * native Vespa ranking expression implementing the same computation. + * + * @author lesters + */ +public class LightGBMFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { + + /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ + private final Map<Path, ConvertedModel> convertedLightGBMModels = new HashMap<>(); + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if ( ! feature.getName().equals("lightgbm")) return feature; + + try { + FeatureArguments arguments = asFeatureArguments(feature.getArguments()); + ConvertedModel convertedModel = + convertedLightGBMModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, true, context)); + return convertedModel.expression(arguments, context); + } catch (IllegalArgumentException | UncheckedIOException e) { + throw new IllegalArgumentException("Could not use LightGBM model from " + feature, e); + } + } + + private FeatureArguments asFeatureArguments(Arguments arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("A lightgbm node must take a single argument pointing to " + + "the LightGBM model file under [application]/models"); + return new FeatureArguments(arguments); + } + +} diff --git a/config-model/src/test/cfg/application/ml_models/models/lightgbm_regression.json b/config-model/src/test/cfg/application/ml_models/models/lightgbm_regression.json new file mode 100644 index 00000000000..cf0488ecd8b --- /dev/null +++ b/config-model/src/test/cfg/application/ml_models/models/lightgbm_regression.json @@ -0,0 +1,275 @@ +{ + "name": "tree", + "version": "v3", + "num_class": 1, + "num_tree_per_iteration": 1, + "label_index": 0, + "max_feature_idx": 3, + "average_output": false, + "objective": "regression", + "feature_names": [ + "numerical_1", + "numerical_2", + "categorical_1", + "categorical_2" + ], + "monotone_constraints": [], + "tree_info": [ + { + "tree_index": 0, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 68.5353012084961, + "threshold": 0.46643291586559305, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 2.1594397038037663, + "leaf_weight": 469, + "leaf_count": 469 + }, + "right_child": { + "split_index": 1, + "split_feature": 3, + "split_gain": 41.27640151977539, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.246035, + "internal_weight": 531, + "internal_count": 531, + "left_child": { + "leaf_index": 1, + "leaf_value": 2.235297305276056, + "leaf_weight": 302, + "leaf_count": 302 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 2.1792953471546546, + "leaf_weight": 229, + "leaf_count": 229 + } + } + } + }, + { + "tree_index": 1, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 2, + "split_gain": 64.22250366210938, + "threshold": "3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 0.03070842919354316, + "leaf_weight": 399, + "leaf_count": 399 + }, + "right_child": { + "split_index": 1, + "split_feature": 0, + "split_gain": 36.74250030517578, + "threshold": 0.5102250691730842, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.204906, + "internal_weight": 601, + "internal_count": 601, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.04439151147520909, + "leaf_weight": 315, + "leaf_count": 315 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.005117411709368601, + "leaf_weight": 286, + "leaf_count": 286 + } + } + } + }, + { + "tree_index": 2, + "num_leaves": 3, + "num_cat": 0, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 57.1327018737793, + "threshold": 0.668665477622446, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 40.859100341796875, + "threshold": 0.008118820676863816, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.162926, + "internal_weight": 681, + "internal_count": 681, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.15361238490967524, + "leaf_weight": 21, + "leaf_count": 21 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": -0.01192330846157292, + "leaf_weight": 660, + "leaf_count": 660 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": 0.03499044894987518, + "leaf_weight": 319, + "leaf_count": 319 + } + } + }, + { + "tree_index": 3, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 0, + "split_gain": 54.77090072631836, + "threshold": 0.5201391072644542, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.02141000620783247, + "leaf_weight": 543, + "leaf_count": 543 + }, + "right_child": { + "split_index": 1, + "split_feature": 2, + "split_gain": 27.200700759887695, + "threshold": "0||1", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.255704, + "internal_weight": 457, + "internal_count": 457, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.004121485787596721, + "leaf_weight": 191, + "leaf_count": 191 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.04534090904886873, + "leaf_weight": 266, + "leaf_count": 266 + } + } + } + }, + { + "tree_index": 4, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 3, + "split_gain": 51.84349822998047, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 39.352699279785156, + "threshold": 0.27283279016959255, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0.188414, + "internal_weight": 593, + "internal_count": 593, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.01924803254356527, + "leaf_weight": 184, + "leaf_count": 184 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.03643772842347651, + "leaf_weight": 409, + "leaf_count": 409 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": -0.02701711918923075, + "leaf_weight": 407, + "leaf_count": 407 + } + } + } + ], + "pandas_categorical": [ + [ + "a", + "b", + "c", + "d", + "e" + ], + [ + "i", + "j", + "k", + "l", + "m" + ] + ] +}
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd index ab5e42f983d..247df8a0241 100644 --- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd +++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd @@ -33,8 +33,12 @@ search test { expression: xgboost("xgboost_2_2") } + function my_lightgbm() { + expression: lightgbm("lightgbm_regression") + } + first-phase { - expression: mnist_tensorflow + mnist_softmax_tensorflow + mnist_softmax_onnx + my_xgboost + expression: mnist_tensorflow + mnist_softmax_tensorflow + mnist_softmax_onnx + my_xgboost + my_lightgbm } } diff --git a/config-model/src/test/cfg/application/ml_serving/models/lightgbm_regression.json b/config-model/src/test/cfg/application/ml_serving/models/lightgbm_regression.json new file mode 100644 index 00000000000..cf0488ecd8b --- /dev/null +++ b/config-model/src/test/cfg/application/ml_serving/models/lightgbm_regression.json @@ -0,0 +1,275 @@ +{ + "name": "tree", + "version": "v3", + "num_class": 1, + "num_tree_per_iteration": 1, + "label_index": 0, + "max_feature_idx": 3, + "average_output": false, + "objective": "regression", + "feature_names": [ + "numerical_1", + "numerical_2", + "categorical_1", + "categorical_2" + ], + "monotone_constraints": [], + "tree_info": [ + { + "tree_index": 0, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 68.5353012084961, + "threshold": 0.46643291586559305, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 2.1594397038037663, + "leaf_weight": 469, + "leaf_count": 469 + }, + "right_child": { + "split_index": 1, + "split_feature": 3, + "split_gain": 41.27640151977539, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.246035, + "internal_weight": 531, + "internal_count": 531, + "left_child": { + "leaf_index": 1, + "leaf_value": 2.235297305276056, + "leaf_weight": 302, + "leaf_count": 302 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 2.1792953471546546, + "leaf_weight": 229, + "leaf_count": 229 + } + } + } + }, + { + "tree_index": 1, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 2, + "split_gain": 64.22250366210938, + "threshold": "3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 0.03070842919354316, + "leaf_weight": 399, + "leaf_count": 399 + }, + "right_child": { + "split_index": 1, + "split_feature": 0, + "split_gain": 36.74250030517578, + "threshold": 0.5102250691730842, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.204906, + "internal_weight": 601, + "internal_count": 601, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.04439151147520909, + "leaf_weight": 315, + "leaf_count": 315 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.005117411709368601, + "leaf_weight": 286, + "leaf_count": 286 + } + } + } + }, + { + "tree_index": 2, + "num_leaves": 3, + "num_cat": 0, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 57.1327018737793, + "threshold": 0.668665477622446, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 40.859100341796875, + "threshold": 0.008118820676863816, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.162926, + "internal_weight": 681, + "internal_count": 681, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.15361238490967524, + "leaf_weight": 21, + "leaf_count": 21 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": -0.01192330846157292, + "leaf_weight": 660, + "leaf_count": 660 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": 0.03499044894987518, + "leaf_weight": 319, + "leaf_count": 319 + } + } + }, + { + "tree_index": 3, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 0, + "split_gain": 54.77090072631836, + "threshold": 0.5201391072644542, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.02141000620783247, + "leaf_weight": 543, + "leaf_count": 543 + }, + "right_child": { + "split_index": 1, + "split_feature": 2, + "split_gain": 27.200700759887695, + "threshold": "0||1", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.255704, + "internal_weight": 457, + "internal_count": 457, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.004121485787596721, + "leaf_weight": 191, + "leaf_count": 191 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.04534090904886873, + "leaf_weight": 266, + "leaf_count": 266 + } + } + } + }, + { + "tree_index": 4, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 3, + "split_gain": 51.84349822998047, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 39.352699279785156, + "threshold": 0.27283279016959255, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0.188414, + "internal_weight": 593, + "internal_count": 593, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.01924803254356527, + "leaf_weight": 184, + "leaf_count": 184 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.03643772842347651, + "leaf_weight": 409, + "leaf_count": 409 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": -0.02701711918923075, + "leaf_weight": 407, + "leaf_count": 407 + } + } + } + ], + "pandas_categorical": [ + [ + "a", + "b", + "c", + "d", + "e" + ], + [ + "i", + "j", + "k", + "l", + "m" + ] + ] +}
\ No newline at end of file diff --git a/config-model/src/test/integration/lightgbm/models/regression.json b/config-model/src/test/integration/lightgbm/models/regression.json new file mode 100644 index 00000000000..cf0488ecd8b --- /dev/null +++ b/config-model/src/test/integration/lightgbm/models/regression.json @@ -0,0 +1,275 @@ +{ + "name": "tree", + "version": "v3", + "num_class": 1, + "num_tree_per_iteration": 1, + "label_index": 0, + "max_feature_idx": 3, + "average_output": false, + "objective": "regression", + "feature_names": [ + "numerical_1", + "numerical_2", + "categorical_1", + "categorical_2" + ], + "monotone_constraints": [], + "tree_info": [ + { + "tree_index": 0, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 68.5353012084961, + "threshold": 0.46643291586559305, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 2.1594397038037663, + "leaf_weight": 469, + "leaf_count": 469 + }, + "right_child": { + "split_index": 1, + "split_feature": 3, + "split_gain": 41.27640151977539, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.246035, + "internal_weight": 531, + "internal_count": 531, + "left_child": { + "leaf_index": 1, + "leaf_value": 2.235297305276056, + "leaf_weight": 302, + "leaf_count": 302 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 2.1792953471546546, + "leaf_weight": 229, + "leaf_count": 229 + } + } + } + }, + { + "tree_index": 1, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 2, + "split_gain": 64.22250366210938, + "threshold": "3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": 0.03070842919354316, + "leaf_weight": 399, + "leaf_count": 399 + }, + "right_child": { + "split_index": 1, + "split_feature": 0, + "split_gain": 36.74250030517578, + "threshold": 0.5102250691730842, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.204906, + "internal_weight": 601, + "internal_count": 601, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.04439151147520909, + "leaf_weight": 315, + "leaf_count": 315 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.005117411709368601, + "leaf_weight": 286, + "leaf_count": 286 + } + } + } + }, + { + "tree_index": 2, + "num_leaves": 3, + "num_cat": 0, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 1, + "split_gain": 57.1327018737793, + "threshold": 0.668665477622446, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 40.859100341796875, + "threshold": 0.008118820676863816, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": -0.162926, + "internal_weight": 681, + "internal_count": 681, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.15361238490967524, + "leaf_weight": 21, + "leaf_count": 21 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": -0.01192330846157292, + "leaf_weight": 660, + "leaf_count": 660 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": 0.03499044894987518, + "leaf_weight": 319, + "leaf_count": 319 + } + } + }, + { + "tree_index": 3, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 0, + "split_gain": 54.77090072631836, + "threshold": 0.5201391072644542, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.02141000620783247, + "leaf_weight": 543, + "leaf_count": 543 + }, + "right_child": { + "split_index": 1, + "split_feature": 2, + "split_gain": 27.200700759887695, + "threshold": "0||1", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0.255704, + "internal_weight": 457, + "internal_count": 457, + "left_child": { + "leaf_index": 1, + "leaf_value": -0.004121485787596721, + "leaf_weight": 191, + "leaf_count": 191 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.04534090904886873, + "leaf_weight": 266, + "leaf_count": 266 + } + } + } + }, + { + "tree_index": 4, + "num_leaves": 3, + "num_cat": 1, + "shrinkage": 0.1, + "tree_structure": { + "split_index": 0, + "split_feature": 3, + "split_gain": 51.84349822998047, + "threshold": "2||3||4", + "decision_type": "==", + "default_left": false, + "missing_type": "NaN", + "internal_value": 0, + "internal_weight": 0, + "internal_count": 1000, + "left_child": { + "split_index": 1, + "split_feature": 1, + "split_gain": 39.352699279785156, + "threshold": 0.27283279016959255, + "decision_type": "<=", + "default_left": true, + "missing_type": "NaN", + "internal_value": 0.188414, + "internal_weight": 593, + "internal_count": 593, + "left_child": { + "leaf_index": 0, + "leaf_value": -0.01924803254356527, + "leaf_weight": 184, + "leaf_count": 184 + }, + "right_child": { + "leaf_index": 2, + "leaf_value": 0.03643772842347651, + "leaf_weight": 409, + "leaf_count": 409 + } + }, + "right_child": { + "leaf_index": 1, + "leaf_value": -0.02701711918923075, + "leaf_weight": 407, + "leaf_count": 407 + } + } + } + ], + "pandas_categorical": [ + [ + "a", + "b", + "c", + "d", + "e" + ], + [ + "i", + "j", + "k", + "l", + "m" + ] + ] +}
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 08dd5148b29..0cd6674751e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -15,6 +15,7 @@ import com.yahoo.searchdefinition.parser.ParseException; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; @@ -33,6 +34,7 @@ class RankProfileSearchFixture { private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), + new LightGBMImporter(), new XGBoostImporter()); private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithLightGBMTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithLightGBMTestCase.java new file mode 100644 index 00000000000..79d19371f1c --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithLightGBMTestCase.java @@ -0,0 +1,88 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.parser.ParseException; +import org.junit.After; +import org.junit.Test; + +import java.io.IOException; + +/** + * @author lesters + */ +public class RankingExpressionWithLightGBMTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/lightgbm/"); + + private final static String lightGBMExpression = + "if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in [\"k\", \"l\", \"m\"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in [\"d\", \"e\"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in [\"a\", \"b\"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in [\"k\", \"l\", \"m\"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)"; + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + @Test + public void testLightGBMReference() { + RankProfileSearchFixture search = fixtureWith("lightgbm('regression.json')"); + search.assertFirstPhaseExpression(lightGBMExpression, "my_profile"); + } + + @Test + public void testNestedLightGBMReference() { + RankProfileSearchFixture search = fixtureWith("5 + sum(lightgbm('regression.json'))"); + search.assertFirstPhaseExpression("5 + reduce(" + lightGBMExpression + ", sum)", "my_profile"); + } + + @Test + public void testImportingFromStoredExpressions() throws IOException { + RankProfileSearchFixture search = fixtureWith("lightgbm('regression.json')"); + search.assertFirstPhaseExpression(lightGBMExpression, "my_profile"); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage storedApplication = new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith("lightgbm('regression.json')"); + searchFromStored.assertFirstPhaseExpression(lightGBMExpression, "my_profile"); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + private RankProfileSearchFixture fixtureWith(String firstPhaseExpression) { + return fixtureWith(firstPhaseExpression, null, null, + new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String firstPhaseExpression, + String constant, + String field, + RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage application) { + try { + RankProfileSearchFixture fixture = new RankProfileSearchFixture( + application, + application.getQueryProfiles(), + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }", + constant, + field); + fixture.compileRankProfile("my_profile", applicationDir.append("models")); + return fixture; + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + +} + diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index ce36ecc4a1c..7a3b76db7f8 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; -import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import com.google.common.collect.ImmutableList; import com.yahoo.config.model.ApplicationPackageTester; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; @@ -10,8 +9,10 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; +import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; @@ -35,6 +36,7 @@ public class ImportedModelTester { private final ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), new OnnxImporter(), + new LightGBMImporter(), new XGBoostImporter(), new VespaImporter()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java index c5c475360a3..ced7243adf5 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java @@ -45,7 +45,7 @@ public class MlModelsTest { private void verify(VespaModel model) { assertEquals("Global models are created (although not used directly here", - 4, model.rankProfileList().getRankProfiles().size()); + 5, model.rankProfileList().getRankProfiles().size()); RankProfilesConfig.Builder builder = new RankProfilesConfig.Builder(); model.getSearchClusters().get(0).getConfig(builder); @@ -71,8 +71,9 @@ public class MlModelsTest { "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_softmax_onnx).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))\n" + "rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (!(f56 >= -0.242398), 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (!(f60 >= -0.482947), if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" + + "rankingExpression(my_lightgbm).rankingScript: if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in [\"k\", \"l\", \"m\"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in [\"d\", \"e\"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in [\"a\", \"b\"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in [\"k\", \"l\", \"m\"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)\n" + "vespa.rank.firstphase: rankingExpression(firstphase)\n" + - "rankingExpression(firstphase).rankingScript: rankingExpression(mnist_tensorflow) + rankingExpression(mnist_softmax_tensorflow) + rankingExpression(mnist_softmax_onnx) + rankingExpression(my_xgboost)\n" + + "rankingExpression(firstphase).rankingScript: rankingExpression(mnist_tensorflow) + rankingExpression(mnist_softmax_tensorflow) + rankingExpression(mnist_softmax_onnx) + rankingExpression(my_xgboost) + rankingExpression(my_lightgbm)\n" + "vespa.type.attribute.argument: tensor<float>(d0[],d1[784])\n"; } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 3d4ac7f2eeb..2d3ddc33afb 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -96,9 +96,10 @@ public class ModelEvaluationTest { cluster.getConfig(cb); RankingConstantsConfig constantsConfig = new RankingConstantsConfig(cb); - assertEquals(4, config.rankprofile().size()); + assertEquals(5, config.rankprofile().size()); Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); + assertTrue(modelNames.contains("lightgbm_regression")); assertTrue(modelNames.contains("mnist_saved")); assertTrue(modelNames.contains("mnist_softmax")); assertTrue(modelNames.contains("mnist_softmax_saved")); @@ -112,13 +113,18 @@ public class ModelEvaluationTest { ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null)) .importFrom(config, constantsConfig)); - assertEquals(4, evaluator.models().size()); + assertEquals(5, evaluator.models().size()); Model xgboost = evaluator.models().get("xgboost_2_2"); assertNotNull(xgboost); assertNotNull(xgboost.evaluatorOf()); assertNotNull(xgboost.evaluatorOf("xgboost_2_2")); + Model lightgbm = evaluator.models().get("lightgbm_regression"); + assertNotNull(lightgbm); + assertNotNull(lightgbm.evaluatorOf()); + assertNotNull(lightgbm.evaluatorOf("lightgbm_regression")); + Model tensorflow_mnist = evaluator.models().get("mnist_saved"); assertNotNull(tensorflow_mnist); assertEquals(1, tensorflow_mnist.functions().size()); |