diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-21 19:21:22 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-21 19:21:22 +0100 |
commit | 99ca9b2907ff637fc6e4e0a61860923ac1c9dee5 (patch) | |
tree | d5a5e408d56e9165cd716e9531ab9bcec6a29e4a /config-model/src | |
parent | 61cae2609740b51c180b2f507b5e4d0eb399fedc (diff) |
Separate model integration into a separate module
This allows us to access model importers (such as TensorFlow)
in config models without loading one instance per config model
instance, which is not possible with TensorFlow because it depends
on JNI code.
Diffstat (limited to 'config-model/src')
26 files changed, 107 insertions, 54 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index 574c25a2f84..d784722f77a 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -21,8 +21,9 @@ import com.yahoo.config.provision.Zone; import com.yahoo.io.reader.NamedReader; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.SearchBuilder; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionBuilder; import com.yahoo.vespa.config.ConfigDefinitionKey; @@ -39,6 +40,7 @@ import java.io.IOException; import java.io.Reader; import java.time.Instant; import java.util.Collection; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; @@ -80,11 +82,23 @@ public class DeployState implements ConfigDefinitionStore { return new Builder().applicationPackage(applicationPackage).build(); } - private DeployState(ApplicationPackage applicationPackage, SearchDocumentModel searchDocumentModel, RankProfileRegistry rankProfileRegistry, - FileRegistry fileRegistry, DeployLogger deployLogger, Optional<HostProvisioner> hostProvisioner, DeployProperties properties, - Optional<ApplicationPackage> permanentApplicationPackage, Optional<ConfigDefinitionRepo> configDefinitionRepo, - java.util.Optional<Model> previousModel, Set<Rotation> rotations, Zone zone, QueryProfiles queryProfiles, - SemanticRules semanticRules, Instant now, Version wantedNodeVespaVersion) { + private DeployState(ApplicationPackage applicationPackage, + SearchDocumentModel searchDocumentModel, + RankProfileRegistry rankProfileRegistry, + FileRegistry fileRegistry, + DeployLogger deployLogger, + Optional<HostProvisioner> hostProvisioner, + DeployProperties properties, + Optional<ApplicationPackage> permanentApplicationPackage, + Optional<ConfigDefinitionRepo> configDefinitionRepo, + java.util.Optional<Model> previousModel, + Set<Rotation> rotations, + Collection<ModelImporter> modelImporters, + Zone zone, + QueryProfiles queryProfiles, + SemanticRules semanticRules, + Instant now, + Version wantedNodeVespaVersion) { this.logger = deployLogger; this.fileRegistry = fileRegistry; this.rankProfileRegistry = rankProfileRegistry; @@ -100,7 +114,8 @@ public class DeployState implements ConfigDefinitionStore { this.zone = zone; this.queryProfiles = queryProfiles; // TODO: Remove this by seeing how pagetemplates are propagated this.semanticRules = semanticRules; // TODO: Remove this by seeing how pagetemplates are propagated - this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR)); + this.importedModels = new ImportedModels(applicationPackage.getFileReference(ApplicationPackage.MODELS_DIR), + modelImporters); this.validationOverrides = applicationPackage.getValidationOverrides().map(ValidationOverrides::fromXml).orElse(ValidationOverrides.empty); this.wantedNodeVespaVersion = wantedNodeVespaVersion; @@ -232,6 +247,7 @@ public class DeployState implements ConfigDefinitionStore { private Optional<ConfigDefinitionRepo> configDefinitionRepo = Optional.empty(); private Optional<Model> previousModel = Optional.empty(); private Set<Rotation> rotations = new HashSet<>(); + private Collection<ModelImporter> modelImporters = Collections.emptyList(); private Zone zone = Zone.defaultZone(); private Instant now = Instant.now(); private Version wantedNodeVespaVersion = Vtag.currentVersion; @@ -281,6 +297,11 @@ public class DeployState implements ConfigDefinitionStore { return this; } + public Builder modelImporters(Collection<ModelImporter> modelImporters) { + this.modelImporters = modelImporters; + return this; + } + public Builder zone(Zone zone) { this.zone = zone; return this; @@ -305,9 +326,23 @@ public class DeployState implements ConfigDefinitionStore { QueryProfiles queryProfiles = new QueryProfilesBuilder().build(applicationPackage); SemanticRules semanticRules = new SemanticRuleBuilder().build(applicationPackage); SearchDocumentModel searchDocumentModel = createSearchDocumentModel(rankProfileRegistry, logger, queryProfiles, validationParameters); - return new DeployState(applicationPackage, searchDocumentModel, rankProfileRegistry, fileRegistry, logger, hostProvisioner, - properties, permanentApplicationPackage, configDefinitionRepo, previousModel, rotations, - zone, queryProfiles, semanticRules, now, wantedNodeVespaVersion); + return new DeployState(applicationPackage, + searchDocumentModel, + rankProfileRegistry, + fileRegistry, + logger, + hostProvisioner, + properties, + permanentApplicationPackage, + configDefinitionRepo, + previousModel, + rotations, + modelImporters, + zone, + queryProfiles, + semanticRules, + now, + wantedNodeVespaVersion); } private SearchDocumentModel createSearchDocumentModel(RankProfileRegistry rankProfileRegistry, diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 78f61d7192d..a0cbc4271f3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -16,7 +16,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java index 0c5c7733dda..eb221a80638 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/DerivedConfiguration.java @@ -12,7 +12,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.derived.validation.Validation; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import java.io.IOException; import java.io.Writer; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index fcbfb47c597..c955cd9956b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.RankingConstants; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index a1b0e72051b..fd468c8eca7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -9,7 +9,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 2fe2dacf2ce..f52eedc8d08 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import java.util.HashMap; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index fef198f7939..b4de5e925f6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -33,8 +33,8 @@ import com.yahoo.searchdefinition.derived.RankProfileList; import com.yahoo.searchdefinition.processing.Processing; import com.yahoo.vespa.model.container.search.QueryProfiles; import com.yahoo.vespa.model.ml.ConvertedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigKey; import com.yahoo.vespa.config.ConfigPayload; @@ -150,7 +150,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri this(configModelRegistry, deployState, true, null); } - private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) throws IOException, SAXException { + private VespaModel(ConfigModelRegistry configModelRegistry, DeployState deployState, boolean complete, FileDistributor fileDistributor) + throws IOException, SAXException { super("vespamodel"); this.validationOverrides = deployState.validationOverrides(); configModelRegistry = new VespaConfigModelRegistry(configModelRegistry); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java index 2af9b297e9e..ff64b6753f9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java @@ -21,6 +21,7 @@ import com.yahoo.config.model.deploy.DeployProperties; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.provision.Version; import com.yahoo.config.provision.Zone; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; import com.yahoo.vespa.config.VespaVersion; import com.yahoo.vespa.model.application.validation.Validation; @@ -29,6 +30,8 @@ import org.xml.sax.SAXException; import java.io.IOException; import java.time.Clock; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.logging.Logger; @@ -41,19 +44,17 @@ public class VespaModelFactory implements ModelFactory { private static final Logger log = Logger.getLogger(VespaModelFactory.class.getName()); private final ConfigModelRegistry configModelRegistry; + private final Collection<ModelImporter> modelImporters; private final Zone zone; private final Clock clock; private final Version version; /** Creates a factory for vespa models for this version of the source */ @Inject - public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) { - this(Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro), pluginRegistry, zone); - } - - /** Creates a factory for vespa models of a particular version */ - public VespaModelFactory(Version version, ComponentRegistry<ConfigModelPlugin> pluginRegistry, Zone zone) { - this.version = version; + public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry, + ComponentRegistry<ModelImporter> modelImporters, + Zone zone) { + this.version = Version.fromIntValues(VespaVersion.major, VespaVersion.minor, VespaVersion.micro); List<ConfigModelBuilder> modelBuilders = new ArrayList<>(); for (ConfigModelPlugin plugin : pluginRegistry.allComponents()) { if (plugin instanceof ConfigModelBuilder) { @@ -61,6 +62,7 @@ public class VespaModelFactory implements ModelFactory { } } this.configModelRegistry = new MapConfigModelRegistry(modelBuilders); + this.modelImporters = modelImporters.allComponents(); this.zone = zone; this.clock = Clock.systemUTC(); } @@ -79,6 +81,7 @@ public class VespaModelFactory implements ModelFactory { } else { this.configModelRegistry = configModelRegistry; } + this.modelImporters = Collections.emptyList(); this.zone = Zone.defaultZone(); this.clock = clock; } @@ -137,6 +140,7 @@ public class VespaModelFactory implements ModelFactory { .properties(createDeployProperties(modelContext.properties())) .modelHostProvisioner(createHostProvisioner(modelContext)) .rotations(modelContext.properties().rotations()) + .modelImporters(modelImporters) .zone(zone) .now(clock.instant()) .wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion()); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 30586b1e677..18f3cd6e088 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -19,7 +19,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; @@ -41,7 +41,6 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; import java.io.File; import java.io.IOException; -import java.io.Reader; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java index 07a36832094..ad1676ee7e1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/IncorrectRankingExpressionFileRefTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 4df3add13c5..14847938f68 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -16,7 +16,7 @@ import com.yahoo.searchdefinition.document.RankType; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.Iterator; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java index 8df3985fd24..4f46d813ba5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankPropertiesTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index 150469cc928..551ef6968d2 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -3,7 +3,7 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java index e507a6c48e4..79adb1cf6bf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionInliningTestCase.java @@ -6,8 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import com.yahoo.yolean.Exceptions; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.Optional; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 6a1e5b207c6..d97dfb0cbf9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -9,7 +9,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.List; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java index 5e649c2e551..dedfbce6ae8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionValidationTestCase.java @@ -4,9 +4,8 @@ package com.yahoo.searchdefinition; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.yolean.Exceptions; -import org.junit.Ignore; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java index f39896f5779..437f43976b7 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java @@ -3,12 +3,11 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.document.DocumenttypesConfig; import com.yahoo.document.config.DocumentmanagerConfig; -import com.yahoo.io.IOUtils; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.configmodel.producers.DocumentManager; import com.yahoo.vespa.configmodel.producers.DocumentTypes; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java index f4344c9b03c..409b4236a9c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/EmptyRankProfileTestCase.java @@ -9,12 +9,9 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; -import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; -import java.io.IOException; - /** * Tests deriving rank for files from search definitions * diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java index 2bae285301c..f2dd16577d1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/LiteralBoostTestCase.java @@ -11,7 +11,7 @@ import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java index 723cd58a34a..93b366c8d08 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/SimpleInheritTestCase.java @@ -5,7 +5,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.io.File; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java index 8941b07101d..3a44c123f05 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/TypeConversionTestCase.java @@ -10,7 +10,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.processing.Processing; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java index d38bce04617..1c368ff6f10 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/ImplicitSearchFieldsTestCase.java @@ -6,13 +6,11 @@ import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.DerivedConfiguration; -import com.yahoo.searchdefinition.derived.Deriver; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; -import org.junit.Ignore; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; -import java.io.File; + import java.io.IOException; import static org.junit.Assert.assertEquals; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index cff9abb08ed..7f590d4b230 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import com.google.common.collect.ImmutableList; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.path.Path; @@ -10,7 +11,11 @@ import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import java.util.HashMap; import java.util.List; @@ -26,6 +31,9 @@ import static org.junit.Assert.assertEquals; */ class RankProfileSearchFixture { + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); private RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); private final QueryProfileRegistry queryProfileRegistry; private Search search; @@ -83,7 +91,8 @@ class RankProfileSearchFixture { public RankProfile compileRankProfile(String rankProfile, Path applicationDir) { RankProfile compiled = rankProfileRegistry.get(search, rankProfile) - .compile(queryProfileRegistry, new ImportedModels(applicationDir.toFile())); + .compile(queryProfileRegistry, + new ImportedModels(applicationDir.toFile(), importers)); compiledRankProfiles.put(rankProfile, compiled); return compiled; } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index fd048737b43..78dcae53f7e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -8,7 +8,7 @@ import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.io.IOException; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index 6e3a227e2a9..7eee18c1059 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,7 +17,7 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModels; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModels; import org.junit.Test; import java.util.List; diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java index 2ae629562d0..4d8bdb7a6c6 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java @@ -1,11 +1,17 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import com.google.common.collect.ImmutableList; import com.yahoo.config.model.ApplicationPackageTester; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.model.VespaModel; @@ -26,6 +32,10 @@ import static org.junit.Assert.assertEquals; */ public class ImportedModelTester { + private final ImmutableList<ModelImporter> importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new XGBoostImporter()); + private final String modelName; private final Path applicationDir; @@ -36,7 +46,10 @@ public class ImportedModelTester { public VespaModel createVespaModel() { try { - return new VespaModel(ApplicationPackageTester.create(applicationDir.toString()).app()); + DeployState.Builder state = new DeployState.Builder(); + state.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app()); + state.modelImporters(importers); + return new VespaModel(state.build()); } catch (SAXException | IOException e) { throw new RuntimeException(e); |