aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java196
1 files changed, 196 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java
new file mode 100644
index 00000000000..2c0620a0c52
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java
@@ -0,0 +1,196 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
+import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
+import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.BaseDeployLogger;
+import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchdefinition.derived.DerivedConfiguration;
+import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
+import com.yahoo.yolean.Exceptions;
+import org.junit.After;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+public class RankingExpressionWithBertTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/bert/");
+
+ /** The model name */
+ private final static String name = "bertsquad8";
+
+ private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
+
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+
+ @Ignore
+ @Test
+ public void testGlobalBertModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+// tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L));
+// tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+// tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L));
+// tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ @Ignore
+ @Test
+ public void testBertRankProfile() throws Exception {
+ StoringApplicationPackage application = new StoringApplicationPackage((applicationDir));
+
+ ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new LightGBMImporter(),
+ new XGBoostImporter());
+
+ String rankProfiles = " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: onnx('bertsquad8.onnx', 'default', 'unstack')" +
+ " }\n" +
+ " }";
+
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ QueryProfileRegistry queryProfileRegistry = application.getQueryProfiles();
+
+ SearchBuilder builder = new SearchBuilder(application, rankProfileRegistry, queryProfileRegistry);
+ String sdContent = "search test {\n" +
+ " document test {\n" +
+ " field unique_ids type tensor(d0[1]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field input_ids type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field input_mask type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field segment_ids type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }" +
+ " }\n" +
+ " rank-profile my_profile inherits default {\n" +
+ " function inline unique_ids_raw_output___9() {\n" +
+ " expression: attribute(unique_ids)\n" +
+ " }\n" +
+ " function inline input_ids() {\n" +
+ " expression: attribute(input_ids)\n" +
+ " }\n" +
+ " function inline input_mask() {\n" +
+ " expression: attribute(input_mask)\n" +
+ " }\n" +
+ " function inline segment_ids() {\n" +
+ " expression: attribute(segment_ids)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: onnx(\"bertsquad8.onnx\", \"default\", \"unstack\") \n" +
+ " }\n" +
+ " }" +
+ "}";
+ builder.importString(sdContent);
+ builder.build();
+ Search search = builder.getSearch();
+
+ RankProfile compiled = rankProfileRegistry.get(search, "my_profile")
+ .compile(queryProfileRegistry,
+ new ImportedMlModels(applicationDir.toFile(), importers));
+
+ DerivedConfiguration config = new DerivedConfiguration(search,
+ new BaseDeployLogger(),
+ new TestProperties(),
+ rankProfileRegistry,
+ queryProfileRegistry,
+ new ImportedMlModels());
+
+ config.export("/Users/lesters/temp/bert/idea/");
+
+// fixture.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ System.out.println("Joda");
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
+ String constant, String field) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder",
+ new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(application, application.getQueryProfiles(),
+ rankProfile, null, null);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ private RankProfileSearchFixture fixtureWith(String functionExpression,
+ String firstPhaseExpression,
+ String constant,
+ String field,
+ String functionName,
+ StoringApplicationPackage application) {
+ try {
+ RankProfileSearchFixture fixture = new RankProfileSearchFixture(
+ application,
+ application.getQueryProfiles(),
+ " rank-profile my_profile {\n" +
+ " function " + functionName + "() {\n" +
+ " expression: " + functionExpression +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: " + firstPhaseExpression +
+ " }\n" +
+ " }",
+ constant,
+ field);
+ fixture.compileRankProfile("my_profile", applicationDir.append("models"));
+ return fixture;
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+}