summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-15 09:46:39 +0200
committerGitHub <noreply@github.com>2018-08-15 09:46:39 +0200
commit9a402fd71acf3af060302600ab436ed256069fcf (patch)
treed4295359f3c303014e89c5a4b09e9af85bea4bbf
parent9fb31660d34357e7640ab808f65d10b7c2dcddba (diff)
parent88f31d5f5af6af44f28de9f363c33044220bd611 (diff)
Merge pull request #6577 from grace-lam/add-xgboost-converter
Add XGBoost converter
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java58
-rw-r--r--config-model/src/test/integration/xgboost/models/xgboost.2.2.json19
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java55
-rw-r--r--searchlib/pom.xml2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XgboostImporter.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java77
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java77
8 files changed, 315 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
index 6ca16c1559d..34721ee4da1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
@@ -22,6 +22,7 @@ public class ExpressionTransforms {
private final List<ExpressionTransformer> transforms =
ImmutableList.of(new TensorFlowFeatureConverter(),
new OnnxFeatureConverter(),
+ new XgboostFeatureConverter(),
new ConstantDereferencer(),
new ConstantTensorTransformer(),
new MacroInliner(),
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
new file mode 100644
index 00000000000..4ae223ec3a5
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -0,0 +1,58 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.XgboostImporter;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+
+import java.io.UncheckedIOException;
+
+/**
+ * Replaces instances of the xgboost(model-path)
+ * pseudofeature with the native Vespa ranking expression implementing
+ * the same computation.
+ *
+ * @author grace-lam
+ */
+public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+
+ private final XgboostImporter xgboostImporter = new XgboostImporter();
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode)
+ return transformFeature((ReferenceNode) node, context);
+ else if (node instanceof CompositeNode)
+ return super.transformChildren((CompositeNode) node, context);
+ else
+ return node;
+ }
+
+ private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if (!feature.getName().equals("xgboost")) return feature;
+
+ try {
+ ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments());
+ ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
+ RankingExpression expression = xgboostImporter.parseModel(store.modelDir().toString());
+ return expression.getRoot();
+ } catch (IllegalArgumentException | UncheckedIOException e) {
+ throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e);
+ }
+ }
+
+ private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("An xgboost node must take an argument pointing to " +
+ "the xgboost model directory under [application]/models");
+ if (arguments.expressions().size() > 1)
+ throw new IllegalArgumentException("An xgboost feature can have at most 1 argument");
+
+ return new ConvertedModel.FeatureArguments(arguments);
+ }
+
+}
diff --git a/config-model/src/test/integration/xgboost/models/xgboost.2.2.json b/config-model/src/test/integration/xgboost/models/xgboost.2.2.json
new file mode 100644
index 00000000000..f8949b47e52
--- /dev/null
+++ b/config-model/src/test/integration/xgboost/models/xgboost.2.2.json
@@ -0,0 +1,19 @@
+[
+ { "nodeid": 0, "depth": 0, "split": "f29", "split_condition": -0.1234567, "yes": 1, "no": 2, "missing": 1, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f56", "split_condition": -0.242398, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 1.71218 },
+ { "nodeid": 4, "leaf": -1.70044 }
+ ]},
+ { "nodeid": 2, "depth": 1, "split": "f109", "split_condition": 0.8723473, "yes": 5, "no": 6, "missing": 5, "children": [
+ { "nodeid": 5, "leaf": -1.94071 },
+ { "nodeid": 6, "leaf": 1.85965 }
+ ]}
+ ]},
+ { "nodeid": 0, "depth": 0, "split": "f60", "split_condition": -0.482947, "yes": 1, "no": 2, "missing": 1, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f29", "split_condition": -4.2387498, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 0.784718 },
+ { "nodeid": 4, "leaf": -0.96853 }
+ ]},
+ { "nodeid": 2, "leaf": -6.23624 }
+ ]}
+] \ No newline at end of file
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
new file mode 100644
index 00000000000..b65cb0b3d5f
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
@@ -0,0 +1,55 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.path.Path;
+import com.yahoo.searchdefinition.parser.ParseException;
+import org.junit.Test;
+
+/**
+ * @author grace-lam
+ */
+public class RankingExpressionWithXgboostTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/xgboost/");
+ private final static String vespaExpression = "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + " +
+ "if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)";
+
+ @Test
+ public void testXgboostReference() {
+ RankProfileSearchFixture search = fixtureWith("xgboost('xgboost.2.2.json')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ }
+
+ @Test
+ public void testNestedXgboostReference() {
+ RankProfileSearchFixture search = fixtureWith("5 + sum(xgboost('xgboost.2.2.json'))");
+ search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
+ }
+
+ private RankProfileSearchFixture fixtureWith(String firstPhaseExpression) {
+ return fixtureWith(firstPhaseExpression, null, null,
+ new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture fixtureWith(String firstPhaseExpression,
+ String constant,
+ String field,
+ RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(
+ application,
+ application.getQueryProfiles(),
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: " + firstPhaseExpression +
+ " }\n" +
+ " }",
+ constant,
+ field);
+ } catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+}
+
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index 0202f8510bb..8037f1d399a 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -51,12 +51,10 @@
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
- <scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
- <scope>test</scope>
</dependency>
</dependencies>
<build>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XgboostImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XgboostImporter.java
new file mode 100644
index 00000000000..f9717c39a8b
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/XgboostImporter.java
@@ -0,0 +1,28 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost.XGBoostParser;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+
+import java.io.IOException;
+
+/**
+ * Converts a saved XGBoost model into a ranking expression.
+ *
+ * @author grace-lam
+ */
+public class XgboostImporter {
+
+ public RankingExpression parseModel(String modelPath) {
+ try {
+ XGBoostParser parser = new XGBoostParser(modelPath);
+ return new RankingExpression(parser.toRankingExpression());
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import XGBoost model from '" + modelPath + "'", e);
+ } catch (ParseException e) {
+ throw new IllegalArgumentException("Could not parse ranking expression: " + e);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java
new file mode 100644
index 00000000000..fef8bfec81d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostParser.java
@@ -0,0 +1,77 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+/**
+ * @author grace-lam
+ */
+public class XGBoostParser {
+
+ private List<XGBoostTree> xgboostTrees;
+
+ /**
+ * Constructor stores parsed JSON trees.
+ *
+ * @param filePath XGBoost JSON output file.
+ * @throws JsonProcessingException Fails JSON parsing.
+ * @throws IOException Fails file reading.
+ */
+ public XGBoostParser(String filePath) throws JsonProcessingException, IOException {
+ this.xgboostTrees = new ArrayList<>();
+ ObjectMapper mapper = new ObjectMapper();
+ JsonNode forestNode = mapper.readTree(new File(filePath));
+ for (JsonNode treeNode : forestNode) {
+ this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class));
+ }
+ }
+
+ /**
+ * Converts parsed JSON trees to Vespa ranking expressions.
+ *
+ * @return Vespa ranking expressions.
+ */
+ public String toRankingExpression() {
+ StringBuilder ret = new StringBuilder();
+ for (int i = 0; i < xgboostTrees.size(); i++) {
+ ret.append(treeToRankExp(xgboostTrees.get(i)));
+ if (i != xgboostTrees.size() - 1) {
+ ret.append(" + \n");
+ }
+ }
+ return ret.toString();
+ }
+
+ /**
+ * Recursive helper function for toRankingExpression().
+ *
+ * @param node XGBoost tree node to convert.
+ * @return Vespa ranking expression for input node.
+ */
+ public String treeToRankExp(XGBoostTree node) {
+ if (node.isLeaf()) {
+ return Double.toString(node.getLeaf());
+ } else {
+ assert node.getChildren().size() == 2;
+ String trueExp;
+ String falseExp;
+ if (node.getYes() == node.getChildren().get(0).getNodeid()) {
+ trueExp = treeToRankExp(node.getChildren().get(0));
+ falseExp = treeToRankExp(node.getChildren().get(1));
+ } else {
+ trueExp = treeToRankExp(node.getChildren().get(1));
+ falseExp = treeToRankExp(node.getChildren().get(0));
+ }
+ return "if (" + node.getSplit() + " < " + Double.toString(node.getSplit_condition()) + ", " + trueExp + ", "
+ + falseExp + ")";
+ }
+ }
+
+} \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java
new file mode 100644
index 00000000000..6bbc9abe8ae
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/xgboost/XGBoostTree.java
@@ -0,0 +1,77 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml.importer.xgboost;
+
+import java.util.List;
+
+/**
+ * Outlines the JSON representation used for parsing the XGBoost output file.
+ *
+ * @author grace-lam
+ */
+public class XGBoostTree {
+
+ // ID of current node.
+ private int nodeid;
+ // Depth of current node w.r.t. the tree's root.
+ private int depth;
+ // Feature name used for split.
+ private String split;
+ // Feature value threshold to split on.
+ private double split_condition;
+ // Next node if feature value < split_condition.
+ private int yes;
+ // Next node if feature value >= split_condition.
+ private int no;
+ // Next node if feature value is missing.
+ private int missing;
+ // Response value for leaf node.
+ private double leaf;
+ // List of child nodes.
+ private List<XGBoostTree> children;
+
+ public int getNodeid() {
+ return nodeid;
+ }
+
+ public int getDepth() {
+ return depth;
+ }
+
+ public String getSplit() {
+ return split;
+ }
+
+ public double getSplit_condition() {
+ return split_condition;
+ }
+
+ public int getYes() {
+ return yes;
+ }
+
+ public int getNo() {
+ return no;
+ }
+
+ public int getMissing() {
+ return missing;
+ }
+
+ public double getLeaf() {
+ return leaf;
+ }
+
+ public List<XGBoostTree> getChildren() {
+ return children;
+ }
+
+ /**
+ * Check if current node is a leaf node.
+ *
+ * @return True if leaf, false otherwise.
+ */
+ public boolean isLeaf() {
+ return children == null;
+ }
+
+} \ No newline at end of file