summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-23 11:00:34 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-23 11:00:34 +0200
commit64f528e36db2ced78535fc8e93ea2617fff55921 (patch)
tree5bd87ca8bd490b81c528f4e3b2084fa1d62ff938 /config-model
parentd9cf52f4d552f065501258253414908a1b9a4ab6 (diff)
Make XGBoost models first class citizens
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java27
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXGBoostTestCase.java (renamed from config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java)43
4 files changed, 54 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index 61419918f2a..e6b08ab0350 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -1,8 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
+import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.XgboostImporter;
+import com.yahoo.searchlib.rankingexpression.integration.ml.XGBoostImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -10,6 +11,8 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.io.UncheckedIOException;
+import java.util.HashMap;
+import java.util.Map;
/**
* Replaces instances of the xgboost(model-path)
@@ -17,10 +20,12 @@ import java.io.UncheckedIOException;
* the same computation.
*
* @author grace-lam
+ * @author bratseth
*/
public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private final XgboostImporter xgboostImporter = new XgboostImporter();
+ /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
+ private final Map<Path, ConvertedModel> convertedXGBoostModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -33,26 +38,22 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- if (!feature.getName().equals("xgboost")) return feature;
+ if ( ! feature.getName().equals("xgboost")) return feature;
try {
- ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments());
- ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(),
- arguments.modelPath());
- RankingExpression expression = xgboostImporter.parseModel(store.sourceModelFile().toString());
- return expression.getRoot();
+ Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ ConvertedModel convertedModel =
+ convertedXGBoostModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
} catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e);
}
}
private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) {
- if (arguments.isEmpty())
- throw new IllegalArgumentException("An xgboost node must take an argument pointing to " +
+ if (arguments.size() != 1)
+ throw new IllegalArgumentException("An xgboost node must take a single argument pointing to " +
"the xgboost model directory under [application]/models");
- if (arguments.expressions().size() > 1)
- throw new IllegalArgumentException("An xgboost feature can have at most 1 argument");
-
return new ConvertedModel.FeatureArguments(arguments);
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 90137ddde49..77d20657f64 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -36,7 +36,7 @@ public class RankingExpressionWithOnnxTestCase {
private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
@After
- public void removeGeneratedConstantTensorFiles() {
+ public void removeGeneratedModelFiles() {
IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 2804b92767a..cf37864b73a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -51,7 +51,7 @@ public class RankingExpressionWithTensorFlowTestCase {
private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_layer_Variable_1_read), f(a,b)(a + b))";
@After
- public void removeGeneratedConstantTensorFiles() {
+ public void removeGeneratedModelFiles() {
IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXGBoostTestCase.java
index 2e109553560..832a974082c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXGBoostTestCase.java
@@ -1,31 +1,64 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.parser.ParseException;
+import org.junit.After;
import org.junit.Test;
+import java.io.IOException;
+
/**
* @author grace-lam
+ * @author bratseth
*/
-public class RankingExpressionWithXgboostTestCase {
+public class RankingExpressionWithXGBoostTestCase {
private final Path applicationDir = Path.fromString("src/test/integration/xgboost/");
- private final static String vespaExpression = "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + " +
+
+ private final static String vespaExpression =
+ "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + " +
"if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)";
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
@Test
- public void testXgboostReference() {
+ public void testXGBoostReference() {
RankProfileSearchFixture search = fixtureWith("xgboost('xgboost.2.2.json')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
- public void testNestedXgboostReference() {
+ public void testNestedXGBoostReference() {
RankProfileSearchFixture search = fixtureWith("5 + sum(xgboost('xgboost.2.2.json'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
+ @Test
+ public void testImportingFromStoredExpressions() throws IOException {
+ RankProfileSearchFixture search = fixtureWith("xgboost('xgboost.2.2.json')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage storedApplication = new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith("xgboost('xgboost.2.2.json')");
+ searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
private RankProfileSearchFixture fixtureWith(String firstPhaseExpression) {
return fixtureWith(firstPhaseExpression, null, null,
new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(applicationDir));
@@ -46,7 +79,7 @@ public class RankingExpressionWithXgboostTestCase {
" }",
constant,
field);
- fixture.compileRankProfile("my_profile", applicationDir);
+ fixture.compileRankProfile("my_profile", applicationDir.append("models"));
return fixture;
} catch (ParseException e) {
throw new IllegalArgumentException(e);