diff options
8 files changed, 315 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index 6ca16c1559d..34721ee4da1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -22,6 +22,7 @@ public class ExpressionTransforms { private final List<ExpressionTransformer> transforms = ImmutableList.of(new TensorFlowFeatureConverter(), new OnnxFeatureConverter(), + new XgboostFeatureConverter(), new ConstantDereferencer(), new ConstantTensorTransformer(), new MacroInliner(), diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java new file mode 100644 index 00000000000..4ae223ec3a5 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -0,0 +1,58 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.XgboostImporter; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.io.UncheckedIOException; + +/** + * Replaces instances of the xgboost(model-path) + * pseudofeature with the native Vespa ranking expression implementing + * the same computation. + * + * @author grace-lam + */ +public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { + + private final XgboostImporter xgboostImporter = new XgboostImporter(); + + @Override + public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node, context); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node, context); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { + if (!feature.getName().equals("xgboost")) return feature; + + try { + ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments()); + ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); + RankingExpression expression = xgboostImporter.parseModel(store.modelDir().toString()); + return expression.getRoot(); + } catch (IllegalArgumentException | UncheckedIOException e) { + throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e); + } + } + + private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("An xgboost node must take an argument pointing to " + + "the xgboost model directory under [application]/models"); + if (arguments.expressions().size() > 1) + throw new IllegalArgumentException("An xgboost feature can have at most 1 argument"); + + return new ConvertedModel.FeatureArguments(arguments); + } + +} diff --git a/config-model/src/test/integration/xgboost/models/xgboost.2.2.json b/config-model/src/test/integration/xgboost/models/xgboost.2.2.json new file mode 100644 index 00000000000..f8949b47e52 --- /dev/null +++ b/config-model/src/test/integration/xgboost/models/xgboost.2.2.json @@ -0,0 +1,19 @@ +[ + { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 1.71218 }, + { "nodeid": 4, "leaf": -1.70044 } + ]}, + { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [ + { "nodeid": 5, "leaf": -1.94071 }, + { "nodeid": 6, "leaf": 1.85965 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [ + { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 0.784718 }, + { "nodeid": 4, "leaf": -0.96853 } + ]}, + { "nodeid": 2, "leaf": -6.23624 } + ]} +]
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java new file mode 100644 index 00000000000..b65cb0b3d5f --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java @@ -0,0 +1,55 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.path.Path; +import com.yahoo.searchdefinition.parser.ParseException; +import org.junit.Test; + +/** + * @author grace-lam + */ +public class RankingExpressionWithXgboostTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/xgboost/"); + private final static String vespaExpression = "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + " + + "if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)"; + + @Test + public void testXgboostReference() { + RankProfileSearchFixture search = fixtureWith("xgboost('xgboost.2.2.json')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + + @Test + public void testNestedXgboostReference() { + RankProfileSearchFixture search = fixtureWith("5 + sum(xgboost('xgboost.2.2.json'))"); + search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); + } + + private RankProfileSearchFixture fixtureWith(String firstPhaseExpression) { + return fixtureWith(firstPhaseExpression, null, null, + new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String firstPhaseExpression, + String constant, + String field, + RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture( + application, + application.getQueryProfiles(), + " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }", + constant, + field); + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + +} + diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 0202f8510bb..8037f1d399a 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -51,12 +51,10 @@ <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> - <scope>test</scope> </dependency> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-databind</artifactId> - <scope>test</scope> </dependency> </dependencies> <build> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XgboostImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XgboostImporter.java new file mode 100644 index 00000000000..f9717c39a8b --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XgboostImporter.java @@ -0,0 +1,28 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost.XGBoostParser; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; + +import java.io.IOException; + +/** + * Converts a saved XGBoost model into a ranking expression. + * + * @author grace-lam + */ +public class XgboostImporter { + + public RankingExpression parseModel(String modelPath) { + try { + XGBoostParser parser = new XGBoostParser(modelPath); + return new RankingExpression(parser.toRankingExpression()); + } catch (IOException e) { + throw new IllegalArgumentException("Could not import XGBoost model from '" + modelPath + "'", e); + } catch (ParseException e) { + throw new IllegalArgumentException("Could not parse ranking expression: " + e); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java new file mode 100644 index 00000000000..fef8bfec81d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * @author grace-lam + */ +public class XGBoostParser { + + private List<XGBoostTree> xgboostTrees; + + /** + * Constructor stores parsed JSON trees. + * + * @param filePath XGBoost JSON output file. + * @throws JsonProcessingException Fails JSON parsing. + * @throws IOException Fails file reading. + */ + public XGBoostParser(String filePath) throws JsonProcessingException, IOException { + this.xgboostTrees = new ArrayList<>(); + ObjectMapper mapper = new ObjectMapper(); + JsonNode forestNode = mapper.readTree(new File(filePath)); + for (JsonNode treeNode : forestNode) { + this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class)); + } + } + + /** + * Converts parsed JSON trees to Vespa ranking expressions. + * + * @return Vespa ranking expressions. + */ + public String toRankingExpression() { + StringBuilder ret = new StringBuilder(); + for (int i = 0; i < xgboostTrees.size(); i++) { + ret.append(treeToRankExp(xgboostTrees.get(i))); + if (i != xgboostTrees.size() - 1) { + ret.append(" + \n"); + } + } + return ret.toString(); + } + + /** + * Recursive helper function for toRankingExpression(). + * + * @param node XGBoost tree node to convert. + * @return Vespa ranking expression for input node. + */ + public String treeToRankExp(XGBoostTree node) { + if (node.isLeaf()) { + return Double.toString(node.getLeaf()); + } else { + assert node.getChildren().size() == 2; + String trueExp; + String falseExp; + if (node.getYes() == node.getChildren().get(0).getNodeid()) { + trueExp = treeToRankExp(node.getChildren().get(0)); + falseExp = treeToRankExp(node.getChildren().get(1)); + } else { + trueExp = treeToRankExp(node.getChildren().get(1)); + falseExp = treeToRankExp(node.getChildren().get(0)); + } + return "if (" + node.getSplit() + " < " + Double.toString(node.getSplit_condition()) + ", " + trueExp + ", " + + falseExp + ")"; + } + } + +}
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java new file mode 100644 index 00000000000..6bbc9abe8ae --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost; + +import java.util.List; + +/** + * Outlines the JSON representation used for parsing the XGBoost output file. + * + * @author grace-lam + */ +public class XGBoostTree { + + // ID of current node. + private int nodeid; + // Depth of current node w.r.t. the tree's root. + private int depth; + // Feature name used for split. + private String split; + // Feature value threshold to split on. + private double split_condition; + // Next node if feature value < split_condition. + private int yes; + // Next node if feature value >= split_condition. + private int no; + // Next node if feature value is missing. + private int missing; + // Response value for leaf node. + private double leaf; + // List of child nodes. + private List<XGBoostTree> children; + + public int getNodeid() { + return nodeid; + } + + public int getDepth() { + return depth; + } + + public String getSplit() { + return split; + } + + public double getSplit_condition() { + return split_condition; + } + + public int getYes() { + return yes; + } + + public int getNo() { + return no; + } + + public int getMissing() { + return missing; + } + + public double getLeaf() { + return leaf; + } + + public List<XGBoostTree> getChildren() { + return children; + } + + /** + * Check if current node is a leaf node. + * + * @return True if leaf, false otherwise. + */ + public boolean isLeaf() { + return children == null; + } + +}
\ No newline at end of file |