diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-23 11:00:34 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-23 11:00:34 +0200 |
commit | 64f528e36db2ced78535fc8e93ea2617fff55921 (patch) | |
tree | 5bd87ca8bd490b81c528f4e3b2084fa1d62ff938 /config-model | |
parent | d9cf52f4d552f065501258253414908a1b9a4ab6 (diff) |
Make XGBoost models first class citizens
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java | 27 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java | 2 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java | 2 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXGBoostTestCase.java (renamed from config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java) | 43 |
4 files changed, 54 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index 61419918f2a..e6b08ab0350 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -1,8 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.XgboostImporter; +import com.yahoo.searchlib.rankingexpression.integration.ml.XGBoostImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -10,6 +11,8 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; /** * Replaces instances of the xgboost(model-path) @@ -17,10 +20,12 @@ import java.io.UncheckedIOException; * the same computation. * * @author grace-lam + * @author bratseth */ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final XgboostImporter xgboostImporter = new XgboostImporter(); + /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ + private final Map<Path, ConvertedModel> convertedXGBoostModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -33,26 +38,22 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - if (!feature.getName().equals("xgboost")) return feature; + if ( ! feature.getName().equals("xgboost")) return feature; try { - ConvertedModel.FeatureArguments arguments = asFeatureArguments(feature.getArguments()); - ConvertedModel.ModelStore store = new ConvertedModel.ModelStore(context.rankProfile().getSearch().sourceApplication(), - arguments.modelPath()); - RankingExpression expression = xgboostImporter.parseModel(store.sourceModelFile().toString()); - return expression.getRoot(); + Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); + ConvertedModel convertedModel = + convertedXGBoostModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context)); + return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e); } } private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { - if (arguments.isEmpty()) - throw new IllegalArgumentException("An xgboost node must take an argument pointing to " + + if (arguments.size() != 1) + throw new IllegalArgumentException("An xgboost node must take a single argument pointing to " + "the xgboost model directory under [application]/models"); - if (arguments.expressions().size() > 1) - throw new IllegalArgumentException("An xgboost feature can have at most 1 argument"); - return new ConvertedModel.FeatureArguments(arguments); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 90137ddde49..77d20657f64 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -36,7 +36,7 @@ public class RankingExpressionWithOnnxTestCase { private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; @After - public void removeGeneratedConstantTensorFiles() { + public void removeGeneratedModelFiles() { IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 2804b92767a..cf37864b73a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -51,7 +51,7 @@ public class RankingExpressionWithTensorFlowTestCase { private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_layer_Variable_1_read), f(a,b)(a + b))"; @After - public void removeGeneratedConstantTensorFiles() { + public void removeGeneratedModelFiles() { IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXGBoostTestCase.java index 2e109553560..832a974082c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXgboostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithXGBoostTestCase.java @@ -1,31 +1,64 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.parser.ParseException; +import org.junit.After; import org.junit.Test; +import java.io.IOException; + /** * @author grace-lam + * @author bratseth */ -public class RankingExpressionWithXgboostTestCase { +public class RankingExpressionWithXGBoostTestCase { private final Path applicationDir = Path.fromString("src/test/integration/xgboost/"); - private final static String vespaExpression = "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + " + + + private final static String vespaExpression = + "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + " + "if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)"; + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + @Test - public void testXgboostReference() { + public void testXGBoostReference() { RankProfileSearchFixture search = fixtureWith("xgboost('xgboost.2.2.json')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test - public void testNestedXgboostReference() { + public void testNestedXGBoostReference() { RankProfileSearchFixture search = fixtureWith("5 + sum(xgboost('xgboost.2.2.json'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } + @Test + public void testImportingFromStoredExpressions() throws IOException { + RankProfileSearchFixture search = fixtureWith("xgboost('xgboost.2.2.json')"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage storedApplication = new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith("xgboost('xgboost.2.2.json')"); + searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + private RankProfileSearchFixture fixtureWith(String firstPhaseExpression) { return fixtureWith(firstPhaseExpression, null, null, new RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage(applicationDir)); @@ -46,7 +79,7 @@ public class RankingExpressionWithXgboostTestCase { " }", constant, field); - fixture.compileRankProfile("my_profile", applicationDir); + fixture.compileRankProfile("my_profile", applicationDir.append("models")); return fixture; } catch (ParseException e) { throw new IllegalArgumentException(e); |