diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-21 19:21:22 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-21 19:21:22 +0100 |
commit | 99ca9b2907ff637fc6e4e0a61860923ac1c9dee5 (patch) | |
tree | d5a5e408d56e9165cd716e9531ab9bcec6a29e4a /config-model/src/test/java | |
parent | 61cae2609740b51c180b2f507b5e4d0eb399fedc (diff) |
Separate model integration into a separate module
This allows us to access model importers (such as TensorFlow)
in config models without loading one instance per config model
instance, which is not possible with TensorFlow because it depends
on JNI code.
Diffstat (limited to 'config-model/src/test/java')
17 files changed, 41 insertions, 27 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java index 07a36832094..ad1676ee7e1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 4df3add13c5..14847938f68 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.Iterator; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java index 8df3985fd24..4f46d813ba5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index 150469cc928..551ef6968d2 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -3,7 +3,7 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index e507a6c48e4..79adb1cf6bf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -6,8 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import com.yahoo.yolean.Exceptions; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.Optional; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 6a1e5b207c6..d97dfb0cbf9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.List; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java index 5e649c2e551..dedfbce6ae8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java @@ -4,9 +4,8 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java index f39896f5779..437f43976b7 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java @@ -3,12 +3,11 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.document.DocumenttypesConfig; import com.yahoo.document.config.DocumentmanagerConfig; -import com.yahoo.io.IOUtils; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.configmodel.producers.DocumentManager; import com.yahoo.vespa.configmodel.producers.DocumentTypes; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java index f4344c9b03c..409b4236a9c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java @@ -9,12 +9,9 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; -import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; -import java.io.IOException; - /** * Tests deriving rank for files from search definitions * diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java index 2bae285301c..f2dd16577d1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java @@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java index 723cd58a34a..93b366c8d08 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.io.File; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java index 8941b07101d..3a44c123f05 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java @@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java index d38bce04617..1c368ff6f10 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java @@ -6,13 +6,11 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.DerivedConfiguration; -import com.yahoo.searchdefinition.derived.Deriver; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import org.junit.Ignore; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; -import java.io.File; + import java.io.IOException; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index cff9abb08ed..7f590d4b230 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.google.common.collect.ImmutableList; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.path.Path; @@ -10,7 +11,11 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; @@ -26,6 +31,9 @@ import static org.junit.Assert.assertEquals; */ class RankProfileSearchFixture { + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; private Search search; @@ -83,7 +91,8 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) - .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); + .compile(queryProfileRegistry, + new ImportedModels(applicationDir.toFile(), importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index fd048737b43..78dcae53f7e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.io.IOException; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 6e3a227e2a9..7eee18c1059 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.List; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 2ae629562d0..4d8bdb7a6c6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -1,11 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import com.google.common.collect.ImmutableList; import com.yahoo.config.model.ApplicationPackageTester; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.model.VespaModel; @@ -26,6 +32,10 @@ import static org.junit.Assert.assertEquals; */ public class ImportedModelTester { + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); + private final String modelName; private final Path applicationDir; @@ -36,7 +46,10 @@ public class ImportedModelTester { public VespaModel createVespaModel() { try { - return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); + DeployState.Builder state = new DeployState.Builder(); + state.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app()); + state.modelImporters(importers); + return new VespaModel(state.build()); } catch (SAXException | IOException e) { throw new RuntimeException(e); |